llm/backends/
cohere.rs

1//! Cohere API client implementation for chat and completion functionality.
2//!
3//! This module provides integration with Cohere's LLM models through their Compatibility API.
4use std::time::Duration;
5
6#[cfg(feature = "cohere")]
7use crate::{
8    chat::Tool,
9    chat::{ChatMessage, ChatProvider, ChatRole, MessageType, StructuredOutputFormat},
10    completion::{CompletionProvider, CompletionRequest, CompletionResponse},
11    embedding::EmbeddingProvider,
12    error::LLMError,
13    models::{ModelsProvider},
14    stt::SpeechToTextProvider,
15    tts::TextToSpeechProvider,
16    LLMProvider,
17};
18#[cfg(feature = "cohere")]
19use crate::{
20    chat::{ChatResponse, ToolChoice},
21    ToolCall,
22};
23use async_trait::async_trait;
24use either::*;
25use futures::stream::Stream;
26use reqwest::{Client, Url};
27use serde::{Deserialize, Serialize};
28
29/// Client for interacting with Cohere's API (OpenAI compatibility mode).
30///
31/// Provides methods for chat and embedding requests using Cohere's models.  
32/// **Note:** Cohere expects system instructions to use the `developer` role instead of `system`:contentReference[oaicite:0]{index=0}.
33pub struct Cohere {
34    pub api_key: String,
35    pub base_url: Url,
36    pub model: String,
37    pub max_tokens: Option<u32>,
38    pub temperature: Option<f32>,
39    pub system: Option<String>,
40    pub timeout_seconds: Option<u64>,
41    pub stream: Option<bool>,
42    pub top_p: Option<f32>,
43    pub top_k: Option<u32>,
44    pub tools: Option<Vec<Tool>>,
45    pub tool_choice: Option<ToolChoice>,
46    /// Embedding parameters
47    pub embedding_encoding_format: Option<String>,
48    pub embedding_dimensions: Option<u32>,
49    pub reasoning_effort: Option<String>,
50    /// JSON schema for structured output
51    pub json_schema: Option<StructuredOutputFormat>,
52    client: Client,
53}
54
55/// Individual message in a Cohere chat conversation.
56#[derive(Serialize, Debug)]
57struct CohereChatMessage<'a> {
58    #[allow(dead_code)]
59    role: &'a str,
60    #[serde(
61        skip_serializing_if = "Option::is_none",
62        with = "either::serde_untagged_optional"
63    )]
64    content: Option<Either<Vec<CohereMessageContent<'a>>, String>>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    tool_calls: Option<Vec<CohereFunctionCall<'a>>>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    tool_call_id: Option<String>,
69}
70
71#[derive(Serialize, Debug)]
72struct CohereFunctionPayload<'a> {
73    name: &'a str,
74    arguments: &'a str,
75}
76
77#[derive(Serialize, Debug)]
78struct CohereFunctionCall<'a> {
79    id: &'a str,
80    #[serde(rename = "type")]
81    content_type: &'a str,
82    function: CohereFunctionPayload<'a>,
83}
84
85#[derive(Serialize, Debug)]
86struct CohereMessageContent<'a> {
87    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
88    message_type: Option<&'a str>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    text: Option<&'a str>,
91    #[serde(skip_serializing_if = "Option::is_none")]
92    image_url: Option<ImageUrlContent<'a>>,
93    #[serde(skip_serializing_if = "Option::is_none", rename = "tool_call_id")]
94    tool_call_id: Option<&'a str>,
95    #[serde(skip_serializing_if = "Option::is_none", rename = "content")]
96    tool_output: Option<&'a str>,
97}
98
99/// Individual image message (URL) in a Cohere chat conversation.
100#[derive(Serialize, Debug)]
101struct ImageUrlContent<'a> {
102    url: &'a str,
103}
104
105#[derive(Serialize)]
106struct CohereEmbeddingRequest {
107    model: String,
108    input: Vec<String>,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    encoding_format: Option<String>,
111    #[serde(skip_serializing_if = "Option::is_none")]
112    dimensions: Option<u32>,
113}
114
115/// Request payload for Cohere's chat API endpoint.
116#[derive(Serialize, Debug)]
117struct CohereChatRequest<'a> {
118    model: &'a str,
119    messages: Vec<CohereChatMessage<'a>>,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    max_tokens: Option<u32>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    temperature: Option<f32>,
124    stream: bool,
125    #[serde(skip_serializing_if = "Option::is_none")]
126    top_p: Option<f32>,
127    #[serde(skip_serializing_if = "Option::is_none")]
128    top_k: Option<u32>,
129    #[serde(skip_serializing_if = "Option::is_none")]
130    tools: Option<Vec<Tool>>,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    tool_choice: Option<ToolChoice>,
133    #[serde(skip_serializing_if = "Option::is_none")]
134    reasoning_effort: Option<String>,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    response_format: Option<CohereResponseFormat>,
137}
138
139/// Response from Cohere's chat API endpoint.
140#[derive(Deserialize, Debug)]
141struct CohereChatResponse {
142    choices: Vec<CohereChatChoice>,
143}
144
145/// Individual choice within a Cohere chat API response.
146#[derive(Deserialize, Debug)]
147struct CohereChatChoice {
148    message: CohereChatMsg,
149}
150
151/// Message content within a Cohere chat API response.
152#[derive(Deserialize, Debug)]
153struct CohereChatMsg {
154    #[allow(dead_code)]
155    role: String,
156    content: Option<String>,
157    tool_calls: Option<Vec<ToolCall>>,
158}
159
160/// Response from Cohere's embedding API endpoint.
161#[derive(Deserialize, Debug)]
162struct CohereEmbeddingData {
163    embedding: Vec<f32>,
164}
165#[derive(Deserialize, Debug)]
166struct CohereEmbeddingResponse {
167    data: Vec<CohereEmbeddingData>,
168}
169
170/// Output format type for structured responses in Cohere.
171#[derive(Deserialize, Debug, Serialize)]
172enum CohereResponseType {
173    #[serde(rename = "text")]
174    Text,
175    #[serde(rename = "json_schema")]
176    JsonSchema,
177    #[serde(rename = "json_object")]
178    JsonObject,
179}
180
181/// Configuration for forcing the model output format (e.g., JSON schema).
182#[derive(Deserialize, Debug, Serialize)]
183struct CohereResponseFormat {
184    #[serde(rename = "type")]
185    response_type: CohereResponseType,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    json_schema: Option<StructuredOutputFormat>,
188}
189
190impl From<StructuredOutputFormat> for CohereResponseFormat {
191    fn from(structured_response_format: StructuredOutputFormat) -> Self {
192        match structured_response_format.schema {
193            None => CohereResponseFormat {
194                response_type: CohereResponseType::JsonSchema,
195                json_schema: Some(structured_response_format),
196            },
197            Some(mut schema) => {
198                // Ensure "additionalProperties": false in schema if missing
199                if schema.get("additionalProperties").is_none() {
200                    schema["additionalProperties"] = serde_json::json!(false);
201                }
202                CohereResponseFormat {
203                    response_type: CohereResponseType::JsonSchema,
204                    json_schema: Some(StructuredOutputFormat {
205                        name: structured_response_format.name,
206                        description: structured_response_format.description,
207                        schema: Some(schema),
208                        strict: structured_response_format.strict,
209                    }),
210                }
211            }
212        }
213    }
214}
215
216impl ChatResponse for CohereChatResponse {
217    fn text(&self) -> Option<String> {
218        self.choices.first().and_then(|c| c.message.content.clone())
219    }
220    fn tool_calls(&self) -> Option<Vec<ToolCall>> {
221        self.choices.first().and_then(|c| c.message.tool_calls.clone())
222    }
223}
224
225impl std::fmt::Display for CohereChatResponse {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        match (
228            &self.choices.first().unwrap().message.content,
229            &self.choices.first().unwrap().message.tool_calls,
230        ) {
231            (Some(content), Some(tool_calls)) => {
232                for tool_call in tool_calls {
233                    write!(f, "{}", tool_call)?;
234                }
235                write!(f, "{}", content)
236            }
237            (Some(content), None) => write!(f, "{}", content),
238            (None, Some(tool_calls)) => {
239                for tool_call in tool_calls {
240                    write!(f, "{}", tool_call)?;
241                }
242                Ok(())
243            }
244            (None, None) => write!(f, ""),
245        }
246    }
247}
248
249impl Cohere {
250    /// Creates a new Cohere client with the specified configuration.
251    ///
252    /// # Arguments
253    ///
254    /// * `api_key` - Cohere API key  
255    /// * `base_url` - Base URL for Cohere API (defaults to Cohere compatibility API endpoint)  
256    /// * `model` - Model to use (e.g., "command-xlarge")  
257    /// * `max_tokens` - Maximum tokens to generate  
258    /// * `temperature` - Sampling temperature  
259    /// * `timeout_seconds` - Request timeout in seconds  
260    /// * `system` - System prompt (sent as a developer role message)  
261    /// * `stream` - Whether to stream responses  
262    /// * `top_p` - Top-p sampling parameter  
263    /// * `top_k` - Top-k sampling parameter  
264    /// * `embedding_encoding_format` - Format for embedding outputs (`float` or `base64`)  
265    /// * `embedding_dimensions` - (Unused by Cohere) Dimensions for embedding vectors  
266    /// * `tools` - Function tools available to the model  
267    /// * `tool_choice` - Determines how the model uses tools  
268    /// * `reasoning_effort` - Reasoning effort level (unsupported by Cohere)  
269    /// * `json_schema` - JSON schema for structured output
270    #[allow(clippy::too_many_arguments)]
271    pub fn new(
272        api_key: impl Into<String>,
273        base_url: Option<String>,
274        model: Option<String>,
275        max_tokens: Option<u32>,
276        temperature: Option<f32>,
277        timeout_seconds: Option<u64>,
278        system: Option<String>,
279        stream: Option<bool>,
280        top_p: Option<f32>,
281        top_k: Option<u32>,
282        embedding_encoding_format: Option<String>,
283        embedding_dimensions: Option<u32>,
284        tools: Option<Vec<Tool>>,
285        tool_choice: Option<ToolChoice>,
286        reasoning_effort: Option<String>,
287        json_schema: Option<StructuredOutputFormat>,
288    ) -> Self {
289        let mut builder = Client::builder();
290        if let Some(sec) = timeout_seconds {
291            builder = builder.timeout(Duration::from_secs(sec));
292        }
293        Self {
294            api_key: api_key.into(),
295            base_url: Url::parse(
296                &base_url.unwrap_or_else(|| "https://api.cohere.ai/compatibility/v1/".to_owned()),
297            )
298            .expect("Failed to parse base Url"),
299            model: model.unwrap_or("command-light".to_string()),
300            max_tokens,
301            temperature,
302            system,
303            timeout_seconds,
304            stream,
305            top_p,
306            top_k,
307            tools,
308            tool_choice,
309            embedding_encoding_format,
310            embedding_dimensions,
311            reasoning_effort,
312            json_schema,
313            client: builder.build().expect("Failed to build reqwest Client"),
314        }
315    }
316}
317
318#[async_trait]
319impl ChatProvider for Cohere {
320    /// Sends a chat request to Cohere's API (optionally with tool usage).
321    async fn chat_with_tools(
322        &self,
323        messages: &[ChatMessage],
324        tools: Option<&[Tool]>,
325    ) -> Result<Box<dyn ChatResponse>, LLMError> {
326        if self.api_key.is_empty() {
327            return Err(LLMError::AuthError("Missing Cohere API key".to_string()));
328        }
329        // Clone messages to own them
330        let messages = messages.to_vec();
331        let mut cohere_msgs: Vec<CohereChatMessage> = vec![];
332
333        for msg in messages {
334            if let MessageType::ToolResult(ref results) = msg.message_type {
335                // Include tool result as a message with role "tool"
336                for result in results {
337                    cohere_msgs.push(CohereChatMessage {
338                        role: "tool",
339                        tool_call_id: Some(result.id.clone()),
340                        tool_calls: None,
341                        content: Some(Right(result.function.arguments.clone())),
342                    });
343                }
344            } else {
345                cohere_msgs.push(chat_message_to_api_message(msg));
346            }
347        }
348
349        // Prepend system prompt as a "developer" role message if provided
350        if let Some(system) = &self.system {
351            cohere_msgs.insert(
352                0,
353                CohereChatMessage {
354                    role: "developer",
355                    content: Some(Left(vec![CohereMessageContent {
356                        message_type: Some("text"),
357                        text: Some(system),
358                        image_url: None,
359                        tool_call_id: None,
360                        tool_output: None,
361                    }])),
362                    tool_calls: None,
363                    tool_call_id: None,
364                },
365            );
366        }
367
368        let response_format: Option<CohereResponseFormat> =
369            self.json_schema.clone().map(|s| s.into());
370        let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone());
371        let request_tool_choice = if request_tools.is_some() {
372            self.tool_choice.clone()
373        } else {
374            None
375        };
376
377        // Build the request payload
378        let body = CohereChatRequest {
379            model: &self.model,
380            messages: cohere_msgs,
381            max_tokens: self.max_tokens,
382            temperature: self.temperature,
383            stream: self.stream.unwrap_or(false),
384            top_p: self.top_p,
385            top_k: self.top_k,
386            tools: request_tools,
387            tool_choice: request_tool_choice,
388            reasoning_effort: self.reasoning_effort.clone(),
389            response_format,
390        };
391
392        let url = self
393            .base_url
394            .join("chat/completions")
395            .map_err(|e| LLMError::HttpError(e.to_string()))?;
396        let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body);
397
398        if log::log_enabled!(log::Level::Trace) {
399            if let Ok(json) = serde_json::to_string(&body) {
400                log::trace!("Cohere request payload: {}", json);
401            }
402        }
403        if let Some(timeout) = self.timeout_seconds {
404            request = request.timeout(Duration::from_secs(timeout));
405        }
406        let response = request.send().await?;
407        log::debug!("Cohere HTTP status: {}", response.status());
408
409        if !response.status().is_success() {
410            let status = response.status();
411            let error_text = response.text().await?;
412            return Err(LLMError::ResponseFormatError {
413                message: format!("Cohere API returned error status: {}", status),
414                raw_response: error_text,
415            });
416        }
417        // Parse the successful response
418        let resp_text = response.text().await?;
419        let json_resp: Result<CohereChatResponse, serde_json::Error> =
420            serde_json::from_str(&resp_text);
421        match json_resp {
422            Ok(res) => Ok(Box::new(res)),
423            Err(e) => Err(LLMError::ResponseFormatError {
424                message: format!("Failed to decode Cohere API response: {}", e),
425                raw_response: resp_text,
426            }),
427        }
428    }
429
430    async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
431        self.chat_with_tools(messages, None).await
432    }
433
434    /// Sends a streaming chat request to Cohere's API.
435    ///
436    /// # Returns
437    /// A stream of response text chunks or an error if the request fails.
438    async fn chat_stream(
439        &self,
440        messages: &[ChatMessage],
441    ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
442    {
443        if self.api_key.is_empty() {
444            return Err(LLMError::AuthError("Missing Cohere API key".to_string()));
445        }
446        let messages = messages.to_vec();
447        let mut cohere_msgs: Vec<CohereChatMessage> = vec![];
448
449        for msg in messages {
450            if let MessageType::ToolResult(ref results) = msg.message_type {
451                for result in results {
452                    cohere_msgs.push(CohereChatMessage {
453                        role: "tool",
454                        tool_call_id: Some(result.id.clone()),
455                        tool_calls: None,
456                        content: Some(Right(result.function.arguments.clone())),
457                    });
458                }
459            } else {
460                cohere_msgs.push(chat_message_to_api_message(msg));
461            }
462        }
463        if let Some(system) = &self.system {
464            cohere_msgs.insert(
465                0,
466                CohereChatMessage {
467                    role: "developer",
468                    content: Some(Left(vec![CohereMessageContent {
469                        message_type: Some("text"),
470                        text: Some(system),
471                        image_url: None,
472                        tool_call_id: None,
473                        tool_output: None,
474                    }])),
475                    tool_calls: None,
476                    tool_call_id: None,
477                },
478            );
479        }
480
481        let body = CohereChatRequest {
482            model: &self.model,
483            messages: cohere_msgs,
484            max_tokens: self.max_tokens,
485            temperature: self.temperature,
486            stream: true,
487            top_p: self.top_p,
488            top_k: self.top_k,
489            tools: self.tools.clone(),
490            tool_choice: self.tool_choice.clone(),
491            reasoning_effort: self.reasoning_effort.clone(),
492            response_format: None,
493        };
494        let url = self
495            .base_url
496            .join("chat/completions")
497            .map_err(|e| LLMError::HttpError(e.to_string()))?;
498        let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body);
499        if let Some(timeout) = self.timeout_seconds {
500            request = request.timeout(Duration::from_secs(timeout));
501        }
502        let response = request.send().await?;
503        if !response.status().is_success() {
504            let status = response.status();
505            let error_text = response.text().await?;
506            return Err(LLMError::ResponseFormatError {
507                message: format!("Cohere API returned error status: {}", status),
508                raw_response: error_text,
509            });
510        }
511        // Return a Server-Sent Events stream of the response content
512        Ok(crate::chat::create_sse_stream(response, parse_sse_chunk))
513    }
514}
515
516// Convert a ChatMessage into a CohereChatMessage with 'static lifetime.
517fn chat_message_to_api_message(chat_msg: ChatMessage) -> CohereChatMessage<'static> {
518    CohereChatMessage {
519        role: match chat_msg.role {
520            ChatRole::User => "user",
521            ChatRole::Assistant => "assistant",
522        },
523        tool_call_id: None,
524        content: match &chat_msg.message_type {
525            MessageType::Text => Some(Right(chat_msg.content.clone())),
526            MessageType::Image(_) => unreachable!(),
527            MessageType::Pdf(_) => unimplemented!(),
528            MessageType::ImageURL(url) => {
529                let owned_url = url.clone();
530                let url_str = Box::leak(owned_url.into_boxed_str());
531                Some(Left(vec![CohereMessageContent {
532                    message_type: Some("image_url"),
533                    text: None,
534                    image_url: Some(ImageUrlContent { url: url_str }),
535                    tool_output: None,
536                    tool_call_id: None,
537                }]))
538            }
539            MessageType::ToolUse(_) => None,
540            MessageType::ToolResult(_) => None,
541        },
542        tool_calls: match &chat_msg.message_type {
543            MessageType::ToolUse(calls) => {
544                let owned_calls: Vec<CohereFunctionCall<'static>> = calls
545                    .iter()
546                    .map(|c| {
547                        let owned_id = c.id.clone();
548                        let owned_name = c.function.name.clone();
549                        let owned_args = c.function.arguments.clone();
550                        // Leak strings to static lifetime
551                        let id_str = Box::leak(owned_id.into_boxed_str());
552                        let name_str = Box::leak(owned_name.into_boxed_str());
553                        let args_str = Box::leak(owned_args.into_boxed_str());
554                        CohereFunctionCall {
555                            id: id_str,
556                            content_type: "function",
557                            function: CohereFunctionPayload {
558                                name: name_str,
559                                arguments: args_str,
560                            },
561                        }
562                    })
563                    .collect();
564                Some(owned_calls)
565            }
566            _ => None,
567        },
568    }
569}
570
571/// SSE (Server-Sent Events) chunk parser for Cohere's streaming responses.
572///
573/// Parses an SSE data chunk and extracts any generated content.
574fn parse_sse_chunk(chunk: &str) -> Result<Option<String>, LLMError> {
575    let mut collected_content = String::new();
576    for line in chunk.lines() {
577        let line = line.trim();
578        if let Some(data) = line.strip_prefix("data: ") {
579            if data == "[DONE]" {
580                return if collected_content.is_empty() {
581                    Ok(None)
582                } else {
583                    Ok(Some(collected_content))
584                };
585            }
586            match serde_json::from_str::<CohereChatStreamResponse>(data) {
587                Ok(response) => {
588                    if let Some(choice) = response.choices.first() {
589                        if let Some(content) = &choice.delta.content {
590                            collected_content.push_str(content);
591                        }
592                    }
593                }
594                Err(_) => continue,
595            }
596        }
597    }
598    if collected_content.is_empty() {
599        Ok(None)
600    } else {
601        Ok(Some(collected_content))
602    }
603}
604
605#[derive(Deserialize, Debug)]
606struct CohereChatStreamResponse {
607    choices: Vec<CohereChatStreamChoice>,
608}
609#[derive(Deserialize, Debug)]
610struct CohereChatStreamChoice {
611    delta: CohereChatStreamDelta,
612}
613#[derive(Deserialize, Debug)]
614struct CohereChatStreamDelta {
615    content: Option<String>,
616}
617
618#[async_trait]
619impl CompletionProvider for Cohere {
620    /// Sends a completion request to Cohere's API (not supported in compatibility mode).
621    async fn complete(&self, _req: & CompletionRequest) -> Result<CompletionResponse, LLMError> {
622        Ok(CompletionResponse {
623            text: "Cohere completion not implemented.".into(),
624        })
625    }
626}
627
628#[async_trait]
629impl EmbeddingProvider for Cohere {
630    /// Generates embeddings for the given input texts using Cohere's API.
631    async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
632        if self.api_key.is_empty() {
633            return Err(LLMError::AuthError("Missing Cohere API key".into()));
634        }
635        let emb_format = self
636            .embedding_encoding_format
637            .clone()
638            .unwrap_or_else(|| "float".to_string());
639        let body = CohereEmbeddingRequest {
640            model: self.model.clone(),
641            input,
642            encoding_format: Some(emb_format),
643            dimensions: self.embedding_dimensions,
644        };
645        let url = self
646            .base_url
647            .join("embeddings")
648            .map_err(|e| LLMError::HttpError(e.to_string()))?;
649        let resp = self
650            .client
651            .post(url)
652            .bearer_auth(&self.api_key)
653            .json(&body)
654            .send()
655            .await?
656            .error_for_status()?;
657        let json_resp: CohereEmbeddingResponse = resp.json().await?;
658        let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
659        Ok(embeddings)
660    }
661}
662
663impl LLMProvider for Cohere {
664    fn tools(&self) -> Option<&[Tool]> {
665        self.tools.as_deref()
666    }
667}
668
669#[async_trait]
670impl SpeechToTextProvider for Cohere {
671    /// Transcribing audio is not supported by Cohere.
672    async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
673        Err(LLMError::ProviderError(
674            "Cohere does not implement speech-to-text.".into(),
675        ))
676    }
677}
678
679#[async_trait]
680impl TextToSpeechProvider for Cohere {
681    /// Text-to-speech conversion is not supported by Cohere.
682    async fn speech(&self, _text: &str) -> Result<Vec<u8>, LLMError> {
683        Err(LLMError::ProviderError(
684            "Text-to-speech not supported by Cohere.".into(),
685        ))
686    }
687}
688
689#[async_trait]
690impl ModelsProvider for Cohere {
691    // Uses default implementation: listing models is not supported by Cohere
692}