1use crate::{
2 OneOrMany,
3 completion::{self, CompletionError},
4 json_utils, message,
5};
6use std::collections::HashMap;
7
8use super::client::Client;
9use crate::completion::CompletionRequest;
10use crate::providers::cohere::streaming::StreamingCompletionResponse;
11use serde::{Deserialize, Serialize};
12use serde_json::{Value, json};
13
14#[derive(Debug, Deserialize)]
15pub struct CompletionResponse {
16 pub id: String,
17 pub finish_reason: FinishReason,
18 message: Message,
19 #[serde(default)]
20 pub usage: Option<Usage>,
21}
22
23impl CompletionResponse {
24 pub fn message(&self) -> (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>) {
26 let Message::Assistant {
27 content,
28 citations,
29 tool_calls,
30 ..
31 } = self.message.clone()
32 else {
33 unreachable!("Completion responses will only return an assistant message")
34 };
35
36 (content, citations, tool_calls)
37 }
38}
39
40#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
41#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
42pub enum FinishReason {
43 MaxTokens,
44 StopSequence,
45 Complete,
46 Error,
47 ToolCall,
48}
49
50#[derive(Debug, Deserialize, Clone)]
51pub struct Usage {
52 #[serde(default)]
53 pub billed_units: Option<BilledUnits>,
54 #[serde(default)]
55 pub tokens: Option<Tokens>,
56}
57
58#[derive(Debug, Deserialize, Clone)]
59pub struct BilledUnits {
60 #[serde(default)]
61 pub output_tokens: Option<f64>,
62 #[serde(default)]
63 pub classifications: Option<f64>,
64 #[serde(default)]
65 pub search_units: Option<f64>,
66 #[serde(default)]
67 pub input_tokens: Option<f64>,
68}
69
70#[derive(Debug, Deserialize, Clone)]
71pub struct Tokens {
72 #[serde(default)]
73 pub input_tokens: Option<f64>,
74 #[serde(default)]
75 pub output_tokens: Option<f64>,
76}
77
78impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
79 type Error = CompletionError;
80
81 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
82 let (content, _, tool_calls) = response.message();
83
84 let model_response = if !tool_calls.is_empty() {
85 OneOrMany::many(
86 tool_calls
87 .into_iter()
88 .filter_map(|tool_call| {
89 let ToolCallFunction { name, arguments } = tool_call.function?;
90 let id = tool_call.id.unwrap_or_else(|| name.clone());
91
92 Some(completion::AssistantContent::tool_call(id, name, arguments))
93 })
94 .collect::<Vec<_>>(),
95 )
96 .expect("We have atleast 1 tool call in this if block")
97 } else {
98 OneOrMany::many(content.into_iter().map(|content| match content {
99 AssistantContent::Text { text } => completion::AssistantContent::text(text),
100 }))
101 .map_err(|_| {
102 CompletionError::ResponseError(
103 "Response contained no message or tool call (empty)".to_owned(),
104 )
105 })?
106 };
107
108 Ok(completion::CompletionResponse {
109 choice: OneOrMany::many(model_response).expect("There is atleast one content"),
110 raw_response: response,
111 })
112 }
113}
114
115#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
116pub struct Document {
117 pub id: String,
118 pub data: HashMap<String, serde_json::Value>,
119}
120
121impl From<completion::Document> for Document {
122 fn from(document: completion::Document) -> Self {
123 let mut data: HashMap<String, serde_json::Value> = HashMap::new();
124
125 document
128 .additional_props
129 .into_iter()
130 .for_each(|(key, value)| {
131 data.insert(key, value.into());
132 });
133
134 data.insert("text".to_string(), document.text.into());
135
136 Self {
137 id: document.id,
138 data,
139 }
140 }
141}
142
143#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
144pub struct ToolCall {
145 #[serde(default)]
146 pub id: Option<String>,
147 #[serde(default)]
148 pub r#type: Option<ToolType>,
149 #[serde(default)]
150 pub function: Option<ToolCallFunction>,
151}
152
153#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
154pub struct ToolCallFunction {
155 pub name: String,
156 #[serde(with = "json_utils::stringified_json")]
157 pub arguments: serde_json::Value,
158}
159
160#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
161#[serde(rename_all = "lowercase")]
162pub enum ToolType {
163 #[default]
164 Function,
165}
166
167#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
168pub struct Tool {
169 pub r#type: ToolType,
170 pub function: Function,
171}
172
173#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
174pub struct Function {
175 pub name: String,
176 #[serde(default)]
177 pub description: Option<String>,
178 pub parameters: serde_json::Value,
179}
180
181impl From<completion::ToolDefinition> for Tool {
182 fn from(tool: completion::ToolDefinition) -> Self {
183 Self {
184 r#type: ToolType::default(),
185 function: Function {
186 name: tool.name,
187 description: Some(tool.description),
188 parameters: tool.parameters,
189 },
190 }
191 }
192}
193
194#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
195#[serde(tag = "role", rename_all = "lowercase")]
196pub enum Message {
197 User {
198 content: OneOrMany<UserContent>,
199 },
200
201 Assistant {
202 #[serde(default)]
203 content: Vec<AssistantContent>,
204 #[serde(default)]
205 citations: Vec<Citation>,
206 #[serde(default)]
207 tool_calls: Vec<ToolCall>,
208 #[serde(default)]
209 tool_plan: Option<String>,
210 },
211
212 Tool {
213 content: OneOrMany<ToolResultContent>,
214 tool_call_id: String,
215 },
216
217 System {
218 content: String,
219 },
220}
221
222#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
223#[serde(tag = "type", rename_all = "lowercase")]
224pub enum UserContent {
225 Text { text: String },
226 ImageUrl { image_url: ImageUrl },
227}
228
229#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
230#[serde(tag = "type", rename_all = "lowercase")]
231pub enum AssistantContent {
232 Text { text: String },
233}
234
235#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
236pub struct ImageUrl {
237 pub url: String,
238}
239
240#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
241pub enum ToolResultContent {
242 Text { text: String },
243 Document { document: Document },
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
247pub struct Citation {
248 #[serde(default)]
249 pub start: Option<u32>,
250 #[serde(default)]
251 pub end: Option<u32>,
252 #[serde(default)]
253 pub text: Option<String>,
254 #[serde(rename = "type")]
255 pub citation_type: Option<CitationType>,
256 #[serde(default)]
257 pub sources: Vec<Source>,
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
261#[serde(tag = "type", rename_all = "lowercase")]
262pub enum Source {
263 Document {
264 id: Option<String>,
265 document: Option<serde_json::Map<String, serde_json::Value>>,
266 },
267 Tool {
268 id: Option<String>,
269 tool_output: Option<serde_json::Map<String, serde_json::Value>>,
270 },
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
274#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
275pub enum CitationType {
276 TextContent,
277 Plan,
278}
279
280impl TryFrom<message::Message> for Vec<Message> {
281 type Error = message::MessageError;
282
283 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
284 Ok(match message {
285 message::Message::User { content } => content
286 .into_iter()
287 .map(|content| match content {
288 message::UserContent::Text(message::Text { text }) => Ok(Message::User {
289 content: OneOrMany::one(UserContent::Text { text }),
290 }),
291 message::UserContent::ToolResult(message::ToolResult {
292 id, content, ..
293 }) => Ok(Message::Tool {
294 tool_call_id: id,
295 content: content.try_map(|content| match content {
296 message::ToolResultContent::Text(text) => {
297 Ok(ToolResultContent::Text { text: text.text })
298 }
299 _ => Err(message::MessageError::ConversionError(
300 "Only text tool result content is supported by Cohere".to_owned(),
301 )),
302 })?,
303 }),
304 _ => Err(message::MessageError::ConversionError(
305 "Only text content is supported by Cohere".to_owned(),
306 )),
307 })
308 .collect::<Result<Vec<_>, _>>()?,
309 message::Message::Assistant { content, .. } => {
310 let mut text_content = vec![];
311 let mut tool_calls = vec![];
312 content.into_iter().for_each(|content| match content {
313 message::AssistantContent::Text(message::Text { text }) => {
314 text_content.push(AssistantContent::Text { text });
315 }
316 message::AssistantContent::ToolCall(message::ToolCall {
317 id,
318 function:
319 message::ToolFunction {
320 name, arguments, ..
321 },
322 ..
323 }) => {
324 tool_calls.push(ToolCall {
325 id: Some(id),
326 r#type: Some(ToolType::Function),
327 function: Some(ToolCallFunction {
328 name,
329 arguments: serde_json::to_value(arguments).unwrap_or_default(),
330 }),
331 });
332 }
333 });
334
335 vec![Message::Assistant {
336 content: text_content,
337 citations: vec![],
338 tool_calls,
339 tool_plan: None,
340 }]
341 }
342 })
343 }
344}
345
346impl TryFrom<Message> for message::Message {
347 type Error = message::MessageError;
348
349 fn try_from(message: Message) -> Result<Self, Self::Error> {
350 match message {
351 Message::User { content } => Ok(message::Message::User {
352 content: content.map(|content| match content {
353 UserContent::Text { text } => {
354 message::UserContent::Text(message::Text { text })
355 }
356 UserContent::ImageUrl { image_url } => message::UserContent::image(
357 image_url.url,
358 Some(message::ContentFormat::String),
359 None,
360 None,
361 ),
362 }),
363 }),
364 Message::Assistant {
365 content,
366 tool_calls,
367 ..
368 } => {
369 let mut content = content
370 .into_iter()
371 .map(|content| match content {
372 AssistantContent::Text { text } => message::AssistantContent::text(text),
373 })
374 .collect::<Vec<_>>();
375
376 content.extend(tool_calls.into_iter().filter_map(|tool_call| {
377 let ToolCallFunction { name, arguments } = tool_call.function?;
378
379 Some(message::AssistantContent::tool_call(
380 tool_call.id.unwrap_or_else(|| name.clone()),
381 name,
382 arguments,
383 ))
384 }));
385
386 let content = OneOrMany::many(content).map_err(|_| {
387 message::MessageError::ConversionError(
388 "Expected either text content or tool calls".to_string(),
389 )
390 })?;
391
392 Ok(message::Message::Assistant { id: None, content })
393 }
394 Message::Tool {
395 content,
396 tool_call_id,
397 } => {
398 let content = content.try_map(|content| {
399 Ok(match content {
400 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
401 ToolResultContent::Document { document } => {
402 message::ToolResultContent::text(
403 serde_json::to_string(&document.data).map_err(|e| {
404 message::MessageError::ConversionError(
405 format!("Failed to convert tool result document content into text: {e}"),
406 )
407 })?,
408 )
409 }
410 })
411 })?;
412
413 Ok(message::Message::User {
414 content: OneOrMany::one(message::UserContent::tool_result(
415 tool_call_id,
416 content,
417 )),
418 })
419 }
420 Message::System { content } => Ok(message::Message::user(content)),
421 }
422 }
423}
424
425#[derive(Clone)]
426pub struct CompletionModel {
427 pub(crate) client: Client,
428 pub model: String,
429}
430
431impl CompletionModel {
432 pub fn new(client: Client, model: &str) -> Self {
433 Self {
434 client,
435 model: model.to_string(),
436 }
437 }
438
439 pub(crate) fn create_completion_request(
440 &self,
441 completion_request: CompletionRequest,
442 ) -> Result<Value, CompletionError> {
443 let mut partial_history = vec![];
445 if let Some(docs) = completion_request.normalized_documents() {
446 partial_history.push(docs);
447 }
448 partial_history.extend(completion_request.chat_history);
449
450 let mut full_history: Vec<Message> = completion_request
452 .preamble
453 .map_or_else(Vec::new, |preamble| {
454 vec![Message::System { content: preamble }]
455 });
456
457 full_history.extend(
459 partial_history
460 .into_iter()
461 .map(message::Message::try_into)
462 .collect::<Result<Vec<Vec<Message>>, _>>()?
463 .into_iter()
464 .flatten()
465 .collect::<Vec<_>>(),
466 );
467
468 let request = json!({
469 "model": self.model,
470 "messages": full_history,
471 "documents": completion_request.documents,
472 "temperature": completion_request.temperature,
473 "tools": completion_request.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
474 });
475
476 if let Some(ref params) = completion_request.additional_params {
477 Ok(json_utils::merge(request.clone(), params.clone()))
478 } else {
479 Ok(request)
480 }
481 }
482}
483
484impl completion::CompletionModel for CompletionModel {
485 type Response = CompletionResponse;
486 type StreamingResponse = StreamingCompletionResponse;
487
488 #[cfg_attr(feature = "worker", worker::send)]
489 async fn completion(
490 &self,
491 completion_request: completion::CompletionRequest,
492 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
493 let request = self.create_completion_request(completion_request)?;
494 tracing::debug!(
495 "Cohere request: {}",
496 serde_json::to_string_pretty(&request)?
497 );
498
499 let response = self.client.post("/v2/chat").json(&request).send().await?;
500
501 if response.status().is_success() {
502 let text_response = response.text().await?;
503 tracing::debug!("Cohere response text: {}", text_response);
504
505 let json_response: CompletionResponse = serde_json::from_str(&text_response)?;
506 let completion: completion::CompletionResponse<CompletionResponse> =
507 json_response.try_into()?;
508 Ok(completion)
509 } else {
510 Err(CompletionError::ProviderError(response.text().await?))
511 }
512 }
513
514 #[cfg_attr(feature = "worker", worker::send)]
515 async fn stream(
516 &self,
517 request: CompletionRequest,
518 ) -> Result<
519 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
520 CompletionError,
521 > {
522 CompletionModel::stream(self, request).await
523 }
524}
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use serde_path_to_error::deserialize;
529
530 #[test]
531 fn test_deserialize_completion_response() {
532 let json_data = r#"
533 {
534 "id": "abc123",
535 "message": {
536 "role": "assistant",
537 "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
538 "tool_calls": [
539 {
540 "id": "subtract_sm6ps6fb6y9f",
541 "type": "function",
542 "function": {
543 "name": "subtract",
544 "arguments": "{\"x\":5,\"y\":2}"
545 }
546 }
547 ]
548 },
549 "finish_reason": "TOOL_CALL",
550 "usage": {
551 "billed_units": {
552 "input_tokens": 78,
553 "output_tokens": 27
554 },
555 "tokens": {
556 "input_tokens": 1028,
557 "output_tokens": 63
558 }
559 }
560 }
561 "#;
562
563 let mut deserializer = serde_json::Deserializer::from_str(json_data);
564 let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
565
566 let response = result.unwrap();
567 let (_, citations, tool_calls) = response.message();
568 let CompletionResponse {
569 id,
570 finish_reason,
571 usage,
572 ..
573 } = response;
574
575 assert_eq!(id, "abc123");
576 assert_eq!(finish_reason, FinishReason::ToolCall);
577
578 let Usage {
579 billed_units,
580 tokens,
581 } = usage.unwrap();
582 let BilledUnits {
583 input_tokens: billed_input_tokens,
584 output_tokens: billed_output_tokens,
585 ..
586 } = billed_units.unwrap();
587 let Tokens {
588 input_tokens,
589 output_tokens,
590 } = tokens.unwrap();
591
592 assert_eq!(billed_input_tokens.unwrap(), 78.0);
593 assert_eq!(billed_output_tokens.unwrap(), 27.0);
594 assert_eq!(input_tokens.unwrap(), 1028.0);
595 assert_eq!(output_tokens.unwrap(), 63.0);
596
597 assert!(citations.is_empty());
598 assert_eq!(tool_calls.len(), 1);
599
600 let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
601
602 assert_eq!(name, "subtract");
603 assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
604 }
605
606 #[test]
607 fn test_convert_completion_message_to_message_and_back() {
608 let completion_message = completion::Message::User {
609 content: OneOrMany::one(completion::message::UserContent::Text(
610 completion::message::Text {
611 text: "Hello, world!".to_string(),
612 },
613 )),
614 };
615
616 let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
617 let _converted_back: Vec<completion::Message> = messages
618 .into_iter()
619 .map(|msg| msg.try_into().unwrap())
620 .collect::<Vec<_>>();
621 }
622
623 #[test]
624 fn test_convert_message_to_completion_message_and_back() {
625 let message = Message::User {
626 content: OneOrMany::one(UserContent::Text {
627 text: "Hello, world!".to_string(),
628 }),
629 };
630
631 let completion_message: completion::Message = message.clone().try_into().unwrap();
632 let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
633 }
634}