1use crate::{
2 OneOrMany,
3 completion::{self, CompletionError, GetTokenUsage},
4 http_client::{self, HttpClientExt},
5 json_utils,
6 message::{self, Reasoning, ToolChoice},
7 telemetry::SpanCombinator,
8};
9use std::collections::HashMap;
10
11use super::client::Client;
12use crate::completion::CompletionRequest;
13use crate::providers::cohere::streaming::StreamingCompletionResponse;
14use serde::{Deserialize, Serialize};
15use tracing::{Instrument, Level, enabled, info_span};
16
17#[derive(Debug, Deserialize, Serialize)]
18pub struct CompletionResponse {
19 pub id: String,
20 pub finish_reason: FinishReason,
21 message: Message,
22 #[serde(default)]
23 pub usage: Option<Usage>,
24}
25
26impl CompletionResponse {
27 pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
29 let Message::Assistant {
30 content,
31 citations,
32 tool_calls,
33 ..
34 } = self.message.clone()
35 else {
36 unreachable!("Completion responses will only return an assistant message")
37 };
38
39 (content, citations, tool_calls)
40 }
41}
42
43impl crate::telemetry::ProviderResponseExt for CompletionResponse {
44 type OutputMessage = Message;
45 type Usage = Usage;
46
47 fn get_response_id(&self) -> Option<String> {
48 Some(self.id.clone())
49 }
50
51 fn get_response_model_name(&self) -> Option<String> {
52 None
53 }
54
55 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
56 vec![self.message.clone()]
57 }
58
59 fn get_text_response(&self) -> Option<String> {
60 let Message::Assistant { ref content, .. } = self.message else {
61 return None;
62 };
63
64 let res = content
65 .iter()
66 .filter_map(|x| {
67 if let AssistantContent::Text { text } = x {
68 Some(text.to_string())
69 } else {
70 None
71 }
72 })
73 .collect::<Vec<String>>()
74 .join("\n");
75
76 if res.is_empty() { None } else { Some(res) }
77 }
78
79 fn get_usage(&self) -> Option<Self::Usage> {
80 self.usage.clone()
81 }
82}
83
84#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
85#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
86pub enum FinishReason {
87 MaxTokens,
88 StopSequence,
89 Complete,
90 Error,
91 ToolCall,
92}
93
94#[derive(Debug, Deserialize, Clone, Serialize)]
95pub struct Usage {
96 #[serde(default)]
97 pub billed_units: Option<BilledUnits>,
98 #[serde(default)]
99 pub tokens: Option<Tokens>,
100}
101
102impl GetTokenUsage for Usage {
103 fn token_usage(&self) -> Option<crate::completion::Usage> {
104 let mut usage = crate::completion::Usage::new();
105
106 if let Some(ref billed_units) = self.billed_units {
107 usage.input_tokens = billed_units.input_tokens.unwrap_or_default() as u64;
108 usage.output_tokens = billed_units.output_tokens.unwrap_or_default() as u64;
109 usage.total_tokens = usage.input_tokens + usage.output_tokens;
110 }
111
112 Some(usage)
113 }
114}
115
116#[derive(Debug, Deserialize, Clone, Serialize)]
117pub struct BilledUnits {
118 #[serde(default)]
119 pub output_tokens: Option<f64>,
120 #[serde(default)]
121 pub classifications: Option<f64>,
122 #[serde(default)]
123 pub search_units: Option<f64>,
124 #[serde(default)]
125 pub input_tokens: Option<f64>,
126}
127
128#[derive(Debug, Deserialize, Clone, Serialize)]
129pub struct Tokens {
130 #[serde(default)]
131 pub input_tokens: Option<f64>,
132 #[serde(default)]
133 pub output_tokens: Option<f64>,
134}
135
136impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
137 type Error = CompletionError;
138
139 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
140 let (content, _, tool_calls) = response.message();
141
142 let model_response = if !tool_calls.is_empty() {
143 OneOrMany::many(
144 tool_calls
145 .into_iter()
146 .filter_map(|tool_call| {
147 let ToolCallFunction { name, arguments } = tool_call.function?;
148 let id = tool_call.id.unwrap_or_else(|| name.clone());
149
150 Some(completion::AssistantContent::tool_call(id, name, arguments))
151 })
152 .collect::<Vec<_>>(),
153 )
154 .expect("We have atleast 1 tool call in this if block")
155 } else {
156 OneOrMany::many(content.into_iter().map(|content| match content {
157 AssistantContent::Text { text } => completion::AssistantContent::text(text),
158 AssistantContent::Thinking { thinking } => {
159 completion::AssistantContent::Reasoning(Reasoning {
160 id: None,
161 reasoning: vec![thinking],
162 signature: None,
163 })
164 }
165 }))
166 .map_err(|_| {
167 CompletionError::ResponseError(
168 "Response contained no message or tool call (empty)".to_owned(),
169 )
170 })?
171 };
172
173 let usage = response
174 .usage
175 .as_ref()
176 .and_then(|usage| usage.tokens.as_ref())
177 .map(|tokens| {
178 let input_tokens = tokens.input_tokens.unwrap_or(0.0);
179 let output_tokens = tokens.output_tokens.unwrap_or(0.0);
180
181 completion::Usage {
182 input_tokens: input_tokens as u64,
183 output_tokens: output_tokens as u64,
184 total_tokens: (input_tokens + output_tokens) as u64,
185 }
186 })
187 .unwrap_or_default();
188
189 Ok(completion::CompletionResponse {
190 choice: OneOrMany::many(model_response).expect("There is atleast one content"),
191 usage,
192 raw_response: response,
193 })
194 }
195}
196
197#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
198pub struct Document {
199 pub id: String,
200 pub data: HashMap<String, serde_json::Value>,
201}
202
203impl From<completion::Document> for Document {
204 fn from(document: completion::Document) -> Self {
205 let mut data: HashMap<String, serde_json::Value> = HashMap::new();
206
207 document
210 .additional_props
211 .into_iter()
212 .for_each(|(key, value)| {
213 data.insert(key, value.into());
214 });
215
216 data.insert("text".to_string(), document.text.into());
217
218 Self {
219 id: document.id,
220 data,
221 }
222 }
223}
224
225#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
226pub struct ToolCall {
227 #[serde(default)]
228 pub id: Option<String>,
229 #[serde(default)]
230 pub r#type: Option<ToolType>,
231 #[serde(default)]
232 pub function: Option<ToolCallFunction>,
233}
234
235#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
236pub struct ToolCallFunction {
237 pub name: String,
238 #[serde(with = "json_utils::stringified_json")]
239 pub arguments: serde_json::Value,
240}
241
242#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
243#[serde(rename_all = "lowercase")]
244pub enum ToolType {
245 #[default]
246 Function,
247}
248
249#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
250pub struct Tool {
251 pub r#type: ToolType,
252 pub function: Function,
253}
254
255#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
256pub struct Function {
257 pub name: String,
258 #[serde(default)]
259 pub description: Option<String>,
260 pub parameters: serde_json::Value,
261}
262
263impl From<completion::ToolDefinition> for Tool {
264 fn from(tool: completion::ToolDefinition) -> Self {
265 Self {
266 r#type: ToolType::default(),
267 function: Function {
268 name: tool.name,
269 description: Some(tool.description),
270 parameters: tool.parameters,
271 },
272 }
273 }
274}
275
276#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
277#[serde(tag = "role", rename_all = "lowercase")]
278pub enum Message {
279 User {
280 content: OneOrMany<UserContent>,
281 },
282
283 Assistant {
284 #[serde(default)]
285 content: Vec<AssistantContent>,
286 #[serde(default)]
287 citations: Vec<Citation>,
288 #[serde(default)]
289 tool_calls: Vec<ToolCall>,
290 #[serde(default)]
291 tool_plan: Option<String>,
292 },
293
294 Tool {
295 content: OneOrMany<ToolResultContent>,
296 tool_call_id: String,
297 },
298
299 System {
300 content: String,
301 },
302}
303
304#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
305#[serde(tag = "type", rename_all = "lowercase")]
306pub enum UserContent {
307 Text { text: String },
308 ImageUrl { image_url: ImageUrl },
309}
310
311#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
312#[serde(tag = "type", rename_all = "lowercase")]
313pub enum AssistantContent {
314 Text { text: String },
315 Thinking { thinking: String },
316}
317
318#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
319pub struct ImageUrl {
320 pub url: String,
321}
322
323#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
324pub enum ToolResultContent {
325 Text { text: String },
326 Document { document: Document },
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
330pub struct Citation {
331 #[serde(default)]
332 pub start: Option<u32>,
333 #[serde(default)]
334 pub end: Option<u32>,
335 #[serde(default)]
336 pub text: Option<String>,
337 #[serde(rename = "type")]
338 pub citation_type: Option<CitationType>,
339 #[serde(default)]
340 pub sources: Vec<Source>,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
344#[serde(tag = "type", rename_all = "lowercase")]
345pub enum Source {
346 Document {
347 id: Option<String>,
348 document: Option<serde_json::Map<String, serde_json::Value>>,
349 },
350 Tool {
351 id: Option<String>,
352 tool_output: Option<serde_json::Map<String, serde_json::Value>>,
353 },
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
357#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
358pub enum CitationType {
359 TextContent,
360 Plan,
361}
362
363impl TryFrom<message::Message> for Vec<Message> {
364 type Error = message::MessageError;
365
366 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
367 Ok(match message {
368 message::Message::User { content } => content
369 .into_iter()
370 .map(|content| match content {
371 message::UserContent::Text(message::Text { text }) => Ok(Message::User {
372 content: OneOrMany::one(UserContent::Text { text }),
373 }),
374 message::UserContent::ToolResult(message::ToolResult {
375 id, content, ..
376 }) => Ok(Message::Tool {
377 tool_call_id: id,
378 content: content.try_map(|content| match content {
379 message::ToolResultContent::Text(text) => {
380 Ok(ToolResultContent::Text { text: text.text })
381 }
382 _ => Err(message::MessageError::ConversionError(
383 "Only text tool result content is supported by Cohere".to_owned(),
384 )),
385 })?,
386 }),
387 _ => Err(message::MessageError::ConversionError(
388 "Only text content is supported by Cohere".to_owned(),
389 )),
390 })
391 .collect::<Result<Vec<_>, _>>()?,
392 message::Message::Assistant { content, .. } => {
393 let mut text_content = vec![];
394 let mut tool_calls = vec![];
395
396 for content in content.into_iter() {
397 match content {
398 message::AssistantContent::Text(message::Text { text }) => {
399 text_content.push(AssistantContent::Text { text });
400 }
401 message::AssistantContent::ToolCall(message::ToolCall {
402 id,
403 function:
404 message::ToolFunction {
405 name, arguments, ..
406 },
407 ..
408 }) => {
409 tool_calls.push(ToolCall {
410 id: Some(id),
411 r#type: Some(ToolType::Function),
412 function: Some(ToolCallFunction {
413 name,
414 arguments: serde_json::to_value(arguments).unwrap_or_default(),
415 }),
416 });
417 }
418 message::AssistantContent::Reasoning(Reasoning { reasoning, .. }) => {
419 let thinking = reasoning.join("\n");
420 text_content.push(AssistantContent::Thinking { thinking });
421 }
422 message::AssistantContent::Image(_) => {
423 return Err(message::MessageError::ConversionError(
424 "Cohere currently doesn't support images.".to_owned(),
425 ));
426 }
427 }
428 }
429
430 vec![Message::Assistant {
431 content: text_content,
432 citations: vec![],
433 tool_calls,
434 tool_plan: None,
435 }]
436 }
437 })
438 }
439}
440
441impl TryFrom<Message> for message::Message {
442 type Error = message::MessageError;
443
444 fn try_from(message: Message) -> Result<Self, Self::Error> {
445 match message {
446 Message::User { content } => Ok(message::Message::User {
447 content: content.map(|content| match content {
448 UserContent::Text { text } => {
449 message::UserContent::Text(message::Text { text })
450 }
451 UserContent::ImageUrl { image_url } => {
452 message::UserContent::image_url(image_url.url, None, None)
453 }
454 }),
455 }),
456 Message::Assistant {
457 content,
458 tool_calls,
459 ..
460 } => {
461 let mut content = content
462 .into_iter()
463 .map(|content| match content {
464 AssistantContent::Text { text } => message::AssistantContent::text(text),
465 AssistantContent::Thinking { thinking } => {
466 message::AssistantContent::Reasoning(Reasoning {
467 id: None,
468 reasoning: vec![thinking],
469 signature: None,
470 })
471 }
472 })
473 .collect::<Vec<_>>();
474
475 content.extend(tool_calls.into_iter().filter_map(|tool_call| {
476 let ToolCallFunction { name, arguments } = tool_call.function?;
477
478 Some(message::AssistantContent::tool_call(
479 tool_call.id.unwrap_or_else(|| name.clone()),
480 name,
481 arguments,
482 ))
483 }));
484
485 let content = OneOrMany::many(content).map_err(|_| {
486 message::MessageError::ConversionError(
487 "Expected either text content or tool calls".to_string(),
488 )
489 })?;
490
491 Ok(message::Message::Assistant { id: None, content })
492 }
493 Message::Tool {
494 content,
495 tool_call_id,
496 } => {
497 let content = content.try_map(|content| {
498 Ok(match content {
499 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
500 ToolResultContent::Document { document } => {
501 message::ToolResultContent::text(
502 serde_json::to_string(&document.data).map_err(|e| {
503 message::MessageError::ConversionError(
504 format!("Failed to convert tool result document content into text: {e}"),
505 )
506 })?,
507 )
508 }
509 })
510 })?;
511
512 Ok(message::Message::User {
513 content: OneOrMany::one(message::UserContent::tool_result(
514 tool_call_id,
515 content,
516 )),
517 })
518 }
519 Message::System { content } => Ok(message::Message::user(content)),
520 }
521 }
522}
523
524#[derive(Clone)]
525pub struct CompletionModel<T = reqwest::Client> {
526 pub(crate) client: Client<T>,
527 pub model: String,
528}
529
530#[derive(Debug, Serialize, Deserialize)]
531pub(super) struct CohereCompletionRequest {
532 model: String,
533 pub messages: Vec<Message>,
534 documents: Vec<crate::completion::Document>,
535 #[serde(skip_serializing_if = "Option::is_none")]
536 temperature: Option<f64>,
537 #[serde(skip_serializing_if = "Vec::is_empty")]
538 tools: Vec<Tool>,
539 #[serde(skip_serializing_if = "Option::is_none")]
540 tool_choice: Option<ToolChoice>,
541 #[serde(flatten, skip_serializing_if = "Option::is_none")]
542 pub additional_params: Option<serde_json::Value>,
543}
544
545impl TryFrom<(&str, CompletionRequest)> for CohereCompletionRequest {
546 type Error = CompletionError;
547
548 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
549 let mut partial_history = vec![];
550 if let Some(docs) = req.normalized_documents() {
551 partial_history.push(docs);
552 }
553 partial_history.extend(req.chat_history);
554
555 let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
556 vec![Message::System { content: preamble }]
557 });
558
559 full_history.extend(
560 partial_history
561 .into_iter()
562 .map(message::Message::try_into)
563 .collect::<Result<Vec<Vec<Message>>, _>>()?
564 .into_iter()
565 .flatten()
566 .collect::<Vec<_>>(),
567 );
568
569 let tool_choice = if let Some(tool_choice) = req.tool_choice {
570 if !matches!(tool_choice, ToolChoice::Auto) {
571 Some(tool_choice)
572 } else {
573 return Err(CompletionError::RequestError(
574 "\"auto\" is not an allowed tool_choice value in the Cohere API".into(),
575 ));
576 }
577 } else {
578 None
579 };
580
581 Ok(Self {
582 model: model.to_string(),
583 messages: full_history,
584 documents: req.documents,
585 temperature: req.temperature,
586 tools: req.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
587 tool_choice,
588 additional_params: req.additional_params,
589 })
590 }
591}
592
593impl<T> CompletionModel<T>
594where
595 T: HttpClientExt,
596{
597 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
598 Self {
599 client,
600 model: model.into(),
601 }
602 }
603}
604
605impl<T> completion::CompletionModel for CompletionModel<T>
606where
607 T: HttpClientExt + Clone + 'static,
608{
609 type Response = CompletionResponse;
610 type StreamingResponse = StreamingCompletionResponse;
611 type Client = Client<T>;
612
613 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
614 Self::new(client.clone(), model.into())
615 }
616
617 #[cfg_attr(feature = "worker", worker::send)]
618 async fn completion(
619 &self,
620 completion_request: completion::CompletionRequest,
621 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
622 let request = CohereCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
623
624 let llm_span = if tracing::Span::current().is_disabled() {
625 info_span!(
626 target: "rig::completions",
627 "chat",
628 gen_ai.operation.name = "chat",
629 gen_ai.provider.name = "cohere",
630 gen_ai.request.model = self.model,
631 gen_ai.response.id = tracing::field::Empty,
632 gen_ai.response.model = self.model,
633 gen_ai.usage.output_tokens = tracing::field::Empty,
634 gen_ai.usage.input_tokens = tracing::field::Empty,
635 )
636 } else {
637 tracing::Span::current()
638 };
639
640 if enabled!(Level::TRACE) {
641 tracing::trace!(
642 "Cohere completion request: {}",
643 serde_json::to_string_pretty(&request)?
644 );
645 }
646
647 let req_body = serde_json::to_vec(&request)?;
648
649 let req = self.client.post("/v2/chat")?.body(req_body).unwrap();
650
651 async {
652 let response = self
653 .client
654 .send::<_, bytes::Bytes>(req)
655 .await
656 .map_err(|e| http_client::Error::Instance(e.into()))?;
657
658 let status = response.status();
659 let body = response.into_body().into_future().await?.to_owned();
660
661 if status.is_success() {
662 let json_response: CompletionResponse = serde_json::from_slice(&body)?;
663 let span = tracing::Span::current();
664 span.record_token_usage(&json_response.usage);
665 span.record_response_metadata(&json_response);
666
667 if enabled!(Level::TRACE) {
668 tracing::trace!(
669 target: "rig::completions",
670 "Cohere completion response: {}",
671 serde_json::to_string_pretty(&json_response)?
672 );
673 }
674
675 let completion: completion::CompletionResponse<CompletionResponse> =
676 json_response.try_into()?;
677 Ok(completion)
678 } else {
679 Err(CompletionError::ProviderError(
680 String::from_utf8_lossy(&body).to_string(),
681 ))
682 }
683 }
684 .instrument(llm_span)
685 .await
686 }
687
688 #[cfg_attr(feature = "worker", worker::send)]
689 async fn stream(
690 &self,
691 request: CompletionRequest,
692 ) -> Result<
693 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
694 CompletionError,
695 > {
696 CompletionModel::stream(self, request).await
697 }
698}
699#[cfg(test)]
700mod tests {
701 use super::*;
702 use serde_path_to_error::deserialize;
703
704 #[test]
705 fn test_deserialize_completion_response() {
706 let json_data = r#"
707 {
708 "id": "abc123",
709 "message": {
710 "role": "assistant",
711 "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
712 "tool_calls": [
713 {
714 "id": "subtract_sm6ps6fb6y9f",
715 "type": "function",
716 "function": {
717 "name": "subtract",
718 "arguments": "{\"x\":5,\"y\":2}"
719 }
720 }
721 ]
722 },
723 "finish_reason": "TOOL_CALL",
724 "usage": {
725 "billed_units": {
726 "input_tokens": 78,
727 "output_tokens": 27
728 },
729 "tokens": {
730 "input_tokens": 1028,
731 "output_tokens": 63
732 }
733 }
734 }
735 "#;
736
737 let mut deserializer = serde_json::Deserializer::from_str(json_data);
738 let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
739
740 let response = result.unwrap();
741 let (_, citations, tool_calls) = response.message();
742 let CompletionResponse {
743 id,
744 finish_reason,
745 usage,
746 ..
747 } = response;
748
749 assert_eq!(id, "abc123");
750 assert_eq!(finish_reason, FinishReason::ToolCall);
751
752 let Usage {
753 billed_units,
754 tokens,
755 } = usage.unwrap();
756 let BilledUnits {
757 input_tokens: billed_input_tokens,
758 output_tokens: billed_output_tokens,
759 ..
760 } = billed_units.unwrap();
761 let Tokens {
762 input_tokens,
763 output_tokens,
764 } = tokens.unwrap();
765
766 assert_eq!(billed_input_tokens.unwrap(), 78.0);
767 assert_eq!(billed_output_tokens.unwrap(), 27.0);
768 assert_eq!(input_tokens.unwrap(), 1028.0);
769 assert_eq!(output_tokens.unwrap(), 63.0);
770
771 assert!(citations.is_empty());
772 assert_eq!(tool_calls.len(), 1);
773
774 let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
775
776 assert_eq!(name, "subtract");
777 assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
778 }
779
780 #[test]
781 fn test_convert_completion_message_to_message_and_back() {
782 let completion_message = completion::Message::User {
783 content: OneOrMany::one(completion::message::UserContent::Text(
784 completion::message::Text {
785 text: "Hello, world!".to_string(),
786 },
787 )),
788 };
789
790 let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
791 let _converted_back: Vec<completion::Message> = messages
792 .into_iter()
793 .map(|msg| msg.try_into().unwrap())
794 .collect::<Vec<_>>();
795 }
796
797 #[test]
798 fn test_convert_message_to_completion_message_and_back() {
799 let message = Message::User {
800 content: OneOrMany::one(UserContent::Text {
801 text: "Hello, world!".to_string(),
802 }),
803 };
804
805 let completion_message: completion::Message = message.clone().try_into().unwrap();
806 let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
807 }
808}