1use crate::{
2 OneOrMany,
3 completion::{self, CompletionError, GetTokenUsage},
4 http_client::{self, HttpClientExt},
5 json_utils,
6 message::{self, Reasoning, ToolChoice},
7 telemetry::SpanCombinator,
8};
9use std::collections::HashMap;
10
11use super::client::Client;
12use crate::completion::CompletionRequest;
13use crate::providers::cohere::streaming::StreamingCompletionResponse;
14use serde::{Deserialize, Serialize};
15use serde_json::{Value, json};
16use tracing::{Instrument, info_span};
17
18#[derive(Debug, Deserialize, Serialize)]
19pub struct CompletionResponse {
20 pub id: String,
21 pub finish_reason: FinishReason,
22 message: Message,
23 #[serde(default)]
24 pub usage: Option<Usage>,
25}
26
27impl CompletionResponse {
28 pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
30 let Message::Assistant {
31 content,
32 citations,
33 tool_calls,
34 ..
35 } = self.message.clone()
36 else {
37 unreachable!("Completion responses will only return an assistant message")
38 };
39
40 (content, citations, tool_calls)
41 }
42}
43
44impl crate::telemetry::ProviderResponseExt for CompletionResponse {
45 type OutputMessage = Message;
46 type Usage = Usage;
47
48 fn get_response_id(&self) -> Option<String> {
49 Some(self.id.clone())
50 }
51
52 fn get_response_model_name(&self) -> Option<String> {
53 None
54 }
55
56 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
57 vec![self.message.clone()]
58 }
59
60 fn get_text_response(&self) -> Option<String> {
61 let Message::Assistant { ref content, .. } = self.message else {
62 return None;
63 };
64
65 let res = content
66 .iter()
67 .filter_map(|x| {
68 if let AssistantContent::Text { text } = x {
69 Some(text.to_string())
70 } else {
71 None
72 }
73 })
74 .collect::<Vec<String>>()
75 .join("\n");
76
77 if res.is_empty() { None } else { Some(res) }
78 }
79
80 fn get_usage(&self) -> Option<Self::Usage> {
81 self.usage.clone()
82 }
83}
84
85#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
86#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
87pub enum FinishReason {
88 MaxTokens,
89 StopSequence,
90 Complete,
91 Error,
92 ToolCall,
93}
94
95#[derive(Debug, Deserialize, Clone, Serialize)]
96pub struct Usage {
97 #[serde(default)]
98 pub billed_units: Option<BilledUnits>,
99 #[serde(default)]
100 pub tokens: Option<Tokens>,
101}
102
103impl GetTokenUsage for Usage {
104 fn token_usage(&self) -> Option<crate::completion::Usage> {
105 let mut usage = crate::completion::Usage::new();
106
107 if let Some(ref billed_units) = self.billed_units {
108 usage.input_tokens = billed_units.input_tokens.unwrap_or_default() as u64;
109 usage.output_tokens = billed_units.output_tokens.unwrap_or_default() as u64;
110 usage.total_tokens = usage.input_tokens + usage.output_tokens;
111 }
112
113 Some(usage)
114 }
115}
116
117#[derive(Debug, Deserialize, Clone, Serialize)]
118pub struct BilledUnits {
119 #[serde(default)]
120 pub output_tokens: Option<f64>,
121 #[serde(default)]
122 pub classifications: Option<f64>,
123 #[serde(default)]
124 pub search_units: Option<f64>,
125 #[serde(default)]
126 pub input_tokens: Option<f64>,
127}
128
129#[derive(Debug, Deserialize, Clone, Serialize)]
130pub struct Tokens {
131 #[serde(default)]
132 pub input_tokens: Option<f64>,
133 #[serde(default)]
134 pub output_tokens: Option<f64>,
135}
136
137impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
138 type Error = CompletionError;
139
140 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
141 let (content, _, tool_calls) = response.message();
142
143 let model_response = if !tool_calls.is_empty() {
144 OneOrMany::many(
145 tool_calls
146 .into_iter()
147 .filter_map(|tool_call| {
148 let ToolCallFunction { name, arguments } = tool_call.function?;
149 let id = tool_call.id.unwrap_or_else(|| name.clone());
150
151 Some(completion::AssistantContent::tool_call(id, name, arguments))
152 })
153 .collect::<Vec<_>>(),
154 )
155 .expect("We have atleast 1 tool call in this if block")
156 } else {
157 OneOrMany::many(content.into_iter().map(|content| match content {
158 AssistantContent::Text { text } => completion::AssistantContent::text(text),
159 AssistantContent::Thinking { thinking } => {
160 completion::AssistantContent::Reasoning(Reasoning {
161 id: None,
162 reasoning: vec![thinking],
163 })
164 }
165 }))
166 .map_err(|_| {
167 CompletionError::ResponseError(
168 "Response contained no message or tool call (empty)".to_owned(),
169 )
170 })?
171 };
172
173 let usage = response
174 .usage
175 .as_ref()
176 .and_then(|usage| usage.tokens.as_ref())
177 .map(|tokens| {
178 let input_tokens = tokens.input_tokens.unwrap_or(0.0);
179 let output_tokens = tokens.output_tokens.unwrap_or(0.0);
180
181 completion::Usage {
182 input_tokens: input_tokens as u64,
183 output_tokens: output_tokens as u64,
184 total_tokens: (input_tokens + output_tokens) as u64,
185 }
186 })
187 .unwrap_or_default();
188
189 Ok(completion::CompletionResponse {
190 choice: OneOrMany::many(model_response).expect("There is atleast one content"),
191 usage,
192 raw_response: response,
193 })
194 }
195}
196
197#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
198pub struct Document {
199 pub id: String,
200 pub data: HashMap<String, serde_json::Value>,
201}
202
203impl From<completion::Document> for Document {
204 fn from(document: completion::Document) -> Self {
205 let mut data: HashMap<String, serde_json::Value> = HashMap::new();
206
207 document
210 .additional_props
211 .into_iter()
212 .for_each(|(key, value)| {
213 data.insert(key, value.into());
214 });
215
216 data.insert("text".to_string(), document.text.into());
217
218 Self {
219 id: document.id,
220 data,
221 }
222 }
223}
224
225#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
226pub struct ToolCall {
227 #[serde(default)]
228 pub id: Option<String>,
229 #[serde(default)]
230 pub r#type: Option<ToolType>,
231 #[serde(default)]
232 pub function: Option<ToolCallFunction>,
233}
234
235#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
236pub struct ToolCallFunction {
237 pub name: String,
238 #[serde(with = "json_utils::stringified_json")]
239 pub arguments: serde_json::Value,
240}
241
242#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
243#[serde(rename_all = "lowercase")]
244pub enum ToolType {
245 #[default]
246 Function,
247}
248
249#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
250pub struct Tool {
251 pub r#type: ToolType,
252 pub function: Function,
253}
254
255#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
256pub struct Function {
257 pub name: String,
258 #[serde(default)]
259 pub description: Option<String>,
260 pub parameters: serde_json::Value,
261}
262
263impl From<completion::ToolDefinition> for Tool {
264 fn from(tool: completion::ToolDefinition) -> Self {
265 Self {
266 r#type: ToolType::default(),
267 function: Function {
268 name: tool.name,
269 description: Some(tool.description),
270 parameters: tool.parameters,
271 },
272 }
273 }
274}
275
276#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
277#[serde(tag = "role", rename_all = "lowercase")]
278pub enum Message {
279 User {
280 content: OneOrMany<UserContent>,
281 },
282
283 Assistant {
284 #[serde(default)]
285 content: Vec<AssistantContent>,
286 #[serde(default)]
287 citations: Vec<Citation>,
288 #[serde(default)]
289 tool_calls: Vec<ToolCall>,
290 #[serde(default)]
291 tool_plan: Option<String>,
292 },
293
294 Tool {
295 content: OneOrMany<ToolResultContent>,
296 tool_call_id: String,
297 },
298
299 System {
300 content: String,
301 },
302}
303
304#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
305#[serde(tag = "type", rename_all = "lowercase")]
306pub enum UserContent {
307 Text { text: String },
308 ImageUrl { image_url: ImageUrl },
309}
310
311#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
312#[serde(tag = "type", rename_all = "lowercase")]
313pub enum AssistantContent {
314 Text { text: String },
315 Thinking { thinking: String },
316}
317
318#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
319pub struct ImageUrl {
320 pub url: String,
321}
322
323#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
324pub enum ToolResultContent {
325 Text { text: String },
326 Document { document: Document },
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
330pub struct Citation {
331 #[serde(default)]
332 pub start: Option<u32>,
333 #[serde(default)]
334 pub end: Option<u32>,
335 #[serde(default)]
336 pub text: Option<String>,
337 #[serde(rename = "type")]
338 pub citation_type: Option<CitationType>,
339 #[serde(default)]
340 pub sources: Vec<Source>,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
344#[serde(tag = "type", rename_all = "lowercase")]
345pub enum Source {
346 Document {
347 id: Option<String>,
348 document: Option<serde_json::Map<String, serde_json::Value>>,
349 },
350 Tool {
351 id: Option<String>,
352 tool_output: Option<serde_json::Map<String, serde_json::Value>>,
353 },
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
357#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
358pub enum CitationType {
359 TextContent,
360 Plan,
361}
362
363impl TryFrom<message::Message> for Vec<Message> {
364 type Error = message::MessageError;
365
366 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
367 Ok(match message {
368 message::Message::User { content } => content
369 .into_iter()
370 .map(|content| match content {
371 message::UserContent::Text(message::Text { text }) => Ok(Message::User {
372 content: OneOrMany::one(UserContent::Text { text }),
373 }),
374 message::UserContent::ToolResult(message::ToolResult {
375 id, content, ..
376 }) => Ok(Message::Tool {
377 tool_call_id: id,
378 content: content.try_map(|content| match content {
379 message::ToolResultContent::Text(text) => {
380 Ok(ToolResultContent::Text { text: text.text })
381 }
382 _ => Err(message::MessageError::ConversionError(
383 "Only text tool result content is supported by Cohere".to_owned(),
384 )),
385 })?,
386 }),
387 _ => Err(message::MessageError::ConversionError(
388 "Only text content is supported by Cohere".to_owned(),
389 )),
390 })
391 .collect::<Result<Vec<_>, _>>()?,
392 message::Message::Assistant { content, .. } => {
393 let mut text_content = vec![];
394 let mut tool_calls = vec![];
395 content.into_iter().for_each(|content| match content {
396 message::AssistantContent::Text(message::Text { text }) => {
397 text_content.push(AssistantContent::Text { text });
398 }
399 message::AssistantContent::ToolCall(message::ToolCall {
400 id,
401 function:
402 message::ToolFunction {
403 name, arguments, ..
404 },
405 ..
406 }) => {
407 tool_calls.push(ToolCall {
408 id: Some(id),
409 r#type: Some(ToolType::Function),
410 function: Some(ToolCallFunction {
411 name,
412 arguments: serde_json::to_value(arguments).unwrap_or_default(),
413 }),
414 });
415 }
416 message::AssistantContent::Reasoning(Reasoning { reasoning, .. }) => {
417 let thinking = reasoning.join("\n");
418 text_content.push(AssistantContent::Thinking { thinking });
419 }
420 });
421
422 vec![Message::Assistant {
423 content: text_content,
424 citations: vec![],
425 tool_calls,
426 tool_plan: None,
427 }]
428 }
429 })
430 }
431}
432
433impl TryFrom<Message> for message::Message {
434 type Error = message::MessageError;
435
436 fn try_from(message: Message) -> Result<Self, Self::Error> {
437 match message {
438 Message::User { content } => Ok(message::Message::User {
439 content: content.map(|content| match content {
440 UserContent::Text { text } => {
441 message::UserContent::Text(message::Text { text })
442 }
443 UserContent::ImageUrl { image_url } => {
444 message::UserContent::image_url(image_url.url, None, None)
445 }
446 }),
447 }),
448 Message::Assistant {
449 content,
450 tool_calls,
451 ..
452 } => {
453 let mut content = content
454 .into_iter()
455 .map(|content| match content {
456 AssistantContent::Text { text } => message::AssistantContent::text(text),
457 AssistantContent::Thinking { thinking } => {
458 message::AssistantContent::Reasoning(Reasoning {
459 id: None,
460 reasoning: vec![thinking],
461 })
462 }
463 })
464 .collect::<Vec<_>>();
465
466 content.extend(tool_calls.into_iter().filter_map(|tool_call| {
467 let ToolCallFunction { name, arguments } = tool_call.function?;
468
469 Some(message::AssistantContent::tool_call(
470 tool_call.id.unwrap_or_else(|| name.clone()),
471 name,
472 arguments,
473 ))
474 }));
475
476 let content = OneOrMany::many(content).map_err(|_| {
477 message::MessageError::ConversionError(
478 "Expected either text content or tool calls".to_string(),
479 )
480 })?;
481
482 Ok(message::Message::Assistant { id: None, content })
483 }
484 Message::Tool {
485 content,
486 tool_call_id,
487 } => {
488 let content = content.try_map(|content| {
489 Ok(match content {
490 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
491 ToolResultContent::Document { document } => {
492 message::ToolResultContent::text(
493 serde_json::to_string(&document.data).map_err(|e| {
494 message::MessageError::ConversionError(
495 format!("Failed to convert tool result document content into text: {e}"),
496 )
497 })?,
498 )
499 }
500 })
501 })?;
502
503 Ok(message::Message::User {
504 content: OneOrMany::one(message::UserContent::tool_result(
505 tool_call_id,
506 content,
507 )),
508 })
509 }
510 Message::System { content } => Ok(message::Message::user(content)),
511 }
512 }
513}
514
515#[derive(Clone)]
516pub struct CompletionModel<T = reqwest::Client> {
517 pub(crate) client: Client<T>,
518 pub model: String,
519}
520
521impl<T> CompletionModel<T>
522where
523 T: HttpClientExt,
524{
525 pub fn new(client: Client<T>, model: &str) -> Self {
526 Self {
527 client,
528 model: model.to_string(),
529 }
530 }
531
532 pub(crate) fn create_completion_request(
533 &self,
534 completion_request: CompletionRequest,
535 ) -> Result<Value, CompletionError> {
536 let mut partial_history = vec![];
538 if let Some(docs) = completion_request.normalized_documents() {
539 partial_history.push(docs);
540 }
541 partial_history.extend(completion_request.chat_history);
542
543 let mut full_history: Vec<Message> = completion_request
545 .preamble
546 .map_or_else(Vec::new, |preamble| {
547 vec![Message::System { content: preamble }]
548 });
549
550 full_history.extend(
552 partial_history
553 .into_iter()
554 .map(message::Message::try_into)
555 .collect::<Result<Vec<Vec<Message>>, _>>()?
556 .into_iter()
557 .flatten()
558 .collect::<Vec<_>>(),
559 );
560
561 let request = json!({
562 "model": self.model,
563 "messages": full_history,
564 "documents": completion_request.documents,
565 "temperature": completion_request.temperature,
566 "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
567 "tool_choice": if let Some(tool_choice) = completion_request.tool_choice && !matches!(tool_choice, ToolChoice::Auto) { tool_choice } else {
568 return Err(CompletionError::RequestError("\"auto\" is not an allowed tool_choice value in the Cohere API".into()))
569 },
570 });
571
572 if let Some(ref params) = completion_request.additional_params {
573 Ok(json_utils::merge(request.clone(), params.clone()))
574 } else {
575 Ok(request)
576 }
577 }
578}
579
580impl completion::CompletionModel for CompletionModel<reqwest::Client> {
581 type Response = CompletionResponse;
582 type StreamingResponse = StreamingCompletionResponse;
583
584 #[cfg_attr(feature = "worker", worker::send)]
585 async fn completion(
586 &self,
587 completion_request: completion::CompletionRequest,
588 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
589 let request = self.create_completion_request(completion_request)?;
590
591 let llm_span = if tracing::Span::current().is_disabled() {
592 info_span!(
593 target: "rig::completions",
594 "chat",
595 gen_ai.operation.name = "chat",
596 gen_ai.provider.name = "cohere",
597 gen_ai.request.model = self.model,
598 gen_ai.response.id = tracing::field::Empty,
599 gen_ai.response.model = self.model,
600 gen_ai.usage.output_tokens = tracing::field::Empty,
601 gen_ai.usage.input_tokens = tracing::field::Empty,
602 gen_ai.input.messages = serde_json::to_string(request.get("messages").expect("Converting request messages to JSON should not fail!")).unwrap(),
603 gen_ai.output.messages = tracing::field::Empty,
604 )
605 } else {
606 tracing::Span::current()
607 };
608
609 tracing::debug!(
610 "Cohere request: {}",
611 serde_json::to_string_pretty(&request)?
612 );
613
614 async {
615 let response = self
616 .client
617 .client()
618 .post("/v2/chat")
619 .json(&request)
620 .send()
621 .await
622 .map_err(|e| http_client::Error::Instance(e.into()))?;
623
624 if response.status().is_success() {
625 let text_response = response.text().await.map_err(|e| {
626 CompletionError::HttpError(http_client::Error::Instance(e.into()))
627 })?;
628 tracing::debug!("Cohere completion request: {}", text_response);
629
630 let json_response: CompletionResponse = serde_json::from_str(&text_response)?;
631 let span = tracing::Span::current();
632 span.record_token_usage(&json_response.usage);
633 span.record_model_output(&json_response.message);
634 span.record_response_metadata(&json_response);
635 tracing::debug!("Cohere completion response: {}", text_response);
636 let completion: completion::CompletionResponse<CompletionResponse> =
637 json_response.try_into()?;
638 Ok(completion)
639 } else {
640 Err(CompletionError::ProviderError(
641 response.text().await.map_err(|e| {
642 CompletionError::HttpError(http_client::Error::Instance(e.into()))
643 })?,
644 ))
645 }
646 }
647 .instrument(llm_span)
648 .await
649 }
650
651 #[cfg_attr(feature = "worker", worker::send)]
652 async fn stream(
653 &self,
654 request: CompletionRequest,
655 ) -> Result<
656 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
657 CompletionError,
658 > {
659 CompletionModel::stream(self, request).await
660 }
661}
662#[cfg(test)]
663mod tests {
664 use super::*;
665 use serde_path_to_error::deserialize;
666
667 #[test]
668 fn test_deserialize_completion_response() {
669 let json_data = r#"
670 {
671 "id": "abc123",
672 "message": {
673 "role": "assistant",
674 "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
675 "tool_calls": [
676 {
677 "id": "subtract_sm6ps6fb6y9f",
678 "type": "function",
679 "function": {
680 "name": "subtract",
681 "arguments": "{\"x\":5,\"y\":2}"
682 }
683 }
684 ]
685 },
686 "finish_reason": "TOOL_CALL",
687 "usage": {
688 "billed_units": {
689 "input_tokens": 78,
690 "output_tokens": 27
691 },
692 "tokens": {
693 "input_tokens": 1028,
694 "output_tokens": 63
695 }
696 }
697 }
698 "#;
699
700 let mut deserializer = serde_json::Deserializer::from_str(json_data);
701 let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
702
703 let response = result.unwrap();
704 let (_, citations, tool_calls) = response.message();
705 let CompletionResponse {
706 id,
707 finish_reason,
708 usage,
709 ..
710 } = response;
711
712 assert_eq!(id, "abc123");
713 assert_eq!(finish_reason, FinishReason::ToolCall);
714
715 let Usage {
716 billed_units,
717 tokens,
718 } = usage.unwrap();
719 let BilledUnits {
720 input_tokens: billed_input_tokens,
721 output_tokens: billed_output_tokens,
722 ..
723 } = billed_units.unwrap();
724 let Tokens {
725 input_tokens,
726 output_tokens,
727 } = tokens.unwrap();
728
729 assert_eq!(billed_input_tokens.unwrap(), 78.0);
730 assert_eq!(billed_output_tokens.unwrap(), 27.0);
731 assert_eq!(input_tokens.unwrap(), 1028.0);
732 assert_eq!(output_tokens.unwrap(), 63.0);
733
734 assert!(citations.is_empty());
735 assert_eq!(tool_calls.len(), 1);
736
737 let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
738
739 assert_eq!(name, "subtract");
740 assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
741 }
742
743 #[test]
744 fn test_convert_completion_message_to_message_and_back() {
745 let completion_message = completion::Message::User {
746 content: OneOrMany::one(completion::message::UserContent::Text(
747 completion::message::Text {
748 text: "Hello, world!".to_string(),
749 },
750 )),
751 };
752
753 let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
754 let _converted_back: Vec<completion::Message> = messages
755 .into_iter()
756 .map(|msg| msg.try_into().unwrap())
757 .collect::<Vec<_>>();
758 }
759
760 #[test]
761 fn test_convert_message_to_completion_message_and_back() {
762 let message = Message::User {
763 content: OneOrMany::one(UserContent::Text {
764 text: "Hello, world!".to_string(),
765 }),
766 };
767
768 let completion_message: completion::Message = message.clone().try_into().unwrap();
769 let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
770 }
771}