Skip to main content

rig_core/providers/cohere/
completion.rs

1use crate::{
2    OneOrMany,
3    completion::{self, CompletionError, GetTokenUsage},
4    http_client::{self, HttpClientExt},
5    json_utils,
6    message::{self, Reasoning, ToolChoice},
7    telemetry::SpanCombinator,
8};
9use std::collections::HashMap;
10
11use super::client::Client;
12use crate::completion::CompletionRequest;
13use crate::providers::cohere::streaming::StreamingCompletionResponse;
14use serde::{Deserialize, Serialize};
15use tracing::{Instrument, Level, enabled, info_span};
16
17#[derive(Debug, Deserialize, Serialize)]
18pub struct CompletionResponse {
19    pub id: String,
20    pub finish_reason: FinishReason,
21    message: Message,
22    #[serde(default)]
23    pub usage: Option<Usage>,
24}
25
26type AssistantMessageParts = (Vec<AssistantContent>, Vec<Citation>, Vec<ToolCall>);
27
28impl CompletionResponse {
29    /// Return that parts of the response for assistant messages w/o dealing with the other variants
30    pub fn message(&self) -> Result<AssistantMessageParts, CompletionError> {
31        let Message::Assistant {
32            content,
33            citations,
34            tool_calls,
35            ..
36        } = self.message.clone()
37        else {
38            return Err(CompletionError::ResponseError(
39                "completion response did not contain an assistant message".into(),
40            ));
41        };
42
43        Ok((content, citations, tool_calls))
44    }
45}
46
47impl crate::telemetry::ProviderResponseExt for CompletionResponse {
48    type OutputMessage = Message;
49    type Usage = Usage;
50
51    fn get_response_id(&self) -> Option<String> {
52        Some(self.id.clone())
53    }
54
55    fn get_response_model_name(&self) -> Option<String> {
56        None
57    }
58
59    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
60        vec![self.message.clone()]
61    }
62
63    fn get_text_response(&self) -> Option<String> {
64        let Message::Assistant { ref content, .. } = self.message else {
65            return None;
66        };
67
68        let res = content
69            .iter()
70            .filter_map(|x| {
71                if let AssistantContent::Text { text } = x {
72                    Some(text.to_string())
73                } else {
74                    None
75                }
76            })
77            .collect::<Vec<String>>()
78            .join("\n");
79
80        if res.is_empty() { None } else { Some(res) }
81    }
82
83    fn get_usage(&self) -> Option<Self::Usage> {
84        self.usage.clone()
85    }
86}
87
88#[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)]
89#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
90pub enum FinishReason {
91    MaxTokens,
92    StopSequence,
93    Complete,
94    Error,
95    ToolCall,
96}
97
98#[derive(Debug, Deserialize, Clone, Serialize)]
99pub struct Usage {
100    #[serde(default)]
101    pub billed_units: Option<BilledUnits>,
102    #[serde(default)]
103    pub tokens: Option<Tokens>,
104}
105
106impl GetTokenUsage for Usage {
107    fn token_usage(&self) -> Option<crate::completion::Usage> {
108        let mut usage = crate::completion::Usage::new();
109
110        if let Some(ref billed_units) = self.billed_units {
111            usage.input_tokens = billed_units.input_tokens.unwrap_or_default() as u64;
112            usage.output_tokens = billed_units.output_tokens.unwrap_or_default() as u64;
113            usage.total_tokens = usage.input_tokens + usage.output_tokens;
114        }
115
116        Some(usage)
117    }
118}
119
120#[derive(Debug, Deserialize, Clone, Serialize)]
121pub struct BilledUnits {
122    #[serde(default)]
123    pub output_tokens: Option<f64>,
124    #[serde(default)]
125    pub classifications: Option<f64>,
126    #[serde(default)]
127    pub search_units: Option<f64>,
128    #[serde(default)]
129    pub input_tokens: Option<f64>,
130}
131
132#[derive(Debug, Deserialize, Clone, Serialize)]
133pub struct Tokens {
134    #[serde(default)]
135    pub input_tokens: Option<f64>,
136    #[serde(default)]
137    pub output_tokens: Option<f64>,
138}
139
140impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
141    type Error = CompletionError;
142
143    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
144        let (content, _, tool_calls) = response.message()?;
145
146        let model_response = if !tool_calls.is_empty() {
147            OneOrMany::many(
148                tool_calls
149                    .into_iter()
150                    .filter_map(|tool_call| {
151                        let ToolCallFunction { name, arguments } = tool_call.function?;
152                        let id = tool_call.id.unwrap_or_else(|| name.clone());
153
154                        Some(completion::AssistantContent::tool_call(id, name, arguments))
155                    })
156                    .collect::<Vec<_>>(),
157            )
158            .map_err(|_| {
159                CompletionError::ResponseError(
160                    "response contained tool call metadata without any callable tool content"
161                        .to_owned(),
162                )
163            })?
164        } else {
165            OneOrMany::many(content.into_iter().map(|content| match content {
166                AssistantContent::Text { text } => completion::AssistantContent::text(text),
167                AssistantContent::Thinking { thinking } => {
168                    completion::AssistantContent::Reasoning(Reasoning::new(&thinking))
169                }
170            }))
171            .map_err(|_| {
172                CompletionError::ResponseError(
173                    "Response contained no message or tool call (empty)".to_owned(),
174                )
175            })?
176        };
177
178        let usage = response
179            .usage
180            .as_ref()
181            .and_then(|usage| usage.tokens.as_ref())
182            .map(|tokens| {
183                let input_tokens = tokens.input_tokens.unwrap_or(0.0);
184                let output_tokens = tokens.output_tokens.unwrap_or(0.0);
185
186                completion::Usage {
187                    input_tokens: input_tokens as u64,
188                    output_tokens: output_tokens as u64,
189                    total_tokens: (input_tokens + output_tokens) as u64,
190                    cached_input_tokens: 0,
191                    cache_creation_input_tokens: 0,
192                    tool_use_prompt_tokens: 0,
193                    reasoning_tokens: 0,
194                }
195            })
196            .unwrap_or_default();
197
198        Ok(completion::CompletionResponse {
199            choice: model_response,
200            usage,
201            raw_response: response,
202            message_id: None,
203        })
204    }
205}
206
207#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
208pub struct Document {
209    pub id: String,
210    pub data: HashMap<String, serde_json::Value>,
211}
212
213impl From<completion::Document> for Document {
214    fn from(document: completion::Document) -> Self {
215        let mut data: HashMap<String, serde_json::Value> = HashMap::new();
216
217        // We use `.into()` here explicitly since the `document.additional_props` type will likely
218        //  evolve into `serde_json::Value` in the future.
219        document
220            .additional_props
221            .into_iter()
222            .for_each(|(key, value)| {
223                data.insert(key, value.into());
224            });
225
226        data.insert("text".to_string(), document.text.into());
227
228        Self {
229            id: document.id,
230            data,
231        }
232    }
233}
234
235#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
236pub struct ToolCall {
237    #[serde(default)]
238    pub id: Option<String>,
239    #[serde(default)]
240    pub r#type: Option<ToolType>,
241    #[serde(default)]
242    pub function: Option<ToolCallFunction>,
243}
244
245#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
246pub struct ToolCallFunction {
247    pub name: String,
248    #[serde(with = "json_utils::stringified_json")]
249    pub arguments: serde_json::Value,
250}
251
252#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq)]
253#[serde(rename_all = "lowercase")]
254pub enum ToolType {
255    #[default]
256    Function,
257}
258
259#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
260pub struct Tool {
261    pub r#type: ToolType,
262    pub function: Function,
263}
264
265#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
266pub struct Function {
267    pub name: String,
268    #[serde(default)]
269    pub description: Option<String>,
270    pub parameters: serde_json::Value,
271}
272
273impl From<completion::ToolDefinition> for Tool {
274    fn from(tool: completion::ToolDefinition) -> Self {
275        Self {
276            r#type: ToolType::default(),
277            function: Function {
278                name: tool.name,
279                description: Some(tool.description),
280                parameters: tool.parameters,
281            },
282        }
283    }
284}
285
286#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
287#[serde(tag = "role", rename_all = "lowercase")]
288pub enum Message {
289    User {
290        content: OneOrMany<UserContent>,
291    },
292
293    Assistant {
294        #[serde(default)]
295        content: Vec<AssistantContent>,
296        #[serde(default)]
297        citations: Vec<Citation>,
298        #[serde(default)]
299        tool_calls: Vec<ToolCall>,
300        #[serde(default)]
301        tool_plan: Option<String>,
302    },
303
304    Tool {
305        content: OneOrMany<ToolResultContent>,
306        tool_call_id: String,
307    },
308
309    System {
310        content: String,
311    },
312}
313
314#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
315#[serde(tag = "type", rename_all = "lowercase")]
316pub enum UserContent {
317    Text { text: String },
318    ImageUrl { image_url: ImageUrl },
319}
320
321#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
322#[serde(tag = "type", rename_all = "lowercase")]
323pub enum AssistantContent {
324    Text { text: String },
325    Thinking { thinking: String },
326}
327
328#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
329pub struct ImageUrl {
330    pub url: String,
331}
332
333#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
334pub enum ToolResultContent {
335    Text { text: String },
336    Document { document: Document },
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
340pub struct Citation {
341    #[serde(default)]
342    pub start: Option<u32>,
343    #[serde(default)]
344    pub end: Option<u32>,
345    #[serde(default)]
346    pub text: Option<String>,
347    #[serde(rename = "type")]
348    pub citation_type: Option<CitationType>,
349    #[serde(default)]
350    pub sources: Vec<Source>,
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
354#[serde(tag = "type", rename_all = "lowercase")]
355pub enum Source {
356    Document {
357        id: Option<String>,
358        document: Option<serde_json::Map<String, serde_json::Value>>,
359    },
360    Tool {
361        id: Option<String>,
362        tool_output: Option<serde_json::Map<String, serde_json::Value>>,
363    },
364}
365
366#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
367#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
368pub enum CitationType {
369    TextContent,
370    Plan,
371}
372
373impl TryFrom<message::Message> for Vec<Message> {
374    type Error = message::MessageError;
375
376    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
377        Ok(match message {
378            message::Message::User { content } => content
379                .into_iter()
380                .map(|content| match content {
381                    message::UserContent::Text(message::Text { text, .. }) => Ok(Message::User {
382                        content: OneOrMany::one(UserContent::Text { text }),
383                    }),
384                    message::UserContent::ToolResult(message::ToolResult {
385                        id, content, ..
386                    }) => Ok(Message::Tool {
387                        tool_call_id: id,
388                        content: content.try_map(|content| match content {
389                            message::ToolResultContent::Text(text) => {
390                                Ok(ToolResultContent::Text { text: text.text })
391                            }
392                            _ => Err(message::MessageError::ConversionError(
393                                "Only text tool result content is supported by Cohere".to_owned(),
394                            )),
395                        })?,
396                    }),
397                    _ => Err(message::MessageError::ConversionError(
398                        "Only text content is supported by Cohere".to_owned(),
399                    )),
400                })
401                .collect::<Result<Vec<_>, _>>()?,
402            message::Message::System { content } => {
403                vec![Message::System { content }]
404            }
405            message::Message::Assistant { content, .. } => {
406                let mut text_content = vec![];
407                let mut tool_calls = vec![];
408
409                for content in content.into_iter() {
410                    match content {
411                        message::AssistantContent::Text(message::Text { text, .. }) => {
412                            text_content.push(AssistantContent::Text { text });
413                        }
414                        message::AssistantContent::ToolCall(message::ToolCall {
415                            id,
416                            function:
417                                message::ToolFunction {
418                                    name, arguments, ..
419                                },
420                            ..
421                        }) => {
422                            tool_calls.push(ToolCall {
423                                id: Some(id),
424                                r#type: Some(ToolType::Function),
425                                function: Some(ToolCallFunction {
426                                    name,
427                                    arguments: serde_json::to_value(arguments).unwrap_or_default(),
428                                }),
429                            });
430                        }
431                        message::AssistantContent::Reasoning(reasoning) => {
432                            let thinking = reasoning.display_text();
433                            text_content.push(AssistantContent::Thinking { thinking });
434                        }
435                        message::AssistantContent::Image(_) => {
436                            return Err(message::MessageError::ConversionError(
437                                "Cohere currently doesn't support images.".to_owned(),
438                            ));
439                        }
440                    }
441                }
442
443                vec![Message::Assistant {
444                    content: text_content,
445                    citations: vec![],
446                    tool_calls,
447                    tool_plan: None,
448                }]
449            }
450        })
451    }
452}
453
454impl TryFrom<Message> for message::Message {
455    type Error = message::MessageError;
456
457    fn try_from(message: Message) -> Result<Self, Self::Error> {
458        match message {
459            Message::User { content } => Ok(message::Message::User {
460                content: content.map(|content| match content {
461                    UserContent::Text { text } => {
462                        message::UserContent::Text(message::Text::new(text))
463                    }
464                    UserContent::ImageUrl { image_url } => {
465                        message::UserContent::image_url(image_url.url, None, None)
466                    }
467                }),
468            }),
469            Message::Assistant {
470                content,
471                tool_calls,
472                ..
473            } => {
474                let mut content = content
475                    .into_iter()
476                    .map(|content| match content {
477                        AssistantContent::Text { text } => message::AssistantContent::text(text),
478                        AssistantContent::Thinking { thinking } => {
479                            message::AssistantContent::Reasoning(Reasoning::new(&thinking))
480                        }
481                    })
482                    .collect::<Vec<_>>();
483
484                content.extend(tool_calls.into_iter().filter_map(|tool_call| {
485                    let ToolCallFunction { name, arguments } = tool_call.function?;
486
487                    Some(message::AssistantContent::tool_call(
488                        tool_call.id.unwrap_or_else(|| name.clone()),
489                        name,
490                        arguments,
491                    ))
492                }));
493
494                let content = OneOrMany::many(content).map_err(|_| {
495                    message::MessageError::ConversionError(
496                        "Expected either text content or tool calls".to_string(),
497                    )
498                })?;
499
500                Ok(message::Message::Assistant { id: None, content })
501            }
502            Message::Tool {
503                content,
504                tool_call_id,
505            } => {
506                let content = content.try_map(|content| {
507                    Ok(match content {
508                        ToolResultContent::Text { text } => message::ToolResultContent::text(text),
509                        ToolResultContent::Document { document } => {
510                            message::ToolResultContent::text(
511                                serde_json::to_string(&document.data).map_err(|e| {
512                                    message::MessageError::ConversionError(
513                                        format!("Failed to convert tool result document content into text: {e}"),
514                                    )
515                                })?,
516                            )
517                        }
518                    })
519                })?;
520
521                Ok(message::Message::User {
522                    content: OneOrMany::one(message::UserContent::tool_result(
523                        tool_call_id,
524                        content,
525                    )),
526                })
527            }
528            Message::System { content } => Ok(message::Message::user(content)),
529        }
530    }
531}
532
533#[derive(Clone)]
534pub struct CompletionModel<T = reqwest::Client> {
535    pub(crate) client: Client<T>,
536    pub model: String,
537}
538
539#[derive(Debug, Serialize, Deserialize)]
540pub(super) struct CohereCompletionRequest {
541    model: String,
542    pub messages: Vec<Message>,
543    documents: Vec<crate::completion::Document>,
544    #[serde(skip_serializing_if = "Option::is_none")]
545    temperature: Option<f64>,
546    #[serde(skip_serializing_if = "Vec::is_empty")]
547    tools: Vec<Tool>,
548    #[serde(skip_serializing_if = "Option::is_none")]
549    tool_choice: Option<ToolChoice>,
550    #[serde(flatten, skip_serializing_if = "Option::is_none")]
551    pub additional_params: Option<serde_json::Value>,
552}
553
554impl TryFrom<(&str, CompletionRequest)> for CohereCompletionRequest {
555    type Error = CompletionError;
556
557    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
558        if req.output_schema.is_some() {
559            tracing::warn!("Structured outputs currently not supported for Cohere");
560        }
561
562        let model = req.model.clone().unwrap_or_else(|| model.to_string());
563        let mut partial_history = vec![];
564        if let Some(docs) = req.normalized_documents() {
565            partial_history.push(docs);
566        }
567        partial_history.extend(req.chat_history);
568
569        let mut full_history: Vec<Message> = req.preamble.map_or_else(Vec::new, |preamble| {
570            vec![Message::System { content: preamble }]
571        });
572
573        full_history.extend(
574            partial_history
575                .into_iter()
576                .map(message::Message::try_into)
577                .collect::<Result<Vec<Vec<Message>>, _>>()?
578                .into_iter()
579                .flatten()
580                .collect::<Vec<_>>(),
581        );
582
583        let tool_choice = if let Some(tool_choice) = req.tool_choice {
584            if !matches!(tool_choice, ToolChoice::Auto) {
585                Some(tool_choice)
586            } else {
587                return Err(CompletionError::RequestError(
588                    "\"auto\" is not an allowed tool_choice value in the Cohere API".into(),
589                ));
590            }
591        } else {
592            None
593        };
594
595        Ok(Self {
596            model: model.to_string(),
597            messages: full_history,
598            documents: req.documents,
599            temperature: req.temperature,
600            tools: req.tools.into_iter().map(Tool::from).collect::<Vec<_>>(),
601            tool_choice,
602            additional_params: req.additional_params,
603        })
604    }
605}
606
607impl<T> CompletionModel<T>
608where
609    T: HttpClientExt,
610{
611    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
612        Self {
613            client,
614            model: model.into(),
615        }
616    }
617}
618
619impl<T> completion::CompletionModel for CompletionModel<T>
620where
621    T: HttpClientExt + Clone + 'static,
622{
623    type Response = CompletionResponse;
624    type StreamingResponse = StreamingCompletionResponse;
625    type Client = Client<T>;
626
627    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
628        Self::new(client.clone(), model.into())
629    }
630
631    async fn completion(
632        &self,
633        completion_request: completion::CompletionRequest,
634    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
635        let request = CohereCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
636
637        let llm_span = if tracing::Span::current().is_disabled() {
638            info_span!(
639            target: "rig::completions",
640            "chat",
641            gen_ai.operation.name = "chat",
642            gen_ai.provider.name = "cohere",
643            gen_ai.request.model = self.model,
644            gen_ai.response.id = tracing::field::Empty,
645            gen_ai.response.model = self.model,
646            gen_ai.usage.output_tokens = tracing::field::Empty,
647            gen_ai.usage.input_tokens = tracing::field::Empty,
648            gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
649            )
650        } else {
651            tracing::Span::current()
652        };
653
654        if enabled!(Level::TRACE) {
655            tracing::trace!(
656                "Cohere completion request: {}",
657                serde_json::to_string_pretty(&request)?
658            );
659        }
660
661        let req_body = serde_json::to_vec(&request)?;
662
663        let req = self
664            .client
665            .post("/v2/chat")?
666            .body(req_body)
667            .map_err(|e| CompletionError::HttpError(e.into()))?;
668
669        async {
670            let response = self
671                .client
672                .send::<_, bytes::Bytes>(req)
673                .await
674                .map_err(|e| http_client::Error::Instance(e.into()))?;
675
676            let status = response.status();
677            let body = response.into_body().into_future().await?.to_owned();
678
679            if status.is_success() {
680                let json_response: CompletionResponse = serde_json::from_slice(&body)?;
681                let span = tracing::Span::current();
682                span.record_token_usage(&json_response.usage);
683                span.record_response_metadata(&json_response);
684
685                if enabled!(Level::TRACE) {
686                    tracing::trace!(
687                        target: "rig::completions",
688                        "Cohere completion response: {}",
689                        serde_json::to_string_pretty(&json_response)?
690                    );
691                }
692
693                let completion: completion::CompletionResponse<CompletionResponse> =
694                    json_response.try_into()?;
695                Ok(completion)
696            } else {
697                Err(CompletionError::ProviderError(
698                    String::from_utf8_lossy(&body).to_string(),
699                ))
700            }
701        }
702        .instrument(llm_span)
703        .await
704    }
705
706    async fn stream(
707        &self,
708        request: CompletionRequest,
709    ) -> Result<
710        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
711        CompletionError,
712    > {
713        CompletionModel::stream(self, request).await
714    }
715}
716#[cfg(test)]
717mod tests {
718    use super::*;
719    use serde_path_to_error::deserialize;
720
721    #[test]
722    fn test_deserialize_completion_response() {
723        let json_data = r#"
724        {
725            "id": "abc123",
726            "message": {
727                "role": "assistant",
728                "tool_plan": "I will use the subtract tool to find the difference between 2 and 5.",
729                "tool_calls": [
730                        {
731                            "id": "subtract_sm6ps6fb6y9f",
732                            "type": "function",
733                            "function": {
734                                "name": "subtract",
735                                "arguments": "{\"x\":5,\"y\":2}"
736                            }
737                        }
738                    ]
739                },
740                "finish_reason": "TOOL_CALL",
741                "usage": {
742                "billed_units": {
743                    "input_tokens": 78,
744                    "output_tokens": 27
745                },
746                "tokens": {
747                    "input_tokens": 1028,
748                    "output_tokens": 63
749                }
750            }
751        }
752        "#;
753
754        let mut deserializer = serde_json::Deserializer::from_str(json_data);
755        let result: Result<CompletionResponse, _> = deserialize(&mut deserializer);
756
757        let response = result.unwrap();
758        let (_, citations, tool_calls) = response.message().expect("assistant message");
759        let CompletionResponse {
760            id,
761            finish_reason,
762            usage,
763            ..
764        } = response;
765
766        assert_eq!(id, "abc123");
767        assert_eq!(finish_reason, FinishReason::ToolCall);
768
769        let Usage {
770            billed_units,
771            tokens,
772        } = usage.unwrap();
773        let BilledUnits {
774            input_tokens: billed_input_tokens,
775            output_tokens: billed_output_tokens,
776            ..
777        } = billed_units.unwrap();
778        let Tokens {
779            input_tokens,
780            output_tokens,
781        } = tokens.unwrap();
782
783        assert_eq!(billed_input_tokens.unwrap(), 78.0);
784        assert_eq!(billed_output_tokens.unwrap(), 27.0);
785        assert_eq!(input_tokens.unwrap(), 1028.0);
786        assert_eq!(output_tokens.unwrap(), 63.0);
787
788        assert!(citations.is_empty());
789        assert_eq!(tool_calls.len(), 1);
790
791        let ToolCallFunction { name, arguments } = tool_calls[0].function.clone().unwrap();
792
793        assert_eq!(name, "subtract");
794        assert_eq!(arguments, serde_json::json!({"x": 5, "y": 2}));
795    }
796
797    #[test]
798    fn test_convert_completion_message_to_message_and_back() {
799        let completion_message = completion::Message::User {
800            content: OneOrMany::one(completion::message::UserContent::Text(
801                completion::message::Text::new("Hello, world!".to_string()),
802            )),
803        };
804
805        let messages: Vec<Message> = completion_message.clone().try_into().unwrap();
806        let _converted_back: Vec<completion::Message> = messages
807            .into_iter()
808            .map(|msg| msg.try_into().unwrap())
809            .collect::<Vec<_>>();
810    }
811
812    #[test]
813    fn test_convert_message_to_completion_message_and_back() {
814        let message = Message::User {
815            content: OneOrMany::one(UserContent::Text {
816                text: "Hello, world!".to_string(),
817            }),
818        };
819
820        let completion_message: completion::Message = message.clone().try_into().unwrap();
821        let _converted_back: Vec<Message> = completion_message.try_into().unwrap();
822    }
823}