1use crate::{
2 completion::{self, CompletionError},
3 json_utils, message, OneOrMany,
4};
5use std::collections::HashMap;
6
7use super::client::Client;
8use crate::completion::CompletionRequest;
9use crate::providers::cohere::streaming::StreamingCompletionResponse;
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12
13#[derive(Debug, Deserialize)]
14pub struct CompletionResponse {
15 pub id: String,
16 pub finish_reason: FinishReason,
17 message: Message,
18 #[serde(default)]
19 pub usage: Option<Usage>,
20}
21
22impl CompletionResponse {
23 pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
25 let Message::Assistant {
26 content,
27 citations,
28 tool_calls,
29 ..
30 } = self.message.clone()
31 else {
32 unreachable!("Completion responses will only return an assistant message")
33 };
34
35 (content, citations, tool_calls)
36 }
37}
38
39#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
40#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
41pub enum FinishReason {
42 MaxTokens,
43 StopSequence,
44 Complete,
45 Error,
46 ToolCall,
47}
48
49#[derive(Debug, Deserialize, Clone)]
50pub struct Usage {
51 #[serde(default)]
52 pub billed_units: Option<BilledUnits>,
53 #[serde(default)]
54 pub tokens: Option<Tokens>,
55}
56
57#[derive(Debug, Deserialize, Clone)]
58pub struct BilledUnits {
59 #[serde(default)]
60 pub output_tokens: Option<f64>,
61 #[serde(default)]
62 pub classifications: Option<f64>,
63 #[serde(default)]
64 pub search_units: Option<f64>,
65 #[serde(default)]
66 pub input_tokens: Option<f64>,
67}
68
69#[derive(Debug, Deserialize, Clone)]
70pub struct Tokens {
71 #[serde(default)]
72 pub input_tokens: Option<f64>,
73 #[serde(default)]
74 pub output_tokens: Option<f64>,
75}
76
77impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
78 type Error = CompletionError;
79
80 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
81 let (content, _, tool_calls) = response.message();
82
83 let model_response = if !tool_calls.is_empty() {
84 OneOrMany::many(
85 tool_calls
86 .into_iter()
87 .filter_map(|tool_call| {
88 let ToolCallFunction { name, arguments } = tool_call.function?;
89 let id = tool_call.id.unwrap_or_else(|| name.clone());
90
91 Some(completion::AssistantContent::tool_call(id, name, arguments))
92 })
93 .collect::<Vec<_>>(),
94 )
95 .expect("We have atleast 1 tool call in this if block")
96 } else {
97 OneOrMany::many(content.into_iter().map(|content| match content {
98 AssistantContent::Text { text } => completion::AssistantContent::text(text),
99 }))
100 .map_err(|_| {
101 CompletionError::ResponseError(
102 "Response contained no message or tool call (empty)".to_owned(),
103 )
104 })?
105 };
106
107 Ok(completion::CompletionResponse {
108 choice: OneOrMany::many(model_response).expect("There is atleast one content"),
109 raw_response: response,
110 })
111 }
112}
113
114#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
115pub struct Document {
116 pub id: String,
117 pub data: HashMap<String, serde_json::Value>,
118}
119
120impl From<completion::Document> for Document {
121 fn from(document: completion::Document) -> Self {
122 let mut data: HashMap<String, serde_json::Value> = HashMap::new();
123
124 document
127 .additional_props
128 .into_iter()
129 .for_each(|(key, value)| {
130 data.insert(key, value.into());
131 });
132
133 data.insert("text".to_string(), document.text.into());
134
135 Self {
136 id: document.id,
137 data,
138 }
139 }
140}
141
142#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
143pub struct ToolCall {
144 #[serde(default)]
145 pub id: Option<String>,
146 #[serde(default)]
147 pub r#type: Option<ToolType>,
148 #[serde(default)]
149 pub function: Option<ToolCallFunction>,
150}
151
152#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
153pub struct ToolCallFunction {
154 pub name: String,
155 #[serde(with = "json_utils::stringified_json")]
156 pub arguments: serde_json::Value,
157}
158
159#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
160#[serde(rename_all = "lowercase")]
161pub enum ToolType {
162 #[default]
163 Function,
164}
165
166#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
167pub struct Tool {
168 pub r#type: ToolType,
169 pub function: Function,
170}
171
172#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
173pub struct Function {
174 pub name: String,
175 #[serde(default)]
176 pub description: Option<String>,
177 pub parameters: serde_json::Value,
178}
179
180impl From<completion::ToolDefinition> for Tool {
181 fn from(tool: completion::ToolDefinition) -> Self {
182 Self {
183 r#type: ToolType::default(),
184 function: Function {
185 name: tool.name,
186 description: Some(tool.description),
187 parameters: tool.parameters,
188 },
189 }
190 }
191}
192
193#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
194#[serde(tag = "role", rename_all = "lowercase")]
195pub enum Message {
196 User {
197 content: OneOrMany<UserContent>,
198 },
199
200 Assistant {
201 #[serde(default)]
202 content: Vec<AssistantContent>,
203 #[serde(default)]
204 citations: Vec<Citation>,
205 #[serde(default)]
206 tool_calls: Vec<ToolCall>,
207 #[serde(default)]
208 tool_plan: Option<String>,
209 },
210
211 Tool {
212 content: OneOrMany<ToolResultContent>,
213 tool_call_id: String,
214 },
215
216 System {
217 content: String,
218 },
219}
220
221#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
222#[serde(tag = "type", rename_all = "lowercase")]
223pub enum UserContent {
224 Text { text: String },
225 ImageUrl { image_url: ImageUrl },
226}
227
228#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
229#[serde(tag = "type", rename_all = "lowercase")]
230pub enum AssistantContent {
231 Text { text: String },
232}
233
234#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
235pub struct ImageUrl {
236 pub url: String,
237}
238
239#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
240pub enum ToolResultContent {
241 Text { text: String },
242 Document { document: Document },
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
246pub struct Citation {
247 #[serde(default)]
248 pub start: Option<u32>,
249 #[serde(default)]
250 pub end: Option<u32>,
251 #[serde(default)]
252 pub text: Option<String>,
253 #[serde(rename = "type")]
254 pub citation_type: Option<CitationType>,
255 #[serde(default)]
256 pub sources: Vec<Source>,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
260#[serde(tag = "type", rename_all = "lowercase")]
261pub enum Source {
262 Document {
263 id: Option<String>,
264 document: Option<serde_json::Map<String, serde_json::Value>>,
265 },
266 Tool {
267 id: Option<String>,
268 tool_output: Option<serde_json::Map<String, serde_json::Value>>,
269 },
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
273#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
274pub enum CitationType {
275 TextContent,
276 Plan,
277}
278
279impl TryFrom<message::Message> for Vec<Message> {
280 type Error = message::MessageError;
281
282 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
283 Ok(match message {
284 message::Message::User { content } => content
285 .into_iter()
286 .map(|content| match content {
287 message::UserContent::Text(message::Text { text }) => Ok(Message::User {
288 content: OneOrMany::one(UserContent::Text { text }),
289 }),
290 message::UserContent::ToolResult(message::ToolResult { id, content }) => {
291 Ok(Message::Tool {
292 tool_call_id: id,
293 content: content.try_map(|content| match content {
294 message::ToolResultContent::Text(text) => {
295 Ok(ToolResultContent::Text { text: text.text })
296 }
297 _ => Err(message::MessageError::ConversionError(
298 "Only text tool result content is supported by Cohere"
299 .to_owned(),
300 )),
301 })?,
302 })
303 }
304 _ => Err(message::MessageError::ConversionError(
305 "Only text content is supported by Cohere".to_owned(),
306 )),
307 })
308 .collect::<Result<Vec<_>, _>>()?,
309 message::Message::Assistant { content } => {
310 let mut text_content = vec![];
311 let mut tool_calls = vec![];
312 content.into_iter().for_each(|content| match content {
313 message::AssistantContent::Text(message::Text { text }) => {
314 text_content.push(AssistantContent::Text { text });
315 }
316 message::AssistantContent::ToolCall(message::ToolCall {
317 id,
318 function: message::ToolFunction { name, arguments },
319 }) => {
320 tool_calls.push(ToolCall {
321 id: Some(id),
322 r#type: Some(ToolType::Function),
323 function: Some(ToolCallFunction {
324 name,
325 arguments: serde_json::to_value(arguments).unwrap_or_default(),
326 }),
327 });
328 }
329 });
330
331 vec![Message::Assistant {
332 content: text_content,
333 citations: vec![],
334 tool_calls,
335 tool_plan: None,
336 }]
337 }
338 })
339 }
340}
341
342impl TryFrom<Message> for message::Message {
343 type Error = message::MessageError;
344
345 fn try_from(message: Message) -> Result<Self, Self::Error> {
346 match message {
347 Message::User { content } => Ok(message::Message::User {
348 content: content.map(|content| match content {
349 UserContent::Text { text } => {
350 message::UserContent::Text(message::Text { text })
351 }
352 UserContent::ImageUrl { image_url } => message::UserContent::image(
353 image_url.url,
354 Some(message::ContentFormat::String),
355 None,
356 None,
357 ),
358 }),
359 }),
360 Message::Assistant {
361 content,
362 tool_calls,
363 ..
364 } => {
365 let mut content = content
366 .into_iter()
367 .map(|content| match content {
368 AssistantContent::Text { text } => message::AssistantContent::text(text),
369 })
370 .collect::<Vec<_>>();
371
372 content.extend(tool_calls.into_iter().filter_map(|tool_call| {
373 let ToolCallFunction { name, arguments } = tool_call.function?;
374
375 Some(message::AssistantContent::tool_call(
376 tool_call.id.unwrap_or_else(|| name.clone()),
377 name,
378 arguments,
379 ))
380 }));
381
382 let content = OneOrMany::many(content).map_err(|_| {
383 message::MessageError::ConversionError(
384 "Expected either text content or tool calls".to_string(),
385 )
386 })?;
387
388 Ok(message::Message::Assistant { content })
389 }
390 Message::Tool {
391 content,
392 tool_call_id,
393 } => {
394 let content = content.try_map(|content| {
395 Ok(match content {
396 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
397 ToolResultContent::Document { document } => {
398 message::ToolResultContent::text(
399 serde_json::to_string(&document.data).map_err(|e| {
400 message::MessageError::ConversionError(
401 format!("Failed to convert tool result document content into text: {e}"),
402 )
403 })?,
404 )
405 }
406 })
407 })?;
408
409 Ok(message::Message::User {
410 content: OneOrMany::one(message::UserContent::tool_result(
411 tool_call_id,
412 content,
413 )),
414 })
415 }
416 Message::System { content } => Ok(message::Message::user(content)),
417 }
418 }
419}
420
421#[derive(Clone)]
422pub struct CompletionModel {
423 pub(crate) client: Client,
424 pub model: String,
425}
426
427impl CompletionModel {
428 pub fn new(client: Client, model: &str) -> Self {
429 Self {
430 client,
431 model: model.to_string(),
432 }
433 }
434
435 pub(crate) fn create_completion_request(
436 &self,
437 completion_request: CompletionRequest,
438 ) -> Result<Value, CompletionError> {
439 let mut partial_history = vec![];
441 if let Some(docs) = completion_request.normalized_documents() {
442 partial_history.push(docs);
443 }
444 partial_history.extend(completion_request.chat_history);
445
446 let mut full_history: Vec<Message> = completion_request
448 .preamble
449 .map_or_else(Vec::new, |preamble| {
450 vec![Message::System { content: preamble }]
451 });
452
453 full_history.extend(
455 partial_history
456 .into_iter()
457 .map(message::Message::try_into)
458 .collect::<Result<Vec<Vec<Message>>, _>>()?
459 .into_iter()
460 .flatten()
461 .collect::<Vec<_>>(),
462 );
463
464 let request = json!({
465 "model": self.model,
466 "messages": full_history,
467 "documents": completion_request.documents,
468 "temperature": completion_request.temperature,
469 "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
470 });
471
472 if let Some(ref params) = completion_request.additional_params {
473 Ok(json_utils::merge(request.clone(), params.clone()))
474 } else {
475 Ok(request)
476 }
477 }
478}
479
480impl completion::CompletionModel for CompletionModel {
481 type Response = CompletionResponse;
482 type StreamingResponse = StreamingCompletionResponse;
483
484 #[cfg_attr(feature = "worker", worker::send)]
485 async fn completion(
486 &self,
487 completion_request: completion::CompletionRequest,
488 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
489 let request = self.create_completion_request(completion_request)?;
490 tracing::debug!(
491 "Cohere request: {}",
492 serde_json::to_string_pretty(&request)?
493 );
494
495 let response = self.client.post("/v2/chat").json(&request).send().await?;
496
497 if response.status().is_success() {
498 let text_response = response.text().await?;
499 tracing::debug!("Cohere response text: {}", text_response);
500
501 let json_response: CompletionResponse = serde_json::from_str(&text_response)?;
502 let completion: completion::CompletionResponse<CompletionResponse> =
503 json_response.try_into()?;
504 Ok(completion)
505 } else {
506 Err(CompletionError::ProviderError(response.text().await?))
507 }
508 }
509
510 #[cfg_attr(feature = "worker", worker::send)]
511 async fn stream(
512 &self,
513 request: CompletionRequest,
514 ) -> Result<
515 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
516 CompletionError,
517 > {
518 CompletionModel::stream(self, request).await
519 }
520}
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use serde_path_to_error::deserialize;
525
526 #[test]
527 fn test_deserialize_completion_response() {
528 let json_data = r#"
529 {
530 "id": "abc123",
531 "message": {
532 "role": "assistant",
533 "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
534 "tool_calls": [
535 {
536 "id": "subtract_sm6ps6fb6y9f",
537 "type": "function",
538 "function": {
539 "name": "subtract",
540 "arguments": "{\"x\":5,\"y\":2}"
541 }
542 }
543 ]
544 },
545 "finish_reason": "TOOL_CALL",
546 "usage": {
547 "billed_units": {
548 "input_tokens": 78,
549 "output_tokens": 27
550 },
551 "tokens": {
552 "input_tokens": 1028,
553 "output_tokens": 63
554 }
555 }
556 }
557 "#;
558
559 let mut deserializer = serde_json::Deserializer::from_str(json_data);
560 let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
561
562 let response = result.unwrap();
563 let (_, citations, tool_calls) = response.message();
564 let CompletionResponse {
565 id,
566 finish_reason,
567 usage,
568 ..
569 } = response;
570
571 assert_eq!(id, "abc123");
572 assert_eq!(finish_reason, FinishReason::ToolCall);
573
574 let Usage {
575 billed_units,
576 tokens,
577 } = usage.unwrap();
578 let BilledUnits {
579 input_tokens: billed_input_tokens,
580 output_tokens: billed_output_tokens,
581 ..
582 } = billed_units.unwrap();
583 let Tokens {
584 input_tokens,
585 output_tokens,
586 } = tokens.unwrap();
587
588 assert_eq!(billed_input_tokens.unwrap(), 78.0);
589 assert_eq!(billed_output_tokens.unwrap(), 27.0);
590 assert_eq!(input_tokens.unwrap(), 1028.0);
591 assert_eq!(output_tokens.unwrap(), 63.0);
592
593 assert!(citations.is_empty());
594 assert_eq!(tool_calls.len(), 1);
595
596 let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
597
598 assert_eq!(name, "subtract");
599 assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
600 }
601
602 #[test]
603 fn test_convert_completion_message_to_message_and_back() {
604 let completion_message = completion::Message::User {
605 content: OneOrMany::one(completion::message::UserContent::Text(
606 completion::message::Text {
607 text: "Hello, world!".to_string(),
608 },
609 )),
610 };
611
612 let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
613 let _converted_back: Vec<completion::Message> = messages
614 .into_iter()
615 .map(|msg| msg.try_into().unwrap())
616 .collect::<Vec<_>>();
617 }
618
619 #[test]
620 fn test_convert_message_to_completion_message_and_back() {
621 let message = Message::User {
622 content: OneOrMany::one(UserContent::Text {
623 text: "Hello, world!".to_string(),
624 }),
625 };
626
627 let completion_message: completion::Message = message.clone().try_into().unwrap();
628 let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
629 }
630}