rig/providers/mistral/
completion.rs

1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use std::{convert::Infallible, str::FromStr};
5
6use super::client::{Client, Usage};
7use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse};
8use crate::{
9    OneOrMany,
10    completion::{self, CompletionError, CompletionRequest},
11    json_utils, message,
12    providers::mistral::client::ApiResponse,
13};
14
15pub const CODESTRAL: &str = "codestral-latest";
16pub const MISTRAL_LARGE: &str = "mistral-large-latest";
17pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
18pub const MISTRAL_SABA: &str = "mistral-saba-latest";
19pub const MINISTRAL_3B: &str = "ministral-3b-latest";
20pub const MINISTRAL_8B: &str = "ministral-8b-latest";
21
22//Free models
23pub const MISTRAL_SMALL: &str = "mistral-small-latest";
24pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
25pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
26pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
27
28// =================================================================
29// Rig Implementation Types
30// =================================================================
31
32#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
33#[serde(tag = "type", rename_all = "lowercase")]
34pub struct AssistantContent {
35    text: String,
36}
37
38#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
39#[serde(tag = "type", rename_all = "lowercase")]
40pub enum UserContent {
41    Text { text: String },
42}
43
44#[derive(Debug, Serialize, Deserialize, Clone)]
45pub struct Choice {
46    pub index: usize,
47    pub message: Message,
48    pub logprobs: Option<serde_json::Value>,
49    pub finish_reason: String,
50}
51
52#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
53#[serde(tag = "role", rename_all = "lowercase")]
54pub enum Message {
55    User {
56        content: String,
57    },
58    Assistant {
59        content: String,
60        #[serde(
61            default,
62            deserialize_with = "json_utils::null_or_vec",
63            skip_serializing_if = "Vec::is_empty"
64        )]
65        tool_calls: Vec<ToolCall>,
66        #[serde(default)]
67        prefix: bool,
68    },
69    System {
70        content: String,
71    },
72}
73
74impl Message {
75    pub fn user(content: String) -> Self {
76        Message::User { content }
77    }
78
79    pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
80        Message::Assistant {
81            content,
82            tool_calls,
83            prefix,
84        }
85    }
86
87    pub fn system(content: String) -> Self {
88        Message::System { content }
89    }
90}
91
92impl TryFrom<message::Message> for Vec<Message> {
93    type Error = message::MessageError;
94
95    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
96        match message {
97            message::Message::User { content } => {
98                let (_, other_content): (Vec<_>, Vec<_>) = content
99                    .into_iter()
100                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
101
102                let messages = other_content
103                    .into_iter()
104                    .filter_map(|content| match content {
105                        message::UserContent::Text(message::Text { text }) => {
106                            Some(Message::User { content: text })
107                        }
108                        _ => None,
109                    })
110                    .collect::<Vec<_>>();
111
112                Ok(messages)
113            }
114            message::Message::Assistant { content, .. } => {
115                let (text_content, tool_calls) = content.into_iter().fold(
116                    (Vec::new(), Vec::new()),
117                    |(mut texts, mut tools), content| {
118                        match content {
119                            message::AssistantContent::Text(text) => texts.push(text),
120                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
121                        }
122                        (texts, tools)
123                    },
124                );
125
126                Ok(vec![Message::Assistant {
127                    content: text_content
128                        .into_iter()
129                        .next()
130                        .map(|content| content.text)
131                        .unwrap_or_default(),
132                    tool_calls: tool_calls
133                        .into_iter()
134                        .map(|tool_call| tool_call.into())
135                        .collect::<Vec<_>>(),
136                    prefix: false,
137                }])
138            }
139        }
140    }
141}
142
143#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
144pub struct ToolCall {
145    pub id: String,
146    #[serde(default)]
147    pub r#type: ToolType,
148    pub function: Function,
149}
150
151impl From<message::ToolCall> for ToolCall {
152    fn from(tool_call: message::ToolCall) -> Self {
153        Self {
154            id: tool_call.id,
155            r#type: ToolType::default(),
156            function: Function {
157                name: tool_call.function.name,
158                arguments: tool_call.function.arguments,
159            },
160        }
161    }
162}
163
164#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
165pub struct Function {
166    pub name: String,
167    #[serde(with = "json_utils::stringified_json")]
168    pub arguments: serde_json::Value,
169}
170
171#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
172#[serde(rename_all = "lowercase")]
173pub enum ToolType {
174    #[default]
175    Function,
176}
177
178#[derive(Debug, Deserialize, Serialize, Clone)]
179pub struct ToolDefinition {
180    pub r#type: String,
181    pub function: completion::ToolDefinition,
182}
183
184impl From<completion::ToolDefinition> for ToolDefinition {
185    fn from(tool: completion::ToolDefinition) -> Self {
186        Self {
187            r#type: "function".into(),
188            function: tool,
189        }
190    }
191}
192
193#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
194pub struct ToolResultContent {
195    #[serde(default)]
196    r#type: ToolResultContentType,
197    text: String,
198}
199
200#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
201#[serde(rename_all = "lowercase")]
202pub enum ToolResultContentType {
203    #[default]
204    Text,
205}
206
207impl From<String> for ToolResultContent {
208    fn from(s: String) -> Self {
209        ToolResultContent {
210            r#type: ToolResultContentType::default(),
211            text: s,
212        }
213    }
214}
215
216impl From<String> for UserContent {
217    fn from(s: String) -> Self {
218        UserContent::Text { text: s }
219    }
220}
221
222impl FromStr for UserContent {
223    type Err = Infallible;
224
225    fn from_str(s: &str) -> Result<Self, Self::Err> {
226        Ok(UserContent::Text {
227            text: s.to_string(),
228        })
229    }
230}
231
232impl From<String> for AssistantContent {
233    fn from(s: String) -> Self {
234        AssistantContent { text: s }
235    }
236}
237
238impl FromStr for AssistantContent {
239    type Err = Infallible;
240
241    fn from_str(s: &str) -> Result<Self, Self::Err> {
242        Ok(AssistantContent {
243            text: s.to_string(),
244        })
245    }
246}
247
248#[derive(Clone)]
249pub struct CompletionModel {
250    pub(crate) client: Client,
251    pub model: String,
252}
253
254impl CompletionModel {
255    pub fn new(client: Client, model: &str) -> Self {
256        Self {
257            client,
258            model: model.to_string(),
259        }
260    }
261
262    pub(crate) fn create_completion_request(
263        &self,
264        completion_request: CompletionRequest,
265    ) -> Result<Value, CompletionError> {
266        let mut partial_history = vec![];
267        if let Some(docs) = completion_request.normalized_documents() {
268            partial_history.push(docs);
269        }
270
271        partial_history.extend(completion_request.chat_history);
272
273        let mut full_history: Vec<Message> = match &completion_request.preamble {
274            Some(preamble) => vec![Message::system(preamble.clone())],
275            None => vec![],
276        };
277
278        full_history.extend(
279            partial_history
280                .into_iter()
281                .map(message::Message::try_into)
282                .collect::<Result<Vec<Vec<Message>>, _>>()?
283                .into_iter()
284                .flatten()
285                .collect::<Vec<_>>(),
286        );
287
288        let request = if completion_request.tools.is_empty() {
289            json!({
290                "model": self.model,
291                "messages": full_history,
292
293            })
294        } else {
295            json!({
296                "model": self.model,
297                "messages": full_history,
298                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
299                "tool_choice": "auto",
300            })
301        };
302
303        let request = if let Some(temperature) = completion_request.temperature {
304            json_utils::merge(
305                request,
306                json!({
307                    "temperature": temperature,
308                }),
309            )
310        } else {
311            request
312        };
313
314        let request = if let Some(params) = completion_request.additional_params {
315            json_utils::merge(request, params)
316        } else {
317            request
318        };
319
320        Ok(request)
321    }
322}
323
324#[derive(Debug, Deserialize, Clone)]
325pub struct CompletionResponse {
326    pub id: String,
327    pub object: String,
328    pub created: u64,
329    pub model: String,
330    pub system_fingerprint: Option<String>,
331    pub choices: Vec<Choice>,
332    pub usage: Option<Usage>,
333}
334
335impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
336    type Error = CompletionError;
337
338    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
339        let choice = response.choices.first().ok_or_else(|| {
340            CompletionError::ResponseError("Response contained no choices".to_owned())
341        })?;
342        let content = match &choice.message {
343            Message::Assistant {
344                content,
345                tool_calls,
346                ..
347            } => {
348                let mut content = if content.is_empty() {
349                    vec![]
350                } else {
351                    vec![completion::AssistantContent::text(content.clone())]
352                };
353
354                content.extend(
355                    tool_calls
356                        .iter()
357                        .map(|call| {
358                            completion::AssistantContent::tool_call(
359                                &call.id,
360                                &call.function.name,
361                                call.function.arguments.clone(),
362                            )
363                        })
364                        .collect::<Vec<_>>(),
365                );
366                Ok(content)
367            }
368            _ => Err(CompletionError::ResponseError(
369                "Response did not contain a valid message or tool call".into(),
370            )),
371        }?;
372
373        let choice = OneOrMany::many(content).map_err(|_| {
374            CompletionError::ResponseError(
375                "Response contained no message or tool call (empty)".to_owned(),
376            )
377        })?;
378
379        Ok(completion::CompletionResponse {
380            choice,
381            raw_response: response,
382        })
383    }
384}
385
386impl completion::CompletionModel for CompletionModel {
387    type Response = CompletionResponse;
388    type StreamingResponse = CompletionResponse;
389
390    #[cfg_attr(feature = "worker", worker::send)]
391    async fn completion(
392        &self,
393        completion_request: CompletionRequest,
394    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
395        let request = self.create_completion_request(completion_request)?;
396
397        let response = self
398            .client
399            .post("v1/chat/completions")
400            .json(&request)
401            .send()
402            .await?;
403
404        if response.status().is_success() {
405            let text = response.text().await?;
406            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
407                ApiResponse::Ok(response) => {
408                    tracing::debug!(target: "rig",
409                        "Mistral completion token usage: {:?}",
410                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
411                    );
412                    response.try_into()
413                }
414                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
415            }
416        } else {
417            Err(CompletionError::ProviderError(response.text().await?))
418        }
419    }
420
421    #[cfg_attr(feature = "worker", worker::send)]
422    async fn stream(
423        &self,
424        request: CompletionRequest,
425    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
426        let resp = self.completion(request).await?;
427
428        let stream = Box::pin(stream! {
429            for c in resp.choice.clone() {
430                match c {
431                    message::AssistantContent::Text(t) => {
432                        yield Ok(RawStreamingChoice::Message(t.text.clone()))
433                    }
434                    message::AssistantContent::ToolCall(tc) => {
435                        yield Ok(RawStreamingChoice::ToolCall {
436                            id: tc.id.clone(),
437                            name: tc.function.name.clone(),
438                            arguments: tc.function.arguments.clone(),
439                             call_id: None
440                        })
441                    }
442                }
443            }
444
445            yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
446        });
447
448        Ok(StreamingCompletionResponse::stream(stream))
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    #[test]
457    fn test_response_deserialization() {
458        //https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post
459        let json_data = r#"
460        {
461            "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
462            "object": "chat.completion",
463            "model": "mistral-small-latest",
464            "usage": {
465                "prompt_tokens": 16,
466                "completion_tokens": 34,
467                "total_tokens": 50
468            },
469            "created": 1702256327,
470            "choices": [
471                {
472                    "index": 0,
473                    "message": {
474                        "content": "string",
475                        "tool_calls": [
476                            {
477                                "id": "null",
478                                "type": "function",
479                                "function": {
480                                    "name": "string",
481                                    "arguments": "{ }"
482                                },
483                                "index": 0
484                            }
485                        ],
486                        "prefix": false,
487                        "role": "assistant"
488                    },
489                    "finish_reason": "stop"
490                }
491            ]
492        }
493        "#;
494        let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
495        assert_eq!(completion_response.model, MISTRAL_SMALL);
496
497        let CompletionResponse {
498            id,
499            object,
500            created,
501            choices,
502            usage,
503            ..
504        } = completion_response;
505
506        assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
507
508        let Usage {
509            completion_tokens,
510            prompt_tokens,
511            total_tokens,
512        } = usage.unwrap();
513
514        assert_eq!(prompt_tokens, 16);
515        assert_eq!(completion_tokens, 34);
516        assert_eq!(total_tokens, 50);
517        assert_eq!(object, "chat.completion".to_string());
518        assert_eq!(created, 1702256327);
519        assert_eq!(choices.len(), 1);
520    }
521}