llm/backends/
cohere.rs

1//! Cohere API client implementation using the OpenAI-compatible base
2//!
3//! This module provides integration with Cohere's LLM models through their API.
4
5use crate::providers::openai_compatible::{OpenAICompatibleProvider, OpenAIProviderConfig};
6use crate::{
7    chat::{StructuredOutputFormat, Tool, ToolChoice},
8    completion::{CompletionProvider, CompletionRequest, CompletionResponse},
9    embedding::EmbeddingProvider,
10    error::LLMError,
11    models::ModelsProvider,
12    stt::SpeechToTextProvider,
13    tts::TextToSpeechProvider,
14    LLMProvider,
15};
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18
19/// Cohere configuration for the generic provider
20pub struct CohereConfig;
21
22impl OpenAIProviderConfig for CohereConfig {
23    const PROVIDER_NAME: &'static str = "Cohere";
24    const DEFAULT_BASE_URL: &'static str = "https://api.cohere.ai/compatibility/v1/";
25    // NOTE: upgrading to v2 (not OpenAI-compatible) is required to get usage in streaming responses
26    // const DEFAULT_BASE_URL: &'static str = "https://api.cohere.com/v2/chat/";
27    const DEFAULT_MODEL: &'static str = "command-r7b-12-2024";
28    const SUPPORTS_REASONING_EFFORT: bool = false;
29    const SUPPORTS_STRUCTURED_OUTPUT: bool = true;
30    const SUPPORTS_PARALLEL_TOOL_CALLS: bool = false;
31}
32
33/// Type alias for Cohere client using the generic OpenAI-compatible provider
34pub type Cohere = OpenAICompatibleProvider<CohereConfig>;
35
36impl Cohere {
37    /// Creates a new Cohere client with the specified configuration.
38    #[allow(clippy::too_many_arguments)]
39    pub fn with_config(
40        api_key: impl Into<String>,
41        base_url: Option<String>,
42        model: Option<String>,
43        max_tokens: Option<u32>,
44        temperature: Option<f32>,
45        timeout_seconds: Option<u64>,
46        system: Option<String>,
47        top_p: Option<f32>,
48        top_k: Option<u32>,
49        tools: Option<Vec<Tool>>,
50        tool_choice: Option<ToolChoice>,
51        embedding_encoding_format: Option<String>,
52        embedding_dimensions: Option<u32>,
53        reasoning_effort: Option<String>,
54        json_schema: Option<StructuredOutputFormat>,
55        parallel_tool_calls: Option<bool>,
56    ) -> Self {
57        <OpenAICompatibleProvider<CohereConfig>>::new(
58            api_key,
59            base_url,
60            model,
61            max_tokens,
62            temperature,
63            timeout_seconds,
64            system,
65            top_p,
66            top_k,
67            tools,
68            tool_choice,
69            reasoning_effort,
70            json_schema,
71            None, // voice - not supported by Cohere
72            parallel_tool_calls,
73            embedding_encoding_format,
74            embedding_dimensions,
75        )
76    }
77}
78
79// Cohere-specific implementations that don't fit in the generic OpenAI-compatible provider
80
81#[derive(Serialize)]
82struct CohereEmbeddingRequest {
83    model: String,
84    input: Vec<String>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    encoding_format: Option<String>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    dimensions: Option<u32>,
89}
90
91#[derive(Deserialize, Debug)]
92struct CohereEmbeddingData {
93    embedding: Vec<f32>,
94}
95
96#[derive(Deserialize, Debug)]
97struct CohereEmbeddingResponse {
98    data: Vec<CohereEmbeddingData>,
99}
100
101impl LLMProvider for Cohere {
102    fn tools(&self) -> Option<&[Tool]> {
103        self.tools.as_deref()
104    }
105}
106
107#[async_trait]
108impl CompletionProvider for Cohere {
109    async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
110        Ok(CompletionResponse {
111            text: "Cohere completion not implemented.".into(),
112        })
113    }
114}
115
116#[async_trait]
117impl SpeechToTextProvider for Cohere {
118    async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
119        Err(LLMError::ProviderError(
120            "Cohere does not support speech-to-text".into(),
121        ))
122    }
123
124    async fn transcribe_file(&self, _file_path: &str) -> Result<String, LLMError> {
125        Err(LLMError::ProviderError(
126            "Cohere does not support speech-to-text".into(),
127        ))
128    }
129}
130
131#[cfg(feature = "cohere")]
132#[async_trait]
133impl EmbeddingProvider for Cohere {
134    async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
135        if self.api_key.is_empty() {
136            return Err(LLMError::AuthError("Missing Cohere API key".into()));
137        }
138
139        let body = CohereEmbeddingRequest {
140            model: self.model.clone(),
141            input,
142            encoding_format: self.embedding_encoding_format.clone(),
143            dimensions: self.embedding_dimensions,
144        };
145
146        let url = self
147            .base_url
148            .join("embeddings")
149            .map_err(|e| LLMError::HttpError(e.to_string()))?;
150
151        let resp = self
152            .client
153            .post(url)
154            .bearer_auth(&self.api_key)
155            .json(&body)
156            .send()
157            .await?
158            .error_for_status()?;
159
160        let json_resp: CohereEmbeddingResponse = resp.json().await?;
161        let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
162        Ok(embeddings)
163    }
164}
165
166#[async_trait]
167impl ModelsProvider for Cohere {
168    async fn list_models(
169        &self,
170        _request: Option<&crate::models::ModelListRequest>,
171    ) -> Result<Box<dyn crate::models::ModelListResponse>, LLMError> {
172        Err(LLMError::ProviderError(
173            "Cohere does not provide a models listing endpoint".into(),
174        ))
175    }
176}
177
178#[async_trait]
179impl TextToSpeechProvider for Cohere {
180    async fn speech(&self, _text: &str) -> Result<Vec<u8>, LLMError> {
181        Err(LLMError::ProviderError(
182            "Cohere does not support text-to-speech".into(),
183        ))
184    }
185}