1use crate::providers::openai_compatible::{OpenAICompatibleProvider, OpenAIProviderConfig};
6use crate::{
7 chat::{StructuredOutputFormat, Tool, ToolChoice},
8 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
9 embedding::EmbeddingProvider,
10 error::LLMError,
11 models::ModelsProvider,
12 stt::SpeechToTextProvider,
13 tts::TextToSpeechProvider,
14 LLMProvider,
15};
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18
19pub struct CohereConfig;
21
22impl OpenAIProviderConfig for CohereConfig {
23 const PROVIDER_NAME: &'static str = "Cohere";
24 const DEFAULT_BASE_URL: &'static str = "https://api.cohere.ai/compatibility/v1/";
25 const DEFAULT_MODEL: &'static str = "command-r7b-12-2024";
28 const SUPPORTS_REASONING_EFFORT: bool = false;
29 const SUPPORTS_STRUCTURED_OUTPUT: bool = true;
30 const SUPPORTS_PARALLEL_TOOL_CALLS: bool = false;
31}
32
33pub type Cohere = OpenAICompatibleProvider<CohereConfig>;
35
36impl Cohere {
37 #[allow(clippy::too_many_arguments)]
39 pub fn with_config(
40 api_key: impl Into<String>,
41 base_url: Option<String>,
42 model: Option<String>,
43 max_tokens: Option<u32>,
44 temperature: Option<f32>,
45 timeout_seconds: Option<u64>,
46 system: Option<String>,
47 top_p: Option<f32>,
48 top_k: Option<u32>,
49 tools: Option<Vec<Tool>>,
50 tool_choice: Option<ToolChoice>,
51 embedding_encoding_format: Option<String>,
52 embedding_dimensions: Option<u32>,
53 reasoning_effort: Option<String>,
54 json_schema: Option<StructuredOutputFormat>,
55 parallel_tool_calls: Option<bool>,
56 ) -> Self {
57 <OpenAICompatibleProvider<CohereConfig>>::new(
58 api_key,
59 base_url,
60 model,
61 max_tokens,
62 temperature,
63 timeout_seconds,
64 system,
65 top_p,
66 top_k,
67 tools,
68 tool_choice,
69 reasoning_effort,
70 json_schema,
71 None, parallel_tool_calls,
73 embedding_encoding_format,
74 embedding_dimensions,
75 )
76 }
77}
78
79#[derive(Serialize)]
82struct CohereEmbeddingRequest {
83 model: String,
84 input: Vec<String>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 encoding_format: Option<String>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 dimensions: Option<u32>,
89}
90
91#[derive(Deserialize, Debug)]
92struct CohereEmbeddingData {
93 embedding: Vec<f32>,
94}
95
96#[derive(Deserialize, Debug)]
97struct CohereEmbeddingResponse {
98 data: Vec<CohereEmbeddingData>,
99}
100
101impl LLMProvider for Cohere {
102 fn tools(&self) -> Option<&[Tool]> {
103 self.tools.as_deref()
104 }
105}
106
107#[async_trait]
108impl CompletionProvider for Cohere {
109 async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
110 Ok(CompletionResponse {
111 text: "Cohere completion not implemented.".into(),
112 })
113 }
114}
115
116#[async_trait]
117impl SpeechToTextProvider for Cohere {
118 async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
119 Err(LLMError::ProviderError(
120 "Cohere does not support speech-to-text".into(),
121 ))
122 }
123
124 async fn transcribe_file(&self, _file_path: &str) -> Result<String, LLMError> {
125 Err(LLMError::ProviderError(
126 "Cohere does not support speech-to-text".into(),
127 ))
128 }
129}
130
131#[cfg(feature = "cohere")]
132#[async_trait]
133impl EmbeddingProvider for Cohere {
134 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
135 if self.api_key.is_empty() {
136 return Err(LLMError::AuthError("Missing Cohere API key".into()));
137 }
138
139 let body = CohereEmbeddingRequest {
140 model: self.model.clone(),
141 input,
142 encoding_format: self.embedding_encoding_format.clone(),
143 dimensions: self.embedding_dimensions,
144 };
145
146 let url = self
147 .base_url
148 .join("embeddings")
149 .map_err(|e| LLMError::HttpError(e.to_string()))?;
150
151 let resp = self
152 .client
153 .post(url)
154 .bearer_auth(&self.api_key)
155 .json(&body)
156 .send()
157 .await?
158 .error_for_status()?;
159
160 let json_resp: CohereEmbeddingResponse = resp.json().await?;
161 let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
162 Ok(embeddings)
163 }
164}
165
166#[async_trait]
167impl ModelsProvider for Cohere {
168 async fn list_models(
169 &self,
170 _request: Option<&crate::models::ModelListRequest>,
171 ) -> Result<Box<dyn crate::models::ModelListResponse>, LLMError> {
172 Err(LLMError::ProviderError(
173 "Cohere does not provide a models listing endpoint".into(),
174 ))
175 }
176}
177
178#[async_trait]
179impl TextToSpeechProvider for Cohere {
180 async fn speech(&self, _text: &str) -> Result<Vec<u8>, LLMError> {
181 Err(LLMError::ProviderError(
182 "Cohere does not support text-to-speech".into(),
183 ))
184 }
185}