1use crate::{
4 EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage, LlmChunk, LlmError,
5 LlmProvider, LlmRequest, LlmResponse, LlmStream, Result, StreamUsage, StreamingLlmProvider,
6 Usage,
7};
8use async_trait::async_trait;
9use futures::stream::StreamExt;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13pub struct CohereProvider {
15 api_key: String,
16 model: String,
17 client: reqwest::Client,
18 base_url: String,
19}
20
21#[derive(Serialize)]
22struct CohereRequest {
23 message: String,
24 model: String,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 preamble: Option<String>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 temperature: Option<f64>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 max_tokens: Option<u32>,
31}
32
33#[derive(Deserialize)]
34struct CohereResponse {
35 text: String,
36 meta: Option<CohereMeta>,
37}
38
39#[derive(Deserialize)]
40struct CohereMeta {
41 billed_units: Option<CohereBilledUnits>,
42}
43
44#[derive(Deserialize)]
45struct CohereBilledUnits {
46 input_tokens: Option<u32>,
47 output_tokens: Option<u32>,
48}
49
50impl CohereProvider {
51 pub fn new(api_key: String, model: String) -> Self {
53 Self {
54 api_key,
55 model,
56 client: reqwest::Client::new(),
57 base_url: "https://api.cohere.ai/v1".to_string(),
58 }
59 }
60
61 pub fn for_embeddings(api_key: String) -> Self {
63 Self::new(api_key, "embed-english-v3.0".to_string())
64 }
65
66 pub fn with_base_url(mut self, base_url: String) -> Self {
68 self.base_url = base_url;
69 self
70 }
71}
72
73#[async_trait]
74impl LlmProvider for CohereProvider {
75 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
76 let cohere_request = CohereRequest {
77 message: request.prompt.clone(),
78 model: self.model.clone(),
79 preamble: request.system_prompt,
80 temperature: request.temperature,
81 max_tokens: request.max_tokens,
82 };
83
84 let response = self
85 .client
86 .post(format!("{}/chat", self.base_url))
87 .header("Authorization", format!("Bearer {}", self.api_key))
88 .header("Content-Type", "application/json")
89 .json(&cohere_request)
90 .send()
91 .await?;
92
93 let status = response.status();
94
95 if status == 429 {
96 let retry_after = response
98 .headers()
99 .get("retry-after")
100 .and_then(|v| v.to_str().ok())
101 .and_then(|s| s.parse::<u64>().ok())
102 .map(Duration::from_secs);
103
104 return Err(LlmError::RateLimited(retry_after));
105 }
106
107 let body = response.text().await?;
108
109 if !status.is_success() {
110 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
111 }
112
113 let cohere_response: CohereResponse =
114 serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
115
116 let usage = cohere_response.meta.and_then(|m| m.billed_units).map(|bu| {
117 let input = bu.input_tokens.unwrap_or(0);
118 let output = bu.output_tokens.unwrap_or(0);
119 Usage {
120 prompt_tokens: input,
121 completion_tokens: output,
122 total_tokens: input + output,
123 }
124 });
125
126 Ok(LlmResponse {
127 content: cohere_response.text,
128 model: self.model.clone(),
129 usage,
130 tool_calls: Vec::new(),
131 })
132 }
133}
134
135#[derive(Serialize)]
138struct CohereStreamRequest {
139 message: String,
140 model: String,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 preamble: Option<String>,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 temperature: Option<f64>,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 max_tokens: Option<u32>,
147 stream: bool,
148}
149
150#[derive(Deserialize)]
151#[serde(tag = "event_type")]
152enum CohereStreamEvent {
153 #[serde(rename = "text-generation")]
154 TextGeneration { text: String },
155 #[serde(rename = "stream-end")]
156 StreamEnd { response: CohereStreamEndResponse },
157 #[serde(other)]
158 Other,
159}
160
161#[derive(Deserialize)]
162struct CohereStreamEndResponse {
163 meta: Option<CohereStreamMeta>,
164}
165
166#[derive(Deserialize)]
167struct CohereStreamMeta {
168 billed_units: Option<CohereBilledUnits>,
169}
170
171#[async_trait]
172impl StreamingLlmProvider for CohereProvider {
173 async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
174 let cohere_request = CohereStreamRequest {
175 message: request.prompt.clone(),
176 model: self.model.clone(),
177 preamble: request.system_prompt,
178 temperature: request.temperature,
179 max_tokens: request.max_tokens,
180 stream: true,
181 };
182
183 let response = self
184 .client
185 .post(format!("{}/chat", self.base_url))
186 .header("Authorization", format!("Bearer {}", self.api_key))
187 .header("Content-Type", "application/json")
188 .json(&cohere_request)
189 .send()
190 .await?;
191
192 let status = response.status();
193
194 if status == 429 {
195 let retry_after = response
197 .headers()
198 .get("retry-after")
199 .and_then(|v| v.to_str().ok())
200 .and_then(|s| s.parse::<u64>().ok())
201 .map(Duration::from_secs);
202
203 return Err(LlmError::RateLimited(retry_after));
204 }
205
206 if !status.is_success() {
207 let body = response.text().await?;
208 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
209 }
210
211 let stream = response.bytes_stream();
212 let model_name = self.model.clone();
213
214 let parsed_stream = stream.filter_map(move |chunk_result| {
215 let model_name = model_name.clone();
216 async move {
217 match chunk_result {
218 Ok(bytes) => {
219 let text = String::from_utf8_lossy(&bytes);
220 for line in text.lines() {
221 if line.trim().is_empty() {
222 continue;
223 }
224
225 if let Ok(event) = serde_json::from_str::<CohereStreamEvent>(line) {
226 match event {
227 CohereStreamEvent::TextGeneration { text } => {
228 return Some(Ok(LlmChunk {
229 content: text,
230 done: false,
231 model: None,
232 usage: None,
233 }));
234 }
235 CohereStreamEvent::StreamEnd { response } => {
236 let usage =
237 response.meta.and_then(|m| m.billed_units).map(|bu| {
238 let input = bu.input_tokens.unwrap_or(0);
239 let output = bu.output_tokens.unwrap_or(0);
240 StreamUsage {
241 prompt_tokens: Some(input),
242 completion_tokens: Some(output),
243 total_tokens: Some(input + output),
244 }
245 });
246
247 return Some(Ok(LlmChunk {
248 content: String::new(),
249 done: true,
250 model: Some(model_name),
251 usage,
252 }));
253 }
254 CohereStreamEvent::Other => {}
255 }
256 }
257 }
258 None
259 }
260 Err(e) => Some(Err(LlmError::NetworkError(e))),
261 }
262 }
263 });
264
265 Ok(Box::pin(parsed_stream))
266 }
267}
268
269#[derive(Serialize)]
272struct CohereEmbeddingRequest {
273 texts: Vec<String>,
274 model: String,
275 input_type: String,
276}
277
278#[derive(Deserialize)]
279struct CohereEmbeddingResponse {
280 embeddings: Vec<Vec<f32>>,
281 meta: Option<CohereEmbeddingMeta>,
282}
283
284#[derive(Deserialize)]
285struct CohereEmbeddingMeta {
286 billed_units: Option<CohereEmbeddingBilledUnits>,
287}
288
289#[derive(Deserialize)]
290struct CohereEmbeddingBilledUnits {
291 input_tokens: Option<u32>,
292}
293
294#[async_trait]
295impl EmbeddingProvider for CohereProvider {
296 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
297 let model = request.model.unwrap_or_else(|| self.model.clone());
298
299 let cohere_request = CohereEmbeddingRequest {
300 texts: request.texts,
301 model: model.clone(),
302 input_type: "search_document".to_string(),
303 };
304
305 let response = self
306 .client
307 .post(format!("{}/embed", self.base_url))
308 .header("Authorization", format!("Bearer {}", self.api_key))
309 .header("Content-Type", "application/json")
310 .json(&cohere_request)
311 .send()
312 .await?;
313
314 let status = response.status();
315
316 if status == 429 {
317 let retry_after = response
319 .headers()
320 .get("retry-after")
321 .and_then(|v| v.to_str().ok())
322 .and_then(|s| s.parse::<u64>().ok())
323 .map(Duration::from_secs);
324
325 return Err(LlmError::RateLimited(retry_after));
326 }
327
328 let body = response.text().await?;
329
330 if !status.is_success() {
331 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
332 }
333
334 let cohere_response: CohereEmbeddingResponse =
335 serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
336
337 let usage = cohere_response.meta.and_then(|m| m.billed_units).map(|bu| {
338 let input = bu.input_tokens.unwrap_or(0);
339 EmbeddingUsage {
340 prompt_tokens: input,
341 total_tokens: input,
342 }
343 });
344
345 Ok(EmbeddingResponse {
346 embeddings: cohere_response.embeddings,
347 model,
348 usage,
349 })
350 }
351}