1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use std::{convert::Infallible, str::FromStr};
4use tracing::{Instrument, Level, enabled, info_span};
5
6use super::client::{Client, Usage};
7use crate::completion::GetTokenUsage;
8use crate::http_client::{self, HttpClientExt};
9use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
10use crate::{
11 OneOrMany,
12 completion::{self, CompletionError, CompletionRequest},
13 json_utils, message,
14 providers::mistral::client::ApiResponse,
15 telemetry::SpanCombinator,
16};
17
18pub const CODESTRAL: &str = "codestral-latest";
20pub const MISTRAL_LARGE: &str = "mistral-large-latest";
22pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
24pub const MISTRAL_SABA: &str = "mistral-saba-latest";
26pub const MINISTRAL_3B: &str = "ministral-3b-latest";
28pub const MINISTRAL_8B: &str = "ministral-8b-latest";
30
31pub const MISTRAL_SMALL: &str = "mistral-small-latest";
33pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
35pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
37pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
39
40#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
45#[serde(tag = "type", rename_all = "lowercase")]
46pub struct AssistantContent {
47 text: String,
48}
49
50#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
51#[serde(tag = "type", rename_all = "lowercase")]
52pub enum UserContent {
53 Text { text: String },
54}
55
56#[derive(Debug, Serialize, Deserialize, Clone)]
57pub struct Choice {
58 pub index: usize,
59 pub message: Message,
60 pub logprobs: Option<serde_json::Value>,
61 pub finish_reason: String,
62}
63
64#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
65#[serde(tag = "role", rename_all = "lowercase")]
66pub enum Message {
67 User {
68 content: String,
69 },
70 Assistant {
71 content: String,
72 #[serde(
73 default,
74 deserialize_with = "json_utils::null_or_vec",
75 skip_serializing_if = "Vec::is_empty"
76 )]
77 tool_calls: Vec<ToolCall>,
78 #[serde(default)]
79 prefix: bool,
80 },
81 System {
82 content: String,
83 },
84 Tool {
85 name: String,
87 content: String,
89 tool_call_id: String,
91 },
92}
93
94impl Message {
95 pub fn user(content: String) -> Self {
96 Message::User { content }
97 }
98
99 pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
100 Message::Assistant {
101 content,
102 tool_calls,
103 prefix,
104 }
105 }
106
107 pub fn system(content: String) -> Self {
108 Message::System { content }
109 }
110}
111
112impl TryFrom<message::Message> for Vec<Message> {
113 type Error = message::MessageError;
114
115 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
116 match message {
117 message::Message::System { content } => Ok(vec![Message::System { content }]),
118 message::Message::User { content } => {
119 let mut tool_result_messages = Vec::new();
120 let mut other_messages = Vec::new();
121
122 for content_item in content {
123 match content_item {
124 message::UserContent::ToolResult(message::ToolResult {
125 id,
126 call_id,
127 content: tool_content,
128 }) => {
129 let call_id_key = call_id.unwrap_or_else(|| id.clone());
130 let content_text = tool_content
131 .into_iter()
132 .find_map(|content_item| match content_item {
133 message::ToolResultContent::Text(text) => Some(text.text),
134 message::ToolResultContent::Image(_) => None,
135 })
136 .unwrap_or_default();
137 tool_result_messages.push(Message::Tool {
138 name: id,
139 content: content_text,
140 tool_call_id: call_id_key,
141 });
142 }
143 message::UserContent::Text(message::Text { text }) => {
144 other_messages.push(Message::User { content: text });
145 }
146 _ => {}
147 }
148 }
149
150 tool_result_messages.append(&mut other_messages);
151 Ok(tool_result_messages)
152 }
153 message::Message::Assistant { content, .. } => {
154 let mut text_content = Vec::new();
155 let mut tool_calls = Vec::new();
156
157 for content in content {
158 match content {
159 message::AssistantContent::Text(text) => text_content.push(text),
160 message::AssistantContent::ToolCall(tool_call) => {
161 tool_calls.push(tool_call)
162 }
163 message::AssistantContent::Reasoning(_) => {
164 }
167 message::AssistantContent::Image(_) => {
168 panic!("Image content is not currently supported on Mistral via Rig");
169 }
170 }
171 }
172
173 if text_content.is_empty() && tool_calls.is_empty() {
174 return Ok(vec![]);
175 }
176
177 Ok(vec![Message::Assistant {
178 content: text_content
179 .into_iter()
180 .next()
181 .map(|content| content.text)
182 .unwrap_or_default(),
183 tool_calls: tool_calls
184 .into_iter()
185 .map(|tool_call| tool_call.into())
186 .collect::<Vec<_>>(),
187 prefix: false,
188 }])
189 }
190 }
191 }
192}
193
194#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
195pub struct ToolCall {
196 pub id: String,
197 #[serde(default)]
198 pub r#type: ToolType,
199 pub function: Function,
200}
201
202impl From<message::ToolCall> for ToolCall {
203 fn from(tool_call: message::ToolCall) -> Self {
204 Self {
205 id: tool_call.id,
206 r#type: ToolType::default(),
207 function: Function {
208 name: tool_call.function.name,
209 arguments: tool_call.function.arguments,
210 },
211 }
212 }
213}
214
215#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
216pub struct Function {
217 pub name: String,
218 #[serde(with = "json_utils::stringified_json")]
219 pub arguments: serde_json::Value,
220}
221
222#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
223#[serde(rename_all = "lowercase")]
224pub enum ToolType {
225 #[default]
226 Function,
227}
228
229#[derive(Debug, Deserialize, Serialize, Clone)]
230pub struct ToolDefinition {
231 pub r#type: String,
232 pub function: completion::ToolDefinition,
233}
234
235impl From<completion::ToolDefinition> for ToolDefinition {
236 fn from(tool: completion::ToolDefinition) -> Self {
237 Self {
238 r#type: "function".into(),
239 function: tool,
240 }
241 }
242}
243
244#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
245pub struct ToolResultContent {
246 #[serde(default)]
247 r#type: ToolResultContentType,
248 text: String,
249}
250
251#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
252#[serde(rename_all = "lowercase")]
253pub enum ToolResultContentType {
254 #[default]
255 Text,
256}
257
258impl From<String> for ToolResultContent {
259 fn from(s: String) -> Self {
260 ToolResultContent {
261 r#type: ToolResultContentType::default(),
262 text: s,
263 }
264 }
265}
266
267impl From<String> for UserContent {
268 fn from(s: String) -> Self {
269 UserContent::Text { text: s }
270 }
271}
272
273impl FromStr for UserContent {
274 type Err = Infallible;
275
276 fn from_str(s: &str) -> Result<Self, Self::Err> {
277 Ok(UserContent::Text {
278 text: s.to_string(),
279 })
280 }
281}
282
283impl From<String> for AssistantContent {
284 fn from(s: String) -> Self {
285 AssistantContent { text: s }
286 }
287}
288
289impl FromStr for AssistantContent {
290 type Err = Infallible;
291
292 fn from_str(s: &str) -> Result<Self, Self::Err> {
293 Ok(AssistantContent {
294 text: s.to_string(),
295 })
296 }
297}
298
299#[derive(Clone)]
300pub struct CompletionModel<T = reqwest::Client> {
301 pub(crate) client: Client<T>,
302 pub model: String,
303}
304
305#[derive(Debug, Default, Serialize, Deserialize)]
306pub enum ToolChoice {
307 #[default]
308 Auto,
309 None,
310 Any,
311}
312
313impl TryFrom<message::ToolChoice> for ToolChoice {
314 type Error = CompletionError;
315
316 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
317 let res = match value {
318 message::ToolChoice::Auto => Self::Auto,
319 message::ToolChoice::None => Self::None,
320 message::ToolChoice::Required => Self::Any,
321 message::ToolChoice::Specific { .. } => {
322 return Err(CompletionError::ProviderError(
323 "Mistral doesn't support requiring specific tools to be called".to_string(),
324 ));
325 }
326 };
327
328 Ok(res)
329 }
330}
331
332#[derive(Debug, Serialize, Deserialize)]
333pub(super) struct MistralCompletionRequest {
334 model: String,
335 pub messages: Vec<Message>,
336 #[serde(skip_serializing_if = "Option::is_none")]
337 temperature: Option<f64>,
338 #[serde(skip_serializing_if = "Vec::is_empty")]
339 tools: Vec<ToolDefinition>,
340 #[serde(skip_serializing_if = "Option::is_none")]
341 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
342 #[serde(flatten, skip_serializing_if = "Option::is_none")]
343 pub additional_params: Option<serde_json::Value>,
344}
345
346impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest {
347 type Error = CompletionError;
348
349 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
350 if req.output_schema.is_some() {
351 tracing::warn!("Structured outputs currently not supported for Mistral");
352 }
353 let model = req.model.clone().unwrap_or_else(|| model.to_string());
354 let mut full_history: Vec<Message> = match &req.preamble {
355 Some(preamble) => vec![Message::system(preamble.clone())],
356 None => vec![],
357 };
358 if let Some(docs) = req.normalized_documents() {
359 let docs: Vec<Message> = docs.try_into()?;
360 full_history.extend(docs);
361 }
362
363 let chat_history: Vec<Message> = req
364 .chat_history
365 .clone()
366 .into_iter()
367 .map(|message| message.try_into())
368 .collect::<Result<Vec<Vec<Message>>, _>>()?
369 .into_iter()
370 .flatten()
371 .collect();
372
373 full_history.extend(chat_history);
374
375 if full_history.is_empty() {
376 return Err(CompletionError::RequestError(
377 std::io::Error::new(
378 std::io::ErrorKind::InvalidInput,
379 "Mistral request has no provider-compatible messages after conversion",
380 )
381 .into(),
382 ));
383 }
384
385 let tool_choice = req
386 .tool_choice
387 .clone()
388 .map(crate::providers::openai::completion::ToolChoice::try_from)
389 .transpose()?;
390
391 Ok(Self {
392 model: model.to_string(),
393 messages: full_history,
394 temperature: req.temperature,
395 tools: req
396 .tools
397 .clone()
398 .into_iter()
399 .map(ToolDefinition::from)
400 .collect::<Vec<_>>(),
401 tool_choice,
402 additional_params: req.additional_params,
403 })
404 }
405}
406
407impl<T> CompletionModel<T> {
408 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
409 Self {
410 client,
411 model: model.into(),
412 }
413 }
414
415 pub fn with_model(client: Client<T>, model: &str) -> Self {
416 Self {
417 client,
418 model: model.into(),
419 }
420 }
421}
422
423#[derive(Debug, Deserialize, Clone, Serialize)]
424pub struct CompletionResponse {
425 pub id: String,
426 pub object: String,
427 pub created: u64,
428 pub model: String,
429 pub system_fingerprint: Option<String>,
430 pub choices: Vec<Choice>,
431 pub usage: Option<Usage>,
432}
433
434impl crate::telemetry::ProviderResponseExt for CompletionResponse {
435 type OutputMessage = Choice;
436 type Usage = Usage;
437
438 fn get_response_id(&self) -> Option<String> {
439 Some(self.id.clone())
440 }
441
442 fn get_response_model_name(&self) -> Option<String> {
443 Some(self.model.clone())
444 }
445
446 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
447 self.choices.clone()
448 }
449
450 fn get_text_response(&self) -> Option<String> {
451 let res = self
452 .choices
453 .iter()
454 .filter_map(|choice| match choice.message {
455 Message::Assistant { ref content, .. } => {
456 if content.is_empty() {
457 None
458 } else {
459 Some(content.to_string())
460 }
461 }
462 _ => None,
463 })
464 .collect::<Vec<String>>()
465 .join("\n");
466
467 if res.is_empty() { None } else { Some(res) }
468 }
469
470 fn get_usage(&self) -> Option<Self::Usage> {
471 self.usage.clone()
472 }
473}
474
475impl GetTokenUsage for CompletionResponse {
476 fn token_usage(&self) -> Option<crate::completion::Usage> {
477 let api_usage = self.usage.clone()?;
478
479 let mut usage = crate::completion::Usage::new();
480 usage.input_tokens = api_usage.prompt_tokens as u64;
481 usage.output_tokens = api_usage.completion_tokens as u64;
482 usage.total_tokens = api_usage.total_tokens as u64;
483
484 Some(usage)
485 }
486}
487
488impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
489 type Error = CompletionError;
490
491 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
492 let choice = response.choices.first().ok_or_else(|| {
493 CompletionError::ResponseError("Response contained no choices".to_owned())
494 })?;
495 let content = match &choice.message {
496 Message::Assistant {
497 content,
498 tool_calls,
499 ..
500 } => {
501 let mut content = if content.is_empty() {
502 vec![]
503 } else {
504 vec![completion::AssistantContent::text(content.clone())]
505 };
506
507 content.extend(
508 tool_calls
509 .iter()
510 .map(|call| {
511 completion::AssistantContent::tool_call(
512 &call.id,
513 &call.function.name,
514 call.function.arguments.clone(),
515 )
516 })
517 .collect::<Vec<_>>(),
518 );
519 Ok(content)
520 }
521 _ => Err(CompletionError::ResponseError(
522 "Response did not contain a valid message or tool call".into(),
523 )),
524 }?;
525
526 let choice = OneOrMany::many(content).map_err(|_| {
527 CompletionError::ResponseError(
528 "Response contained no message or tool call (empty)".to_owned(),
529 )
530 })?;
531
532 let usage = response
533 .usage
534 .as_ref()
535 .map(|usage| completion::Usage {
536 input_tokens: usage.prompt_tokens as u64,
537 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
538 total_tokens: usage.total_tokens as u64,
539 cached_input_tokens: 0,
540 })
541 .unwrap_or_default();
542
543 Ok(completion::CompletionResponse {
544 choice,
545 usage,
546 raw_response: response,
547 message_id: None,
548 })
549 }
550}
551
552fn assistant_content_to_streaming_choice(
553 content: message::AssistantContent,
554) -> Option<RawStreamingChoice<CompletionResponse>> {
555 match content {
556 message::AssistantContent::Text(t) => Some(RawStreamingChoice::Message(t.text)),
557 message::AssistantContent::ToolCall(tc) => Some(RawStreamingChoice::ToolCall(
558 RawStreamingToolCall::new(tc.id, tc.function.name, tc.function.arguments),
559 )),
560 message::AssistantContent::Reasoning(_) => None,
561 message::AssistantContent::Image(_) => {
562 panic!("Image content is not supported on Mistral via Rig")
563 }
564 }
565}
566
567impl<T> completion::CompletionModel for CompletionModel<T>
568where
569 T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static,
570{
571 type Response = CompletionResponse;
572 type StreamingResponse = CompletionResponse;
573
574 type Client = Client<T>;
575
576 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
577 Self::new(client.clone(), model.into())
578 }
579
580 async fn completion(
581 &self,
582 completion_request: CompletionRequest,
583 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
584 let preamble = completion_request.preamble.clone();
585 let request =
586 MistralCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
587
588 if enabled!(Level::TRACE) {
589 tracing::trace!(
590 target: "rig::completions",
591 "Mistral completion request: {}",
592 serde_json::to_string_pretty(&request)?
593 );
594 }
595
596 let span = if tracing::Span::current().is_disabled() {
597 info_span!(
598 target: "rig::completions",
599 "chat",
600 gen_ai.operation.name = "chat",
601 gen_ai.provider.name = "mistral",
602 gen_ai.request.model = self.model,
603 gen_ai.system_instructions = &preamble,
604 gen_ai.response.id = tracing::field::Empty,
605 gen_ai.response.model = tracing::field::Empty,
606 gen_ai.usage.output_tokens = tracing::field::Empty,
607 gen_ai.usage.input_tokens = tracing::field::Empty,
608 gen_ai.usage.cached_tokens = tracing::field::Empty,
609 )
610 } else {
611 tracing::Span::current()
612 };
613
614 let body = serde_json::to_vec(&request)?;
615
616 let request = self
617 .client
618 .post("v1/chat/completions")?
619 .body(body)
620 .map_err(|e| CompletionError::HttpError(e.into()))?;
621
622 async move {
623 let response = self.client.send(request).await?;
624
625 if response.status().is_success() {
626 let text = http_client::text(response).await?;
627 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
628 ApiResponse::Ok(response) => {
629 let span = tracing::Span::current();
630 span.record_token_usage(&response);
631 span.record_response_metadata(&response);
632 response.try_into()
633 }
634 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
635 }
636 } else {
637 let text = http_client::text(response).await?;
638 Err(CompletionError::ProviderError(text))
639 }
640 }
641 .instrument(span)
642 .await
643 }
644
645 async fn stream(
646 &self,
647 request: CompletionRequest,
648 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
649 let resp = self.completion(request).await?;
650
651 let stream = stream! {
652 for c in resp.choice.clone() {
653 if let Some(choice) = assistant_content_to_streaming_choice(c) {
654 yield Ok(choice);
655 }
656 }
657
658 yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
659 };
660
661 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
662 }
663}
664
665#[cfg(test)]
666mod tests {
667 use super::*;
668
669 #[test]
670 fn test_response_deserialization() {
671 let json_data = r#"
673 {
674 "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
675 "object": "chat.completion",
676 "model": "mistral-small-latest",
677 "usage": {
678 "prompt_tokens": 16,
679 "completion_tokens": 34,
680 "total_tokens": 50
681 },
682 "created": 1702256327,
683 "choices": [
684 {
685 "index": 0,
686 "message": {
687 "content": "string",
688 "tool_calls": [
689 {
690 "id": "null",
691 "type": "function",
692 "function": {
693 "name": "string",
694 "arguments": "{ }"
695 },
696 "index": 0
697 }
698 ],
699 "prefix": false,
700 "role": "assistant"
701 },
702 "finish_reason": "stop"
703 }
704 ]
705 }
706 "#;
707 let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
708 assert_eq!(completion_response.model, MISTRAL_SMALL);
709
710 let CompletionResponse {
711 id,
712 object,
713 created,
714 choices,
715 usage,
716 ..
717 } = completion_response;
718
719 assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
720
721 let Usage {
722 completion_tokens,
723 prompt_tokens,
724 total_tokens,
725 } = usage.unwrap();
726
727 assert_eq!(prompt_tokens, 16);
728 assert_eq!(completion_tokens, 34);
729 assert_eq!(total_tokens, 50);
730 assert_eq!(object, "chat.completion".to_string());
731 assert_eq!(created, 1702256327);
732 assert_eq!(choices.len(), 1);
733 }
734
735 #[test]
736 fn test_assistant_reasoning_is_skipped_in_message_conversion() {
737 let assistant = message::Message::Assistant {
738 id: None,
739 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
740 };
741
742 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
743 assert!(converted.is_empty());
744 }
745
746 #[test]
747 fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
748 let assistant = message::Message::Assistant {
749 id: None,
750 content: OneOrMany::many(vec![
751 message::AssistantContent::reasoning("hidden"),
752 message::AssistantContent::text("visible"),
753 message::AssistantContent::tool_call(
754 "call_1",
755 "subtract",
756 serde_json::json!({"x": 2, "y": 1}),
757 ),
758 ])
759 .expect("non-empty assistant content"),
760 };
761
762 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
763 assert_eq!(converted.len(), 1);
764
765 match &converted[0] {
766 Message::Assistant {
767 content,
768 tool_calls,
769 ..
770 } => {
771 assert_eq!(content, "visible");
772 assert_eq!(tool_calls.len(), 1);
773 assert_eq!(tool_calls[0].id, "call_1");
774 assert_eq!(tool_calls[0].function.name, "subtract");
775 assert_eq!(
776 tool_calls[0].function.arguments,
777 serde_json::json!({"x": 2, "y": 1})
778 );
779 }
780 _ => panic!("expected assistant message"),
781 }
782 }
783
784 #[test]
785 fn test_streaming_choice_mapping_skips_reasoning_and_preserves_other_content() {
786 assert!(
787 assistant_content_to_streaming_choice(message::AssistantContent::reasoning("hidden"))
788 .is_none()
789 );
790
791 let text_choice =
792 assistant_content_to_streaming_choice(message::AssistantContent::text("visible"))
793 .expect("text should be preserved");
794 match text_choice {
795 RawStreamingChoice::Message(text) => assert_eq!(text, "visible"),
796 _ => panic!("expected text streaming choice"),
797 }
798
799 let tool_choice =
800 assistant_content_to_streaming_choice(message::AssistantContent::tool_call(
801 "call_2",
802 "add",
803 serde_json::json!({"x": 2, "y": 3}),
804 ))
805 .expect("tool call should be preserved");
806 match tool_choice {
807 RawStreamingChoice::ToolCall(call) => {
808 assert_eq!(call.id, "call_2");
809 assert_eq!(call.name, "add");
810 assert_eq!(call.arguments, serde_json::json!({"x": 2, "y": 3}));
811 }
812 _ => panic!("expected tool-call streaming choice"),
813 }
814 }
815
816 #[test]
817 fn test_request_conversion_errors_when_all_messages_are_filtered() {
818 let request = CompletionRequest {
819 preamble: None,
820 chat_history: OneOrMany::one(message::Message::Assistant {
821 id: None,
822 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
823 }),
824 documents: vec![],
825 tools: vec![],
826 temperature: None,
827 max_tokens: None,
828 tool_choice: None,
829 additional_params: None,
830 model: None,
831 output_schema: None,
832 };
833
834 let result = MistralCompletionRequest::try_from((MISTRAL_SMALL, request));
835 assert!(matches!(result, Err(CompletionError::RequestError(_))));
836 }
837}