1use std::time::Duration;
5
6#[cfg(feature = "cohere")]
7use crate::{
8 chat::Tool,
9 chat::{ChatMessage, ChatProvider, ChatRole, MessageType, StructuredOutputFormat},
10 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
11 embedding::EmbeddingProvider,
12 error::LLMError,
13 models::{ModelsProvider},
14 stt::SpeechToTextProvider,
15 tts::TextToSpeechProvider,
16 LLMProvider,
17};
18#[cfg(feature = "cohere")]
19use crate::{
20 chat::{ChatResponse, ToolChoice},
21 ToolCall,
22};
23use async_trait::async_trait;
24use either::*;
25use futures::stream::Stream;
26use reqwest::{Client, Url};
27use serde::{Deserialize, Serialize};
28
29pub struct Cohere {
34 pub api_key: String,
35 pub base_url: Url,
36 pub model: String,
37 pub max_tokens: Option<u32>,
38 pub temperature: Option<f32>,
39 pub system: Option<String>,
40 pub timeout_seconds: Option<u64>,
41 pub stream: Option<bool>,
42 pub top_p: Option<f32>,
43 pub top_k: Option<u32>,
44 pub tools: Option<Vec<Tool>>,
45 pub tool_choice: Option<ToolChoice>,
46 pub embedding_encoding_format: Option<String>,
48 pub embedding_dimensions: Option<u32>,
49 pub reasoning_effort: Option<String>,
50 pub json_schema: Option<StructuredOutputFormat>,
52 client: Client,
53}
54
55#[derive(Serialize, Debug)]
57struct CohereChatMessage<'a> {
58 #[allow(dead_code)]
59 role: &'a str,
60 #[serde(
61 skip_serializing_if = "Option::is_none",
62 with = "either::serde_untagged_optional"
63 )]
64 content: Option<Either<Vec<CohereMessageContent<'a>>, String>>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 tool_calls: Option<Vec<CohereFunctionCall<'a>>>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 tool_call_id: Option<String>,
69}
70
71#[derive(Serialize, Debug)]
72struct CohereFunctionPayload<'a> {
73 name: &'a str,
74 arguments: &'a str,
75}
76
77#[derive(Serialize, Debug)]
78struct CohereFunctionCall<'a> {
79 id: &'a str,
80 #[serde(rename = "type")]
81 content_type: &'a str,
82 function: CohereFunctionPayload<'a>,
83}
84
85#[derive(Serialize, Debug)]
86struct CohereMessageContent<'a> {
87 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
88 message_type: Option<&'a str>,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 text: Option<&'a str>,
91 #[serde(skip_serializing_if = "Option::is_none")]
92 image_url: Option<ImageUrlContent<'a>>,
93 #[serde(skip_serializing_if = "Option::is_none", rename = "tool_call_id")]
94 tool_call_id: Option<&'a str>,
95 #[serde(skip_serializing_if = "Option::is_none", rename = "content")]
96 tool_output: Option<&'a str>,
97}
98
99#[derive(Serialize, Debug)]
101struct ImageUrlContent<'a> {
102 url: &'a str,
103}
104
105#[derive(Serialize)]
106struct CohereEmbeddingRequest {
107 model: String,
108 input: Vec<String>,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 encoding_format: Option<String>,
111 #[serde(skip_serializing_if = "Option::is_none")]
112 dimensions: Option<u32>,
113}
114
115#[derive(Serialize, Debug)]
117struct CohereChatRequest<'a> {
118 model: &'a str,
119 messages: Vec<CohereChatMessage<'a>>,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 max_tokens: Option<u32>,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 temperature: Option<f32>,
124 stream: bool,
125 #[serde(skip_serializing_if = "Option::is_none")]
126 top_p: Option<f32>,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 top_k: Option<u32>,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 tools: Option<Vec<Tool>>,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 tool_choice: Option<ToolChoice>,
133 #[serde(skip_serializing_if = "Option::is_none")]
134 reasoning_effort: Option<String>,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 response_format: Option<CohereResponseFormat>,
137}
138
139#[derive(Deserialize, Debug)]
141struct CohereChatResponse {
142 choices: Vec<CohereChatChoice>,
143}
144
145#[derive(Deserialize, Debug)]
147struct CohereChatChoice {
148 message: CohereChatMsg,
149}
150
151#[derive(Deserialize, Debug)]
153struct CohereChatMsg {
154 #[allow(dead_code)]
155 role: String,
156 content: Option<String>,
157 tool_calls: Option<Vec<ToolCall>>,
158}
159
160#[derive(Deserialize, Debug)]
162struct CohereEmbeddingData {
163 embedding: Vec<f32>,
164}
165#[derive(Deserialize, Debug)]
166struct CohereEmbeddingResponse {
167 data: Vec<CohereEmbeddingData>,
168}
169
170#[derive(Deserialize, Debug, Serialize)]
172enum CohereResponseType {
173 #[serde(rename = "text")]
174 Text,
175 #[serde(rename = "json_schema")]
176 JsonSchema,
177 #[serde(rename = "json_object")]
178 JsonObject,
179}
180
181#[derive(Deserialize, Debug, Serialize)]
183struct CohereResponseFormat {
184 #[serde(rename = "type")]
185 response_type: CohereResponseType,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 json_schema: Option<StructuredOutputFormat>,
188}
189
190impl From<StructuredOutputFormat> for CohereResponseFormat {
191 fn from(structured_response_format: StructuredOutputFormat) -> Self {
192 match structured_response_format.schema {
193 None => CohereResponseFormat {
194 response_type: CohereResponseType::JsonSchema,
195 json_schema: Some(structured_response_format),
196 },
197 Some(mut schema) => {
198 if schema.get("additionalProperties").is_none() {
200 schema["additionalProperties"] = serde_json::json!(false);
201 }
202 CohereResponseFormat {
203 response_type: CohereResponseType::JsonSchema,
204 json_schema: Some(StructuredOutputFormat {
205 name: structured_response_format.name,
206 description: structured_response_format.description,
207 schema: Some(schema),
208 strict: structured_response_format.strict,
209 }),
210 }
211 }
212 }
213 }
214}
215
216impl ChatResponse for CohereChatResponse {
217 fn text(&self) -> Option<String> {
218 self.choices.first().and_then(|c| c.message.content.clone())
219 }
220 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
221 self.choices.first().and_then(|c| c.message.tool_calls.clone())
222 }
223}
224
225impl std::fmt::Display for CohereChatResponse {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 match (
228 &self.choices.first().unwrap().message.content,
229 &self.choices.first().unwrap().message.tool_calls,
230 ) {
231 (Some(content), Some(tool_calls)) => {
232 for tool_call in tool_calls {
233 write!(f, "{}", tool_call)?;
234 }
235 write!(f, "{}", content)
236 }
237 (Some(content), None) => write!(f, "{}", content),
238 (None, Some(tool_calls)) => {
239 for tool_call in tool_calls {
240 write!(f, "{}", tool_call)?;
241 }
242 Ok(())
243 }
244 (None, None) => write!(f, ""),
245 }
246 }
247}
248
249impl Cohere {
250 #[allow(clippy::too_many_arguments)]
271 pub fn new(
272 api_key: impl Into<String>,
273 base_url: Option<String>,
274 model: Option<String>,
275 max_tokens: Option<u32>,
276 temperature: Option<f32>,
277 timeout_seconds: Option<u64>,
278 system: Option<String>,
279 stream: Option<bool>,
280 top_p: Option<f32>,
281 top_k: Option<u32>,
282 embedding_encoding_format: Option<String>,
283 embedding_dimensions: Option<u32>,
284 tools: Option<Vec<Tool>>,
285 tool_choice: Option<ToolChoice>,
286 reasoning_effort: Option<String>,
287 json_schema: Option<StructuredOutputFormat>,
288 ) -> Self {
289 let mut builder = Client::builder();
290 if let Some(sec) = timeout_seconds {
291 builder = builder.timeout(Duration::from_secs(sec));
292 }
293 Self {
294 api_key: api_key.into(),
295 base_url: Url::parse(
296 &base_url.unwrap_or_else(|| "https://api.cohere.ai/compatibility/v1/".to_owned()),
297 )
298 .expect("Failed to parse base Url"),
299 model: model.unwrap_or("command-light".to_string()),
300 max_tokens,
301 temperature,
302 system,
303 timeout_seconds,
304 stream,
305 top_p,
306 top_k,
307 tools,
308 tool_choice,
309 embedding_encoding_format,
310 embedding_dimensions,
311 reasoning_effort,
312 json_schema,
313 client: builder.build().expect("Failed to build reqwest Client"),
314 }
315 }
316}
317
318#[async_trait]
319impl ChatProvider for Cohere {
320 async fn chat_with_tools(
322 &self,
323 messages: &[ChatMessage],
324 tools: Option<&[Tool]>,
325 ) -> Result<Box<dyn ChatResponse>, LLMError> {
326 if self.api_key.is_empty() {
327 return Err(LLMError::AuthError("Missing Cohere API key".to_string()));
328 }
329 let messages = messages.to_vec();
331 let mut cohere_msgs: Vec<CohereChatMessage> = vec![];
332
333 for msg in messages {
334 if let MessageType::ToolResult(ref results) = msg.message_type {
335 for result in results {
337 cohere_msgs.push(CohereChatMessage {
338 role: "tool",
339 tool_call_id: Some(result.id.clone()),
340 tool_calls: None,
341 content: Some(Right(result.function.arguments.clone())),
342 });
343 }
344 } else {
345 cohere_msgs.push(chat_message_to_api_message(msg));
346 }
347 }
348
349 if let Some(system) = &self.system {
351 cohere_msgs.insert(
352 0,
353 CohereChatMessage {
354 role: "developer",
355 content: Some(Left(vec![CohereMessageContent {
356 message_type: Some("text"),
357 text: Some(system),
358 image_url: None,
359 tool_call_id: None,
360 tool_output: None,
361 }])),
362 tool_calls: None,
363 tool_call_id: None,
364 },
365 );
366 }
367
368 let response_format: Option<CohereResponseFormat> =
369 self.json_schema.clone().map(|s| s.into());
370 let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone());
371 let request_tool_choice = if request_tools.is_some() {
372 self.tool_choice.clone()
373 } else {
374 None
375 };
376
377 let body = CohereChatRequest {
379 model: &self.model,
380 messages: cohere_msgs,
381 max_tokens: self.max_tokens,
382 temperature: self.temperature,
383 stream: self.stream.unwrap_or(false),
384 top_p: self.top_p,
385 top_k: self.top_k,
386 tools: request_tools,
387 tool_choice: request_tool_choice,
388 reasoning_effort: self.reasoning_effort.clone(),
389 response_format,
390 };
391
392 let url = self
393 .base_url
394 .join("chat/completions")
395 .map_err(|e| LLMError::HttpError(e.to_string()))?;
396 let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body);
397
398 if log::log_enabled!(log::Level::Trace) {
399 if let Ok(json) = serde_json::to_string(&body) {
400 log::trace!("Cohere request payload: {}", json);
401 }
402 }
403 if let Some(timeout) = self.timeout_seconds {
404 request = request.timeout(Duration::from_secs(timeout));
405 }
406 let response = request.send().await?;
407 log::debug!("Cohere HTTP status: {}", response.status());
408
409 if !response.status().is_success() {
410 let status = response.status();
411 let error_text = response.text().await?;
412 return Err(LLMError::ResponseFormatError {
413 message: format!("Cohere API returned error status: {}", status),
414 raw_response: error_text,
415 });
416 }
417 let resp_text = response.text().await?;
419 let json_resp: Result<CohereChatResponse, serde_json::Error> =
420 serde_json::from_str(&resp_text);
421 match json_resp {
422 Ok(res) => Ok(Box::new(res)),
423 Err(e) => Err(LLMError::ResponseFormatError {
424 message: format!("Failed to decode Cohere API response: {}", e),
425 raw_response: resp_text,
426 }),
427 }
428 }
429
430 async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
431 self.chat_with_tools(messages, None).await
432 }
433
434 async fn chat_stream(
439 &self,
440 messages: &[ChatMessage],
441 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
442 {
443 if self.api_key.is_empty() {
444 return Err(LLMError::AuthError("Missing Cohere API key".to_string()));
445 }
446 let messages = messages.to_vec();
447 let mut cohere_msgs: Vec<CohereChatMessage> = vec![];
448
449 for msg in messages {
450 if let MessageType::ToolResult(ref results) = msg.message_type {
451 for result in results {
452 cohere_msgs.push(CohereChatMessage {
453 role: "tool",
454 tool_call_id: Some(result.id.clone()),
455 tool_calls: None,
456 content: Some(Right(result.function.arguments.clone())),
457 });
458 }
459 } else {
460 cohere_msgs.push(chat_message_to_api_message(msg));
461 }
462 }
463 if let Some(system) = &self.system {
464 cohere_msgs.insert(
465 0,
466 CohereChatMessage {
467 role: "developer",
468 content: Some(Left(vec![CohereMessageContent {
469 message_type: Some("text"),
470 text: Some(system),
471 image_url: None,
472 tool_call_id: None,
473 tool_output: None,
474 }])),
475 tool_calls: None,
476 tool_call_id: None,
477 },
478 );
479 }
480
481 let body = CohereChatRequest {
482 model: &self.model,
483 messages: cohere_msgs,
484 max_tokens: self.max_tokens,
485 temperature: self.temperature,
486 stream: true,
487 top_p: self.top_p,
488 top_k: self.top_k,
489 tools: self.tools.clone(),
490 tool_choice: self.tool_choice.clone(),
491 reasoning_effort: self.reasoning_effort.clone(),
492 response_format: None,
493 };
494 let url = self
495 .base_url
496 .join("chat/completions")
497 .map_err(|e| LLMError::HttpError(e.to_string()))?;
498 let mut request = self.client.post(url).bearer_auth(&self.api_key).json(&body);
499 if let Some(timeout) = self.timeout_seconds {
500 request = request.timeout(Duration::from_secs(timeout));
501 }
502 let response = request.send().await?;
503 if !response.status().is_success() {
504 let status = response.status();
505 let error_text = response.text().await?;
506 return Err(LLMError::ResponseFormatError {
507 message: format!("Cohere API returned error status: {}", status),
508 raw_response: error_text,
509 });
510 }
511 Ok(crate::chat::create_sse_stream(response, parse_sse_chunk))
513 }
514}
515
516fn chat_message_to_api_message(chat_msg: ChatMessage) -> CohereChatMessage<'static> {
518 CohereChatMessage {
519 role: match chat_msg.role {
520 ChatRole::User => "user",
521 ChatRole::Assistant => "assistant",
522 },
523 tool_call_id: None,
524 content: match &chat_msg.message_type {
525 MessageType::Text => Some(Right(chat_msg.content.clone())),
526 MessageType::Image(_) => unreachable!(),
527 MessageType::Pdf(_) => unimplemented!(),
528 MessageType::ImageURL(url) => {
529 let owned_url = url.clone();
530 let url_str = Box::leak(owned_url.into_boxed_str());
531 Some(Left(vec![CohereMessageContent {
532 message_type: Some("image_url"),
533 text: None,
534 image_url: Some(ImageUrlContent { url: url_str }),
535 tool_output: None,
536 tool_call_id: None,
537 }]))
538 }
539 MessageType::ToolUse(_) => None,
540 MessageType::ToolResult(_) => None,
541 },
542 tool_calls: match &chat_msg.message_type {
543 MessageType::ToolUse(calls) => {
544 let owned_calls: Vec<CohereFunctionCall<'static>> = calls
545 .iter()
546 .map(|c| {
547 let owned_id = c.id.clone();
548 let owned_name = c.function.name.clone();
549 let owned_args = c.function.arguments.clone();
550 let id_str = Box::leak(owned_id.into_boxed_str());
552 let name_str = Box::leak(owned_name.into_boxed_str());
553 let args_str = Box::leak(owned_args.into_boxed_str());
554 CohereFunctionCall {
555 id: id_str,
556 content_type: "function",
557 function: CohereFunctionPayload {
558 name: name_str,
559 arguments: args_str,
560 },
561 }
562 })
563 .collect();
564 Some(owned_calls)
565 }
566 _ => None,
567 },
568 }
569}
570
571fn parse_sse_chunk(chunk: &str) -> Result<Option<String>, LLMError> {
575 let mut collected_content = String::new();
576 for line in chunk.lines() {
577 let line = line.trim();
578 if let Some(data) = line.strip_prefix("data: ") {
579 if data == "[DONE]" {
580 return if collected_content.is_empty() {
581 Ok(None)
582 } else {
583 Ok(Some(collected_content))
584 };
585 }
586 match serde_json::from_str::<CohereChatStreamResponse>(data) {
587 Ok(response) => {
588 if let Some(choice) = response.choices.first() {
589 if let Some(content) = &choice.delta.content {
590 collected_content.push_str(content);
591 }
592 }
593 }
594 Err(_) => continue,
595 }
596 }
597 }
598 if collected_content.is_empty() {
599 Ok(None)
600 } else {
601 Ok(Some(collected_content))
602 }
603}
604
605#[derive(Deserialize, Debug)]
606struct CohereChatStreamResponse {
607 choices: Vec<CohereChatStreamChoice>,
608}
609#[derive(Deserialize, Debug)]
610struct CohereChatStreamChoice {
611 delta: CohereChatStreamDelta,
612}
613#[derive(Deserialize, Debug)]
614struct CohereChatStreamDelta {
615 content: Option<String>,
616}
617
618#[async_trait]
619impl CompletionProvider for Cohere {
620 async fn complete(&self, _req: & CompletionRequest) -> Result<CompletionResponse, LLMError> {
622 Ok(CompletionResponse {
623 text: "Cohere completion not implemented.".into(),
624 })
625 }
626}
627
628#[async_trait]
629impl EmbeddingProvider for Cohere {
630 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
632 if self.api_key.is_empty() {
633 return Err(LLMError::AuthError("Missing Cohere API key".into()));
634 }
635 let emb_format = self
636 .embedding_encoding_format
637 .clone()
638 .unwrap_or_else(|| "float".to_string());
639 let body = CohereEmbeddingRequest {
640 model: self.model.clone(),
641 input,
642 encoding_format: Some(emb_format),
643 dimensions: self.embedding_dimensions,
644 };
645 let url = self
646 .base_url
647 .join("embeddings")
648 .map_err(|e| LLMError::HttpError(e.to_string()))?;
649 let resp = self
650 .client
651 .post(url)
652 .bearer_auth(&self.api_key)
653 .json(&body)
654 .send()
655 .await?
656 .error_for_status()?;
657 let json_resp: CohereEmbeddingResponse = resp.json().await?;
658 let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
659 Ok(embeddings)
660 }
661}
662
663impl LLMProvider for Cohere {
664 fn tools(&self) -> Option<&[Tool]> {
665 self.tools.as_deref()
666 }
667}
668
669#[async_trait]
670impl SpeechToTextProvider for Cohere {
671 async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
673 Err(LLMError::ProviderError(
674 "Cohere does not implement speech-to-text.".into(),
675 ))
676 }
677}
678
679#[async_trait]
680impl TextToSpeechProvider for Cohere {
681 async fn speech(&self, _text: &str) -> Result<Vec<u8>, LLMError> {
683 Err(LLMError::ProviderError(
684 "Text-to-speech not supported by Cohere.".into(),
685 ))
686 }
687}
688
689#[async_trait]
690impl ModelsProvider for Cohere {
691 }