1use serde::{Deserialize, Serialize};
2use std::{convert::Infallible, str::FromStr};
3use tracing::{Instrument, Level, enabled, info_span};
4
5use super::client::{Client, Usage};
6use crate::completion::GetTokenUsage;
7use crate::http_client::{self, HttpClientExt};
8use crate::providers::internal::buffered;
9use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
10use crate::{
11 OneOrMany,
12 completion::{self, CompletionError, CompletionRequest},
13 json_utils, message,
14 providers::mistral::client::ApiResponse,
15 telemetry::SpanCombinator,
16};
17
18pub const CODESTRAL: &str = "codestral-latest";
20pub const MISTRAL_LARGE: &str = "mistral-large-latest";
22pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
24pub const MISTRAL_SABA: &str = "mistral-saba-latest";
26pub const MINISTRAL_3B: &str = "ministral-3b-latest";
28pub const MINISTRAL_8B: &str = "ministral-8b-latest";
30
31pub const MISTRAL_SMALL: &str = "mistral-small-latest";
33pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
35pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
37pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
39
40#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
45#[serde(tag = "type", rename_all = "lowercase")]
46pub struct AssistantContent {
47 text: String,
48}
49
50#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
51#[serde(tag = "type", rename_all = "lowercase")]
52pub enum UserContent {
53 Text { text: String },
54}
55
56#[derive(Debug, Serialize, Deserialize, Clone)]
57pub struct Choice {
58 pub index: usize,
59 pub message: Message,
60 pub logprobs: Option<serde_json::Value>,
61 pub finish_reason: String,
62}
63
64#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
65#[serde(tag = "role", rename_all = "lowercase")]
66pub enum Message {
67 User {
68 content: String,
69 },
70 Assistant {
71 content: String,
72 #[serde(
73 default,
74 deserialize_with = "json_utils::null_or_vec",
75 skip_serializing_if = "Vec::is_empty"
76 )]
77 tool_calls: Vec<ToolCall>,
78 #[serde(default)]
79 prefix: bool,
80 },
81 System {
82 content: String,
83 },
84 Tool {
85 name: String,
87 content: String,
89 tool_call_id: String,
91 },
92}
93
94impl Message {
95 pub fn user(content: String) -> Self {
96 Message::User { content }
97 }
98
99 pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
100 Message::Assistant {
101 content,
102 tool_calls,
103 prefix,
104 }
105 }
106
107 pub fn system(content: String) -> Self {
108 Message::System { content }
109 }
110}
111
112impl TryFrom<message::Message> for Vec<Message> {
113 type Error = message::MessageError;
114
115 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
116 match message {
117 message::Message::System { content } => Ok(vec![Message::System { content }]),
118 message::Message::User { content } => {
119 let mut tool_result_messages = Vec::new();
120 let mut other_messages = Vec::new();
121
122 for content_item in content {
123 match content_item {
124 message::UserContent::ToolResult(message::ToolResult {
125 id,
126 call_id,
127 content: tool_content,
128 }) => {
129 let call_id_key = call_id.unwrap_or_else(|| id.clone());
130 let content_text = tool_content
131 .into_iter()
132 .find_map(|content_item| match content_item {
133 message::ToolResultContent::Text(text) => Some(text.text),
134 message::ToolResultContent::Image(_) => None,
135 })
136 .unwrap_or_default();
137 tool_result_messages.push(Message::Tool {
138 name: id,
139 content: content_text,
140 tool_call_id: call_id_key,
141 });
142 }
143 message::UserContent::Text(message::Text { text, .. }) => {
144 other_messages.push(Message::User { content: text });
145 }
146 _ => {}
147 }
148 }
149
150 tool_result_messages.append(&mut other_messages);
151 Ok(tool_result_messages)
152 }
153 message::Message::Assistant { content, .. } => {
154 let mut text_content = Vec::new();
155 let mut tool_calls = Vec::new();
156
157 for content in content {
158 match content {
159 message::AssistantContent::Text(text) => text_content.push(text),
160 message::AssistantContent::ToolCall(tool_call) => {
161 tool_calls.push(tool_call)
162 }
163 message::AssistantContent::Reasoning(_) => {
164 }
167 message::AssistantContent::Image(_) => {
168 return Err(message::MessageError::ConversionError(
169 "Mistral assistant messages do not support image content".into(),
170 ));
171 }
172 }
173 }
174
175 if text_content.is_empty() && tool_calls.is_empty() {
176 return Ok(vec![]);
177 }
178
179 Ok(vec![Message::Assistant {
180 content: text_content
181 .into_iter()
182 .next()
183 .map(|content| content.text)
184 .unwrap_or_default(),
185 tool_calls: tool_calls
186 .into_iter()
187 .map(|tool_call| tool_call.into())
188 .collect::<Vec<_>>(),
189 prefix: false,
190 }])
191 }
192 }
193 }
194}
195
196#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
197pub struct ToolCall {
198 pub id: String,
199 #[serde(default)]
200 pub r#type: ToolType,
201 pub function: Function,
202}
203
204impl From<message::ToolCall> for ToolCall {
205 fn from(tool_call: message::ToolCall) -> Self {
206 Self {
207 id: tool_call.id,
208 r#type: ToolType::default(),
209 function: Function {
210 name: tool_call.function.name,
211 arguments: tool_call.function.arguments,
212 },
213 }
214 }
215}
216
217#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
218pub struct Function {
219 pub name: String,
220 #[serde(with = "json_utils::stringified_json")]
221 pub arguments: serde_json::Value,
222}
223
224#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
225#[serde(rename_all = "lowercase")]
226pub enum ToolType {
227 #[default]
228 Function,
229}
230
231#[derive(Debug, Deserialize, Serialize, Clone)]
232pub struct ToolDefinition {
233 pub r#type: String,
234 pub function: completion::ToolDefinition,
235}
236
237impl From<completion::ToolDefinition> for ToolDefinition {
238 fn from(tool: completion::ToolDefinition) -> Self {
239 Self {
240 r#type: "function".into(),
241 function: tool,
242 }
243 }
244}
245
246#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
247pub struct ToolResultContent {
248 #[serde(default)]
249 r#type: ToolResultContentType,
250 text: String,
251}
252
253#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
254#[serde(rename_all = "lowercase")]
255pub enum ToolResultContentType {
256 #[default]
257 Text,
258}
259
260impl From<String> for ToolResultContent {
261 fn from(s: String) -> Self {
262 ToolResultContent {
263 r#type: ToolResultContentType::default(),
264 text: s,
265 }
266 }
267}
268
269impl From<String> for UserContent {
270 fn from(s: String) -> Self {
271 UserContent::Text { text: s }
272 }
273}
274
275impl FromStr for UserContent {
276 type Err = Infallible;
277
278 fn from_str(s: &str) -> Result<Self, Self::Err> {
279 Ok(UserContent::Text {
280 text: s.to_string(),
281 })
282 }
283}
284
285impl From<String> for AssistantContent {
286 fn from(s: String) -> Self {
287 AssistantContent { text: s }
288 }
289}
290
291impl FromStr for AssistantContent {
292 type Err = Infallible;
293
294 fn from_str(s: &str) -> Result<Self, Self::Err> {
295 Ok(AssistantContent {
296 text: s.to_string(),
297 })
298 }
299}
300
301#[derive(Clone)]
302pub struct CompletionModel<T = reqwest::Client> {
303 pub(crate) client: Client<T>,
304 pub model: String,
305}
306
307#[derive(Debug, Default, Serialize, Deserialize)]
308pub enum ToolChoice {
309 #[default]
310 Auto,
311 None,
312 Any,
313}
314
315impl TryFrom<message::ToolChoice> for ToolChoice {
316 type Error = CompletionError;
317
318 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
319 let res = match value {
320 message::ToolChoice::Auto => Self::Auto,
321 message::ToolChoice::None => Self::None,
322 message::ToolChoice::Required => Self::Any,
323 message::ToolChoice::Specific { .. } => {
324 return Err(CompletionError::ProviderError(
325 "Mistral doesn't support requiring specific tools to be called".to_string(),
326 ));
327 }
328 };
329
330 Ok(res)
331 }
332}
333
334#[derive(Debug, Serialize, Deserialize)]
335pub(super) struct MistralCompletionRequest {
336 model: String,
337 pub messages: Vec<Message>,
338 #[serde(skip_serializing_if = "Option::is_none")]
339 temperature: Option<f64>,
340 #[serde(skip_serializing_if = "Vec::is_empty")]
341 tools: Vec<ToolDefinition>,
342 #[serde(skip_serializing_if = "Option::is_none")]
343 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
344 #[serde(flatten, skip_serializing_if = "Option::is_none")]
345 pub additional_params: Option<serde_json::Value>,
346}
347
348impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest {
349 type Error = CompletionError;
350
351 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
352 let chat_history = req.chat_history_with_documents();
353 if req.output_schema.is_some() {
354 tracing::warn!("Structured outputs currently not supported for Mistral");
355 }
356 let model = req.model.clone().unwrap_or_else(|| model.to_string());
357 let mut full_history: Vec<Message> = match &req.preamble {
358 Some(preamble) => vec![Message::system(preamble.clone())],
359 None => vec![],
360 };
361 let chat_history: Vec<Message> = chat_history
362 .into_iter()
363 .map(|message| message.try_into())
364 .collect::<Result<Vec<Vec<Message>>, _>>()?
365 .into_iter()
366 .flatten()
367 .collect();
368
369 full_history.extend(chat_history);
370
371 if full_history.is_empty() {
372 return Err(CompletionError::RequestError(
373 std::io::Error::new(
374 std::io::ErrorKind::InvalidInput,
375 "Mistral request has no provider-compatible messages after conversion",
376 )
377 .into(),
378 ));
379 }
380
381 let tool_choice = req
382 .tool_choice
383 .clone()
384 .map(crate::providers::openai::completion::ToolChoice::try_from)
385 .transpose()?;
386
387 Ok(Self {
388 model: model.to_string(),
389 messages: full_history,
390 temperature: req.temperature,
391 tools: req
392 .tools
393 .clone()
394 .into_iter()
395 .map(ToolDefinition::from)
396 .collect::<Vec<_>>(),
397 tool_choice,
398 additional_params: req.additional_params,
399 })
400 }
401}
402
403impl<T> CompletionModel<T> {
404 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
405 Self {
406 client,
407 model: model.into(),
408 }
409 }
410
411 pub fn with_model(client: Client<T>, model: &str) -> Self {
412 Self {
413 client,
414 model: model.into(),
415 }
416 }
417}
418
419#[derive(Debug, Deserialize, Clone, Serialize)]
420pub struct CompletionResponse {
421 pub id: String,
422 pub object: String,
423 pub created: u64,
424 pub model: String,
425 pub system_fingerprint: Option<String>,
426 pub choices: Vec<Choice>,
427 pub usage: Option<Usage>,
428}
429
430impl crate::telemetry::ProviderResponseExt for CompletionResponse {
431 type OutputMessage = Choice;
432 type Usage = Usage;
433
434 fn get_response_id(&self) -> Option<String> {
435 Some(self.id.clone())
436 }
437
438 fn get_response_model_name(&self) -> Option<String> {
439 Some(self.model.clone())
440 }
441
442 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
443 self.choices.clone()
444 }
445
446 fn get_text_response(&self) -> Option<String> {
447 let res = self
448 .choices
449 .iter()
450 .filter_map(|choice| match choice.message {
451 Message::Assistant { ref content, .. } => {
452 if content.is_empty() {
453 None
454 } else {
455 Some(content.to_string())
456 }
457 }
458 _ => None,
459 })
460 .collect::<Vec<String>>()
461 .join("\n");
462
463 if res.is_empty() { None } else { Some(res) }
464 }
465
466 fn get_usage(&self) -> Option<Self::Usage> {
467 self.usage.clone()
468 }
469}
470
471impl GetTokenUsage for CompletionResponse {
472 fn token_usage(&self) -> crate::completion::Usage {
473 let Some(api_usage) = self.usage.as_ref() else {
474 return crate::completion::Usage::new();
475 };
476
477 let mut usage = crate::completion::Usage::new();
478 usage.input_tokens = api_usage.prompt_tokens as u64;
479 usage.output_tokens = api_usage.completion_tokens as u64;
480 usage.total_tokens = api_usage.total_tokens as u64;
481 usage.cached_input_tokens = api_usage.cached_tokens();
482
483 usage
484 }
485}
486
487impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
488 type Error = CompletionError;
489
490 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
491 let choice = response.choices.first().ok_or_else(|| {
492 CompletionError::ResponseError("Response contained no choices".to_owned())
493 })?;
494 let content = match &choice.message {
495 Message::Assistant {
496 content,
497 tool_calls,
498 ..
499 } => {
500 let mut content = if content.is_empty() {
501 vec![]
502 } else {
503 vec![completion::AssistantContent::text(content.clone())]
504 };
505
506 content.extend(
507 tool_calls
508 .iter()
509 .map(|call| {
510 completion::AssistantContent::tool_call(
511 &call.id,
512 &call.function.name,
513 call.function.arguments.clone(),
514 )
515 })
516 .collect::<Vec<_>>(),
517 );
518 Ok(content)
519 }
520 _ => Err(CompletionError::ResponseError(
521 "Response did not contain a valid message or tool call".into(),
522 )),
523 }?;
524
525 let choice = OneOrMany::many(content).map_err(|_| {
526 CompletionError::ResponseError(
527 "Response contained no message or tool call (empty)".to_owned(),
528 )
529 })?;
530
531 let usage = response
532 .usage
533 .as_ref()
534 .map(|usage| completion::Usage {
535 input_tokens: usage.prompt_tokens as u64,
536 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
537 total_tokens: usage.total_tokens as u64,
538 cached_input_tokens: usage.cached_tokens(),
539 cache_creation_input_tokens: 0,
540 tool_use_prompt_tokens: 0,
541 reasoning_tokens: 0,
542 })
543 .unwrap_or_default();
544
545 Ok(completion::CompletionResponse {
546 choice,
547 usage,
548 raw_response: response,
549 message_id: None,
550 })
551 }
552}
553
554fn assistant_content_to_streaming_choices(
555 content: message::AssistantContent,
556) -> Result<Vec<RawStreamingChoice<CompletionResponse>>, CompletionError> {
557 match content {
558 message::AssistantContent::Text(t) => Ok(vec![RawStreamingChoice::Message(t.text)]),
559 message::AssistantContent::ToolCall(tc) => Ok(vec![RawStreamingChoice::ToolCall(
560 RawStreamingToolCall::new(tc.id, tc.function.name, tc.function.arguments),
561 )]),
562 message::AssistantContent::Reasoning(_) => Ok(Vec::new()),
563 message::AssistantContent::Image(_) => Err(CompletionError::ResponseError(
564 "Image content is not supported on Mistral via Rig".into(),
565 )),
566 }
567}
568
569impl<T> completion::CompletionModel for CompletionModel<T>
570where
571 T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static,
572{
573 type Response = CompletionResponse;
574 type StreamingResponse = CompletionResponse;
575
576 type Client = Client<T>;
577
578 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
579 Self::new(client.clone(), model.into())
580 }
581
582 async fn completion(
583 &self,
584 completion_request: CompletionRequest,
585 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
586 let preamble = completion_request.preamble.clone();
587 let request =
588 MistralCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
589
590 if enabled!(Level::TRACE) {
591 tracing::trace!(
592 target: "rig::completions",
593 "Mistral completion request: {}",
594 serde_json::to_string_pretty(&request)?
595 );
596 }
597
598 let span = if tracing::Span::current().is_disabled() {
599 info_span!(
600 target: "rig::completions",
601 "chat",
602 gen_ai.operation.name = "chat",
603 gen_ai.provider.name = "mistral",
604 gen_ai.request.model = self.model,
605 gen_ai.system_instructions = &preamble,
606 gen_ai.response.id = tracing::field::Empty,
607 gen_ai.response.model = tracing::field::Empty,
608 gen_ai.usage.output_tokens = tracing::field::Empty,
609 gen_ai.usage.input_tokens = tracing::field::Empty,
610 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
611 )
612 } else {
613 tracing::Span::current()
614 };
615
616 let body = serde_json::to_vec(&request)?;
617
618 let request = self
619 .client
620 .post("v1/chat/completions")?
621 .body(body)
622 .map_err(|e| CompletionError::HttpError(e.into()))?;
623
624 async move {
625 let response = self.client.send(request).await?;
626
627 if response.status().is_success() {
628 let text = http_client::text(response).await?;
629 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
630 ApiResponse::Ok(response) => {
631 let span = tracing::Span::current();
632 span.record_token_usage(&response);
633 span.record_response_metadata(&response);
634 response.try_into()
635 }
636 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
637 }
638 } else {
639 let text = http_client::text(response).await?;
640 Err(CompletionError::ProviderError(text))
641 }
642 }
643 .instrument(span)
644 .await
645 }
646
647 async fn stream(
648 &self,
649 request: CompletionRequest,
650 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
651 let resp = self.completion(request).await?;
652 buffered::stream_from_completion_response(resp, assistant_content_to_streaming_choices)
653 }
654}
655
656#[cfg(test)]
657mod tests {
658 use super::*;
659
660 #[test]
661 fn test_response_deserialization() {
662 let json_data = r#"
664 {
665 "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
666 "object": "chat.completion",
667 "model": "mistral-small-latest",
668 "usage": {
669 "prompt_tokens": 16,
670 "completion_tokens": 34,
671 "total_tokens": 50
672 },
673 "created": 1702256327,
674 "choices": [
675 {
676 "index": 0,
677 "message": {
678 "content": "string",
679 "tool_calls": [
680 {
681 "id": "null",
682 "type": "function",
683 "function": {
684 "name": "string",
685 "arguments": "{ }"
686 },
687 "index": 0
688 }
689 ],
690 "prefix": false,
691 "role": "assistant"
692 },
693 "finish_reason": "stop"
694 }
695 ]
696 }
697 "#;
698 let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
699 assert_eq!(completion_response.model, MISTRAL_SMALL);
700
701 let CompletionResponse {
702 id,
703 object,
704 created,
705 choices,
706 usage,
707 ..
708 } = completion_response;
709
710 assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
711
712 let usage = usage.unwrap();
713 assert_eq!(usage.prompt_tokens, 16);
714 assert_eq!(usage.completion_tokens, 34);
715 assert_eq!(usage.total_tokens, 50);
716 assert_eq!(usage.cached_tokens(), 0);
717 assert!(usage.prompt_tokens_details.is_none());
718 assert!(usage.num_cached_tokens.is_none());
719 assert_eq!(object, "chat.completion".to_string());
720 assert_eq!(created, 1702256327);
721 assert_eq!(choices.len(), 1);
722 }
723
724 #[test]
725 fn test_usage_deserializes_prompt_tokens_details_cached_tokens() {
726 let json = r#"{
727 "prompt_tokens": 100,
728 "completion_tokens": 20,
729 "total_tokens": 120,
730 "prompt_tokens_details": { "cached_tokens": 42 }
731 }"#;
732 let usage: Usage = serde_json::from_str(json).unwrap();
733 assert_eq!(usage.prompt_tokens, 100);
734 assert_eq!(
735 usage.prompt_tokens_details.as_ref().unwrap().cached_tokens,
736 42
737 );
738 assert_eq!(usage.cached_tokens(), 42);
739 }
740
741 #[test]
742 fn test_usage_accepts_singular_prompt_token_details_alias() {
743 let json = r#"{
744 "prompt_tokens": 100,
745 "completion_tokens": 20,
746 "total_tokens": 120,
747 "prompt_token_details": { "cached_tokens": 7 }
748 }"#;
749 let usage: Usage = serde_json::from_str(json).unwrap();
750 assert_eq!(
751 usage.prompt_tokens_details.as_ref().unwrap().cached_tokens,
752 7
753 );
754 assert_eq!(usage.cached_tokens(), 7);
755 }
756
757 #[test]
758 fn test_usage_falls_back_to_num_cached_tokens() {
759 let json = r#"{
760 "prompt_tokens": 100,
761 "completion_tokens": 20,
762 "total_tokens": 120,
763 "num_cached_tokens": 13
764 }"#;
765 let usage: Usage = serde_json::from_str(json).unwrap();
766 assert_eq!(usage.num_cached_tokens, Some(13));
767 assert!(usage.prompt_tokens_details.is_none());
768 assert_eq!(usage.cached_tokens(), 13);
769 }
770
771 #[test]
772 fn test_usage_prefers_prompt_tokens_details_over_num_cached_tokens() {
773 let json = r#"{
774 "prompt_tokens": 100,
775 "completion_tokens": 20,
776 "total_tokens": 120,
777 "num_cached_tokens": 1,
778 "prompt_tokens_details": { "cached_tokens": 99 }
779 }"#;
780 let usage: Usage = serde_json::from_str(json).unwrap();
781 assert_eq!(usage.cached_tokens(), 99);
782 }
783
784 #[test]
785 fn test_token_usage_threads_cached_tokens_into_completion_usage() {
786 let json = r#"{
787 "id": "cmpl-x",
788 "object": "chat.completion",
789 "model": "mistral-small-latest",
790 "created": 1700000000,
791 "choices": [{
792 "index": 0,
793 "message": { "content": "hi", "role": "assistant", "prefix": false },
794 "finish_reason": "stop"
795 }],
796 "usage": {
797 "prompt_tokens": 100,
798 "completion_tokens": 20,
799 "total_tokens": 120,
800 "prompt_tokens_details": { "cached_tokens": 42 }
801 }
802 }"#;
803 let response: CompletionResponse = serde_json::from_str(json).unwrap();
804 let usage = response.token_usage();
805 assert_eq!(usage.input_tokens, 100);
806 assert_eq!(usage.output_tokens, 20);
807 assert_eq!(usage.total_tokens, 120);
808 assert_eq!(usage.cached_input_tokens, 42);
809 }
810
811 #[test]
812 fn test_assistant_reasoning_is_skipped_in_message_conversion() {
813 let assistant = message::Message::Assistant {
814 id: None,
815 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
816 };
817
818 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
819 assert!(converted.is_empty());
820 }
821
822 #[test]
823 fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
824 let assistant = message::Message::Assistant {
825 id: None,
826 content: OneOrMany::many(vec![
827 message::AssistantContent::reasoning("hidden"),
828 message::AssistantContent::text("visible"),
829 message::AssistantContent::tool_call(
830 "call_1",
831 "subtract",
832 serde_json::json!({"x": 2, "y": 1}),
833 ),
834 ])
835 .expect("non-empty assistant content"),
836 };
837
838 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
839 assert_eq!(converted.len(), 1);
840
841 match &converted[0] {
842 Message::Assistant {
843 content,
844 tool_calls,
845 ..
846 } => {
847 assert_eq!(content, "visible");
848 assert_eq!(tool_calls.len(), 1);
849 assert_eq!(tool_calls[0].id, "call_1");
850 assert_eq!(tool_calls[0].function.name, "subtract");
851 assert_eq!(
852 tool_calls[0].function.arguments,
853 serde_json::json!({"x": 2, "y": 1})
854 );
855 }
856 _ => panic!("expected assistant message"),
857 }
858 }
859
860 #[test]
861 fn test_streaming_choice_mapping_skips_reasoning_and_preserves_other_content() {
862 let reasoning_choices =
863 assistant_content_to_streaming_choices(message::AssistantContent::reasoning("hidden"))
864 .expect("reasoning should be ignored");
865 assert!(reasoning_choices.is_empty());
866
867 let text_choices =
868 assistant_content_to_streaming_choices(message::AssistantContent::text("visible"))
869 .expect("text should be preserved");
870 match text_choices.as_slice() {
871 [RawStreamingChoice::Message(text)] => assert_eq!(text, "visible"),
872 _ => panic!("expected text streaming choice"),
873 }
874
875 let tool_choices =
876 assistant_content_to_streaming_choices(message::AssistantContent::tool_call(
877 "call_2",
878 "add",
879 serde_json::json!({"x": 2, "y": 3}),
880 ))
881 .expect("tool call should be preserved");
882 match tool_choices.as_slice() {
883 [RawStreamingChoice::ToolCall(call)] => {
884 assert_eq!(call.id, "call_2");
885 assert_eq!(call.name, "add");
886 assert_eq!(call.arguments, serde_json::json!({"x": 2, "y": 3}));
887 }
888 _ => panic!("expected tool-call streaming choice"),
889 }
890 }
891
892 #[test]
893 fn test_request_conversion_errors_when_all_messages_are_filtered() {
894 let request = CompletionRequest {
895 preamble: None,
896 chat_history: OneOrMany::one(message::Message::Assistant {
897 id: None,
898 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
899 }),
900 documents: vec![],
901 tools: vec![],
902 temperature: None,
903 max_tokens: None,
904 tool_choice: None,
905 additional_params: None,
906 model: None,
907 output_schema: None,
908 };
909
910 let result = MistralCompletionRequest::try_from((MISTRAL_SMALL, request));
911 assert!(matches!(result, Err(CompletionError::RequestError(_))));
912 }
913}