rig/providers/mistral/
completion.rs

1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use std::{convert::Infallible, str::FromStr};
4use tracing::{Instrument, info_span};
5
6use super::client::{Client, Usage};
7use crate::completion::GetTokenUsage;
8use crate::http_client::{self, HttpClientExt};
9use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse};
10use crate::{
11    OneOrMany,
12    completion::{self, CompletionError, CompletionRequest},
13    json_utils, message,
14    providers::mistral::client::ApiResponse,
15    telemetry::SpanCombinator,
16};
17
18/// The latest version of the `codestral` Mistral model
19pub const CODESTRAL: &str = "codestral-latest";
20/// The latest version of the `mistral-large` Mistral model
21pub const MISTRAL_LARGE: &str = "mistral-large-latest";
22/// The latest version of the `pixtral-large` Mistral multimodal model
23pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
24/// The latest version of the `mistral` Mistral multimodal model, trained on datasets from the Middle East & South Asia
25pub const MISTRAL_SABA: &str = "mistral-saba-latest";
26/// The latest version of the `mistral-3b` Mistral completions model
27pub const MINISTRAL_3B: &str = "ministral-3b-latest";
28/// The latest version of the `mistral-8b` Mistral completions model
29pub const MINISTRAL_8B: &str = "ministral-8b-latest";
30
31/// The latest version of the `mistral-small` Mistral completions model
32pub const MISTRAL_SMALL: &str = "mistral-small-latest";
33/// The `24-09` version of the `pixtral-small` Mistral multimodal model
34pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
35/// The `open-mistral-nemo` model
36pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
37/// The `open-mistral-mamba` model
38pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
39
40// =================================================================
41// Rig Implementation Types
42// =================================================================
43
44#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
45#[serde(tag = "type", rename_all = "lowercase")]
46pub struct AssistantContent {
47    text: String,
48}
49
50#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
51#[serde(tag = "type", rename_all = "lowercase")]
52pub enum UserContent {
53    Text { text: String },
54}
55
56#[derive(Debug, Serialize, Deserialize, Clone)]
57pub struct Choice {
58    pub index: usize,
59    pub message: Message,
60    pub logprobs: Option<serde_json::Value>,
61    pub finish_reason: String,
62}
63
64#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
65#[serde(tag = "role", rename_all = "lowercase")]
66pub enum Message {
67    User {
68        content: String,
69    },
70    Assistant {
71        content: String,
72        #[serde(
73            default,
74            deserialize_with = "json_utils::null_or_vec",
75            skip_serializing_if = "Vec::is_empty"
76        )]
77        tool_calls: Vec<ToolCall>,
78        #[serde(default)]
79        prefix: bool,
80    },
81    System {
82        content: String,
83    },
84}
85
86impl Message {
87    pub fn user(content: String) -> Self {
88        Message::User { content }
89    }
90
91    pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
92        Message::Assistant {
93            content,
94            tool_calls,
95            prefix,
96        }
97    }
98
99    pub fn system(content: String) -> Self {
100        Message::System { content }
101    }
102}
103
104impl TryFrom<message::Message> for Vec<Message> {
105    type Error = message::MessageError;
106
107    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
108        match message {
109            message::Message::User { content } => {
110                let (_, other_content): (Vec<_>, Vec<_>) = content
111                    .into_iter()
112                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
113
114                let messages = other_content
115                    .into_iter()
116                    .filter_map(|content| match content {
117                        message::UserContent::Text(message::Text { text }) => {
118                            Some(Message::User { content: text })
119                        }
120                        _ => None,
121                    })
122                    .collect::<Vec<_>>();
123
124                Ok(messages)
125            }
126            message::Message::Assistant { content, .. } => {
127                let (text_content, tool_calls) = content.into_iter().fold(
128                    (Vec::new(), Vec::new()),
129                    |(mut texts, mut tools), content| {
130                        match content {
131                            message::AssistantContent::Text(text) => texts.push(text),
132                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
133                            message::AssistantContent::Reasoning(_) => {
134                                unimplemented!("Reasoning content is not currently supported on Mistral via Rig");
135                            }
136                            message::AssistantContent::Image(_) => {
137                                unimplemented!("Image content is not currently supported on Mistral via Rig");
138                            }
139                        }
140                        (texts, tools)
141                    },
142                );
143
144                Ok(vec![Message::Assistant {
145                    content: text_content
146                        .into_iter()
147                        .next()
148                        .map(|content| content.text)
149                        .unwrap_or_default(),
150                    tool_calls: tool_calls
151                        .into_iter()
152                        .map(|tool_call| tool_call.into())
153                        .collect::<Vec<_>>(),
154                    prefix: false,
155                }])
156            }
157        }
158    }
159}
160
161#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
162pub struct ToolCall {
163    pub id: String,
164    #[serde(default)]
165    pub r#type: ToolType,
166    pub function: Function,
167}
168
169impl From<message::ToolCall> for ToolCall {
170    fn from(tool_call: message::ToolCall) -> Self {
171        Self {
172            id: tool_call.id,
173            r#type: ToolType::default(),
174            function: Function {
175                name: tool_call.function.name,
176                arguments: tool_call.function.arguments,
177            },
178        }
179    }
180}
181
182#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
183pub struct Function {
184    pub name: String,
185    #[serde(with = "json_utils::stringified_json")]
186    pub arguments: serde_json::Value,
187}
188
189#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
190#[serde(rename_all = "lowercase")]
191pub enum ToolType {
192    #[default]
193    Function,
194}
195
196#[derive(Debug, Deserialize, Serialize, Clone)]
197pub struct ToolDefinition {
198    pub r#type: String,
199    pub function: completion::ToolDefinition,
200}
201
202impl From<completion::ToolDefinition> for ToolDefinition {
203    fn from(tool: completion::ToolDefinition) -> Self {
204        Self {
205            r#type: "function".into(),
206            function: tool,
207        }
208    }
209}
210
211#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
212pub struct ToolResultContent {
213    #[serde(default)]
214    r#type: ToolResultContentType,
215    text: String,
216}
217
218#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
219#[serde(rename_all = "lowercase")]
220pub enum ToolResultContentType {
221    #[default]
222    Text,
223}
224
225impl From<String> for ToolResultContent {
226    fn from(s: String) -> Self {
227        ToolResultContent {
228            r#type: ToolResultContentType::default(),
229            text: s,
230        }
231    }
232}
233
234impl From<String> for UserContent {
235    fn from(s: String) -> Self {
236        UserContent::Text { text: s }
237    }
238}
239
240impl FromStr for UserContent {
241    type Err = Infallible;
242
243    fn from_str(s: &str) -> Result<Self, Self::Err> {
244        Ok(UserContent::Text {
245            text: s.to_string(),
246        })
247    }
248}
249
250impl From<String> for AssistantContent {
251    fn from(s: String) -> Self {
252        AssistantContent { text: s }
253    }
254}
255
256impl FromStr for AssistantContent {
257    type Err = Infallible;
258
259    fn from_str(s: &str) -> Result<Self, Self::Err> {
260        Ok(AssistantContent {
261            text: s.to_string(),
262        })
263    }
264}
265
266#[derive(Clone)]
267pub struct CompletionModel<T = reqwest::Client> {
268    pub(crate) client: Client<T>,
269    pub model: String,
270}
271
272#[derive(Debug, Default, Serialize, Deserialize)]
273pub enum ToolChoice {
274    #[default]
275    Auto,
276    None,
277    Any,
278}
279
280impl TryFrom<message::ToolChoice> for ToolChoice {
281    type Error = CompletionError;
282
283    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
284        let res = match value {
285            message::ToolChoice::Auto => Self::Auto,
286            message::ToolChoice::None => Self::None,
287            message::ToolChoice::Required => Self::Any,
288            message::ToolChoice::Specific { .. } => {
289                return Err(CompletionError::ProviderError(
290                    "Mistral doesn't support requiring specific tools to be called".to_string(),
291                ));
292            }
293        };
294
295        Ok(res)
296    }
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300pub(super) struct MistralCompletionRequest {
301    model: String,
302    pub messages: Vec<Message>,
303    #[serde(flatten, skip_serializing_if = "Option::is_none")]
304    temperature: Option<f64>,
305    #[serde(skip_serializing_if = "Vec::is_empty")]
306    tools: Vec<ToolDefinition>,
307    #[serde(flatten, skip_serializing_if = "Option::is_none")]
308    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
309    #[serde(flatten, skip_serializing_if = "Option::is_none")]
310    pub additional_params: Option<serde_json::Value>,
311}
312
313impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest {
314    type Error = CompletionError;
315
316    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
317        let mut full_history: Vec<Message> = match &req.preamble {
318            Some(preamble) => vec![Message::system(preamble.clone())],
319            None => vec![],
320        };
321        if let Some(docs) = req.normalized_documents() {
322            let docs: Vec<Message> = docs.try_into()?;
323            full_history.extend(docs);
324        }
325
326        let chat_history: Vec<Message> = req
327            .chat_history
328            .clone()
329            .into_iter()
330            .map(|message| message.try_into())
331            .collect::<Result<Vec<Vec<Message>>, _>>()?
332            .into_iter()
333            .flatten()
334            .collect();
335
336        full_history.extend(chat_history);
337
338        let tool_choice = req
339            .tool_choice
340            .clone()
341            .map(crate::providers::openai::completion::ToolChoice::try_from)
342            .transpose()?;
343
344        Ok(Self {
345            model: model.to_string(),
346            messages: full_history,
347            temperature: req.temperature,
348            tools: req
349                .tools
350                .clone()
351                .into_iter()
352                .map(ToolDefinition::from)
353                .collect::<Vec<_>>(),
354            tool_choice,
355            additional_params: req.additional_params,
356        })
357    }
358}
359
360impl<T> CompletionModel<T> {
361    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
362        Self {
363            client,
364            model: model.into(),
365        }
366    }
367
368    pub fn with_model(client: Client<T>, model: &str) -> Self {
369        Self {
370            client,
371            model: model.into(),
372        }
373    }
374}
375
376#[derive(Debug, Deserialize, Clone, Serialize)]
377pub struct CompletionResponse {
378    pub id: String,
379    pub object: String,
380    pub created: u64,
381    pub model: String,
382    pub system_fingerprint: Option<String>,
383    pub choices: Vec<Choice>,
384    pub usage: Option<Usage>,
385}
386
387impl crate::telemetry::ProviderResponseExt for CompletionResponse {
388    type OutputMessage = Choice;
389    type Usage = Usage;
390
391    fn get_response_id(&self) -> Option<String> {
392        Some(self.id.clone())
393    }
394
395    fn get_response_model_name(&self) -> Option<String> {
396        Some(self.model.clone())
397    }
398
399    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
400        self.choices.clone()
401    }
402
403    fn get_text_response(&self) -> Option<String> {
404        let res = self
405            .choices
406            .iter()
407            .filter_map(|choice| match choice.message {
408                Message::Assistant { ref content, .. } => {
409                    if content.is_empty() {
410                        None
411                    } else {
412                        Some(content.to_string())
413                    }
414                }
415                _ => None,
416            })
417            .collect::<Vec<String>>()
418            .join("\n");
419
420        if res.is_empty() { None } else { Some(res) }
421    }
422
423    fn get_usage(&self) -> Option<Self::Usage> {
424        self.usage.clone()
425    }
426}
427
428impl GetTokenUsage for CompletionResponse {
429    fn token_usage(&self) -> Option<crate::completion::Usage> {
430        let api_usage = self.usage.clone()?;
431
432        let mut usage = crate::completion::Usage::new();
433        usage.input_tokens = api_usage.prompt_tokens as u64;
434        usage.output_tokens = api_usage.completion_tokens as u64;
435        usage.total_tokens = api_usage.total_tokens as u64;
436
437        Some(usage)
438    }
439}
440
441impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
442    type Error = CompletionError;
443
444    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
445        let choice = response.choices.first().ok_or_else(|| {
446            CompletionError::ResponseError("Response contained no choices".to_owned())
447        })?;
448        let content = match &choice.message {
449            Message::Assistant {
450                content,
451                tool_calls,
452                ..
453            } => {
454                let mut content = if content.is_empty() {
455                    vec![]
456                } else {
457                    vec![completion::AssistantContent::text(content.clone())]
458                };
459
460                content.extend(
461                    tool_calls
462                        .iter()
463                        .map(|call| {
464                            completion::AssistantContent::tool_call(
465                                &call.id,
466                                &call.function.name,
467                                call.function.arguments.clone(),
468                            )
469                        })
470                        .collect::<Vec<_>>(),
471                );
472                Ok(content)
473            }
474            _ => Err(CompletionError::ResponseError(
475                "Response did not contain a valid message or tool call".into(),
476            )),
477        }?;
478
479        let choice = OneOrMany::many(content).map_err(|_| {
480            CompletionError::ResponseError(
481                "Response contained no message or tool call (empty)".to_owned(),
482            )
483        })?;
484
485        let usage = response
486            .usage
487            .as_ref()
488            .map(|usage| completion::Usage {
489                input_tokens: usage.prompt_tokens as u64,
490                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
491                total_tokens: usage.total_tokens as u64,
492            })
493            .unwrap_or_default();
494
495        Ok(completion::CompletionResponse {
496            choice,
497            usage,
498            raw_response: response,
499        })
500    }
501}
502
503impl<T> completion::CompletionModel for CompletionModel<T>
504where
505    T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static,
506{
507    type Response = CompletionResponse;
508    type StreamingResponse = CompletionResponse;
509
510    type Client = Client<T>;
511
512    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
513        Self::new(client.clone(), model.into())
514    }
515
516    #[cfg_attr(feature = "worker", worker::send)]
517    async fn completion(
518        &self,
519        completion_request: CompletionRequest,
520    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
521        let preamble = completion_request.preamble.clone();
522        let request =
523            MistralCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
524
525        let span = if tracing::Span::current().is_disabled() {
526            info_span!(
527                target: "rig::completions",
528                "chat",
529                gen_ai.operation.name = "chat",
530                gen_ai.provider.name = "mistral",
531                gen_ai.request.model = self.model,
532                gen_ai.system_instructions = &preamble,
533                gen_ai.response.id = tracing::field::Empty,
534                gen_ai.response.model = tracing::field::Empty,
535                gen_ai.usage.output_tokens = tracing::field::Empty,
536                gen_ai.usage.input_tokens = tracing::field::Empty,
537                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
538                gen_ai.output.messages = tracing::field::Empty,
539            )
540        } else {
541            tracing::Span::current()
542        };
543
544        let body = serde_json::to_vec(&request)?;
545
546        let request = self
547            .client
548            .post("v1/chat/completions")?
549            .body(body)
550            .map_err(|e| CompletionError::HttpError(e.into()))?;
551
552        async move {
553            let response = self.client.send(request).await?;
554
555            if response.status().is_success() {
556                let text = http_client::text(response).await?;
557                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
558                    ApiResponse::Ok(response) => {
559                        let span = tracing::Span::current();
560                        span.record_token_usage(&response);
561                        span.record_model_output(&response.choices);
562                        span.record_response_metadata(&response);
563                        response.try_into()
564                    }
565                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
566                }
567            } else {
568                let text = http_client::text(response).await?;
569                Err(CompletionError::ProviderError(text))
570            }
571        }
572        .instrument(span)
573        .await
574    }
575
576    #[cfg_attr(feature = "worker", worker::send)]
577    async fn stream(
578        &self,
579        request: CompletionRequest,
580    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
581        let resp = self.completion(request).await?;
582
583        let stream = stream! {
584            for c in resp.choice.clone() {
585                match c {
586                    message::AssistantContent::Text(t) => {
587                        yield Ok(RawStreamingChoice::Message(t.text.clone()))
588                    }
589                    message::AssistantContent::ToolCall(tc) => {
590                        yield Ok(RawStreamingChoice::ToolCall {
591                            id: tc.id.clone(),
592                            name: tc.function.name.clone(),
593                            arguments: tc.function.arguments.clone(),
594                             call_id: None
595                        })
596                    }
597                    message::AssistantContent::Reasoning(_) => {
598                        unimplemented!("Reasoning is not supported on Mistral via Rig")
599                    }
600                    message::AssistantContent::Image(_) => {
601                        unimplemented!("Image content is not supported on Mistral via Rig")
602                    }
603                }
604            }
605
606            yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
607        };
608
609        Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[test]
618    fn test_response_deserialization() {
619        //https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post
620        let json_data = r#"
621        {
622            "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
623            "object": "chat.completion",
624            "model": "mistral-small-latest",
625            "usage": {
626                "prompt_tokens": 16,
627                "completion_tokens": 34,
628                "total_tokens": 50
629            },
630            "created": 1702256327,
631            "choices": [
632                {
633                    "index": 0,
634                    "message": {
635                        "content": "string",
636                        "tool_calls": [
637                            {
638                                "id": "null",
639                                "type": "function",
640                                "function": {
641                                    "name": "string",
642                                    "arguments": "{ }"
643                                },
644                                "index": 0
645                            }
646                        ],
647                        "prefix": false,
648                        "role": "assistant"
649                    },
650                    "finish_reason": "stop"
651                }
652            ]
653        }
654        "#;
655        let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
656        assert_eq!(completion_response.model, MISTRAL_SMALL);
657
658        let CompletionResponse {
659            id,
660            object,
661            created,
662            choices,
663            usage,
664            ..
665        } = completion_response;
666
667        assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
668
669        let Usage {
670            completion_tokens,
671            prompt_tokens,
672            total_tokens,
673        } = usage.unwrap();
674
675        assert_eq!(prompt_tokens, 16);
676        assert_eq!(completion_tokens, 34);
677        assert_eq!(total_tokens, 50);
678        assert_eq!(object, "chat.completion".to_string());
679        assert_eq!(created, 1702256327);
680        assert_eq!(choices.len(), 1);
681    }
682}