rig/providers/cohere/
completion.rs

1use crate::{
2    OneOrMany,
3    completion::{self, CompletionError},
4    json_utils, message,
5};
6use std::collections::HashMap;
7
8use super::client::Client;
9use crate::completion::CompletionRequest;
10use crate::providers::cohere::streaming::StreamingCompletionResponse;
11use serde::{Deserialize, Serialize};
12use serde_json::{Value, json};
13
14#[derive(Debug, Deserialize)]
15pub struct CompletionResponse {
16    pub id: String,
17    pub finish_reason: FinishReason,
18    message: Message,
19    #[serde(default)]
20    pub usage: Option<Usage>,
21}
22
23impl CompletionResponse {
24    /// Return that parts of the response for assistant messages w/o dealing with the other variants
25    pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
26        let Message::Assistant {
27            content,
28            citations,
29            tool_calls,
30            ..
31        } = self.message.clone()
32        else {
33            unreachable!("Completion responses will only return an assistant message")
34        };
35
36        (content, citations, tool_calls)
37    }
38}
39
40#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
41#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
42pub enum FinishReason {
43    MaxTokens,
44    StopSequence,
45    Complete,
46    Error,
47    ToolCall,
48}
49
50#[derive(Debug, Deserialize, Clone)]
51pub struct Usage {
52    #[serde(default)]
53    pub billed_units: Option<BilledUnits>,
54    #[serde(default)]
55    pub tokens: Option<Tokens>,
56}
57
58#[derive(Debug, Deserialize, Clone)]
59pub struct BilledUnits {
60    #[serde(default)]
61    pub output_tokens: Option<f64>,
62    #[serde(default)]
63    pub classifications: Option<f64>,
64    #[serde(default)]
65    pub search_units: Option<f64>,
66    #[serde(default)]
67    pub input_tokens: Option<f64>,
68}
69
70#[derive(Debug, Deserialize, Clone)]
71pub struct Tokens {
72    #[serde(default)]
73    pub input_tokens: Option<f64>,
74    #[serde(default)]
75    pub output_tokens: Option<f64>,
76}
77
78impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
79    type Error = CompletionError;
80
81    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
82        let (content, _, tool_calls) = response.message();
83
84        let model_response = if !tool_calls.is_empty() {
85            OneOrMany::many(
86                tool_calls
87                    .into_iter()
88                    .filter_map(|tool_call| {
89                        let ToolCallFunction { name, arguments } = tool_call.function?;
90                        let id = tool_call.id.unwrap_or_else(|| name.clone());
91
92                        Some(completion::AssistantContent::tool_call(id, name, arguments))
93                    })
94                    .collect::<Vec<_>>(),
95            )
96            .expect("We have atleast 1 tool call in this if block")
97        } else {
98            OneOrMany::many(content.into_iter().map(|content| match content {
99                AssistantContent::Text { text } => completion::AssistantContent::text(text),
100            }))
101            .map_err(|_| {
102                CompletionError::ResponseError(
103                    "Response contained no message or tool call (empty)".to_owned(),
104                )
105            })?
106        };
107
108        Ok(completion::CompletionResponse {
109            choice: OneOrMany::many(model_response).expect("There is atleast one content"),
110            raw_response: response,
111        })
112    }
113}
114
115#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
116pub struct Document {
117    pub id: String,
118    pub data: HashMap<String, serde_json::Value>,
119}
120
121impl From<completion::Document> for Document {
122    fn from(document: completion::Document) -> Self {
123        let mut data: HashMap<String, serde_json::Value> = HashMap::new();
124
125        // We use `.into()` here explicitly since the `document.additional_props` type will likely
126        //  evolve into `serde_json::Value` in the future.
127        document
128            .additional_props
129            .into_iter()
130            .for_each(|(key, value)| {
131                data.insert(key, value.into());
132            });
133
134        data.insert("text".to_string(), document.text.into());
135
136        Self {
137            id: document.id,
138            data,
139        }
140    }
141}
142
143#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
144pub struct ToolCall {
145    #[serde(default)]
146    pub id: Option<String>,
147    #[serde(default)]
148    pub r#type: Option<ToolType>,
149    #[serde(default)]
150    pub function: Option<ToolCallFunction>,
151}
152
153#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
154pub struct ToolCallFunction {
155    pub name: String,
156    #[serde(with = "json_utils::stringified_json")]
157    pub arguments: serde_json::Value,
158}
159
160#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
161#[serde(rename_all = "lowercase")]
162pub enum ToolType {
163    #[default]
164    Function,
165}
166
167#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
168pub struct Tool {
169    pub r#type: ToolType,
170    pub function: Function,
171}
172
173#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
174pub struct Function {
175    pub name: String,
176    #[serde(default)]
177    pub description: Option<String>,
178    pub parameters: serde_json::Value,
179}
180
181impl From<completion::ToolDefinition> for Tool {
182    fn from(tool: completion::ToolDefinition) -> Self {
183        Self {
184            r#type: ToolType::default(),
185            function: Function {
186                name: tool.name,
187                description: Some(tool.description),
188                parameters: tool.parameters,
189            },
190        }
191    }
192}
193
194#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
195#[serde(tag = "role", rename_all = "lowercase")]
196pub enum Message {
197    User {
198        content: OneOrMany<UserContent>,
199    },
200
201    Assistant {
202        #[serde(default)]
203        content: Vec<AssistantContent>,
204        #[serde(default)]
205        citations: Vec<Citation>,
206        #[serde(default)]
207        tool_calls: Vec<ToolCall>,
208        #[serde(default)]
209        tool_plan: Option<String>,
210    },
211
212    Tool {
213        content: OneOrMany<ToolResultContent>,
214        tool_call_id: String,
215    },
216
217    System {
218        content: String,
219    },
220}
221
222#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
223#[serde(tag = "type", rename_all = "lowercase")]
224pub enum UserContent {
225    Text { text: String },
226    ImageUrl { image_url: ImageUrl },
227}
228
229#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
230#[serde(tag = "type", rename_all = "lowercase")]
231pub enum AssistantContent {
232    Text { text: String },
233}
234
235#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
236pub struct ImageUrl {
237    pub url: String,
238}
239
240#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
241pub enum ToolResultContent {
242    Text { text: String },
243    Document { document: Document },
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
247pub struct Citation {
248    #[serde(default)]
249    pub start: Option<u32>,
250    #[serde(default)]
251    pub end: Option<u32>,
252    #[serde(default)]
253    pub text: Option<String>,
254    #[serde(rename = "type")]
255    pub citation_type: Option<CitationType>,
256    #[serde(default)]
257    pub sources: Vec<Source>,
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
261#[serde(tag = "type", rename_all = "lowercase")]
262pub enum Source {
263    Document {
264        id: Option<String>,
265        document: Option<serde_json::Map<String, serde_json::Value>>,
266    },
267    Tool {
268        id: Option<String>,
269        tool_output: Option<serde_json::Map<String, serde_json::Value>>,
270    },
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
274#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
275pub enum CitationType {
276    TextContent,
277    Plan,
278}
279
280impl TryFrom<message::Message> for Vec<Message> {
281    type Error = message::MessageError;
282
283    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
284        Ok(match message {
285            message::Message::User { content } => content
286                .into_iter()
287                .map(|content| match content {
288                    message::UserContent::Text(message::Text { text }) => Ok(Message::User {
289                        content: OneOrMany::one(UserContent::Text { text }),
290                    }),
291                    message::UserContent::ToolResult(message::ToolResult {
292                        id, content, ..
293                    }) => Ok(Message::Tool {
294                        tool_call_id: id,
295                        content: content.try_map(|content| match content {
296                            message::ToolResultContent::Text(text) => {
297                                Ok(ToolResultContent::Text { text: text.text })
298                            }
299                            _ => Err(message::MessageError::ConversionError(
300                                "Only text tool result content is supported by Cohere".to_owned(),
301                            )),
302                        })?,
303                    }),
304                    _ => Err(message::MessageError::ConversionError(
305                        "Only text content is supported by Cohere".to_owned(),
306                    )),
307                })
308                .collect::<Result<Vec<_>, _>>()?,
309            message::Message::Assistant { content, .. } => {
310                let mut text_content = vec![];
311                let mut tool_calls = vec![];
312                content.into_iter().for_each(|content| match content {
313                    message::AssistantContent::Text(message::Text { text }) => {
314                        text_content.push(AssistantContent::Text { text });
315                    }
316                    message::AssistantContent::ToolCall(message::ToolCall {
317                        id,
318                        function:
319                            message::ToolFunction {
320                                name, arguments, ..
321                            },
322                        ..
323                    }) => {
324                        tool_calls.push(ToolCall {
325                            id: Some(id),
326                            r#type: Some(ToolType::Function),
327                            function: Some(ToolCallFunction {
328                                name,
329                                arguments: serde_json::to_value(arguments).unwrap_or_default(),
330                            }),
331                        });
332                    }
333                });
334
335                vec![Message::Assistant {
336                    content: text_content,
337                    citations: vec![],
338                    tool_calls,
339                    tool_plan: None,
340                }]
341            }
342        })
343    }
344}
345
346impl TryFrom<Message> for message::Message {
347    type Error = message::MessageError;
348
349    fn try_from(message: Message) -> Result<Self, Self::Error> {
350        match message {
351            Message::User { content } => Ok(message::Message::User {
352                content: content.map(|content| match content {
353                    UserContent::Text { text } => {
354                        message::UserContent::Text(message::Text { text })
355                    }
356                    UserContent::ImageUrl { image_url } => message::UserContent::image(
357                        image_url.url,
358                        Some(message::ContentFormat::String),
359                        None,
360                        None,
361                    ),
362                }),
363            }),
364            Message::Assistant {
365                content,
366                tool_calls,
367                ..
368            } => {
369                let mut content = content
370                    .into_iter()
371                    .map(|content| match content {
372                        AssistantContent::Text { text } => message::AssistantContent::text(text),
373                    })
374                    .collect::<Vec<_>>();
375
376                content.extend(tool_calls.into_iter().filter_map(|tool_call| {
377                    let ToolCallFunction { name, arguments } = tool_call.function?;
378
379                    Some(message::AssistantContent::tool_call(
380                        tool_call.id.unwrap_or_else(|| name.clone()),
381                        name,
382                        arguments,
383                    ))
384                }));
385
386                let content = OneOrMany::many(content).map_err(|_| {
387                    message::MessageError::ConversionError(
388                        "Expected either text content or tool calls".to_string(),
389                    )
390                })?;
391
392                Ok(message::Message::Assistant { id: None, content })
393            }
394            Message::Tool {
395                content,
396                tool_call_id,
397            } => {
398                let content = content.try_map(|content| {
399                    Ok(match content {
400                        ToolResultContent::Text { text } => message::ToolResultContent::text(text),
401                        ToolResultContent::Document { document } => {
402                            message::ToolResultContent::text(
403                                serde_json::to_string(&document.data).map_err(|e| {
404                                    message::MessageError::ConversionError(
405                                        format!("Failed to convert tool result document content into text: {e}"),
406                                    )
407                                })?,
408                            )
409                        }
410                    })
411                })?;
412
413                Ok(message::Message::User {
414                    content: OneOrMany::one(message::UserContent::tool_result(
415                        tool_call_id,
416                        content,
417                    )),
418                })
419            }
420            Message::System { content } => Ok(message::Message::user(content)),
421        }
422    }
423}
424
425#[derive(Clone)]
426pub struct CompletionModel {
427    pub(crate) client: Client,
428    pub model: String,
429}
430
431impl CompletionModel {
432    pub fn new(client: Client, model: &str) -> Self {
433        Self {
434            client,
435            model: model.to_string(),
436        }
437    }
438
439    pub(crate) fn create_completion_request(
440        &self,
441        completion_request: CompletionRequest,
442    ) -> Result<Value, CompletionError> {
443        // Build up the order of messages (context, chat_history)
444        let mut partial_history = vec![];
445        if let Some(docs) = completion_request.normalized_documents() {
446            partial_history.push(docs);
447        }
448        partial_history.extend(completion_request.chat_history);
449
450        // Initialize full history with preamble (or empty if non-existent)
451        let mut full_history: Vec<Message> = completion_request
452            .preamble
453            .map_or_else(Vec::new, |preamble| {
454                vec![Message::System { content: preamble }]
455            });
456
457        // Convert and extend the rest of the history
458        full_history.extend(
459            partial_history
460                .into_iter()
461                .map(message::Message::try_into)
462                .collect::<Result<Vec<Vec<Message>>, _>>()?
463                .into_iter()
464                .flatten()
465                .collect::<Vec<_>>(),
466        );
467
468        let request = json!({
469            "model": self.model,
470            "messages": full_history,
471            "documents": completion_request.documents,
472            "temperature": completion_request.temperature,
473            "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
474        });
475
476        if let Some(ref params) = completion_request.additional_params {
477            Ok(json_utils::merge(request.clone(), params.clone()))
478        } else {
479            Ok(request)
480        }
481    }
482}
483
484impl completion::CompletionModel for CompletionModel {
485    type Response = CompletionResponse;
486    type StreamingResponse = StreamingCompletionResponse;
487
488    #[cfg_attr(feature = "worker", worker::send)]
489    async fn completion(
490        &self,
491        completion_request: completion::CompletionRequest,
492    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
493        let request = self.create_completion_request(completion_request)?;
494        tracing::debug!(
495            "Cohere request: {}",
496            serde_json::to_string_pretty(&request)?
497        );
498
499        let response = self.client.post("/v2/chat").json(&request).send().await?;
500
501        if response.status().is_success() {
502            let text_response = response.text().await?;
503            tracing::debug!("Cohere response text: {}", text_response);
504
505            let json_response: CompletionResponse = serde_json::from_str(&text_response)?;
506            let completion: completion::CompletionResponse<CompletionResponse> =
507                json_response.try_into()?;
508            Ok(completion)
509        } else {
510            Err(CompletionError::ProviderError(response.text().await?))
511        }
512    }
513
514    #[cfg_attr(feature = "worker", worker::send)]
515    async fn stream(
516        &self,
517        request: CompletionRequest,
518    ) -> Result<
519        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
520        CompletionError,
521    > {
522        CompletionModel::stream(self, request).await
523    }
524}
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use serde_path_to_error::deserialize;
529
530    #[test]
531    fn test_deserialize_completion_response() {
532        let json_data = r#"
533        {
534            "id": "abc123",
535            "message": {
536                "role": "assistant",
537                "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
538                "tool_calls": [
539                        {
540                            "id": "subtract_sm6ps6fb6y9f",
541                            "type": "function",
542                            "function": {
543                                "name": "subtract",
544                                "arguments": "{\"x\":5,\"y\":2}"
545                            }
546                        }
547                    ]
548                },
549                "finish_reason": "TOOL_CALL",
550                "usage": {
551                "billed_units": {
552                    "input_tokens": 78,
553                    "output_tokens": 27
554                },
555                "tokens": {
556                    "input_tokens": 1028,
557                    "output_tokens": 63
558                }
559            }
560        }
561        "#;
562
563        let mut deserializer = serde_json::Deserializer::from_str(json_data);
564        let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
565
566        let response = result.unwrap();
567        let (_, citations, tool_calls) = response.message();
568        let CompletionResponse {
569            id,
570            finish_reason,
571            usage,
572            ..
573        } = response;
574
575        assert_eq!(id, "abc123");
576        assert_eq!(finish_reason, FinishReason::ToolCall);
577
578        let Usage {
579            billed_units,
580            tokens,
581        } = usage.unwrap();
582        let BilledUnits {
583            input_tokens: billed_input_tokens,
584            output_tokens: billed_output_tokens,
585            ..
586        } = billed_units.unwrap();
587        let Tokens {
588            input_tokens,
589            output_tokens,
590        } = tokens.unwrap();
591
592        assert_eq!(billed_input_tokens.unwrap(), 78.0);
593        assert_eq!(billed_output_tokens.unwrap(), 27.0);
594        assert_eq!(input_tokens.unwrap(), 1028.0);
595        assert_eq!(output_tokens.unwrap(), 63.0);
596
597        assert!(citations.is_empty());
598        assert_eq!(tool_calls.len(), 1);
599
600        let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
601
602        assert_eq!(name, "subtract");
603        assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
604    }
605
606    #[test]
607    fn test_convert_completion_message_to_message_and_back() {
608        let completion_message = completion::Message::User {
609            content: OneOrMany::one(completion::message::UserContent::Text(
610                completion::message::Text {
611                    text: "Hello, world!".to_string(),
612                },
613            )),
614        };
615
616        let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
617        let _converted_back: Vec<completion::Message> = messages
618            .into_iter()
619            .map(|msg| msg.try_into().unwrap())
620            .collect::<Vec<_>>();
621    }
622
623    #[test]
624    fn test_convert_message_to_completion_message_and_back() {
625        let message = Message::User {
626            content: OneOrMany::one(UserContent::Text {
627                text: "Hello, world!".to_string(),
628            }),
629        };
630
631        let completion_message: completion::Message = message.clone().try_into().unwrap();
632        let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
633    }
634}