oxify_connect_llm/
cohere.rs

1//! Cohere LLM provider
2
3use crate::{
4    EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage, LlmChunk, LlmError,
5    LlmProvider, LlmRequest, LlmResponse, LlmStream, Result, StreamUsage, StreamingLlmProvider,
6    Usage,
7};
8use async_trait::async_trait;
9use futures::stream::StreamExt;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13/// Cohere provider
14pub struct CohereProvider {
15    api_key: String,
16    model: String,
17    client: reqwest::Client,
18    base_url: String,
19}
20
21#[derive(Serialize)]
22struct CohereRequest {
23    message: String,
24    model: String,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    preamble: Option<String>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    temperature: Option<f64>,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    max_tokens: Option<u32>,
31}
32
33#[derive(Deserialize)]
34struct CohereResponse {
35    text: String,
36    meta: Option<CohereMeta>,
37}
38
39#[derive(Deserialize)]
40struct CohereMeta {
41    billed_units: Option<CohereBilledUnits>,
42}
43
44#[derive(Deserialize)]
45struct CohereBilledUnits {
46    input_tokens: Option<u32>,
47    output_tokens: Option<u32>,
48}
49
50impl CohereProvider {
51    /// Create a new Cohere provider
52    pub fn new(api_key: String, model: String) -> Self {
53        Self {
54            api_key,
55            model,
56            client: reqwest::Client::new(),
57            base_url: "https://api.cohere.ai/v1".to_string(),
58        }
59    }
60
61    /// Create a provider specifically for embeddings
62    pub fn for_embeddings(api_key: String) -> Self {
63        Self::new(api_key, "embed-english-v3.0".to_string())
64    }
65
66    /// Set custom base URL
67    pub fn with_base_url(mut self, base_url: String) -> Self {
68        self.base_url = base_url;
69        self
70    }
71}
72
73#[async_trait]
74impl LlmProvider for CohereProvider {
75    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
76        let cohere_request = CohereRequest {
77            message: request.prompt.clone(),
78            model: self.model.clone(),
79            preamble: request.system_prompt,
80            temperature: request.temperature,
81            max_tokens: request.max_tokens,
82        };
83
84        let response = self
85            .client
86            .post(format!("{}/chat", self.base_url))
87            .header("Authorization", format!("Bearer {}", self.api_key))
88            .header("Content-Type", "application/json")
89            .json(&cohere_request)
90            .send()
91            .await?;
92
93        let status = response.status();
94
95        if status == 429 {
96            // Extract Retry-After header if present
97            let retry_after = response
98                .headers()
99                .get("retry-after")
100                .and_then(|v| v.to_str().ok())
101                .and_then(|s| s.parse::<u64>().ok())
102                .map(Duration::from_secs);
103
104            return Err(LlmError::RateLimited(retry_after));
105        }
106
107        let body = response.text().await?;
108
109        if !status.is_success() {
110            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
111        }
112
113        let cohere_response: CohereResponse =
114            serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
115
116        let usage = cohere_response.meta.and_then(|m| m.billed_units).map(|bu| {
117            let input = bu.input_tokens.unwrap_or(0);
118            let output = bu.output_tokens.unwrap_or(0);
119            Usage {
120                prompt_tokens: input,
121                completion_tokens: output,
122                total_tokens: input + output,
123            }
124        });
125
126        Ok(LlmResponse {
127            content: cohere_response.text,
128            model: self.model.clone(),
129            usage,
130            tool_calls: Vec::new(),
131        })
132    }
133}
134
135// ===== Cohere Streaming Implementation =====
136
137#[derive(Serialize)]
138struct CohereStreamRequest {
139    message: String,
140    model: String,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    preamble: Option<String>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    temperature: Option<f64>,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    max_tokens: Option<u32>,
147    stream: bool,
148}
149
150#[derive(Deserialize)]
151#[serde(tag = "event_type")]
152enum CohereStreamEvent {
153    #[serde(rename = "text-generation")]
154    TextGeneration { text: String },
155    #[serde(rename = "stream-end")]
156    StreamEnd { response: CohereStreamEndResponse },
157    #[serde(other)]
158    Other,
159}
160
161#[derive(Deserialize)]
162struct CohereStreamEndResponse {
163    meta: Option<CohereStreamMeta>,
164}
165
166#[derive(Deserialize)]
167struct CohereStreamMeta {
168    billed_units: Option<CohereBilledUnits>,
169}
170
171#[async_trait]
172impl StreamingLlmProvider for CohereProvider {
173    async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
174        let cohere_request = CohereStreamRequest {
175            message: request.prompt.clone(),
176            model: self.model.clone(),
177            preamble: request.system_prompt,
178            temperature: request.temperature,
179            max_tokens: request.max_tokens,
180            stream: true,
181        };
182
183        let response = self
184            .client
185            .post(format!("{}/chat", self.base_url))
186            .header("Authorization", format!("Bearer {}", self.api_key))
187            .header("Content-Type", "application/json")
188            .json(&cohere_request)
189            .send()
190            .await?;
191
192        let status = response.status();
193
194        if status == 429 {
195            // Extract Retry-After header if present
196            let retry_after = response
197                .headers()
198                .get("retry-after")
199                .and_then(|v| v.to_str().ok())
200                .and_then(|s| s.parse::<u64>().ok())
201                .map(Duration::from_secs);
202
203            return Err(LlmError::RateLimited(retry_after));
204        }
205
206        if !status.is_success() {
207            let body = response.text().await?;
208            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
209        }
210
211        let stream = response.bytes_stream();
212        let model_name = self.model.clone();
213
214        let parsed_stream = stream.filter_map(move |chunk_result| {
215            let model_name = model_name.clone();
216            async move {
217                match chunk_result {
218                    Ok(bytes) => {
219                        let text = String::from_utf8_lossy(&bytes);
220                        for line in text.lines() {
221                            if line.trim().is_empty() {
222                                continue;
223                            }
224
225                            if let Ok(event) = serde_json::from_str::<CohereStreamEvent>(line) {
226                                match event {
227                                    CohereStreamEvent::TextGeneration { text } => {
228                                        return Some(Ok(LlmChunk {
229                                            content: text,
230                                            done: false,
231                                            model: None,
232                                            usage: None,
233                                        }));
234                                    }
235                                    CohereStreamEvent::StreamEnd { response } => {
236                                        let usage =
237                                            response.meta.and_then(|m| m.billed_units).map(|bu| {
238                                                let input = bu.input_tokens.unwrap_or(0);
239                                                let output = bu.output_tokens.unwrap_or(0);
240                                                StreamUsage {
241                                                    prompt_tokens: Some(input),
242                                                    completion_tokens: Some(output),
243                                                    total_tokens: Some(input + output),
244                                                }
245                                            });
246
247                                        return Some(Ok(LlmChunk {
248                                            content: String::new(),
249                                            done: true,
250                                            model: Some(model_name),
251                                            usage,
252                                        }));
253                                    }
254                                    CohereStreamEvent::Other => {}
255                                }
256                            }
257                        }
258                        None
259                    }
260                    Err(e) => Some(Err(LlmError::NetworkError(e))),
261                }
262            }
263        });
264
265        Ok(Box::pin(parsed_stream))
266    }
267}
268
269// ===== Cohere Embeddings Implementation =====
270
271#[derive(Serialize)]
272struct CohereEmbeddingRequest {
273    texts: Vec<String>,
274    model: String,
275    input_type: String,
276}
277
278#[derive(Deserialize)]
279struct CohereEmbeddingResponse {
280    embeddings: Vec<Vec<f32>>,
281    meta: Option<CohereEmbeddingMeta>,
282}
283
284#[derive(Deserialize)]
285struct CohereEmbeddingMeta {
286    billed_units: Option<CohereEmbeddingBilledUnits>,
287}
288
289#[derive(Deserialize)]
290struct CohereEmbeddingBilledUnits {
291    input_tokens: Option<u32>,
292}
293
294#[async_trait]
295impl EmbeddingProvider for CohereProvider {
296    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
297        let model = request.model.unwrap_or_else(|| self.model.clone());
298
299        let cohere_request = CohereEmbeddingRequest {
300            texts: request.texts,
301            model: model.clone(),
302            input_type: "search_document".to_string(),
303        };
304
305        let response = self
306            .client
307            .post(format!("{}/embed", self.base_url))
308            .header("Authorization", format!("Bearer {}", self.api_key))
309            .header("Content-Type", "application/json")
310            .json(&cohere_request)
311            .send()
312            .await?;
313
314        let status = response.status();
315
316        if status == 429 {
317            // Extract Retry-After header if present
318            let retry_after = response
319                .headers()
320                .get("retry-after")
321                .and_then(|v| v.to_str().ok())
322                .and_then(|s| s.parse::<u64>().ok())
323                .map(Duration::from_secs);
324
325            return Err(LlmError::RateLimited(retry_after));
326        }
327
328        let body = response.text().await?;
329
330        if !status.is_success() {
331            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
332        }
333
334        let cohere_response: CohereEmbeddingResponse =
335            serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
336
337        let usage = cohere_response.meta.and_then(|m| m.billed_units).map(|bu| {
338            let input = bu.input_tokens.unwrap_or(0);
339            EmbeddingUsage {
340                prompt_tokens: input,
341                total_tokens: input,
342            }
343        });
344
345        Ok(EmbeddingResponse {
346            embeddings: cohere_response.embeddings,
347            model,
348            usage,
349        })
350    }
351}