1use crate::{
2 OneOrMany,
3 completion::{self, CompletionError, GetTokenUsage},
4 http_client::{self, HttpClientExt},
5 json_utils,
6 message::{self, Reasoning, ToolChoice},
7 telemetry::SpanCombinator,
8};
9use std::collections::HashMap;
10
11use super::client::Client;
12use crate::completion::CompletionRequest;
13use crate::providers::cohere::streaming::StreamingCompletionResponse;
14use serde::{Deserialize, Serialize};
15use tracing::{Instrument, Level, enabled, info_span};
16
17#[derive(Debug, Deserialize, Serialize)]
18pub struct CompletionResponse {
19 pub id: String,
20 pub finish_reason: FinishReason,
21 message: Message,
22 #[serde(default)]
23 pub usage: Option<Usage>,
24}
25
26impl CompletionResponse {
27 pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
29 let Message::Assistant {
30 content,
31 citations,
32 tool_calls,
33 ..
34 } = self.message.clone()
35 else {
36 unreachable!("Completion responses will only return an assistant message")
37 };
38
39 (content, citations, tool_calls)
40 }
41}
42
43impl crate::telemetry::ProviderResponseExt for CompletionResponse {
44 type OutputMessage = Message;
45 type Usage = Usage;
46
47 fn get_response_id(&self) -> Option<String> {
48 Some(self.id.clone())
49 }
50
51 fn get_response_model_name(&self) -> Option<String> {
52 None
53 }
54
55 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
56 vec![self.message.clone()]
57 }
58
59 fn get_text_response(&self) -> Option<String> {
60 let Message::Assistant { ref content, .. } = self.message else {
61 return None;
62 };
63
64 let res = content
65 .iter()
66 .filter_map(|x| {
67 if let AssistantContent::Text { text } = x {
68 Some(text.to_string())
69 } else {
70 None
71 }
72 })
73 .collect::<Vec<String>>()
74 .join("\n");
75
76 if res.is_empty() { None } else { Some(res) }
77 }
78
79 fn get_usage(&self) -> Option<Self::Usage> {
80 self.usage.clone()
81 }
82}
83
84#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
85#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
86pub enum FinishReason {
87 MaxTokens,
88 StopSequence,
89 Complete,
90 Error,
91 ToolCall,
92}
93
94#[derive(Debug, Deserialize, Clone, Serialize)]
95pub struct Usage {
96 #[serde(default)]
97 pub billed_units: Option<BilledUnits>,
98 #[serde(default)]
99 pub tokens: Option<Tokens>,
100}
101
102impl GetTokenUsage for Usage {
103 fn token_usage(&self) -> Option<crate::completion::Usage> {
104 let mut usage = crate::completion::Usage::new();
105
106 if let Some(ref billed_units) = self.billed_units {
107 usage.input_tokens = billed_units.input_tokens.unwrap_or_default() as u64;
108 usage.output_tokens = billed_units.output_tokens.unwrap_or_default() as u64;
109 usage.total_tokens = usage.input_tokens + usage.output_tokens;
110 }
111
112 Some(usage)
113 }
114}
115
116#[derive(Debug, Deserialize, Clone, Serialize)]
117pub struct BilledUnits {
118 #[serde(default)]
119 pub output_tokens: Option<f64>,
120 #[serde(default)]
121 pub classifications: Option<f64>,
122 #[serde(default)]
123 pub search_units: Option<f64>,
124 #[serde(default)]
125 pub input_tokens: Option<f64>,
126}
127
128#[derive(Debug, Deserialize, Clone, Serialize)]
129pub struct Tokens {
130 #[serde(default)]
131 pub input_tokens: Option<f64>,
132 #[serde(default)]
133 pub output_tokens: Option<f64>,
134}
135
136impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
137 type Error = CompletionError;
138
139 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
140 let (content, _, tool_calls) = response.message();
141
142 let model_response = if !tool_calls.is_empty() {
143 OneOrMany::many(
144 tool_calls
145 .into_iter()
146 .filter_map(|tool_call| {
147 let ToolCallFunction { name, arguments } = tool_call.function?;
148 let id = tool_call.id.unwrap_or_else(|| name.clone());
149
150 Some(completion::AssistantContent::tool_call(id, name, arguments))
151 })
152 .collect::<Vec<_>>(),
153 )
154 .expect("We have atleast 1 tool call in this if block")
155 } else {
156 OneOrMany::many(content.into_iter().map(|content| match content {
157 AssistantContent::Text { text } => completion::AssistantContent::text(text),
158 AssistantContent::Thinking { thinking } => {
159 completion::AssistantContent::Reasoning(Reasoning::new(&thinking))
160 }
161 }))
162 .map_err(|_| {
163 CompletionError::ResponseError(
164 "Response contained no message or tool call (empty)".to_owned(),
165 )
166 })?
167 };
168
169 let usage = response
170 .usage
171 .as_ref()
172 .and_then(|usage| usage.tokens.as_ref())
173 .map(|tokens| {
174 let input_tokens = tokens.input_tokens.unwrap_or(0.0);
175 let output_tokens = tokens.output_tokens.unwrap_or(0.0);
176
177 completion::Usage {
178 input_tokens: input_tokens as u64,
179 output_tokens: output_tokens as u64,
180 total_tokens: (input_tokens + output_tokens) as u64,
181 cached_input_tokens: 0,
182 }
183 })
184 .unwrap_or_default();
185
186 Ok(completion::CompletionResponse {
187 choice: OneOrMany::many(model_response).expect("There is atleast one content"),
188 usage,
189 raw_response: response,
190 message_id: None,
191 })
192 }
193}
194
195#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
196pub struct Document {
197 pub id: String,
198 pub data: HashMap<String, serde_json::Value>,
199}
200
201impl From<completion::Document> for Document {
202 fn from(document: completion::Document) -> Self {
203 let mut data: HashMap<String, serde_json::Value> = HashMap::new();
204
205 document
208 .additional_props
209 .into_iter()
210 .for_each(|(key, value)| {
211 data.insert(key, value.into());
212 });
213
214 data.insert("text".to_string(), document.text.into());
215
216 Self {
217 id: document.id,
218 data,
219 }
220 }
221}
222
223#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
224pub struct ToolCall {
225 #[serde(default)]
226 pub id: Option<String>,
227 #[serde(default)]
228 pub r#type: Option<ToolType>,
229 #[serde(default)]
230 pub function: Option<ToolCallFunction>,
231}
232
233#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
234pub struct ToolCallFunction {
235 pub name: String,
236 #[serde(with = "json_utils::stringified_json")]
237 pub arguments: serde_json::Value,
238}
239
240#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
241#[serde(rename_all = "lowercase")]
242pub enum ToolType {
243 #[default]
244 Function,
245}
246
247#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
248pub struct Tool {
249 pub r#type: ToolType,
250 pub function: Function,
251}
252
253#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
254pub struct Function {
255 pub name: String,
256 #[serde(default)]
257 pub description: Option<String>,
258 pub parameters: serde_json::Value,
259}
260
261impl From<completion::ToolDefinition> for Tool {
262 fn from(tool: completion::ToolDefinition) -> Self {
263 Self {
264 r#type: ToolType::default(),
265 function: Function {
266 name: tool.name,
267 description: Some(tool.description),
268 parameters: tool.parameters,
269 },
270 }
271 }
272}
273
274#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
275#[serde(tag = "role", rename_all = "lowercase")]
276pub enum Message {
277 User {
278 content: OneOrMany<UserContent>,
279 },
280
281 Assistant {
282 #[serde(default)]
283 content: Vec<AssistantContent>,
284 #[serde(default)]
285 citations: Vec<Citation>,
286 #[serde(default)]
287 tool_calls: Vec<ToolCall>,
288 #[serde(default)]
289 tool_plan: Option<String>,
290 },
291
292 Tool {
293 content: OneOrMany<ToolResultContent>,
294 tool_call_id: String,
295 },
296
297 System {
298 content: String,
299 },
300}
301
302#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
303#[serde(tag = "type", rename_all = "lowercase")]
304pub enum UserContent {
305 Text { text: String },
306 ImageUrl { image_url: ImageUrl },
307}
308
309#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
310#[serde(tag = "type", rename_all = "lowercase")]
311pub enum AssistantContent {
312 Text { text: String },
313 Thinking { thinking: String },
314}
315
316#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
317pub struct ImageUrl {
318 pub url: String,
319}
320
321#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
322pub enum ToolResultContent {
323 Text { text: String },
324 Document { document: Document },
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
328pub struct Citation {
329 #[serde(default)]
330 pub start: Option<u32>,
331 #[serde(default)]
332 pub end: Option<u32>,
333 #[serde(default)]
334 pub text: Option<String>,
335 #[serde(rename = "type")]
336 pub citation_type: Option<CitationType>,
337 #[serde(default)]
338 pub sources: Vec<Source>,
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
342#[serde(tag = "type", rename_all = "lowercase")]
343pub enum Source {
344 Document {
345 id: Option<String>,
346 document: Option<serde_json::Map<String, serde_json::Value>>,
347 },
348 Tool {
349 id: Option<String>,
350 tool_output: Option<serde_json::Map<String, serde_json::Value>>,
351 },
352}
353
354#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
355#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
356pub enum CitationType {
357 TextContent,
358 Plan,
359}
360
361impl TryFrom<message::Message> for Vec<Message> {
362 type Error = message::MessageError;
363
364 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
365 Ok(match message {
366 message::Message::User { content } => content
367 .into_iter()
368 .map(|content| match content {
369 message::UserContent::Text(message::Text { text }) => Ok(Message::User {
370 content: OneOrMany::one(UserContent::Text { text }),
371 }),
372 message::UserContent::ToolResult(message::ToolResult {
373 id, content, ..
374 }) => Ok(Message::Tool {
375 tool_call_id: id,
376 content: content.try_map(|content| match content {
377 message::ToolResultContent::Text(text) => {
378 Ok(ToolResultContent::Text { text: text.text })
379 }
380 _ => Err(message::MessageError::ConversionError(
381 "Only text tool result content is supported by Cohere".to_owned(),
382 )),
383 })?,
384 }),
385 _ => Err(message::MessageError::ConversionError(
386 "Only text content is supported by Cohere".to_owned(),
387 )),
388 })
389 .collect::<Result<Vec<_>, _>>()?,
390 message::Message::Assistant { content, .. } => {
391 let mut text_content = vec![];
392 let mut tool_calls = vec![];
393
394 for content in content.into_iter() {
395 match content {
396 message::AssistantContent::Text(message::Text { text }) => {
397 text_content.push(AssistantContent::Text { text });
398 }
399 message::AssistantContent::ToolCall(message::ToolCall {
400 id,
401 function:
402 message::ToolFunction {
403 name, arguments, ..
404 },
405 ..
406 }) => {
407 tool_calls.push(ToolCall {
408 id: Some(id),
409 r#type: Some(ToolType::Function),
410 function: Some(ToolCallFunction {
411 name,
412 arguments: serde_json::to_value(arguments).unwrap_or_default(),
413 }),
414 });
415 }
416 message::AssistantContent::Reasoning(reasoning) => {
417 let thinking = reasoning.display_text();
418 text_content.push(AssistantContent::Thinking { thinking });
419 }
420 message::AssistantContent::Image(_) => {
421 return Err(message::MessageError::ConversionError(
422 "Cohere currently doesn't support images.".to_owned(),
423 ));
424 }
425 }
426 }
427
428 vec![Message::Assistant {
429 content: text_content,
430 citations: vec![],
431 tool_calls,
432 tool_plan: None,
433 }]
434 }
435 })
436 }
437}
438
439impl TryFrom<Message> for message::Message {
440 type Error = message::MessageError;
441
442 fn try_from(message: Message) -> Result<Self, Self::Error> {
443 match message {
444 Message::User { content } => Ok(message::Message::User {
445 content: content.map(|content| match content {
446 UserContent::Text { text } => {
447 message::UserContent::Text(message::Text { text })
448 }
449 UserContent::ImageUrl { image_url } => {
450 message::UserContent::image_url(image_url.url, None, None)
451 }
452 }),
453 }),
454 Message::Assistant {
455 content,
456 tool_calls,
457 ..
458 } => {
459 let mut content = content
460 .into_iter()
461 .map(|content| match content {
462 AssistantContent::Text { text } => message::AssistantContent::text(text),
463 AssistantContent::Thinking { thinking } => {
464 message::AssistantContent::Reasoning(Reasoning::new(&thinking))
465 }
466 })
467 .collect::<Vec<_>>();
468
469 content.extend(tool_calls.into_iter().filter_map(|tool_call| {
470 let ToolCallFunction { name, arguments } = tool_call.function?;
471
472 Some(message::AssistantContent::tool_call(
473 tool_call.id.unwrap_or_else(|| name.clone()),
474 name,
475 arguments,
476 ))
477 }));
478
479 let content = OneOrMany::many(content).map_err(|_| {
480 message::MessageError::ConversionError(
481 "Expected either text content or tool calls".to_string(),
482 )
483 })?;
484
485 Ok(message::Message::Assistant { id: None, content })
486 }
487 Message::Tool {
488 content,
489 tool_call_id,
490 } => {
491 let content = content.try_map(|content| {
492 Ok(match content {
493 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
494 ToolResultContent::Document { document } => {
495 message::ToolResultContent::text(
496 serde_json::to_string(&document.data).map_err(|e| {
497 message::MessageError::ConversionError(
498 format!("Failed to convert tool result document content into text: {e}"),
499 )
500 })?,
501 )
502 }
503 })
504 })?;
505
506 Ok(message::Message::User {
507 content: OneOrMany::one(message::UserContent::tool_result(
508 tool_call_id,
509 content,
510 )),
511 })
512 }
513 Message::System { content } => Ok(message::Message::user(content)),
514 }
515 }
516}
517
518#[derive(Clone)]
519pub struct CompletionModel<T = reqwest::Client> {
520 pub(crate) client: Client<T>,
521 pub model: String,
522}
523
524#[derive(Debug, Serialize, Deserialize)]
525pub(super) struct CohereCompletionRequest {
526 model: String,
527 pub messages: Vec<Message>,
528 documents: Vec<crate::completion::Document>,
529 #[serde(skip_serializing_if = "Option::is_none")]
530 temperature: Option<f64>,
531 #[serde(skip_serializing_if = "Vec::is_empty")]
532 tools: Vec<Tool>,
533 #[serde(skip_serializing_if = "Option::is_none")]
534 tool_choice: Option<ToolChoice>,
535 #[serde(flatten, skip_serializing_if = "Option::is_none")]
536 pub additional_params: Option<serde_json::Value>,
537}
538
539impl TryFrom<(&str, CompletionRequest)> for CohereCompletionRequest {
540 type Error = CompletionError;
541
542 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
543 if req.output_schema.is_some() {
544 tracing::warn!("Structured outputs currently not supported for Cohere");
545 }
546
547 let model = req.model.clone().unwrap_or_else(|| model.to_string());
548 let mut partial_history = vec![];
549 if let Some(docs) = req.normalized_documents() {
550 partial_history.push(docs);
551 }
552 partial_history.extend(req.chat_history);
553
554 let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
555 vec![Message::System { content: preamble }]
556 });
557
558 full_history.extend(
559 partial_history
560 .into_iter()
561 .map(message::Message::try_into)
562 .collect::<Result<Vec<Vec<Message>>, _>>()?
563 .into_iter()
564 .flatten()
565 .collect::<Vec<_>>(),
566 );
567
568 let tool_choice = if let Some(tool_choice) = req.tool_choice {
569 if !matches!(tool_choice, ToolChoice::Auto) {
570 Some(tool_choice)
571 } else {
572 return Err(CompletionError::RequestError(
573 "\"auto\" is not an allowed tool_choice value in the Cohere API".into(),
574 ));
575 }
576 } else {
577 None
578 };
579
580 Ok(Self {
581 model: model.to_string(),
582 messages: full_history,
583 documents: req.documents,
584 temperature: req.temperature,
585 tools: req.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
586 tool_choice,
587 additional_params: req.additional_params,
588 })
589 }
590}
591
592impl<T> CompletionModel<T>
593where
594 T: HttpClientExt,
595{
596 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
597 Self {
598 client,
599 model: model.into(),
600 }
601 }
602}
603
604impl<T> completion::CompletionModel for CompletionModel<T>
605where
606 T: HttpClientExt + Clone + 'static,
607{
608 type Response = CompletionResponse;
609 type StreamingResponse = StreamingCompletionResponse;
610 type Client = Client<T>;
611
612 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
613 Self::new(client.clone(), model.into())
614 }
615
616 async fn completion(
617 &self,
618 completion_request: completion::CompletionRequest,
619 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
620 let request = CohereCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
621
622 let llm_span = if tracing::Span::current().is_disabled() {
623 info_span!(
624 target: "rig::completions",
625 "chat",
626 gen_ai.operation.name = "chat",
627 gen_ai.provider.name = "cohere",
628 gen_ai.request.model = self.model,
629 gen_ai.response.id = tracing::field::Empty,
630 gen_ai.response.model = self.model,
631 gen_ai.usage.output_tokens = tracing::field::Empty,
632 gen_ai.usage.input_tokens = tracing::field::Empty,
633 )
634 } else {
635 tracing::Span::current()
636 };
637
638 if enabled!(Level::TRACE) {
639 tracing::trace!(
640 "Cohere completion request: {}",
641 serde_json::to_string_pretty(&request)?
642 );
643 }
644
645 let req_body = serde_json::to_vec(&request)?;
646
647 let req = self.client.post("/v2/chat")?.body(req_body).unwrap();
648
649 async {
650 let response = self
651 .client
652 .send::<_, bytes::Bytes>(req)
653 .await
654 .map_err(|e| http_client::Error::Instance(e.into()))?;
655
656 let status = response.status();
657 let body = response.into_body().into_future().await?.to_owned();
658
659 if status.is_success() {
660 let json_response: CompletionResponse = serde_json::from_slice(&body)?;
661 let span = tracing::Span::current();
662 span.record_token_usage(&json_response.usage);
663 span.record_response_metadata(&json_response);
664
665 if enabled!(Level::TRACE) {
666 tracing::trace!(
667 target: "rig::completions",
668 "Cohere completion response: {}",
669 serde_json::to_string_pretty(&json_response)?
670 );
671 }
672
673 let completion: completion::CompletionResponse<CompletionResponse> =
674 json_response.try_into()?;
675 Ok(completion)
676 } else {
677 Err(CompletionError::ProviderError(
678 String::from_utf8_lossy(&body).to_string(),
679 ))
680 }
681 }
682 .instrument(llm_span)
683 .await
684 }
685
686 async fn stream(
687 &self,
688 request: CompletionRequest,
689 ) -> Result<
690 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
691 CompletionError,
692 > {
693 CompletionModel::stream(self, request).await
694 }
695}
696#[cfg(test)]
697mod tests {
698 use super::*;
699 use serde_path_to_error::deserialize;
700
701 #[test]
702 fn test_deserialize_completion_response() {
703 let json_data = r#"
704 {
705 "id": "abc123",
706 "message": {
707 "role": "assistant",
708 "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
709 "tool_calls": [
710 {
711 "id": "subtract_sm6ps6fb6y9f",
712 "type": "function",
713 "function": {
714 "name": "subtract",
715 "arguments": "{\"x\":5,\"y\":2}"
716 }
717 }
718 ]
719 },
720 "finish_reason": "TOOL_CALL",
721 "usage": {
722 "billed_units": {
723 "input_tokens": 78,
724 "output_tokens": 27
725 },
726 "tokens": {
727 "input_tokens": 1028,
728 "output_tokens": 63
729 }
730 }
731 }
732 "#;
733
734 let mut deserializer = serde_json::Deserializer::from_str(json_data);
735 let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
736
737 let response = result.unwrap();
738 let (_, citations, tool_calls) = response.message();
739 let CompletionResponse {
740 id,
741 finish_reason,
742 usage,
743 ..
744 } = response;
745
746 assert_eq!(id, "abc123");
747 assert_eq!(finish_reason, FinishReason::ToolCall);
748
749 let Usage {
750 billed_units,
751 tokens,
752 } = usage.unwrap();
753 let BilledUnits {
754 input_tokens: billed_input_tokens,
755 output_tokens: billed_output_tokens,
756 ..
757 } = billed_units.unwrap();
758 let Tokens {
759 input_tokens,
760 output_tokens,
761 } = tokens.unwrap();
762
763 assert_eq!(billed_input_tokens.unwrap(), 78.0);
764 assert_eq!(billed_output_tokens.unwrap(), 27.0);
765 assert_eq!(input_tokens.unwrap(), 1028.0);
766 assert_eq!(output_tokens.unwrap(), 63.0);
767
768 assert!(citations.is_empty());
769 assert_eq!(tool_calls.len(), 1);
770
771 let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
772
773 assert_eq!(name, "subtract");
774 assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
775 }
776
777 #[test]
778 fn test_convert_completion_message_to_message_and_back() {
779 let completion_message = completion::Message::User {
780 content: OneOrMany::one(completion::message::UserContent::Text(
781 completion::message::Text {
782 text: "Hello, world!".to_string(),
783 },
784 )),
785 };
786
787 let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
788 let _converted_back: Vec<completion::Message> = messages
789 .into_iter()
790 .map(|msg| msg.try_into().unwrap())
791 .collect::<Vec<_>>();
792 }
793
794 #[test]
795 fn test_convert_message_to_completion_message_and_back() {
796 let message = Message::User {
797 content: OneOrMany::one(UserContent::Text {
798 text: "Hello, world!".to_string(),
799 }),
800 };
801
802 let completion_message: completion::Message = message.clone().try_into().unwrap();
803 let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
804 }
805}