rig/providers/mistral/
completion.rs

1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use std::{convert::Infallible, str::FromStr};
4use tracing::{Instrument, Level, enabled, info_span};
5
6use super::client::{Client, Usage};
7use crate::completion::GetTokenUsage;
8use crate::http_client::{self, HttpClientExt};
9use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
10use crate::{
11    OneOrMany,
12    completion::{self, CompletionError, CompletionRequest},
13    json_utils, message,
14    providers::mistral::client::ApiResponse,
15    telemetry::SpanCombinator,
16};
17
18/// The latest version of the `codestral` Mistral model
19pub const CODESTRAL: &str = "codestral-latest";
20/// The latest version of the `mistral-large` Mistral model
21pub const MISTRAL_LARGE: &str = "mistral-large-latest";
22/// The latest version of the `pixtral-large` Mistral multimodal model
23pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
24/// The latest version of the `mistral` Mistral multimodal model, trained on datasets from the Middle East & South Asia
25pub const MISTRAL_SABA: &str = "mistral-saba-latest";
26/// The latest version of the `mistral-3b` Mistral completions model
27pub const MINISTRAL_3B: &str = "ministral-3b-latest";
28/// The latest version of the `mistral-8b` Mistral completions model
29pub const MINISTRAL_8B: &str = "ministral-8b-latest";
30
31/// The latest version of the `mistral-small` Mistral completions model
32pub const MISTRAL_SMALL: &str = "mistral-small-latest";
33/// The `24-09` version of the `pixtral-small` Mistral multimodal model
34pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
35/// The `open-mistral-nemo` model
36pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
37/// The `open-mistral-mamba` model
38pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
39
40// =================================================================
41// Rig Implementation Types
42// =================================================================
43
44#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
45#[serde(tag = "type", rename_all = "lowercase")]
46pub struct AssistantContent {
47    text: String,
48}
49
50#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
51#[serde(tag = "type", rename_all = "lowercase")]
52pub enum UserContent {
53    Text { text: String },
54}
55
56#[derive(Debug, Serialize, Deserialize, Clone)]
57pub struct Choice {
58    pub index: usize,
59    pub message: Message,
60    pub logprobs: Option<serde_json::Value>,
61    pub finish_reason: String,
62}
63
64#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
65#[serde(tag = "role", rename_all = "lowercase")]
66pub enum Message {
67    User {
68        content: String,
69    },
70    Assistant {
71        content: String,
72        #[serde(
73            default,
74            deserialize_with = "json_utils::null_or_vec",
75            skip_serializing_if = "Vec::is_empty"
76        )]
77        tool_calls: Vec<ToolCall>,
78        #[serde(default)]
79        prefix: bool,
80    },
81    System {
82        content: String,
83    },
84    Tool {
85        /// The name of the tool that was called
86        name: String,
87        /// The content of the tool call
88        content: String,
89        /// The id of the tool call
90        tool_call_id: String,
91    },
92}
93
94impl Message {
95    pub fn user(content: String) -> Self {
96        Message::User { content }
97    }
98
99    pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
100        Message::Assistant {
101            content,
102            tool_calls,
103            prefix,
104        }
105    }
106
107    pub fn system(content: String) -> Self {
108        Message::System { content }
109    }
110}
111
112impl TryFrom<message::Message> for Vec<Message> {
113    type Error = message::MessageError;
114
115    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
116        match message {
117            message::Message::User { content } => {
118                let mut tool_result_messages = Vec::new();
119                let mut other_messages = Vec::new();
120
121                for content_item in content {
122                    match content_item {
123                        message::UserContent::ToolResult(message::ToolResult {
124                            id,
125                            call_id,
126                            content: tool_content,
127                        }) => {
128                            let call_id_key = call_id.unwrap_or_else(|| id.clone());
129                            let content_text = tool_content
130                                .into_iter()
131                                .find_map(|content_item| match content_item {
132                                    message::ToolResultContent::Text(text) => Some(text.text),
133                                    message::ToolResultContent::Image(_) => None,
134                                })
135                                .unwrap_or_default();
136                            tool_result_messages.push(Message::Tool {
137                                name: id,
138                                content: content_text,
139                                tool_call_id: call_id_key,
140                            });
141                        }
142                        message::UserContent::Text(message::Text { text }) => {
143                            other_messages.push(Message::User { content: text });
144                        }
145                        _ => {}
146                    }
147                }
148
149                tool_result_messages.append(&mut other_messages);
150                Ok(tool_result_messages)
151            }
152            message::Message::Assistant { content, .. } => {
153                let (text_content, tool_calls) = content.into_iter().fold(
154                    (Vec::new(), Vec::new()),
155                    |(mut texts, mut tools), content| {
156                        match content {
157                            message::AssistantContent::Text(text) => texts.push(text),
158                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
159                            message::AssistantContent::Reasoning(_) => {
160                                panic!("Reasoning content is not currently supported on Mistral via Rig");
161                            }
162                            message::AssistantContent::Image(_) => {
163                                panic!("Image content is not currently supported on Mistral via Rig");
164                            }
165                        }
166                        (texts, tools)
167                    },
168                );
169
170                Ok(vec![Message::Assistant {
171                    content: text_content
172                        .into_iter()
173                        .next()
174                        .map(|content| content.text)
175                        .unwrap_or_default(),
176                    tool_calls: tool_calls
177                        .into_iter()
178                        .map(|tool_call| tool_call.into())
179                        .collect::<Vec<_>>(),
180                    prefix: false,
181                }])
182            }
183        }
184    }
185}
186
187#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
188pub struct ToolCall {
189    pub id: String,
190    #[serde(default)]
191    pub r#type: ToolType,
192    pub function: Function,
193}
194
195impl From<message::ToolCall> for ToolCall {
196    fn from(tool_call: message::ToolCall) -> Self {
197        Self {
198            id: tool_call.id,
199            r#type: ToolType::default(),
200            function: Function {
201                name: tool_call.function.name,
202                arguments: tool_call.function.arguments,
203            },
204        }
205    }
206}
207
208#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
209pub struct Function {
210    pub name: String,
211    #[serde(with = "json_utils::stringified_json")]
212    pub arguments: serde_json::Value,
213}
214
215#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
216#[serde(rename_all = "lowercase")]
217pub enum ToolType {
218    #[default]
219    Function,
220}
221
222#[derive(Debug, Deserialize, Serialize, Clone)]
223pub struct ToolDefinition {
224    pub r#type: String,
225    pub function: completion::ToolDefinition,
226}
227
228impl From<completion::ToolDefinition> for ToolDefinition {
229    fn from(tool: completion::ToolDefinition) -> Self {
230        Self {
231            r#type: "function".into(),
232            function: tool,
233        }
234    }
235}
236
237#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
238pub struct ToolResultContent {
239    #[serde(default)]
240    r#type: ToolResultContentType,
241    text: String,
242}
243
244#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
245#[serde(rename_all = "lowercase")]
246pub enum ToolResultContentType {
247    #[default]
248    Text,
249}
250
251impl From<String> for ToolResultContent {
252    fn from(s: String) -> Self {
253        ToolResultContent {
254            r#type: ToolResultContentType::default(),
255            text: s,
256        }
257    }
258}
259
260impl From<String> for UserContent {
261    fn from(s: String) -> Self {
262        UserContent::Text { text: s }
263    }
264}
265
266impl FromStr for UserContent {
267    type Err = Infallible;
268
269    fn from_str(s: &str) -> Result<Self, Self::Err> {
270        Ok(UserContent::Text {
271            text: s.to_string(),
272        })
273    }
274}
275
276impl From<String> for AssistantContent {
277    fn from(s: String) -> Self {
278        AssistantContent { text: s }
279    }
280}
281
282impl FromStr for AssistantContent {
283    type Err = Infallible;
284
285    fn from_str(s: &str) -> Result<Self, Self::Err> {
286        Ok(AssistantContent {
287            text: s.to_string(),
288        })
289    }
290}
291
292#[derive(Clone)]
293pub struct CompletionModel<T = reqwest::Client> {
294    pub(crate) client: Client<T>,
295    pub model: String,
296}
297
298#[derive(Debug, Default, Serialize, Deserialize)]
299pub enum ToolChoice {
300    #[default]
301    Auto,
302    None,
303    Any,
304}
305
306impl TryFrom<message::ToolChoice> for ToolChoice {
307    type Error = CompletionError;
308
309    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
310        let res = match value {
311            message::ToolChoice::Auto => Self::Auto,
312            message::ToolChoice::None => Self::None,
313            message::ToolChoice::Required => Self::Any,
314            message::ToolChoice::Specific { .. } => {
315                return Err(CompletionError::ProviderError(
316                    "Mistral doesn't support requiring specific tools to be called".to_string(),
317                ));
318            }
319        };
320
321        Ok(res)
322    }
323}
324
325#[derive(Debug, Serialize, Deserialize)]
326pub(super) struct MistralCompletionRequest {
327    model: String,
328    pub messages: Vec<Message>,
329    #[serde(skip_serializing_if = "Option::is_none")]
330    temperature: Option<f64>,
331    #[serde(skip_serializing_if = "Vec::is_empty")]
332    tools: Vec<ToolDefinition>,
333    #[serde(skip_serializing_if = "Option::is_none")]
334    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
335    #[serde(flatten, skip_serializing_if = "Option::is_none")]
336    pub additional_params: Option<serde_json::Value>,
337}
338
339impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest {
340    type Error = CompletionError;
341
342    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
343        let mut full_history: Vec<Message> = match &req.preamble {
344            Some(preamble) => vec![Message::system(preamble.clone())],
345            None => vec![],
346        };
347        if let Some(docs) = req.normalized_documents() {
348            let docs: Vec<Message> = docs.try_into()?;
349            full_history.extend(docs);
350        }
351
352        let chat_history: Vec<Message> = req
353            .chat_history
354            .clone()
355            .into_iter()
356            .map(|message| message.try_into())
357            .collect::<Result<Vec<Vec<Message>>, _>>()?
358            .into_iter()
359            .flatten()
360            .collect();
361
362        full_history.extend(chat_history);
363
364        let tool_choice = req
365            .tool_choice
366            .clone()
367            .map(crate::providers::openai::completion::ToolChoice::try_from)
368            .transpose()?;
369
370        Ok(Self {
371            model: model.to_string(),
372            messages: full_history,
373            temperature: req.temperature,
374            tools: req
375                .tools
376                .clone()
377                .into_iter()
378                .map(ToolDefinition::from)
379                .collect::<Vec<_>>(),
380            tool_choice,
381            additional_params: req.additional_params,
382        })
383    }
384}
385
386impl<T> CompletionModel<T> {
387    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
388        Self {
389            client,
390            model: model.into(),
391        }
392    }
393
394    pub fn with_model(client: Client<T>, model: &str) -> Self {
395        Self {
396            client,
397            model: model.into(),
398        }
399    }
400}
401
402#[derive(Debug, Deserialize, Clone, Serialize)]
403pub struct CompletionResponse {
404    pub id: String,
405    pub object: String,
406    pub created: u64,
407    pub model: String,
408    pub system_fingerprint: Option<String>,
409    pub choices: Vec<Choice>,
410    pub usage: Option<Usage>,
411}
412
413impl crate::telemetry::ProviderResponseExt for CompletionResponse {
414    type OutputMessage = Choice;
415    type Usage = Usage;
416
417    fn get_response_id(&self) -> Option<String> {
418        Some(self.id.clone())
419    }
420
421    fn get_response_model_name(&self) -> Option<String> {
422        Some(self.model.clone())
423    }
424
425    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
426        self.choices.clone()
427    }
428
429    fn get_text_response(&self) -> Option<String> {
430        let res = self
431            .choices
432            .iter()
433            .filter_map(|choice| match choice.message {
434                Message::Assistant { ref content, .. } => {
435                    if content.is_empty() {
436                        None
437                    } else {
438                        Some(content.to_string())
439                    }
440                }
441                _ => None,
442            })
443            .collect::<Vec<String>>()
444            .join("\n");
445
446        if res.is_empty() { None } else { Some(res) }
447    }
448
449    fn get_usage(&self) -> Option<Self::Usage> {
450        self.usage.clone()
451    }
452}
453
454impl GetTokenUsage for CompletionResponse {
455    fn token_usage(&self) -> Option<crate::completion::Usage> {
456        let api_usage = self.usage.clone()?;
457
458        let mut usage = crate::completion::Usage::new();
459        usage.input_tokens = api_usage.prompt_tokens as u64;
460        usage.output_tokens = api_usage.completion_tokens as u64;
461        usage.total_tokens = api_usage.total_tokens as u64;
462
463        Some(usage)
464    }
465}
466
467impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
468    type Error = CompletionError;
469
470    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
471        let choice = response.choices.first().ok_or_else(|| {
472            CompletionError::ResponseError("Response contained no choices".to_owned())
473        })?;
474        let content = match &choice.message {
475            Message::Assistant {
476                content,
477                tool_calls,
478                ..
479            } => {
480                let mut content = if content.is_empty() {
481                    vec![]
482                } else {
483                    vec![completion::AssistantContent::text(content.clone())]
484                };
485
486                content.extend(
487                    tool_calls
488                        .iter()
489                        .map(|call| {
490                            completion::AssistantContent::tool_call(
491                                &call.id,
492                                &call.function.name,
493                                call.function.arguments.clone(),
494                            )
495                        })
496                        .collect::<Vec<_>>(),
497                );
498                Ok(content)
499            }
500            _ => Err(CompletionError::ResponseError(
501                "Response did not contain a valid message or tool call".into(),
502            )),
503        }?;
504
505        let choice = OneOrMany::many(content).map_err(|_| {
506            CompletionError::ResponseError(
507                "Response contained no message or tool call (empty)".to_owned(),
508            )
509        })?;
510
511        let usage = response
512            .usage
513            .as_ref()
514            .map(|usage| completion::Usage {
515                input_tokens: usage.prompt_tokens as u64,
516                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
517                total_tokens: usage.total_tokens as u64,
518            })
519            .unwrap_or_default();
520
521        Ok(completion::CompletionResponse {
522            choice,
523            usage,
524            raw_response: response,
525        })
526    }
527}
528
529impl<T> completion::CompletionModel for CompletionModel<T>
530where
531    T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static,
532{
533    type Response = CompletionResponse;
534    type StreamingResponse = CompletionResponse;
535
536    type Client = Client<T>;
537
538    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
539        Self::new(client.clone(), model.into())
540    }
541
542    async fn completion(
543        &self,
544        completion_request: CompletionRequest,
545    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
546        let preamble = completion_request.preamble.clone();
547        let request =
548            MistralCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
549
550        if enabled!(Level::TRACE) {
551            tracing::trace!(
552                target: "rig::completions",
553                "Mistral completion request: {}",
554                serde_json::to_string_pretty(&request)?
555            );
556        }
557
558        let span = if tracing::Span::current().is_disabled() {
559            info_span!(
560                target: "rig::completions",
561                "chat",
562                gen_ai.operation.name = "chat",
563                gen_ai.provider.name = "mistral",
564                gen_ai.request.model = self.model,
565                gen_ai.system_instructions = &preamble,
566                gen_ai.response.id = tracing::field::Empty,
567                gen_ai.response.model = tracing::field::Empty,
568                gen_ai.usage.output_tokens = tracing::field::Empty,
569                gen_ai.usage.input_tokens = tracing::field::Empty,
570            )
571        } else {
572            tracing::Span::current()
573        };
574
575        let body = serde_json::to_vec(&request)?;
576
577        let request = self
578            .client
579            .post("v1/chat/completions")?
580            .body(body)
581            .map_err(|e| CompletionError::HttpError(e.into()))?;
582
583        async move {
584            let response = self.client.send(request).await?;
585
586            if response.status().is_success() {
587                let text = http_client::text(response).await?;
588                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
589                    ApiResponse::Ok(response) => {
590                        let span = tracing::Span::current();
591                        span.record_token_usage(&response);
592                        span.record_response_metadata(&response);
593                        response.try_into()
594                    }
595                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
596                }
597            } else {
598                let text = http_client::text(response).await?;
599                Err(CompletionError::ProviderError(text))
600            }
601        }
602        .instrument(span)
603        .await
604    }
605
606    async fn stream(
607        &self,
608        request: CompletionRequest,
609    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
610        let resp = self.completion(request).await?;
611
612        let stream = stream! {
613            for c in resp.choice.clone() {
614                match c {
615                    message::AssistantContent::Text(t) => {
616                        yield Ok(RawStreamingChoice::Message(t.text.clone()))
617                    }
618                    message::AssistantContent::ToolCall(tc) => {
619                        yield Ok(RawStreamingChoice::ToolCall(
620                                RawStreamingToolCall::new(
621                                    tc.id.clone(),
622                                    tc.function.name.clone(),
623                                    tc.function.arguments.clone(),
624                                )
625                        ))
626                    }
627                    message::AssistantContent::Reasoning(_) => {
628                        panic!("Reasoning is not supported on Mistral via Rig")
629                    }
630                    message::AssistantContent::Image(_) => {
631                        panic!("Image content is not supported on Mistral via Rig")
632                    }
633                }
634            }
635
636            yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
637        };
638
639        Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
640    }
641}
642
643#[cfg(test)]
644mod tests {
645    use super::*;
646
647    #[test]
648    fn test_response_deserialization() {
649        //https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post
650        let json_data = r#"
651        {
652            "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
653            "object": "chat.completion",
654            "model": "mistral-small-latest",
655            "usage": {
656                "prompt_tokens": 16,
657                "completion_tokens": 34,
658                "total_tokens": 50
659            },
660            "created": 1702256327,
661            "choices": [
662                {
663                    "index": 0,
664                    "message": {
665                        "content": "string",
666                        "tool_calls": [
667                            {
668                                "id": "null",
669                                "type": "function",
670                                "function": {
671                                    "name": "string",
672                                    "arguments": "{ }"
673                                },
674                                "index": 0
675                            }
676                        ],
677                        "prefix": false,
678                        "role": "assistant"
679                    },
680                    "finish_reason": "stop"
681                }
682            ]
683        }
684        "#;
685        let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
686        assert_eq!(completion_response.model, MISTRAL_SMALL);
687
688        let CompletionResponse {
689            id,
690            object,
691            created,
692            choices,
693            usage,
694            ..
695        } = completion_response;
696
697        assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
698
699        let Usage {
700            completion_tokens,
701            prompt_tokens,
702            total_tokens,
703        } = usage.unwrap();
704
705        assert_eq!(prompt_tokens, 16);
706        assert_eq!(completion_tokens, 34);
707        assert_eq!(total_tokens, 50);
708        assert_eq!(object, "chat.completion".to_string());
709        assert_eq!(created, 1702256327);
710        assert_eq!(choices.len(), 1);
711    }
712}