1use crate::{
2 OneOrMany,
3 completion::{self, CompletionError},
4 json_utils, message,
5};
6use std::collections::HashMap;
7
8use super::client::Client;
9use crate::completion::CompletionRequest;
10use crate::providers::cohere::streaming::StreamingCompletionResponse;
11use serde::{Deserialize, Serialize};
12use serde_json::{Value, json};
13
14#[derive(Debug, Deserialize, Serialize)]
15pub struct CompletionResponse {
16 pub id: String,
17 pub finish_reason: FinishReason,
18 message: Message,
19 #[serde(default)]
20 pub usage: Option<Usage>,
21}
22
23impl CompletionResponse {
24 pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
26 let Message::Assistant {
27 content,
28 citations,
29 tool_calls,
30 ..
31 } = self.message.clone()
32 else {
33 unreachable!("Completion responses will only return an assistant message")
34 };
35
36 (content, citations, tool_calls)
37 }
38}
39
40#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
41#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
42pub enum FinishReason {
43 MaxTokens,
44 StopSequence,
45 Complete,
46 Error,
47 ToolCall,
48}
49
50#[derive(Debug, Deserialize, Clone, Serialize)]
51pub struct Usage {
52 #[serde(default)]
53 pub billed_units: Option<BilledUnits>,
54 #[serde(default)]
55 pub tokens: Option<Tokens>,
56}
57
58#[derive(Debug, Deserialize, Clone, Serialize)]
59pub struct BilledUnits {
60 #[serde(default)]
61 pub output_tokens: Option<f64>,
62 #[serde(default)]
63 pub classifications: Option<f64>,
64 #[serde(default)]
65 pub search_units: Option<f64>,
66 #[serde(default)]
67 pub input_tokens: Option<f64>,
68}
69
70#[derive(Debug, Deserialize, Clone, Serialize)]
71pub struct Tokens {
72 #[serde(default)]
73 pub input_tokens: Option<f64>,
74 #[serde(default)]
75 pub output_tokens: Option<f64>,
76}
77
78impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
79 type Error = CompletionError;
80
81 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
82 let (content, _, tool_calls) = response.message();
83
84 let model_response = if !tool_calls.is_empty() {
85 OneOrMany::many(
86 tool_calls
87 .into_iter()
88 .filter_map(|tool_call| {
89 let ToolCallFunction { name, arguments } = tool_call.function?;
90 let id = tool_call.id.unwrap_or_else(|| name.clone());
91
92 Some(completion::AssistantContent::tool_call(id, name, arguments))
93 })
94 .collect::<Vec<_>>(),
95 )
96 .expect("We have atleast 1 tool call in this if block")
97 } else {
98 OneOrMany::many(content.into_iter().map(|content| match content {
99 AssistantContent::Text { text } => completion::AssistantContent::text(text),
100 }))
101 .map_err(|_| {
102 CompletionError::ResponseError(
103 "Response contained no message or tool call (empty)".to_owned(),
104 )
105 })?
106 };
107
108 let usage = response
109 .usage
110 .as_ref()
111 .and_then(|usage| usage.tokens.as_ref())
112 .map(|tokens| {
113 let input_tokens = tokens.input_tokens.unwrap_or(0.0);
114 let output_tokens = tokens.output_tokens.unwrap_or(0.0);
115
116 completion::Usage {
117 input_tokens: input_tokens as u64,
118 output_tokens: output_tokens as u64,
119 total_tokens: (input_tokens + output_tokens) as u64,
120 }
121 })
122 .unwrap_or_default();
123
124 Ok(completion::CompletionResponse {
125 choice: OneOrMany::many(model_response).expect("There is atleast one content"),
126 usage,
127 raw_response: response,
128 })
129 }
130}
131
132#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
133pub struct Document {
134 pub id: String,
135 pub data: HashMap<String, serde_json::Value>,
136}
137
138impl From<completion::Document> for Document {
139 fn from(document: completion::Document) -> Self {
140 let mut data: HashMap<String, serde_json::Value> = HashMap::new();
141
142 document
145 .additional_props
146 .into_iter()
147 .for_each(|(key, value)| {
148 data.insert(key, value.into());
149 });
150
151 data.insert("text".to_string(), document.text.into());
152
153 Self {
154 id: document.id,
155 data,
156 }
157 }
158}
159
160#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
161pub struct ToolCall {
162 #[serde(default)]
163 pub id: Option<String>,
164 #[serde(default)]
165 pub r#type: Option<ToolType>,
166 #[serde(default)]
167 pub function: Option<ToolCallFunction>,
168}
169
170#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
171pub struct ToolCallFunction {
172 pub name: String,
173 #[serde(with = "json_utils::stringified_json")]
174 pub arguments: serde_json::Value,
175}
176
177#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
178#[serde(rename_all = "lowercase")]
179pub enum ToolType {
180 #[default]
181 Function,
182}
183
184#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
185pub struct Tool {
186 pub r#type: ToolType,
187 pub function: Function,
188}
189
190#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
191pub struct Function {
192 pub name: String,
193 #[serde(default)]
194 pub description: Option<String>,
195 pub parameters: serde_json::Value,
196}
197
198impl From<completion::ToolDefinition> for Tool {
199 fn from(tool: completion::ToolDefinition) -> Self {
200 Self {
201 r#type: ToolType::default(),
202 function: Function {
203 name: tool.name,
204 description: Some(tool.description),
205 parameters: tool.parameters,
206 },
207 }
208 }
209}
210
211#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
212#[serde(tag = "role", rename_all = "lowercase")]
213pub enum Message {
214 User {
215 content: OneOrMany<UserContent>,
216 },
217
218 Assistant {
219 #[serde(default)]
220 content: Vec<AssistantContent>,
221 #[serde(default)]
222 citations: Vec<Citation>,
223 #[serde(default)]
224 tool_calls: Vec<ToolCall>,
225 #[serde(default)]
226 tool_plan: Option<String>,
227 },
228
229 Tool {
230 content: OneOrMany<ToolResultContent>,
231 tool_call_id: String,
232 },
233
234 System {
235 content: String,
236 },
237}
238
239#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
240#[serde(tag = "type", rename_all = "lowercase")]
241pub enum UserContent {
242 Text { text: String },
243 ImageUrl { image_url: ImageUrl },
244}
245
246#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
247#[serde(tag = "type", rename_all = "lowercase")]
248pub enum AssistantContent {
249 Text { text: String },
250}
251
252#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
253pub struct ImageUrl {
254 pub url: String,
255}
256
257#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
258pub enum ToolResultContent {
259 Text { text: String },
260 Document { document: Document },
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
264pub struct Citation {
265 #[serde(default)]
266 pub start: Option<u32>,
267 #[serde(default)]
268 pub end: Option<u32>,
269 #[serde(default)]
270 pub text: Option<String>,
271 #[serde(rename = "type")]
272 pub citation_type: Option<CitationType>,
273 #[serde(default)]
274 pub sources: Vec<Source>,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
278#[serde(tag = "type", rename_all = "lowercase")]
279pub enum Source {
280 Document {
281 id: Option<String>,
282 document: Option<serde_json::Map<String, serde_json::Value>>,
283 },
284 Tool {
285 id: Option<String>,
286 tool_output: Option<serde_json::Map<String, serde_json::Value>>,
287 },
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
291#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
292pub enum CitationType {
293 TextContent,
294 Plan,
295}
296
297impl TryFrom<message::Message> for Vec<Message> {
298 type Error = message::MessageError;
299
300 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
301 Ok(match message {
302 message::Message::User { content } => content
303 .into_iter()
304 .map(|content| match content {
305 message::UserContent::Text(message::Text { text }) => Ok(Message::User {
306 content: OneOrMany::one(UserContent::Text { text }),
307 }),
308 message::UserContent::ToolResult(message::ToolResult {
309 id, content, ..
310 }) => Ok(Message::Tool {
311 tool_call_id: id,
312 content: content.try_map(|content| match content {
313 message::ToolResultContent::Text(text) => {
314 Ok(ToolResultContent::Text { text: text.text })
315 }
316 _ => Err(message::MessageError::ConversionError(
317 "Only text tool result content is supported by Cohere".to_owned(),
318 )),
319 })?,
320 }),
321 _ => Err(message::MessageError::ConversionError(
322 "Only text content is supported by Cohere".to_owned(),
323 )),
324 })
325 .collect::<Result<Vec<_>, _>>()?,
326 message::Message::Assistant { content, .. } => {
327 let mut text_content = vec![];
328 let mut tool_calls = vec![];
329 content.into_iter().for_each(|content| match content {
330 message::AssistantContent::Text(message::Text { text }) => {
331 text_content.push(AssistantContent::Text { text });
332 }
333 message::AssistantContent::ToolCall(message::ToolCall {
334 id,
335 function:
336 message::ToolFunction {
337 name, arguments, ..
338 },
339 ..
340 }) => {
341 tool_calls.push(ToolCall {
342 id: Some(id),
343 r#type: Some(ToolType::Function),
344 function: Some(ToolCallFunction {
345 name,
346 arguments: serde_json::to_value(arguments).unwrap_or_default(),
347 }),
348 });
349 }
350 message::AssistantContent::Reasoning(_) => {
351 unimplemented!("Reasoning is not natively supported on Cohere V2");
352 }
353 });
354
355 vec![Message::Assistant {
356 content: text_content,
357 citations: vec![],
358 tool_calls,
359 tool_plan: None,
360 }]
361 }
362 })
363 }
364}
365
366impl TryFrom<Message> for message::Message {
367 type Error = message::MessageError;
368
369 fn try_from(message: Message) -> Result<Self, Self::Error> {
370 match message {
371 Message::User { content } => Ok(message::Message::User {
372 content: content.map(|content| match content {
373 UserContent::Text { text } => {
374 message::UserContent::Text(message::Text { text })
375 }
376 UserContent::ImageUrl { image_url } => message::UserContent::image(
377 image_url.url,
378 Some(message::ContentFormat::String),
379 None,
380 None,
381 ),
382 }),
383 }),
384 Message::Assistant {
385 content,
386 tool_calls,
387 ..
388 } => {
389 let mut content = content
390 .into_iter()
391 .map(|content| match content {
392 AssistantContent::Text { text } => message::AssistantContent::text(text),
393 })
394 .collect::<Vec<_>>();
395
396 content.extend(tool_calls.into_iter().filter_map(|tool_call| {
397 let ToolCallFunction { name, arguments } = tool_call.function?;
398
399 Some(message::AssistantContent::tool_call(
400 tool_call.id.unwrap_or_else(|| name.clone()),
401 name,
402 arguments,
403 ))
404 }));
405
406 let content = OneOrMany::many(content).map_err(|_| {
407 message::MessageError::ConversionError(
408 "Expected either text content or tool calls".to_string(),
409 )
410 })?;
411
412 Ok(message::Message::Assistant { id: None, content })
413 }
414 Message::Tool {
415 content,
416 tool_call_id,
417 } => {
418 let content = content.try_map(|content| {
419 Ok(match content {
420 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
421 ToolResultContent::Document { document } => {
422 message::ToolResultContent::text(
423 serde_json::to_string(&document.data).map_err(|e| {
424 message::MessageError::ConversionError(
425 format!("Failed to convert tool result document content into text: {e}"),
426 )
427 })?,
428 )
429 }
430 })
431 })?;
432
433 Ok(message::Message::User {
434 content: OneOrMany::one(message::UserContent::tool_result(
435 tool_call_id,
436 content,
437 )),
438 })
439 }
440 Message::System { content } => Ok(message::Message::user(content)),
441 }
442 }
443}
444
445#[derive(Clone)]
446pub struct CompletionModel {
447 pub(crate) client: Client,
448 pub model: String,
449}
450
451impl CompletionModel {
452 pub fn new(client: Client, model: &str) -> Self {
453 Self {
454 client,
455 model: model.to_string(),
456 }
457 }
458
459 pub(crate) fn create_completion_request(
460 &self,
461 completion_request: CompletionRequest,
462 ) -> Result<Value, CompletionError> {
463 let mut partial_history = vec![];
465 if let Some(docs) = completion_request.normalized_documents() {
466 partial_history.push(docs);
467 }
468 partial_history.extend(completion_request.chat_history);
469
470 let mut full_history: Vec<Message> = completion_request
472 .preamble
473 .map_or_else(Vec::new, |preamble| {
474 vec![Message::System { content: preamble }]
475 });
476
477 full_history.extend(
479 partial_history
480 .into_iter()
481 .map(message::Message::try_into)
482 .collect::<Result<Vec<Vec<Message>>, _>>()?
483 .into_iter()
484 .flatten()
485 .collect::<Vec<_>>(),
486 );
487
488 let request = json!({
489 "model": self.model,
490 "messages": full_history,
491 "documents": completion_request.documents,
492 "temperature": completion_request.temperature,
493 "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
494 });
495
496 if let Some(ref params) = completion_request.additional_params {
497 Ok(json_utils::merge(request.clone(), params.clone()))
498 } else {
499 Ok(request)
500 }
501 }
502}
503
504impl completion::CompletionModel for CompletionModel {
505 type Response = CompletionResponse;
506 type StreamingResponse = StreamingCompletionResponse;
507
508 #[cfg_attr(feature = "worker", worker::send)]
509 async fn completion(
510 &self,
511 completion_request: completion::CompletionRequest,
512 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
513 let request = self.create_completion_request(completion_request)?;
514 tracing::debug!(
515 "Cohere request: {}",
516 serde_json::to_string_pretty(&request)?
517 );
518
519 let response = self.client.post("/v2/chat").json(&request).send().await?;
520
521 if response.status().is_success() {
522 let text_response = response.text().await?;
523 tracing::debug!("Cohere response text: {}", text_response);
524
525 let json_response: CompletionResponse = serde_json::from_str(&text_response)?;
526 let completion: completion::CompletionResponse<CompletionResponse> =
527 json_response.try_into()?;
528 Ok(completion)
529 } else {
530 Err(CompletionError::ProviderError(response.text().await?))
531 }
532 }
533
534 #[cfg_attr(feature = "worker", worker::send)]
535 async fn stream(
536 &self,
537 request: CompletionRequest,
538 ) -> Result<
539 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
540 CompletionError,
541 > {
542 CompletionModel::stream(self, request).await
543 }
544}
545#[cfg(test)]
546mod tests {
547 use super::*;
548 use serde_path_to_error::deserialize;
549
550 #[test]
551 fn test_deserialize_completion_response() {
552 let json_data = r#"
553 {
554 "id": "abc123",
555 "message": {
556 "role": "assistant",
557 "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
558 "tool_calls": [
559 {
560 "id": "subtract_sm6ps6fb6y9f",
561 "type": "function",
562 "function": {
563 "name": "subtract",
564 "arguments": "{\"x\":5,\"y\":2}"
565 }
566 }
567 ]
568 },
569 "finish_reason": "TOOL_CALL",
570 "usage": {
571 "billed_units": {
572 "input_tokens": 78,
573 "output_tokens": 27
574 },
575 "tokens": {
576 "input_tokens": 1028,
577 "output_tokens": 63
578 }
579 }
580 }
581 "#;
582
583 let mut deserializer = serde_json::Deserializer::from_str(json_data);
584 let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
585
586 let response = result.unwrap();
587 let (_, citations, tool_calls) = response.message();
588 let CompletionResponse {
589 id,
590 finish_reason,
591 usage,
592 ..
593 } = response;
594
595 assert_eq!(id, "abc123");
596 assert_eq!(finish_reason, FinishReason::ToolCall);
597
598 let Usage {
599 billed_units,
600 tokens,
601 } = usage.unwrap();
602 let BilledUnits {
603 input_tokens: billed_input_tokens,
604 output_tokens: billed_output_tokens,
605 ..
606 } = billed_units.unwrap();
607 let Tokens {
608 input_tokens,
609 output_tokens,
610 } = tokens.unwrap();
611
612 assert_eq!(billed_input_tokens.unwrap(), 78.0);
613 assert_eq!(billed_output_tokens.unwrap(), 27.0);
614 assert_eq!(input_tokens.unwrap(), 1028.0);
615 assert_eq!(output_tokens.unwrap(), 63.0);
616
617 assert!(citations.is_empty());
618 assert_eq!(tool_calls.len(), 1);
619
620 let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
621
622 assert_eq!(name, "subtract");
623 assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
624 }
625
626 #[test]
627 fn test_convert_completion_message_to_message_and_back() {
628 let completion_message = completion::Message::User {
629 content: OneOrMany::one(completion::message::UserContent::Text(
630 completion::message::Text {
631 text: "Hello, world!".to_string(),
632 },
633 )),
634 };
635
636 let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
637 let _converted_back: Vec<completion::Message> = messages
638 .into_iter()
639 .map(|msg| msg.try_into().unwrap())
640 .collect::<Vec<_>>();
641 }
642
643 #[test]
644 fn test_convert_message_to_completion_message_and_back() {
645 let message = Message::User {
646 content: OneOrMany::one(UserContent::Text {
647 text: "Hello, world!".to_string(),
648 }),
649 };
650
651 let completion_message: completion::Message = message.clone().try_into().unwrap();
652 let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
653 }
654}