rig/providers/cohere/
completion.rs

1use crate::{
2    OneOrMany,
3    completion::{self, CompletionError, GetTokenUsage},
4    http_client::{self, HttpClientExt},
5    json_utils,
6    message::{self, Reasoning, ToolChoice},
7    telemetry::SpanCombinator,
8};
9use std::collections::HashMap;
10
11use super::client::Client;
12use crate::completion::CompletionRequest;
13use crate::providers::cohere::streaming::StreamingCompletionResponse;
14use serde::{Deserialize, Serialize};
15use serde_json::{Value, json};
16use tracing::{Instrument, info_span};
17
18#[derive(Debug, Deserialize, Serialize)]
19pub struct CompletionResponse {
20    pub id: String,
21    pub finish_reason: FinishReason,
22    message: Message,
23    #[serde(default)]
24    pub usage: Option<Usage>,
25}
26
27impl CompletionResponse {
28    /// Return that parts of the response for assistant messages w/o dealing with the other variants
29    pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
30        let Message::Assistant {
31            content,
32            citations,
33            tool_calls,
34            ..
35        } = self.message.clone()
36        else {
37            unreachable!("Completion responses will only return an assistant message")
38        };
39
40        (content, citations, tool_calls)
41    }
42}
43
44impl crate::telemetry::ProviderResponseExt for CompletionResponse {
45    type OutputMessage = Message;
46    type Usage = Usage;
47
48    fn get_response_id(&self) -> Option<String> {
49        Some(self.id.clone())
50    }
51
52    fn get_response_model_name(&self) -> Option<String> {
53        None
54    }
55
56    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
57        vec![self.message.clone()]
58    }
59
60    fn get_text_response(&self) -> Option<String> {
61        let Message::Assistant { ref content, .. } = self.message else {
62            return None;
63        };
64
65        let res = content
66            .iter()
67            .filter_map(|x| {
68                if let AssistantContent::Text { text } = x {
69                    Some(text.to_string())
70                } else {
71                    None
72                }
73            })
74            .collect::<Vec<String>>()
75            .join("\n");
76
77        if res.is_empty() { None } else { Some(res) }
78    }
79
80    fn get_usage(&self) -> Option<Self::Usage> {
81        self.usage.clone()
82    }
83}
84
85#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
86#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
87pub enum FinishReason {
88    MaxTokens,
89    StopSequence,
90    Complete,
91    Error,
92    ToolCall,
93}
94
95#[derive(Debug, Deserialize, Clone, Serialize)]
96pub struct Usage {
97    #[serde(default)]
98    pub billed_units: Option<BilledUnits>,
99    #[serde(default)]
100    pub tokens: Option<Tokens>,
101}
102
103impl GetTokenUsage for Usage {
104    fn token_usage(&self) -> Option<crate::completion::Usage> {
105        let mut usage = crate::completion::Usage::new();
106
107        if let Some(ref billed_units) = self.billed_units {
108            usage.input_tokens = billed_units.input_tokens.unwrap_or_default() as u64;
109            usage.output_tokens = billed_units.output_tokens.unwrap_or_default() as u64;
110            usage.total_tokens = usage.input_tokens + usage.output_tokens;
111        }
112
113        Some(usage)
114    }
115}
116
117#[derive(Debug, Deserialize, Clone, Serialize)]
118pub struct BilledUnits {
119    #[serde(default)]
120    pub output_tokens: Option<f64>,
121    #[serde(default)]
122    pub classifications: Option<f64>,
123    #[serde(default)]
124    pub search_units: Option<f64>,
125    #[serde(default)]
126    pub input_tokens: Option<f64>,
127}
128
129#[derive(Debug, Deserialize, Clone, Serialize)]
130pub struct Tokens {
131    #[serde(default)]
132    pub input_tokens: Option<f64>,
133    #[serde(default)]
134    pub output_tokens: Option<f64>,
135}
136
137impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
138    type Error = CompletionError;
139
140    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
141        let (content, _, tool_calls) = response.message();
142
143        let model_response = if !tool_calls.is_empty() {
144            OneOrMany::many(
145                tool_calls
146                    .into_iter()
147                    .filter_map(|tool_call| {
148                        let ToolCallFunction { name, arguments } = tool_call.function?;
149                        let id = tool_call.id.unwrap_or_else(|| name.clone());
150
151                        Some(completion::AssistantContent::tool_call(id, name, arguments))
152                    })
153                    .collect::<Vec<_>>(),
154            )
155            .expect("We have atleast 1 tool call in this if block")
156        } else {
157            OneOrMany::many(content.into_iter().map(|content| match content {
158                AssistantContent::Text { text } => completion::AssistantContent::text(text),
159                AssistantContent::Thinking { thinking } => {
160                    completion::AssistantContent::Reasoning(Reasoning {
161                        id: None,
162                        reasoning: vec![thinking],
163                    })
164                }
165            }))
166            .map_err(|_| {
167                CompletionError::ResponseError(
168                    "Response contained no message or tool call (empty)".to_owned(),
169                )
170            })?
171        };
172
173        let usage = response
174            .usage
175            .as_ref()
176            .and_then(|usage| usage.tokens.as_ref())
177            .map(|tokens| {
178                let input_tokens = tokens.input_tokens.unwrap_or(0.0);
179                let output_tokens = tokens.output_tokens.unwrap_or(0.0);
180
181                completion::Usage {
182                    input_tokens: input_tokens as u64,
183                    output_tokens: output_tokens as u64,
184                    total_tokens: (input_tokens + output_tokens) as u64,
185                }
186            })
187            .unwrap_or_default();
188
189        Ok(completion::CompletionResponse {
190            choice: OneOrMany::many(model_response).expect("There is atleast one content"),
191            usage,
192            raw_response: response,
193        })
194    }
195}
196
197#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
198pub struct Document {
199    pub id: String,
200    pub data: HashMap<String, serde_json::Value>,
201}
202
203impl From<completion::Document> for Document {
204    fn from(document: completion::Document) -> Self {
205        let mut data: HashMap<String, serde_json::Value> = HashMap::new();
206
207        // We use `.into()` here explicitly since the `document.additional_props` type will likely
208        //  evolve into `serde_json::Value` in the future.
209        document
210            .additional_props
211            .into_iter()
212            .for_each(|(key, value)| {
213                data.insert(key, value.into());
214            });
215
216        data.insert("text".to_string(), document.text.into());
217
218        Self {
219            id: document.id,
220            data,
221        }
222    }
223}
224
225#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
226pub struct ToolCall {
227    #[serde(default)]
228    pub id: Option<String>,
229    #[serde(default)]
230    pub r#type: Option<ToolType>,
231    #[serde(default)]
232    pub function: Option<ToolCallFunction>,
233}
234
235#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
236pub struct ToolCallFunction {
237    pub name: String,
238    #[serde(with = "json_utils::stringified_json")]
239    pub arguments: serde_json::Value,
240}
241
242#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
243#[serde(rename_all = "lowercase")]
244pub enum ToolType {
245    #[default]
246    Function,
247}
248
249#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
250pub struct Tool {
251    pub r#type: ToolType,
252    pub function: Function,
253}
254
255#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
256pub struct Function {
257    pub name: String,
258    #[serde(default)]
259    pub description: Option<String>,
260    pub parameters: serde_json::Value,
261}
262
263impl From<completion::ToolDefinition> for Tool {
264    fn from(tool: completion::ToolDefinition) -> Self {
265        Self {
266            r#type: ToolType::default(),
267            function: Function {
268                name: tool.name,
269                description: Some(tool.description),
270                parameters: tool.parameters,
271            },
272        }
273    }
274}
275
276#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
277#[serde(tag = "role", rename_all = "lowercase")]
278pub enum Message {
279    User {
280        content: OneOrMany<UserContent>,
281    },
282
283    Assistant {
284        #[serde(default)]
285        content: Vec<AssistantContent>,
286        #[serde(default)]
287        citations: Vec<Citation>,
288        #[serde(default)]
289        tool_calls: Vec<ToolCall>,
290        #[serde(default)]
291        tool_plan: Option<String>,
292    },
293
294    Tool {
295        content: OneOrMany<ToolResultContent>,
296        tool_call_id: String,
297    },
298
299    System {
300        content: String,
301    },
302}
303
304#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
305#[serde(tag = "type", rename_all = "lowercase")]
306pub enum UserContent {
307    Text { text: String },
308    ImageUrl { image_url: ImageUrl },
309}
310
311#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
312#[serde(tag = "type", rename_all = "lowercase")]
313pub enum AssistantContent {
314    Text { text: String },
315    Thinking { thinking: String },
316}
317
318#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
319pub struct ImageUrl {
320    pub url: String,
321}
322
323#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
324pub enum ToolResultContent {
325    Text { text: String },
326    Document { document: Document },
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
330pub struct Citation {
331    #[serde(default)]
332    pub start: Option<u32>,
333    #[serde(default)]
334    pub end: Option<u32>,
335    #[serde(default)]
336    pub text: Option<String>,
337    #[serde(rename = "type")]
338    pub citation_type: Option<CitationType>,
339    #[serde(default)]
340    pub sources: Vec<Source>,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
344#[serde(tag = "type", rename_all = "lowercase")]
345pub enum Source {
346    Document {
347        id: Option<String>,
348        document: Option<serde_json::Map<String, serde_json::Value>>,
349    },
350    Tool {
351        id: Option<String>,
352        tool_output: Option<serde_json::Map<String, serde_json::Value>>,
353    },
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
357#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
358pub enum CitationType {
359    TextContent,
360    Plan,
361}
362
363impl TryFrom<message::Message> for Vec<Message> {
364    type Error = message::MessageError;
365
366    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
367        Ok(match message {
368            message::Message::User { content } => content
369                .into_iter()
370                .map(|content| match content {
371                    message::UserContent::Text(message::Text { text }) => Ok(Message::User {
372                        content: OneOrMany::one(UserContent::Text { text }),
373                    }),
374                    message::UserContent::ToolResult(message::ToolResult {
375                        id, content, ..
376                    }) => Ok(Message::Tool {
377                        tool_call_id: id,
378                        content: content.try_map(|content| match content {
379                            message::ToolResultContent::Text(text) => {
380                                Ok(ToolResultContent::Text { text: text.text })
381                            }
382                            _ => Err(message::MessageError::ConversionError(
383                                "Only text tool result content is supported by Cohere".to_owned(),
384                            )),
385                        })?,
386                    }),
387                    _ => Err(message::MessageError::ConversionError(
388                        "Only text content is supported by Cohere".to_owned(),
389                    )),
390                })
391                .collect::<Result<Vec<_>, _>>()?,
392            message::Message::Assistant { content, .. } => {
393                let mut text_content = vec![];
394                let mut tool_calls = vec![];
395                content.into_iter().for_each(|content| match content {
396                    message::AssistantContent::Text(message::Text { text }) => {
397                        text_content.push(AssistantContent::Text { text });
398                    }
399                    message::AssistantContent::ToolCall(message::ToolCall {
400                        id,
401                        function:
402                            message::ToolFunction {
403                                name, arguments, ..
404                            },
405                        ..
406                    }) => {
407                        tool_calls.push(ToolCall {
408                            id: Some(id),
409                            r#type: Some(ToolType::Function),
410                            function: Some(ToolCallFunction {
411                                name,
412                                arguments: serde_json::to_value(arguments).unwrap_or_default(),
413                            }),
414                        });
415                    }
416                    message::AssistantContent::Reasoning(Reasoning { reasoning, .. }) => {
417                        let thinking = reasoning.join("\n");
418                        text_content.push(AssistantContent::Thinking { thinking });
419                    }
420                });
421
422                vec![Message::Assistant {
423                    content: text_content,
424                    citations: vec![],
425                    tool_calls,
426                    tool_plan: None,
427                }]
428            }
429        })
430    }
431}
432
433impl TryFrom<Message> for message::Message {
434    type Error = message::MessageError;
435
436    fn try_from(message: Message) -> Result<Self, Self::Error> {
437        match message {
438            Message::User { content } => Ok(message::Message::User {
439                content: content.map(|content| match content {
440                    UserContent::Text { text } => {
441                        message::UserContent::Text(message::Text { text })
442                    }
443                    UserContent::ImageUrl { image_url } => {
444                        message::UserContent::image_url(image_url.url, None, None)
445                    }
446                }),
447            }),
448            Message::Assistant {
449                content,
450                tool_calls,
451                ..
452            } => {
453                let mut content = content
454                    .into_iter()
455                    .map(|content| match content {
456                        AssistantContent::Text { text } => message::AssistantContent::text(text),
457                        AssistantContent::Thinking { thinking } => {
458                            message::AssistantContent::Reasoning(Reasoning {
459                                id: None,
460                                reasoning: vec![thinking],
461                            })
462                        }
463                    })
464                    .collect::<Vec<_>>();
465
466                content.extend(tool_calls.into_iter().filter_map(|tool_call| {
467                    let ToolCallFunction { name, arguments } = tool_call.function?;
468
469                    Some(message::AssistantContent::tool_call(
470                        tool_call.id.unwrap_or_else(|| name.clone()),
471                        name,
472                        arguments,
473                    ))
474                }));
475
476                let content = OneOrMany::many(content).map_err(|_| {
477                    message::MessageError::ConversionError(
478                        "Expected either text content or tool calls".to_string(),
479                    )
480                })?;
481
482                Ok(message::Message::Assistant { id: None, content })
483            }
484            Message::Tool {
485                content,
486                tool_call_id,
487            } => {
488                let content = content.try_map(|content| {
489                    Ok(match content {
490                        ToolResultContent::Text { text } => message::ToolResultContent::text(text),
491                        ToolResultContent::Document { document } => {
492                            message::ToolResultContent::text(
493                                serde_json::to_string(&document.data).map_err(|e| {
494                                    message::MessageError::ConversionError(
495                                        format!("Failed to convert tool result document content into text: {e}"),
496                                    )
497                                })?,
498                            )
499                        }
500                    })
501                })?;
502
503                Ok(message::Message::User {
504                    content: OneOrMany::one(message::UserContent::tool_result(
505                        tool_call_id,
506                        content,
507                    )),
508                })
509            }
510            Message::System { content } => Ok(message::Message::user(content)),
511        }
512    }
513}
514
515#[derive(Clone)]
516pub struct CompletionModel<T = reqwest::Client> {
517    pub(crate) client: Client<T>,
518    pub model: String,
519}
520
521impl<T> CompletionModel<T>
522where
523    T: HttpClientExt,
524{
525    pub fn new(client: Client<T>, model: &str) -> Self {
526        Self {
527            client,
528            model: model.to_string(),
529        }
530    }
531
532    pub(crate) fn create_completion_request(
533        &self,
534        completion_request: CompletionRequest,
535    ) -> Result<Value, CompletionError> {
536        // Build up the order of messages (context, chat_history)
537        let mut partial_history = vec![];
538        if let Some(docs) = completion_request.normalized_documents() {
539            partial_history.push(docs);
540        }
541        partial_history.extend(completion_request.chat_history);
542
543        // Initialize full history with preamble (or empty if non-existent)
544        let mut full_history: Vec<Message> = completion_request
545            .preamble
546            .map_or_else(Vec::new, |preamble| {
547                vec![Message::System { content: preamble }]
548            });
549
550        // Convert and extend the rest of the history
551        full_history.extend(
552            partial_history
553                .into_iter()
554                .map(message::Message::try_into)
555                .collect::<Result<Vec<Vec<Message>>, _>>()?
556                .into_iter()
557                .flatten()
558                .collect::<Vec<_>>(),
559        );
560
561        let request = json!({
562            "model": self.model,
563            "messages": full_history,
564            "documents": completion_request.documents,
565            "temperature": completion_request.temperature,
566            "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
567            "tool_choice": if let Some(tool_choice) = completion_request.tool_choice && !matches!(tool_choice, ToolChoice::Auto) { tool_choice } else {
568                return Err(CompletionError::RequestError("\"auto\" is not an allowed tool_choice value in the Cohere API".into()))
569            },
570        });
571
572        if let Some(ref params) = completion_request.additional_params {
573            Ok(json_utils::merge(request.clone(), params.clone()))
574        } else {
575            Ok(request)
576        }
577    }
578}
579
580impl completion::CompletionModel for CompletionModel<reqwest::Client> {
581    type Response = CompletionResponse;
582    type StreamingResponse = StreamingCompletionResponse;
583
584    #[cfg_attr(feature = "worker", worker::send)]
585    async fn completion(
586        &self,
587        completion_request: completion::CompletionRequest,
588    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
589        let request = self.create_completion_request(completion_request)?;
590
591        let llm_span = if tracing::Span::current().is_disabled() {
592            info_span!(
593            target: "rig::completions",
594            "chat",
595            gen_ai.operation.name = "chat",
596            gen_ai.provider.name = "cohere",
597            gen_ai.request.model = self.model,
598            gen_ai.response.id = tracing::field::Empty,
599            gen_ai.response.model = self.model,
600            gen_ai.usage.output_tokens = tracing::field::Empty,
601            gen_ai.usage.input_tokens = tracing::field::Empty,
602            gen_ai.input.messages = serde_json::to_string(request.get("messages").expect("Converting request messages to JSON should not fail!")).unwrap(),
603            gen_ai.output.messages = tracing::field::Empty,
604            )
605        } else {
606            tracing::Span::current()
607        };
608
609        tracing::debug!(
610            "Cohere request: {}",
611            serde_json::to_string_pretty(&request)?
612        );
613
614        async {
615            let response = self
616                .client
617                .client()
618                .post("/v2/chat")
619                .json(&request)
620                .send()
621                .await
622                .map_err(|e| http_client::Error::Instance(e.into()))?;
623
624            if response.status().is_success() {
625                let text_response = response.text().await.map_err(|e| {
626                    CompletionError::HttpError(http_client::Error::Instance(e.into()))
627                })?;
628                tracing::debug!("Cohere completion request: {}", text_response);
629
630                let json_response: CompletionResponse = serde_json::from_str(&text_response)?;
631                let span = tracing::Span::current();
632                span.record_token_usage(&json_response.usage);
633                span.record_model_output(&json_response.message);
634                span.record_response_metadata(&json_response);
635                tracing::debug!("Cohere completion response: {}", text_response);
636                let completion: completion::CompletionResponse<CompletionResponse> =
637                    json_response.try_into()?;
638                Ok(completion)
639            } else {
640                Err(CompletionError::ProviderError(
641                    response.text().await.map_err(|e| {
642                        CompletionError::HttpError(http_client::Error::Instance(e.into()))
643                    })?,
644                ))
645            }
646        }
647        .instrument(llm_span)
648        .await
649    }
650
651    #[cfg_attr(feature = "worker", worker::send)]
652    async fn stream(
653        &self,
654        request: CompletionRequest,
655    ) -> Result<
656        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
657        CompletionError,
658    > {
659        CompletionModel::stream(self, request).await
660    }
661}
662#[cfg(test)]
663mod tests {
664    use super::*;
665    use serde_path_to_error::deserialize;
666
667    #[test]
668    fn test_deserialize_completion_response() {
669        let json_data = r#"
670        {
671            "id": "abc123",
672            "message": {
673                "role": "assistant",
674                "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
675                "tool_calls": [
676                        {
677                            "id": "subtract_sm6ps6fb6y9f",
678                            "type": "function",
679                            "function": {
680                                "name": "subtract",
681                                "arguments": "{\"x\":5,\"y\":2}"
682                            }
683                        }
684                    ]
685                },
686                "finish_reason": "TOOL_CALL",
687                "usage": {
688                "billed_units": {
689                    "input_tokens": 78,
690                    "output_tokens": 27
691                },
692                "tokens": {
693                    "input_tokens": 1028,
694                    "output_tokens": 63
695                }
696            }
697        }
698        "#;
699
700        let mut deserializer = serde_json::Deserializer::from_str(json_data);
701        let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
702
703        let response = result.unwrap();
704        let (_, citations, tool_calls) = response.message();
705        let CompletionResponse {
706            id,
707            finish_reason,
708            usage,
709            ..
710        } = response;
711
712        assert_eq!(id, "abc123");
713        assert_eq!(finish_reason, FinishReason::ToolCall);
714
715        let Usage {
716            billed_units,
717            tokens,
718        } = usage.unwrap();
719        let BilledUnits {
720            input_tokens: billed_input_tokens,
721            output_tokens: billed_output_tokens,
722            ..
723        } = billed_units.unwrap();
724        let Tokens {
725            input_tokens,
726            output_tokens,
727        } = tokens.unwrap();
728
729        assert_eq!(billed_input_tokens.unwrap(), 78.0);
730        assert_eq!(billed_output_tokens.unwrap(), 27.0);
731        assert_eq!(input_tokens.unwrap(), 1028.0);
732        assert_eq!(output_tokens.unwrap(), 63.0);
733
734        assert!(citations.is_empty());
735        assert_eq!(tool_calls.len(), 1);
736
737        let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
738
739        assert_eq!(name, "subtract");
740        assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
741    }
742
743    #[test]
744    fn test_convert_completion_message_to_message_and_back() {
745        let completion_message = completion::Message::User {
746            content: OneOrMany::one(completion::message::UserContent::Text(
747                completion::message::Text {
748                    text: "Hello, world!".to_string(),
749                },
750            )),
751        };
752
753        let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
754        let _converted_back: Vec<completion::Message> = messages
755            .into_iter()
756            .map(|msg| msg.try_into().unwrap())
757            .collect::<Vec<_>>();
758    }
759
760    #[test]
761    fn test_convert_message_to_completion_message_and_back() {
762        let message = Message::User {
763            content: OneOrMany::one(UserContent::Text {
764                text: "Hello, world!".to_string(),
765            }),
766        };
767
768        let completion_message: completion::Message = message.clone().try_into().unwrap();
769        let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
770    }
771}