1use crate::{
2 OneOrMany,
3 completion::{self, CompletionError, GetTokenUsage},
4 http_client::{self, HttpClientExt},
5 json_utils,
6 message::{self, Reasoning, ToolChoice},
7 telemetry::SpanCombinator,
8};
9use std::collections::HashMap;
10
11use super::client::Client;
12use crate::completion::CompletionRequest;
13use crate::providers::cohere::streaming::StreamingCompletionResponse;
14use http::Method;
15use serde::{Deserialize, Serialize};
16use serde_json::{Value, json};
17use tracing::{Instrument, info_span};
18
19#[derive(Debug, Deserialize, Serialize)]
20pub struct CompletionResponse {
21 pub id: String,
22 pub finish_reason: FinishReason,
23 message: Message,
24 #[serde(default)]
25 pub usage: Option<Usage>,
26}
27
28impl CompletionResponse {
29 pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
31 let Message::Assistant {
32 content,
33 citations,
34 tool_calls,
35 ..
36 } = self.message.clone()
37 else {
38 unreachable!("Completion responses will only return an assistant message")
39 };
40
41 (content, citations, tool_calls)
42 }
43}
44
45impl crate::telemetry::ProviderResponseExt for CompletionResponse {
46 type OutputMessage = Message;
47 type Usage = Usage;
48
49 fn get_response_id(&self) -> Option<String> {
50 Some(self.id.clone())
51 }
52
53 fn get_response_model_name(&self) -> Option<String> {
54 None
55 }
56
57 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
58 vec![self.message.clone()]
59 }
60
61 fn get_text_response(&self) -> Option<String> {
62 let Message::Assistant { ref content, .. } = self.message else {
63 return None;
64 };
65
66 let res = content
67 .iter()
68 .filter_map(|x| {
69 if let AssistantContent::Text { text } = x {
70 Some(text.to_string())
71 } else {
72 None
73 }
74 })
75 .collect::<Vec<String>>()
76 .join("\n");
77
78 if res.is_empty() { None } else { Some(res) }
79 }
80
81 fn get_usage(&self) -> Option<Self::Usage> {
82 self.usage.clone()
83 }
84}
85
86#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
87#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
88pub enum FinishReason {
89 MaxTokens,
90 StopSequence,
91 Complete,
92 Error,
93 ToolCall,
94}
95
96#[derive(Debug, Deserialize, Clone, Serialize)]
97pub struct Usage {
98 #[serde(default)]
99 pub billed_units: Option<BilledUnits>,
100 #[serde(default)]
101 pub tokens: Option<Tokens>,
102}
103
104impl GetTokenUsage for Usage {
105 fn token_usage(&self) -> Option<crate::completion::Usage> {
106 let mut usage = crate::completion::Usage::new();
107
108 if let Some(ref billed_units) = self.billed_units {
109 usage.input_tokens = billed_units.input_tokens.unwrap_or_default() as u64;
110 usage.output_tokens = billed_units.output_tokens.unwrap_or_default() as u64;
111 usage.total_tokens = usage.input_tokens + usage.output_tokens;
112 }
113
114 Some(usage)
115 }
116}
117
118#[derive(Debug, Deserialize, Clone, Serialize)]
119pub struct BilledUnits {
120 #[serde(default)]
121 pub output_tokens: Option<f64>,
122 #[serde(default)]
123 pub classifications: Option<f64>,
124 #[serde(default)]
125 pub search_units: Option<f64>,
126 #[serde(default)]
127 pub input_tokens: Option<f64>,
128}
129
130#[derive(Debug, Deserialize, Clone, Serialize)]
131pub struct Tokens {
132 #[serde(default)]
133 pub input_tokens: Option<f64>,
134 #[serde(default)]
135 pub output_tokens: Option<f64>,
136}
137
138impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
139 type Error = CompletionError;
140
141 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
142 let (content, _, tool_calls) = response.message();
143
144 let model_response = if !tool_calls.is_empty() {
145 OneOrMany::many(
146 tool_calls
147 .into_iter()
148 .filter_map(|tool_call| {
149 let ToolCallFunction { name, arguments } = tool_call.function?;
150 let id = tool_call.id.unwrap_or_else(|| name.clone());
151
152 Some(completion::AssistantContent::tool_call(id, name, arguments))
153 })
154 .collect::<Vec<_>>(),
155 )
156 .expect("We have atleast 1 tool call in this if block")
157 } else {
158 OneOrMany::many(content.into_iter().map(|content| match content {
159 AssistantContent::Text { text } => completion::AssistantContent::text(text),
160 AssistantContent::Thinking { thinking } => {
161 completion::AssistantContent::Reasoning(Reasoning {
162 id: None,
163 reasoning: vec![thinking],
164 signature: None,
165 })
166 }
167 }))
168 .map_err(|_| {
169 CompletionError::ResponseError(
170 "Response contained no message or tool call (empty)".to_owned(),
171 )
172 })?
173 };
174
175 let usage = response
176 .usage
177 .as_ref()
178 .and_then(|usage| usage.tokens.as_ref())
179 .map(|tokens| {
180 let input_tokens = tokens.input_tokens.unwrap_or(0.0);
181 let output_tokens = tokens.output_tokens.unwrap_or(0.0);
182
183 completion::Usage {
184 input_tokens: input_tokens as u64,
185 output_tokens: output_tokens as u64,
186 total_tokens: (input_tokens + output_tokens) as u64,
187 }
188 })
189 .unwrap_or_default();
190
191 Ok(completion::CompletionResponse {
192 choice: OneOrMany::many(model_response).expect("There is atleast one content"),
193 usage,
194 raw_response: response,
195 })
196 }
197}
198
199#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
200pub struct Document {
201 pub id: String,
202 pub data: HashMap<String, serde_json::Value>,
203}
204
205impl From<completion::Document> for Document {
206 fn from(document: completion::Document) -> Self {
207 let mut data: HashMap<String, serde_json::Value> = HashMap::new();
208
209 document
212 .additional_props
213 .into_iter()
214 .for_each(|(key, value)| {
215 data.insert(key, value.into());
216 });
217
218 data.insert("text".to_string(), document.text.into());
219
220 Self {
221 id: document.id,
222 data,
223 }
224 }
225}
226
227#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
228pub struct ToolCall {
229 #[serde(default)]
230 pub id: Option<String>,
231 #[serde(default)]
232 pub r#type: Option<ToolType>,
233 #[serde(default)]
234 pub function: Option<ToolCallFunction>,
235}
236
237#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
238pub struct ToolCallFunction {
239 pub name: String,
240 #[serde(with = "json_utils::stringified_json")]
241 pub arguments: serde_json::Value,
242}
243
244#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
245#[serde(rename_all = "lowercase")]
246pub enum ToolType {
247 #[default]
248 Function,
249}
250
251#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
252pub struct Tool {
253 pub r#type: ToolType,
254 pub function: Function,
255}
256
257#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
258pub struct Function {
259 pub name: String,
260 #[serde(default)]
261 pub description: Option<String>,
262 pub parameters: serde_json::Value,
263}
264
265impl From<completion::ToolDefinition> for Tool {
266 fn from(tool: completion::ToolDefinition) -> Self {
267 Self {
268 r#type: ToolType::default(),
269 function: Function {
270 name: tool.name,
271 description: Some(tool.description),
272 parameters: tool.parameters,
273 },
274 }
275 }
276}
277
278#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
279#[serde(tag = "role", rename_all = "lowercase")]
280pub enum Message {
281 User {
282 content: OneOrMany<UserContent>,
283 },
284
285 Assistant {
286 #[serde(default)]
287 content: Vec<AssistantContent>,
288 #[serde(default)]
289 citations: Vec<Citation>,
290 #[serde(default)]
291 tool_calls: Vec<ToolCall>,
292 #[serde(default)]
293 tool_plan: Option<String>,
294 },
295
296 Tool {
297 content: OneOrMany<ToolResultContent>,
298 tool_call_id: String,
299 },
300
301 System {
302 content: String,
303 },
304}
305
306#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
307#[serde(tag = "type", rename_all = "lowercase")]
308pub enum UserContent {
309 Text { text: String },
310 ImageUrl { image_url: ImageUrl },
311}
312
313#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
314#[serde(tag = "type", rename_all = "lowercase")]
315pub enum AssistantContent {
316 Text { text: String },
317 Thinking { thinking: String },
318}
319
320#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
321pub struct ImageUrl {
322 pub url: String,
323}
324
325#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
326pub enum ToolResultContent {
327 Text { text: String },
328 Document { document: Document },
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
332pub struct Citation {
333 #[serde(default)]
334 pub start: Option<u32>,
335 #[serde(default)]
336 pub end: Option<u32>,
337 #[serde(default)]
338 pub text: Option<String>,
339 #[serde(rename = "type")]
340 pub citation_type: Option<CitationType>,
341 #[serde(default)]
342 pub sources: Vec<Source>,
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
346#[serde(tag = "type", rename_all = "lowercase")]
347pub enum Source {
348 Document {
349 id: Option<String>,
350 document: Option<serde_json::Map<String, serde_json::Value>>,
351 },
352 Tool {
353 id: Option<String>,
354 tool_output: Option<serde_json::Map<String, serde_json::Value>>,
355 },
356}
357
358#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
359#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
360pub enum CitationType {
361 TextContent,
362 Plan,
363}
364
365impl TryFrom<message::Message> for Vec<Message> {
366 type Error = message::MessageError;
367
368 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
369 Ok(match message {
370 message::Message::User { content } => content
371 .into_iter()
372 .map(|content| match content {
373 message::UserContent::Text(message::Text { text }) => Ok(Message::User {
374 content: OneOrMany::one(UserContent::Text { text }),
375 }),
376 message::UserContent::ToolResult(message::ToolResult {
377 id, content, ..
378 }) => Ok(Message::Tool {
379 tool_call_id: id,
380 content: content.try_map(|content| match content {
381 message::ToolResultContent::Text(text) => {
382 Ok(ToolResultContent::Text { text: text.text })
383 }
384 _ => Err(message::MessageError::ConversionError(
385 "Only text tool result content is supported by Cohere".to_owned(),
386 )),
387 })?,
388 }),
389 _ => Err(message::MessageError::ConversionError(
390 "Only text content is supported by Cohere".to_owned(),
391 )),
392 })
393 .collect::<Result<Vec<_>, _>>()?,
394 message::Message::Assistant { content, .. } => {
395 let mut text_content = vec![];
396 let mut tool_calls = vec![];
397 content.into_iter().for_each(|content| match content {
398 message::AssistantContent::Text(message::Text { text }) => {
399 text_content.push(AssistantContent::Text { text });
400 }
401 message::AssistantContent::ToolCall(message::ToolCall {
402 id,
403 function:
404 message::ToolFunction {
405 name, arguments, ..
406 },
407 ..
408 }) => {
409 tool_calls.push(ToolCall {
410 id: Some(id),
411 r#type: Some(ToolType::Function),
412 function: Some(ToolCallFunction {
413 name,
414 arguments: serde_json::to_value(arguments).unwrap_or_default(),
415 }),
416 });
417 }
418 message::AssistantContent::Reasoning(Reasoning { reasoning, .. }) => {
419 let thinking = reasoning.join("\n");
420 text_content.push(AssistantContent::Thinking { thinking });
421 }
422 });
423
424 vec![Message::Assistant {
425 content: text_content,
426 citations: vec![],
427 tool_calls,
428 tool_plan: None,
429 }]
430 }
431 })
432 }
433}
434
435impl TryFrom<Message> for message::Message {
436 type Error = message::MessageError;
437
438 fn try_from(message: Message) -> Result<Self, Self::Error> {
439 match message {
440 Message::User { content } => Ok(message::Message::User {
441 content: content.map(|content| match content {
442 UserContent::Text { text } => {
443 message::UserContent::Text(message::Text { text })
444 }
445 UserContent::ImageUrl { image_url } => {
446 message::UserContent::image_url(image_url.url, None, None)
447 }
448 }),
449 }),
450 Message::Assistant {
451 content,
452 tool_calls,
453 ..
454 } => {
455 let mut content = content
456 .into_iter()
457 .map(|content| match content {
458 AssistantContent::Text { text } => message::AssistantContent::text(text),
459 AssistantContent::Thinking { thinking } => {
460 message::AssistantContent::Reasoning(Reasoning {
461 id: None,
462 reasoning: vec![thinking],
463 signature: None,
464 })
465 }
466 })
467 .collect::<Vec<_>>();
468
469 content.extend(tool_calls.into_iter().filter_map(|tool_call| {
470 let ToolCallFunction { name, arguments } = tool_call.function?;
471
472 Some(message::AssistantContent::tool_call(
473 tool_call.id.unwrap_or_else(|| name.clone()),
474 name,
475 arguments,
476 ))
477 }));
478
479 let content = OneOrMany::many(content).map_err(|_| {
480 message::MessageError::ConversionError(
481 "Expected either text content or tool calls".to_string(),
482 )
483 })?;
484
485 Ok(message::Message::Assistant { id: None, content })
486 }
487 Message::Tool {
488 content,
489 tool_call_id,
490 } => {
491 let content = content.try_map(|content| {
492 Ok(match content {
493 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
494 ToolResultContent::Document { document } => {
495 message::ToolResultContent::text(
496 serde_json::to_string(&document.data).map_err(|e| {
497 message::MessageError::ConversionError(
498 format!("Failed to convert tool result document content into text: {e}"),
499 )
500 })?,
501 )
502 }
503 })
504 })?;
505
506 Ok(message::Message::User {
507 content: OneOrMany::one(message::UserContent::tool_result(
508 tool_call_id,
509 content,
510 )),
511 })
512 }
513 Message::System { content } => Ok(message::Message::user(content)),
514 }
515 }
516}
517
518#[derive(Clone)]
519pub struct CompletionModel<T = reqwest::Client> {
520 pub(crate) client: Client<T>,
521 pub model: String,
522}
523
524impl<T> CompletionModel<T>
525where
526 T: HttpClientExt,
527{
528 pub fn new(client: Client<T>, model: &str) -> Self {
529 Self {
530 client,
531 model: model.to_string(),
532 }
533 }
534
535 pub(crate) fn create_completion_request(
536 &self,
537 completion_request: CompletionRequest,
538 ) -> Result<Value, CompletionError> {
539 let mut partial_history = vec![];
541 if let Some(docs) = completion_request.normalized_documents() {
542 partial_history.push(docs);
543 }
544 partial_history.extend(completion_request.chat_history);
545
546 let mut full_history: Vec<Message> = completion_request
548 .preamble
549 .map_or_else(Vec::new, |preamble| {
550 vec![Message::System { content: preamble }]
551 });
552
553 full_history.extend(
555 partial_history
556 .into_iter()
557 .map(message::Message::try_into)
558 .collect::<Result<Vec<Vec<Message>>, _>>()?
559 .into_iter()
560 .flatten()
561 .collect::<Vec<_>>(),
562 );
563
564 let request = json!({
565 "model": self.model,
566 "messages": full_history,
567 "documents": completion_request.documents,
568 "temperature": completion_request.temperature,
569 "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
570 "tool_choice": if let Some(tool_choice) = completion_request.tool_choice && !matches!(tool_choice, ToolChoice::Auto) { tool_choice } else {
571 return Err(CompletionError::RequestError("\"auto\" is not an allowed tool_choice value in the Cohere API".into()))
572 },
573 });
574
575 if let Some(ref params) = completion_request.additional_params {
576 Ok(json_utils::merge(request.clone(), params.clone()))
577 } else {
578 Ok(request)
579 }
580 }
581}
582
583impl<T> completion::CompletionModel for CompletionModel<T>
584where
585 T: HttpClientExt + Clone + 'static,
586{
587 type Response = CompletionResponse;
588 type StreamingResponse = StreamingCompletionResponse;
589
590 #[cfg_attr(feature = "worker", worker::send)]
591 async fn completion(
592 &self,
593 completion_request: completion::CompletionRequest,
594 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
595 let request = self.create_completion_request(completion_request)?;
596
597 let llm_span = if tracing::Span::current().is_disabled() {
598 info_span!(
599 target: "rig::completions",
600 "chat",
601 gen_ai.operation.name = "chat",
602 gen_ai.provider.name = "cohere",
603 gen_ai.request.model = self.model,
604 gen_ai.response.id = tracing::field::Empty,
605 gen_ai.response.model = self.model,
606 gen_ai.usage.output_tokens = tracing::field::Empty,
607 gen_ai.usage.input_tokens = tracing::field::Empty,
608 gen_ai.input.messages = serde_json::to_string(request.get("messages").expect("Converting request messages to JSON should not fail!")).unwrap(),
609 gen_ai.output.messages = tracing::field::Empty,
610 )
611 } else {
612 tracing::Span::current()
613 };
614
615 tracing::debug!(
616 "Cohere request: {}",
617 serde_json::to_string_pretty(&request)?
618 );
619
620 let req_body = serde_json::to_vec(&request)?;
621
622 let req = self
623 .client
624 .req(Method::POST, "/v2/chat")?
625 .body(req_body)
626 .unwrap();
627
628 async {
629 let response = self
630 .client
631 .http_client()
632 .send::<_, bytes::Bytes>(req)
633 .await
634 .map_err(|e| http_client::Error::Instance(e.into()))?;
635
636 let status = response.status();
637 let body = response.into_body().into_future().await?.to_owned();
638
639 if status.is_success() {
640 let json_response: CompletionResponse = serde_json::from_slice(&body)?;
641 let span = tracing::Span::current();
642 span.record_token_usage(&json_response.usage);
643 span.record_model_output(&json_response.message);
644 span.record_response_metadata(&json_response);
645 tracing::debug!(
646 "Cohere completion response: {}",
647 serde_json::to_string_pretty(&json_response)?
648 );
649 let completion: completion::CompletionResponse<CompletionResponse> =
650 json_response.try_into()?;
651 Ok(completion)
652 } else {
653 Err(CompletionError::ProviderError(
654 String::from_utf8_lossy(&body).to_string(),
655 ))
656 }
657 }
658 .instrument(llm_span)
659 .await
660 }
661
662 #[cfg_attr(feature = "worker", worker::send)]
663 async fn stream(
664 &self,
665 request: CompletionRequest,
666 ) -> Result<
667 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
668 CompletionError,
669 > {
670 CompletionModel::stream(self, request).await
671 }
672}
673#[cfg(test)]
674mod tests {
675 use super::*;
676 use serde_path_to_error::deserialize;
677
678 #[test]
679 fn test_deserialize_completion_response() {
680 let json_data = r#"
681 {
682 "id": "abc123",
683 "message": {
684 "role": "assistant",
685 "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
686 "tool_calls": [
687 {
688 "id": "subtract_sm6ps6fb6y9f",
689 "type": "function",
690 "function": {
691 "name": "subtract",
692 "arguments": "{\"x\":5,\"y\":2}"
693 }
694 }
695 ]
696 },
697 "finish_reason": "TOOL_CALL",
698 "usage": {
699 "billed_units": {
700 "input_tokens": 78,
701 "output_tokens": 27
702 },
703 "tokens": {
704 "input_tokens": 1028,
705 "output_tokens": 63
706 }
707 }
708 }
709 "#;
710
711 let mut deserializer = serde_json::Deserializer::from_str(json_data);
712 let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
713
714 let response = result.unwrap();
715 let (_, citations, tool_calls) = response.message();
716 let CompletionResponse {
717 id,
718 finish_reason,
719 usage,
720 ..
721 } = response;
722
723 assert_eq!(id, "abc123");
724 assert_eq!(finish_reason, FinishReason::ToolCall);
725
726 let Usage {
727 billed_units,
728 tokens,
729 } = usage.unwrap();
730 let BilledUnits {
731 input_tokens: billed_input_tokens,
732 output_tokens: billed_output_tokens,
733 ..
734 } = billed_units.unwrap();
735 let Tokens {
736 input_tokens,
737 output_tokens,
738 } = tokens.unwrap();
739
740 assert_eq!(billed_input_tokens.unwrap(), 78.0);
741 assert_eq!(billed_output_tokens.unwrap(), 27.0);
742 assert_eq!(input_tokens.unwrap(), 1028.0);
743 assert_eq!(output_tokens.unwrap(), 63.0);
744
745 assert!(citations.is_empty());
746 assert_eq!(tool_calls.len(), 1);
747
748 let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
749
750 assert_eq!(name, "subtract");
751 assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
752 }
753
754 #[test]
755 fn test_convert_completion_message_to_message_and_back() {
756 let completion_message = completion::Message::User {
757 content: OneOrMany::one(completion::message::UserContent::Text(
758 completion::message::Text {
759 text: "Hello, world!".to_string(),
760 },
761 )),
762 };
763
764 let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
765 let _converted_back: Vec<completion::Message> = messages
766 .into_iter()
767 .map(|msg| msg.try_into().unwrap())
768 .collect::<Vec<_>>();
769 }
770
771 #[test]
772 fn test_convert_message_to_completion_message_and_back() {
773 let message = Message::User {
774 content: OneOrMany::one(UserContent::Text {
775 text: "Hello, world!".to_string(),
776 }),
777 };
778
779 let completion_message: completion::Message = message.clone().try_into().unwrap();
780 let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
781 }
782}