ai_lib/provider/
cohere.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::Stream;
8use std::clone::Clone;
9use std::collections::HashMap;
10use std::sync::Arc;
11#[cfg(feature = "unified_transport")]
12use std::time::Duration;
13
14#[cfg(not(feature = "unified_sse"))]
15fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
16    let mut i = 0;
17    while i < buffer.len().saturating_sub(1) {
18        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
19            return Some(i + 2);
20        }
21        if i < buffer.len().saturating_sub(3)
22            && buffer[i] == b'\r'
23            && buffer[i + 1] == b'\n'
24            && buffer[i + 2] == b'\r'
25            && buffer[i + 3] == b'\n'
26        {
27            return Some(i + 4);
28        }
29        i += 1;
30    }
31    None
32}
33
34#[cfg(not(feature = "unified_sse"))]
35fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
36    for line in event_text.lines() {
37        let line = line.trim();
38        if let Some(stripped) = line.strip_prefix("data: ") {
39            let data = stripped;
40            if data == "[DONE]" {
41                return Some(Ok(None));
42            }
43            return Some(parse_chunk_data(data));
44        }
45    }
46    None
47}
48
49#[cfg(not(feature = "unified_sse"))]
50fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
51    match serde_json::from_str::<serde_json::Value>(data) {
52        Ok(json) => {
53            let choices = json["choices"]
54                .as_array()
55                .map(|arr| {
56                    arr.iter()
57                        .enumerate()
58                        .map(|(index, choice)| {
59                            let delta = &choice["delta"];
60                            crate::api::ChoiceDelta {
61                                index: index as u32,
62                                delta: crate::api::MessageDelta {
63                                    role: delta["role"].as_str().map(|r| match r {
64                                        "assistant" => crate::types::Role::Assistant,
65                                        "user" => crate::types::Role::User,
66                                        "system" => crate::types::Role::System,
67                                        _ => crate::types::Role::Assistant,
68                                    }),
69                                    content: delta["content"].as_str().map(str::to_string),
70                                },
71                                finish_reason: choice["finish_reason"].as_str().map(str::to_string),
72                            }
73                        })
74                        .collect()
75                })
76                .unwrap_or_default();
77
78            Ok(Some(ChatCompletionChunk {
79                id: json["id"].as_str().unwrap_or_default().to_string(),
80                object: json["object"]
81                    .as_str()
82                    .unwrap_or("chat.completion.chunk")
83                    .to_string(),
84                created: json["created"].as_u64().unwrap_or(0),
85                model: json["model"].as_str().unwrap_or_default().to_string(),
86                choices,
87            }))
88        }
89        Err(e) => Err(AiLibError::ProviderError(format!(
90            "JSON parse error: {}",
91            e
92        ))),
93    }
94}
95
96fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
97    let mut chunks = Vec::new();
98    let mut start = 0;
99    let s = text.as_bytes();
100    while start < s.len() {
101        let end = std::cmp::min(start + max_len, s.len());
102        // try to cut at last whitespace
103        let mut cut = end;
104        if end < s.len() {
105            if let Some(pos) = text[start..end].rfind(' ') {
106                cut = start + pos;
107            }
108        }
109        if cut == start {
110            cut = end;
111        }
112        let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
113        chunks.push(chunk);
114        start = cut;
115        if start < s.len() && s[start] == b' ' {
116            start += 1;
117        }
118    }
119    chunks
120}
121use futures::StreamExt;
122use tokio::sync::mpsc;
123use tokio_stream::wrappers::UnboundedReceiverStream;
124
125pub struct CohereAdapter {
126    #[allow(dead_code)] // Kept for backward compatibility, now using direct reqwest
127    transport: DynHttpTransportRef,
128    api_key: String,
129    base_url: String,
130    metrics: Arc<dyn Metrics>,
131}
132
133impl CohereAdapter {
134    fn build_default_timeout_secs() -> u64 {
135        std::env::var("AI_HTTP_TIMEOUT_SECS")
136            .ok()
137            .and_then(|s| s.parse::<u64>().ok())
138            .unwrap_or(30)
139    }
140
141    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
142        #[cfg(feature = "unified_transport")]
143        {
144            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
145            let client = crate::transport::client_factory::build_shared_client()
146                .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
147            let t = HttpTransport::with_reqwest_client(client, timeout);
148            return Ok(t.boxed());
149        }
150        #[cfg(not(feature = "unified_transport"))]
151        {
152            let t = HttpTransport::new();
153            return Ok(t.boxed());
154        }
155    }
156
157    /// Create Cohere adapter. Requires COHERE_API_KEY env var.
158    pub fn new() -> Result<Self, AiLibError> {
159        let api_key = std::env::var("COHERE_API_KEY").map_err(|_| {
160            AiLibError::AuthenticationError(
161                "COHERE_API_KEY environment variable not set".to_string(),
162            )
163        })?;
164        let base_url = std::env::var("COHERE_BASE_URL")
165            .unwrap_or_else(|_| "https://api.cohere.ai".to_string());
166        Ok(Self {
167            transport: Self::build_default_transport()?,
168            api_key,
169            base_url,
170            metrics: Arc::new(NoopMetrics::new()),
171        })
172    }
173
174    /// Explicit overrides for api_key and optional base_url (takes precedence over env vars)
175    pub fn new_with_overrides(
176        api_key: String,
177        base_url: Option<String>,
178    ) -> Result<Self, AiLibError> {
179        let resolved_base = base_url.unwrap_or_else(|| {
180            std::env::var("COHERE_BASE_URL").unwrap_or_else(|_| "https://api.cohere.ai".to_string())
181        });
182        Ok(Self {
183            transport: Self::build_default_transport()?,
184            api_key,
185            base_url: resolved_base,
186            metrics: Arc::new(NoopMetrics::new()),
187        })
188    }
189
190    /// Create adapter with injectable transport (for testing)
191    pub fn with_transport(transport: HttpTransport, api_key: String, base_url: String) -> Self {
192        Self {
193            transport: transport.boxed(),
194            api_key,
195            base_url,
196            metrics: Arc::new(NoopMetrics::new()),
197        }
198    }
199
200    /// Construct using object-safe transport reference
201    pub fn with_transport_ref(
202        transport: DynHttpTransportRef,
203        api_key: String,
204        base_url: String,
205    ) -> Self {
206        Self {
207            transport,
208            api_key,
209            base_url,
210            metrics: Arc::new(NoopMetrics::new()),
211        }
212    }
213
214    /// Construct with an injected transport reference and metrics implementation
215    pub fn with_transport_ref_and_metrics(
216        transport: DynHttpTransportRef,
217        api_key: String,
218        base_url: String,
219        metrics: Arc<dyn Metrics>,
220    ) -> Self {
221        Self {
222            transport,
223            api_key,
224            base_url,
225            metrics,
226        }
227    }
228
229    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
230        // Build v2/chat format according to Cohere documentation
231        let msgs: Vec<serde_json::Value> = request
232            .messages
233            .iter()
234            .map(|msg| {
235                serde_json::json!({
236                    "role": match msg.role {
237                        Role::System => "system",
238                        Role::User => "user",
239                        Role::Assistant => "assistant",
240                    },
241                    "content": msg.content.as_text()
242                })
243            })
244            .collect();
245
246        let mut chat_body = serde_json::json!({
247            "model": request.model,
248            "messages": msgs,
249        });
250
251        // Add optional parameters
252        if let Some(temp) = request.temperature {
253            chat_body["temperature"] =
254                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
255        }
256        if let Some(max_tokens) = request.max_tokens {
257            chat_body["max_tokens"] =
258                serde_json::Value::Number(serde_json::Number::from(max_tokens));
259        }
260
261        chat_body
262    }
263
264    fn parse_response(
265        &self,
266        response: serde_json::Value,
267    ) -> Result<ChatCompletionResponse, AiLibError> {
268        // Try different response formats: OpenAI-like choices, Cohere v2/chat message, or Cohere v1 generations
269        let content = if let Some(c) = response.get("choices") {
270            // OpenAI format
271            c[0]["message"]["content"]
272                .as_str()
273                .map(|s| s.to_string())
274                .or_else(|| c[0]["text"].as_str().map(|s| s.to_string()))
275        } else if let Some(msg) = response.get("message") {
276            // Cohere v2/chat format
277            msg.get("content").and_then(|content_array| {
278                content_array
279                    .as_array()
280                    .and_then(|arr| arr.first())
281                    .and_then(|content_obj| {
282                        content_obj
283                            .get("text")
284                            .and_then(|t| t.as_str())
285                            .map(|text| text.to_string())
286                    })
287            })
288        } else if let Some(gens) = response.get("generations") {
289            // Cohere v1 format
290            gens[0]["text"].as_str().map(|s| s.to_string())
291        } else {
292            None
293        };
294
295        let content = content.unwrap_or_default();
296
297        let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
298        if let Some(fc_val) = response.get("function_call") {
299            if let Ok(fc) =
300                serde_json::from_value::<crate::types::function_call::FunctionCall>(fc_val.clone())
301            {
302                function_call = Some(fc);
303            } else if let Some(name) = fc_val
304                .get("name")
305                .and_then(|v| v.as_str())
306                .map(|s| s.to_string())
307            {
308                let args = fc_val.get("arguments").and_then(|a| {
309                    if a.is_string() {
310                        serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
311                    } else {
312                        Some(a.clone())
313                    }
314                });
315                function_call = Some(crate::types::function_call::FunctionCall {
316                    name,
317                    arguments: args,
318                });
319            }
320        } else if let Some(tool_calls) = response.get("tool_calls").and_then(|v| v.as_array()) {
321            if let Some(first) = tool_calls.first() {
322                if let Some(func) = first.get("function").or_else(|| first.get("function_call")) {
323                    if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
324                        let mut args_opt = func.get("arguments").cloned();
325                        if let Some(args_val) = &args_opt {
326                            if args_val.is_string() {
327                                if let Some(s) = args_val.as_str() {
328                                    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s) {
329                                        args_opt = Some(parsed);
330                                    }
331                                }
332                            }
333                        }
334                        function_call = Some(crate::types::function_call::FunctionCall {
335                            name: name.to_string(),
336                            arguments: args_opt,
337                        });
338                    }
339                }
340            }
341        }
342
343        let choice = Choice {
344            index: 0,
345            message: Message {
346                role: Role::Assistant,
347                content: crate::types::common::Content::Text(content.clone()),
348                function_call,
349            },
350            finish_reason: None,
351        };
352
353        let usage = Usage {
354            prompt_tokens: 0,
355            completion_tokens: 0,
356            total_tokens: 0,
357        };
358
359        Ok(ChatCompletionResponse {
360            id: response["id"].as_str().unwrap_or_default().to_string(),
361            object: response["object"].as_str().unwrap_or_default().to_string(),
362            created: response["created"].as_u64().unwrap_or(0),
363            model: response["model"].as_str().unwrap_or_default().to_string(),
364            choices: vec![choice],
365            usage,
366        })
367    }
368}
369
370#[async_trait::async_trait]
371impl ChatApi for CohereAdapter {
372    async fn chat_completion(
373        &self,
374        request: ChatCompletionRequest,
375    ) -> Result<ChatCompletionResponse, AiLibError> {
376        self.metrics.incr_counter("cohere.requests", 1).await;
377        let timer = self.metrics.start_timer("cohere.request_duration_ms").await;
378
379        let _body = self.convert_request(&request);
380
381        // Use v1/generate endpoint (fallback for older API keys)
382        let url_generate = format!("{}/v1/generate", self.base_url);
383
384        let mut headers = HashMap::new();
385        headers.insert(
386            "Authorization".to_string(),
387            format!("Bearer {}", self.api_key),
388        );
389        headers.insert("Content-Type".to_string(), "application/json".to_string());
390        headers.insert("Accept".to_string(), "application/json".to_string());
391
392        // Convert messages to prompt string for v1/generate endpoint
393        let prompt = request
394            .messages
395            .iter()
396            .map(|msg| match msg.role {
397                Role::System => format!("System: {}", msg.content.as_text()),
398                Role::User => format!("Human: {}", msg.content.as_text()),
399                Role::Assistant => format!("Assistant: {}", msg.content.as_text()),
400            })
401            .collect::<Vec<_>>()
402            .join("\n");
403
404        let mut generate_body = serde_json::json!({
405            "model": request.model,
406            "prompt": prompt,
407        });
408
409        if let Some(temp) = request.temperature {
410            generate_body["temperature"] =
411                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
412        }
413        if let Some(max_tokens) = request.max_tokens {
414            generate_body["max_tokens"] =
415                serde_json::Value::Number(serde_json::Number::from(max_tokens));
416        }
417
418        // Use unified transport
419        let response_json = self
420            .transport
421            .post_json(&url_generate, Some(headers), generate_body)
422            .await?;
423
424        if let Some(t) = timer {
425            t.stop();
426        }
427
428        self.parse_response(response_json)
429    }
430
431    async fn chat_completion_stream(
432        &self,
433        _request: ChatCompletionRequest,
434    ) -> Result<
435        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
436        AiLibError,
437    > {
438        // Build stream request similar to chat_completion but with stream=true
439        let mut stream_request = self.convert_request(&_request);
440        stream_request["stream"] = serde_json::Value::Bool(true);
441
442        let url = format!("{}/v1/chat", self.base_url);
443
444        let mut headers = HashMap::new();
445        headers.insert(
446            "Authorization".to_string(),
447            format!("Bearer {}", self.api_key),
448        );
449        // Try unified transport streaming first
450        if let Ok(byte_stream) = self
451            .transport
452            .post_stream(&url, Some(headers.clone()), stream_request.clone())
453            .await
454        {
455            let (tx, rx) = mpsc::unbounded_channel();
456            tokio::spawn(async move {
457                let mut buffer = Vec::new();
458                futures::pin_mut!(byte_stream);
459                while let Some(item) = byte_stream.next().await {
460                    match item {
461                        Ok(bytes) => {
462                            buffer.extend_from_slice(&bytes);
463                            #[cfg(feature = "unified_sse")]
464                            {
465                                while let Some(boundary) =
466                                    crate::sse::parser::find_event_boundary(&buffer)
467                                {
468                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
469                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
470                                        if let Some(parsed) =
471                                            crate::sse::parser::parse_sse_event(event_text)
472                                        {
473                                            match parsed {
474                                                Ok(Some(chunk)) => {
475                                                    if tx.send(Ok(chunk)).is_err() {
476                                                        return;
477                                                    }
478                                                }
479                                                Ok(None) => return,
480                                                Err(e) => {
481                                                    let _ = tx.send(Err(e));
482                                                    return;
483                                                }
484                                            }
485                                        }
486                                    }
487                                }
488                            }
489                            #[cfg(not(feature = "unified_sse"))]
490                            {
491                                while let Some(boundary) = find_event_boundary(&buffer) {
492                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
493                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
494                                        if let Some(parsed) = parse_sse_event(event_text) {
495                                            match parsed {
496                                                Ok(Some(chunk)) => {
497                                                    if tx.send(Ok(chunk)).is_err() {
498                                                        return;
499                                                    }
500                                                }
501                                                Ok(None) => return,
502                                                Err(e) => {
503                                                    let _ = tx.send(Err(e));
504                                                    return;
505                                                }
506                                            }
507                                        }
508                                    }
509                                }
510                            }
511                        }
512                        Err(e) => {
513                            let _ = tx.send(Err(AiLibError::ProviderError(format!(
514                                "Stream error: {}",
515                                e
516                            ))));
517                            break;
518                        }
519                    }
520                }
521            });
522            let stream = UnboundedReceiverStream::new(rx);
523            return Ok(Box::new(Box::pin(stream)));
524        }
525
526        // Fallback: call non-streaming chat_completion and stream the result in chunks
527        let finished = self.chat_completion(_request.clone()).await?;
528        let text = finished
529            .choices
530            .first()
531            .map(|c| c.message.content.as_text())
532            .unwrap_or_default();
533
534        let (tx, rx) = mpsc::unbounded_channel();
535
536        tokio::spawn(async move {
537            let chunks = split_text_into_chunks(&text, 80);
538            for chunk in chunks {
539                // construct ChatCompletionChunk with single delta
540                let delta = crate::api::ChoiceDelta {
541                    index: 0,
542                    delta: crate::api::MessageDelta {
543                        role: Some(crate::types::Role::Assistant),
544                        content: Some(chunk.clone()),
545                    },
546                    finish_reason: None,
547                };
548                let chunk_obj = ChatCompletionChunk {
549                    id: "simulated".to_string(),
550                    object: "chat.completion.chunk".to_string(),
551                    created: 0,
552                    model: finished.model.clone(),
553                    choices: vec![delta],
554                };
555
556                if tx.send(Ok(chunk_obj)).is_err() {
557                    return;
558                }
559                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
560            }
561        });
562
563        let stream = UnboundedReceiverStream::new(rx);
564        Ok(Box::new(Box::pin(stream)))
565    }
566
567    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
568        // Use v1/models endpoint for listing models via unified transport
569        let url = format!("{}/v1/models", self.base_url);
570        let mut headers = HashMap::new();
571        headers.insert(
572            "Authorization".to_string(),
573            format!("Bearer {}", self.api_key),
574        );
575
576        let response = self.transport.get_json(&url, Some(headers)).await?;
577
578        Ok(response["models"]
579            .as_array()
580            .unwrap_or(&vec![])
581            .iter()
582            .filter_map(|m| {
583                m["id"]
584                    .as_str()
585                    .map(|s| s.to_string())
586                    .or_else(|| m["name"].as_str().map(|s| s.to_string()))
587            })
588            .collect())
589    }
590
591    async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
592        Ok(ModelInfo {
593            id: model_id.to_string(),
594            object: "model".to_string(),
595            created: 0,
596            owned_by: "cohere".to_string(),
597            permission: vec![ModelPermission {
598                id: "default".to_string(),
599                object: "model_permission".to_string(),
600                created: 0,
601                allow_create_engine: false,
602                allow_sampling: true,
603                allow_logprobs: false,
604                allow_search_indices: false,
605                allow_view: true,
606                allow_fine_tuning: false,
607                organization: "*".to_string(),
608                group: None,
609                is_blocking: false,
610            }],
611        })
612    }
613}