ai_lib/provider/
cohere.rs

1use crate::api::{ChatCompletionChunk, ChatProvider, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6    UsageStatus,
7};
8use futures::stream::Stream;
9use std::clone::Clone;
10use std::collections::HashMap;
11use std::sync::Arc;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14
15#[cfg(not(feature = "unified_sse"))]
16#[allow(dead_code)]
17fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
18    let mut i = 0;
19    while i < buffer.len().saturating_sub(1) {
20        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
21            return Some(i + 2);
22        }
23        if i < buffer.len().saturating_sub(3)
24            && buffer[i] == b'\r'
25            && buffer[i + 1] == b'\n'
26            && buffer[i + 2] == b'\r'
27            && buffer[i + 3] == b'\n'
28        {
29            return Some(i + 4);
30        }
31        i += 1;
32    }
33    None
34}
35
36#[cfg(not(feature = "unified_sse"))]
37#[allow(dead_code)]
38fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
39    for line in event_text.lines() {
40        let line = line.trim();
41        if let Some(stripped) = line.strip_prefix("data: ") {
42            let data = stripped;
43            if data == "[DONE]" {
44                return Some(Ok(None));
45            }
46            return Some(parse_chunk_data(data));
47        }
48    }
49    None
50}
51
52#[cfg(not(feature = "unified_sse"))]
53fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
54    match serde_json::from_str::<serde_json::Value>(data) {
55        Ok(json) => {
56            let choices = json["choices"]
57                .as_array()
58                .map(|arr| {
59                    arr.iter()
60                        .enumerate()
61                        .map(|(index, choice)| {
62                            let delta = &choice["delta"];
63                            crate::api::ChoiceDelta {
64                                index: index as u32,
65                                delta: crate::api::MessageDelta {
66                                    role: delta["role"].as_str().map(|r| match r {
67                                        "assistant" => crate::types::Role::Assistant,
68                                        "user" => crate::types::Role::User,
69                                        "system" => crate::types::Role::System,
70                                        _ => crate::types::Role::Assistant,
71                                    }),
72                                    content: delta["content"].as_str().map(str::to_string),
73                                },
74                                finish_reason: choice["finish_reason"].as_str().map(str::to_string),
75                            }
76                        })
77                        .collect()
78                })
79                .unwrap_or_default();
80
81            Ok(Some(ChatCompletionChunk {
82                id: json["id"].as_str().unwrap_or_default().to_string(),
83                object: json["object"]
84                    .as_str()
85                    .unwrap_or("chat.completion.chunk")
86                    .to_string(),
87                created: json["created"].as_u64().unwrap_or(0),
88                model: json["model"].as_str().unwrap_or_default().to_string(),
89                choices,
90            }))
91        }
92        Err(e) => Err(AiLibError::ProviderError(format!(
93            "JSON parse error: {}",
94            e
95        ))),
96    }
97}
98
99fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
100    // Split `text` into chunks where each chunk's byte length is around `max_len` but
101    // never splits a UTF-8 code point. This implementation works on byte indices but
102    // always ensures slicing is done on `is_char_boundary` positions.
103    let mut chunks = Vec::new();
104    let bytes = text.as_bytes();
105    let mut i = 0usize;
106    let len = bytes.len();
107    while i < len {
108        // initial tentative end
109        let mut end = std::cmp::min(i + max_len, len);
110        // if end is not a char boundary, move it backward to the previous boundary
111        if end < len && !text.is_char_boundary(end) {
112            while end > i && !text.is_char_boundary(end) {
113                end -= 1;
114            }
115            // if we couldn't move backward (very small max_len), move forward to next boundary
116            if end == i {
117                end = std::cmp::min(i + max_len, len);
118                while end < len && !text.is_char_boundary(end) {
119                    end += 1;
120                }
121                if end == i {
122                    // Give up and take remainder as lossy string to avoid infinite loop
123                    end = len;
124                }
125            }
126        }
127
128        // try to cut at last whitespace within i..end (char-safe slice)
129        let mut cut = end;
130        if end < len {
131            if let Some(pos) = text[i..end].rfind(' ') {
132                cut = i + pos;
133            }
134        }
135        if cut == i {
136            cut = end;
137        }
138
139        // Safe slice because i..cut are guaranteed to be char boundaries
140        let chunk = text[i..cut].to_string();
141        chunks.push(chunk);
142        i = cut;
143        // skip a single leading space on next chunk if present
144        if i < len && bytes[i] == b' ' {
145            i += 1;
146        }
147    }
148    chunks
149}
150use futures::StreamExt;
151use tokio::sync::mpsc;
152use tokio_stream::wrappers::UnboundedReceiverStream;
153
154pub struct CohereAdapter {
155    #[allow(dead_code)] // Kept for backward compatibility, now using direct reqwest
156    transport: DynHttpTransportRef,
157    api_key: String,
158    base_url: String,
159    metrics: Arc<dyn Metrics>,
160}
161
162impl CohereAdapter {
163    #[allow(dead_code)]
164    fn build_default_timeout_secs() -> u64 {
165        std::env::var("AI_HTTP_TIMEOUT_SECS")
166            .ok()
167            .and_then(|s| s.parse::<u64>().ok())
168            .unwrap_or(30)
169    }
170
171    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
172        #[cfg(feature = "unified_transport")]
173        {
174            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
175            let client = crate::transport::client_factory::build_shared_client().map_err(|e| {
176                AiLibError::NetworkError(format!("Failed to build http client: {}", e))
177            })?;
178            let t = HttpTransport::with_reqwest_client(client, timeout);
179            Ok(t.boxed())
180        }
181        #[cfg(not(feature = "unified_transport"))]
182        {
183            let t = HttpTransport::new();
184            return Ok(t.boxed());
185        }
186    }
187
188    /// Create Cohere adapter. Requires COHERE_API_KEY env var.
189    pub fn new() -> Result<Self, AiLibError> {
190        let api_key = std::env::var("COHERE_API_KEY").map_err(|_| {
191            AiLibError::AuthenticationError(
192                "COHERE_API_KEY environment variable not set".to_string(),
193            )
194        })?;
195        let base_url = std::env::var("COHERE_BASE_URL")
196            .unwrap_or_else(|_| "https://api.cohere.ai".to_string());
197        Ok(Self {
198            transport: Self::build_default_transport()?,
199            api_key,
200            base_url,
201            metrics: Arc::new(NoopMetrics::new()),
202        })
203    }
204
205    /// Explicit overrides for api_key and optional base_url (takes precedence over env vars)
206    pub fn new_with_overrides(
207        api_key: String,
208        base_url: Option<String>,
209    ) -> Result<Self, AiLibError> {
210        let resolved_base = base_url.unwrap_or_else(|| {
211            std::env::var("COHERE_BASE_URL").unwrap_or_else(|_| "https://api.cohere.ai".to_string())
212        });
213        Ok(Self {
214            transport: Self::build_default_transport()?,
215            api_key,
216            base_url: resolved_base,
217            metrics: Arc::new(NoopMetrics::new()),
218        })
219    }
220
221    /// Create adapter with injectable transport (for testing)
222    pub fn with_transport(transport: HttpTransport, api_key: String, base_url: String) -> Self {
223        Self {
224            transport: transport.boxed(),
225            api_key,
226            base_url,
227            metrics: Arc::new(NoopMetrics::new()),
228        }
229    }
230
231    /// Construct using object-safe transport reference
232    pub fn with_transport_ref(
233        transport: DynHttpTransportRef,
234        api_key: String,
235        base_url: String,
236    ) -> Self {
237        Self {
238            transport,
239            api_key,
240            base_url,
241            metrics: Arc::new(NoopMetrics::new()),
242        }
243    }
244
245    /// Construct with an injected transport reference and metrics implementation
246    pub fn with_transport_ref_and_metrics(
247        transport: DynHttpTransportRef,
248        api_key: String,
249        base_url: String,
250        metrics: Arc<dyn Metrics>,
251    ) -> Self {
252        Self {
253            transport,
254            api_key,
255            base_url,
256            metrics,
257        }
258    }
259
260    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
261        // Build v2/chat format according to Cohere documentation
262        let msgs: Vec<serde_json::Value> = request
263            .messages
264            .iter()
265            .map(|msg| {
266                serde_json::json!({
267                    "role": match msg.role {
268                        Role::System => "system",
269                        Role::User => "user",
270                        Role::Assistant => "assistant",
271                    },
272                    "content": msg.content.as_text()
273                })
274            })
275            .collect();
276
277        let mut chat_body = serde_json::json!({
278            "model": request.model,
279            "messages": msgs,
280        });
281
282        // Add optional parameters
283        if let Some(temp) = request.temperature {
284            chat_body["temperature"] =
285                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
286        }
287        if let Some(max_tokens) = request.max_tokens {
288            chat_body["max_tokens"] =
289                serde_json::Value::Number(serde_json::Number::from(max_tokens));
290        }
291
292        request.apply_extensions(&mut chat_body);
293
294        chat_body
295    }
296
297    fn parse_response(
298        &self,
299        response: serde_json::Value,
300    ) -> Result<ChatCompletionResponse, AiLibError> {
301        // Try different response formats: OpenAI-like choices, Cohere v2/chat message, or Cohere v1 generations
302        let content = if let Some(c) = response.get("choices") {
303            // OpenAI format
304            c[0]["message"]["content"]
305                .as_str()
306                .map(|s| s.to_string())
307                .or_else(|| c[0]["text"].as_str().map(|s| s.to_string()))
308        } else if let Some(msg) = response.get("message") {
309            // Cohere v2/chat format
310            msg.get("content").and_then(|content_array| {
311                content_array
312                    .as_array()
313                    .and_then(|arr| arr.first())
314                    .and_then(|content_obj| {
315                        content_obj
316                            .get("text")
317                            .and_then(|t| t.as_str())
318                            .map(|text| text.to_string())
319                    })
320            })
321        } else if let Some(gens) = response.get("generations") {
322            // Cohere v1 format
323            gens[0]["text"].as_str().map(|s| s.to_string())
324        } else {
325            None
326        };
327
328        let content = content.unwrap_or_default();
329
330        let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
331        if let Some(fc_val) = response.get("function_call") {
332            if let Ok(fc) =
333                serde_json::from_value::<crate::types::function_call::FunctionCall>(fc_val.clone())
334            {
335                function_call = Some(fc);
336            } else if let Some(name) = fc_val
337                .get("name")
338                .and_then(|v| v.as_str())
339                .map(|s| s.to_string())
340            {
341                let args = fc_val.get("arguments").and_then(|a| {
342                    if a.is_string() {
343                        serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
344                    } else {
345                        Some(a.clone())
346                    }
347                });
348                function_call = Some(crate::types::function_call::FunctionCall {
349                    name,
350                    arguments: args,
351                });
352            }
353        } else if let Some(tool_calls) = response.get("tool_calls").and_then(|v| v.as_array()) {
354            if let Some(first) = tool_calls.first() {
355                if let Some(func) = first.get("function").or_else(|| first.get("function_call")) {
356                    if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
357                        let mut args_opt = func.get("arguments").cloned();
358                        if let Some(args_val) = &args_opt {
359                            if args_val.is_string() {
360                                if let Some(s) = args_val.as_str() {
361                                    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
362                                    {
363                                        args_opt = Some(parsed);
364                                    }
365                                }
366                            }
367                        }
368                        function_call = Some(crate::types::function_call::FunctionCall {
369                            name: name.to_string(),
370                            arguments: args_opt,
371                        });
372                    }
373                }
374            }
375        }
376
377        let choice = Choice {
378            index: 0,
379            message: Message {
380                role: Role::Assistant,
381                content: crate::types::common::Content::Text(content.clone()),
382                function_call,
383            },
384            finish_reason: None,
385        };
386
387        let usage = Usage {
388            prompt_tokens: 0,
389            completion_tokens: 0,
390            total_tokens: 0,
391        };
392
393        Ok(ChatCompletionResponse {
394            id: response["id"].as_str().unwrap_or_default().to_string(),
395            object: response["object"].as_str().unwrap_or_default().to_string(),
396            created: response["created"].as_u64().unwrap_or(0),
397            model: response["model"].as_str().unwrap_or_default().to_string(),
398            choices: vec![choice],
399            usage,
400            usage_status: UsageStatus::Unsupported, // Cohere doesn't provide usage data in this format
401        })
402    }
403}
404
405#[async_trait::async_trait]
406impl ChatProvider for CohereAdapter {
407    fn name(&self) -> &str {
408        "Cohere"
409    }
410
411    async fn chat(
412        &self,
413        request: ChatCompletionRequest,
414    ) -> Result<ChatCompletionResponse, AiLibError> {
415        self.metrics.incr_counter("cohere.requests", 1).await;
416        let timer = self.metrics.start_timer("cohere.request_duration_ms").await;
417
418        let _body = self.convert_request(&request);
419
420        // Use v1/generate endpoint (fallback for older API keys)
421        let url_generate = format!("{}/v1/generate", self.base_url);
422
423        let mut headers = HashMap::new();
424        headers.insert(
425            "Authorization".to_string(),
426            format!("Bearer {}", self.api_key),
427        );
428        headers.insert("Content-Type".to_string(), "application/json".to_string());
429        headers.insert("Accept".to_string(), "application/json".to_string());
430
431        // Convert messages to prompt string for v1/generate endpoint
432        let prompt = request
433            .messages
434            .iter()
435            .map(|msg| match msg.role {
436                Role::System => format!("System: {}", msg.content.as_text()),
437                Role::User => format!("Human: {}", msg.content.as_text()),
438                Role::Assistant => format!("Assistant: {}", msg.content.as_text()),
439            })
440            .collect::<Vec<_>>()
441            .join("\n");
442
443        let mut generate_body = serde_json::json!({
444            "model": request.model,
445            "prompt": prompt,
446        });
447
448        if let Some(temp) = request.temperature {
449            generate_body["temperature"] =
450                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
451        }
452        if let Some(max_tokens) = request.max_tokens {
453            generate_body["max_tokens"] =
454                serde_json::Value::Number(serde_json::Number::from(max_tokens));
455        }
456
457        request.apply_extensions(&mut generate_body);
458
459        // Use unified transport
460        let response_json = self
461            .transport
462            .post_json(&url_generate, Some(headers), generate_body)
463            .await?;
464
465        if let Some(t) = timer {
466            t.stop();
467        }
468
469        self.parse_response(response_json)
470    }
471
472    async fn stream(
473        &self,
474        _request: ChatCompletionRequest,
475    ) -> Result<
476        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
477        AiLibError,
478    > {
479        // Build stream request similar to chat_completion but with stream=true
480        let mut stream_request = self.convert_request(&_request);
481        stream_request["stream"] = serde_json::Value::Bool(true);
482
483        let url = format!("{}/v1/chat", self.base_url);
484
485        let mut headers = HashMap::new();
486        headers.insert(
487            "Authorization".to_string(),
488            format!("Bearer {}", self.api_key),
489        );
490        // Try unified transport streaming first
491        if let Ok(byte_stream) = self
492            .transport
493            .post_stream(&url, Some(headers.clone()), stream_request.clone())
494            .await
495        {
496            let (tx, rx) = mpsc::unbounded_channel();
497            tokio::spawn(async move {
498                let mut buffer = Vec::new();
499                futures::pin_mut!(byte_stream);
500                while let Some(item) = byte_stream.next().await {
501                    match item {
502                        Ok(bytes) => {
503                            buffer.extend_from_slice(&bytes);
504                            #[cfg(feature = "unified_sse")]
505                            {
506                                while let Some(boundary) =
507                                    crate::sse::parser::find_event_boundary(&buffer)
508                                {
509                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
510                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
511                                        if let Some(parsed) =
512                                            crate::sse::parser::parse_sse_event(event_text)
513                                        {
514                                            match parsed {
515                                                Ok(Some(chunk)) => {
516                                                    if tx.send(Ok(chunk)).is_err() {
517                                                        return;
518                                                    }
519                                                }
520                                                Ok(None) => return,
521                                                Err(e) => {
522                                                    let _ = tx.send(Err(e));
523                                                    return;
524                                                }
525                                            }
526                                        }
527                                    }
528                                }
529                            }
530                            #[cfg(not(feature = "unified_sse"))]
531                            {
532                                while let Some(boundary) = find_event_boundary(&buffer) {
533                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
534                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
535                                        if let Some(parsed) = parse_sse_event(event_text) {
536                                            match parsed {
537                                                Ok(Some(chunk)) => {
538                                                    if tx.send(Ok(chunk)).is_err() {
539                                                        return;
540                                                    }
541                                                }
542                                                Ok(None) => return,
543                                                Err(e) => {
544                                                    let _ = tx.send(Err(e));
545                                                    return;
546                                                }
547                                            }
548                                        }
549                                    }
550                                }
551                            }
552                        }
553                        Err(e) => {
554                            let _ = tx.send(Err(AiLibError::ProviderError(format!(
555                                "Stream error: {}",
556                                e
557                            ))));
558                            break;
559                        }
560                    }
561                }
562            });
563            let stream = UnboundedReceiverStream::new(rx);
564            return Ok(Box::new(Box::pin(stream)));
565        }
566
567        // Fallback: call non-streaming chat and stream the result in chunks
568        let finished = self.chat(_request.clone()).await?;
569        let text = finished
570            .choices
571            .first()
572            .map(|c| c.message.content.as_text())
573            .unwrap_or_default();
574
575        let (tx, rx) = mpsc::unbounded_channel();
576
577        tokio::spawn(async move {
578            let chunks = split_text_into_chunks(&text, 80);
579            for chunk in chunks {
580                // construct ChatCompletionChunk with single delta
581                let delta = crate::api::ChoiceDelta {
582                    index: 0,
583                    delta: crate::api::MessageDelta {
584                        role: Some(crate::types::Role::Assistant),
585                        content: Some(chunk.clone()),
586                    },
587                    finish_reason: None,
588                };
589                let chunk_obj = ChatCompletionChunk {
590                    id: "simulated".to_string(),
591                    object: "chat.completion.chunk".to_string(),
592                    created: 0,
593                    model: finished.model.clone(),
594                    choices: vec![delta],
595                };
596
597                if tx.send(Ok(chunk_obj)).is_err() {
598                    return;
599                }
600                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
601            }
602        });
603
604        let stream = UnboundedReceiverStream::new(rx);
605        Ok(Box::new(Box::pin(stream)))
606    }
607
608    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
609        // Use v1/models endpoint for listing models via unified transport
610        let url = format!("{}/v1/models", self.base_url);
611        let mut headers = HashMap::new();
612        headers.insert(
613            "Authorization".to_string(),
614            format!("Bearer {}", self.api_key),
615        );
616
617        let response = self.transport.get_json(&url, Some(headers)).await?;
618
619        Ok(response["models"]
620            .as_array()
621            .unwrap_or(&vec![])
622            .iter()
623            .filter_map(|m| {
624                m["id"]
625                    .as_str()
626                    .map(|s| s.to_string())
627                    .or_else(|| m["name"].as_str().map(|s| s.to_string()))
628            })
629            .collect())
630    }
631
632    async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
633        Ok(ModelInfo {
634            id: model_id.to_string(),
635            object: "model".to_string(),
636            created: 0,
637            owned_by: "cohere".to_string(),
638            permission: vec![ModelPermission {
639                id: "default".to_string(),
640                object: "model_permission".to_string(),
641                created: 0,
642                allow_create_engine: false,
643                allow_sampling: true,
644                allow_logprobs: false,
645                allow_search_indices: false,
646                allow_view: true,
647                allow_fine_tuning: false,
648                organization: "*".to_string(),
649                group: None,
650                is_blocking: false,
651            }],
652        })
653    }
654}
655
656#[cfg(test)]
657mod tests {
658    use super::*;
659
660    #[test]
661    fn split_text_handles_multibyte_without_panic() {
662        let s = "这是一个包含中文字符的长文本,用于测试边界处理。🚀✨";
663        // Use a small max_len to force potential mid-codepoint boundaries
664        let chunks = split_text_into_chunks(s, 8);
665        assert!(!chunks.is_empty());
666        // ensure every chunk is valid UTF-8 and reasonably sized
667        for c in &chunks {
668            assert!(std::str::from_utf8(c.as_bytes()).is_ok());
669            // allow small slack: chunk bytes should not greatly exceed the requested max_len
670            assert!(c.len() <= 8 + 8);
671        }
672        // Reconstructing by simple concatenation should produce a string that is at least
673        // a substring of the original (this guards against lossy truncation logic)
674        let combined = chunks.join("");
675        assert!(s.contains(&combined) || combined.contains(s) || !combined.is_empty());
676    }
677}