1use crate::{
2 OneOrMany,
3 completion::{self, CompletionError, GetTokenUsage},
4 http_client::{self, HttpClientExt},
5 json_utils,
6 message::{self, Reasoning, ToolChoice},
7 telemetry::SpanCombinator,
8};
9use std::collections::HashMap;
10
11use super::client::Client;
12use crate::completion::CompletionRequest;
13use crate::providers::cohere::streaming::StreamingCompletionResponse;
14use serde::{Deserialize, Serialize};
15use tracing::{Instrument, Level, enabled, info_span};
16
17#[derive(Debug, Deserialize, Serialize)]
18pub struct CompletionResponse {
19 pub id: String,
20 pub finish_reason: FinishReason,
21 message: Message,
22 #[serde(default)]
23 pub usage: Option<Usage>,
24}
25
26type AssistantMessageParts = (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>);
27
28impl CompletionResponse {
29 pub fn message(&self) -> Result<AssistantMessageParts, CompletionError> {
31 let Message::Assistant {
32 content,
33 citations,
34 tool_calls,
35 ..
36 } = self.message.clone()
37 else {
38 return Err(CompletionError::ResponseError(
39 "completion response did not contain an assistant message".into(),
40 ));
41 };
42
43 Ok((content, citations, tool_calls))
44 }
45}
46
47impl crate::telemetry::ProviderResponseExt for CompletionResponse {
48 type OutputMessage = Message;
49 type Usage = Usage;
50
51 fn get_response_id(&self) -> Option<String> {
52 Some(self.id.clone())
53 }
54
55 fn get_response_model_name(&self) -> Option<String> {
56 None
57 }
58
59 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
60 vec![self.message.clone()]
61 }
62
63 fn get_text_response(&self) -> Option<String> {
64 let Message::Assistant { ref content, .. } = self.message else {
65 return None;
66 };
67
68 let res = content
69 .iter()
70 .filter_map(|x| {
71 if let AssistantContent::Text { text } = x {
72 Some(text.to_string())
73 } else {
74 None
75 }
76 })
77 .collect::<Vec<String>>()
78 .join("\n");
79
80 if res.is_empty() { None } else { Some(res) }
81 }
82
83 fn get_usage(&self) -> Option<Self::Usage> {
84 self.usage.clone()
85 }
86}
87
88#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
89#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
90pub enum FinishReason {
91 MaxTokens,
92 StopSequence,
93 Complete,
94 Error,
95 ToolCall,
96}
97
98#[derive(Debug, Deserialize, Clone, Serialize)]
99pub struct Usage {
100 #[serde(default)]
101 pub billed_units: Option<BilledUnits>,
102 #[serde(default)]
103 pub tokens: Option<Tokens>,
104}
105
106impl GetTokenUsage for Usage {
107 fn token_usage(&self) -> crate::completion::Usage {
108 let mut usage = crate::completion::Usage::new();
109
110 if let Some(ref billed_units) = self.billed_units {
111 usage.input_tokens = billed_units.input_tokens.unwrap_or_default() as u64;
112 usage.output_tokens = billed_units.output_tokens.unwrap_or_default() as u64;
113 usage.total_tokens = usage.input_tokens + usage.output_tokens;
114 }
115
116 usage
117 }
118}
119
120#[derive(Debug, Deserialize, Clone, Serialize)]
121pub struct BilledUnits {
122 #[serde(default)]
123 pub output_tokens: Option<f64>,
124 #[serde(default)]
125 pub classifications: Option<f64>,
126 #[serde(default)]
127 pub search_units: Option<f64>,
128 #[serde(default)]
129 pub input_tokens: Option<f64>,
130}
131
132#[derive(Debug, Deserialize, Clone, Serialize)]
133pub struct Tokens {
134 #[serde(default)]
135 pub input_tokens: Option<f64>,
136 #[serde(default)]
137 pub output_tokens: Option<f64>,
138}
139
140impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
141 type Error = CompletionError;
142
143 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
144 let (content, _, tool_calls) = response.message()?;
145
146 let model_response = if !tool_calls.is_empty() {
147 OneOrMany::many(
148 tool_calls
149 .into_iter()
150 .filter_map(|tool_call| {
151 let ToolCallFunction { name, arguments } = tool_call.function?;
152 let id = tool_call.id.unwrap_or_else(|| name.clone());
153
154 Some(completion::AssistantContent::tool_call(id, name, arguments))
155 })
156 .collect::<Vec<_>>(),
157 )
158 .map_err(|_| {
159 CompletionError::ResponseError(
160 "response contained tool call metadata without any callable tool content"
161 .to_owned(),
162 )
163 })?
164 } else {
165 OneOrMany::many(content.into_iter().map(|content| match content {
166 AssistantContent::Text { text } => completion::AssistantContent::text(text),
167 AssistantContent::Thinking { thinking } => {
168 completion::AssistantContent::Reasoning(Reasoning::new(&thinking))
169 }
170 }))
171 .map_err(|_| {
172 CompletionError::ResponseError(
173 "Response contained no message or tool call (empty)".to_owned(),
174 )
175 })?
176 };
177
178 let usage = response
179 .usage
180 .as_ref()
181 .and_then(|usage| usage.tokens.as_ref())
182 .map(|tokens| {
183 let input_tokens = tokens.input_tokens.unwrap_or(0.0);
184 let output_tokens = tokens.output_tokens.unwrap_or(0.0);
185
186 completion::Usage {
187 input_tokens: input_tokens as u64,
188 output_tokens: output_tokens as u64,
189 total_tokens: (input_tokens + output_tokens) as u64,
190 cached_input_tokens: 0,
191 cache_creation_input_tokens: 0,
192 tool_use_prompt_tokens: 0,
193 reasoning_tokens: 0,
194 }
195 })
196 .unwrap_or_default();
197
198 Ok(completion::CompletionResponse {
199 choice: model_response,
200 usage,
201 raw_response: response,
202 message_id: None,
203 })
204 }
205}
206
207#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
208pub struct Document {
209 pub id: String,
210 pub data: HashMap<String, serde_json::Value>,
211}
212
213impl From<completion::Document> for Document {
214 fn from(document: completion::Document) -> Self {
215 let mut data: HashMap<String, serde_json::Value> = HashMap::new();
216
217 document
220 .additional_props
221 .into_iter()
222 .for_each(|(key, value)| {
223 data.insert(key, value.into());
224 });
225
226 data.insert("text".to_string(), document.text.into());
227
228 Self {
229 id: document.id,
230 data,
231 }
232 }
233}
234
235#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
236pub struct ToolCall {
237 #[serde(default)]
238 pub id: Option<String>,
239 #[serde(default)]
240 pub r#type: Option<ToolType>,
241 #[serde(default)]
242 pub function: Option<ToolCallFunction>,
243}
244
245#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
246pub struct ToolCallFunction {
247 pub name: String,
248 #[serde(with = "json_utils::stringified_json")]
249 pub arguments: serde_json::Value,
250}
251
252#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
253#[serde(rename_all = "lowercase")]
254pub enum ToolType {
255 #[default]
256 Function,
257}
258
259#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
260pub struct Tool {
261 pub r#type: ToolType,
262 pub function: Function,
263}
264
265#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
266pub struct Function {
267 pub name: String,
268 #[serde(default)]
269 pub description: Option<String>,
270 pub parameters: serde_json::Value,
271}
272
273impl From<completion::ToolDefinition> for Tool {
274 fn from(tool: completion::ToolDefinition) -> Self {
275 Self {
276 r#type: ToolType::default(),
277 function: Function {
278 name: tool.name,
279 description: Some(tool.description),
280 parameters: tool.parameters,
281 },
282 }
283 }
284}
285
286#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
287#[serde(tag = "role", rename_all = "lowercase")]
288pub enum Message {
289 User {
290 content: OneOrMany<UserContent>,
291 },
292
293 Assistant {
294 #[serde(default)]
295 content: Vec<AssistantContent>,
296 #[serde(default)]
297 citations: Vec<Citation>,
298 #[serde(default)]
299 tool_calls: Vec<ToolCall>,
300 #[serde(default)]
301 tool_plan: Option<String>,
302 },
303
304 Tool {
305 content: OneOrMany<ToolResultContent>,
306 tool_call_id: String,
307 },
308
309 System {
310 content: String,
311 },
312}
313
314#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
315#[serde(tag = "type", rename_all = "lowercase")]
316pub enum UserContent {
317 Text { text: String },
318 ImageUrl { image_url: ImageUrl },
319}
320
321#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
322#[serde(tag = "type", rename_all = "lowercase")]
323pub enum AssistantContent {
324 Text { text: String },
325 Thinking { thinking: String },
326}
327
328#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
329pub struct ImageUrl {
330 pub url: String,
331}
332
333#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
334pub enum ToolResultContent {
335 Text { text: String },
336 Document { document: Document },
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
340pub struct Citation {
341 #[serde(default)]
342 pub start: Option<u32>,
343 #[serde(default)]
344 pub end: Option<u32>,
345 #[serde(default)]
346 pub text: Option<String>,
347 #[serde(rename = "type")]
348 pub citation_type: Option<CitationType>,
349 #[serde(default)]
350 pub sources: Vec<Source>,
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
354#[serde(tag = "type", rename_all = "lowercase")]
355pub enum Source {
356 Document {
357 id: Option<String>,
358 document: Option<serde_json::Map<String, serde_json::Value>>,
359 },
360 Tool {
361 id: Option<String>,
362 tool_output: Option<serde_json::Map<String, serde_json::Value>>,
363 },
364}
365
366#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
367#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
368pub enum CitationType {
369 TextContent,
370 Plan,
371}
372
373impl TryFrom<message::Message> for Vec<Message> {
374 type Error = message::MessageError;
375
376 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
377 Ok(match message {
378 message::Message::User { content } => content
379 .into_iter()
380 .map(|content| match content {
381 message::UserContent::Text(message::Text { text, .. }) => Ok(Message::User {
382 content: OneOrMany::one(UserContent::Text { text }),
383 }),
384 message::UserContent::ToolResult(message::ToolResult {
385 id, content, ..
386 }) => Ok(Message::Tool {
387 tool_call_id: id,
388 content: content.try_map(|content| match content {
389 message::ToolResultContent::Text(text) => {
390 Ok(ToolResultContent::Text { text: text.text })
391 }
392 _ => Err(message::MessageError::ConversionError(
393 "Only text tool result content is supported by Cohere".to_owned(),
394 )),
395 })?,
396 }),
397 _ => Err(message::MessageError::ConversionError(
398 "Only text content is supported by Cohere".to_owned(),
399 )),
400 })
401 .collect::<Result<Vec<_>, _>>()?,
402 message::Message::System { content } => {
403 vec![Message::System { content }]
404 }
405 message::Message::Assistant { content, .. } => {
406 let mut text_content = vec![];
407 let mut tool_calls = vec![];
408
409 for content in content.into_iter() {
410 match content {
411 message::AssistantContent::Text(message::Text { text, .. }) => {
412 text_content.push(AssistantContent::Text { text });
413 }
414 message::AssistantContent::ToolCall(message::ToolCall {
415 id,
416 function:
417 message::ToolFunction {
418 name, arguments, ..
419 },
420 ..
421 }) => {
422 tool_calls.push(ToolCall {
423 id: Some(id),
424 r#type: Some(ToolType::Function),
425 function: Some(ToolCallFunction {
426 name,
427 arguments: serde_json::to_value(arguments).unwrap_or_default(),
428 }),
429 });
430 }
431 message::AssistantContent::Reasoning(reasoning) => {
432 let thinking = reasoning.display_text();
433 text_content.push(AssistantContent::Thinking { thinking });
434 }
435 message::AssistantContent::Image(_) => {
436 return Err(message::MessageError::ConversionError(
437 "Cohere currently doesn't support images.".to_owned(),
438 ));
439 }
440 }
441 }
442
443 vec![Message::Assistant {
444 content: text_content,
445 citations: vec![],
446 tool_calls,
447 tool_plan: None,
448 }]
449 }
450 })
451 }
452}
453
454impl TryFrom<Message> for message::Message {
455 type Error = message::MessageError;
456
457 fn try_from(message: Message) -> Result<Self, Self::Error> {
458 match message {
459 Message::User { content } => Ok(message::Message::User {
460 content: content.map(|content| match content {
461 UserContent::Text { text } => {
462 message::UserContent::Text(message::Text::new(text))
463 }
464 UserContent::ImageUrl { image_url } => {
465 message::UserContent::image_url(image_url.url, None, None)
466 }
467 }),
468 }),
469 Message::Assistant {
470 content,
471 tool_calls,
472 ..
473 } => {
474 let mut content = content
475 .into_iter()
476 .map(|content| match content {
477 AssistantContent::Text { text } => message::AssistantContent::text(text),
478 AssistantContent::Thinking { thinking } => {
479 message::AssistantContent::Reasoning(Reasoning::new(&thinking))
480 }
481 })
482 .collect::<Vec<_>>();
483
484 content.extend(tool_calls.into_iter().filter_map(|tool_call| {
485 let ToolCallFunction { name, arguments } = tool_call.function?;
486
487 Some(message::AssistantContent::tool_call(
488 tool_call.id.unwrap_or_else(|| name.clone()),
489 name,
490 arguments,
491 ))
492 }));
493
494 let content = OneOrMany::many(content).map_err(|_| {
495 message::MessageError::ConversionError(
496 "Expected either text content or tool calls".to_string(),
497 )
498 })?;
499
500 Ok(message::Message::Assistant { id: None, content })
501 }
502 Message::Tool {
503 content,
504 tool_call_id,
505 } => {
506 let content = content.try_map(|content| {
507 Ok(match content {
508 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
509 ToolResultContent::Document { document } => {
510 message::ToolResultContent::text(
511 serde_json::to_string(&document.data).map_err(|e| {
512 message::MessageError::ConversionError(
513 format!("Failed to convert tool result document content into text: {e}"),
514 )
515 })?,
516 )
517 }
518 })
519 })?;
520
521 Ok(message::Message::User {
522 content: OneOrMany::one(message::UserContent::tool_result(
523 tool_call_id,
524 content,
525 )),
526 })
527 }
528 Message::System { content } => Ok(message::Message::user(content)),
529 }
530 }
531}
532
533#[derive(Clone)]
534pub struct CompletionModel<T = reqwest::Client> {
535 pub(crate) client: Client<T>,
536 pub model: String,
537}
538
539#[derive(Debug, Serialize, Deserialize)]
540pub(super) struct CohereCompletionRequest {
541 model: String,
542 pub messages: Vec<Message>,
543 documents: Vec<crate::completion::Document>,
544 #[serde(skip_serializing_if = "Option::is_none")]
545 temperature: Option<f64>,
546 #[serde(skip_serializing_if = "Vec::is_empty")]
547 tools: Vec<Tool>,
548 #[serde(skip_serializing_if = "Option::is_none")]
549 tool_choice: Option<ToolChoice>,
550 #[serde(flatten, skip_serializing_if = "Option::is_none")]
551 pub additional_params: Option<serde_json::Value>,
552}
553
554impl TryFrom<(&str, CompletionRequest)> for CohereCompletionRequest {
555 type Error = CompletionError;
556
557 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
558 let documents = req.documents.clone();
559 if req.output_schema.is_some() {
560 tracing::warn!("Structured outputs currently not supported for Cohere");
561 }
562
563 let model = req.model.clone().unwrap_or_else(|| model.to_string());
564 let mut partial_history = vec![];
565 partial_history.extend(req.chat_history);
566
567 let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
568 vec![Message::System { content: preamble }]
569 });
570
571 full_history.extend(
572 partial_history
573 .into_iter()
574 .map(message::Message::try_into)
575 .collect::<Result<Vec<Vec<Message>>, _>>()?
576 .into_iter()
577 .flatten()
578 .collect::<Vec<_>>(),
579 );
580
581 let tool_choice = if let Some(tool_choice) = req.tool_choice {
582 if !matches!(tool_choice, ToolChoice::Auto) {
583 Some(tool_choice)
584 } else {
585 return Err(CompletionError::RequestError(
586 "\"auto\" is not an allowed tool_choice value in the Cohere API".into(),
587 ));
588 }
589 } else {
590 None
591 };
592
593 Ok(Self {
594 model: model.to_string(),
595 messages: full_history,
596 documents,
597 temperature: req.temperature,
598 tools: req.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
599 tool_choice,
600 additional_params: req.additional_params,
601 })
602 }
603}
604
605impl<T> CompletionModel<T>
606where
607 T: HttpClientExt,
608{
609 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
610 Self {
611 client,
612 model: model.into(),
613 }
614 }
615}
616
617impl<T> completion::CompletionModel for CompletionModel<T>
618where
619 T: HttpClientExt + Clone + 'static,
620{
621 type Response = CompletionResponse;
622 type StreamingResponse = StreamingCompletionResponse;
623 type Client = Client<T>;
624
625 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
626 Self::new(client.clone(), model.into())
627 }
628
629 async fn completion(
630 &self,
631 completion_request: completion::CompletionRequest,
632 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
633 let request = CohereCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
634
635 let llm_span = if tracing::Span::current().is_disabled() {
636 info_span!(
637 target: "rig::completions",
638 "chat",
639 gen_ai.operation.name = "chat",
640 gen_ai.provider.name = "cohere",
641 gen_ai.request.model = self.model,
642 gen_ai.response.id = tracing::field::Empty,
643 gen_ai.response.model = self.model,
644 gen_ai.usage.output_tokens = tracing::field::Empty,
645 gen_ai.usage.input_tokens = tracing::field::Empty,
646 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
647 )
648 } else {
649 tracing::Span::current()
650 };
651
652 if enabled!(Level::TRACE) {
653 tracing::trace!(
654 "Cohere completion request: {}",
655 serde_json::to_string_pretty(&request)?
656 );
657 }
658
659 let req_body = serde_json::to_vec(&request)?;
660
661 let req = self
662 .client
663 .post("/v2/chat")?
664 .body(req_body)
665 .map_err(|e| CompletionError::HttpError(e.into()))?;
666
667 async {
668 let response = self
669 .client
670 .send::<_, bytes::Bytes>(req)
671 .await
672 .map_err(|e| http_client::Error::Instance(e.into()))?;
673
674 let status = response.status();
675 let body = response.into_body().into_future().await?.to_owned();
676
677 if status.is_success() {
678 let json_response: CompletionResponse = serde_json::from_slice(&body)?;
679 let span = tracing::Span::current();
680 span.record_token_usage(&json_response.usage);
681 span.record_response_metadata(&json_response);
682
683 if enabled!(Level::TRACE) {
684 tracing::trace!(
685 target: "rig::completions",
686 "Cohere completion response: {}",
687 serde_json::to_string_pretty(&json_response)?
688 );
689 }
690
691 let completion: completion::CompletionResponse<CompletionResponse> =
692 json_response.try_into()?;
693 Ok(completion)
694 } else {
695 Err(CompletionError::ProviderError(
696 String::from_utf8_lossy(&body).to_string(),
697 ))
698 }
699 }
700 .instrument(llm_span)
701 .await
702 }
703
704 async fn stream(
705 &self,
706 request: CompletionRequest,
707 ) -> Result<
708 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
709 CompletionError,
710 > {
711 CompletionModel::stream(self, request).await
712 }
713}
714#[cfg(test)]
715mod tests {
716 use super::*;
717 use serde_path_to_error::deserialize;
718
719 #[test]
720 fn test_deserialize_completion_response() {
721 let json_data = r#"
722 {
723 "id": "abc123",
724 "message": {
725 "role": "assistant",
726 "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
727 "tool_calls": [
728 {
729 "id": "subtract_sm6ps6fb6y9f",
730 "type": "function",
731 "function": {
732 "name": "subtract",
733 "arguments": "{\"x\":5,\"y\":2}"
734 }
735 }
736 ]
737 },
738 "finish_reason": "TOOL_CALL",
739 "usage": {
740 "billed_units": {
741 "input_tokens": 78,
742 "output_tokens": 27
743 },
744 "tokens": {
745 "input_tokens": 1028,
746 "output_tokens": 63
747 }
748 }
749 }
750 "#;
751
752 let mut deserializer = serde_json::Deserializer::from_str(json_data);
753 let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
754
755 let response = result.unwrap();
756 let (_, citations, tool_calls) = response.message().expect("assistant message");
757 let CompletionResponse {
758 id,
759 finish_reason,
760 usage,
761 ..
762 } = response;
763
764 assert_eq!(id, "abc123");
765 assert_eq!(finish_reason, FinishReason::ToolCall);
766
767 let Usage {
768 billed_units,
769 tokens,
770 } = usage.unwrap();
771 let BilledUnits {
772 input_tokens: billed_input_tokens,
773 output_tokens: billed_output_tokens,
774 ..
775 } = billed_units.unwrap();
776 let Tokens {
777 input_tokens,
778 output_tokens,
779 } = tokens.unwrap();
780
781 assert_eq!(billed_input_tokens.unwrap(), 78.0);
782 assert_eq!(billed_output_tokens.unwrap(), 27.0);
783 assert_eq!(input_tokens.unwrap(), 1028.0);
784 assert_eq!(output_tokens.unwrap(), 63.0);
785
786 assert!(citations.is_empty());
787 assert_eq!(tool_calls.len(), 1);
788
789 let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
790
791 assert_eq!(name, "subtract");
792 assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
793 }
794
795 #[test]
796 fn test_convert_completion_message_to_message_and_back() {
797 let completion_message = completion::Message::User {
798 content: OneOrMany::one(completion::message::UserContent::Text(
799 completion::message::Text::new("Hello, world!".to_string()),
800 )),
801 };
802
803 let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
804 let _converted_back: Vec<completion::Message> = messages
805 .into_iter()
806 .map(|msg| msg.try_into().unwrap())
807 .collect::<Vec<_>>();
808 }
809
810 #[test]
811 fn test_convert_message_to_completion_message_and_back() {
812 let message = Message::User {
813 content: OneOrMany::one(UserContent::Text {
814 text: "Hello, world!".to_string(),
815 }),
816 };
817
818 let completion_message: completion::Message = message.clone().try_into().unwrap();
819 let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
820 }
821
822 #[test]
823 fn cohere_builder_request_preserves_native_documents() {
824 let request = crate::completion::CompletionRequestBuilder::new(
825 crate::test_utils::MockCompletionModel::default(),
826 "What is glarb-glarb?",
827 )
828 .document(crate::completion::request::Document {
829 id: "doc_1".to_string(),
830 text: "Definition of glarb-glarb: an ancient tool.".to_string(),
831 additional_props: Default::default(),
832 })
833 .build();
834
835 let request = CohereCompletionRequest::try_from(("command-r", request))
836 .expect("request conversion should succeed");
837
838 assert_eq!(request.documents.len(), 1);
839 assert_eq!(request.documents[0].id, "doc_1");
840 }
841}