rig/providers/mistral/
completion.rs

1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use std::{convert::Infallible, str::FromStr};
5
6use super::client::{Client, Usage};
7use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse};
8use crate::{
9    OneOrMany,
10    completion::{self, CompletionError, CompletionRequest},
11    json_utils, message,
12    providers::mistral::client::ApiResponse,
13};
14
15pub const CODESTRAL: &str = "codestral-latest";
16pub const MISTRAL_LARGE: &str = "mistral-large-latest";
17pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
18pub const MISTRAL_SABA: &str = "mistral-saba-latest";
19pub const MINISTRAL_3B: &str = "ministral-3b-latest";
20pub const MINISTRAL_8B: &str = "ministral-8b-latest";
21
22//Free models
23pub const MISTRAL_SMALL: &str = "mistral-small-latest";
24pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
25pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
26pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
27
28// =================================================================
29// Rig Implementation Types
30// =================================================================
31
32#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
33#[serde(tag = "type", rename_all = "lowercase")]
34pub struct AssistantContent {
35    text: String,
36}
37
38#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
39#[serde(tag = "type", rename_all = "lowercase")]
40pub enum UserContent {
41    Text { text: String },
42}
43
44#[derive(Debug, Serialize, Deserialize, Clone)]
45pub struct Choice {
46    pub index: usize,
47    pub message: Message,
48    pub logprobs: Option<serde_json::Value>,
49    pub finish_reason: String,
50}
51
52#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
53#[serde(tag = "role", rename_all = "lowercase")]
54pub enum Message {
55    User {
56        content: String,
57    },
58    Assistant {
59        content: String,
60        #[serde(
61            default,
62            deserialize_with = "json_utils::null_or_vec",
63            skip_serializing_if = "Vec::is_empty"
64        )]
65        tool_calls: Vec<ToolCall>,
66        #[serde(default)]
67        prefix: bool,
68    },
69    System {
70        content: String,
71    },
72}
73
74impl Message {
75    pub fn user(content: String) -> Self {
76        Message::User { content }
77    }
78
79    pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
80        Message::Assistant {
81            content,
82            tool_calls,
83            prefix,
84        }
85    }
86
87    pub fn system(content: String) -> Self {
88        Message::System { content }
89    }
90}
91
92impl TryFrom<message::Message> for Vec<Message> {
93    type Error = message::MessageError;
94
95    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
96        match message {
97            message::Message::User { content } => {
98                let (_, other_content): (Vec<_>, Vec<_>) = content
99                    .into_iter()
100                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
101
102                let messages = other_content
103                    .into_iter()
104                    .filter_map(|content| match content {
105                        message::UserContent::Text(message::Text { text }) => {
106                            Some(Message::User { content: text })
107                        }
108                        _ => None,
109                    })
110                    .collect::<Vec<_>>();
111
112                Ok(messages)
113            }
114            message::Message::Assistant { content, .. } => {
115                let (text_content, tool_calls) = content.into_iter().fold(
116                    (Vec::new(), Vec::new()),
117                    |(mut texts, mut tools), content| {
118                        match content {
119                            message::AssistantContent::Text(text) => texts.push(text),
120                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
121                            message::AssistantContent::Reasoning(_) => {
122                                unimplemented!("Reasoning content is not currently supported on Mistral via Rig");
123                            }
124                        }
125                        (texts, tools)
126                    },
127                );
128
129                Ok(vec![Message::Assistant {
130                    content: text_content
131                        .into_iter()
132                        .next()
133                        .map(|content| content.text)
134                        .unwrap_or_default(),
135                    tool_calls: tool_calls
136                        .into_iter()
137                        .map(|tool_call| tool_call.into())
138                        .collect::<Vec<_>>(),
139                    prefix: false,
140                }])
141            }
142        }
143    }
144}
145
146#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
147pub struct ToolCall {
148    pub id: String,
149    #[serde(default)]
150    pub r#type: ToolType,
151    pub function: Function,
152}
153
154impl From<message::ToolCall> for ToolCall {
155    fn from(tool_call: message::ToolCall) -> Self {
156        Self {
157            id: tool_call.id,
158            r#type: ToolType::default(),
159            function: Function {
160                name: tool_call.function.name,
161                arguments: tool_call.function.arguments,
162            },
163        }
164    }
165}
166
167#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
168pub struct Function {
169    pub name: String,
170    #[serde(with = "json_utils::stringified_json")]
171    pub arguments: serde_json::Value,
172}
173
174#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
175#[serde(rename_all = "lowercase")]
176pub enum ToolType {
177    #[default]
178    Function,
179}
180
181#[derive(Debug, Deserialize, Serialize, Clone)]
182pub struct ToolDefinition {
183    pub r#type: String,
184    pub function: completion::ToolDefinition,
185}
186
187impl From<completion::ToolDefinition> for ToolDefinition {
188    fn from(tool: completion::ToolDefinition) -> Self {
189        Self {
190            r#type: "function".into(),
191            function: tool,
192        }
193    }
194}
195
196#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
197pub struct ToolResultContent {
198    #[serde(default)]
199    r#type: ToolResultContentType,
200    text: String,
201}
202
203#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
204#[serde(rename_all = "lowercase")]
205pub enum ToolResultContentType {
206    #[default]
207    Text,
208}
209
210impl From<String> for ToolResultContent {
211    fn from(s: String) -> Self {
212        ToolResultContent {
213            r#type: ToolResultContentType::default(),
214            text: s,
215        }
216    }
217}
218
219impl From<String> for UserContent {
220    fn from(s: String) -> Self {
221        UserContent::Text { text: s }
222    }
223}
224
225impl FromStr for UserContent {
226    type Err = Infallible;
227
228    fn from_str(s: &str) -> Result<Self, Self::Err> {
229        Ok(UserContent::Text {
230            text: s.to_string(),
231        })
232    }
233}
234
235impl From<String> for AssistantContent {
236    fn from(s: String) -> Self {
237        AssistantContent { text: s }
238    }
239}
240
241impl FromStr for AssistantContent {
242    type Err = Infallible;
243
244    fn from_str(s: &str) -> Result<Self, Self::Err> {
245        Ok(AssistantContent {
246            text: s.to_string(),
247        })
248    }
249}
250
251#[derive(Clone)]
252pub struct CompletionModel {
253    pub(crate) client: Client,
254    pub model: String,
255}
256
257impl CompletionModel {
258    pub fn new(client: Client, model: &str) -> Self {
259        Self {
260            client,
261            model: model.to_string(),
262        }
263    }
264
265    pub(crate) fn create_completion_request(
266        &self,
267        completion_request: CompletionRequest,
268    ) -> Result<Value, CompletionError> {
269        let mut partial_history = vec![];
270        if let Some(docs) = completion_request.normalized_documents() {
271            partial_history.push(docs);
272        }
273
274        partial_history.extend(completion_request.chat_history);
275
276        let mut full_history: Vec<Message> = match &completion_request.preamble {
277            Some(preamble) => vec![Message::system(preamble.clone())],
278            None => vec![],
279        };
280
281        full_history.extend(
282            partial_history
283                .into_iter()
284                .map(message::Message::try_into)
285                .collect::<Result<Vec<Vec<Message>>, _>>()?
286                .into_iter()
287                .flatten()
288                .collect::<Vec<_>>(),
289        );
290
291        let request = if completion_request.tools.is_empty() {
292            json!({
293                "model": self.model,
294                "messages": full_history,
295
296            })
297        } else {
298            json!({
299                "model": self.model,
300                "messages": full_history,
301                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
302                "tool_choice": "auto",
303            })
304        };
305
306        let request = if let Some(temperature) = completion_request.temperature {
307            json_utils::merge(
308                request,
309                json!({
310                    "temperature": temperature,
311                }),
312            )
313        } else {
314            request
315        };
316
317        let request = if let Some(params) = completion_request.additional_params {
318            json_utils::merge(request, params)
319        } else {
320            request
321        };
322
323        Ok(request)
324    }
325}
326
327#[derive(Debug, Deserialize, Clone, Serialize)]
328pub struct CompletionResponse {
329    pub id: String,
330    pub object: String,
331    pub created: u64,
332    pub model: String,
333    pub system_fingerprint: Option<String>,
334    pub choices: Vec<Choice>,
335    pub usage: Option<Usage>,
336}
337
338impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
339    type Error = CompletionError;
340
341    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
342        let choice = response.choices.first().ok_or_else(|| {
343            CompletionError::ResponseError("Response contained no choices".to_owned())
344        })?;
345        let content = match &choice.message {
346            Message::Assistant {
347                content,
348                tool_calls,
349                ..
350            } => {
351                let mut content = if content.is_empty() {
352                    vec![]
353                } else {
354                    vec![completion::AssistantContent::text(content.clone())]
355                };
356
357                content.extend(
358                    tool_calls
359                        .iter()
360                        .map(|call| {
361                            completion::AssistantContent::tool_call(
362                                &call.id,
363                                &call.function.name,
364                                call.function.arguments.clone(),
365                            )
366                        })
367                        .collect::<Vec<_>>(),
368                );
369                Ok(content)
370            }
371            _ => Err(CompletionError::ResponseError(
372                "Response did not contain a valid message or tool call".into(),
373            )),
374        }?;
375
376        let choice = OneOrMany::many(content).map_err(|_| {
377            CompletionError::ResponseError(
378                "Response contained no message or tool call (empty)".to_owned(),
379            )
380        })?;
381
382        let usage = response
383            .usage
384            .as_ref()
385            .map(|usage| completion::Usage {
386                input_tokens: usage.prompt_tokens as u64,
387                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
388                total_tokens: usage.total_tokens as u64,
389            })
390            .unwrap_or_default();
391
392        Ok(completion::CompletionResponse {
393            choice,
394            usage,
395            raw_response: response,
396        })
397    }
398}
399
400impl completion::CompletionModel for CompletionModel {
401    type Response = CompletionResponse;
402    type StreamingResponse = CompletionResponse;
403
404    #[cfg_attr(feature = "worker", worker::send)]
405    async fn completion(
406        &self,
407        completion_request: CompletionRequest,
408    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
409        let request = self.create_completion_request(completion_request)?;
410
411        let response = self
412            .client
413            .post("v1/chat/completions")
414            .json(&request)
415            .send()
416            .await?;
417
418        if response.status().is_success() {
419            let text = response.text().await?;
420            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
421                ApiResponse::Ok(response) => {
422                    tracing::debug!(target: "rig",
423                        "Mistral completion token usage: {:?}",
424                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
425                    );
426                    response.try_into()
427                }
428                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
429            }
430        } else {
431            Err(CompletionError::ProviderError(response.text().await?))
432        }
433    }
434
435    #[cfg_attr(feature = "worker", worker::send)]
436    async fn stream(
437        &self,
438        request: CompletionRequest,
439    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
440        let resp = self.completion(request).await?;
441
442        let stream = Box::pin(stream! {
443            for c in resp.choice.clone() {
444                match c {
445                    message::AssistantContent::Text(t) => {
446                        yield Ok(RawStreamingChoice::Message(t.text.clone()))
447                    }
448                    message::AssistantContent::ToolCall(tc) => {
449                        yield Ok(RawStreamingChoice::ToolCall {
450                            id: tc.id.clone(),
451                            name: tc.function.name.clone(),
452                            arguments: tc.function.arguments.clone(),
453                             call_id: None
454                        })
455                    }
456                    message::AssistantContent::Reasoning(_) => {
457                        unimplemented!("Reasoning is not supported on Mistral via Rig")
458                    }
459                }
460            }
461
462            yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
463        });
464
465        Ok(StreamingCompletionResponse::stream(stream))
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    #[test]
474    fn test_response_deserialization() {
475        //https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post
476        let json_data = r#"
477        {
478            "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
479            "object": "chat.completion",
480            "model": "mistral-small-latest",
481            "usage": {
482                "prompt_tokens": 16,
483                "completion_tokens": 34,
484                "total_tokens": 50
485            },
486            "created": 1702256327,
487            "choices": [
488                {
489                    "index": 0,
490                    "message": {
491                        "content": "string",
492                        "tool_calls": [
493                            {
494                                "id": "null",
495                                "type": "function",
496                                "function": {
497                                    "name": "string",
498                                    "arguments": "{ }"
499                                },
500                                "index": 0
501                            }
502                        ],
503                        "prefix": false,
504                        "role": "assistant"
505                    },
506                    "finish_reason": "stop"
507                }
508            ]
509        }
510        "#;
511        let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
512        assert_eq!(completion_response.model, MISTRAL_SMALL);
513
514        let CompletionResponse {
515            id,
516            object,
517            created,
518            choices,
519            usage,
520            ..
521        } = completion_response;
522
523        assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
524
525        let Usage {
526            completion_tokens,
527            prompt_tokens,
528            total_tokens,
529        } = usage.unwrap();
530
531        assert_eq!(prompt_tokens, 16);
532        assert_eq!(completion_tokens, 34);
533        assert_eq!(total_tokens, 50);
534        assert_eq!(object, "chat.completion".to_string());
535        assert_eq!(created, 1702256327);
536        assert_eq!(choices.len(), 1);
537    }
538}