cohere_rust/
lib.rs

1use std::time::Duration;
2
3use api::{
4    chat::{ChatRequest, ChatStreamRequest, ChatStreamResponse},
5    classify::{Classification, ClassifyRequest, ClassifyResponse},
6    detokenize::{DetokenizeRequest, DetokenizeResponse},
7    embed::{EmbedRequest, EmbedResponse},
8    generate::{GenerateRequest, GenerateResponse, Generation},
9    rerank::{ReRankRequest, ReRankResponse, ReRankResult},
10    tokenize::{TokenizeRequest, TokenizeResponse},
11};
12use reqwest::{header, ClientBuilder, StatusCode, Url};
13use tokio::sync::mpsc::{channel, Receiver};
14
15const COHERE_API_BASE_URL: &str = "https://api.cohere.com";
16const COHERE_API_V1: &str = "v1";
17const COHERE_API_TIMEOUT: Duration = Duration::from_secs(240);
18
19use serde::{de::DeserializeOwned, Deserialize, Serialize};
20use thiserror::Error;
21
22pub mod api;
23
24#[derive(Error, Debug)]
25pub enum CohereApiError {
26    #[error("Unexpected request error")]
27    RequestError(#[from] reqwest::Error),
28    #[error("API request failed with status code `{0}` and error message `{1}`")]
29    ApiError(StatusCode, String),
30    #[error("API key is invalid")]
31    InvalidApiKey,
32    #[error("Unknown error")]
33    Unknown,
34}
35
36#[derive(Error, Debug)]
37pub enum CohereStreamError {
38    #[error("Unexpected deserialization error")]
39    RequestError(#[from] serde_json::error::Error),
40    #[error("Unknown error `{0}`")]
41    Unknown(String),
42}
43
44/// Cohere Rust SDK to build natural language understanding and generation into your product with a few lines of code.
45pub struct Cohere {
46    api_url: String,
47    client: reqwest::Client,
48}
49
50#[derive(Deserialize, Debug)]
51struct CohereCheckApiKeyResponse {
52    valid: bool,
53}
54
55#[derive(Deserialize, Debug)]
56struct CohereApiErrorResponse {
57    message: String,
58}
59
60impl Default for Cohere {
61    fn default() -> Self {
62        let api_key = std::env::var("COHERE_API_KEY")
63            .expect("please provide a Cohere API key with the 'COHERE_API_KEY' env variable");
64        Cohere::new(format!("{COHERE_API_BASE_URL}/{COHERE_API_V1}"), api_key)
65    }
66}
67
68impl Cohere {
69    pub fn new<U: Into<String>, K: Into<String>>(api_url: U, api_key: K) -> Self {
70        let api_url: String = api_url.into();
71        let api_key: String = api_key.into();
72
73        let mut headers = header::HeaderMap::new();
74
75        let mut authorization = header::HeaderValue::from_str(&format!("Bearer {api_key}"))
76            .expect("failed to construct authorization header!");
77        authorization.set_sensitive(true);
78        headers.insert(header::AUTHORIZATION, authorization);
79
80        headers.insert(
81            "Request-Source",
82            header::HeaderValue::from_static("rust-sdk"),
83        );
84
85        headers.insert(
86            header::ACCEPT,
87            header::HeaderValue::from_static("application/json"),
88        );
89        headers.insert(
90            header::CONTENT_TYPE,
91            header::HeaderValue::from_static("application/json"),
92        );
93
94        let client = ClientBuilder::new()
95            .default_headers(headers)
96            .use_rustls_tls()
97            .timeout(COHERE_API_TIMEOUT)
98            .build()
99            .expect("failed to initialize HTTP client!");
100
101        Cohere { api_url, client }
102    }
103
104    async fn request<Request: Serialize, Response: DeserializeOwned>(
105        &self,
106        route: &'static str,
107        payload: Request,
108    ) -> Result<Response, CohereApiError> {
109        let url =
110            Url::parse(&format!("{}/{route}", self.api_url)).expect("api url should be valid");
111
112        let response = self.client.post(url).json(&payload).send().await?;
113
114        // Check for any API Warnings
115        if let Some(warning) = response.headers().get("X-API-Warning") {
116            eprintln!("Warning: {:?}", String::from_utf8_lossy(warning.as_bytes()));
117        }
118
119        if response.status().is_client_error() || response.status().is_server_error() {
120            Err(self.parse_error(response).await)
121        } else {
122            Ok(response.json::<Response>().await?)
123        }
124    }
125
126    async fn request_stream<Request: Serialize, Response: DeserializeOwned + Send + 'static>(
127        &self,
128        route: &'static str,
129        payload: Request,
130    ) -> Result<Receiver<Result<Response, CohereStreamError>>, CohereApiError> {
131        let url =
132            Url::parse(&format!("{}/{route}", self.api_url)).expect("api url should be valid");
133
134        let mut response = self.client.post(url).json(&payload).send().await?;
135
136        if response.status().is_client_error() || response.status().is_server_error() {
137            return Err(self.parse_error(response).await);
138        }
139
140        let (tx, rx) = channel::<Result<Response, CohereStreamError>>(32);
141        tokio::spawn(async move {
142            let mut buf = bytes::BytesMut::with_capacity(1024);
143            while let Ok(Some(chunk)) = response.chunk().await {
144                if chunk.is_empty() {
145                    break;
146                }
147                buf.extend_from_slice(&chunk);
148                if !chunk.ends_with(b"\n") {
149                    continue;
150                }
151                match serde_json::from_slice::<Response>(&buf) {
152                    Ok(v) => tx
153                        .send(Ok(v))
154                        .await
155                        .expect("Failed to send message to channel"),
156                    Err(e) => tx
157                        .send(Err(CohereStreamError::from(e)))
158                        .await
159                        .expect("Failed to send error to channel"),
160                }
161                buf.clear()
162            }
163        });
164
165        Ok(rx)
166    }
167
168    async fn parse_error(&self, response: reqwest::Response) -> CohereApiError {
169        let status = response.status();
170        let text = response.text().await;
171        match text {
172            Err(_) => CohereApiError::Unknown,
173            Ok(text) => CohereApiError::ApiError(
174                status,
175                serde_json::from_str::<CohereApiErrorResponse>(&text)
176                    .unwrap_or(CohereApiErrorResponse {
177                        message: format!("Unknown API Error: {}", text),
178                    })
179                    .message,
180            ),
181        }
182    }
183
184    /// Verify that the Cohere API key being used is valid
185    pub async fn check_api_key(&self) -> Result<(), CohereApiError> {
186        let response = self
187            .request::<(), CohereCheckApiKeyResponse>("check-api-key", ())
188            .await?;
189
190        match response.valid {
191            true => Ok(()),
192            false => Err(CohereApiError::InvalidApiKey),
193        }
194    }
195
196    /// Generates realistic text conditioned on a given input.
197    pub async fn generate<'input>(
198        &self,
199        request: &GenerateRequest<'input>,
200    ) -> Result<Vec<Generation>, CohereApiError> {
201        let response = self
202            .request::<_, GenerateResponse>("generate", request)
203            .await?;
204
205        Ok(response.generations)
206    }
207
208    /// Chat with Cohere's LLM
209    pub async fn chat<'input>(
210        &self,
211        request: &ChatRequest<'input>,
212    ) -> Result<Receiver<Result<ChatStreamResponse, CohereStreamError>>, CohereApiError> {
213        let stream_request = ChatStreamRequest {
214            request,
215            stream: true,
216        };
217        let response = self
218            .request_stream::<_, ChatStreamResponse>("chat", stream_request)
219            .await?;
220
221        Ok(response)
222    }
223
224    /// Returns text embeddings.
225    /// An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
226    /// Embeddings can be used to create text classifiers as well as empower semantic search.
227    pub async fn embed<'input>(
228        &self,
229        request: &EmbedRequest<'input>,
230    ) -> Result<Vec<Vec<f64>>, CohereApiError> {
231        let response = self.request::<_, EmbedResponse>("embed", request).await?;
232
233        Ok(response.embeddings)
234    }
235
236    /// Makes a prediction about which label fits the specified text inputs best.
237    /// To make a prediction, classify uses the provided examples of text + label pairs as a reference.
238    pub async fn classify<'input>(
239        &self,
240        request: &ClassifyRequest<'input>,
241    ) -> Result<Vec<Classification>, CohereApiError> {
242        let response = self
243            .request::<_, ClassifyResponse>("classify", request)
244            .await?;
245
246        Ok(response.classifications)
247    }
248
249    /// Splits input text into smaller units called tokens using byte-pair encoding (BPE).
250    pub async fn tokenize<'input>(
251        &self,
252        request: &TokenizeRequest<'input>,
253    ) -> Result<TokenizeResponse, CohereApiError> {
254        let response = self.request("tokenize", request).await?;
255
256        Ok(response)
257    }
258
259    /// Takes tokens using byte-pair encoding and returns their text representation.
260    pub async fn detokenize<'input>(
261        &self,
262        request: &DetokenizeRequest<'input>,
263    ) -> Result<String, CohereApiError> {
264        let response = self
265            .request::<_, DetokenizeResponse>("detokenize", request)
266            .await?;
267
268        Ok(response.text)
269    }
270
271    /// Takes a query plus an list of texts and return an ordered array with each text assigned a relevance score.
272    pub async fn rerank<'input>(
273        &self,
274        request: &ReRankRequest<'input>,
275    ) -> Result<Vec<ReRankResult>, CohereApiError> {
276        let response = self.request::<_, ReRankResponse>("rerank", request).await?;
277
278        Ok(response.results)
279    }
280}