rig/providers/mistral/
completion.rs

1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use std::{convert::Infallible, str::FromStr};
5use tracing::{Instrument, info_span};
6
7use super::client::{Client, Usage};
8use crate::completion::GetTokenUsage;
9use crate::http_client::{self, HttpClientExt};
10use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse};
11use crate::{
12    OneOrMany,
13    completion::{self, CompletionError, CompletionRequest},
14    json_utils, message,
15    providers::mistral::client::ApiResponse,
16    telemetry::SpanCombinator,
17};
18
19pub const CODESTRAL: &str = "codestral-latest";
20pub const MISTRAL_LARGE: &str = "mistral-large-latest";
21pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
22pub const MISTRAL_SABA: &str = "mistral-saba-latest";
23pub const MINISTRAL_3B: &str = "ministral-3b-latest";
24pub const MINISTRAL_8B: &str = "ministral-8b-latest";
25
26//Free models
27pub const MISTRAL_SMALL: &str = "mistral-small-latest";
28pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
29pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
30pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
31
32// =================================================================
33// Rig Implementation Types
34// =================================================================
35
36#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
37#[serde(tag = "type", rename_all = "lowercase")]
38pub struct AssistantContent {
39    text: String,
40}
41
42#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
43#[serde(tag = "type", rename_all = "lowercase")]
44pub enum UserContent {
45    Text { text: String },
46}
47
48#[derive(Debug, Serialize, Deserialize, Clone)]
49pub struct Choice {
50    pub index: usize,
51    pub message: Message,
52    pub logprobs: Option<serde_json::Value>,
53    pub finish_reason: String,
54}
55
56#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
57#[serde(tag = "role", rename_all = "lowercase")]
58pub enum Message {
59    User {
60        content: String,
61    },
62    Assistant {
63        content: String,
64        #[serde(
65            default,
66            deserialize_with = "json_utils::null_or_vec",
67            skip_serializing_if = "Vec::is_empty"
68        )]
69        tool_calls: Vec<ToolCall>,
70        #[serde(default)]
71        prefix: bool,
72    },
73    System {
74        content: String,
75    },
76}
77
78impl Message {
79    pub fn user(content: String) -> Self {
80        Message::User { content }
81    }
82
83    pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
84        Message::Assistant {
85            content,
86            tool_calls,
87            prefix,
88        }
89    }
90
91    pub fn system(content: String) -> Self {
92        Message::System { content }
93    }
94}
95
96impl TryFrom<message::Message> for Vec<Message> {
97    type Error = message::MessageError;
98
99    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
100        match message {
101            message::Message::User { content } => {
102                let (_, other_content): (Vec<_>, Vec<_>) = content
103                    .into_iter()
104                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
105
106                let messages = other_content
107                    .into_iter()
108                    .filter_map(|content| match content {
109                        message::UserContent::Text(message::Text { text }) => {
110                            Some(Message::User { content: text })
111                        }
112                        _ => None,
113                    })
114                    .collect::<Vec<_>>();
115
116                Ok(messages)
117            }
118            message::Message::Assistant { content, .. } => {
119                let (text_content, tool_calls) = content.into_iter().fold(
120                    (Vec::new(), Vec::new()),
121                    |(mut texts, mut tools), content| {
122                        match content {
123                            message::AssistantContent::Text(text) => texts.push(text),
124                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
125                            message::AssistantContent::Reasoning(_) => {
126                                unimplemented!("Reasoning content is not currently supported on Mistral via Rig");
127                            }
128                        }
129                        (texts, tools)
130                    },
131                );
132
133                Ok(vec![Message::Assistant {
134                    content: text_content
135                        .into_iter()
136                        .next()
137                        .map(|content| content.text)
138                        .unwrap_or_default(),
139                    tool_calls: tool_calls
140                        .into_iter()
141                        .map(|tool_call| tool_call.into())
142                        .collect::<Vec<_>>(),
143                    prefix: false,
144                }])
145            }
146        }
147    }
148}
149
150#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
151pub struct ToolCall {
152    pub id: String,
153    #[serde(default)]
154    pub r#type: ToolType,
155    pub function: Function,
156}
157
158impl From<message::ToolCall> for ToolCall {
159    fn from(tool_call: message::ToolCall) -> Self {
160        Self {
161            id: tool_call.id,
162            r#type: ToolType::default(),
163            function: Function {
164                name: tool_call.function.name,
165                arguments: tool_call.function.arguments,
166            },
167        }
168    }
169}
170
171#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
172pub struct Function {
173    pub name: String,
174    #[serde(with = "json_utils::stringified_json")]
175    pub arguments: serde_json::Value,
176}
177
178#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
179#[serde(rename_all = "lowercase")]
180pub enum ToolType {
181    #[default]
182    Function,
183}
184
185#[derive(Debug, Deserialize, Serialize, Clone)]
186pub struct ToolDefinition {
187    pub r#type: String,
188    pub function: completion::ToolDefinition,
189}
190
191impl From<completion::ToolDefinition> for ToolDefinition {
192    fn from(tool: completion::ToolDefinition) -> Self {
193        Self {
194            r#type: "function".into(),
195            function: tool,
196        }
197    }
198}
199
200#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
201pub struct ToolResultContent {
202    #[serde(default)]
203    r#type: ToolResultContentType,
204    text: String,
205}
206
207#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
208#[serde(rename_all = "lowercase")]
209pub enum ToolResultContentType {
210    #[default]
211    Text,
212}
213
214impl From<String> for ToolResultContent {
215    fn from(s: String) -> Self {
216        ToolResultContent {
217            r#type: ToolResultContentType::default(),
218            text: s,
219        }
220    }
221}
222
223impl From<String> for UserContent {
224    fn from(s: String) -> Self {
225        UserContent::Text { text: s }
226    }
227}
228
229impl FromStr for UserContent {
230    type Err = Infallible;
231
232    fn from_str(s: &str) -> Result<Self, Self::Err> {
233        Ok(UserContent::Text {
234            text: s.to_string(),
235        })
236    }
237}
238
239impl From<String> for AssistantContent {
240    fn from(s: String) -> Self {
241        AssistantContent { text: s }
242    }
243}
244
245impl FromStr for AssistantContent {
246    type Err = Infallible;
247
248    fn from_str(s: &str) -> Result<Self, Self::Err> {
249        Ok(AssistantContent {
250            text: s.to_string(),
251        })
252    }
253}
254
255#[derive(Clone)]
256pub struct CompletionModel<T = reqwest::Client> {
257    pub(crate) client: Client<T>,
258    pub model: String,
259}
260
261#[derive(Debug, Default, Serialize, Deserialize)]
262pub enum ToolChoice {
263    #[default]
264    Auto,
265    None,
266    Any,
267}
268
269impl TryFrom<message::ToolChoice> for ToolChoice {
270    type Error = CompletionError;
271
272    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
273        let res = match value {
274            message::ToolChoice::Auto => Self::Auto,
275            message::ToolChoice::None => Self::None,
276            message::ToolChoice::Required => Self::Any,
277            message::ToolChoice::Specific { .. } => {
278                return Err(CompletionError::ProviderError(
279                    "Mistral doesn't support requiring specific tools to be called".to_string(),
280                ));
281            }
282        };
283
284        Ok(res)
285    }
286}
287
288impl<T> CompletionModel<T> {
289    pub fn new(client: Client<T>, model: &str) -> Self {
290        Self {
291            client,
292            model: model.to_string(),
293        }
294    }
295
296    pub(crate) fn create_completion_request(
297        &self,
298        completion_request: CompletionRequest,
299    ) -> Result<Value, CompletionError> {
300        let mut partial_history = vec![];
301        if let Some(docs) = completion_request.normalized_documents() {
302            partial_history.push(docs);
303        }
304
305        partial_history.extend(completion_request.chat_history);
306
307        let mut full_history: Vec<Message> = match &completion_request.preamble {
308            Some(preamble) => vec![Message::system(preamble.clone())],
309            None => vec![],
310        };
311
312        full_history.extend(
313            partial_history
314                .into_iter()
315                .map(message::Message::try_into)
316                .collect::<Result<Vec<Vec<Message>>, _>>()?
317                .into_iter()
318                .flatten()
319                .collect::<Vec<_>>(),
320        );
321
322        let tool_choice = completion_request
323            .tool_choice
324            .map(ToolChoice::try_from)
325            .transpose()?;
326
327        let request = if completion_request.tools.is_empty() {
328            json!({
329                "model": self.model,
330                "messages": full_history,
331
332            })
333        } else {
334            json!({
335                "model": self.model,
336                "messages": full_history,
337                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
338                "tool_choice": tool_choice,
339            })
340        };
341
342        let request = if let Some(temperature) = completion_request.temperature {
343            json_utils::merge(
344                request,
345                json!({
346                    "temperature": temperature,
347                }),
348            )
349        } else {
350            request
351        };
352
353        let request = if let Some(params) = completion_request.additional_params {
354            json_utils::merge(request, params)
355        } else {
356            request
357        };
358
359        Ok(request)
360    }
361}
362
363#[derive(Debug, Deserialize, Clone, Serialize)]
364pub struct CompletionResponse {
365    pub id: String,
366    pub object: String,
367    pub created: u64,
368    pub model: String,
369    pub system_fingerprint: Option<String>,
370    pub choices: Vec<Choice>,
371    pub usage: Option<Usage>,
372}
373
374impl crate::telemetry::ProviderResponseExt for CompletionResponse {
375    type OutputMessage = Choice;
376    type Usage = Usage;
377
378    fn get_response_id(&self) -> Option<String> {
379        Some(self.id.clone())
380    }
381
382    fn get_response_model_name(&self) -> Option<String> {
383        Some(self.model.clone())
384    }
385
386    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
387        self.choices.clone()
388    }
389
390    fn get_text_response(&self) -> Option<String> {
391        let res = self
392            .choices
393            .iter()
394            .filter_map(|choice| match choice.message {
395                Message::Assistant { ref content, .. } => {
396                    if content.is_empty() {
397                        None
398                    } else {
399                        Some(content.to_string())
400                    }
401                }
402                _ => None,
403            })
404            .collect::<Vec<String>>()
405            .join("\n");
406
407        if res.is_empty() { None } else { Some(res) }
408    }
409
410    fn get_usage(&self) -> Option<Self::Usage> {
411        self.usage.clone()
412    }
413}
414
415impl GetTokenUsage for CompletionResponse {
416    fn token_usage(&self) -> Option<crate::completion::Usage> {
417        let api_usage = self.usage.clone()?;
418
419        let mut usage = crate::completion::Usage::new();
420        usage.input_tokens = api_usage.prompt_tokens as u64;
421        usage.output_tokens = api_usage.completion_tokens as u64;
422        usage.total_tokens = api_usage.total_tokens as u64;
423
424        Some(usage)
425    }
426}
427
428impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
429    type Error = CompletionError;
430
431    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
432        let choice = response.choices.first().ok_or_else(|| {
433            CompletionError::ResponseError("Response contained no choices".to_owned())
434        })?;
435        let content = match &choice.message {
436            Message::Assistant {
437                content,
438                tool_calls,
439                ..
440            } => {
441                let mut content = if content.is_empty() {
442                    vec![]
443                } else {
444                    vec![completion::AssistantContent::text(content.clone())]
445                };
446
447                content.extend(
448                    tool_calls
449                        .iter()
450                        .map(|call| {
451                            completion::AssistantContent::tool_call(
452                                &call.id,
453                                &call.function.name,
454                                call.function.arguments.clone(),
455                            )
456                        })
457                        .collect::<Vec<_>>(),
458                );
459                Ok(content)
460            }
461            _ => Err(CompletionError::ResponseError(
462                "Response did not contain a valid message or tool call".into(),
463            )),
464        }?;
465
466        let choice = OneOrMany::many(content).map_err(|_| {
467            CompletionError::ResponseError(
468                "Response contained no message or tool call (empty)".to_owned(),
469            )
470        })?;
471
472        let usage = response
473            .usage
474            .as_ref()
475            .map(|usage| completion::Usage {
476                input_tokens: usage.prompt_tokens as u64,
477                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
478                total_tokens: usage.total_tokens as u64,
479            })
480            .unwrap_or_default();
481
482        Ok(completion::CompletionResponse {
483            choice,
484            usage,
485            raw_response: response,
486        })
487    }
488}
489
490impl<T> completion::CompletionModel for CompletionModel<T>
491where
492    T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static,
493{
494    type Response = CompletionResponse;
495    type StreamingResponse = CompletionResponse;
496
497    #[cfg_attr(feature = "worker", worker::send)]
498    async fn completion(
499        &self,
500        completion_request: CompletionRequest,
501    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
502        let preamble = completion_request.preamble.clone();
503        let request = self.create_completion_request(completion_request)?;
504        let body = serde_json::to_vec(&request)?;
505
506        let span = if tracing::Span::current().is_disabled() {
507            info_span!(
508                target: "rig::completions",
509                "chat",
510                gen_ai.operation.name = "chat",
511                gen_ai.provider.name = "mistral",
512                gen_ai.request.model = self.model,
513                gen_ai.system_instructions = &preamble,
514                gen_ai.response.id = tracing::field::Empty,
515                gen_ai.response.model = tracing::field::Empty,
516                gen_ai.usage.output_tokens = tracing::field::Empty,
517                gen_ai.usage.input_tokens = tracing::field::Empty,
518                gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
519                gen_ai.output.messages = tracing::field::Empty,
520            )
521        } else {
522            tracing::Span::current()
523        };
524
525        let request = self
526            .client
527            .post("v1/chat/completions")?
528            .header("Content-Type", "application/json")
529            .body(body)
530            .map_err(|e| CompletionError::HttpError(e.into()))?;
531
532        async move {
533            let response = self.client.send(request).await?;
534
535            if response.status().is_success() {
536                let text = http_client::text(response).await?;
537                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
538                    ApiResponse::Ok(response) => {
539                        let span = tracing::Span::current();
540                        span.record_token_usage(&response);
541                        span.record_model_output(&response.choices);
542                        span.record_response_metadata(&response);
543                        response.try_into()
544                    }
545                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
546                }
547            } else {
548                let text = http_client::text(response).await?;
549                Err(CompletionError::ProviderError(text))
550            }
551        }
552        .instrument(span)
553        .await
554    }
555
556    #[cfg_attr(feature = "worker", worker::send)]
557    async fn stream(
558        &self,
559        request: CompletionRequest,
560    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
561        let resp = self.completion(request).await?;
562
563        let stream = stream! {
564            for c in resp.choice.clone() {
565                match c {
566                    message::AssistantContent::Text(t) => {
567                        yield Ok(RawStreamingChoice::Message(t.text.clone()))
568                    }
569                    message::AssistantContent::ToolCall(tc) => {
570                        yield Ok(RawStreamingChoice::ToolCall {
571                            id: tc.id.clone(),
572                            name: tc.function.name.clone(),
573                            arguments: tc.function.arguments.clone(),
574                             call_id: None
575                        })
576                    }
577                    message::AssistantContent::Reasoning(_) => {
578                        unimplemented!("Reasoning is not supported on Mistral via Rig")
579                    }
580                }
581            }
582
583            yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
584        };
585
586        Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593
594    #[test]
595    fn test_response_deserialization() {
596        //https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post
597        let json_data = r#"
598        {
599            "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
600            "object": "chat.completion",
601            "model": "mistral-small-latest",
602            "usage": {
603                "prompt_tokens": 16,
604                "completion_tokens": 34,
605                "total_tokens": 50
606            },
607            "created": 1702256327,
608            "choices": [
609                {
610                    "index": 0,
611                    "message": {
612                        "content": "string",
613                        "tool_calls": [
614                            {
615                                "id": "null",
616                                "type": "function",
617                                "function": {
618                                    "name": "string",
619                                    "arguments": "{ }"
620                                },
621                                "index": 0
622                            }
623                        ],
624                        "prefix": false,
625                        "role": "assistant"
626                    },
627                    "finish_reason": "stop"
628                }
629            ]
630        }
631        "#;
632        let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
633        assert_eq!(completion_response.model, MISTRAL_SMALL);
634
635        let CompletionResponse {
636            id,
637            object,
638            created,
639            choices,
640            usage,
641            ..
642        } = completion_response;
643
644        assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
645
646        let Usage {
647            completion_tokens,
648            prompt_tokens,
649            total_tokens,
650        } = usage.unwrap();
651
652        assert_eq!(prompt_tokens, 16);
653        assert_eq!(completion_tokens, 34);
654        assert_eq!(total_tokens, 50);
655        assert_eq!(object, "chat.completion".to_string());
656        assert_eq!(created, 1702256327);
657        assert_eq!(choices.len(), 1);
658    }
659}