rig/providers/cohere/
completion.rs

1use crate::{
2    OneOrMany,
3    completion::{self, CompletionError},
4    json_utils, message,
5};
6use std::collections::HashMap;
7
8use super::client::Client;
9use crate::completion::CompletionRequest;
10use crate::providers::cohere::streaming::StreamingCompletionResponse;
11use serde::{Deserialize, Serialize};
12use serde_json::{Value, json};
13
14#[derive(Debug, Deserialize, Serialize)]
15pub struct CompletionResponse {
16    pub id: String,
17    pub finish_reason: FinishReason,
18    message: Message,
19    #[serde(default)]
20    pub usage: Option<Usage>,
21}
22
23impl CompletionResponse {
24    /// Return that parts of the response for assistant messages w/o dealing with the other variants
25    pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
26        let Message::Assistant {
27            content,
28            citations,
29            tool_calls,
30            ..
31        } = self.message.clone()
32        else {
33            unreachable!("Completion responses will only return an assistant message")
34        };
35
36        (content, citations, tool_calls)
37    }
38}
39
40#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
41#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
42pub enum FinishReason {
43    MaxTokens,
44    StopSequence,
45    Complete,
46    Error,
47    ToolCall,
48}
49
50#[derive(Debug, Deserialize, Clone, Serialize)]
51pub struct Usage {
52    #[serde(default)]
53    pub billed_units: Option<BilledUnits>,
54    #[serde(default)]
55    pub tokens: Option<Tokens>,
56}
57
58#[derive(Debug, Deserialize, Clone, Serialize)]
59pub struct BilledUnits {
60    #[serde(default)]
61    pub output_tokens: Option<f64>,
62    #[serde(default)]
63    pub classifications: Option<f64>,
64    #[serde(default)]
65    pub search_units: Option<f64>,
66    #[serde(default)]
67    pub input_tokens: Option<f64>,
68}
69
70#[derive(Debug, Deserialize, Clone, Serialize)]
71pub struct Tokens {
72    #[serde(default)]
73    pub input_tokens: Option<f64>,
74    #[serde(default)]
75    pub output_tokens: Option<f64>,
76}
77
78impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
79    type Error = CompletionError;
80
81    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
82        let (content, _, tool_calls) = response.message();
83
84        let model_response = if !tool_calls.is_empty() {
85            OneOrMany::many(
86                tool_calls
87                    .into_iter()
88                    .filter_map(|tool_call| {
89                        let ToolCallFunction { name, arguments } = tool_call.function?;
90                        let id = tool_call.id.unwrap_or_else(|| name.clone());
91
92                        Some(completion::AssistantContent::tool_call(id, name, arguments))
93                    })
94                    .collect::<Vec<_>>(),
95            )
96            .expect("We have atleast 1 tool call in this if block")
97        } else {
98            OneOrMany::many(content.into_iter().map(|content| match content {
99                AssistantContent::Text { text } => completion::AssistantContent::text(text),
100            }))
101            .map_err(|_| {
102                CompletionError::ResponseError(
103                    "Response contained no message or tool call (empty)".to_owned(),
104                )
105            })?
106        };
107
108        let usage = response
109            .usage
110            .as_ref()
111            .and_then(|usage| usage.tokens.as_ref())
112            .map(|tokens| {
113                let input_tokens = tokens.input_tokens.unwrap_or(0.0);
114                let output_tokens = tokens.output_tokens.unwrap_or(0.0);
115
116                completion::Usage {
117                    input_tokens: input_tokens as u64,
118                    output_tokens: output_tokens as u64,
119                    total_tokens: (input_tokens + output_tokens) as u64,
120                }
121            })
122            .unwrap_or_default();
123
124        Ok(completion::CompletionResponse {
125            choice: OneOrMany::many(model_response).expect("There is atleast one content"),
126            usage,
127            raw_response: response,
128        })
129    }
130}
131
132#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
133pub struct Document {
134    pub id: String,
135    pub data: HashMap<String, serde_json::Value>,
136}
137
138impl From<completion::Document> for Document {
139    fn from(document: completion::Document) -> Self {
140        let mut data: HashMap<String, serde_json::Value> = HashMap::new();
141
142        // We use `.into()` here explicitly since the `document.additional_props` type will likely
143        //  evolve into `serde_json::Value` in the future.
144        document
145            .additional_props
146            .into_iter()
147            .for_each(|(key, value)| {
148                data.insert(key, value.into());
149            });
150
151        data.insert("text".to_string(), document.text.into());
152
153        Self {
154            id: document.id,
155            data,
156        }
157    }
158}
159
160#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
161pub struct ToolCall {
162    #[serde(default)]
163    pub id: Option<String>,
164    #[serde(default)]
165    pub r#type: Option<ToolType>,
166    #[serde(default)]
167    pub function: Option<ToolCallFunction>,
168}
169
170#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
171pub struct ToolCallFunction {
172    pub name: String,
173    #[serde(with = "json_utils::stringified_json")]
174    pub arguments: serde_json::Value,
175}
176
177#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
178#[serde(rename_all = "lowercase")]
179pub enum ToolType {
180    #[default]
181    Function,
182}
183
184#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
185pub struct Tool {
186    pub r#type: ToolType,
187    pub function: Function,
188}
189
190#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
191pub struct Function {
192    pub name: String,
193    #[serde(default)]
194    pub description: Option<String>,
195    pub parameters: serde_json::Value,
196}
197
198impl From<completion::ToolDefinition> for Tool {
199    fn from(tool: completion::ToolDefinition) -> Self {
200        Self {
201            r#type: ToolType::default(),
202            function: Function {
203                name: tool.name,
204                description: Some(tool.description),
205                parameters: tool.parameters,
206            },
207        }
208    }
209}
210
211#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
212#[serde(tag = "role", rename_all = "lowercase")]
213pub enum Message {
214    User {
215        content: OneOrMany<UserContent>,
216    },
217
218    Assistant {
219        #[serde(default)]
220        content: Vec<AssistantContent>,
221        #[serde(default)]
222        citations: Vec<Citation>,
223        #[serde(default)]
224        tool_calls: Vec<ToolCall>,
225        #[serde(default)]
226        tool_plan: Option<String>,
227    },
228
229    Tool {
230        content: OneOrMany<ToolResultContent>,
231        tool_call_id: String,
232    },
233
234    System {
235        content: String,
236    },
237}
238
239#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
240#[serde(tag = "type", rename_all = "lowercase")]
241pub enum UserContent {
242    Text { text: String },
243    ImageUrl { image_url: ImageUrl },
244}
245
246#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
247#[serde(tag = "type", rename_all = "lowercase")]
248pub enum AssistantContent {
249    Text { text: String },
250}
251
252#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
253pub struct ImageUrl {
254    pub url: String,
255}
256
257#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
258pub enum ToolResultContent {
259    Text { text: String },
260    Document { document: Document },
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
264pub struct Citation {
265    #[serde(default)]
266    pub start: Option<u32>,
267    #[serde(default)]
268    pub end: Option<u32>,
269    #[serde(default)]
270    pub text: Option<String>,
271    #[serde(rename = "type")]
272    pub citation_type: Option<CitationType>,
273    #[serde(default)]
274    pub sources: Vec<Source>,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
278#[serde(tag = "type", rename_all = "lowercase")]
279pub enum Source {
280    Document {
281        id: Option<String>,
282        document: Option<serde_json::Map<String, serde_json::Value>>,
283    },
284    Tool {
285        id: Option<String>,
286        tool_output: Option<serde_json::Map<String, serde_json::Value>>,
287    },
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
291#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
292pub enum CitationType {
293    TextContent,
294    Plan,
295}
296
297impl TryFrom<message::Message> for Vec<Message> {
298    type Error = message::MessageError;
299
300    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
301        Ok(match message {
302            message::Message::User { content } => content
303                .into_iter()
304                .map(|content| match content {
305                    message::UserContent::Text(message::Text { text }) => Ok(Message::User {
306                        content: OneOrMany::one(UserContent::Text { text }),
307                    }),
308                    message::UserContent::ToolResult(message::ToolResult {
309                        id, content, ..
310                    }) => Ok(Message::Tool {
311                        tool_call_id: id,
312                        content: content.try_map(|content| match content {
313                            message::ToolResultContent::Text(text) => {
314                                Ok(ToolResultContent::Text { text: text.text })
315                            }
316                            _ => Err(message::MessageError::ConversionError(
317                                "Only text tool result content is supported by Cohere".to_owned(),
318                            )),
319                        })?,
320                    }),
321                    _ => Err(message::MessageError::ConversionError(
322                        "Only text content is supported by Cohere".to_owned(),
323                    )),
324                })
325                .collect::<Result<Vec<_>, _>>()?,
326            message::Message::Assistant { content, .. } => {
327                let mut text_content = vec![];
328                let mut tool_calls = vec![];
329                content.into_iter().for_each(|content| match content {
330                    message::AssistantContent::Text(message::Text { text }) => {
331                        text_content.push(AssistantContent::Text { text });
332                    }
333                    message::AssistantContent::ToolCall(message::ToolCall {
334                        id,
335                        function:
336                            message::ToolFunction {
337                                name, arguments, ..
338                            },
339                        ..
340                    }) => {
341                        tool_calls.push(ToolCall {
342                            id: Some(id),
343                            r#type: Some(ToolType::Function),
344                            function: Some(ToolCallFunction {
345                                name,
346                                arguments: serde_json::to_value(arguments).unwrap_or_default(),
347                            }),
348                        });
349                    }
350                    message::AssistantContent::Reasoning(_) => {
351                        unimplemented!("Reasoning is not natively supported on Cohere V2");
352                    }
353                });
354
355                vec![Message::Assistant {
356                    content: text_content,
357                    citations: vec![],
358                    tool_calls,
359                    tool_plan: None,
360                }]
361            }
362        })
363    }
364}
365
366impl TryFrom<Message> for message::Message {
367    type Error = message::MessageError;
368
369    fn try_from(message: Message) -> Result<Self, Self::Error> {
370        match message {
371            Message::User { content } => Ok(message::Message::User {
372                content: content.map(|content| match content {
373                    UserContent::Text { text } => {
374                        message::UserContent::Text(message::Text { text })
375                    }
376                    UserContent::ImageUrl { image_url } => message::UserContent::image(
377                        image_url.url,
378                        Some(message::ContentFormat::String),
379                        None,
380                        None,
381                    ),
382                }),
383            }),
384            Message::Assistant {
385                content,
386                tool_calls,
387                ..
388            } => {
389                let mut content = content
390                    .into_iter()
391                    .map(|content| match content {
392                        AssistantContent::Text { text } => message::AssistantContent::text(text),
393                    })
394                    .collect::<Vec<_>>();
395
396                content.extend(tool_calls.into_iter().filter_map(|tool_call| {
397                    let ToolCallFunction { name, arguments } = tool_call.function?;
398
399                    Some(message::AssistantContent::tool_call(
400                        tool_call.id.unwrap_or_else(|| name.clone()),
401                        name,
402                        arguments,
403                    ))
404                }));
405
406                let content = OneOrMany::many(content).map_err(|_| {
407                    message::MessageError::ConversionError(
408                        "Expected either text content or tool calls".to_string(),
409                    )
410                })?;
411
412                Ok(message::Message::Assistant { id: None, content })
413            }
414            Message::Tool {
415                content,
416                tool_call_id,
417            } => {
418                let content = content.try_map(|content| {
419                    Ok(match content {
420                        ToolResultContent::Text { text } => message::ToolResultContent::text(text),
421                        ToolResultContent::Document { document } => {
422                            message::ToolResultContent::text(
423                                serde_json::to_string(&document.data).map_err(|e| {
424                                    message::MessageError::ConversionError(
425                                        format!("Failed to convert tool result document content into text: {e}"),
426                                    )
427                                })?,
428                            )
429                        }
430                    })
431                })?;
432
433                Ok(message::Message::User {
434                    content: OneOrMany::one(message::UserContent::tool_result(
435                        tool_call_id,
436                        content,
437                    )),
438                })
439            }
440            Message::System { content } => Ok(message::Message::user(content)),
441        }
442    }
443}
444
445#[derive(Clone)]
446pub struct CompletionModel {
447    pub(crate) client: Client,
448    pub model: String,
449}
450
451impl CompletionModel {
452    pub fn new(client: Client, model: &str) -> Self {
453        Self {
454            client,
455            model: model.to_string(),
456        }
457    }
458
459    pub(crate) fn create_completion_request(
460        &self,
461        completion_request: CompletionRequest,
462    ) -> Result<Value, CompletionError> {
463        // Build up the order of messages (context, chat_history)
464        let mut partial_history = vec![];
465        if let Some(docs) = completion_request.normalized_documents() {
466            partial_history.push(docs);
467        }
468        partial_history.extend(completion_request.chat_history);
469
470        // Initialize full history with preamble (or empty if non-existent)
471        let mut full_history: Vec<Message> = completion_request
472            .preamble
473            .map_or_else(Vec::new, |preamble| {
474                vec![Message::System { content: preamble }]
475            });
476
477        // Convert and extend the rest of the history
478        full_history.extend(
479            partial_history
480                .into_iter()
481                .map(message::Message::try_into)
482                .collect::<Result<Vec<Vec<Message>>, _>>()?
483                .into_iter()
484                .flatten()
485                .collect::<Vec<_>>(),
486        );
487
488        let request = json!({
489            "model": self.model,
490            "messages": full_history,
491            "documents": completion_request.documents,
492            "temperature": completion_request.temperature,
493            "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
494        });
495
496        if let Some(ref params) = completion_request.additional_params {
497            Ok(json_utils::merge(request.clone(), params.clone()))
498        } else {
499            Ok(request)
500        }
501    }
502}
503
504impl completion::CompletionModel for CompletionModel {
505    type Response = CompletionResponse;
506    type StreamingResponse = StreamingCompletionResponse;
507
508    #[cfg_attr(feature = "worker", worker::send)]
509    async fn completion(
510        &self,
511        completion_request: completion::CompletionRequest,
512    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
513        let request = self.create_completion_request(completion_request)?;
514        tracing::debug!(
515            "Cohere request: {}",
516            serde_json::to_string_pretty(&request)?
517        );
518
519        let response = self.client.post("/v2/chat").json(&request).send().await?;
520
521        if response.status().is_success() {
522            let text_response = response.text().await?;
523            tracing::debug!("Cohere response text: {}", text_response);
524
525            let json_response: CompletionResponse = serde_json::from_str(&text_response)?;
526            let completion: completion::CompletionResponse<CompletionResponse> =
527                json_response.try_into()?;
528            Ok(completion)
529        } else {
530            Err(CompletionError::ProviderError(response.text().await?))
531        }
532    }
533
534    #[cfg_attr(feature = "worker", worker::send)]
535    async fn stream(
536        &self,
537        request: CompletionRequest,
538    ) -> Result<
539        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
540        CompletionError,
541    > {
542        CompletionModel::stream(self, request).await
543    }
544}
545#[cfg(test)]
546mod tests {
547    use super::*;
548    use serde_path_to_error::deserialize;
549
550    #[test]
551    fn test_deserialize_completion_response() {
552        let json_data = r#"
553        {
554            "id": "abc123",
555            "message": {
556                "role": "assistant",
557                "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
558                "tool_calls": [
559                        {
560                            "id": "subtract_sm6ps6fb6y9f",
561                            "type": "function",
562                            "function": {
563                                "name": "subtract",
564                                "arguments": "{\"x\":5,\"y\":2}"
565                            }
566                        }
567                    ]
568                },
569                "finish_reason": "TOOL_CALL",
570                "usage": {
571                "billed_units": {
572                    "input_tokens": 78,
573                    "output_tokens": 27
574                },
575                "tokens": {
576                    "input_tokens": 1028,
577                    "output_tokens": 63
578                }
579            }
580        }
581        "#;
582
583        let mut deserializer = serde_json::Deserializer::from_str(json_data);
584        let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
585
586        let response = result.unwrap();
587        let (_, citations, tool_calls) = response.message();
588        let CompletionResponse {
589            id,
590            finish_reason,
591            usage,
592            ..
593        } = response;
594
595        assert_eq!(id, "abc123");
596        assert_eq!(finish_reason, FinishReason::ToolCall);
597
598        let Usage {
599            billed_units,
600            tokens,
601        } = usage.unwrap();
602        let BilledUnits {
603            input_tokens: billed_input_tokens,
604            output_tokens: billed_output_tokens,
605            ..
606        } = billed_units.unwrap();
607        let Tokens {
608            input_tokens,
609            output_tokens,
610        } = tokens.unwrap();
611
612        assert_eq!(billed_input_tokens.unwrap(), 78.0);
613        assert_eq!(billed_output_tokens.unwrap(), 27.0);
614        assert_eq!(input_tokens.unwrap(), 1028.0);
615        assert_eq!(output_tokens.unwrap(), 63.0);
616
617        assert!(citations.is_empty());
618        assert_eq!(tool_calls.len(), 1);
619
620        let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
621
622        assert_eq!(name, "subtract");
623        assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
624    }
625
626    #[test]
627    fn test_convert_completion_message_to_message_and_back() {
628        let completion_message = completion::Message::User {
629            content: OneOrMany::one(completion::message::UserContent::Text(
630                completion::message::Text {
631                    text: "Hello, world!".to_string(),
632                },
633            )),
634        };
635
636        let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
637        let _converted_back: Vec<completion::Message> = messages
638            .into_iter()
639            .map(|msg| msg.try_into().unwrap())
640            .collect::<Vec<_>>();
641    }
642
643    #[test]
644    fn test_convert_message_to_completion_message_and_back() {
645        let message = Message::User {
646            content: OneOrMany::one(UserContent::Text {
647                text: "Hello, world!".to_string(),
648            }),
649        };
650
651        let completion_message: completion::Message = message.clone().try_into().unwrap();
652        let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
653    }
654}