1use crate::{
2 OneOrMany,
3 completion::{self, CompletionError},
4 json_utils, message,
5};
6use std::collections::HashMap;
7
8use super::client::Client;
9use crate::completion::CompletionRequest;
10use crate::providers::cohere::streaming::StreamingCompletionResponse;
11use serde::{Deserialize, Serialize};
12use serde_json::{Value, json};
13
14#[derive(Debug, Deserialize, Serialize)]
15pub struct CompletionResponse {
16 pub id: String,
17 pub finish_reason: FinishReason,
18 message: Message,
19 #[serde(default)]
20 pub usage: Option<Usage>,
21}
22
23impl CompletionResponse {
24 pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
26 let Message::Assistant {
27 content,
28 citations,
29 tool_calls,
30 ..
31 } = self.message.clone()
32 else {
33 unreachable!("Completion responses will only return an assistant message")
34 };
35
36 (content, citations, tool_calls)
37 }
38}
39
40#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
41#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
42pub enum FinishReason {
43 MaxTokens,
44 StopSequence,
45 Complete,
46 Error,
47 ToolCall,
48}
49
50#[derive(Debug, Deserialize, Clone, Serialize)]
51pub struct Usage {
52 #[serde(default)]
53 pub billed_units: Option<BilledUnits>,
54 #[serde(default)]
55 pub tokens: Option<Tokens>,
56}
57
58#[derive(Debug, Deserialize, Clone, Serialize)]
59pub struct BilledUnits {
60 #[serde(default)]
61 pub output_tokens: Option<f64>,
62 #[serde(default)]
63 pub classifications: Option<f64>,
64 #[serde(default)]
65 pub search_units: Option<f64>,
66 #[serde(default)]
67 pub input_tokens: Option<f64>,
68}
69
70#[derive(Debug, Deserialize, Clone, Serialize)]
71pub struct Tokens {
72 #[serde(default)]
73 pub input_tokens: Option<f64>,
74 #[serde(default)]
75 pub output_tokens: Option<f64>,
76}
77
78impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
79 type Error = CompletionError;
80
81 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
82 let (content, _, tool_calls) = response.message();
83
84 let model_response = if !tool_calls.is_empty() {
85 OneOrMany::many(
86 tool_calls
87 .into_iter()
88 .filter_map(|tool_call| {
89 let ToolCallFunction { name, arguments } = tool_call.function?;
90 let id = tool_call.id.unwrap_or_else(|| name.clone());
91
92 Some(completion::AssistantContent::tool_call(id, name, arguments))
93 })
94 .collect::<Vec<_>>(),
95 )
96 .expect("We have atleast 1 tool call in this if block")
97 } else {
98 OneOrMany::many(content.into_iter().map(|content| match content {
99 AssistantContent::Text { text } => completion::AssistantContent::text(text),
100 }))
101 .map_err(|_| {
102 CompletionError::ResponseError(
103 "Response contained no message or tool call (empty)".to_owned(),
104 )
105 })?
106 };
107
108 let usage = response
109 .usage
110 .as_ref()
111 .and_then(|usage| usage.tokens.as_ref())
112 .map(|tokens| {
113 let input_tokens = tokens.input_tokens.unwrap_or(0.0);
114 let output_tokens = tokens.output_tokens.unwrap_or(0.0);
115
116 completion::Usage {
117 input_tokens: input_tokens as u64,
118 output_tokens: output_tokens as u64,
119 total_tokens: (input_tokens + output_tokens) as u64,
120 }
121 })
122 .unwrap_or_default();
123
124 Ok(completion::CompletionResponse {
125 choice: OneOrMany::many(model_response).expect("There is atleast one content"),
126 usage,
127 raw_response: response,
128 })
129 }
130}
131
132#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
133pub struct Document {
134 pub id: String,
135 pub data: HashMap<String, serde_json::Value>,
136}
137
138impl From<completion::Document> for Document {
139 fn from(document: completion::Document) -> Self {
140 let mut data: HashMap<String, serde_json::Value> = HashMap::new();
141
142 document
145 .additional_props
146 .into_iter()
147 .for_each(|(key, value)| {
148 data.insert(key, value.into());
149 });
150
151 data.insert("text".to_string(), document.text.into());
152
153 Self {
154 id: document.id,
155 data,
156 }
157 }
158}
159
160#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
161pub struct ToolCall {
162 #[serde(default)]
163 pub id: Option<String>,
164 #[serde(default)]
165 pub r#type: Option<ToolType>,
166 #[serde(default)]
167 pub function: Option<ToolCallFunction>,
168}
169
170#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
171pub struct ToolCallFunction {
172 pub name: String,
173 #[serde(with = "json_utils::stringified_json")]
174 pub arguments: serde_json::Value,
175}
176
177#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
178#[serde(rename_all = "lowercase")]
179pub enum ToolType {
180 #[default]
181 Function,
182}
183
184#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
185pub struct Tool {
186 pub r#type: ToolType,
187 pub function: Function,
188}
189
190#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
191pub struct Function {
192 pub name: String,
193 #[serde(default)]
194 pub description: Option<String>,
195 pub parameters: serde_json::Value,
196}
197
198impl From<completion::ToolDefinition> for Tool {
199 fn from(tool: completion::ToolDefinition) -> Self {
200 Self {
201 r#type: ToolType::default(),
202 function: Function {
203 name: tool.name,
204 description: Some(tool.description),
205 parameters: tool.parameters,
206 },
207 }
208 }
209}
210
211#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
212#[serde(tag = "role", rename_all = "lowercase")]
213pub enum Message {
214 User {
215 content: OneOrMany<UserContent>,
216 },
217
218 Assistant {
219 #[serde(default)]
220 content: Vec<AssistantContent>,
221 #[serde(default)]
222 citations: Vec<Citation>,
223 #[serde(default)]
224 tool_calls: Vec<ToolCall>,
225 #[serde(default)]
226 tool_plan: Option<String>,
227 },
228
229 Tool {
230 content: OneOrMany<ToolResultContent>,
231 tool_call_id: String,
232 },
233
234 System {
235 content: String,
236 },
237}
238
239#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
240#[serde(tag = "type", rename_all = "lowercase")]
241pub enum UserContent {
242 Text { text: String },
243 ImageUrl { image_url: ImageUrl },
244}
245
246#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
247#[serde(tag = "type", rename_all = "lowercase")]
248pub enum AssistantContent {
249 Text { text: String },
250}
251
252#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
253pub struct ImageUrl {
254 pub url: String,
255}
256
257#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
258pub enum ToolResultContent {
259 Text { text: String },
260 Document { document: Document },
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
264pub struct Citation {
265 #[serde(default)]
266 pub start: Option<u32>,
267 #[serde(default)]
268 pub end: Option<u32>,
269 #[serde(default)]
270 pub text: Option<String>,
271 #[serde(rename = "type")]
272 pub citation_type: Option<CitationType>,
273 #[serde(default)]
274 pub sources: Vec<Source>,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
278#[serde(tag = "type", rename_all = "lowercase")]
279pub enum Source {
280 Document {
281 id: Option<String>,
282 document: Option<serde_json::Map<String, serde_json::Value>>,
283 },
284 Tool {
285 id: Option<String>,
286 tool_output: Option<serde_json::Map<String, serde_json::Value>>,
287 },
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
291#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
292pub enum CitationType {
293 TextContent,
294 Plan,
295}
296
297impl TryFrom<message::Message> for Vec<Message> {
298 type Error = message::MessageError;
299
300 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
301 Ok(match message {
302 message::Message::User { content } => content
303 .into_iter()
304 .map(|content| match content {
305 message::UserContent::Text(message::Text { text }) => Ok(Message::User {
306 content: OneOrMany::one(UserContent::Text { text }),
307 }),
308 message::UserContent::ToolResult(message::ToolResult {
309 id, content, ..
310 }) => Ok(Message::Tool {
311 tool_call_id: id,
312 content: content.try_map(|content| match content {
313 message::ToolResultContent::Text(text) => {
314 Ok(ToolResultContent::Text { text: text.text })
315 }
316 _ => Err(message::MessageError::ConversionError(
317 "Only text tool result content is supported by Cohere".to_owned(),
318 )),
319 })?,
320 }),
321 _ => Err(message::MessageError::ConversionError(
322 "Only text content is supported by Cohere".to_owned(),
323 )),
324 })
325 .collect::<Result<Vec<_>, _>>()?,
326 message::Message::Assistant { content, .. } => {
327 let mut text_content = vec![];
328 let mut tool_calls = vec![];
329 content.into_iter().for_each(|content| match content {
330 message::AssistantContent::Text(message::Text { text }) => {
331 text_content.push(AssistantContent::Text { text });
332 }
333 message::AssistantContent::ToolCall(message::ToolCall {
334 id,
335 function:
336 message::ToolFunction {
337 name, arguments, ..
338 },
339 ..
340 }) => {
341 tool_calls.push(ToolCall {
342 id: Some(id),
343 r#type: Some(ToolType::Function),
344 function: Some(ToolCallFunction {
345 name,
346 arguments: serde_json::to_value(arguments).unwrap_or_default(),
347 }),
348 });
349 }
350 message::AssistantContent::Reasoning(_) => {
351 unimplemented!("Reasoning is not natively supported on Cohere V2");
352 }
353 });
354
355 vec![Message::Assistant {
356 content: text_content,
357 citations: vec![],
358 tool_calls,
359 tool_plan: None,
360 }]
361 }
362 })
363 }
364}
365
366impl TryFrom<Message> for message::Message {
367 type Error = message::MessageError;
368
369 fn try_from(message: Message) -> Result<Self, Self::Error> {
370 match message {
371 Message::User { content } => Ok(message::Message::User {
372 content: content.map(|content| match content {
373 UserContent::Text { text } => {
374 message::UserContent::Text(message::Text { text })
375 }
376 UserContent::ImageUrl { image_url } => {
377 message::UserContent::image_url(image_url.url, None, None)
378 }
379 }),
380 }),
381 Message::Assistant {
382 content,
383 tool_calls,
384 ..
385 } => {
386 let mut content = content
387 .into_iter()
388 .map(|content| match content {
389 AssistantContent::Text { text } => message::AssistantContent::text(text),
390 })
391 .collect::<Vec<_>>();
392
393 content.extend(tool_calls.into_iter().filter_map(|tool_call| {
394 let ToolCallFunction { name, arguments } = tool_call.function?;
395
396 Some(message::AssistantContent::tool_call(
397 tool_call.id.unwrap_or_else(|| name.clone()),
398 name,
399 arguments,
400 ))
401 }));
402
403 let content = OneOrMany::many(content).map_err(|_| {
404 message::MessageError::ConversionError(
405 "Expected either text content or tool calls".to_string(),
406 )
407 })?;
408
409 Ok(message::Message::Assistant { id: None, content })
410 }
411 Message::Tool {
412 content,
413 tool_call_id,
414 } => {
415 let content = content.try_map(|content| {
416 Ok(match content {
417 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
418 ToolResultContent::Document { document } => {
419 message::ToolResultContent::text(
420 serde_json::to_string(&document.data).map_err(|e| {
421 message::MessageError::ConversionError(
422 format!("Failed to convert tool result document content into text: {e}"),
423 )
424 })?,
425 )
426 }
427 })
428 })?;
429
430 Ok(message::Message::User {
431 content: OneOrMany::one(message::UserContent::tool_result(
432 tool_call_id,
433 content,
434 )),
435 })
436 }
437 Message::System { content } => Ok(message::Message::user(content)),
438 }
439 }
440}
441
442#[derive(Clone)]
443pub struct CompletionModel {
444 pub(crate) client: Client,
445 pub model: String,
446}
447
448impl CompletionModel {
449 pub fn new(client: Client, model: &str) -> Self {
450 Self {
451 client,
452 model: model.to_string(),
453 }
454 }
455
456 pub(crate) fn create_completion_request(
457 &self,
458 completion_request: CompletionRequest,
459 ) -> Result<Value, CompletionError> {
460 let mut partial_history = vec![];
462 if let Some(docs) = completion_request.normalized_documents() {
463 partial_history.push(docs);
464 }
465 partial_history.extend(completion_request.chat_history);
466
467 let mut full_history: Vec<Message> = completion_request
469 .preamble
470 .map_or_else(Vec::new, |preamble| {
471 vec![Message::System { content: preamble }]
472 });
473
474 full_history.extend(
476 partial_history
477 .into_iter()
478 .map(message::Message::try_into)
479 .collect::<Result<Vec<Vec<Message>>, _>>()?
480 .into_iter()
481 .flatten()
482 .collect::<Vec<_>>(),
483 );
484
485 let request = json!({
486 "model": self.model,
487 "messages": full_history,
488 "documents": completion_request.documents,
489 "temperature": completion_request.temperature,
490 "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
491 });
492
493 if let Some(ref params) = completion_request.additional_params {
494 Ok(json_utils::merge(request.clone(), params.clone()))
495 } else {
496 Ok(request)
497 }
498 }
499}
500
501impl completion::CompletionModel for CompletionModel {
502 type Response = CompletionResponse;
503 type StreamingResponse = StreamingCompletionResponse;
504
505 #[cfg_attr(feature = "worker", worker::send)]
506 async fn completion(
507 &self,
508 completion_request: completion::CompletionRequest,
509 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
510 let request = self.create_completion_request(completion_request)?;
511 tracing::debug!(
512 "Cohere request: {}",
513 serde_json::to_string_pretty(&request)?
514 );
515
516 let response = self.client.post("/v2/chat").json(&request).send().await?;
517
518 if response.status().is_success() {
519 let text_response = response.text().await?;
520 tracing::debug!("Cohere response text: {}", text_response);
521
522 let json_response: CompletionResponse = serde_json::from_str(&text_response)?;
523 let completion: completion::CompletionResponse<CompletionResponse> =
524 json_response.try_into()?;
525 Ok(completion)
526 } else {
527 Err(CompletionError::ProviderError(response.text().await?))
528 }
529 }
530
531 #[cfg_attr(feature = "worker", worker::send)]
532 async fn stream(
533 &self,
534 request: CompletionRequest,
535 ) -> Result<
536 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
537 CompletionError,
538 > {
539 CompletionModel::stream(self, request).await
540 }
541}
542#[cfg(test)]
543mod tests {
544 use super::*;
545 use serde_path_to_error::deserialize;
546
547 #[test]
548 fn test_deserialize_completion_response() {
549 let json_data = r#"
550 {
551 "id": "abc123",
552 "message": {
553 "role": "assistant",
554 "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
555 "tool_calls": [
556 {
557 "id": "subtract_sm6ps6fb6y9f",
558 "type": "function",
559 "function": {
560 "name": "subtract",
561 "arguments": "{\"x\":5,\"y\":2}"
562 }
563 }
564 ]
565 },
566 "finish_reason": "TOOL_CALL",
567 "usage": {
568 "billed_units": {
569 "input_tokens": 78,
570 "output_tokens": 27
571 },
572 "tokens": {
573 "input_tokens": 1028,
574 "output_tokens": 63
575 }
576 }
577 }
578 "#;
579
580 let mut deserializer = serde_json::Deserializer::from_str(json_data);
581 let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
582
583 let response = result.unwrap();
584 let (_, citations, tool_calls) = response.message();
585 let CompletionResponse {
586 id,
587 finish_reason,
588 usage,
589 ..
590 } = response;
591
592 assert_eq!(id, "abc123");
593 assert_eq!(finish_reason, FinishReason::ToolCall);
594
595 let Usage {
596 billed_units,
597 tokens,
598 } = usage.unwrap();
599 let BilledUnits {
600 input_tokens: billed_input_tokens,
601 output_tokens: billed_output_tokens,
602 ..
603 } = billed_units.unwrap();
604 let Tokens {
605 input_tokens,
606 output_tokens,
607 } = tokens.unwrap();
608
609 assert_eq!(billed_input_tokens.unwrap(), 78.0);
610 assert_eq!(billed_output_tokens.unwrap(), 27.0);
611 assert_eq!(input_tokens.unwrap(), 1028.0);
612 assert_eq!(output_tokens.unwrap(), 63.0);
613
614 assert!(citations.is_empty());
615 assert_eq!(tool_calls.len(), 1);
616
617 let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
618
619 assert_eq!(name, "subtract");
620 assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
621 }
622
623 #[test]
624 fn test_convert_completion_message_to_message_and_back() {
625 let completion_message = completion::Message::User {
626 content: OneOrMany::one(completion::message::UserContent::Text(
627 completion::message::Text {
628 text: "Hello, world!".to_string(),
629 },
630 )),
631 };
632
633 let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
634 let _converted_back: Vec<completion::Message> = messages
635 .into_iter()
636 .map(|msg| msg.try_into().unwrap())
637 .collect::<Vec<_>>();
638 }
639
640 #[test]
641 fn test_convert_message_to_completion_message_and_back() {
642 let message = Message::User {
643 content: OneOrMany::one(UserContent::Text {
644 text: "Hello, world!".to_string(),
645 }),
646 };
647
648 let completion_message: completion::Message = message.clone().try_into().unwrap();
649 let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
650 }
651}