1use serde::{Deserialize, Serialize};
2use std::{convert::Infallible, str::FromStr};
3use tracing::{Instrument, Level, enabled, info_span};
4
5use super::client::{Client, Usage};
6use crate::completion::GetTokenUsage;
7use crate::http_client::{self, HttpClientExt};
8use crate::providers::internal::buffered;
9use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
10use crate::{
11 OneOrMany,
12 completion::{self, CompletionError, CompletionRequest},
13 json_utils, message,
14 providers::mistral::client::ApiResponse,
15 telemetry::SpanCombinator,
16};
17
18pub const CODESTRAL: &str = "codestral-latest";
20pub const MISTRAL_LARGE: &str = "mistral-large-latest";
22pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
24pub const MISTRAL_SABA: &str = "mistral-saba-latest";
26pub const MINISTRAL_3B: &str = "ministral-3b-latest";
28pub const MINISTRAL_8B: &str = "ministral-8b-latest";
30
31pub const MISTRAL_SMALL: &str = "mistral-small-latest";
33pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
35pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
37pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
39
40#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
45#[serde(tag = "type", rename_all = "lowercase")]
46pub struct AssistantContent {
47 text: String,
48}
49
50#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
51#[serde(tag = "type", rename_all = "lowercase")]
52pub enum UserContent {
53 Text { text: String },
54}
55
56#[derive(Debug, Serialize, Deserialize, Clone)]
57pub struct Choice {
58 pub index: usize,
59 pub message: Message,
60 pub logprobs: Option<serde_json::Value>,
61 pub finish_reason: String,
62}
63
64#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
65#[serde(tag = "role", rename_all = "lowercase")]
66pub enum Message {
67 User {
68 content: String,
69 },
70 Assistant {
71 content: String,
72 #[serde(
73 default,
74 deserialize_with = "json_utils::null_or_vec",
75 skip_serializing_if = "Vec::is_empty"
76 )]
77 tool_calls: Vec<ToolCall>,
78 #[serde(default)]
79 prefix: bool,
80 },
81 System {
82 content: String,
83 },
84 Tool {
85 name: String,
87 content: String,
89 tool_call_id: String,
91 },
92}
93
94impl Message {
95 pub fn user(content: String) -> Self {
96 Message::User { content }
97 }
98
99 pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
100 Message::Assistant {
101 content,
102 tool_calls,
103 prefix,
104 }
105 }
106
107 pub fn system(content: String) -> Self {
108 Message::System { content }
109 }
110}
111
112impl TryFrom<message::Message> for Vec<Message> {
113 type Error = message::MessageError;
114
115 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
116 match message {
117 message::Message::System { content } => Ok(vec![Message::System { content }]),
118 message::Message::User { content } => {
119 let mut tool_result_messages = Vec::new();
120 let mut other_messages = Vec::new();
121
122 for content_item in content {
123 match content_item {
124 message::UserContent::ToolResult(message::ToolResult {
125 id,
126 call_id,
127 content: tool_content,
128 }) => {
129 let call_id_key = call_id.unwrap_or_else(|| id.clone());
130 let content_text = tool_content
131 .into_iter()
132 .find_map(|content_item| match content_item {
133 message::ToolResultContent::Text(text) => Some(text.text),
134 message::ToolResultContent::Image(_) => None,
135 })
136 .unwrap_or_default();
137 tool_result_messages.push(Message::Tool {
138 name: id,
139 content: content_text,
140 tool_call_id: call_id_key,
141 });
142 }
143 message::UserContent::Text(message::Text { text, .. }) => {
144 other_messages.push(Message::User { content: text });
145 }
146 _ => {}
147 }
148 }
149
150 tool_result_messages.append(&mut other_messages);
151 Ok(tool_result_messages)
152 }
153 message::Message::Assistant { content, .. } => {
154 let mut text_content = Vec::new();
155 let mut tool_calls = Vec::new();
156
157 for content in content {
158 match content {
159 message::AssistantContent::Text(text) => text_content.push(text),
160 message::AssistantContent::ToolCall(tool_call) => {
161 tool_calls.push(tool_call)
162 }
163 message::AssistantContent::Reasoning(_) => {
164 }
167 message::AssistantContent::Image(_) => {
168 return Err(message::MessageError::ConversionError(
169 "Mistral assistant messages do not support image content".into(),
170 ));
171 }
172 }
173 }
174
175 if text_content.is_empty() && tool_calls.is_empty() {
176 return Ok(vec![]);
177 }
178
179 Ok(vec![Message::Assistant {
180 content: text_content
181 .into_iter()
182 .next()
183 .map(|content| content.text)
184 .unwrap_or_default(),
185 tool_calls: tool_calls
186 .into_iter()
187 .map(|tool_call| tool_call.into())
188 .collect::<Vec<_>>(),
189 prefix: false,
190 }])
191 }
192 }
193 }
194}
195
196#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
197pub struct ToolCall {
198 pub id: String,
199 #[serde(default)]
200 pub r#type: ToolType,
201 pub function: Function,
202}
203
204impl From<message::ToolCall> for ToolCall {
205 fn from(tool_call: message::ToolCall) -> Self {
206 Self {
207 id: tool_call.id,
208 r#type: ToolType::default(),
209 function: Function {
210 name: tool_call.function.name,
211 arguments: tool_call.function.arguments,
212 },
213 }
214 }
215}
216
217#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
218pub struct Function {
219 pub name: String,
220 #[serde(with = "json_utils::stringified_json")]
221 pub arguments: serde_json::Value,
222}
223
224#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
225#[serde(rename_all = "lowercase")]
226pub enum ToolType {
227 #[default]
228 Function,
229}
230
231#[derive(Debug, Deserialize, Serialize, Clone)]
232pub struct ToolDefinition {
233 pub r#type: String,
234 pub function: completion::ToolDefinition,
235}
236
237impl From<completion::ToolDefinition> for ToolDefinition {
238 fn from(tool: completion::ToolDefinition) -> Self {
239 Self {
240 r#type: "function".into(),
241 function: tool,
242 }
243 }
244}
245
246#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
247pub struct ToolResultContent {
248 #[serde(default)]
249 r#type: ToolResultContentType,
250 text: String,
251}
252
253#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
254#[serde(rename_all = "lowercase")]
255pub enum ToolResultContentType {
256 #[default]
257 Text,
258}
259
260impl From<String> for ToolResultContent {
261 fn from(s: String) -> Self {
262 ToolResultContent {
263 r#type: ToolResultContentType::default(),
264 text: s,
265 }
266 }
267}
268
269impl From<String> for UserContent {
270 fn from(s: String) -> Self {
271 UserContent::Text { text: s }
272 }
273}
274
275impl FromStr for UserContent {
276 type Err = Infallible;
277
278 fn from_str(s: &str) -> Result<Self, Self::Err> {
279 Ok(UserContent::Text {
280 text: s.to_string(),
281 })
282 }
283}
284
285impl From<String> for AssistantContent {
286 fn from(s: String) -> Self {
287 AssistantContent { text: s }
288 }
289}
290
291impl FromStr for AssistantContent {
292 type Err = Infallible;
293
294 fn from_str(s: &str) -> Result<Self, Self::Err> {
295 Ok(AssistantContent {
296 text: s.to_string(),
297 })
298 }
299}
300
301#[derive(Clone)]
302pub struct CompletionModel<T = reqwest::Client> {
303 pub(crate) client: Client<T>,
304 pub model: String,
305}
306
307#[derive(Debug, Default, Serialize, Deserialize)]
308pub enum ToolChoice {
309 #[default]
310 Auto,
311 None,
312 Any,
313}
314
315impl TryFrom<message::ToolChoice> for ToolChoice {
316 type Error = CompletionError;
317
318 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
319 let res = match value {
320 message::ToolChoice::Auto => Self::Auto,
321 message::ToolChoice::None => Self::None,
322 message::ToolChoice::Required => Self::Any,
323 message::ToolChoice::Specific { .. } => {
324 return Err(CompletionError::ProviderError(
325 "Mistral doesn't support requiring specific tools to be called".to_string(),
326 ));
327 }
328 };
329
330 Ok(res)
331 }
332}
333
334#[derive(Debug, Serialize, Deserialize)]
335pub(super) struct MistralCompletionRequest {
336 model: String,
337 pub messages: Vec<Message>,
338 #[serde(skip_serializing_if = "Option::is_none")]
339 temperature: Option<f64>,
340 #[serde(skip_serializing_if = "Vec::is_empty")]
341 tools: Vec<ToolDefinition>,
342 #[serde(skip_serializing_if = "Option::is_none")]
343 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
344 #[serde(flatten, skip_serializing_if = "Option::is_none")]
345 pub additional_params: Option<serde_json::Value>,
346}
347
348impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest {
349 type Error = CompletionError;
350
351 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
352 if req.output_schema.is_some() {
353 tracing::warn!("Structured outputs currently not supported for Mistral");
354 }
355 let model = req.model.clone().unwrap_or_else(|| model.to_string());
356 let mut full_history: Vec<Message> = match &req.preamble {
357 Some(preamble) => vec![Message::system(preamble.clone())],
358 None => vec![],
359 };
360 if let Some(docs) = req.normalized_documents() {
361 let docs: Vec<Message> = docs.try_into()?;
362 full_history.extend(docs);
363 }
364
365 let chat_history: Vec<Message> = req
366 .chat_history
367 .clone()
368 .into_iter()
369 .map(|message| message.try_into())
370 .collect::<Result<Vec<Vec<Message>>, _>>()?
371 .into_iter()
372 .flatten()
373 .collect();
374
375 full_history.extend(chat_history);
376
377 if full_history.is_empty() {
378 return Err(CompletionError::RequestError(
379 std::io::Error::new(
380 std::io::ErrorKind::InvalidInput,
381 "Mistral request has no provider-compatible messages after conversion",
382 )
383 .into(),
384 ));
385 }
386
387 let tool_choice = req
388 .tool_choice
389 .clone()
390 .map(crate::providers::openai::completion::ToolChoice::try_from)
391 .transpose()?;
392
393 Ok(Self {
394 model: model.to_string(),
395 messages: full_history,
396 temperature: req.temperature,
397 tools: req
398 .tools
399 .clone()
400 .into_iter()
401 .map(ToolDefinition::from)
402 .collect::<Vec<_>>(),
403 tool_choice,
404 additional_params: req.additional_params,
405 })
406 }
407}
408
409impl<T> CompletionModel<T> {
410 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
411 Self {
412 client,
413 model: model.into(),
414 }
415 }
416
417 pub fn with_model(client: Client<T>, model: &str) -> Self {
418 Self {
419 client,
420 model: model.into(),
421 }
422 }
423}
424
425#[derive(Debug, Deserialize, Clone, Serialize)]
426pub struct CompletionResponse {
427 pub id: String,
428 pub object: String,
429 pub created: u64,
430 pub model: String,
431 pub system_fingerprint: Option<String>,
432 pub choices: Vec<Choice>,
433 pub usage: Option<Usage>,
434}
435
436impl crate::telemetry::ProviderResponseExt for CompletionResponse {
437 type OutputMessage = Choice;
438 type Usage = Usage;
439
440 fn get_response_id(&self) -> Option<String> {
441 Some(self.id.clone())
442 }
443
444 fn get_response_model_name(&self) -> Option<String> {
445 Some(self.model.clone())
446 }
447
448 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
449 self.choices.clone()
450 }
451
452 fn get_text_response(&self) -> Option<String> {
453 let res = self
454 .choices
455 .iter()
456 .filter_map(|choice| match choice.message {
457 Message::Assistant { ref content, .. } => {
458 if content.is_empty() {
459 None
460 } else {
461 Some(content.to_string())
462 }
463 }
464 _ => None,
465 })
466 .collect::<Vec<String>>()
467 .join("\n");
468
469 if res.is_empty() { None } else { Some(res) }
470 }
471
472 fn get_usage(&self) -> Option<Self::Usage> {
473 self.usage.clone()
474 }
475}
476
477impl GetTokenUsage for CompletionResponse {
478 fn token_usage(&self) -> Option<crate::completion::Usage> {
479 let api_usage = self.usage.as_ref()?;
480
481 let mut usage = crate::completion::Usage::new();
482 usage.input_tokens = api_usage.prompt_tokens as u64;
483 usage.output_tokens = api_usage.completion_tokens as u64;
484 usage.total_tokens = api_usage.total_tokens as u64;
485 usage.cached_input_tokens = api_usage.cached_tokens();
486
487 Some(usage)
488 }
489}
490
491impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
492 type Error = CompletionError;
493
494 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
495 let choice = response.choices.first().ok_or_else(|| {
496 CompletionError::ResponseError("Response contained no choices".to_owned())
497 })?;
498 let content = match &choice.message {
499 Message::Assistant {
500 content,
501 tool_calls,
502 ..
503 } => {
504 let mut content = if content.is_empty() {
505 vec![]
506 } else {
507 vec![completion::AssistantContent::text(content.clone())]
508 };
509
510 content.extend(
511 tool_calls
512 .iter()
513 .map(|call| {
514 completion::AssistantContent::tool_call(
515 &call.id,
516 &call.function.name,
517 call.function.arguments.clone(),
518 )
519 })
520 .collect::<Vec<_>>(),
521 );
522 Ok(content)
523 }
524 _ => Err(CompletionError::ResponseError(
525 "Response did not contain a valid message or tool call".into(),
526 )),
527 }?;
528
529 let choice = OneOrMany::many(content).map_err(|_| {
530 CompletionError::ResponseError(
531 "Response contained no message or tool call (empty)".to_owned(),
532 )
533 })?;
534
535 let usage = response
536 .usage
537 .as_ref()
538 .map(|usage| completion::Usage {
539 input_tokens: usage.prompt_tokens as u64,
540 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
541 total_tokens: usage.total_tokens as u64,
542 cached_input_tokens: usage.cached_tokens(),
543 cache_creation_input_tokens: 0,
544 tool_use_prompt_tokens: 0,
545 reasoning_tokens: 0,
546 })
547 .unwrap_or_default();
548
549 Ok(completion::CompletionResponse {
550 choice,
551 usage,
552 raw_response: response,
553 message_id: None,
554 })
555 }
556}
557
558fn assistant_content_to_streaming_choices(
559 content: message::AssistantContent,
560) -> Result<Vec<RawStreamingChoice<CompletionResponse>>, CompletionError> {
561 match content {
562 message::AssistantContent::Text(t) => Ok(vec![RawStreamingChoice::Message(t.text)]),
563 message::AssistantContent::ToolCall(tc) => Ok(vec![RawStreamingChoice::ToolCall(
564 RawStreamingToolCall::new(tc.id, tc.function.name, tc.function.arguments),
565 )]),
566 message::AssistantContent::Reasoning(_) => Ok(Vec::new()),
567 message::AssistantContent::Image(_) => Err(CompletionError::ResponseError(
568 "Image content is not supported on Mistral via Rig".into(),
569 )),
570 }
571}
572
573impl<T> completion::CompletionModel for CompletionModel<T>
574where
575 T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static,
576{
577 type Response = CompletionResponse;
578 type StreamingResponse = CompletionResponse;
579
580 type Client = Client<T>;
581
582 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
583 Self::new(client.clone(), model.into())
584 }
585
586 async fn completion(
587 &self,
588 completion_request: CompletionRequest,
589 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
590 let preamble = completion_request.preamble.clone();
591 let request =
592 MistralCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
593
594 if enabled!(Level::TRACE) {
595 tracing::trace!(
596 target: "rig::completions",
597 "Mistral completion request: {}",
598 serde_json::to_string_pretty(&request)?
599 );
600 }
601
602 let span = if tracing::Span::current().is_disabled() {
603 info_span!(
604 target: "rig::completions",
605 "chat",
606 gen_ai.operation.name = "chat",
607 gen_ai.provider.name = "mistral",
608 gen_ai.request.model = self.model,
609 gen_ai.system_instructions = &preamble,
610 gen_ai.response.id = tracing::field::Empty,
611 gen_ai.response.model = tracing::field::Empty,
612 gen_ai.usage.output_tokens = tracing::field::Empty,
613 gen_ai.usage.input_tokens = tracing::field::Empty,
614 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
615 )
616 } else {
617 tracing::Span::current()
618 };
619
620 let body = serde_json::to_vec(&request)?;
621
622 let request = self
623 .client
624 .post("v1/chat/completions")?
625 .body(body)
626 .map_err(|e| CompletionError::HttpError(e.into()))?;
627
628 async move {
629 let response = self.client.send(request).await?;
630
631 if response.status().is_success() {
632 let text = http_client::text(response).await?;
633 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
634 ApiResponse::Ok(response) => {
635 let span = tracing::Span::current();
636 span.record_token_usage(&response);
637 span.record_response_metadata(&response);
638 response.try_into()
639 }
640 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
641 }
642 } else {
643 let text = http_client::text(response).await?;
644 Err(CompletionError::ProviderError(text))
645 }
646 }
647 .instrument(span)
648 .await
649 }
650
651 async fn stream(
652 &self,
653 request: CompletionRequest,
654 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
655 let resp = self.completion(request).await?;
656 buffered::stream_from_completion_response(resp, assistant_content_to_streaming_choices)
657 }
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663
664 #[test]
665 fn test_response_deserialization() {
666 let json_data = r#"
668 {
669 "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
670 "object": "chat.completion",
671 "model": "mistral-small-latest",
672 "usage": {
673 "prompt_tokens": 16,
674 "completion_tokens": 34,
675 "total_tokens": 50
676 },
677 "created": 1702256327,
678 "choices": [
679 {
680 "index": 0,
681 "message": {
682 "content": "string",
683 "tool_calls": [
684 {
685 "id": "null",
686 "type": "function",
687 "function": {
688 "name": "string",
689 "arguments": "{ }"
690 },
691 "index": 0
692 }
693 ],
694 "prefix": false,
695 "role": "assistant"
696 },
697 "finish_reason": "stop"
698 }
699 ]
700 }
701 "#;
702 let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
703 assert_eq!(completion_response.model, MISTRAL_SMALL);
704
705 let CompletionResponse {
706 id,
707 object,
708 created,
709 choices,
710 usage,
711 ..
712 } = completion_response;
713
714 assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
715
716 let usage = usage.unwrap();
717 assert_eq!(usage.prompt_tokens, 16);
718 assert_eq!(usage.completion_tokens, 34);
719 assert_eq!(usage.total_tokens, 50);
720 assert_eq!(usage.cached_tokens(), 0);
721 assert!(usage.prompt_tokens_details.is_none());
722 assert!(usage.num_cached_tokens.is_none());
723 assert_eq!(object, "chat.completion".to_string());
724 assert_eq!(created, 1702256327);
725 assert_eq!(choices.len(), 1);
726 }
727
728 #[test]
729 fn test_usage_deserializes_prompt_tokens_details_cached_tokens() {
730 let json = r#"{
731 "prompt_tokens": 100,
732 "completion_tokens": 20,
733 "total_tokens": 120,
734 "prompt_tokens_details": { "cached_tokens": 42 }
735 }"#;
736 let usage: Usage = serde_json::from_str(json).unwrap();
737 assert_eq!(usage.prompt_tokens, 100);
738 assert_eq!(
739 usage.prompt_tokens_details.as_ref().unwrap().cached_tokens,
740 42
741 );
742 assert_eq!(usage.cached_tokens(), 42);
743 }
744
745 #[test]
746 fn test_usage_accepts_singular_prompt_token_details_alias() {
747 let json = r#"{
748 "prompt_tokens": 100,
749 "completion_tokens": 20,
750 "total_tokens": 120,
751 "prompt_token_details": { "cached_tokens": 7 }
752 }"#;
753 let usage: Usage = serde_json::from_str(json).unwrap();
754 assert_eq!(
755 usage.prompt_tokens_details.as_ref().unwrap().cached_tokens,
756 7
757 );
758 assert_eq!(usage.cached_tokens(), 7);
759 }
760
761 #[test]
762 fn test_usage_falls_back_to_num_cached_tokens() {
763 let json = r#"{
764 "prompt_tokens": 100,
765 "completion_tokens": 20,
766 "total_tokens": 120,
767 "num_cached_tokens": 13
768 }"#;
769 let usage: Usage = serde_json::from_str(json).unwrap();
770 assert_eq!(usage.num_cached_tokens, Some(13));
771 assert!(usage.prompt_tokens_details.is_none());
772 assert_eq!(usage.cached_tokens(), 13);
773 }
774
775 #[test]
776 fn test_usage_prefers_prompt_tokens_details_over_num_cached_tokens() {
777 let json = r#"{
778 "prompt_tokens": 100,
779 "completion_tokens": 20,
780 "total_tokens": 120,
781 "num_cached_tokens": 1,
782 "prompt_tokens_details": { "cached_tokens": 99 }
783 }"#;
784 let usage: Usage = serde_json::from_str(json).unwrap();
785 assert_eq!(usage.cached_tokens(), 99);
786 }
787
788 #[test]
789 fn test_token_usage_threads_cached_tokens_into_completion_usage() {
790 let json = r#"{
791 "id": "cmpl-x",
792 "object": "chat.completion",
793 "model": "mistral-small-latest",
794 "created": 1700000000,
795 "choices": [{
796 "index": 0,
797 "message": { "content": "hi", "role": "assistant", "prefix": false },
798 "finish_reason": "stop"
799 }],
800 "usage": {
801 "prompt_tokens": 100,
802 "completion_tokens": 20,
803 "total_tokens": 120,
804 "prompt_tokens_details": { "cached_tokens": 42 }
805 }
806 }"#;
807 let response: CompletionResponse = serde_json::from_str(json).unwrap();
808 let usage = response.token_usage().unwrap();
809 assert_eq!(usage.input_tokens, 100);
810 assert_eq!(usage.output_tokens, 20);
811 assert_eq!(usage.total_tokens, 120);
812 assert_eq!(usage.cached_input_tokens, 42);
813 }
814
815 #[test]
816 fn test_assistant_reasoning_is_skipped_in_message_conversion() {
817 let assistant = message::Message::Assistant {
818 id: None,
819 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
820 };
821
822 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
823 assert!(converted.is_empty());
824 }
825
826 #[test]
827 fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
828 let assistant = message::Message::Assistant {
829 id: None,
830 content: OneOrMany::many(vec![
831 message::AssistantContent::reasoning("hidden"),
832 message::AssistantContent::text("visible"),
833 message::AssistantContent::tool_call(
834 "call_1",
835 "subtract",
836 serde_json::json!({"x": 2, "y": 1}),
837 ),
838 ])
839 .expect("non-empty assistant content"),
840 };
841
842 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
843 assert_eq!(converted.len(), 1);
844
845 match &converted[0] {
846 Message::Assistant {
847 content,
848 tool_calls,
849 ..
850 } => {
851 assert_eq!(content, "visible");
852 assert_eq!(tool_calls.len(), 1);
853 assert_eq!(tool_calls[0].id, "call_1");
854 assert_eq!(tool_calls[0].function.name, "subtract");
855 assert_eq!(
856 tool_calls[0].function.arguments,
857 serde_json::json!({"x": 2, "y": 1})
858 );
859 }
860 _ => panic!("expected assistant message"),
861 }
862 }
863
864 #[test]
865 fn test_streaming_choice_mapping_skips_reasoning_and_preserves_other_content() {
866 let reasoning_choices =
867 assistant_content_to_streaming_choices(message::AssistantContent::reasoning("hidden"))
868 .expect("reasoning should be ignored");
869 assert!(reasoning_choices.is_empty());
870
871 let text_choices =
872 assistant_content_to_streaming_choices(message::AssistantContent::text("visible"))
873 .expect("text should be preserved");
874 match text_choices.as_slice() {
875 [RawStreamingChoice::Message(text)] => assert_eq!(text, "visible"),
876 _ => panic!("expected text streaming choice"),
877 }
878
879 let tool_choices =
880 assistant_content_to_streaming_choices(message::AssistantContent::tool_call(
881 "call_2",
882 "add",
883 serde_json::json!({"x": 2, "y": 3}),
884 ))
885 .expect("tool call should be preserved");
886 match tool_choices.as_slice() {
887 [RawStreamingChoice::ToolCall(call)] => {
888 assert_eq!(call.id, "call_2");
889 assert_eq!(call.name, "add");
890 assert_eq!(call.arguments, serde_json::json!({"x": 2, "y": 3}));
891 }
892 _ => panic!("expected tool-call streaming choice"),
893 }
894 }
895
896 #[test]
897 fn test_request_conversion_errors_when_all_messages_are_filtered() {
898 let request = CompletionRequest {
899 preamble: None,
900 chat_history: OneOrMany::one(message::Message::Assistant {
901 id: None,
902 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
903 }),
904 documents: vec![],
905 tools: vec![],
906 temperature: None,
907 max_tokens: None,
908 tool_choice: None,
909 additional_params: None,
910 model: None,
911 output_schema: None,
912 };
913
914 let result = MistralCompletionRequest::try_from((MISTRAL_SMALL, request));
915 assert!(matches!(result, Err(CompletionError::RequestError(_))));
916 }
917}