Skip to main content

llmsdk_mistral/chat/
model.rs

1//! [`LanguageModel`] implementation for Mistral Chat Completions.
2//!
3//! Mirrors `mistral-chat-language-model.ts`. Entry: [`MistralChatModel::new`]
4//! via [`crate::Mistral::chat`].
5// Rust guideline compliant 2026-05-25
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use llmsdk_provider::ProviderError;
11use llmsdk_provider::language_model::{
12    CallOptions, GenerateResult, LanguageModel, ReasoningEffort, ResponseFormat, StreamResult,
13    SupportedUrls, UrlPattern,
14};
15use llmsdk_provider::shared::Warning;
16use llmsdk_provider_utils::http::{JsonRequest, post_for_stream, post_json, response_byte_stream};
17use llmsdk_provider_utils::sse::{SseEvent, sse_json_stream};
18
19use crate::PROVIDER_ID;
20use crate::config::Inner;
21
22use super::convert_prompt::convert_prompt;
23use super::options::{MistralChatOptions, parse as parse_mistral_options};
24use super::parse_response::parse_response;
25use super::prepare_tools::prepare as prepare_tools;
26use super::stream::StreamState;
27use super::wire::{
28    ChatChunk, ChatRequest, ChatResponse, ResponseFormat as WireResponseFormat, WireJsonSchema,
29};
30
31/// Mistral Chat Completions model handle.
32///
33/// Cheap to clone. Multiple clones share the underlying HTTP client and
34/// authentication state via [`Mistral`](crate::Mistral)'s `Arc`.
35#[derive(Debug, Clone)]
36pub struct MistralChatModel {
37    pub(crate) inner: Arc<Inner>,
38    pub(crate) model_id: String,
39}
40
41impl MistralChatModel {
42    /// Construct from shared provider state and a model id.
43    pub(crate) fn new(inner: Arc<Inner>, model_id: String) -> Self {
44        Self { inner, model_id }
45    }
46
47    fn endpoint(&self) -> String {
48        format!("{}/chat/completions", self.inner.base_url)
49    }
50}
51
52#[async_trait]
53impl LanguageModel for MistralChatModel {
54    fn provider(&self) -> &str {
55        PROVIDER_ID
56    }
57
58    fn model_id(&self) -> &str {
59        &self.model_id
60    }
61
62    async fn supported_urls(&self) -> SupportedUrls {
63        // Mirrors upstream's `supportedUrls = { 'application/pdf': [/^https:\/\/.*$/] }`.
64        let mut map = SupportedUrls::default();
65        map.insert(
66            "application/pdf".to_owned(),
67            vec![UrlPattern::new("^https://.*$")],
68        );
69        map
70    }
71
72    async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult, ProviderError> {
73        let (request, warnings) = build_request(&self.model_id, &options)?;
74        let request_body_value = serde_json::to_value(&request).ok();
75        let endpoint = self.endpoint();
76
77        let mut request_headers = self.inner.headers.clone();
78        if let Some(headers) = &options.headers {
79            for (name, value) in headers {
80                request_headers.insert(name.clone(), value.clone());
81            }
82        }
83
84        let mut http_request = JsonRequest::new(endpoint, request);
85        http_request.headers = request_headers;
86
87        let response = post_json::<_, ChatResponse>(&self.inner.http, http_request).await?;
88
89        parse_response(
90            response.value,
91            response.headers,
92            request_body_value,
93            warnings,
94        )
95    }
96
97    async fn do_stream(&self, options: CallOptions) -> Result<StreamResult, ProviderError> {
98        let (mut request, warnings) = build_request(&self.model_id, &options)?;
99        request.stream = Some(true);
100        let request_body_value = serde_json::to_value(&request).ok();
101
102        let mut request_headers = self.inner.headers.clone();
103        if let Some(headers) = &options.headers {
104            for (name, value) in headers {
105                request_headers.insert(name.clone(), value.clone());
106            }
107        }
108
109        let mut http_request = JsonRequest::new(self.endpoint(), request);
110        http_request.headers = request_headers;
111
112        let stream_response = post_for_stream(&self.inner.http, http_request).await?;
113        let stream_headers = stream_response.headers.clone();
114
115        let byte_stream = response_byte_stream(stream_response.response);
116        let event_stream = sse_json_stream::<ChatChunk>(byte_stream);
117        let state = StreamState::with_generate_id(warnings, self.inner.generate_id.clone());
118        let parts = build_part_stream(state, event_stream);
119
120        Ok(StreamResult {
121            stream: Box::pin(parts),
122            request: Some(llmsdk_provider::shared::RequestInfo {
123                body: request_body_value,
124            }),
125            response: Some(llmsdk_provider::language_model::StreamResponse {
126                headers: Some(headers_to_provider(stream_headers)),
127            }),
128        })
129    }
130}
131
132fn headers_to_provider(
133    raw: std::collections::HashMap<String, String>,
134) -> llmsdk_provider::shared::Headers {
135    raw.into_iter().map(|(k, v)| (k, Some(v))).collect()
136}
137
138fn build_part_stream<S>(
139    mut state: StreamState,
140    events: S,
141) -> impl futures::Stream<Item = Result<llmsdk_provider::language_model::StreamPart, ProviderError>> + Send
142where
143    S: futures::Stream<Item = Result<SseEvent<ChatChunk>, ProviderError>> + Send + 'static,
144{
145    async_stream::stream! {
146        for part in state.start_frames() {
147            yield Ok(part);
148        }
149
150        let mut events = Box::pin(events);
151        while let Some(event) = futures::StreamExt::next(&mut events).await {
152            match event {
153                Ok(SseEvent::Data(chunk)) => {
154                    for part in state.on_chunk(chunk) {
155                        yield Ok(part);
156                    }
157                }
158                Ok(SseEvent::ParseError { raw, message }) => {
159                    for part in state.on_parse_error(&raw, &message) {
160                        yield Ok(part);
161                    }
162                }
163                Err(e) => {
164                    yield Err(e);
165                    return;
166                }
167            }
168        }
169
170        for part in state.flush() {
171            yield Ok(part);
172        }
173    }
174}
175
176/// Build the wire request and collect warnings about dropped settings.
177fn build_request(
178    model_id: &str,
179    options: &CallOptions,
180) -> Result<(ChatRequest, Vec<Warning>), ProviderError> {
181    let mistral_opts = parse_mistral_options(options.provider_options.as_ref());
182    let mut warnings: Vec<Warning> = Vec::new();
183
184    // Mistral does not accept these three sampling parameters.
185    for (val, name) in [
186        (options.top_k.is_some(), "topK"),
187        (options.frequency_penalty.is_some(), "frequencyPenalty"),
188        (options.presence_penalty.is_some(), "presencePenalty"),
189    ] {
190        if val {
191            warnings.push(Warning::Unsupported {
192                feature: name.to_owned(),
193                details: Some(format!("Mistral chat completions does not accept {name}")),
194            });
195        }
196    }
197
198    let reasoning_effort =
199        resolve_reasoning_effort(model_id, &mistral_opts, options.reasoning, &mut warnings);
200
201    let (mut messages, msg_warnings) = convert_prompt(&options.prompt)?;
202    warnings.extend(msg_warnings);
203
204    // Mistral needs an explicit instruction when caller asks for a generic
205    // JSON response (json mode without a schema) — mirrors ai-sdk's
206    // `injectJsonInstructionIntoMessages`. See:
207    //   https://docs.mistral.ai/capabilities/structured-output/structured_output_overview/
208    if matches!(
209        options.response_format.as_ref(),
210        Some(ResponseFormat::Json { schema: None, .. })
211    ) {
212        inject_json_instruction(&mut messages);
213    }
214
215    let prepared = prepare_tools(
216        options.tools.as_deref().unwrap_or(&[]),
217        options.tool_choice.as_ref(),
218    );
219    warnings.extend(prepared.warnings);
220
221    let response_format = options
222        .response_format
223        .as_ref()
224        .and_then(|fmt| convert_response_format(fmt, &mistral_opts));
225
226    let parallel_tool_calls = if prepared.tools.is_some() {
227        mistral_opts.parallel_tool_calls
228    } else {
229        None
230    };
231
232    let request = ChatRequest {
233        model: model_id.to_owned(),
234        messages,
235        stream: None,
236        safe_prompt: mistral_opts.safe_prompt,
237        max_tokens: options.max_output_tokens,
238        temperature: options.temperature,
239        top_p: options.top_p,
240        stop: options.stop_sequences.clone(),
241        random_seed: options.seed,
242        reasoning_effort,
243        response_format,
244        document_image_limit: mistral_opts.document_image_limit,
245        document_page_limit: mistral_opts.document_page_limit,
246        tools: prepared.tools,
247        tool_choice: prepared.tool_choice,
248        parallel_tool_calls,
249    };
250
251    Ok((request, warnings))
252}
253
254/// Append a JSON instruction to the leading system message, creating one if
255/// absent. Mirrors `injectJsonInstructionIntoMessages` in
256/// `@ai-sdk/provider-utils`.
257fn inject_json_instruction(messages: &mut Vec<super::wire::WireMessage>) {
258    const SUFFIX: &str = "You MUST answer with JSON.";
259    match messages.first_mut() {
260        Some(super::wire::WireMessage::System { content }) => {
261            if content.is_empty() {
262                SUFFIX.clone_into(content);
263            } else {
264                content.push('\n');
265                content.push_str(SUFFIX);
266            }
267        }
268        _ => {
269            messages.insert(
270                0,
271                super::wire::WireMessage::System {
272                    content: SUFFIX.to_owned(),
273                },
274            );
275        }
276    }
277}
278
279fn convert_response_format(
280    fmt: &ResponseFormat,
281    mistral: &MistralChatOptions,
282) -> Option<WireResponseFormat> {
283    match fmt {
284        ResponseFormat::Text => None,
285        ResponseFormat::Json {
286            schema,
287            name,
288            description,
289        } => {
290            let structured_outputs = mistral.structured_outputs.unwrap_or(true);
291            let strict_json_schema = mistral.strict_json_schema.unwrap_or(false);
292            Some(match schema {
293                Some(schema) if structured_outputs => WireResponseFormat::JsonSchema {
294                    json_schema: WireJsonSchema {
295                        name: name.clone().unwrap_or_else(|| "response".to_owned()),
296                        schema: serde_json::to_value(schema).unwrap_or(serde_json::Value::Null),
297                        strict: strict_json_schema,
298                        description: description.clone(),
299                    },
300                },
301                _ => WireResponseFormat::JsonObject,
302            })
303        }
304    }
305}
306
307fn resolve_reasoning_effort(
308    model_id: &str,
309    mistral: &MistralChatOptions,
310    top_level: Option<ReasoningEffort>,
311    warnings: &mut Vec<Warning>,
312) -> Option<String> {
313    let supports = supports_reasoning_effort(model_id);
314
315    if !supports {
316        if top_level.is_some() && !matches!(top_level, Some(ReasoningEffort::ProviderDefault)) {
317            warnings.push(Warning::Unsupported {
318                feature: "reasoning".to_owned(),
319                details: Some("This model does not support reasoning configuration.".to_owned()),
320            });
321        }
322        return None;
323    }
324
325    if let Some(effort) = &mistral.reasoning_effort {
326        return Some(effort.clone());
327    }
328    match top_level? {
329        ReasoningEffort::ProviderDefault => None,
330        ReasoningEffort::None => Some("none".to_owned()),
331        ReasoningEffort::Minimal
332        | ReasoningEffort::Low
333        | ReasoningEffort::Medium
334        | ReasoningEffort::High
335        | ReasoningEffort::Xhigh => Some("high".to_owned()),
336    }
337}
338
339/// Mirrors the upstream `supportsReasoningEffort` allowlist.
340fn supports_reasoning_effort(model_id: &str) -> bool {
341    matches!(
342        model_id,
343        "mistral-small-latest" | "mistral-small-2603" | "mistral-medium-3" | "mistral-medium-3.5"
344    )
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use llmsdk_provider::language_model::TextPart;
351    use llmsdk_provider::language_model::{FunctionTool, Message, Tool, ToolChoice, UserPart};
352    use serde_json::json;
353
354    fn opts() -> CallOptions {
355        CallOptions {
356            prompt: vec![Message::User {
357                content: vec![UserPart::Text(TextPart {
358                    text: "hi".into(),
359                    provider_options: None,
360                })],
361                provider_options: None,
362            }],
363            ..Default::default()
364        }
365    }
366
367    #[test]
368    fn warns_on_topk_frequency_presence() {
369        let mut o = opts();
370        o.top_k = Some(5);
371        o.frequency_penalty = Some(0.1);
372        o.presence_penalty = Some(0.1);
373        let (_, warnings) = build_request("mistral-small-latest", &o).unwrap();
374        assert_eq!(warnings.len(), 3);
375    }
376
377    #[test]
378    fn stop_sequences_pass_through_without_warning() {
379        let mut o = opts();
380        o.stop_sequences = Some(vec!["END".into()]);
381        let (req, warnings) = build_request("mistral-small-latest", &o).unwrap();
382        assert_eq!(req.stop, Some(vec!["END".into()]));
383        assert!(warnings.iter().all(
384            |w| !matches!(w, Warning::Unsupported { feature, .. } if feature == "stopSequences")
385        ));
386    }
387
388    #[test]
389    fn seed_serializes_as_random_seed() {
390        let mut o = opts();
391        o.seed = Some(42);
392        let (req, _) = build_request("mistral-small-latest", &o).unwrap();
393        assert_eq!(req.random_seed, Some(42));
394        let body = serde_json::to_value(&req).unwrap();
395        assert_eq!(body["random_seed"], 42);
396        assert!(body.get("seed").is_none());
397    }
398
399    #[test]
400    fn max_output_tokens_serializes_as_max_tokens() {
401        let mut o = opts();
402        o.max_output_tokens = Some(123);
403        let (req, _) = build_request("mistral-small-latest", &o).unwrap();
404        assert_eq!(req.max_tokens, Some(123));
405    }
406
407    #[test]
408    fn safe_prompt_provider_option_pass_through() {
409        let mut o = opts();
410        let mut po = llmsdk_provider::shared::ProviderOptions::new();
411        po.insert(
412            "mistral".into(),
413            json!({"safePrompt": true}).as_object().cloned().unwrap(),
414        );
415        o.provider_options = Some(po);
416        let (req, _) = build_request("mistral-small-latest", &o).unwrap();
417        assert_eq!(req.safe_prompt, Some(true));
418    }
419
420    #[test]
421    fn unsupported_reasoning_warns_for_non_reasoning_model() {
422        let mut o = opts();
423        o.reasoning = Some(ReasoningEffort::High);
424        let (req, warnings) = build_request("mistral-large-latest", &o).unwrap();
425        assert!(req.reasoning_effort.is_none());
426        assert!(warnings.iter().any(|w| matches!(
427            w,
428            Warning::Unsupported { feature, .. } if feature == "reasoning"
429        )));
430    }
431
432    #[test]
433    fn reasoning_effort_coerces_to_high_for_supported_model() {
434        let mut o = opts();
435        o.reasoning = Some(ReasoningEffort::Low);
436        let (req, _) = build_request("mistral-small-latest", &o).unwrap();
437        assert_eq!(req.reasoning_effort.as_deref(), Some("high"));
438    }
439
440    #[test]
441    fn reasoning_effort_none_passes_through() {
442        let mut o = opts();
443        o.reasoning = Some(ReasoningEffort::None);
444        let (req, _) = build_request("mistral-medium-3.5", &o).unwrap();
445        assert_eq!(req.reasoning_effort.as_deref(), Some("none"));
446    }
447
448    #[test]
449    fn provider_options_reasoning_effort_wins() {
450        let mut o = opts();
451        o.reasoning = Some(ReasoningEffort::Low);
452        let mut po = llmsdk_provider::shared::ProviderOptions::new();
453        po.insert(
454            "mistral".into(),
455            json!({"reasoningEffort": "none"})
456                .as_object()
457                .cloned()
458                .unwrap(),
459        );
460        o.provider_options = Some(po);
461        let (req, _) = build_request("mistral-small-latest", &o).unwrap();
462        assert_eq!(req.reasoning_effort.as_deref(), Some("none"));
463    }
464
465    #[test]
466    fn parallel_tool_calls_only_when_tools_present() {
467        let mut o = opts();
468        let mut po = llmsdk_provider::shared::ProviderOptions::new();
469        po.insert(
470            "mistral".into(),
471            json!({"parallelToolCalls": false})
472                .as_object()
473                .cloned()
474                .unwrap(),
475        );
476        o.provider_options = Some(po.clone());
477        let (req, _) = build_request("mistral-small-latest", &o).unwrap();
478        assert_eq!(req.parallel_tool_calls, None);
479
480        o.tools = Some(vec![Tool::Function(FunctionTool {
481            name: "weather".into(),
482            description: None,
483            input_schema: serde_json::from_value(json!({"type":"object"})).unwrap(),
484            input_examples: None,
485            strict: None,
486            provider_options: None,
487        })]);
488        let (req, _) = build_request("mistral-small-latest", &o).unwrap();
489        assert_eq!(req.parallel_tool_calls, Some(false));
490    }
491
492    #[test]
493    fn function_tool_pass_through_with_tool_choice_required() {
494        let mut o = opts();
495        o.tools = Some(vec![Tool::Function(FunctionTool {
496            name: "weather".into(),
497            description: Some("get weather".into()),
498            input_schema: serde_json::from_value(
499                json!({"type":"object","properties":{"c":{"type":"string"}}}),
500            )
501            .unwrap(),
502            input_examples: None,
503            strict: None,
504            provider_options: None,
505        })]);
506        o.tool_choice = Some(ToolChoice::Required);
507        let (req, _) = build_request("mistral-small-latest", &o).unwrap();
508        assert!(req.tools.is_some());
509        let choice = serde_json::to_value(req.tool_choice.unwrap()).unwrap();
510        assert_eq!(choice, json!("any"));
511    }
512
513    #[test]
514    fn json_response_format_object_default() {
515        let mut o = opts();
516        o.response_format = Some(ResponseFormat::Json {
517            schema: None,
518            name: None,
519            description: None,
520        });
521        let (req, _) = build_request("mistral-small-latest", &o).unwrap();
522        let body = serde_json::to_value(req.response_format).unwrap();
523        assert_eq!(body["type"], "json_object");
524    }
525
526    #[test]
527    fn json_response_format_schema_when_structured_outputs() {
528        let mut o = opts();
529        o.response_format = Some(ResponseFormat::Json {
530            schema: Some(serde_json::from_value(json!({"type":"object"})).unwrap()),
531            name: Some("MySchema".into()),
532            description: Some("a schema".into()),
533        });
534        let (req, _) = build_request("mistral-small-latest", &o).unwrap();
535        let body = serde_json::to_value(req.response_format).unwrap();
536        assert_eq!(body["type"], "json_schema");
537        assert_eq!(body["json_schema"]["name"], "MySchema");
538        assert_eq!(body["json_schema"]["description"], "a schema");
539        assert_eq!(body["json_schema"]["strict"], false);
540    }
541
542    #[test]
543    fn json_response_format_strict_pass_through() {
544        let mut o = opts();
545        o.response_format = Some(ResponseFormat::Json {
546            schema: Some(serde_json::from_value(json!({"type":"object"})).unwrap()),
547            name: None,
548            description: None,
549        });
550        let mut po = llmsdk_provider::shared::ProviderOptions::new();
551        po.insert(
552            "mistral".into(),
553            json!({"strictJsonSchema": true})
554                .as_object()
555                .cloned()
556                .unwrap(),
557        );
558        o.provider_options = Some(po);
559        let (req, _) = build_request("mistral-small-latest", &o).unwrap();
560        let body = serde_json::to_value(req.response_format).unwrap();
561        assert_eq!(body["json_schema"]["strict"], true);
562    }
563}