1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use std::{convert::Infallible, str::FromStr};
4use tracing::{Instrument, Level, enabled, info_span};
5
6use super::client::{Client, Usage};
7use crate::completion::GetTokenUsage;
8use crate::http_client::{self, HttpClientExt};
9use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
10use crate::{
11 OneOrMany,
12 completion::{self, CompletionError, CompletionRequest},
13 json_utils, message,
14 providers::mistral::client::ApiResponse,
15 telemetry::SpanCombinator,
16};
17
18pub const CODESTRAL: &str = "codestral-latest";
20pub const MISTRAL_LARGE: &str = "mistral-large-latest";
22pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
24pub const MISTRAL_SABA: &str = "mistral-saba-latest";
26pub const MINISTRAL_3B: &str = "ministral-3b-latest";
28pub const MINISTRAL_8B: &str = "ministral-8b-latest";
30
31pub const MISTRAL_SMALL: &str = "mistral-small-latest";
33pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
35pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
37pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
39
40#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
45#[serde(tag = "type", rename_all = "lowercase")]
46pub struct AssistantContent {
47 text: String,
48}
49
50#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
51#[serde(tag = "type", rename_all = "lowercase")]
52pub enum UserContent {
53 Text { text: String },
54}
55
56#[derive(Debug, Serialize, Deserialize, Clone)]
57pub struct Choice {
58 pub index: usize,
59 pub message: Message,
60 pub logprobs: Option<serde_json::Value>,
61 pub finish_reason: String,
62}
63
64#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
65#[serde(tag = "role", rename_all = "lowercase")]
66pub enum Message {
67 User {
68 content: String,
69 },
70 Assistant {
71 content: String,
72 #[serde(
73 default,
74 deserialize_with = "json_utils::null_or_vec",
75 skip_serializing_if = "Vec::is_empty"
76 )]
77 tool_calls: Vec<ToolCall>,
78 #[serde(default)]
79 prefix: bool,
80 },
81 System {
82 content: String,
83 },
84 Tool {
85 name: String,
87 content: String,
89 tool_call_id: String,
91 },
92}
93
94impl Message {
95 pub fn user(content: String) -> Self {
96 Message::User { content }
97 }
98
99 pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
100 Message::Assistant {
101 content,
102 tool_calls,
103 prefix,
104 }
105 }
106
107 pub fn system(content: String) -> Self {
108 Message::System { content }
109 }
110}
111
112impl TryFrom<message::Message> for Vec<Message> {
113 type Error = message::MessageError;
114
115 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
116 match message {
117 message::Message::User { content } => {
118 let mut tool_result_messages = Vec::new();
119 let mut other_messages = Vec::new();
120
121 for content_item in content {
122 match content_item {
123 message::UserContent::ToolResult(message::ToolResult {
124 id,
125 call_id,
126 content: tool_content,
127 }) => {
128 let call_id_key = call_id.unwrap_or_else(|| id.clone());
129 let content_text = tool_content
130 .into_iter()
131 .find_map(|content_item| match content_item {
132 message::ToolResultContent::Text(text) => Some(text.text),
133 message::ToolResultContent::Image(_) => None,
134 })
135 .unwrap_or_default();
136 tool_result_messages.push(Message::Tool {
137 name: id,
138 content: content_text,
139 tool_call_id: call_id_key,
140 });
141 }
142 message::UserContent::Text(message::Text { text }) => {
143 other_messages.push(Message::User { content: text });
144 }
145 _ => {}
146 }
147 }
148
149 tool_result_messages.append(&mut other_messages);
150 Ok(tool_result_messages)
151 }
152 message::Message::Assistant { content, .. } => {
153 let mut text_content = Vec::new();
154 let mut tool_calls = Vec::new();
155
156 for content in content {
157 match content {
158 message::AssistantContent::Text(text) => text_content.push(text),
159 message::AssistantContent::ToolCall(tool_call) => {
160 tool_calls.push(tool_call)
161 }
162 message::AssistantContent::Reasoning(_) => {
163 }
166 message::AssistantContent::Image(_) => {
167 panic!("Image content is not currently supported on Mistral via Rig");
168 }
169 }
170 }
171
172 if text_content.is_empty() && tool_calls.is_empty() {
173 return Ok(vec![]);
174 }
175
176 Ok(vec![Message::Assistant {
177 content: text_content
178 .into_iter()
179 .next()
180 .map(|content| content.text)
181 .unwrap_or_default(),
182 tool_calls: tool_calls
183 .into_iter()
184 .map(|tool_call| tool_call.into())
185 .collect::<Vec<_>>(),
186 prefix: false,
187 }])
188 }
189 }
190 }
191}
192
193#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
194pub struct ToolCall {
195 pub id: String,
196 #[serde(default)]
197 pub r#type: ToolType,
198 pub function: Function,
199}
200
201impl From<message::ToolCall> for ToolCall {
202 fn from(tool_call: message::ToolCall) -> Self {
203 Self {
204 id: tool_call.id,
205 r#type: ToolType::default(),
206 function: Function {
207 name: tool_call.function.name,
208 arguments: tool_call.function.arguments,
209 },
210 }
211 }
212}
213
214#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
215pub struct Function {
216 pub name: String,
217 #[serde(with = "json_utils::stringified_json")]
218 pub arguments: serde_json::Value,
219}
220
221#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
222#[serde(rename_all = "lowercase")]
223pub enum ToolType {
224 #[default]
225 Function,
226}
227
228#[derive(Debug, Deserialize, Serialize, Clone)]
229pub struct ToolDefinition {
230 pub r#type: String,
231 pub function: completion::ToolDefinition,
232}
233
234impl From<completion::ToolDefinition> for ToolDefinition {
235 fn from(tool: completion::ToolDefinition) -> Self {
236 Self {
237 r#type: "function".into(),
238 function: tool,
239 }
240 }
241}
242
243#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
244pub struct ToolResultContent {
245 #[serde(default)]
246 r#type: ToolResultContentType,
247 text: String,
248}
249
250#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
251#[serde(rename_all = "lowercase")]
252pub enum ToolResultContentType {
253 #[default]
254 Text,
255}
256
257impl From<String> for ToolResultContent {
258 fn from(s: String) -> Self {
259 ToolResultContent {
260 r#type: ToolResultContentType::default(),
261 text: s,
262 }
263 }
264}
265
266impl From<String> for UserContent {
267 fn from(s: String) -> Self {
268 UserContent::Text { text: s }
269 }
270}
271
272impl FromStr for UserContent {
273 type Err = Infallible;
274
275 fn from_str(s: &str) -> Result<Self, Self::Err> {
276 Ok(UserContent::Text {
277 text: s.to_string(),
278 })
279 }
280}
281
282impl From<String> for AssistantContent {
283 fn from(s: String) -> Self {
284 AssistantContent { text: s }
285 }
286}
287
288impl FromStr for AssistantContent {
289 type Err = Infallible;
290
291 fn from_str(s: &str) -> Result<Self, Self::Err> {
292 Ok(AssistantContent {
293 text: s.to_string(),
294 })
295 }
296}
297
298#[derive(Clone)]
299pub struct CompletionModel<T = reqwest::Client> {
300 pub(crate) client: Client<T>,
301 pub model: String,
302}
303
304#[derive(Debug, Default, Serialize, Deserialize)]
305pub enum ToolChoice {
306 #[default]
307 Auto,
308 None,
309 Any,
310}
311
312impl TryFrom<message::ToolChoice> for ToolChoice {
313 type Error = CompletionError;
314
315 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
316 let res = match value {
317 message::ToolChoice::Auto => Self::Auto,
318 message::ToolChoice::None => Self::None,
319 message::ToolChoice::Required => Self::Any,
320 message::ToolChoice::Specific { .. } => {
321 return Err(CompletionError::ProviderError(
322 "Mistral doesn't support requiring specific tools to be called".to_string(),
323 ));
324 }
325 };
326
327 Ok(res)
328 }
329}
330
331#[derive(Debug, Serialize, Deserialize)]
332pub(super) struct MistralCompletionRequest {
333 model: String,
334 pub messages: Vec<Message>,
335 #[serde(skip_serializing_if = "Option::is_none")]
336 temperature: Option<f64>,
337 #[serde(skip_serializing_if = "Vec::is_empty")]
338 tools: Vec<ToolDefinition>,
339 #[serde(skip_serializing_if = "Option::is_none")]
340 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
341 #[serde(flatten, skip_serializing_if = "Option::is_none")]
342 pub additional_params: Option<serde_json::Value>,
343}
344
345impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest {
346 type Error = CompletionError;
347
348 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
349 if req.output_schema.is_some() {
350 tracing::warn!("Structured outputs currently not supported for Mistral");
351 }
352 let model = req.model.clone().unwrap_or_else(|| model.to_string());
353 let mut full_history: Vec<Message> = match &req.preamble {
354 Some(preamble) => vec![Message::system(preamble.clone())],
355 None => vec![],
356 };
357 if let Some(docs) = req.normalized_documents() {
358 let docs: Vec<Message> = docs.try_into()?;
359 full_history.extend(docs);
360 }
361
362 let chat_history: Vec<Message> = req
363 .chat_history
364 .clone()
365 .into_iter()
366 .map(|message| message.try_into())
367 .collect::<Result<Vec<Vec<Message>>, _>>()?
368 .into_iter()
369 .flatten()
370 .collect();
371
372 full_history.extend(chat_history);
373
374 if full_history.is_empty() {
375 return Err(CompletionError::RequestError(
376 std::io::Error::new(
377 std::io::ErrorKind::InvalidInput,
378 "Mistral request has no provider-compatible messages after conversion",
379 )
380 .into(),
381 ));
382 }
383
384 let tool_choice = req
385 .tool_choice
386 .clone()
387 .map(crate::providers::openai::completion::ToolChoice::try_from)
388 .transpose()?;
389
390 Ok(Self {
391 model: model.to_string(),
392 messages: full_history,
393 temperature: req.temperature,
394 tools: req
395 .tools
396 .clone()
397 .into_iter()
398 .map(ToolDefinition::from)
399 .collect::<Vec<_>>(),
400 tool_choice,
401 additional_params: req.additional_params,
402 })
403 }
404}
405
406impl<T> CompletionModel<T> {
407 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
408 Self {
409 client,
410 model: model.into(),
411 }
412 }
413
414 pub fn with_model(client: Client<T>, model: &str) -> Self {
415 Self {
416 client,
417 model: model.into(),
418 }
419 }
420}
421
422#[derive(Debug, Deserialize, Clone, Serialize)]
423pub struct CompletionResponse {
424 pub id: String,
425 pub object: String,
426 pub created: u64,
427 pub model: String,
428 pub system_fingerprint: Option<String>,
429 pub choices: Vec<Choice>,
430 pub usage: Option<Usage>,
431}
432
433impl crate::telemetry::ProviderResponseExt for CompletionResponse {
434 type OutputMessage = Choice;
435 type Usage = Usage;
436
437 fn get_response_id(&self) -> Option<String> {
438 Some(self.id.clone())
439 }
440
441 fn get_response_model_name(&self) -> Option<String> {
442 Some(self.model.clone())
443 }
444
445 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
446 self.choices.clone()
447 }
448
449 fn get_text_response(&self) -> Option<String> {
450 let res = self
451 .choices
452 .iter()
453 .filter_map(|choice| match choice.message {
454 Message::Assistant { ref content, .. } => {
455 if content.is_empty() {
456 None
457 } else {
458 Some(content.to_string())
459 }
460 }
461 _ => None,
462 })
463 .collect::<Vec<String>>()
464 .join("\n");
465
466 if res.is_empty() { None } else { Some(res) }
467 }
468
469 fn get_usage(&self) -> Option<Self::Usage> {
470 self.usage.clone()
471 }
472}
473
474impl GetTokenUsage for CompletionResponse {
475 fn token_usage(&self) -> Option<crate::completion::Usage> {
476 let api_usage = self.usage.clone()?;
477
478 let mut usage = crate::completion::Usage::new();
479 usage.input_tokens = api_usage.prompt_tokens as u64;
480 usage.output_tokens = api_usage.completion_tokens as u64;
481 usage.total_tokens = api_usage.total_tokens as u64;
482
483 Some(usage)
484 }
485}
486
487impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
488 type Error = CompletionError;
489
490 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
491 let choice = response.choices.first().ok_or_else(|| {
492 CompletionError::ResponseError("Response contained no choices".to_owned())
493 })?;
494 let content = match &choice.message {
495 Message::Assistant {
496 content,
497 tool_calls,
498 ..
499 } => {
500 let mut content = if content.is_empty() {
501 vec![]
502 } else {
503 vec![completion::AssistantContent::text(content.clone())]
504 };
505
506 content.extend(
507 tool_calls
508 .iter()
509 .map(|call| {
510 completion::AssistantContent::tool_call(
511 &call.id,
512 &call.function.name,
513 call.function.arguments.clone(),
514 )
515 })
516 .collect::<Vec<_>>(),
517 );
518 Ok(content)
519 }
520 _ => Err(CompletionError::ResponseError(
521 "Response did not contain a valid message or tool call".into(),
522 )),
523 }?;
524
525 let choice = OneOrMany::many(content).map_err(|_| {
526 CompletionError::ResponseError(
527 "Response contained no message or tool call (empty)".to_owned(),
528 )
529 })?;
530
531 let usage = response
532 .usage
533 .as_ref()
534 .map(|usage| completion::Usage {
535 input_tokens: usage.prompt_tokens as u64,
536 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
537 total_tokens: usage.total_tokens as u64,
538 cached_input_tokens: 0,
539 })
540 .unwrap_or_default();
541
542 Ok(completion::CompletionResponse {
543 choice,
544 usage,
545 raw_response: response,
546 message_id: None,
547 })
548 }
549}
550
551fn assistant_content_to_streaming_choice(
552 content: message::AssistantContent,
553) -> Option<RawStreamingChoice<CompletionResponse>> {
554 match content {
555 message::AssistantContent::Text(t) => Some(RawStreamingChoice::Message(t.text)),
556 message::AssistantContent::ToolCall(tc) => Some(RawStreamingChoice::ToolCall(
557 RawStreamingToolCall::new(tc.id, tc.function.name, tc.function.arguments),
558 )),
559 message::AssistantContent::Reasoning(_) => None,
560 message::AssistantContent::Image(_) => {
561 panic!("Image content is not supported on Mistral via Rig")
562 }
563 }
564}
565
566impl<T> completion::CompletionModel for CompletionModel<T>
567where
568 T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static,
569{
570 type Response = CompletionResponse;
571 type StreamingResponse = CompletionResponse;
572
573 type Client = Client<T>;
574
575 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
576 Self::new(client.clone(), model.into())
577 }
578
579 async fn completion(
580 &self,
581 completion_request: CompletionRequest,
582 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
583 let preamble = completion_request.preamble.clone();
584 let request =
585 MistralCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
586
587 if enabled!(Level::TRACE) {
588 tracing::trace!(
589 target: "rig::completions",
590 "Mistral completion request: {}",
591 serde_json::to_string_pretty(&request)?
592 );
593 }
594
595 let span = if tracing::Span::current().is_disabled() {
596 info_span!(
597 target: "rig::completions",
598 "chat",
599 gen_ai.operation.name = "chat",
600 gen_ai.provider.name = "mistral",
601 gen_ai.request.model = self.model,
602 gen_ai.system_instructions = &preamble,
603 gen_ai.response.id = tracing::field::Empty,
604 gen_ai.response.model = tracing::field::Empty,
605 gen_ai.usage.output_tokens = tracing::field::Empty,
606 gen_ai.usage.input_tokens = tracing::field::Empty,
607 )
608 } else {
609 tracing::Span::current()
610 };
611
612 let body = serde_json::to_vec(&request)?;
613
614 let request = self
615 .client
616 .post("v1/chat/completions")?
617 .body(body)
618 .map_err(|e| CompletionError::HttpError(e.into()))?;
619
620 async move {
621 let response = self.client.send(request).await?;
622
623 if response.status().is_success() {
624 let text = http_client::text(response).await?;
625 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
626 ApiResponse::Ok(response) => {
627 let span = tracing::Span::current();
628 span.record_token_usage(&response);
629 span.record_response_metadata(&response);
630 response.try_into()
631 }
632 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
633 }
634 } else {
635 let text = http_client::text(response).await?;
636 Err(CompletionError::ProviderError(text))
637 }
638 }
639 .instrument(span)
640 .await
641 }
642
643 async fn stream(
644 &self,
645 request: CompletionRequest,
646 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
647 let resp = self.completion(request).await?;
648
649 let stream = stream! {
650 for c in resp.choice.clone() {
651 if let Some(choice) = assistant_content_to_streaming_choice(c) {
652 yield Ok(choice);
653 }
654 }
655
656 yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
657 };
658
659 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
660 }
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666
667 #[test]
668 fn test_response_deserialization() {
669 let json_data = r#"
671 {
672 "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
673 "object": "chat.completion",
674 "model": "mistral-small-latest",
675 "usage": {
676 "prompt_tokens": 16,
677 "completion_tokens": 34,
678 "total_tokens": 50
679 },
680 "created": 1702256327,
681 "choices": [
682 {
683 "index": 0,
684 "message": {
685 "content": "string",
686 "tool_calls": [
687 {
688 "id": "null",
689 "type": "function",
690 "function": {
691 "name": "string",
692 "arguments": "{ }"
693 },
694 "index": 0
695 }
696 ],
697 "prefix": false,
698 "role": "assistant"
699 },
700 "finish_reason": "stop"
701 }
702 ]
703 }
704 "#;
705 let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
706 assert_eq!(completion_response.model, MISTRAL_SMALL);
707
708 let CompletionResponse {
709 id,
710 object,
711 created,
712 choices,
713 usage,
714 ..
715 } = completion_response;
716
717 assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
718
719 let Usage {
720 completion_tokens,
721 prompt_tokens,
722 total_tokens,
723 } = usage.unwrap();
724
725 assert_eq!(prompt_tokens, 16);
726 assert_eq!(completion_tokens, 34);
727 assert_eq!(total_tokens, 50);
728 assert_eq!(object, "chat.completion".to_string());
729 assert_eq!(created, 1702256327);
730 assert_eq!(choices.len(), 1);
731 }
732
733 #[test]
734 fn test_assistant_reasoning_is_skipped_in_message_conversion() {
735 let assistant = message::Message::Assistant {
736 id: None,
737 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
738 };
739
740 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
741 assert!(converted.is_empty());
742 }
743
744 #[test]
745 fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
746 let assistant = message::Message::Assistant {
747 id: None,
748 content: OneOrMany::many(vec![
749 message::AssistantContent::reasoning("hidden"),
750 message::AssistantContent::text("visible"),
751 message::AssistantContent::tool_call(
752 "call_1",
753 "subtract",
754 serde_json::json!({"x": 2, "y": 1}),
755 ),
756 ])
757 .expect("non-empty assistant content"),
758 };
759
760 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
761 assert_eq!(converted.len(), 1);
762
763 match &converted[0] {
764 Message::Assistant {
765 content,
766 tool_calls,
767 ..
768 } => {
769 assert_eq!(content, "visible");
770 assert_eq!(tool_calls.len(), 1);
771 assert_eq!(tool_calls[0].id, "call_1");
772 assert_eq!(tool_calls[0].function.name, "subtract");
773 assert_eq!(
774 tool_calls[0].function.arguments,
775 serde_json::json!({"x": 2, "y": 1})
776 );
777 }
778 _ => panic!("expected assistant message"),
779 }
780 }
781
782 #[test]
783 fn test_streaming_choice_mapping_skips_reasoning_and_preserves_other_content() {
784 assert!(
785 assistant_content_to_streaming_choice(message::AssistantContent::reasoning("hidden"))
786 .is_none()
787 );
788
789 let text_choice =
790 assistant_content_to_streaming_choice(message::AssistantContent::text("visible"))
791 .expect("text should be preserved");
792 match text_choice {
793 RawStreamingChoice::Message(text) => assert_eq!(text, "visible"),
794 _ => panic!("expected text streaming choice"),
795 }
796
797 let tool_choice =
798 assistant_content_to_streaming_choice(message::AssistantContent::tool_call(
799 "call_2",
800 "add",
801 serde_json::json!({"x": 2, "y": 3}),
802 ))
803 .expect("tool call should be preserved");
804 match tool_choice {
805 RawStreamingChoice::ToolCall(call) => {
806 assert_eq!(call.id, "call_2");
807 assert_eq!(call.name, "add");
808 assert_eq!(call.arguments, serde_json::json!({"x": 2, "y": 3}));
809 }
810 _ => panic!("expected tool-call streaming choice"),
811 }
812 }
813
814 #[test]
815 fn test_request_conversion_errors_when_all_messages_are_filtered() {
816 let request = CompletionRequest {
817 preamble: None,
818 chat_history: OneOrMany::one(message::Message::Assistant {
819 id: None,
820 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
821 }),
822 documents: vec![],
823 tools: vec![],
824 temperature: None,
825 max_tokens: None,
826 tool_choice: None,
827 additional_params: None,
828 model: None,
829 output_schema: None,
830 };
831
832 let result = MistralCompletionRequest::try_from((MISTRAL_SMALL, request));
833 assert!(matches!(result, Err(CompletionError::RequestError(_))));
834 }
835}