rig/providers/cohere/
completion.rs

1use crate::{
2    OneOrMany,
3    completion::{self, CompletionError, GetTokenUsage},
4    http_client::{self, HttpClientExt},
5    json_utils,
6    message::{self, Reasoning, ToolChoice},
7    telemetry::SpanCombinator,
8};
9use std::collections::HashMap;
10
11use super::client::Client;
12use crate::completion::CompletionRequest;
13use crate::providers::cohere::streaming::StreamingCompletionResponse;
14use http::Method;
15use serde::{Deserialize, Serialize};
16use serde_json::{Value, json};
17use tracing::{Instrument, info_span};
18
19#[derive(Debug, Deserialize, Serialize)]
20pub struct CompletionResponse {
21    pub id: String,
22    pub finish_reason: FinishReason,
23    message: Message,
24    #[serde(default)]
25    pub usage: Option<Usage>,
26}
27
28impl CompletionResponse {
29    /// Return that parts of the response for assistant messages w/o dealing with the other variants
30    pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
31        let Message::Assistant {
32            content,
33            citations,
34            tool_calls,
35            ..
36        } = self.message.clone()
37        else {
38            unreachable!("Completion responses will only return an assistant message")
39        };
40
41        (content, citations, tool_calls)
42    }
43}
44
45impl crate::telemetry::ProviderResponseExt for CompletionResponse {
46    type OutputMessage = Message;
47    type Usage = Usage;
48
49    fn get_response_id(&self) -> Option<String> {
50        Some(self.id.clone())
51    }
52
53    fn get_response_model_name(&self) -> Option<String> {
54        None
55    }
56
57    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
58        vec![self.message.clone()]
59    }
60
61    fn get_text_response(&self) -> Option<String> {
62        let Message::Assistant { ref content, .. } = self.message else {
63            return None;
64        };
65
66        let res = content
67            .iter()
68            .filter_map(|x| {
69                if let AssistantContent::Text { text } = x {
70                    Some(text.to_string())
71                } else {
72                    None
73                }
74            })
75            .collect::<Vec<String>>()
76            .join("\n");
77
78        if res.is_empty() { None } else { Some(res) }
79    }
80
81    fn get_usage(&self) -> Option<Self::Usage> {
82        self.usage.clone()
83    }
84}
85
86#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
87#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
88pub enum FinishReason {
89    MaxTokens,
90    StopSequence,
91    Complete,
92    Error,
93    ToolCall,
94}
95
96#[derive(Debug, Deserialize, Clone, Serialize)]
97pub struct Usage {
98    #[serde(default)]
99    pub billed_units: Option<BilledUnits>,
100    #[serde(default)]
101    pub tokens: Option<Tokens>,
102}
103
104impl GetTokenUsage for Usage {
105    fn token_usage(&self) -> Option<crate::completion::Usage> {
106        let mut usage = crate::completion::Usage::new();
107
108        if let Some(ref billed_units) = self.billed_units {
109            usage.input_tokens = billed_units.input_tokens.unwrap_or_default() as u64;
110            usage.output_tokens = billed_units.output_tokens.unwrap_or_default() as u64;
111            usage.total_tokens = usage.input_tokens + usage.output_tokens;
112        }
113
114        Some(usage)
115    }
116}
117
118#[derive(Debug, Deserialize, Clone, Serialize)]
119pub struct BilledUnits {
120    #[serde(default)]
121    pub output_tokens: Option<f64>,
122    #[serde(default)]
123    pub classifications: Option<f64>,
124    #[serde(default)]
125    pub search_units: Option<f64>,
126    #[serde(default)]
127    pub input_tokens: Option<f64>,
128}
129
130#[derive(Debug, Deserialize, Clone, Serialize)]
131pub struct Tokens {
132    #[serde(default)]
133    pub input_tokens: Option<f64>,
134    #[serde(default)]
135    pub output_tokens: Option<f64>,
136}
137
138impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
139    type Error = CompletionError;
140
141    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
142        let (content, _, tool_calls) = response.message();
143
144        let model_response = if !tool_calls.is_empty() {
145            OneOrMany::many(
146                tool_calls
147                    .into_iter()
148                    .filter_map(|tool_call| {
149                        let ToolCallFunction { name, arguments } = tool_call.function?;
150                        let id = tool_call.id.unwrap_or_else(|| name.clone());
151
152                        Some(completion::AssistantContent::tool_call(id, name, arguments))
153                    })
154                    .collect::<Vec<_>>(),
155            )
156            .expect("We have atleast 1 tool call in this if block")
157        } else {
158            OneOrMany::many(content.into_iter().map(|content| match content {
159                AssistantContent::Text { text } => completion::AssistantContent::text(text),
160                AssistantContent::Thinking { thinking } => {
161                    completion::AssistantContent::Reasoning(Reasoning {
162                        id: None,
163                        reasoning: vec![thinking],
164                        signature: None,
165                    })
166                }
167            }))
168            .map_err(|_| {
169                CompletionError::ResponseError(
170                    "Response contained no message or tool call (empty)".to_owned(),
171                )
172            })?
173        };
174
175        let usage = response
176            .usage
177            .as_ref()
178            .and_then(|usage| usage.tokens.as_ref())
179            .map(|tokens| {
180                let input_tokens = tokens.input_tokens.unwrap_or(0.0);
181                let output_tokens = tokens.output_tokens.unwrap_or(0.0);
182
183                completion::Usage {
184                    input_tokens: input_tokens as u64,
185                    output_tokens: output_tokens as u64,
186                    total_tokens: (input_tokens + output_tokens) as u64,
187                }
188            })
189            .unwrap_or_default();
190
191        Ok(completion::CompletionResponse {
192            choice: OneOrMany::many(model_response).expect("There is atleast one content"),
193            usage,
194            raw_response: response,
195        })
196    }
197}
198
199#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
200pub struct Document {
201    pub id: String,
202    pub data: HashMap<String, serde_json::Value>,
203}
204
205impl From<completion::Document> for Document {
206    fn from(document: completion::Document) -> Self {
207        let mut data: HashMap<String, serde_json::Value> = HashMap::new();
208
209        // We use `.into()` here explicitly since the `document.additional_props` type will likely
210        //  evolve into `serde_json::Value` in the future.
211        document
212            .additional_props
213            .into_iter()
214            .for_each(|(key, value)| {
215                data.insert(key, value.into());
216            });
217
218        data.insert("text".to_string(), document.text.into());
219
220        Self {
221            id: document.id,
222            data,
223        }
224    }
225}
226
227#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
228pub struct ToolCall {
229    #[serde(default)]
230    pub id: Option<String>,
231    #[serde(default)]
232    pub r#type: Option<ToolType>,
233    #[serde(default)]
234    pub function: Option<ToolCallFunction>,
235}
236
237#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
238pub struct ToolCallFunction {
239    pub name: String,
240    #[serde(with = "json_utils::stringified_json")]
241    pub arguments: serde_json::Value,
242}
243
244#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
245#[serde(rename_all = "lowercase")]
246pub enum ToolType {
247    #[default]
248    Function,
249}
250
251#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
252pub struct Tool {
253    pub r#type: ToolType,
254    pub function: Function,
255}
256
257#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
258pub struct Function {
259    pub name: String,
260    #[serde(default)]
261    pub description: Option<String>,
262    pub parameters: serde_json::Value,
263}
264
265impl From<completion::ToolDefinition> for Tool {
266    fn from(tool: completion::ToolDefinition) -> Self {
267        Self {
268            r#type: ToolType::default(),
269            function: Function {
270                name: tool.name,
271                description: Some(tool.description),
272                parameters: tool.parameters,
273            },
274        }
275    }
276}
277
278#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
279#[serde(tag = "role", rename_all = "lowercase")]
280pub enum Message {
281    User {
282        content: OneOrMany<UserContent>,
283    },
284
285    Assistant {
286        #[serde(default)]
287        content: Vec<AssistantContent>,
288        #[serde(default)]
289        citations: Vec<Citation>,
290        #[serde(default)]
291        tool_calls: Vec<ToolCall>,
292        #[serde(default)]
293        tool_plan: Option<String>,
294    },
295
296    Tool {
297        content: OneOrMany<ToolResultContent>,
298        tool_call_id: String,
299    },
300
301    System {
302        content: String,
303    },
304}
305
306#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
307#[serde(tag = "type", rename_all = "lowercase")]
308pub enum UserContent {
309    Text { text: String },
310    ImageUrl { image_url: ImageUrl },
311}
312
313#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
314#[serde(tag = "type", rename_all = "lowercase")]
315pub enum AssistantContent {
316    Text { text: String },
317    Thinking { thinking: String },
318}
319
320#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
321pub struct ImageUrl {
322    pub url: String,
323}
324
325#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
326pub enum ToolResultContent {
327    Text { text: String },
328    Document { document: Document },
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
332pub struct Citation {
333    #[serde(default)]
334    pub start: Option<u32>,
335    #[serde(default)]
336    pub end: Option<u32>,
337    #[serde(default)]
338    pub text: Option<String>,
339    #[serde(rename = "type")]
340    pub citation_type: Option<CitationType>,
341    #[serde(default)]
342    pub sources: Vec<Source>,
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
346#[serde(tag = "type", rename_all = "lowercase")]
347pub enum Source {
348    Document {
349        id: Option<String>,
350        document: Option<serde_json::Map<String, serde_json::Value>>,
351    },
352    Tool {
353        id: Option<String>,
354        tool_output: Option<serde_json::Map<String, serde_json::Value>>,
355    },
356}
357
358#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
359#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
360pub enum CitationType {
361    TextContent,
362    Plan,
363}
364
365impl TryFrom<message::Message> for Vec<Message> {
366    type Error = message::MessageError;
367
368    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
369        Ok(match message {
370            message::Message::User { content } => content
371                .into_iter()
372                .map(|content| match content {
373                    message::UserContent::Text(message::Text { text }) => Ok(Message::User {
374                        content: OneOrMany::one(UserContent::Text { text }),
375                    }),
376                    message::UserContent::ToolResult(message::ToolResult {
377                        id, content, ..
378                    }) => Ok(Message::Tool {
379                        tool_call_id: id,
380                        content: content.try_map(|content| match content {
381                            message::ToolResultContent::Text(text) => {
382                                Ok(ToolResultContent::Text { text: text.text })
383                            }
384                            _ => Err(message::MessageError::ConversionError(
385                                "Only text tool result content is supported by Cohere".to_owned(),
386                            )),
387                        })?,
388                    }),
389                    _ => Err(message::MessageError::ConversionError(
390                        "Only text content is supported by Cohere".to_owned(),
391                    )),
392                })
393                .collect::<Result<Vec<_>, _>>()?,
394            message::Message::Assistant { content, .. } => {
395                let mut text_content = vec![];
396                let mut tool_calls = vec![];
397                content.into_iter().for_each(|content| match content {
398                    message::AssistantContent::Text(message::Text { text }) => {
399                        text_content.push(AssistantContent::Text { text });
400                    }
401                    message::AssistantContent::ToolCall(message::ToolCall {
402                        id,
403                        function:
404                            message::ToolFunction {
405                                name, arguments, ..
406                            },
407                        ..
408                    }) => {
409                        tool_calls.push(ToolCall {
410                            id: Some(id),
411                            r#type: Some(ToolType::Function),
412                            function: Some(ToolCallFunction {
413                                name,
414                                arguments: serde_json::to_value(arguments).unwrap_or_default(),
415                            }),
416                        });
417                    }
418                    message::AssistantContent::Reasoning(Reasoning { reasoning, .. }) => {
419                        let thinking = reasoning.join("\n");
420                        text_content.push(AssistantContent::Thinking { thinking });
421                    }
422                });
423
424                vec![Message::Assistant {
425                    content: text_content,
426                    citations: vec![],
427                    tool_calls,
428                    tool_plan: None,
429                }]
430            }
431        })
432    }
433}
434
435impl TryFrom<Message> for message::Message {
436    type Error = message::MessageError;
437
438    fn try_from(message: Message) -> Result<Self, Self::Error> {
439        match message {
440            Message::User { content } => Ok(message::Message::User {
441                content: content.map(|content| match content {
442                    UserContent::Text { text } => {
443                        message::UserContent::Text(message::Text { text })
444                    }
445                    UserContent::ImageUrl { image_url } => {
446                        message::UserContent::image_url(image_url.url, None, None)
447                    }
448                }),
449            }),
450            Message::Assistant {
451                content,
452                tool_calls,
453                ..
454            } => {
455                let mut content = content
456                    .into_iter()
457                    .map(|content| match content {
458                        AssistantContent::Text { text } => message::AssistantContent::text(text),
459                        AssistantContent::Thinking { thinking } => {
460                            message::AssistantContent::Reasoning(Reasoning {
461                                id: None,
462                                reasoning: vec![thinking],
463                                signature: None,
464                            })
465                        }
466                    })
467                    .collect::<Vec<_>>();
468
469                content.extend(tool_calls.into_iter().filter_map(|tool_call| {
470                    let ToolCallFunction { name, arguments } = tool_call.function?;
471
472                    Some(message::AssistantContent::tool_call(
473                        tool_call.id.unwrap_or_else(|| name.clone()),
474                        name,
475                        arguments,
476                    ))
477                }));
478
479                let content = OneOrMany::many(content).map_err(|_| {
480                    message::MessageError::ConversionError(
481                        "Expected either text content or tool calls".to_string(),
482                    )
483                })?;
484
485                Ok(message::Message::Assistant { id: None, content })
486            }
487            Message::Tool {
488                content,
489                tool_call_id,
490            } => {
491                let content = content.try_map(|content| {
492                    Ok(match content {
493                        ToolResultContent::Text { text } => message::ToolResultContent::text(text),
494                        ToolResultContent::Document { document } => {
495                            message::ToolResultContent::text(
496                                serde_json::to_string(&document.data).map_err(|e| {
497                                    message::MessageError::ConversionError(
498                                        format!("Failed to convert tool result document content into text: {e}"),
499                                    )
500                                })?,
501                            )
502                        }
503                    })
504                })?;
505
506                Ok(message::Message::User {
507                    content: OneOrMany::one(message::UserContent::tool_result(
508                        tool_call_id,
509                        content,
510                    )),
511                })
512            }
513            Message::System { content } => Ok(message::Message::user(content)),
514        }
515    }
516}
517
518#[derive(Clone)]
519pub struct CompletionModel<T = reqwest::Client> {
520    pub(crate) client: Client<T>,
521    pub model: String,
522}
523
524impl<T> CompletionModel<T>
525where
526    T: HttpClientExt,
527{
528    pub fn new(client: Client<T>, model: &str) -> Self {
529        Self {
530            client,
531            model: model.to_string(),
532        }
533    }
534
535    pub(crate) fn create_completion_request(
536        &self,
537        completion_request: CompletionRequest,
538    ) -> Result<Value, CompletionError> {
539        // Build up the order of messages (context, chat_history)
540        let mut partial_history = vec![];
541        if let Some(docs) = completion_request.normalized_documents() {
542            partial_history.push(docs);
543        }
544        partial_history.extend(completion_request.chat_history);
545
546        // Initialize full history with preamble (or empty if non-existent)
547        let mut full_history: Vec<Message> = completion_request
548            .preamble
549            .map_or_else(Vec::new, |preamble| {
550                vec![Message::System { content: preamble }]
551            });
552
553        // Convert and extend the rest of the history
554        full_history.extend(
555            partial_history
556                .into_iter()
557                .map(message::Message::try_into)
558                .collect::<Result<Vec<Vec<Message>>, _>>()?
559                .into_iter()
560                .flatten()
561                .collect::<Vec<_>>(),
562        );
563
564        let request = json!({
565            "model": self.model,
566            "messages": full_history,
567            "documents": completion_request.documents,
568            "temperature": completion_request.temperature,
569            "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
570            "tool_choice": if let Some(tool_choice) = completion_request.tool_choice && !matches!(tool_choice, ToolChoice::Auto) { tool_choice } else {
571                return Err(CompletionError::RequestError("\"auto\" is not an allowed tool_choice value in the Cohere API".into()))
572            },
573        });
574
575        if let Some(ref params) = completion_request.additional_params {
576            Ok(json_utils::merge(request.clone(), params.clone()))
577        } else {
578            Ok(request)
579        }
580    }
581}
582
583impl<T> completion::CompletionModel for CompletionModel<T>
584where
585    T: HttpClientExt + Clone + 'static,
586{
587    type Response = CompletionResponse;
588    type StreamingResponse = StreamingCompletionResponse;
589
590    #[cfg_attr(feature = "worker", worker::send)]
591    async fn completion(
592        &self,
593        completion_request: completion::CompletionRequest,
594    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
595        let request = self.create_completion_request(completion_request)?;
596
597        let llm_span = if tracing::Span::current().is_disabled() {
598            info_span!(
599            target: "rig::completions",
600            "chat",
601            gen_ai.operation.name = "chat",
602            gen_ai.provider.name = "cohere",
603            gen_ai.request.model = self.model,
604            gen_ai.response.id = tracing::field::Empty,
605            gen_ai.response.model = self.model,
606            gen_ai.usage.output_tokens = tracing::field::Empty,
607            gen_ai.usage.input_tokens = tracing::field::Empty,
608            gen_ai.input.messages = serde_json::to_string(request.get("messages").expect("Converting request messages to JSON should not fail!")).unwrap(),
609            gen_ai.output.messages = tracing::field::Empty,
610            )
611        } else {
612            tracing::Span::current()
613        };
614
615        tracing::debug!(
616            "Cohere request: {}",
617            serde_json::to_string_pretty(&request)?
618        );
619
620        let req_body = serde_json::to_vec(&request)?;
621
622        let req = self
623            .client
624            .req(Method::POST, "/v2/chat")?
625            .body(req_body)
626            .unwrap();
627
628        async {
629            let response = self
630                .client
631                .http_client()
632                .send::<_, bytes::Bytes>(req)
633                .await
634                .map_err(|e| http_client::Error::Instance(e.into()))?;
635
636            let status = response.status();
637            let body = response.into_body().into_future().await?.to_owned();
638
639            if status.is_success() {
640                let json_response: CompletionResponse = serde_json::from_slice(&body)?;
641                let span = tracing::Span::current();
642                span.record_token_usage(&json_response.usage);
643                span.record_model_output(&json_response.message);
644                span.record_response_metadata(&json_response);
645                tracing::debug!(
646                    "Cohere completion response: {}",
647                    serde_json::to_string_pretty(&json_response)?
648                );
649                let completion: completion::CompletionResponse<CompletionResponse> =
650                    json_response.try_into()?;
651                Ok(completion)
652            } else {
653                Err(CompletionError::ProviderError(
654                    String::from_utf8_lossy(&body).to_string(),
655                ))
656            }
657        }
658        .instrument(llm_span)
659        .await
660    }
661
662    #[cfg_attr(feature = "worker", worker::send)]
663    async fn stream(
664        &self,
665        request: CompletionRequest,
666    ) -> Result<
667        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
668        CompletionError,
669    > {
670        CompletionModel::stream(self, request).await
671    }
672}
673#[cfg(test)]
674mod tests {
675    use super::*;
676    use serde_path_to_error::deserialize;
677
678    #[test]
679    fn test_deserialize_completion_response() {
680        let json_data = r#"
681        {
682            "id": "abc123",
683            "message": {
684                "role": "assistant",
685                "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
686                "tool_calls": [
687                        {
688                            "id": "subtract_sm6ps6fb6y9f",
689                            "type": "function",
690                            "function": {
691                                "name": "subtract",
692                                "arguments": "{\"x\":5,\"y\":2}"
693                            }
694                        }
695                    ]
696                },
697                "finish_reason": "TOOL_CALL",
698                "usage": {
699                "billed_units": {
700                    "input_tokens": 78,
701                    "output_tokens": 27
702                },
703                "tokens": {
704                    "input_tokens": 1028,
705                    "output_tokens": 63
706                }
707            }
708        }
709        "#;
710
711        let mut deserializer = serde_json::Deserializer::from_str(json_data);
712        let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
713
714        let response = result.unwrap();
715        let (_, citations, tool_calls) = response.message();
716        let CompletionResponse {
717            id,
718            finish_reason,
719            usage,
720            ..
721        } = response;
722
723        assert_eq!(id, "abc123");
724        assert_eq!(finish_reason, FinishReason::ToolCall);
725
726        let Usage {
727            billed_units,
728            tokens,
729        } = usage.unwrap();
730        let BilledUnits {
731            input_tokens: billed_input_tokens,
732            output_tokens: billed_output_tokens,
733            ..
734        } = billed_units.unwrap();
735        let Tokens {
736            input_tokens,
737            output_tokens,
738        } = tokens.unwrap();
739
740        assert_eq!(billed_input_tokens.unwrap(), 78.0);
741        assert_eq!(billed_output_tokens.unwrap(), 27.0);
742        assert_eq!(input_tokens.unwrap(), 1028.0);
743        assert_eq!(output_tokens.unwrap(), 63.0);
744
745        assert!(citations.is_empty());
746        assert_eq!(tool_calls.len(), 1);
747
748        let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
749
750        assert_eq!(name, "subtract");
751        assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
752    }
753
754    #[test]
755    fn test_convert_completion_message_to_message_and_back() {
756        let completion_message = completion::Message::User {
757            content: OneOrMany::one(completion::message::UserContent::Text(
758                completion::message::Text {
759                    text: "Hello, world!".to_string(),
760                },
761            )),
762        };
763
764        let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
765        let _converted_back: Vec<completion::Message> = messages
766            .into_iter()
767            .map(|msg| msg.try_into().unwrap())
768            .collect::<Vec<_>>();
769    }
770
771    #[test]
772    fn test_convert_message_to_completion_message_and_back() {
773        let message = Message::User {
774            content: OneOrMany::one(UserContent::Text {
775                text: "Hello, world!".to_string(),
776            }),
777        };
778
779        let completion_message: completion::Message = message.clone().try_into().unwrap();
780        let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
781    }
782}