rig/providers/cohere/
completion.rs

1use crate::{
2    completion::{self, CompletionError},
3    json_utils, message, OneOrMany,
4};
5use std::collections::HashMap;
6
7use super::client::Client;
8use crate::completion::CompletionRequest;
9use crate::providers::cohere::streaming::StreamingCompletionResponse;
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12
13#[derive(Debug, Deserialize)]
14pub struct CompletionResponse {
15    pub id: String,
16    pub finish_reason: FinishReason,
17    message: Message,
18    #[serde(default)]
19    pub usage: Option<Usage>,
20}
21
22impl CompletionResponse {
23    /// Return that parts of the response for assistant messages w/o dealing with the other variants
24    pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
25        let Message::Assistant {
26            content,
27            citations,
28            tool_calls,
29            ..
30        } = self.message.clone()
31        else {
32            unreachable!("Completion responses will only return an assistant message")
33        };
34
35        (content, citations, tool_calls)
36    }
37}
38
39#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
40#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
41pub enum FinishReason {
42    MaxTokens,
43    StopSequence,
44    Complete,
45    Error,
46    ToolCall,
47}
48
49#[derive(Debug, Deserialize, Clone)]
50pub struct Usage {
51    #[serde(default)]
52    pub billed_units: Option<BilledUnits>,
53    #[serde(default)]
54    pub tokens: Option<Tokens>,
55}
56
57#[derive(Debug, Deserialize, Clone)]
58pub struct BilledUnits {
59    #[serde(default)]
60    pub output_tokens: Option<f64>,
61    #[serde(default)]
62    pub classifications: Option<f64>,
63    #[serde(default)]
64    pub search_units: Option<f64>,
65    #[serde(default)]
66    pub input_tokens: Option<f64>,
67}
68
69#[derive(Debug, Deserialize, Clone)]
70pub struct Tokens {
71    #[serde(default)]
72    pub input_tokens: Option<f64>,
73    #[serde(default)]
74    pub output_tokens: Option<f64>,
75}
76
77impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
78    type Error = CompletionError;
79
80    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
81        let (content, _, tool_calls) = response.message();
82
83        let model_response = if !tool_calls.is_empty() {
84            OneOrMany::many(
85                tool_calls
86                    .into_iter()
87                    .filter_map(|tool_call| {
88                        let ToolCallFunction { name, arguments } = tool_call.function?;
89                        let id = tool_call.id.unwrap_or_else(|| name.clone());
90
91                        Some(completion::AssistantContent::tool_call(id, name, arguments))
92                    })
93                    .collect::<Vec<_>>(),
94            )
95            .expect("We have atleast 1 tool call in this if block")
96        } else {
97            OneOrMany::many(content.into_iter().map(|content| match content {
98                AssistantContent::Text { text } => completion::AssistantContent::text(text),
99            }))
100            .map_err(|_| {
101                CompletionError::ResponseError(
102                    "Response contained no message or tool call (empty)".to_owned(),
103                )
104            })?
105        };
106
107        Ok(completion::CompletionResponse {
108            choice: OneOrMany::many(model_response).expect("There is atleast one content"),
109            raw_response: response,
110        })
111    }
112}
113
114#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
115pub struct Document {
116    pub id: String,
117    pub data: HashMap<String, serde_json::Value>,
118}
119
120impl From<completion::Document> for Document {
121    fn from(document: completion::Document) -> Self {
122        let mut data: HashMap<String, serde_json::Value> = HashMap::new();
123
124        // We use `.into()` here explicitly since the `document.additional_props` type will likely
125        //  evolve into `serde_json::Value` in the future.
126        document
127            .additional_props
128            .into_iter()
129            .for_each(|(key, value)| {
130                data.insert(key, value.into());
131            });
132
133        data.insert("text".to_string(), document.text.into());
134
135        Self {
136            id: document.id,
137            data,
138        }
139    }
140}
141
142#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
143pub struct ToolCall {
144    #[serde(default)]
145    pub id: Option<String>,
146    #[serde(default)]
147    pub r#type: Option<ToolType>,
148    #[serde(default)]
149    pub function: Option<ToolCallFunction>,
150}
151
152#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
153pub struct ToolCallFunction {
154    pub name: String,
155    #[serde(with = "json_utils::stringified_json")]
156    pub arguments: serde_json::Value,
157}
158
159#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
160#[serde(rename_all = "lowercase")]
161pub enum ToolType {
162    #[default]
163    Function,
164}
165
166#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
167pub struct Tool {
168    pub r#type: ToolType,
169    pub function: Function,
170}
171
172#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
173pub struct Function {
174    pub name: String,
175    #[serde(default)]
176    pub description: Option<String>,
177    pub parameters: serde_json::Value,
178}
179
180impl From<completion::ToolDefinition> for Tool {
181    fn from(tool: completion::ToolDefinition) -> Self {
182        Self {
183            r#type: ToolType::default(),
184            function: Function {
185                name: tool.name,
186                description: Some(tool.description),
187                parameters: tool.parameters,
188            },
189        }
190    }
191}
192
193#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
194#[serde(tag = "role", rename_all = "lowercase")]
195pub enum Message {
196    User {
197        content: OneOrMany<UserContent>,
198    },
199
200    Assistant {
201        #[serde(default)]
202        content: Vec<AssistantContent>,
203        #[serde(default)]
204        citations: Vec<Citation>,
205        #[serde(default)]
206        tool_calls: Vec<ToolCall>,
207        #[serde(default)]
208        tool_plan: Option<String>,
209    },
210
211    Tool {
212        content: OneOrMany<ToolResultContent>,
213        tool_call_id: String,
214    },
215
216    System {
217        content: String,
218    },
219}
220
221#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
222#[serde(tag = "type", rename_all = "lowercase")]
223pub enum UserContent {
224    Text { text: String },
225    ImageUrl { image_url: ImageUrl },
226}
227
228#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
229#[serde(tag = "type", rename_all = "lowercase")]
230pub enum AssistantContent {
231    Text { text: String },
232}
233
234#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
235pub struct ImageUrl {
236    pub url: String,
237}
238
239#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
240pub enum ToolResultContent {
241    Text { text: String },
242    Document { document: Document },
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
246pub struct Citation {
247    #[serde(default)]
248    pub start: Option<u32>,
249    #[serde(default)]
250    pub end: Option<u32>,
251    #[serde(default)]
252    pub text: Option<String>,
253    #[serde(rename = "type")]
254    pub citation_type: Option<CitationType>,
255    #[serde(default)]
256    pub sources: Vec<Source>,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
260#[serde(tag = "type", rename_all = "lowercase")]
261pub enum Source {
262    Document {
263        id: Option<String>,
264        document: Option<serde_json::Map<String, serde_json::Value>>,
265    },
266    Tool {
267        id: Option<String>,
268        tool_output: Option<serde_json::Map<String, serde_json::Value>>,
269    },
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
273#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
274pub enum CitationType {
275    TextContent,
276    Plan,
277}
278
279impl TryFrom<message::Message> for Vec<Message> {
280    type Error = message::MessageError;
281
282    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
283        Ok(match message {
284            message::Message::User { content } => content
285                .into_iter()
286                .map(|content| match content {
287                    message::UserContent::Text(message::Text { text }) => Ok(Message::User {
288                        content: OneOrMany::one(UserContent::Text { text }),
289                    }),
290                    message::UserContent::ToolResult(message::ToolResult { id, content }) => {
291                        Ok(Message::Tool {
292                            tool_call_id: id,
293                            content: content.try_map(|content| match content {
294                                message::ToolResultContent::Text(text) => {
295                                    Ok(ToolResultContent::Text { text: text.text })
296                                }
297                                _ => Err(message::MessageError::ConversionError(
298                                    "Only text tool result content is supported by Cohere"
299                                        .to_owned(),
300                                )),
301                            })?,
302                        })
303                    }
304                    _ => Err(message::MessageError::ConversionError(
305                        "Only text content is supported by Cohere".to_owned(),
306                    )),
307                })
308                .collect::<Result<Vec<_>, _>>()?,
309            message::Message::Assistant { content } => {
310                let mut text_content = vec![];
311                let mut tool_calls = vec![];
312                content.into_iter().for_each(|content| match content {
313                    message::AssistantContent::Text(message::Text { text }) => {
314                        text_content.push(AssistantContent::Text { text });
315                    }
316                    message::AssistantContent::ToolCall(message::ToolCall {
317                        id,
318                        function: message::ToolFunction { name, arguments },
319                    }) => {
320                        tool_calls.push(ToolCall {
321                            id: Some(id),
322                            r#type: Some(ToolType::Function),
323                            function: Some(ToolCallFunction {
324                                name,
325                                arguments: serde_json::to_value(arguments).unwrap_or_default(),
326                            }),
327                        });
328                    }
329                });
330
331                vec![Message::Assistant {
332                    content: text_content,
333                    citations: vec![],
334                    tool_calls,
335                    tool_plan: None,
336                }]
337            }
338        })
339    }
340}
341
342impl TryFrom<Message> for message::Message {
343    type Error = message::MessageError;
344
345    fn try_from(message: Message) -> Result<Self, Self::Error> {
346        match message {
347            Message::User { content } => Ok(message::Message::User {
348                content: content.map(|content| match content {
349                    UserContent::Text { text } => {
350                        message::UserContent::Text(message::Text { text })
351                    }
352                    UserContent::ImageUrl { image_url } => message::UserContent::image(
353                        image_url.url,
354                        Some(message::ContentFormat::String),
355                        None,
356                        None,
357                    ),
358                }),
359            }),
360            Message::Assistant {
361                content,
362                tool_calls,
363                ..
364            } => {
365                let mut content = content
366                    .into_iter()
367                    .map(|content| match content {
368                        AssistantContent::Text { text } => message::AssistantContent::text(text),
369                    })
370                    .collect::<Vec<_>>();
371
372                content.extend(tool_calls.into_iter().filter_map(|tool_call| {
373                    let ToolCallFunction { name, arguments } = tool_call.function?;
374
375                    Some(message::AssistantContent::tool_call(
376                        tool_call.id.unwrap_or_else(|| name.clone()),
377                        name,
378                        arguments,
379                    ))
380                }));
381
382                let content = OneOrMany::many(content).map_err(|_| {
383                    message::MessageError::ConversionError(
384                        "Expected either text content or tool calls".to_string(),
385                    )
386                })?;
387
388                Ok(message::Message::Assistant { content })
389            }
390            Message::Tool {
391                content,
392                tool_call_id,
393            } => {
394                let content = content.try_map(|content| {
395                    Ok(match content {
396                        ToolResultContent::Text { text } => message::ToolResultContent::text(text),
397                        ToolResultContent::Document { document } => {
398                            message::ToolResultContent::text(
399                                serde_json::to_string(&document.data).map_err(|e| {
400                                    message::MessageError::ConversionError(
401                                        format!("Failed to convert tool result document content into text: {e}"),
402                                    )
403                                })?,
404                            )
405                        }
406                    })
407                })?;
408
409                Ok(message::Message::User {
410                    content: OneOrMany::one(message::UserContent::tool_result(
411                        tool_call_id,
412                        content,
413                    )),
414                })
415            }
416            Message::System { content } => Ok(message::Message::user(content)),
417        }
418    }
419}
420
421#[derive(Clone)]
422pub struct CompletionModel {
423    pub(crate) client: Client,
424    pub model: String,
425}
426
427impl CompletionModel {
428    pub fn new(client: Client, model: &str) -> Self {
429        Self {
430            client,
431            model: model.to_string(),
432        }
433    }
434
435    pub(crate) fn create_completion_request(
436        &self,
437        completion_request: CompletionRequest,
438    ) -> Result<Value, CompletionError> {
439        // Build up the order of messages (context, chat_history)
440        let mut partial_history = vec![];
441        if let Some(docs) = completion_request.normalized_documents() {
442            partial_history.push(docs);
443        }
444        partial_history.extend(completion_request.chat_history);
445
446        // Initialize full history with preamble (or empty if non-existent)
447        let mut full_history: Vec<Message> = completion_request
448            .preamble
449            .map_or_else(Vec::new, |preamble| {
450                vec![Message::System { content: preamble }]
451            });
452
453        // Convert and extend the rest of the history
454        full_history.extend(
455            partial_history
456                .into_iter()
457                .map(message::Message::try_into)
458                .collect::<Result<Vec<Vec<Message>>, _>>()?
459                .into_iter()
460                .flatten()
461                .collect::<Vec<_>>(),
462        );
463
464        let request = json!({
465            "model": self.model,
466            "messages": full_history,
467            "documents": completion_request.documents,
468            "temperature": completion_request.temperature,
469            "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
470        });
471
472        if let Some(ref params) = completion_request.additional_params {
473            Ok(json_utils::merge(request.clone(), params.clone()))
474        } else {
475            Ok(request)
476        }
477    }
478}
479
480impl completion::CompletionModel for CompletionModel {
481    type Response = CompletionResponse;
482    type StreamingResponse = StreamingCompletionResponse;
483
484    #[cfg_attr(feature = "worker", worker::send)]
485    async fn completion(
486        &self,
487        completion_request: completion::CompletionRequest,
488    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
489        let request = self.create_completion_request(completion_request)?;
490        tracing::debug!(
491            "Cohere request: {}",
492            serde_json::to_string_pretty(&request)?
493        );
494
495        let response = self.client.post("/v2/chat").json(&request).send().await?;
496
497        if response.status().is_success() {
498            let text_response = response.text().await?;
499            tracing::debug!("Cohere response text: {}", text_response);
500
501            let json_response: CompletionResponse = serde_json::from_str(&text_response)?;
502            let completion: completion::CompletionResponse<CompletionResponse> =
503                json_response.try_into()?;
504            Ok(completion)
505        } else {
506            Err(CompletionError::ProviderError(response.text().await?))
507        }
508    }
509
510    #[cfg_attr(feature = "worker", worker::send)]
511    async fn stream(
512        &self,
513        request: CompletionRequest,
514    ) -> Result<
515        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
516        CompletionError,
517    > {
518        CompletionModel::stream(self, request).await
519    }
520}
521#[cfg(test)]
522mod tests {
523    use super::*;
524    use serde_path_to_error::deserialize;
525
526    #[test]
527    fn test_deserialize_completion_response() {
528        let json_data = r#"
529        {
530            "id": "abc123",
531            "message": {
532                "role": "assistant",
533                "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
534                "tool_calls": [
535                        {
536                            "id": "subtract_sm6ps6fb6y9f",
537                            "type": "function",
538                            "function": {
539                                "name": "subtract",
540                                "arguments": "{\"x\":5,\"y\":2}"
541                            }
542                        }
543                    ]
544                },
545                "finish_reason": "TOOL_CALL",
546                "usage": {
547                "billed_units": {
548                    "input_tokens": 78,
549                    "output_tokens": 27
550                },
551                "tokens": {
552                    "input_tokens": 1028,
553                    "output_tokens": 63
554                }
555            }
556        }
557        "#;
558
559        let mut deserializer = serde_json::Deserializer::from_str(json_data);
560        let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
561
562        let response = result.unwrap();
563        let (_, citations, tool_calls) = response.message();
564        let CompletionResponse {
565            id,
566            finish_reason,
567            usage,
568            ..
569        } = response;
570
571        assert_eq!(id, "abc123");
572        assert_eq!(finish_reason, FinishReason::ToolCall);
573
574        let Usage {
575            billed_units,
576            tokens,
577        } = usage.unwrap();
578        let BilledUnits {
579            input_tokens: billed_input_tokens,
580            output_tokens: billed_output_tokens,
581            ..
582        } = billed_units.unwrap();
583        let Tokens {
584            input_tokens,
585            output_tokens,
586        } = tokens.unwrap();
587
588        assert_eq!(billed_input_tokens.unwrap(), 78.0);
589        assert_eq!(billed_output_tokens.unwrap(), 27.0);
590        assert_eq!(input_tokens.unwrap(), 1028.0);
591        assert_eq!(output_tokens.unwrap(), 63.0);
592
593        assert!(citations.is_empty());
594        assert_eq!(tool_calls.len(), 1);
595
596        let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
597
598        assert_eq!(name, "subtract");
599        assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
600    }
601
602    #[test]
603    fn test_convert_completion_message_to_message_and_back() {
604        let completion_message = completion::Message::User {
605            content: OneOrMany::one(completion::message::UserContent::Text(
606                completion::message::Text {
607                    text: "Hello, world!".to_string(),
608                },
609            )),
610        };
611
612        let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
613        let _converted_back: Vec<completion::Message> = messages
614            .into_iter()
615            .map(|msg| msg.try_into().unwrap())
616            .collect::<Vec<_>>();
617    }
618
619    #[test]
620    fn test_convert_message_to_completion_message_and_back() {
621        let message = Message::User {
622            content: OneOrMany::one(UserContent::Text {
623                text: "Hello, world!".to_string(),
624            }),
625        };
626
627        let completion_message: completion::Message = message.clone().try_into().unwrap();
628        let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
629    }
630}