rig/providers/mistral/
completion.rs

1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use std::{convert::Infallible, str::FromStr};
5
6use super::client::{Client, Usage};
7use crate::completion::GetTokenUsage;
8use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse};
9use crate::{
10    OneOrMany,
11    completion::{self, CompletionError, CompletionRequest},
12    json_utils, message,
13    providers::mistral::client::ApiResponse,
14};
15
16pub const CODESTRAL: &str = "codestral-latest";
17pub const MISTRAL_LARGE: &str = "mistral-large-latest";
18pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
19pub const MISTRAL_SABA: &str = "mistral-saba-latest";
20pub const MINISTRAL_3B: &str = "ministral-3b-latest";
21pub const MINISTRAL_8B: &str = "ministral-8b-latest";
22
23//Free models
24pub const MISTRAL_SMALL: &str = "mistral-small-latest";
25pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
26pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
27pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
28
29// =================================================================
30// Rig Implementation Types
31// =================================================================
32
33#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
34#[serde(tag = "type", rename_all = "lowercase")]
35pub struct AssistantContent {
36    text: String,
37}
38
39#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
40#[serde(tag = "type", rename_all = "lowercase")]
41pub enum UserContent {
42    Text { text: String },
43}
44
45#[derive(Debug, Serialize, Deserialize, Clone)]
46pub struct Choice {
47    pub index: usize,
48    pub message: Message,
49    pub logprobs: Option<serde_json::Value>,
50    pub finish_reason: String,
51}
52
53#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
54#[serde(tag = "role", rename_all = "lowercase")]
55pub enum Message {
56    User {
57        content: String,
58    },
59    Assistant {
60        content: String,
61        #[serde(
62            default,
63            deserialize_with = "json_utils::null_or_vec",
64            skip_serializing_if = "Vec::is_empty"
65        )]
66        tool_calls: Vec<ToolCall>,
67        #[serde(default)]
68        prefix: bool,
69    },
70    System {
71        content: String,
72    },
73}
74
75impl Message {
76    pub fn user(content: String) -> Self {
77        Message::User { content }
78    }
79
80    pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
81        Message::Assistant {
82            content,
83            tool_calls,
84            prefix,
85        }
86    }
87
88    pub fn system(content: String) -> Self {
89        Message::System { content }
90    }
91}
92
93impl TryFrom<message::Message> for Vec<Message> {
94    type Error = message::MessageError;
95
96    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
97        match message {
98            message::Message::User { content } => {
99                let (_, other_content): (Vec<_>, Vec<_>) = content
100                    .into_iter()
101                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
102
103                let messages = other_content
104                    .into_iter()
105                    .filter_map(|content| match content {
106                        message::UserContent::Text(message::Text { text }) => {
107                            Some(Message::User { content: text })
108                        }
109                        _ => None,
110                    })
111                    .collect::<Vec<_>>();
112
113                Ok(messages)
114            }
115            message::Message::Assistant { content, .. } => {
116                let (text_content, tool_calls) = content.into_iter().fold(
117                    (Vec::new(), Vec::new()),
118                    |(mut texts, mut tools), content| {
119                        match content {
120                            message::AssistantContent::Text(text) => texts.push(text),
121                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
122                            message::AssistantContent::Reasoning(_) => {
123                                unimplemented!("Reasoning content is not currently supported on Mistral via Rig");
124                            }
125                        }
126                        (texts, tools)
127                    },
128                );
129
130                Ok(vec![Message::Assistant {
131                    content: text_content
132                        .into_iter()
133                        .next()
134                        .map(|content| content.text)
135                        .unwrap_or_default(),
136                    tool_calls: tool_calls
137                        .into_iter()
138                        .map(|tool_call| tool_call.into())
139                        .collect::<Vec<_>>(),
140                    prefix: false,
141                }])
142            }
143        }
144    }
145}
146
147#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
148pub struct ToolCall {
149    pub id: String,
150    #[serde(default)]
151    pub r#type: ToolType,
152    pub function: Function,
153}
154
155impl From<message::ToolCall> for ToolCall {
156    fn from(tool_call: message::ToolCall) -> Self {
157        Self {
158            id: tool_call.id,
159            r#type: ToolType::default(),
160            function: Function {
161                name: tool_call.function.name,
162                arguments: tool_call.function.arguments,
163            },
164        }
165    }
166}
167
168#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
169pub struct Function {
170    pub name: String,
171    #[serde(with = "json_utils::stringified_json")]
172    pub arguments: serde_json::Value,
173}
174
175#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
176#[serde(rename_all = "lowercase")]
177pub enum ToolType {
178    #[default]
179    Function,
180}
181
182#[derive(Debug, Deserialize, Serialize, Clone)]
183pub struct ToolDefinition {
184    pub r#type: String,
185    pub function: completion::ToolDefinition,
186}
187
188impl From<completion::ToolDefinition> for ToolDefinition {
189    fn from(tool: completion::ToolDefinition) -> Self {
190        Self {
191            r#type: "function".into(),
192            function: tool,
193        }
194    }
195}
196
197#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
198pub struct ToolResultContent {
199    #[serde(default)]
200    r#type: ToolResultContentType,
201    text: String,
202}
203
204#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
205#[serde(rename_all = "lowercase")]
206pub enum ToolResultContentType {
207    #[default]
208    Text,
209}
210
211impl From<String> for ToolResultContent {
212    fn from(s: String) -> Self {
213        ToolResultContent {
214            r#type: ToolResultContentType::default(),
215            text: s,
216        }
217    }
218}
219
220impl From<String> for UserContent {
221    fn from(s: String) -> Self {
222        UserContent::Text { text: s }
223    }
224}
225
226impl FromStr for UserContent {
227    type Err = Infallible;
228
229    fn from_str(s: &str) -> Result<Self, Self::Err> {
230        Ok(UserContent::Text {
231            text: s.to_string(),
232        })
233    }
234}
235
236impl From<String> for AssistantContent {
237    fn from(s: String) -> Self {
238        AssistantContent { text: s }
239    }
240}
241
242impl FromStr for AssistantContent {
243    type Err = Infallible;
244
245    fn from_str(s: &str) -> Result<Self, Self::Err> {
246        Ok(AssistantContent {
247            text: s.to_string(),
248        })
249    }
250}
251
252#[derive(Clone)]
253pub struct CompletionModel {
254    pub(crate) client: Client,
255    pub model: String,
256}
257
258impl CompletionModel {
259    pub fn new(client: Client, model: &str) -> Self {
260        Self {
261            client,
262            model: model.to_string(),
263        }
264    }
265
266    pub(crate) fn create_completion_request(
267        &self,
268        completion_request: CompletionRequest,
269    ) -> Result<Value, CompletionError> {
270        let mut partial_history = vec![];
271        if let Some(docs) = completion_request.normalized_documents() {
272            partial_history.push(docs);
273        }
274
275        partial_history.extend(completion_request.chat_history);
276
277        let mut full_history: Vec<Message> = match &completion_request.preamble {
278            Some(preamble) => vec![Message::system(preamble.clone())],
279            None => vec![],
280        };
281
282        full_history.extend(
283            partial_history
284                .into_iter()
285                .map(message::Message::try_into)
286                .collect::<Result<Vec<Vec<Message>>, _>>()?
287                .into_iter()
288                .flatten()
289                .collect::<Vec<_>>(),
290        );
291
292        let request = if completion_request.tools.is_empty() {
293            json!({
294                "model": self.model,
295                "messages": full_history,
296
297            })
298        } else {
299            json!({
300                "model": self.model,
301                "messages": full_history,
302                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
303                "tool_choice": "auto",
304            })
305        };
306
307        let request = if let Some(temperature) = completion_request.temperature {
308            json_utils::merge(
309                request,
310                json!({
311                    "temperature": temperature,
312                }),
313            )
314        } else {
315            request
316        };
317
318        let request = if let Some(params) = completion_request.additional_params {
319            json_utils::merge(request, params)
320        } else {
321            request
322        };
323
324        Ok(request)
325    }
326}
327
328#[derive(Debug, Deserialize, Clone, Serialize)]
329pub struct CompletionResponse {
330    pub id: String,
331    pub object: String,
332    pub created: u64,
333    pub model: String,
334    pub system_fingerprint: Option<String>,
335    pub choices: Vec<Choice>,
336    pub usage: Option<Usage>,
337}
338
339impl GetTokenUsage for CompletionResponse {
340    fn token_usage(&self) -> Option<crate::completion::Usage> {
341        let api_usage = self.usage.clone()?;
342
343        let mut usage = crate::completion::Usage::new();
344        usage.input_tokens = api_usage.prompt_tokens as u64;
345        usage.output_tokens = api_usage.completion_tokens as u64;
346        usage.total_tokens = api_usage.total_tokens as u64;
347
348        Some(usage)
349    }
350}
351
352impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
353    type Error = CompletionError;
354
355    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
356        let choice = response.choices.first().ok_or_else(|| {
357            CompletionError::ResponseError("Response contained no choices".to_owned())
358        })?;
359        let content = match &choice.message {
360            Message::Assistant {
361                content,
362                tool_calls,
363                ..
364            } => {
365                let mut content = if content.is_empty() {
366                    vec![]
367                } else {
368                    vec![completion::AssistantContent::text(content.clone())]
369                };
370
371                content.extend(
372                    tool_calls
373                        .iter()
374                        .map(|call| {
375                            completion::AssistantContent::tool_call(
376                                &call.id,
377                                &call.function.name,
378                                call.function.arguments.clone(),
379                            )
380                        })
381                        .collect::<Vec<_>>(),
382                );
383                Ok(content)
384            }
385            _ => Err(CompletionError::ResponseError(
386                "Response did not contain a valid message or tool call".into(),
387            )),
388        }?;
389
390        let choice = OneOrMany::many(content).map_err(|_| {
391            CompletionError::ResponseError(
392                "Response contained no message or tool call (empty)".to_owned(),
393            )
394        })?;
395
396        let usage = response
397            .usage
398            .as_ref()
399            .map(|usage| completion::Usage {
400                input_tokens: usage.prompt_tokens as u64,
401                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
402                total_tokens: usage.total_tokens as u64,
403            })
404            .unwrap_or_default();
405
406        Ok(completion::CompletionResponse {
407            choice,
408            usage,
409            raw_response: response,
410        })
411    }
412}
413
414impl completion::CompletionModel for CompletionModel {
415    type Response = CompletionResponse;
416    type StreamingResponse = CompletionResponse;
417
418    #[cfg_attr(feature = "worker", worker::send)]
419    async fn completion(
420        &self,
421        completion_request: CompletionRequest,
422    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
423        let request = self.create_completion_request(completion_request)?;
424
425        let response = self
426            .client
427            .post("v1/chat/completions")
428            .json(&request)
429            .send()
430            .await?;
431
432        if response.status().is_success() {
433            let text = response.text().await?;
434            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
435                ApiResponse::Ok(response) => {
436                    tracing::debug!(target: "rig",
437                        "Mistral completion token usage: {:?}",
438                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
439                    );
440                    response.try_into()
441                }
442                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
443            }
444        } else {
445            Err(CompletionError::ProviderError(response.text().await?))
446        }
447    }
448
449    #[cfg_attr(feature = "worker", worker::send)]
450    async fn stream(
451        &self,
452        request: CompletionRequest,
453    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
454        let resp = self.completion(request).await?;
455
456        let stream = Box::pin(stream! {
457            for c in resp.choice.clone() {
458                match c {
459                    message::AssistantContent::Text(t) => {
460                        yield Ok(RawStreamingChoice::Message(t.text.clone()))
461                    }
462                    message::AssistantContent::ToolCall(tc) => {
463                        yield Ok(RawStreamingChoice::ToolCall {
464                            id: tc.id.clone(),
465                            name: tc.function.name.clone(),
466                            arguments: tc.function.arguments.clone(),
467                             call_id: None
468                        })
469                    }
470                    message::AssistantContent::Reasoning(_) => {
471                        unimplemented!("Reasoning is not supported on Mistral via Rig")
472                    }
473                }
474            }
475
476            yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
477        });
478
479        Ok(StreamingCompletionResponse::stream(stream))
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_response_deserialization() {
489        //https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post
490        let json_data = r#"
491        {
492            "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
493            "object": "chat.completion",
494            "model": "mistral-small-latest",
495            "usage": {
496                "prompt_tokens": 16,
497                "completion_tokens": 34,
498                "total_tokens": 50
499            },
500            "created": 1702256327,
501            "choices": [
502                {
503                    "index": 0,
504                    "message": {
505                        "content": "string",
506                        "tool_calls": [
507                            {
508                                "id": "null",
509                                "type": "function",
510                                "function": {
511                                    "name": "string",
512                                    "arguments": "{ }"
513                                },
514                                "index": 0
515                            }
516                        ],
517                        "prefix": false,
518                        "role": "assistant"
519                    },
520                    "finish_reason": "stop"
521                }
522            ]
523        }
524        "#;
525        let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
526        assert_eq!(completion_response.model, MISTRAL_SMALL);
527
528        let CompletionResponse {
529            id,
530            object,
531            created,
532            choices,
533            usage,
534            ..
535        } = completion_response;
536
537        assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
538
539        let Usage {
540            completion_tokens,
541            prompt_tokens,
542            total_tokens,
543        } = usage.unwrap();
544
545        assert_eq!(prompt_tokens, 16);
546        assert_eq!(completion_tokens, 34);
547        assert_eq!(total_tokens, 50);
548        assert_eq!(object, "chat.completion".to_string());
549        assert_eq!(created, 1702256327);
550        assert_eq!(choices.len(), 1);
551    }
552}