1use std::time::Duration;
2
3use api::{
4 chat::{ChatRequest, ChatStreamRequest, ChatStreamResponse},
5 classify::{Classification, ClassifyRequest, ClassifyResponse},
6 detokenize::{DetokenizeRequest, DetokenizeResponse},
7 embed::{EmbedRequest, EmbedResponse},
8 generate::{GenerateRequest, GenerateResponse, Generation},
9 rerank::{ReRankRequest, ReRankResponse, ReRankResult},
10 tokenize::{TokenizeRequest, TokenizeResponse},
11};
12use reqwest::{header, ClientBuilder, StatusCode, Url};
13use tokio::sync::mpsc::{channel, Receiver};
14
15const COHERE_API_BASE_URL: &str = "https://api.cohere.com";
16const COHERE_API_V1: &str = "v1";
17const COHERE_API_TIMEOUT: Duration = Duration::from_secs(240);
18
19use serde::{de::DeserializeOwned, Deserialize, Serialize};
20use thiserror::Error;
21
22pub mod api;
23
24#[derive(Error, Debug)]
25pub enum CohereApiError {
26 #[error("Unexpected request error")]
27 RequestError(#[from] reqwest::Error),
28 #[error("API request failed with status code `{0}` and error message `{1}`")]
29 ApiError(StatusCode, String),
30 #[error("API key is invalid")]
31 InvalidApiKey,
32 #[error("Unknown error")]
33 Unknown,
34}
35
36#[derive(Error, Debug)]
37pub enum CohereStreamError {
38 #[error("Unexpected deserialization error")]
39 RequestError(#[from] serde_json::error::Error),
40 #[error("Unknown error `{0}`")]
41 Unknown(String),
42}
43
44pub struct Cohere {
46 api_url: String,
47 client: reqwest::Client,
48}
49
50#[derive(Deserialize, Debug)]
51struct CohereCheckApiKeyResponse {
52 valid: bool,
53}
54
55#[derive(Deserialize, Debug)]
56struct CohereApiErrorResponse {
57 message: String,
58}
59
60impl Default for Cohere {
61 fn default() -> Self {
62 let api_key = std::env::var("COHERE_API_KEY")
63 .expect("please provide a Cohere API key with the 'COHERE_API_KEY' env variable");
64 Cohere::new(format!("{COHERE_API_BASE_URL}/{COHERE_API_V1}"), api_key)
65 }
66}
67
68impl Cohere {
69 pub fn new<U: Into<String>, K: Into<String>>(api_url: U, api_key: K) -> Self {
70 let api_url: String = api_url.into();
71 let api_key: String = api_key.into();
72
73 let mut headers = header::HeaderMap::new();
74
75 let mut authorization = header::HeaderValue::from_str(&format!("Bearer {api_key}"))
76 .expect("failed to construct authorization header!");
77 authorization.set_sensitive(true);
78 headers.insert(header::AUTHORIZATION, authorization);
79
80 headers.insert(
81 "Request-Source",
82 header::HeaderValue::from_static("rust-sdk"),
83 );
84
85 headers.insert(
86 header::ACCEPT,
87 header::HeaderValue::from_static("application/json"),
88 );
89 headers.insert(
90 header::CONTENT_TYPE,
91 header::HeaderValue::from_static("application/json"),
92 );
93
94 let client = ClientBuilder::new()
95 .default_headers(headers)
96 .use_rustls_tls()
97 .timeout(COHERE_API_TIMEOUT)
98 .build()
99 .expect("failed to initialize HTTP client!");
100
101 Cohere { api_url, client }
102 }
103
104 async fn request<Request: Serialize, Response: DeserializeOwned>(
105 &self,
106 route: &'static str,
107 payload: Request,
108 ) -> Result<Response, CohereApiError> {
109 let url =
110 Url::parse(&format!("{}/{route}", self.api_url)).expect("api url should be valid");
111
112 let response = self.client.post(url).json(&payload).send().await?;
113
114 if let Some(warning) = response.headers().get("X-API-Warning") {
116 eprintln!("Warning: {:?}", String::from_utf8_lossy(warning.as_bytes()));
117 }
118
119 if response.status().is_client_error() || response.status().is_server_error() {
120 Err(self.parse_error(response).await)
121 } else {
122 Ok(response.json::<Response>().await?)
123 }
124 }
125
126 async fn request_stream<Request: Serialize, Response: DeserializeOwned + Send + 'static>(
127 &self,
128 route: &'static str,
129 payload: Request,
130 ) -> Result<Receiver<Result<Response, CohereStreamError>>, CohereApiError> {
131 let url =
132 Url::parse(&format!("{}/{route}", self.api_url)).expect("api url should be valid");
133
134 let mut response = self.client.post(url).json(&payload).send().await?;
135
136 if response.status().is_client_error() || response.status().is_server_error() {
137 return Err(self.parse_error(response).await);
138 }
139
140 let (tx, rx) = channel::<Result<Response, CohereStreamError>>(32);
141 tokio::spawn(async move {
142 let mut buf = bytes::BytesMut::with_capacity(1024);
143 while let Ok(Some(chunk)) = response.chunk().await {
144 if chunk.is_empty() {
145 break;
146 }
147 buf.extend_from_slice(&chunk);
148 if !chunk.ends_with(b"\n") {
149 continue;
150 }
151 match serde_json::from_slice::<Response>(&buf) {
152 Ok(v) => tx
153 .send(Ok(v))
154 .await
155 .expect("Failed to send message to channel"),
156 Err(e) => tx
157 .send(Err(CohereStreamError::from(e)))
158 .await
159 .expect("Failed to send error to channel"),
160 }
161 buf.clear()
162 }
163 });
164
165 Ok(rx)
166 }
167
168 async fn parse_error(&self, response: reqwest::Response) -> CohereApiError {
169 let status = response.status();
170 let text = response.text().await;
171 match text {
172 Err(_) => CohereApiError::Unknown,
173 Ok(text) => CohereApiError::ApiError(
174 status,
175 serde_json::from_str::<CohereApiErrorResponse>(&text)
176 .unwrap_or(CohereApiErrorResponse {
177 message: format!("Unknown API Error: {}", text),
178 })
179 .message,
180 ),
181 }
182 }
183
184 pub async fn check_api_key(&self) -> Result<(), CohereApiError> {
186 let response = self
187 .request::<(), CohereCheckApiKeyResponse>("check-api-key", ())
188 .await?;
189
190 match response.valid {
191 true => Ok(()),
192 false => Err(CohereApiError::InvalidApiKey),
193 }
194 }
195
196 pub async fn generate<'input>(
198 &self,
199 request: &GenerateRequest<'input>,
200 ) -> Result<Vec<Generation>, CohereApiError> {
201 let response = self
202 .request::<_, GenerateResponse>("generate", request)
203 .await?;
204
205 Ok(response.generations)
206 }
207
208 pub async fn chat<'input>(
210 &self,
211 request: &ChatRequest<'input>,
212 ) -> Result<Receiver<Result<ChatStreamResponse, CohereStreamError>>, CohereApiError> {
213 let stream_request = ChatStreamRequest {
214 request,
215 stream: true,
216 };
217 let response = self
218 .request_stream::<_, ChatStreamResponse>("chat", stream_request)
219 .await?;
220
221 Ok(response)
222 }
223
224 pub async fn embed<'input>(
228 &self,
229 request: &EmbedRequest<'input>,
230 ) -> Result<Vec<Vec<f64>>, CohereApiError> {
231 let response = self.request::<_, EmbedResponse>("embed", request).await?;
232
233 Ok(response.embeddings)
234 }
235
236 pub async fn classify<'input>(
239 &self,
240 request: &ClassifyRequest<'input>,
241 ) -> Result<Vec<Classification>, CohereApiError> {
242 let response = self
243 .request::<_, ClassifyResponse>("classify", request)
244 .await?;
245
246 Ok(response.classifications)
247 }
248
249 pub async fn tokenize<'input>(
251 &self,
252 request: &TokenizeRequest<'input>,
253 ) -> Result<TokenizeResponse, CohereApiError> {
254 let response = self.request("tokenize", request).await?;
255
256 Ok(response)
257 }
258
259 pub async fn detokenize<'input>(
261 &self,
262 request: &DetokenizeRequest<'input>,
263 ) -> Result<String, CohereApiError> {
264 let response = self
265 .request::<_, DetokenizeResponse>("detokenize", request)
266 .await?;
267
268 Ok(response.text)
269 }
270
271 pub async fn rerank<'input>(
273 &self,
274 request: &ReRankRequest<'input>,
275 ) -> Result<Vec<ReRankResult>, CohereApiError> {
276 let response = self.request::<_, ReRankResponse>("rerank", request).await?;
277
278 Ok(response.results)
279 }
280}