1use std::sync::Arc;
8
9use async_trait::async_trait;
10use llmsdk_provider::ProviderError;
11use llmsdk_provider::language_model::{
12 CallOptions, GenerateResult, LanguageModel, ReasoningEffort, ResponseFormat, StreamResult,
13 SupportedUrls, UrlPattern,
14};
15use llmsdk_provider::shared::Warning;
16use llmsdk_provider_utils::http::{JsonRequest, post_for_stream, post_json, response_byte_stream};
17use llmsdk_provider_utils::sse::{SseEvent, sse_json_stream};
18
19use crate::PROVIDER_ID;
20use crate::config::Inner;
21
22use super::convert_prompt::convert_prompt;
23use super::options::{MistralChatOptions, parse as parse_mistral_options};
24use super::parse_response::parse_response;
25use super::prepare_tools::prepare as prepare_tools;
26use super::stream::StreamState;
27use super::wire::{
28 ChatChunk, ChatRequest, ChatResponse, ResponseFormat as WireResponseFormat, WireJsonSchema,
29};
30
31#[derive(Debug, Clone)]
36pub struct MistralChatModel {
37 pub(crate) inner: Arc<Inner>,
38 pub(crate) model_id: String,
39}
40
41impl MistralChatModel {
42 pub(crate) fn new(inner: Arc<Inner>, model_id: String) -> Self {
44 Self { inner, model_id }
45 }
46
47 fn endpoint(&self) -> String {
48 format!("{}/chat/completions", self.inner.base_url)
49 }
50}
51
52#[async_trait]
53impl LanguageModel for MistralChatModel {
54 fn provider(&self) -> &str {
55 PROVIDER_ID
56 }
57
58 fn model_id(&self) -> &str {
59 &self.model_id
60 }
61
62 async fn supported_urls(&self) -> SupportedUrls {
63 let mut map = SupportedUrls::default();
65 map.insert(
66 "application/pdf".to_owned(),
67 vec![UrlPattern::new("^https://.*$")],
68 );
69 map
70 }
71
72 async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult, ProviderError> {
73 let (request, warnings) = build_request(&self.model_id, &options)?;
74 let request_body_value = serde_json::to_value(&request).ok();
75 let endpoint = self.endpoint();
76
77 let mut request_headers = self.inner.headers.clone();
78 if let Some(headers) = &options.headers {
79 for (name, value) in headers {
80 request_headers.insert(name.clone(), value.clone());
81 }
82 }
83
84 let mut http_request = JsonRequest::new(endpoint, request);
85 http_request.headers = request_headers;
86
87 let response = post_json::<_, ChatResponse>(&self.inner.http, http_request).await?;
88
89 parse_response(
90 response.value,
91 response.headers,
92 request_body_value,
93 warnings,
94 )
95 }
96
97 async fn do_stream(&self, options: CallOptions) -> Result<StreamResult, ProviderError> {
98 let (mut request, warnings) = build_request(&self.model_id, &options)?;
99 request.stream = Some(true);
100 let request_body_value = serde_json::to_value(&request).ok();
101
102 let mut request_headers = self.inner.headers.clone();
103 if let Some(headers) = &options.headers {
104 for (name, value) in headers {
105 request_headers.insert(name.clone(), value.clone());
106 }
107 }
108
109 let mut http_request = JsonRequest::new(self.endpoint(), request);
110 http_request.headers = request_headers;
111
112 let stream_response = post_for_stream(&self.inner.http, http_request).await?;
113 let stream_headers = stream_response.headers.clone();
114
115 let byte_stream = response_byte_stream(stream_response.response);
116 let event_stream = sse_json_stream::<ChatChunk>(byte_stream);
117 let state = StreamState::with_generate_id(warnings, self.inner.generate_id.clone());
118 let parts = build_part_stream(state, event_stream);
119
120 Ok(StreamResult {
121 stream: Box::pin(parts),
122 request: Some(llmsdk_provider::shared::RequestInfo {
123 body: request_body_value,
124 }),
125 response: Some(llmsdk_provider::language_model::StreamResponse {
126 headers: Some(headers_to_provider(stream_headers)),
127 }),
128 })
129 }
130}
131
132fn headers_to_provider(
133 raw: std::collections::HashMap<String, String>,
134) -> llmsdk_provider::shared::Headers {
135 raw.into_iter().map(|(k, v)| (k, Some(v))).collect()
136}
137
138fn build_part_stream<S>(
139 mut state: StreamState,
140 events: S,
141) -> impl futures::Stream<Item = Result<llmsdk_provider::language_model::StreamPart, ProviderError>> + Send
142where
143 S: futures::Stream<Item = Result<SseEvent<ChatChunk>, ProviderError>> + Send + 'static,
144{
145 async_stream::stream! {
146 for part in state.start_frames() {
147 yield Ok(part);
148 }
149
150 let mut events = Box::pin(events);
151 while let Some(event) = futures::StreamExt::next(&mut events).await {
152 match event {
153 Ok(SseEvent::Data(chunk)) => {
154 for part in state.on_chunk(chunk) {
155 yield Ok(part);
156 }
157 }
158 Ok(SseEvent::ParseError { raw, message }) => {
159 for part in state.on_parse_error(&raw, &message) {
160 yield Ok(part);
161 }
162 }
163 Err(e) => {
164 yield Err(e);
165 return;
166 }
167 }
168 }
169
170 for part in state.flush() {
171 yield Ok(part);
172 }
173 }
174}
175
176fn build_request(
178 model_id: &str,
179 options: &CallOptions,
180) -> Result<(ChatRequest, Vec<Warning>), ProviderError> {
181 let mistral_opts = parse_mistral_options(options.provider_options.as_ref());
182 let mut warnings: Vec<Warning> = Vec::new();
183
184 for (val, name) in [
186 (options.top_k.is_some(), "topK"),
187 (options.frequency_penalty.is_some(), "frequencyPenalty"),
188 (options.presence_penalty.is_some(), "presencePenalty"),
189 ] {
190 if val {
191 warnings.push(Warning::Unsupported {
192 feature: name.to_owned(),
193 details: Some(format!("Mistral chat completions does not accept {name}")),
194 });
195 }
196 }
197
198 let reasoning_effort =
199 resolve_reasoning_effort(model_id, &mistral_opts, options.reasoning, &mut warnings);
200
201 let (mut messages, msg_warnings) = convert_prompt(&options.prompt)?;
202 warnings.extend(msg_warnings);
203
204 if matches!(
209 options.response_format.as_ref(),
210 Some(ResponseFormat::Json { schema: None, .. })
211 ) {
212 inject_json_instruction(&mut messages);
213 }
214
215 let prepared = prepare_tools(
216 options.tools.as_deref().unwrap_or(&[]),
217 options.tool_choice.as_ref(),
218 );
219 warnings.extend(prepared.warnings);
220
221 let response_format = options
222 .response_format
223 .as_ref()
224 .and_then(|fmt| convert_response_format(fmt, &mistral_opts));
225
226 let parallel_tool_calls = if prepared.tools.is_some() {
227 mistral_opts.parallel_tool_calls
228 } else {
229 None
230 };
231
232 let request = ChatRequest {
233 model: model_id.to_owned(),
234 messages,
235 stream: None,
236 safe_prompt: mistral_opts.safe_prompt,
237 max_tokens: options.max_output_tokens,
238 temperature: options.temperature,
239 top_p: options.top_p,
240 stop: options.stop_sequences.clone(),
241 random_seed: options.seed,
242 reasoning_effort,
243 response_format,
244 document_image_limit: mistral_opts.document_image_limit,
245 document_page_limit: mistral_opts.document_page_limit,
246 tools: prepared.tools,
247 tool_choice: prepared.tool_choice,
248 parallel_tool_calls,
249 };
250
251 Ok((request, warnings))
252}
253
254fn inject_json_instruction(messages: &mut Vec<super::wire::WireMessage>) {
258 const SUFFIX: &str = "You MUST answer with JSON.";
259 match messages.first_mut() {
260 Some(super::wire::WireMessage::System { content }) => {
261 if content.is_empty() {
262 SUFFIX.clone_into(content);
263 } else {
264 content.push('\n');
265 content.push_str(SUFFIX);
266 }
267 }
268 _ => {
269 messages.insert(
270 0,
271 super::wire::WireMessage::System {
272 content: SUFFIX.to_owned(),
273 },
274 );
275 }
276 }
277}
278
279fn convert_response_format(
280 fmt: &ResponseFormat,
281 mistral: &MistralChatOptions,
282) -> Option<WireResponseFormat> {
283 match fmt {
284 ResponseFormat::Text => None,
285 ResponseFormat::Json {
286 schema,
287 name,
288 description,
289 } => {
290 let structured_outputs = mistral.structured_outputs.unwrap_or(true);
291 let strict_json_schema = mistral.strict_json_schema.unwrap_or(false);
292 Some(match schema {
293 Some(schema) if structured_outputs => WireResponseFormat::JsonSchema {
294 json_schema: WireJsonSchema {
295 name: name.clone().unwrap_or_else(|| "response".to_owned()),
296 schema: serde_json::to_value(schema).unwrap_or(serde_json::Value::Null),
297 strict: strict_json_schema,
298 description: description.clone(),
299 },
300 },
301 _ => WireResponseFormat::JsonObject,
302 })
303 }
304 }
305}
306
307fn resolve_reasoning_effort(
308 model_id: &str,
309 mistral: &MistralChatOptions,
310 top_level: Option<ReasoningEffort>,
311 warnings: &mut Vec<Warning>,
312) -> Option<String> {
313 let supports = supports_reasoning_effort(model_id);
314
315 if !supports {
316 if top_level.is_some() && !matches!(top_level, Some(ReasoningEffort::ProviderDefault)) {
317 warnings.push(Warning::Unsupported {
318 feature: "reasoning".to_owned(),
319 details: Some("This model does not support reasoning configuration.".to_owned()),
320 });
321 }
322 return None;
323 }
324
325 if let Some(effort) = &mistral.reasoning_effort {
326 return Some(effort.clone());
327 }
328 match top_level? {
329 ReasoningEffort::ProviderDefault => None,
330 ReasoningEffort::None => Some("none".to_owned()),
331 ReasoningEffort::Minimal
332 | ReasoningEffort::Low
333 | ReasoningEffort::Medium
334 | ReasoningEffort::High
335 | ReasoningEffort::Xhigh => Some("high".to_owned()),
336 }
337}
338
339fn supports_reasoning_effort(model_id: &str) -> bool {
341 matches!(
342 model_id,
343 "mistral-small-latest" | "mistral-small-2603" | "mistral-medium-3" | "mistral-medium-3.5"
344 )
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use llmsdk_provider::language_model::TextPart;
351 use llmsdk_provider::language_model::{FunctionTool, Message, Tool, ToolChoice, UserPart};
352 use serde_json::json;
353
354 fn opts() -> CallOptions {
355 CallOptions {
356 prompt: vec![Message::User {
357 content: vec![UserPart::Text(TextPart {
358 text: "hi".into(),
359 provider_options: None,
360 })],
361 provider_options: None,
362 }],
363 ..Default::default()
364 }
365 }
366
367 #[test]
368 fn warns_on_topk_frequency_presence() {
369 let mut o = opts();
370 o.top_k = Some(5);
371 o.frequency_penalty = Some(0.1);
372 o.presence_penalty = Some(0.1);
373 let (_, warnings) = build_request("mistral-small-latest", &o).unwrap();
374 assert_eq!(warnings.len(), 3);
375 }
376
377 #[test]
378 fn stop_sequences_pass_through_without_warning() {
379 let mut o = opts();
380 o.stop_sequences = Some(vec!["END".into()]);
381 let (req, warnings) = build_request("mistral-small-latest", &o).unwrap();
382 assert_eq!(req.stop, Some(vec!["END".into()]));
383 assert!(warnings.iter().all(
384 |w| !matches!(w, Warning::Unsupported { feature, .. } if feature == "stopSequences")
385 ));
386 }
387
388 #[test]
389 fn seed_serializes_as_random_seed() {
390 let mut o = opts();
391 o.seed = Some(42);
392 let (req, _) = build_request("mistral-small-latest", &o).unwrap();
393 assert_eq!(req.random_seed, Some(42));
394 let body = serde_json::to_value(&req).unwrap();
395 assert_eq!(body["random_seed"], 42);
396 assert!(body.get("seed").is_none());
397 }
398
399 #[test]
400 fn max_output_tokens_serializes_as_max_tokens() {
401 let mut o = opts();
402 o.max_output_tokens = Some(123);
403 let (req, _) = build_request("mistral-small-latest", &o).unwrap();
404 assert_eq!(req.max_tokens, Some(123));
405 }
406
407 #[test]
408 fn safe_prompt_provider_option_pass_through() {
409 let mut o = opts();
410 let mut po = llmsdk_provider::shared::ProviderOptions::new();
411 po.insert(
412 "mistral".into(),
413 json!({"safePrompt": true}).as_object().cloned().unwrap(),
414 );
415 o.provider_options = Some(po);
416 let (req, _) = build_request("mistral-small-latest", &o).unwrap();
417 assert_eq!(req.safe_prompt, Some(true));
418 }
419
420 #[test]
421 fn unsupported_reasoning_warns_for_non_reasoning_model() {
422 let mut o = opts();
423 o.reasoning = Some(ReasoningEffort::High);
424 let (req, warnings) = build_request("mistral-large-latest", &o).unwrap();
425 assert!(req.reasoning_effort.is_none());
426 assert!(warnings.iter().any(|w| matches!(
427 w,
428 Warning::Unsupported { feature, .. } if feature == "reasoning"
429 )));
430 }
431
432 #[test]
433 fn reasoning_effort_coerces_to_high_for_supported_model() {
434 let mut o = opts();
435 o.reasoning = Some(ReasoningEffort::Low);
436 let (req, _) = build_request("mistral-small-latest", &o).unwrap();
437 assert_eq!(req.reasoning_effort.as_deref(), Some("high"));
438 }
439
440 #[test]
441 fn reasoning_effort_none_passes_through() {
442 let mut o = opts();
443 o.reasoning = Some(ReasoningEffort::None);
444 let (req, _) = build_request("mistral-medium-3.5", &o).unwrap();
445 assert_eq!(req.reasoning_effort.as_deref(), Some("none"));
446 }
447
448 #[test]
449 fn provider_options_reasoning_effort_wins() {
450 let mut o = opts();
451 o.reasoning = Some(ReasoningEffort::Low);
452 let mut po = llmsdk_provider::shared::ProviderOptions::new();
453 po.insert(
454 "mistral".into(),
455 json!({"reasoningEffort": "none"})
456 .as_object()
457 .cloned()
458 .unwrap(),
459 );
460 o.provider_options = Some(po);
461 let (req, _) = build_request("mistral-small-latest", &o).unwrap();
462 assert_eq!(req.reasoning_effort.as_deref(), Some("none"));
463 }
464
465 #[test]
466 fn parallel_tool_calls_only_when_tools_present() {
467 let mut o = opts();
468 let mut po = llmsdk_provider::shared::ProviderOptions::new();
469 po.insert(
470 "mistral".into(),
471 json!({"parallelToolCalls": false})
472 .as_object()
473 .cloned()
474 .unwrap(),
475 );
476 o.provider_options = Some(po.clone());
477 let (req, _) = build_request("mistral-small-latest", &o).unwrap();
478 assert_eq!(req.parallel_tool_calls, None);
479
480 o.tools = Some(vec![Tool::Function(FunctionTool {
481 name: "weather".into(),
482 description: None,
483 input_schema: serde_json::from_value(json!({"type":"object"})).unwrap(),
484 input_examples: None,
485 strict: None,
486 provider_options: None,
487 })]);
488 let (req, _) = build_request("mistral-small-latest", &o).unwrap();
489 assert_eq!(req.parallel_tool_calls, Some(false));
490 }
491
492 #[test]
493 fn function_tool_pass_through_with_tool_choice_required() {
494 let mut o = opts();
495 o.tools = Some(vec![Tool::Function(FunctionTool {
496 name: "weather".into(),
497 description: Some("get weather".into()),
498 input_schema: serde_json::from_value(
499 json!({"type":"object","properties":{"c":{"type":"string"}}}),
500 )
501 .unwrap(),
502 input_examples: None,
503 strict: None,
504 provider_options: None,
505 })]);
506 o.tool_choice = Some(ToolChoice::Required);
507 let (req, _) = build_request("mistral-small-latest", &o).unwrap();
508 assert!(req.tools.is_some());
509 let choice = serde_json::to_value(req.tool_choice.unwrap()).unwrap();
510 assert_eq!(choice, json!("any"));
511 }
512
513 #[test]
514 fn json_response_format_object_default() {
515 let mut o = opts();
516 o.response_format = Some(ResponseFormat::Json {
517 schema: None,
518 name: None,
519 description: None,
520 });
521 let (req, _) = build_request("mistral-small-latest", &o).unwrap();
522 let body = serde_json::to_value(req.response_format).unwrap();
523 assert_eq!(body["type"], "json_object");
524 }
525
526 #[test]
527 fn json_response_format_schema_when_structured_outputs() {
528 let mut o = opts();
529 o.response_format = Some(ResponseFormat::Json {
530 schema: Some(serde_json::from_value(json!({"type":"object"})).unwrap()),
531 name: Some("MySchema".into()),
532 description: Some("a schema".into()),
533 });
534 let (req, _) = build_request("mistral-small-latest", &o).unwrap();
535 let body = serde_json::to_value(req.response_format).unwrap();
536 assert_eq!(body["type"], "json_schema");
537 assert_eq!(body["json_schema"]["name"], "MySchema");
538 assert_eq!(body["json_schema"]["description"], "a schema");
539 assert_eq!(body["json_schema"]["strict"], false);
540 }
541
542 #[test]
543 fn json_response_format_strict_pass_through() {
544 let mut o = opts();
545 o.response_format = Some(ResponseFormat::Json {
546 schema: Some(serde_json::from_value(json!({"type":"object"})).unwrap()),
547 name: None,
548 description: None,
549 });
550 let mut po = llmsdk_provider::shared::ProviderOptions::new();
551 po.insert(
552 "mistral".into(),
553 json!({"strictJsonSchema": true})
554 .as_object()
555 .cloned()
556 .unwrap(),
557 );
558 o.provider_options = Some(po);
559 let (req, _) = build_request("mistral-small-latest", &o).unwrap();
560 let body = serde_json::to_value(req.response_format).unwrap();
561 assert_eq!(body["json_schema"]["strict"], true);
562 }
563}