ai_lib/provider/
cohere.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
6};
7use futures::stream::Stream;
8use std::clone::Clone;
9use std::collections::HashMap;
10use std::sync::Arc;
11#[cfg(feature = "unified_transport")]
12use std::time::Duration;
13
14#[cfg(not(feature = "unified_sse"))]
15#[allow(dead_code)]
16fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
17    let mut i = 0;
18    while i < buffer.len().saturating_sub(1) {
19        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
20            return Some(i + 2);
21        }
22        if i < buffer.len().saturating_sub(3)
23            && buffer[i] == b'\r'
24            && buffer[i + 1] == b'\n'
25            && buffer[i + 2] == b'\r'
26            && buffer[i + 3] == b'\n'
27        {
28            return Some(i + 4);
29        }
30        i += 1;
31    }
32    None
33}
34
35#[cfg(not(feature = "unified_sse"))]
36#[allow(dead_code)]
37fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
38    for line in event_text.lines() {
39        let line = line.trim();
40        if let Some(stripped) = line.strip_prefix("data: ") {
41            let data = stripped;
42            if data == "[DONE]" {
43                return Some(Ok(None));
44            }
45            return Some(parse_chunk_data(data));
46        }
47    }
48    None
49}
50
51#[cfg(not(feature = "unified_sse"))]
52fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
53    match serde_json::from_str::<serde_json::Value>(data) {
54        Ok(json) => {
55            let choices = json["choices"]
56                .as_array()
57                .map(|arr| {
58                    arr.iter()
59                        .enumerate()
60                        .map(|(index, choice)| {
61                            let delta = &choice["delta"];
62                            crate::api::ChoiceDelta {
63                                index: index as u32,
64                                delta: crate::api::MessageDelta {
65                                    role: delta["role"].as_str().map(|r| match r {
66                                        "assistant" => crate::types::Role::Assistant,
67                                        "user" => crate::types::Role::User,
68                                        "system" => crate::types::Role::System,
69                                        _ => crate::types::Role::Assistant,
70                                    }),
71                                    content: delta["content"].as_str().map(str::to_string),
72                                },
73                                finish_reason: choice["finish_reason"].as_str().map(str::to_string),
74                            }
75                        })
76                        .collect()
77                })
78                .unwrap_or_default();
79
80            Ok(Some(ChatCompletionChunk {
81                id: json["id"].as_str().unwrap_or_default().to_string(),
82                object: json["object"]
83                    .as_str()
84                    .unwrap_or("chat.completion.chunk")
85                    .to_string(),
86                created: json["created"].as_u64().unwrap_or(0),
87                model: json["model"].as_str().unwrap_or_default().to_string(),
88                choices,
89            }))
90        }
91        Err(e) => Err(AiLibError::ProviderError(format!(
92            "JSON parse error: {}",
93            e
94        ))),
95    }
96}
97
98fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
99    // Split `text` into chunks where each chunk's byte length is around `max_len` but
100    // never splits a UTF-8 code point. This implementation works on byte indices but
101    // always ensures slicing is done on `is_char_boundary` positions.
102    let mut chunks = Vec::new();
103    let bytes = text.as_bytes();
104    let mut i = 0usize;
105    let len = bytes.len();
106    while i < len {
107        // initial tentative end
108        let mut end = std::cmp::min(i + max_len, len);
109        // if end is not a char boundary, move it backward to the previous boundary
110        if end < len && !text.is_char_boundary(end) {
111            while end > i && !text.is_char_boundary(end) {
112                end -= 1;
113            }
114            // if we couldn't move backward (very small max_len), move forward to next boundary
115            if end == i {
116                end = std::cmp::min(i + max_len, len);
117                while end < len && !text.is_char_boundary(end) {
118                    end += 1;
119                }
120                if end == i {
121                    // Give up and take remainder as lossy string to avoid infinite loop
122                    end = len;
123                }
124            }
125        }
126
127        // try to cut at last whitespace within i..end (char-safe slice)
128        let mut cut = end;
129        if end < len {
130            if let Some(pos) = text[i..end].rfind(' ') {
131                cut = i + pos;
132            }
133        }
134        if cut == i {
135            cut = end;
136        }
137
138        // Safe slice because i..cut are guaranteed to be char boundaries
139        let chunk = text[i..cut].to_string();
140        chunks.push(chunk);
141        i = cut;
142        // skip a single leading space on next chunk if present
143        if i < len && bytes[i] == b' ' {
144            i += 1;
145        }
146    }
147    chunks
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn split_text_handles_multibyte_without_panic() {
156        let s = "这是一个包含中文字符的长文本,用于测试边界处理。🚀✨";
157        // Use a small max_len to force potential mid-codepoint boundaries
158        let chunks = split_text_into_chunks(s, 8);
159        assert!(!chunks.is_empty());
160        // ensure every chunk is valid UTF-8 and reasonably sized
161        for c in &chunks {
162            assert!(std::str::from_utf8(c.as_bytes()).is_ok());
163            // allow small slack: chunk bytes should not greatly exceed the requested max_len
164            assert!(c.as_bytes().len() <= 8 + 8);
165        }
166        // Reconstructing by simple concatenation should produce a string that is at least
167        // a substring of the original (this guards against lossy truncation logic)
168        let combined = chunks.join("");
169        assert!(s.contains(&combined) || combined.contains(&s) || !combined.is_empty());
170    }
171}
172use futures::StreamExt;
173use tokio::sync::mpsc;
174use tokio_stream::wrappers::UnboundedReceiverStream;
175
176pub struct CohereAdapter {
177    #[allow(dead_code)] // Kept for backward compatibility, now using direct reqwest
178    transport: DynHttpTransportRef,
179    api_key: String,
180    base_url: String,
181    metrics: Arc<dyn Metrics>,
182}
183
184impl CohereAdapter {
185    #[allow(dead_code)]
186    fn build_default_timeout_secs() -> u64 {
187        std::env::var("AI_HTTP_TIMEOUT_SECS")
188            .ok()
189            .and_then(|s| s.parse::<u64>().ok())
190            .unwrap_or(30)
191    }
192
193    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
194        #[cfg(feature = "unified_transport")]
195        {
196            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
197            let client = crate::transport::client_factory::build_shared_client()
198                .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
199            let t = HttpTransport::with_reqwest_client(client, timeout);
200            return Ok(t.boxed());
201        }
202        #[cfg(not(feature = "unified_transport"))]
203        {
204            let t = HttpTransport::new();
205            return Ok(t.boxed());
206        }
207    }
208
209    /// Create Cohere adapter. Requires COHERE_API_KEY env var.
210    pub fn new() -> Result<Self, AiLibError> {
211        let api_key = std::env::var("COHERE_API_KEY").map_err(|_| {
212            AiLibError::AuthenticationError(
213                "COHERE_API_KEY environment variable not set".to_string(),
214            )
215        })?;
216        let base_url = std::env::var("COHERE_BASE_URL")
217            .unwrap_or_else(|_| "https://api.cohere.ai".to_string());
218        Ok(Self {
219            transport: Self::build_default_transport()?,
220            api_key,
221            base_url,
222            metrics: Arc::new(NoopMetrics::new()),
223        })
224    }
225
226    /// Explicit overrides for api_key and optional base_url (takes precedence over env vars)
227    pub fn new_with_overrides(
228        api_key: String,
229        base_url: Option<String>,
230    ) -> Result<Self, AiLibError> {
231        let resolved_base = base_url.unwrap_or_else(|| {
232            std::env::var("COHERE_BASE_URL").unwrap_or_else(|_| "https://api.cohere.ai".to_string())
233        });
234        Ok(Self {
235            transport: Self::build_default_transport()?,
236            api_key,
237            base_url: resolved_base,
238            metrics: Arc::new(NoopMetrics::new()),
239        })
240    }
241
242    /// Create adapter with injectable transport (for testing)
243    pub fn with_transport(transport: HttpTransport, api_key: String, base_url: String) -> Self {
244        Self {
245            transport: transport.boxed(),
246            api_key,
247            base_url,
248            metrics: Arc::new(NoopMetrics::new()),
249        }
250    }
251
252    /// Construct using object-safe transport reference
253    pub fn with_transport_ref(
254        transport: DynHttpTransportRef,
255        api_key: String,
256        base_url: String,
257    ) -> Self {
258        Self {
259            transport,
260            api_key,
261            base_url,
262            metrics: Arc::new(NoopMetrics::new()),
263        }
264    }
265
266    /// Construct with an injected transport reference and metrics implementation
267    pub fn with_transport_ref_and_metrics(
268        transport: DynHttpTransportRef,
269        api_key: String,
270        base_url: String,
271        metrics: Arc<dyn Metrics>,
272    ) -> Self {
273        Self {
274            transport,
275            api_key,
276            base_url,
277            metrics,
278        }
279    }
280
281    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
282        // Build v2/chat format according to Cohere documentation
283        let msgs: Vec<serde_json::Value> = request
284            .messages
285            .iter()
286            .map(|msg| {
287                serde_json::json!({
288                    "role": match msg.role {
289                        Role::System => "system",
290                        Role::User => "user",
291                        Role::Assistant => "assistant",
292                    },
293                    "content": msg.content.as_text()
294                })
295            })
296            .collect();
297
298        let mut chat_body = serde_json::json!({
299            "model": request.model,
300            "messages": msgs,
301        });
302
303        // Add optional parameters
304        if let Some(temp) = request.temperature {
305            chat_body["temperature"] =
306                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
307        }
308        if let Some(max_tokens) = request.max_tokens {
309            chat_body["max_tokens"] =
310                serde_json::Value::Number(serde_json::Number::from(max_tokens));
311        }
312
313        chat_body
314    }
315
316    fn parse_response(
317        &self,
318        response: serde_json::Value,
319    ) -> Result<ChatCompletionResponse, AiLibError> {
320        // Try different response formats: OpenAI-like choices, Cohere v2/chat message, or Cohere v1 generations
321        let content = if let Some(c) = response.get("choices") {
322            // OpenAI format
323            c[0]["message"]["content"]
324                .as_str()
325                .map(|s| s.to_string())
326                .or_else(|| c[0]["text"].as_str().map(|s| s.to_string()))
327        } else if let Some(msg) = response.get("message") {
328            // Cohere v2/chat format
329            msg.get("content").and_then(|content_array| {
330                content_array
331                    .as_array()
332                    .and_then(|arr| arr.first())
333                    .and_then(|content_obj| {
334                        content_obj
335                            .get("text")
336                            .and_then(|t| t.as_str())
337                            .map(|text| text.to_string())
338                    })
339            })
340        } else if let Some(gens) = response.get("generations") {
341            // Cohere v1 format
342            gens[0]["text"].as_str().map(|s| s.to_string())
343        } else {
344            None
345        };
346
347        let content = content.unwrap_or_default();
348
349        let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
350        if let Some(fc_val) = response.get("function_call") {
351            if let Ok(fc) =
352                serde_json::from_value::<crate::types::function_call::FunctionCall>(fc_val.clone())
353            {
354                function_call = Some(fc);
355            } else if let Some(name) = fc_val
356                .get("name")
357                .and_then(|v| v.as_str())
358                .map(|s| s.to_string())
359            {
360                let args = fc_val.get("arguments").and_then(|a| {
361                    if a.is_string() {
362                        serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
363                    } else {
364                        Some(a.clone())
365                    }
366                });
367                function_call = Some(crate::types::function_call::FunctionCall {
368                    name,
369                    arguments: args,
370                });
371            }
372        } else if let Some(tool_calls) = response.get("tool_calls").and_then(|v| v.as_array()) {
373            if let Some(first) = tool_calls.first() {
374                if let Some(func) = first.get("function").or_else(|| first.get("function_call")) {
375                    if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
376                        let mut args_opt = func.get("arguments").cloned();
377                        if let Some(args_val) = &args_opt {
378                            if args_val.is_string() {
379                                if let Some(s) = args_val.as_str() {
380                                    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s) {
381                                        args_opt = Some(parsed);
382                                    }
383                                }
384                            }
385                        }
386                        function_call = Some(crate::types::function_call::FunctionCall {
387                            name: name.to_string(),
388                            arguments: args_opt,
389                        });
390                    }
391                }
392            }
393        }
394
395        let choice = Choice {
396            index: 0,
397            message: Message {
398                role: Role::Assistant,
399                content: crate::types::common::Content::Text(content.clone()),
400                function_call,
401            },
402            finish_reason: None,
403        };
404
405        let usage = Usage {
406            prompt_tokens: 0,
407            completion_tokens: 0,
408            total_tokens: 0,
409        };
410
411        Ok(ChatCompletionResponse {
412            id: response["id"].as_str().unwrap_or_default().to_string(),
413            object: response["object"].as_str().unwrap_or_default().to_string(),
414            created: response["created"].as_u64().unwrap_or(0),
415            model: response["model"].as_str().unwrap_or_default().to_string(),
416            choices: vec![choice],
417            usage,
418            usage_status: UsageStatus::Unsupported, // Cohere doesn't provide usage data in this format
419        })
420    }
421}
422
423#[async_trait::async_trait]
424impl ChatApi for CohereAdapter {
425    async fn chat_completion(
426        &self,
427        request: ChatCompletionRequest,
428    ) -> Result<ChatCompletionResponse, AiLibError> {
429        self.metrics.incr_counter("cohere.requests", 1).await;
430        let timer = self.metrics.start_timer("cohere.request_duration_ms").await;
431
432        let _body = self.convert_request(&request);
433
434        // Use v1/generate endpoint (fallback for older API keys)
435        let url_generate = format!("{}/v1/generate", self.base_url);
436
437        let mut headers = HashMap::new();
438        headers.insert(
439            "Authorization".to_string(),
440            format!("Bearer {}", self.api_key),
441        );
442        headers.insert("Content-Type".to_string(), "application/json".to_string());
443        headers.insert("Accept".to_string(), "application/json".to_string());
444
445        // Convert messages to prompt string for v1/generate endpoint
446        let prompt = request
447            .messages
448            .iter()
449            .map(|msg| match msg.role {
450                Role::System => format!("System: {}", msg.content.as_text()),
451                Role::User => format!("Human: {}", msg.content.as_text()),
452                Role::Assistant => format!("Assistant: {}", msg.content.as_text()),
453            })
454            .collect::<Vec<_>>()
455            .join("\n");
456
457        let mut generate_body = serde_json::json!({
458            "model": request.model,
459            "prompt": prompt,
460        });
461
462        if let Some(temp) = request.temperature {
463            generate_body["temperature"] =
464                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
465        }
466        if let Some(max_tokens) = request.max_tokens {
467            generate_body["max_tokens"] =
468                serde_json::Value::Number(serde_json::Number::from(max_tokens));
469        }
470
471        // Use unified transport
472        let response_json = self
473            .transport
474            .post_json(&url_generate, Some(headers), generate_body)
475            .await?;
476
477        if let Some(t) = timer {
478            t.stop();
479        }
480
481        self.parse_response(response_json)
482    }
483
484    async fn chat_completion_stream(
485        &self,
486        _request: ChatCompletionRequest,
487    ) -> Result<
488        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
489        AiLibError,
490    > {
491        // Build stream request similar to chat_completion but with stream=true
492        let mut stream_request = self.convert_request(&_request);
493        stream_request["stream"] = serde_json::Value::Bool(true);
494
495        let url = format!("{}/v1/chat", self.base_url);
496
497        let mut headers = HashMap::new();
498        headers.insert(
499            "Authorization".to_string(),
500            format!("Bearer {}", self.api_key),
501        );
502        // Try unified transport streaming first
503        if let Ok(byte_stream) = self
504            .transport
505            .post_stream(&url, Some(headers.clone()), stream_request.clone())
506            .await
507        {
508            let (tx, rx) = mpsc::unbounded_channel();
509            tokio::spawn(async move {
510                let mut buffer = Vec::new();
511                futures::pin_mut!(byte_stream);
512                while let Some(item) = byte_stream.next().await {
513                    match item {
514                        Ok(bytes) => {
515                            buffer.extend_from_slice(&bytes);
516                            #[cfg(feature = "unified_sse")]
517                            {
518                                while let Some(boundary) =
519                                    crate::sse::parser::find_event_boundary(&buffer)
520                                {
521                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
522                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
523                                        if let Some(parsed) =
524                                            crate::sse::parser::parse_sse_event(event_text)
525                                        {
526                                            match parsed {
527                                                Ok(Some(chunk)) => {
528                                                    if tx.send(Ok(chunk)).is_err() {
529                                                        return;
530                                                    }
531                                                }
532                                                Ok(None) => return,
533                                                Err(e) => {
534                                                    let _ = tx.send(Err(e));
535                                                    return;
536                                                }
537                                            }
538                                        }
539                                    }
540                                }
541                            }
542                            #[cfg(not(feature = "unified_sse"))]
543                            {
544                                while let Some(boundary) = find_event_boundary(&buffer) {
545                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
546                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
547                                        if let Some(parsed) = parse_sse_event(event_text) {
548                                            match parsed {
549                                                Ok(Some(chunk)) => {
550                                                    if tx.send(Ok(chunk)).is_err() {
551                                                        return;
552                                                    }
553                                                }
554                                                Ok(None) => return,
555                                                Err(e) => {
556                                                    let _ = tx.send(Err(e));
557                                                    return;
558                                                }
559                                            }
560                                        }
561                                    }
562                                }
563                            }
564                        }
565                        Err(e) => {
566                            let _ = tx.send(Err(AiLibError::ProviderError(format!(
567                                "Stream error: {}",
568                                e
569                            ))));
570                            break;
571                        }
572                    }
573                }
574            });
575            let stream = UnboundedReceiverStream::new(rx);
576            return Ok(Box::new(Box::pin(stream)));
577        }
578
579        // Fallback: call non-streaming chat_completion and stream the result in chunks
580        let finished = self.chat_completion(_request.clone()).await?;
581        let text = finished
582            .choices
583            .first()
584            .map(|c| c.message.content.as_text())
585            .unwrap_or_default();
586
587        let (tx, rx) = mpsc::unbounded_channel();
588
589        tokio::spawn(async move {
590            let chunks = split_text_into_chunks(&text, 80);
591            for chunk in chunks {
592                // construct ChatCompletionChunk with single delta
593                let delta = crate::api::ChoiceDelta {
594                    index: 0,
595                    delta: crate::api::MessageDelta {
596                        role: Some(crate::types::Role::Assistant),
597                        content: Some(chunk.clone()),
598                    },
599                    finish_reason: None,
600                };
601                let chunk_obj = ChatCompletionChunk {
602                    id: "simulated".to_string(),
603                    object: "chat.completion.chunk".to_string(),
604                    created: 0,
605                    model: finished.model.clone(),
606                    choices: vec![delta],
607                };
608
609                if tx.send(Ok(chunk_obj)).is_err() {
610                    return;
611                }
612                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
613            }
614        });
615
616        let stream = UnboundedReceiverStream::new(rx);
617        Ok(Box::new(Box::pin(stream)))
618    }
619
620    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
621        // Use v1/models endpoint for listing models via unified transport
622        let url = format!("{}/v1/models", self.base_url);
623        let mut headers = HashMap::new();
624        headers.insert(
625            "Authorization".to_string(),
626            format!("Bearer {}", self.api_key),
627        );
628
629        let response = self.transport.get_json(&url, Some(headers)).await?;
630
631        Ok(response["models"]
632            .as_array()
633            .unwrap_or(&vec![])
634            .iter()
635            .filter_map(|m| {
636                m["id"]
637                    .as_str()
638                    .map(|s| s.to_string())
639                    .or_else(|| m["name"].as_str().map(|s| s.to_string()))
640            })
641            .collect())
642    }
643
644    async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
645        Ok(ModelInfo {
646            id: model_id.to_string(),
647            object: "model".to_string(),
648            created: 0,
649            owned_by: "cohere".to_string(),
650            permission: vec![ModelPermission {
651                id: "default".to_string(),
652                object: "model_permission".to_string(),
653                created: 0,
654                allow_create_engine: false,
655                allow_sampling: true,
656                allow_logprobs: false,
657                allow_search_indices: false,
658                allow_view: true,
659                allow_fine_tuning: false,
660                organization: "*".to_string(),
661                group: None,
662                is_blocking: false,
663            }],
664        })
665    }
666}