Skip to main content

rig_core/providers/mistral/
completion.rs

1use serde::{Deserialize, Serialize};
2use std::{convert::Infallible, str::FromStr};
3use tracing::{Instrument, Level, enabled, info_span};
4
5use super::client::{Client, Usage};
6use crate::completion::GetTokenUsage;
7use crate::http_client::{self, HttpClientExt};
8use crate::providers::internal::buffered;
9use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
10use crate::{
11    OneOrMany,
12    completion::{self, CompletionError, CompletionRequest},
13    json_utils, message,
14    providers::mistral::client::ApiResponse,
15    telemetry::SpanCombinator,
16};
17
18/// The latest version of the `codestral` Mistral model
19pub const CODESTRAL: &str = "codestral-latest";
20/// The latest version of the `mistral-large` Mistral model
21pub const MISTRAL_LARGE: &str = "mistral-large-latest";
22/// The latest version of the `pixtral-large` Mistral multimodal model
23pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
24/// The latest version of the `mistral` Mistral multimodal model, trained on datasets from the Middle East & South Asia
25pub const MISTRAL_SABA: &str = "mistral-saba-latest";
26/// The latest version of the `mistral-3b` Mistral completions model
27pub const MINISTRAL_3B: &str = "ministral-3b-latest";
28/// The latest version of the `mistral-8b` Mistral completions model
29pub const MINISTRAL_8B: &str = "ministral-8b-latest";
30
31/// The latest version of the `mistral-small` Mistral completions model
32pub const MISTRAL_SMALL: &str = "mistral-small-latest";
33/// The `24-09` version of the `pixtral-small` Mistral multimodal model
34pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
35/// The `open-mistral-nemo` model
36pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
37/// The `open-mistral-mamba` model
38pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
39
40// =================================================================
41// Rig Implementation Types
42// =================================================================
43
44#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
45#[serde(tag = "type", rename_all = "lowercase")]
46pub struct AssistantContent {
47    text: String,
48}
49
50#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
51#[serde(tag = "type", rename_all = "lowercase")]
52pub enum UserContent {
53    Text { text: String },
54}
55
56#[derive(Debug, Serialize, Deserialize, Clone)]
57pub struct Choice {
58    pub index: usize,
59    pub message: Message,
60    pub logprobs: Option<serde_json::Value>,
61    pub finish_reason: String,
62}
63
64#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
65#[serde(tag = "role", rename_all = "lowercase")]
66pub enum Message {
67    User {
68        content: String,
69    },
70    Assistant {
71        content: String,
72        #[serde(
73            default,
74            deserialize_with = "json_utils::null_or_vec",
75            skip_serializing_if = "Vec::is_empty"
76        )]
77        tool_calls: Vec<ToolCall>,
78        #[serde(default)]
79        prefix: bool,
80    },
81    System {
82        content: String,
83    },
84    Tool {
85        /// The name of the tool that was called
86        name: String,
87        /// The content of the tool call
88        content: String,
89        /// The id of the tool call
90        tool_call_id: String,
91    },
92}
93
94impl Message {
95    pub fn user(content: String) -> Self {
96        Message::User { content }
97    }
98
99    pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
100        Message::Assistant {
101            content,
102            tool_calls,
103            prefix,
104        }
105    }
106
107    pub fn system(content: String) -> Self {
108        Message::System { content }
109    }
110}
111
112impl TryFrom<message::Message> for Vec<Message> {
113    type Error = message::MessageError;
114
115    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
116        match message {
117            message::Message::System { content } => Ok(vec![Message::System { content }]),
118            message::Message::User { content } => {
119                let mut tool_result_messages = Vec::new();
120                let mut other_messages = Vec::new();
121
122                for content_item in content {
123                    match content_item {
124                        message::UserContent::ToolResult(message::ToolResult {
125                            id,
126                            call_id,
127                            content: tool_content,
128                        }) => {
129                            let call_id_key = call_id.unwrap_or_else(|| id.clone());
130                            let content_text = tool_content
131                                .into_iter()
132                                .find_map(|content_item| match content_item {
133                                    message::ToolResultContent::Text(text) => Some(text.text),
134                                    message::ToolResultContent::Image(_) => None,
135                                })
136                                .unwrap_or_default();
137                            tool_result_messages.push(Message::Tool {
138                                name: id,
139                                content: content_text,
140                                tool_call_id: call_id_key,
141                            });
142                        }
143                        message::UserContent::Text(message::Text { text, .. }) => {
144                            other_messages.push(Message::User { content: text });
145                        }
146                        _ => {}
147                    }
148                }
149
150                tool_result_messages.append(&mut other_messages);
151                Ok(tool_result_messages)
152            }
153            message::Message::Assistant { content, .. } => {
154                let mut text_content = Vec::new();
155                let mut tool_calls = Vec::new();
156
157                for content in content {
158                    match content {
159                        message::AssistantContent::Text(text) => text_content.push(text),
160                        message::AssistantContent::ToolCall(tool_call) => {
161                            tool_calls.push(tool_call)
162                        }
163                        message::AssistantContent::Reasoning(_) => {
164                            // Mistral conversion path currently does not support assistant-history
165                            // reasoning items. Silently skip to avoid crashing the process.
166                        }
167                        message::AssistantContent::Image(_) => {
168                            return Err(message::MessageError::ConversionError(
169                                "Mistral assistant messages do not support image content".into(),
170                            ));
171                        }
172                    }
173                }
174
175                if text_content.is_empty() && tool_calls.is_empty() {
176                    return Ok(vec![]);
177                }
178
179                Ok(vec![Message::Assistant {
180                    content: text_content
181                        .into_iter()
182                        .next()
183                        .map(|content| content.text)
184                        .unwrap_or_default(),
185                    tool_calls: tool_calls
186                        .into_iter()
187                        .map(|tool_call| tool_call.into())
188                        .collect::<Vec<_>>(),
189                    prefix: false,
190                }])
191            }
192        }
193    }
194}
195
196#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
197pub struct ToolCall {
198    pub id: String,
199    #[serde(default)]
200    pub r#type: ToolType,
201    pub function: Function,
202}
203
204impl From<message::ToolCall> for ToolCall {
205    fn from(tool_call: message::ToolCall) -> Self {
206        Self {
207            id: tool_call.id,
208            r#type: ToolType::default(),
209            function: Function {
210                name: tool_call.function.name,
211                arguments: tool_call.function.arguments,
212            },
213        }
214    }
215}
216
217#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
218pub struct Function {
219    pub name: String,
220    #[serde(with = "json_utils::stringified_json")]
221    pub arguments: serde_json::Value,
222}
223
224#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
225#[serde(rename_all = "lowercase")]
226pub enum ToolType {
227    #[default]
228    Function,
229}
230
231#[derive(Debug, Deserialize, Serialize, Clone)]
232pub struct ToolDefinition {
233    pub r#type: String,
234    pub function: completion::ToolDefinition,
235}
236
237impl From<completion::ToolDefinition> for ToolDefinition {
238    fn from(tool: completion::ToolDefinition) -> Self {
239        Self {
240            r#type: "function".into(),
241            function: tool,
242        }
243    }
244}
245
246#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
247pub struct ToolResultContent {
248    #[serde(default)]
249    r#type: ToolResultContentType,
250    text: String,
251}
252
253#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
254#[serde(rename_all = "lowercase")]
255pub enum ToolResultContentType {
256    #[default]
257    Text,
258}
259
260impl From<String> for ToolResultContent {
261    fn from(s: String) -> Self {
262        ToolResultContent {
263            r#type: ToolResultContentType::default(),
264            text: s,
265        }
266    }
267}
268
269impl From<String> for UserContent {
270    fn from(s: String) -> Self {
271        UserContent::Text { text: s }
272    }
273}
274
275impl FromStr for UserContent {
276    type Err = Infallible;
277
278    fn from_str(s: &str) -> Result<Self, Self::Err> {
279        Ok(UserContent::Text {
280            text: s.to_string(),
281        })
282    }
283}
284
285impl From<String> for AssistantContent {
286    fn from(s: String) -> Self {
287        AssistantContent { text: s }
288    }
289}
290
291impl FromStr for AssistantContent {
292    type Err = Infallible;
293
294    fn from_str(s: &str) -> Result<Self, Self::Err> {
295        Ok(AssistantContent {
296            text: s.to_string(),
297        })
298    }
299}
300
301#[derive(Clone)]
302pub struct CompletionModel<T = reqwest::Client> {
303    pub(crate) client: Client<T>,
304    pub model: String,
305}
306
307#[derive(Debug, Default, Serialize, Deserialize)]
308pub enum ToolChoice {
309    #[default]
310    Auto,
311    None,
312    Any,
313}
314
315impl TryFrom<message::ToolChoice> for ToolChoice {
316    type Error = CompletionError;
317
318    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
319        let res = match value {
320            message::ToolChoice::Auto => Self::Auto,
321            message::ToolChoice::None => Self::None,
322            message::ToolChoice::Required => Self::Any,
323            message::ToolChoice::Specific { .. } => {
324                return Err(CompletionError::ProviderError(
325                    "Mistral doesn't support requiring specific tools to be called".to_string(),
326                ));
327            }
328        };
329
330        Ok(res)
331    }
332}
333
334#[derive(Debug, Serialize, Deserialize)]
335pub(super) struct MistralCompletionRequest {
336    model: String,
337    pub messages: Vec<Message>,
338    #[serde(skip_serializing_if = "Option::is_none")]
339    temperature: Option<f64>,
340    #[serde(skip_serializing_if = "Vec::is_empty")]
341    tools: Vec<ToolDefinition>,
342    #[serde(skip_serializing_if = "Option::is_none")]
343    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
344    #[serde(flatten, skip_serializing_if = "Option::is_none")]
345    pub additional_params: Option<serde_json::Value>,
346}
347
348impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest {
349    type Error = CompletionError;
350
351    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
352        if req.output_schema.is_some() {
353            tracing::warn!("Structured outputs currently not supported for Mistral");
354        }
355        let model = req.model.clone().unwrap_or_else(|| model.to_string());
356        let mut full_history: Vec<Message> = match &req.preamble {
357            Some(preamble) => vec![Message::system(preamble.clone())],
358            None => vec![],
359        };
360        if let Some(docs) = req.normalized_documents() {
361            let docs: Vec<Message> = docs.try_into()?;
362            full_history.extend(docs);
363        }
364
365        let chat_history: Vec<Message> = req
366            .chat_history
367            .clone()
368            .into_iter()
369            .map(|message| message.try_into())
370            .collect::<Result<Vec<Vec<Message>>, _>>()?
371            .into_iter()
372            .flatten()
373            .collect();
374
375        full_history.extend(chat_history);
376
377        if full_history.is_empty() {
378            return Err(CompletionError::RequestError(
379                std::io::Error::new(
380                    std::io::ErrorKind::InvalidInput,
381                    "Mistral request has no provider-compatible messages after conversion",
382                )
383                .into(),
384            ));
385        }
386
387        let tool_choice = req
388            .tool_choice
389            .clone()
390            .map(crate::providers::openai::completion::ToolChoice::try_from)
391            .transpose()?;
392
393        Ok(Self {
394            model: model.to_string(),
395            messages: full_history,
396            temperature: req.temperature,
397            tools: req
398                .tools
399                .clone()
400                .into_iter()
401                .map(ToolDefinition::from)
402                .collect::<Vec<_>>(),
403            tool_choice,
404            additional_params: req.additional_params,
405        })
406    }
407}
408
409impl<T> CompletionModel<T> {
410    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
411        Self {
412            client,
413            model: model.into(),
414        }
415    }
416
417    pub fn with_model(client: Client<T>, model: &str) -> Self {
418        Self {
419            client,
420            model: model.into(),
421        }
422    }
423}
424
425#[derive(Debug, Deserialize, Clone, Serialize)]
426pub struct CompletionResponse {
427    pub id: String,
428    pub object: String,
429    pub created: u64,
430    pub model: String,
431    pub system_fingerprint: Option<String>,
432    pub choices: Vec<Choice>,
433    pub usage: Option<Usage>,
434}
435
436impl crate::telemetry::ProviderResponseExt for CompletionResponse {
437    type OutputMessage = Choice;
438    type Usage = Usage;
439
440    fn get_response_id(&self) -> Option<String> {
441        Some(self.id.clone())
442    }
443
444    fn get_response_model_name(&self) -> Option<String> {
445        Some(self.model.clone())
446    }
447
448    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
449        self.choices.clone()
450    }
451
452    fn get_text_response(&self) -> Option<String> {
453        let res = self
454            .choices
455            .iter()
456            .filter_map(|choice| match choice.message {
457                Message::Assistant { ref content, .. } => {
458                    if content.is_empty() {
459                        None
460                    } else {
461                        Some(content.to_string())
462                    }
463                }
464                _ => None,
465            })
466            .collect::<Vec<String>>()
467            .join("\n");
468
469        if res.is_empty() { None } else { Some(res) }
470    }
471
472    fn get_usage(&self) -> Option<Self::Usage> {
473        self.usage.clone()
474    }
475}
476
477impl GetTokenUsage for CompletionResponse {
478    fn token_usage(&self) -> Option<crate::completion::Usage> {
479        let api_usage = self.usage.as_ref()?;
480
481        let mut usage = crate::completion::Usage::new();
482        usage.input_tokens = api_usage.prompt_tokens as u64;
483        usage.output_tokens = api_usage.completion_tokens as u64;
484        usage.total_tokens = api_usage.total_tokens as u64;
485        usage.cached_input_tokens = api_usage.cached_tokens();
486
487        Some(usage)
488    }
489}
490
491impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
492    type Error = CompletionError;
493
494    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
495        let choice = response.choices.first().ok_or_else(|| {
496            CompletionError::ResponseError("Response contained no choices".to_owned())
497        })?;
498        let content = match &choice.message {
499            Message::Assistant {
500                content,
501                tool_calls,
502                ..
503            } => {
504                let mut content = if content.is_empty() {
505                    vec![]
506                } else {
507                    vec![completion::AssistantContent::text(content.clone())]
508                };
509
510                content.extend(
511                    tool_calls
512                        .iter()
513                        .map(|call| {
514                            completion::AssistantContent::tool_call(
515                                &call.id,
516                                &call.function.name,
517                                call.function.arguments.clone(),
518                            )
519                        })
520                        .collect::<Vec<_>>(),
521                );
522                Ok(content)
523            }
524            _ => Err(CompletionError::ResponseError(
525                "Response did not contain a valid message or tool call".into(),
526            )),
527        }?;
528
529        let choice = OneOrMany::many(content).map_err(|_| {
530            CompletionError::ResponseError(
531                "Response contained no message or tool call (empty)".to_owned(),
532            )
533        })?;
534
535        let usage = response
536            .usage
537            .as_ref()
538            .map(|usage| completion::Usage {
539                input_tokens: usage.prompt_tokens as u64,
540                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
541                total_tokens: usage.total_tokens as u64,
542                cached_input_tokens: usage.cached_tokens(),
543                cache_creation_input_tokens: 0,
544                tool_use_prompt_tokens: 0,
545                reasoning_tokens: 0,
546            })
547            .unwrap_or_default();
548
549        Ok(completion::CompletionResponse {
550            choice,
551            usage,
552            raw_response: response,
553            message_id: None,
554        })
555    }
556}
557
558fn assistant_content_to_streaming_choices(
559    content: message::AssistantContent,
560) -> Result<Vec<RawStreamingChoice<CompletionResponse>>, CompletionError> {
561    match content {
562        message::AssistantContent::Text(t) => Ok(vec![RawStreamingChoice::Message(t.text)]),
563        message::AssistantContent::ToolCall(tc) => Ok(vec![RawStreamingChoice::ToolCall(
564            RawStreamingToolCall::new(tc.id, tc.function.name, tc.function.arguments),
565        )]),
566        message::AssistantContent::Reasoning(_) => Ok(Vec::new()),
567        message::AssistantContent::Image(_) => Err(CompletionError::ResponseError(
568            "Image content is not supported on Mistral via Rig".into(),
569        )),
570    }
571}
572
573impl<T> completion::CompletionModel for CompletionModel<T>
574where
575    T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static,
576{
577    type Response = CompletionResponse;
578    type StreamingResponse = CompletionResponse;
579
580    type Client = Client<T>;
581
582    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
583        Self::new(client.clone(), model.into())
584    }
585
586    async fn completion(
587        &self,
588        completion_request: CompletionRequest,
589    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
590        let preamble = completion_request.preamble.clone();
591        let request =
592            MistralCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
593
594        if enabled!(Level::TRACE) {
595            tracing::trace!(
596                target: "rig::completions",
597                "Mistral completion request: {}",
598                serde_json::to_string_pretty(&request)?
599            );
600        }
601
602        let span = if tracing::Span::current().is_disabled() {
603            info_span!(
604                target: "rig::completions",
605                "chat",
606                gen_ai.operation.name = "chat",
607                gen_ai.provider.name = "mistral",
608                gen_ai.request.model = self.model,
609                gen_ai.system_instructions = &preamble,
610                gen_ai.response.id = tracing::field::Empty,
611                gen_ai.response.model = tracing::field::Empty,
612                gen_ai.usage.output_tokens = tracing::field::Empty,
613                gen_ai.usage.input_tokens = tracing::field::Empty,
614                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
615            )
616        } else {
617            tracing::Span::current()
618        };
619
620        let body = serde_json::to_vec(&request)?;
621
622        let request = self
623            .client
624            .post("v1/chat/completions")?
625            .body(body)
626            .map_err(|e| CompletionError::HttpError(e.into()))?;
627
628        async move {
629            let response = self.client.send(request).await?;
630
631            if response.status().is_success() {
632                let text = http_client::text(response).await?;
633                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
634                    ApiResponse::Ok(response) => {
635                        let span = tracing::Span::current();
636                        span.record_token_usage(&response);
637                        span.record_response_metadata(&response);
638                        response.try_into()
639                    }
640                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
641                }
642            } else {
643                let text = http_client::text(response).await?;
644                Err(CompletionError::ProviderError(text))
645            }
646        }
647        .instrument(span)
648        .await
649    }
650
651    async fn stream(
652        &self,
653        request: CompletionRequest,
654    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
655        let resp = self.completion(request).await?;
656        buffered::stream_from_completion_response(resp, assistant_content_to_streaming_choices)
657    }
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663
664    #[test]
665    fn test_response_deserialization() {
666        //https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post
667        let json_data = r#"
668        {
669            "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
670            "object": "chat.completion",
671            "model": "mistral-small-latest",
672            "usage": {
673                "prompt_tokens": 16,
674                "completion_tokens": 34,
675                "total_tokens": 50
676            },
677            "created": 1702256327,
678            "choices": [
679                {
680                    "index": 0,
681                    "message": {
682                        "content": "string",
683                        "tool_calls": [
684                            {
685                                "id": "null",
686                                "type": "function",
687                                "function": {
688                                    "name": "string",
689                                    "arguments": "{ }"
690                                },
691                                "index": 0
692                            }
693                        ],
694                        "prefix": false,
695                        "role": "assistant"
696                    },
697                    "finish_reason": "stop"
698                }
699            ]
700        }
701        "#;
702        let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
703        assert_eq!(completion_response.model, MISTRAL_SMALL);
704
705        let CompletionResponse {
706            id,
707            object,
708            created,
709            choices,
710            usage,
711            ..
712        } = completion_response;
713
714        assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
715
716        let usage = usage.unwrap();
717        assert_eq!(usage.prompt_tokens, 16);
718        assert_eq!(usage.completion_tokens, 34);
719        assert_eq!(usage.total_tokens, 50);
720        assert_eq!(usage.cached_tokens(), 0);
721        assert!(usage.prompt_tokens_details.is_none());
722        assert!(usage.num_cached_tokens.is_none());
723        assert_eq!(object, "chat.completion".to_string());
724        assert_eq!(created, 1702256327);
725        assert_eq!(choices.len(), 1);
726    }
727
728    #[test]
729    fn test_usage_deserializes_prompt_tokens_details_cached_tokens() {
730        let json = r#"{
731            "prompt_tokens": 100,
732            "completion_tokens": 20,
733            "total_tokens": 120,
734            "prompt_tokens_details": { "cached_tokens": 42 }
735        }"#;
736        let usage: Usage = serde_json::from_str(json).unwrap();
737        assert_eq!(usage.prompt_tokens, 100);
738        assert_eq!(
739            usage.prompt_tokens_details.as_ref().unwrap().cached_tokens,
740            42
741        );
742        assert_eq!(usage.cached_tokens(), 42);
743    }
744
745    #[test]
746    fn test_usage_accepts_singular_prompt_token_details_alias() {
747        let json = r#"{
748            "prompt_tokens": 100,
749            "completion_tokens": 20,
750            "total_tokens": 120,
751            "prompt_token_details": { "cached_tokens": 7 }
752        }"#;
753        let usage: Usage = serde_json::from_str(json).unwrap();
754        assert_eq!(
755            usage.prompt_tokens_details.as_ref().unwrap().cached_tokens,
756            7
757        );
758        assert_eq!(usage.cached_tokens(), 7);
759    }
760
761    #[test]
762    fn test_usage_falls_back_to_num_cached_tokens() {
763        let json = r#"{
764            "prompt_tokens": 100,
765            "completion_tokens": 20,
766            "total_tokens": 120,
767            "num_cached_tokens": 13
768        }"#;
769        let usage: Usage = serde_json::from_str(json).unwrap();
770        assert_eq!(usage.num_cached_tokens, Some(13));
771        assert!(usage.prompt_tokens_details.is_none());
772        assert_eq!(usage.cached_tokens(), 13);
773    }
774
775    #[test]
776    fn test_usage_prefers_prompt_tokens_details_over_num_cached_tokens() {
777        let json = r#"{
778            "prompt_tokens": 100,
779            "completion_tokens": 20,
780            "total_tokens": 120,
781            "num_cached_tokens": 1,
782            "prompt_tokens_details": { "cached_tokens": 99 }
783        }"#;
784        let usage: Usage = serde_json::from_str(json).unwrap();
785        assert_eq!(usage.cached_tokens(), 99);
786    }
787
788    #[test]
789    fn test_token_usage_threads_cached_tokens_into_completion_usage() {
790        let json = r#"{
791            "id": "cmpl-x",
792            "object": "chat.completion",
793            "model": "mistral-small-latest",
794            "created": 1700000000,
795            "choices": [{
796                "index": 0,
797                "message": { "content": "hi", "role": "assistant", "prefix": false },
798                "finish_reason": "stop"
799            }],
800            "usage": {
801                "prompt_tokens": 100,
802                "completion_tokens": 20,
803                "total_tokens": 120,
804                "prompt_tokens_details": { "cached_tokens": 42 }
805            }
806        }"#;
807        let response: CompletionResponse = serde_json::from_str(json).unwrap();
808        let usage = response.token_usage().unwrap();
809        assert_eq!(usage.input_tokens, 100);
810        assert_eq!(usage.output_tokens, 20);
811        assert_eq!(usage.total_tokens, 120);
812        assert_eq!(usage.cached_input_tokens, 42);
813    }
814
815    #[test]
816    fn test_assistant_reasoning_is_skipped_in_message_conversion() {
817        let assistant = message::Message::Assistant {
818            id: None,
819            content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
820        };
821
822        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
823        assert!(converted.is_empty());
824    }
825
826    #[test]
827    fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
828        let assistant = message::Message::Assistant {
829            id: None,
830            content: OneOrMany::many(vec![
831                message::AssistantContent::reasoning("hidden"),
832                message::AssistantContent::text("visible"),
833                message::AssistantContent::tool_call(
834                    "call_1",
835                    "subtract",
836                    serde_json::json!({"x": 2, "y": 1}),
837                ),
838            ])
839            .expect("non-empty assistant content"),
840        };
841
842        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
843        assert_eq!(converted.len(), 1);
844
845        match &converted[0] {
846            Message::Assistant {
847                content,
848                tool_calls,
849                ..
850            } => {
851                assert_eq!(content, "visible");
852                assert_eq!(tool_calls.len(), 1);
853                assert_eq!(tool_calls[0].id, "call_1");
854                assert_eq!(tool_calls[0].function.name, "subtract");
855                assert_eq!(
856                    tool_calls[0].function.arguments,
857                    serde_json::json!({"x": 2, "y": 1})
858                );
859            }
860            _ => panic!("expected assistant message"),
861        }
862    }
863
864    #[test]
865    fn test_streaming_choice_mapping_skips_reasoning_and_preserves_other_content() {
866        let reasoning_choices =
867            assistant_content_to_streaming_choices(message::AssistantContent::reasoning("hidden"))
868                .expect("reasoning should be ignored");
869        assert!(reasoning_choices.is_empty());
870
871        let text_choices =
872            assistant_content_to_streaming_choices(message::AssistantContent::text("visible"))
873                .expect("text should be preserved");
874        match text_choices.as_slice() {
875            [RawStreamingChoice::Message(text)] => assert_eq!(text, "visible"),
876            _ => panic!("expected text streaming choice"),
877        }
878
879        let tool_choices =
880            assistant_content_to_streaming_choices(message::AssistantContent::tool_call(
881                "call_2",
882                "add",
883                serde_json::json!({"x": 2, "y": 3}),
884            ))
885            .expect("tool call should be preserved");
886        match tool_choices.as_slice() {
887            [RawStreamingChoice::ToolCall(call)] => {
888                assert_eq!(call.id, "call_2");
889                assert_eq!(call.name, "add");
890                assert_eq!(call.arguments, serde_json::json!({"x": 2, "y": 3}));
891            }
892            _ => panic!("expected tool-call streaming choice"),
893        }
894    }
895
896    #[test]
897    fn test_request_conversion_errors_when_all_messages_are_filtered() {
898        let request = CompletionRequest {
899            preamble: None,
900            chat_history: OneOrMany::one(message::Message::Assistant {
901                id: None,
902                content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
903            }),
904            documents: vec![],
905            tools: vec![],
906            temperature: None,
907            max_tokens: None,
908            tool_choice: None,
909            additional_params: None,
910            model: None,
911            output_schema: None,
912        };
913
914        let result = MistralCompletionRequest::try_from((MISTRAL_SMALL, request));
915        assert!(matches!(result, Err(CompletionError::RequestError(_))));
916    }
917}