1use crate::{
2 OneOrMany,
3 completion::{self, CompletionError, GetTokenUsage},
4 http_client::{self, HttpClientExt},
5 json_utils,
6 message::{self, Reasoning, ToolChoice},
7 telemetry::SpanCombinator,
8};
9use std::collections::HashMap;
10
11use super::client::Client;
12use crate::completion::CompletionRequest;
13use crate::providers::cohere::streaming::StreamingCompletionResponse;
14use serde::{Deserialize, Serialize};
15use tracing::{Instrument, Level, enabled, info_span};
16
17#[derive(Debug, Deserialize, Serialize)]
18pub struct CompletionResponse {
19 pub id: String,
20 pub finish_reason: FinishReason,
21 message: Message,
22 #[serde(default)]
23 pub usage: Option<Usage>,
24}
25
26impl CompletionResponse {
27 pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
29 let Message::Assistant {
30 content,
31 citations,
32 tool_calls,
33 ..
34 } = self.message.clone()
35 else {
36 unreachable!("Completion responses will only return an assistant message")
37 };
38
39 (content, citations, tool_calls)
40 }
41}
42
43impl crate::telemetry::ProviderResponseExt for CompletionResponse {
44 type OutputMessage = Message;
45 type Usage = Usage;
46
47 fn get_response_id(&self) -> Option<String> {
48 Some(self.id.clone())
49 }
50
51 fn get_response_model_name(&self) -> Option<String> {
52 None
53 }
54
55 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
56 vec![self.message.clone()]
57 }
58
59 fn get_text_response(&self) -> Option<String> {
60 let Message::Assistant { ref content, .. } = self.message else {
61 return None;
62 };
63
64 let res = content
65 .iter()
66 .filter_map(|x| {
67 if let AssistantContent::Text { text } = x {
68 Some(text.to_string())
69 } else {
70 None
71 }
72 })
73 .collect::<Vec<String>>()
74 .join("\n");
75
76 if res.is_empty() { None } else { Some(res) }
77 }
78
79 fn get_usage(&self) -> Option<Self::Usage> {
80 self.usage.clone()
81 }
82}
83
84#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
85#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
86pub enum FinishReason {
87 MaxTokens,
88 StopSequence,
89 Complete,
90 Error,
91 ToolCall,
92}
93
94#[derive(Debug, Deserialize, Clone, Serialize)]
95pub struct Usage {
96 #[serde(default)]
97 pub billed_units: Option<BilledUnits>,
98 #[serde(default)]
99 pub tokens: Option<Tokens>,
100}
101
102impl GetTokenUsage for Usage {
103 fn token_usage(&self) -> Option<crate::completion::Usage> {
104 let mut usage = crate::completion::Usage::new();
105
106 if let Some(ref billed_units) = self.billed_units {
107 usage.input_tokens = billed_units.input_tokens.unwrap_or_default() as u64;
108 usage.output_tokens = billed_units.output_tokens.unwrap_or_default() as u64;
109 usage.total_tokens = usage.input_tokens + usage.output_tokens;
110 }
111
112 Some(usage)
113 }
114}
115
116#[derive(Debug, Deserialize, Clone, Serialize)]
117pub struct BilledUnits {
118 #[serde(default)]
119 pub output_tokens: Option<f64>,
120 #[serde(default)]
121 pub classifications: Option<f64>,
122 #[serde(default)]
123 pub search_units: Option<f64>,
124 #[serde(default)]
125 pub input_tokens: Option<f64>,
126}
127
128#[derive(Debug, Deserialize, Clone, Serialize)]
129pub struct Tokens {
130 #[serde(default)]
131 pub input_tokens: Option<f64>,
132 #[serde(default)]
133 pub output_tokens: Option<f64>,
134}
135
136impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
137 type Error = CompletionError;
138
139 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
140 let (content, _, tool_calls) = response.message();
141
142 let model_response = if !tool_calls.is_empty() {
143 OneOrMany::many(
144 tool_calls
145 .into_iter()
146 .filter_map(|tool_call| {
147 let ToolCallFunction { name, arguments } = tool_call.function?;
148 let id = tool_call.id.unwrap_or_else(|| name.clone());
149
150 Some(completion::AssistantContent::tool_call(id, name, arguments))
151 })
152 .collect::<Vec<_>>(),
153 )
154 .expect("We have atleast 1 tool call in this if block")
155 } else {
156 OneOrMany::many(content.into_iter().map(|content| match content {
157 AssistantContent::Text { text } => completion::AssistantContent::text(text),
158 AssistantContent::Thinking { thinking } => {
159 completion::AssistantContent::Reasoning(Reasoning::new(&thinking))
160 }
161 }))
162 .map_err(|_| {
163 CompletionError::ResponseError(
164 "Response contained no message or tool call (empty)".to_owned(),
165 )
166 })?
167 };
168
169 let usage = response
170 .usage
171 .as_ref()
172 .and_then(|usage| usage.tokens.as_ref())
173 .map(|tokens| {
174 let input_tokens = tokens.input_tokens.unwrap_or(0.0);
175 let output_tokens = tokens.output_tokens.unwrap_or(0.0);
176
177 completion::Usage {
178 input_tokens: input_tokens as u64,
179 output_tokens: output_tokens as u64,
180 total_tokens: (input_tokens + output_tokens) as u64,
181 cached_input_tokens: 0,
182 cache_creation_input_tokens: 0,
183 }
184 })
185 .unwrap_or_default();
186
187 Ok(completion::CompletionResponse {
188 choice: OneOrMany::many(model_response).expect("There is atleast one content"),
189 usage,
190 raw_response: response,
191 message_id: None,
192 })
193 }
194}
195
196#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
197pub struct Document {
198 pub id: String,
199 pub data: HashMap<String, serde_json::Value>,
200}
201
202impl From<completion::Document> for Document {
203 fn from(document: completion::Document) -> Self {
204 let mut data: HashMap<String, serde_json::Value> = HashMap::new();
205
206 document
209 .additional_props
210 .into_iter()
211 .for_each(|(key, value)| {
212 data.insert(key, value.into());
213 });
214
215 data.insert("text".to_string(), document.text.into());
216
217 Self {
218 id: document.id,
219 data,
220 }
221 }
222}
223
224#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
225pub struct ToolCall {
226 #[serde(default)]
227 pub id: Option<String>,
228 #[serde(default)]
229 pub r#type: Option<ToolType>,
230 #[serde(default)]
231 pub function: Option<ToolCallFunction>,
232}
233
234#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
235pub struct ToolCallFunction {
236 pub name: String,
237 #[serde(with = "json_utils::stringified_json")]
238 pub arguments: serde_json::Value,
239}
240
241#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
242#[serde(rename_all = "lowercase")]
243pub enum ToolType {
244 #[default]
245 Function,
246}
247
248#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
249pub struct Tool {
250 pub r#type: ToolType,
251 pub function: Function,
252}
253
254#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
255pub struct Function {
256 pub name: String,
257 #[serde(default)]
258 pub description: Option<String>,
259 pub parameters: serde_json::Value,
260}
261
262impl From<completion::ToolDefinition> for Tool {
263 fn from(tool: completion::ToolDefinition) -> Self {
264 Self {
265 r#type: ToolType::default(),
266 function: Function {
267 name: tool.name,
268 description: Some(tool.description),
269 parameters: tool.parameters,
270 },
271 }
272 }
273}
274
275#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
276#[serde(tag = "role", rename_all = "lowercase")]
277pub enum Message {
278 User {
279 content: OneOrMany<UserContent>,
280 },
281
282 Assistant {
283 #[serde(default)]
284 content: Vec<AssistantContent>,
285 #[serde(default)]
286 citations: Vec<Citation>,
287 #[serde(default)]
288 tool_calls: Vec<ToolCall>,
289 #[serde(default)]
290 tool_plan: Option<String>,
291 },
292
293 Tool {
294 content: OneOrMany<ToolResultContent>,
295 tool_call_id: String,
296 },
297
298 System {
299 content: String,
300 },
301}
302
303#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
304#[serde(tag = "type", rename_all = "lowercase")]
305pub enum UserContent {
306 Text { text: String },
307 ImageUrl { image_url: ImageUrl },
308}
309
310#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
311#[serde(tag = "type", rename_all = "lowercase")]
312pub enum AssistantContent {
313 Text { text: String },
314 Thinking { thinking: String },
315}
316
317#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
318pub struct ImageUrl {
319 pub url: String,
320}
321
322#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
323pub enum ToolResultContent {
324 Text { text: String },
325 Document { document: Document },
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
329pub struct Citation {
330 #[serde(default)]
331 pub start: Option<u32>,
332 #[serde(default)]
333 pub end: Option<u32>,
334 #[serde(default)]
335 pub text: Option<String>,
336 #[serde(rename = "type")]
337 pub citation_type: Option<CitationType>,
338 #[serde(default)]
339 pub sources: Vec<Source>,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
343#[serde(tag = "type", rename_all = "lowercase")]
344pub enum Source {
345 Document {
346 id: Option<String>,
347 document: Option<serde_json::Map<String, serde_json::Value>>,
348 },
349 Tool {
350 id: Option<String>,
351 tool_output: Option<serde_json::Map<String, serde_json::Value>>,
352 },
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
356#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
357pub enum CitationType {
358 TextContent,
359 Plan,
360}
361
362impl TryFrom<message::Message> for Vec<Message> {
363 type Error = message::MessageError;
364
365 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
366 Ok(match message {
367 message::Message::User { content } => content
368 .into_iter()
369 .map(|content| match content {
370 message::UserContent::Text(message::Text { text }) => Ok(Message::User {
371 content: OneOrMany::one(UserContent::Text { text }),
372 }),
373 message::UserContent::ToolResult(message::ToolResult {
374 id, content, ..
375 }) => Ok(Message::Tool {
376 tool_call_id: id,
377 content: content.try_map(|content| match content {
378 message::ToolResultContent::Text(text) => {
379 Ok(ToolResultContent::Text { text: text.text })
380 }
381 _ => Err(message::MessageError::ConversionError(
382 "Only text tool result content is supported by Cohere".to_owned(),
383 )),
384 })?,
385 }),
386 _ => Err(message::MessageError::ConversionError(
387 "Only text content is supported by Cohere".to_owned(),
388 )),
389 })
390 .collect::<Result<Vec<_>, _>>()?,
391 message::Message::System { content } => {
392 vec![Message::System { content }]
393 }
394 message::Message::Assistant { content, .. } => {
395 let mut text_content = vec![];
396 let mut tool_calls = vec![];
397
398 for content in content.into_iter() {
399 match content {
400 message::AssistantContent::Text(message::Text { text }) => {
401 text_content.push(AssistantContent::Text { text });
402 }
403 message::AssistantContent::ToolCall(message::ToolCall {
404 id,
405 function:
406 message::ToolFunction {
407 name, arguments, ..
408 },
409 ..
410 }) => {
411 tool_calls.push(ToolCall {
412 id: Some(id),
413 r#type: Some(ToolType::Function),
414 function: Some(ToolCallFunction {
415 name,
416 arguments: serde_json::to_value(arguments).unwrap_or_default(),
417 }),
418 });
419 }
420 message::AssistantContent::Reasoning(reasoning) => {
421 let thinking = reasoning.display_text();
422 text_content.push(AssistantContent::Thinking { thinking });
423 }
424 message::AssistantContent::Image(_) => {
425 return Err(message::MessageError::ConversionError(
426 "Cohere currently doesn't support images.".to_owned(),
427 ));
428 }
429 }
430 }
431
432 vec![Message::Assistant {
433 content: text_content,
434 citations: vec![],
435 tool_calls,
436 tool_plan: None,
437 }]
438 }
439 })
440 }
441}
442
443impl TryFrom<Message> for message::Message {
444 type Error = message::MessageError;
445
446 fn try_from(message: Message) -> Result<Self, Self::Error> {
447 match message {
448 Message::User { content } => Ok(message::Message::User {
449 content: content.map(|content| match content {
450 UserContent::Text { text } => {
451 message::UserContent::Text(message::Text { text })
452 }
453 UserContent::ImageUrl { image_url } => {
454 message::UserContent::image_url(image_url.url, None, None)
455 }
456 }),
457 }),
458 Message::Assistant {
459 content,
460 tool_calls,
461 ..
462 } => {
463 let mut content = content
464 .into_iter()
465 .map(|content| match content {
466 AssistantContent::Text { text } => message::AssistantContent::text(text),
467 AssistantContent::Thinking { thinking } => {
468 message::AssistantContent::Reasoning(Reasoning::new(&thinking))
469 }
470 })
471 .collect::<Vec<_>>();
472
473 content.extend(tool_calls.into_iter().filter_map(|tool_call| {
474 let ToolCallFunction { name, arguments } = tool_call.function?;
475
476 Some(message::AssistantContent::tool_call(
477 tool_call.id.unwrap_or_else(|| name.clone()),
478 name,
479 arguments,
480 ))
481 }));
482
483 let content = OneOrMany::many(content).map_err(|_| {
484 message::MessageError::ConversionError(
485 "Expected either text content or tool calls".to_string(),
486 )
487 })?;
488
489 Ok(message::Message::Assistant { id: None, content })
490 }
491 Message::Tool {
492 content,
493 tool_call_id,
494 } => {
495 let content = content.try_map(|content| {
496 Ok(match content {
497 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
498 ToolResultContent::Document { document } => {
499 message::ToolResultContent::text(
500 serde_json::to_string(&document.data).map_err(|e| {
501 message::MessageError::ConversionError(
502 format!("Failed to convert tool result document content into text: {e}"),
503 )
504 })?,
505 )
506 }
507 })
508 })?;
509
510 Ok(message::Message::User {
511 content: OneOrMany::one(message::UserContent::tool_result(
512 tool_call_id,
513 content,
514 )),
515 })
516 }
517 Message::System { content } => Ok(message::Message::user(content)),
518 }
519 }
520}
521
522#[derive(Clone)]
523pub struct CompletionModel<T = reqwest::Client> {
524 pub(crate) client: Client<T>,
525 pub model: String,
526}
527
528#[derive(Debug, Serialize, Deserialize)]
529pub(super) struct CohereCompletionRequest {
530 model: String,
531 pub messages: Vec<Message>,
532 documents: Vec<crate::completion::Document>,
533 #[serde(skip_serializing_if = "Option::is_none")]
534 temperature: Option<f64>,
535 #[serde(skip_serializing_if = "Vec::is_empty")]
536 tools: Vec<Tool>,
537 #[serde(skip_serializing_if = "Option::is_none")]
538 tool_choice: Option<ToolChoice>,
539 #[serde(flatten, skip_serializing_if = "Option::is_none")]
540 pub additional_params: Option<serde_json::Value>,
541}
542
543impl TryFrom<(&str, CompletionRequest)> for CohereCompletionRequest {
544 type Error = CompletionError;
545
546 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
547 if req.output_schema.is_some() {
548 tracing::warn!("Structured outputs currently not supported for Cohere");
549 }
550
551 let model = req.model.clone().unwrap_or_else(|| model.to_string());
552 let mut partial_history = vec![];
553 if let Some(docs) = req.normalized_documents() {
554 partial_history.push(docs);
555 }
556 partial_history.extend(req.chat_history);
557
558 let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
559 vec![Message::System { content: preamble }]
560 });
561
562 full_history.extend(
563 partial_history
564 .into_iter()
565 .map(message::Message::try_into)
566 .collect::<Result<Vec<Vec<Message>>, _>>()?
567 .into_iter()
568 .flatten()
569 .collect::<Vec<_>>(),
570 );
571
572 let tool_choice = if let Some(tool_choice) = req.tool_choice {
573 if !matches!(tool_choice, ToolChoice::Auto) {
574 Some(tool_choice)
575 } else {
576 return Err(CompletionError::RequestError(
577 "\"auto\" is not an allowed tool_choice value in the Cohere API".into(),
578 ));
579 }
580 } else {
581 None
582 };
583
584 Ok(Self {
585 model: model.to_string(),
586 messages: full_history,
587 documents: req.documents,
588 temperature: req.temperature,
589 tools: req.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
590 tool_choice,
591 additional_params: req.additional_params,
592 })
593 }
594}
595
596impl<T> CompletionModel<T>
597where
598 T: HttpClientExt,
599{
600 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
601 Self {
602 client,
603 model: model.into(),
604 }
605 }
606}
607
608impl<T> completion::CompletionModel for CompletionModel<T>
609where
610 T: HttpClientExt + Clone + 'static,
611{
612 type Response = CompletionResponse;
613 type StreamingResponse = StreamingCompletionResponse;
614 type Client = Client<T>;
615
616 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
617 Self::new(client.clone(), model.into())
618 }
619
620 async fn completion(
621 &self,
622 completion_request: completion::CompletionRequest,
623 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
624 let request = CohereCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
625
626 let llm_span = if tracing::Span::current().is_disabled() {
627 info_span!(
628 target: "rig::completions",
629 "chat",
630 gen_ai.operation.name = "chat",
631 gen_ai.provider.name = "cohere",
632 gen_ai.request.model = self.model,
633 gen_ai.response.id = tracing::field::Empty,
634 gen_ai.response.model = self.model,
635 gen_ai.usage.output_tokens = tracing::field::Empty,
636 gen_ai.usage.input_tokens = tracing::field::Empty,
637 gen_ai.usage.cached_tokens = tracing::field::Empty,
638 )
639 } else {
640 tracing::Span::current()
641 };
642
643 if enabled!(Level::TRACE) {
644 tracing::trace!(
645 "Cohere completion request: {}",
646 serde_json::to_string_pretty(&request)?
647 );
648 }
649
650 let req_body = serde_json::to_vec(&request)?;
651
652 let req = self.client.post("/v2/chat")?.body(req_body).unwrap();
653
654 async {
655 let response = self
656 .client
657 .send::<_, bytes::Bytes>(req)
658 .await
659 .map_err(|e| http_client::Error::Instance(e.into()))?;
660
661 let status = response.status();
662 let body = response.into_body().into_future().await?.to_owned();
663
664 if status.is_success() {
665 let json_response: CompletionResponse = serde_json::from_slice(&body)?;
666 let span = tracing::Span::current();
667 span.record_token_usage(&json_response.usage);
668 span.record_response_metadata(&json_response);
669
670 if enabled!(Level::TRACE) {
671 tracing::trace!(
672 target: "rig::completions",
673 "Cohere completion response: {}",
674 serde_json::to_string_pretty(&json_response)?
675 );
676 }
677
678 let completion: completion::CompletionResponse<CompletionResponse> =
679 json_response.try_into()?;
680 Ok(completion)
681 } else {
682 Err(CompletionError::ProviderError(
683 String::from_utf8_lossy(&body).to_string(),
684 ))
685 }
686 }
687 .instrument(llm_span)
688 .await
689 }
690
691 async fn stream(
692 &self,
693 request: CompletionRequest,
694 ) -> Result<
695 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
696 CompletionError,
697 > {
698 CompletionModel::stream(self, request).await
699 }
700}
701#[cfg(test)]
702mod tests {
703 use super::*;
704 use serde_path_to_error::deserialize;
705
706 #[test]
707 fn test_deserialize_completion_response() {
708 let json_data = r#"
709 {
710 "id": "abc123",
711 "message": {
712 "role": "assistant",
713 "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
714 "tool_calls": [
715 {
716 "id": "subtract_sm6ps6fb6y9f",
717 "type": "function",
718 "function": {
719 "name": "subtract",
720 "arguments": "{\"x\":5,\"y\":2}"
721 }
722 }
723 ]
724 },
725 "finish_reason": "TOOL_CALL",
726 "usage": {
727 "billed_units": {
728 "input_tokens": 78,
729 "output_tokens": 27
730 },
731 "tokens": {
732 "input_tokens": 1028,
733 "output_tokens": 63
734 }
735 }
736 }
737 "#;
738
739 let mut deserializer = serde_json::Deserializer::from_str(json_data);
740 let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
741
742 let response = result.unwrap();
743 let (_, citations, tool_calls) = response.message();
744 let CompletionResponse {
745 id,
746 finish_reason,
747 usage,
748 ..
749 } = response;
750
751 assert_eq!(id, "abc123");
752 assert_eq!(finish_reason, FinishReason::ToolCall);
753
754 let Usage {
755 billed_units,
756 tokens,
757 } = usage.unwrap();
758 let BilledUnits {
759 input_tokens: billed_input_tokens,
760 output_tokens: billed_output_tokens,
761 ..
762 } = billed_units.unwrap();
763 let Tokens {
764 input_tokens,
765 output_tokens,
766 } = tokens.unwrap();
767
768 assert_eq!(billed_input_tokens.unwrap(), 78.0);
769 assert_eq!(billed_output_tokens.unwrap(), 27.0);
770 assert_eq!(input_tokens.unwrap(), 1028.0);
771 assert_eq!(output_tokens.unwrap(), 63.0);
772
773 assert!(citations.is_empty());
774 assert_eq!(tool_calls.len(), 1);
775
776 let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
777
778 assert_eq!(name, "subtract");
779 assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
780 }
781
782 #[test]
783 fn test_convert_completion_message_to_message_and_back() {
784 let completion_message = completion::Message::User {
785 content: OneOrMany::one(completion::message::UserContent::Text(
786 completion::message::Text {
787 text: "Hello, world!".to_string(),
788 },
789 )),
790 };
791
792 let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
793 let _converted_back: Vec<completion::Message> = messages
794 .into_iter()
795 .map(|msg| msg.try_into().unwrap())
796 .collect::<Vec<_>>();
797 }
798
799 #[test]
800 fn test_convert_message_to_completion_message_and_back() {
801 let message = Message::User {
802 content: OneOrMany::one(UserContent::Text {
803 text: "Hello, world!".to_string(),
804 }),
805 };
806
807 let completion_message: completion::Message = message.clone().try_into().unwrap();
808 let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
809 }
810}