1use async_trait::async_trait;
8use bytes::Bytes;
9use futures::{Stream, StreamExt};
10use reqwest::Client;
11use serde::Deserialize;
12use serde_json::Value as JsonValue;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use crate::{
17 Api, AssistantMessage, ContentBlock, Context, Model, Provider, ProviderError, ProviderEvent,
18 StopReason, StreamOptions, Usage,
19};
20
21use super::shared_client;
22
23#[derive(Clone)]
25pub struct MistralProvider {
26 client: &'static Client,
27 api_key: Option<String>,
28}
29
30const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
32
33const MISTRAL_TOOL_CALL_ID_LENGTH: usize = 9;
35
36impl MistralProvider {
37 pub fn new() -> Self {
41 Self {
42 client: shared_client(),
43 api_key: None,
44 }
45 }
46
47 #[cfg(test)]
49 pub fn with_api_key(api_key: impl Into<String>) -> Self {
50 Self {
51 client: shared_client(),
52 api_key: Some(api_key.into()),
53 }
54 }
55
56 fn normalize_tool_call_id(id: &str) -> String {
62 if id.len() <= MISTRAL_TOOL_CALL_ID_LENGTH {
64 return id.to_string();
65 }
66
67 let normalized: String = id
69 .chars()
70 .filter(|c| c.is_alphanumeric())
71 .take(MISTRAL_TOOL_CALL_ID_LENGTH)
72 .collect();
73
74 if normalized.len() < MISTRAL_TOOL_CALL_ID_LENGTH {
76 format!(
77 "{}{}",
78 normalized,
79 "0".repeat(MISTRAL_TOOL_CALL_ID_LENGTH - normalized.len())
80 )
81 } else {
82 normalized
83 }
84 }
85}
86
87impl Default for MistralProvider {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93#[async_trait]
94impl Provider for MistralProvider {
95 async fn stream(
96 &self,
97 model: &Model,
98 context: &Context,
99 options: Option<StreamOptions>,
100 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
101 let options = options.unwrap_or_default();
102
103 let base_url = if model.base_url.is_empty() {
105 MISTRAL_API_URL.to_string()
106 } else {
107 model.base_url.trim_end_matches('/').to_string()
108 };
109 let url = format!("{}/chat/completions", base_url);
110
111 let api_key = options
113 .api_key
114 .as_ref()
115 .or(self.api_key.as_ref())
116 .ok_or(ProviderError::MissingApiKey)?;
117
118 let messages = build_messages(context)?;
120
121 let mut body = serde_json::json!({
123 "model": model.id,
124 "messages": messages,
125 "stream": true,
126 });
127
128 if let Some(temp) = options.temperature {
130 body["temperature"] = serde_json::json!(temp);
131 }
132
133 if let Some(max) = options.max_tokens {
134 body["max_tokens"] = serde_json::json!(max);
135 }
136
137 if !context.tools.is_empty() {
139 body["tools"] = build_tools(&context.tools)?;
140 }
141
142 let mut headers = reqwest::header::HeaderMap::new();
144 headers.insert(
145 reqwest::header::AUTHORIZATION,
146 format!("Bearer {}", api_key)
147 .parse()
148 .expect("valid bearer header"),
149 );
150 headers.insert(
151 reqwest::header::CONTENT_TYPE,
152 "application/json".parse().expect("valid header value"),
153 );
154
155 for (k, v) in &options.headers {
156 if let (Ok(name), Ok(value)) = (
157 k.parse::<reqwest::header::HeaderName>(),
158 v.parse::<reqwest::header::HeaderValue>(),
159 ) {
160 headers.insert(name, value);
161 }
162 }
163
164 let response = self
166 .client
167 .post(&url)
168 .headers(headers)
169 .json(&body)
170 .send()
171 .await
172 .map_err(ProviderError::RequestFailed)?;
173
174 if !response.status().is_success() {
175 let status = response.status();
176 let body: String = response.text().await.unwrap_or_default();
177 return Err(ProviderError::HttpError(status.as_u16(), body));
178 }
179
180 let provider_name = model.provider.clone();
182 let model_id = model.id.clone();
183
184 let stream = response.bytes_stream().flat_map(
185 move |chunk: Result<Bytes, reqwest::Error>| match chunk {
186 Ok(bytes) => {
187 let text = String::from_utf8_lossy(&bytes).to_string();
188 futures::stream::iter(parse_sse_events(&text, &provider_name, &model_id))
189 }
190 Err(e) => futures::stream::iter(vec![ProviderEvent::Error {
191 reason: StopReason::Error,
192 error: create_error_message(&e.to_string(), &provider_name, &model_id),
193 }]),
194 },
195 );
196
197 Ok(Box::pin(stream))
198 }
199
200 fn name(&self) -> &str {
201 "mistral"
202 }
203}
204
205fn build_messages(context: &Context) -> Result<Vec<JsonValue>, ProviderError> {
207 let mut messages = Vec::new();
208
209 if let Some(ref prompt) = context.system_prompt {
211 messages.push(serde_json::json!({
212 "role": "system",
213 "content": prompt,
214 }));
215 }
216
217 for msg in &context.messages {
219 match msg {
220 crate::Message::User(u) => {
221 let content: String = match &u.content {
222 crate::MessageContent::Text(s) => s.clone(),
223 crate::MessageContent::Blocks(blocks) => blocks_to_content(blocks)?.to_string(),
224 };
225 messages.push(serde_json::json!({
226 "role": "user",
227 "content": content,
228 }));
229 }
230 crate::Message::Assistant(a) => {
231 let content = blocks_to_content(&a.content)?.to_string();
232 messages.push(serde_json::json!({
233 "role": "assistant",
234 "content": content,
235 }));
236 }
237 crate::Message::ToolResult(t) => {
238 let content = blocks_to_content(&t.content)?.to_string();
239 let normalized_id = MistralProvider::normalize_tool_call_id(&t.tool_call_id);
241 messages.push(serde_json::json!({
242 "role": "tool",
243 "tool_call_id": normalized_id,
244 "tool_name": t.tool_name,
245 "content": content,
246 }));
247 }
248 }
249 }
250
251 Ok(messages)
252}
253
254fn blocks_to_content(blocks: &[ContentBlock]) -> Result<JsonValue, ProviderError> {
256 if blocks.len() == 1 {
257 if let Some(text) = blocks[0].as_text() {
258 return Ok(JsonValue::String(text.to_string()));
259 }
260 }
261
262 let items: Result<Vec<_>, _> = blocks
263 .iter()
264 .map(|block| match block {
265 ContentBlock::Text(t) => Ok(serde_json::json!({
266 "type": "text",
267 "text": t.text,
268 })),
269 ContentBlock::ToolCall(tc) => {
270 let normalized_id = MistralProvider::normalize_tool_call_id(&tc.id);
272 Ok(serde_json::json!({
273 "type": "function",
274 "id": normalized_id,
275 "function": {
276 "name": tc.name,
277 "arguments": tc.arguments.to_string(),
278 },
279 }))
280 }
281 ContentBlock::Thinking(th) => Ok(serde_json::json!({
282 "type": "thinking",
283 "thinking": th.thinking,
284 })),
285 ContentBlock::Image(img) => Ok(serde_json::json!({
286 "type": "image_url",
287 "image_url": {
288 "url": format!("data:{};base64,{}", img.mime_type, img.data),
289 },
290 })),
291 ContentBlock::Unknown(_) => Err(ProviderError::InvalidResponse(
292 "Unknown content block type".into(),
293 )),
294 })
295 .collect();
296
297 Ok(serde_json::json!(items?))
298}
299
300fn build_tools(tools: &[crate::Tool]) -> Result<JsonValue, ProviderError> {
302 let items: Vec<_> = tools
303 .iter()
304 .map(|tool| {
305 serde_json::json!({
306 "type": "function",
307 "function": {
308 "name": tool.name,
309 "description": tool.description,
310 "parameters": tool.parameters,
311 },
312 })
313 })
314 .collect();
315
316 Ok(serde_json::json!(items))
317}
318
319fn parse_sse_events(text: &str, provider: &str, model_id: &str) -> Vec<ProviderEvent> {
325 let mut events = Vec::new();
326 let mut partial_message = AssistantMessage::new(Api::OpenAiCompletions, provider, model_id);
327
328 let estimated_events = text.split('\n').filter(|l| l.starts_with("data: ")).count();
330 events.reserve(estimated_events);
331
332 let mut accumulated_usage = Usage::default();
333
334 for line in text.split('\n') {
335 let line = line.trim_end_matches('\r');
336 if line.is_empty() {
337 continue;
338 }
339
340 if !line.starts_with("data: ") {
341 continue;
342 }
343
344 let data = &line[6..];
345
346 if data == "[DONE]" {
347 break;
348 }
349
350 if data.is_empty() {
351 continue;
352 }
353
354 let chunk = match serde_json::from_str::<SSEChunk>(data) {
355 Ok(c) => c,
356 Err(_) => continue,
357 };
358
359 for choice in &chunk.choices {
360 if let Some(chunk_usage) = &chunk.usage {
362 accumulated_usage.input = chunk_usage.prompt_tokens;
363 accumulated_usage.output = chunk_usage.completion_tokens;
364 accumulated_usage.cache_read = chunk_usage
365 .prompt_tokens_details
366 .as_ref()
367 .map(|d| d.cached_tokens)
368 .unwrap_or(0);
369 accumulated_usage.total_tokens = chunk_usage.total_tokens;
370 }
371
372 if let Some(delta) = &choice.delta {
373 if let Some(content) = &delta.content {
374 let last_text_idx = partial_message
377 .content
378 .iter()
379 .rposition(|b| matches!(b, ContentBlock::Text(_)));
380 if let Some(idx) = last_text_idx {
381 if let ContentBlock::Text(t) = &mut partial_message.content[idx] {
382 t.text.push_str(content);
383 }
384 } else {
385 partial_message
386 .content
387 .push(ContentBlock::Text(crate::TextContent::new(content.clone())));
388 }
389 events.push(ProviderEvent::TextDelta {
390 content_index: choice.index,
391 delta: content.clone(),
392 partial: Arc::new(partial_message.clone()),
393 });
394 }
395
396 if let Some(tool_calls) = &delta.tool_calls {
397 for tc in tool_calls {
398 if let Some(ref id) = tc.id {
400 let _normalized_id = MistralProvider::normalize_tool_call_id(id);
401 if let Some(func) = &tc.function {
402 events.push(ProviderEvent::ToolCallDelta {
403 content_index: choice.index,
404 delta: func.arguments.clone().unwrap_or_default(),
405 partial: Arc::new(partial_message.clone()),
406 });
407 }
408 } else if let Some(func) = &tc.function {
409 events.push(ProviderEvent::ToolCallDelta {
411 content_index: choice.index,
412 delta: func.arguments.clone().unwrap_or_default(),
413 partial: Arc::new(partial_message.clone()),
414 });
415 }
416 }
417 }
418 }
419
420 if choice.finish_reason.is_some() {
421 let reason = match choice.finish_reason.as_deref() {
422 Some("stop") => StopReason::Stop,
423 Some("length") => StopReason::Length,
424 Some("tool_calls") => StopReason::ToolUse,
425 _ => StopReason::Stop,
426 };
427
428 let mut done_msg = partial_message.clone();
429 done_msg.usage = accumulated_usage.clone();
430 events.push(ProviderEvent::Done {
431 reason,
432 message: done_msg,
433 });
434 }
435 }
436 }
437
438 events
439}
440
441fn create_error_message(msg: &str, provider: &str, model_id: &str) -> AssistantMessage {
443 let mut message = AssistantMessage::new(Api::OpenAiCompletions, provider, model_id);
444 message.stop_reason = StopReason::Error;
445 message.error_message = Some(msg.to_string());
446 message
447}
448
449#[derive(Debug, Deserialize)]
451struct SSEChunk {
453 _id: Option<String>,
454 #[serde(rename = "model")]
455 _model: Option<String>,
456 choices: Vec<Choice>,
457 usage: Option<UsageInfo>,
458}
459
460#[derive(Debug, Deserialize)]
461struct Choice {
462 index: usize,
463 delta: Option<Delta>,
464 finish_reason: Option<String>,
465}
466
467#[derive(Debug, Deserialize)]
468struct Delta {
469 content: Option<String>,
470 tool_calls: Option<Vec<ToolCallDelta>>,
471}
472
473#[derive(Debug, Deserialize)]
474struct ToolCallDelta {
476 _index: Option<usize>,
477 id: Option<String>,
478 #[serde(rename = "type")]
479 _type_: Option<String>,
480 function: Option<FunctionDelta>,
481}
482
483#[derive(Debug, Deserialize)]
484struct FunctionDelta {
486 _name: Option<String>,
487 arguments: Option<String>,
488}
489
490#[derive(Debug, Deserialize, Clone)]
491struct UsageInfo {
492 prompt_tokens: usize,
493 completion_tokens: usize,
494 total_tokens: usize,
495 #[serde(rename = "prompt_tokens_details")]
496 prompt_tokens_details: Option<PromptTokensDetails>,
497}
498
499#[derive(Debug, Deserialize, Clone)]
500struct PromptTokensDetails {
501 #[serde(rename = "cached_tokens")]
502 cached_tokens: usize,
503}
504
505#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn test_normalize_tool_call_id_short() {
515 assert_eq!(MistralProvider::normalize_tool_call_id("short"), "short");
517 assert_eq!(MistralProvider::normalize_tool_call_id("abc"), "abc");
518 assert_eq!(MistralProvider::normalize_tool_call_id(""), "");
519 }
520
521 #[test]
522 fn test_normalize_tool_call_id_exact_length() {
523 assert_eq!(
525 MistralProvider::normalize_tool_call_id("123456789"),
526 "123456789"
527 );
528 assert_eq!(
529 MistralProvider::normalize_tool_call_id("abcdefghi"),
530 "abcdefghi"
531 );
532 }
533
534 #[test]
535 fn test_normalize_tool_call_id_long() {
536 let long_uuid = "call_abc123def456ghi789";
538 let result = MistralProvider::normalize_tool_call_id(long_uuid);
539 assert_eq!(result.len(), 9);
540 assert!(result.chars().all(|c| c.is_alphanumeric()));
541 }
542
543 #[test]
544 fn test_normalize_tool_call_id_with_special_chars() {
545 let id_with_special = "call-abc-123";
547 let result = MistralProvider::normalize_tool_call_id(id_with_special);
548 assert!(result.chars().all(|c| c.is_alphanumeric()));
549 assert_eq!(result.len(), 9);
550 }
551
552 #[test]
553 fn test_normalize_tool_call_id_padding() {
554 let short_id = "a-b-c";
557 let result = MistralProvider::normalize_tool_call_id(short_id);
558 assert_eq!(result, "a-b-c");
559
560 let long_with_special = "call-abc-def-ghi-jkl";
562 let result = MistralProvider::normalize_tool_call_id(long_with_special);
563 assert_eq!(result.len(), 9);
564 assert!(result.chars().all(|c| c.is_alphanumeric()));
565 }
566
567 #[test]
568 fn test_provider_name() {
569 let provider = MistralProvider::new();
570 assert_eq!(provider.name(), "mistral");
571 }
572
573 #[test]
574 fn test_provider_default() {
575 let provider = MistralProvider::default();
576 assert_eq!(provider.name(), "mistral");
577 }
578
579 #[test]
580 fn test_provider_with_api_key() {
581 let provider = MistralProvider::with_api_key("test-key-123");
582 assert_eq!(provider.name(), "mistral");
583 }
584
585 #[test]
586 fn test_parse_sse_text_delta() {
587 let sse_data = r#"data: {"id":"chatcmpl-123","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#;
588 let events = parse_sse_events(sse_data, "mistral", "mistral-small");
589
590 assert!(!events.is_empty());
591 match &events[0] {
592 ProviderEvent::TextDelta { delta, .. } => {
593 assert_eq!(delta, "Hello");
594 }
595 _ => panic!("Expected TextDelta event"),
596 }
597 }
598
599 #[test]
600 fn test_parse_sse_done_event() {
601 let sse_data = r#"data: {"id":"chatcmpl-123","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}"#;
602 let events = parse_sse_events(sse_data, "mistral", "mistral-small");
603
604 assert!(!events.is_empty());
605 let done_event = events
607 .iter()
608 .find(|e| matches!(e, ProviderEvent::Done { .. }));
609 assert!(done_event.is_some());
610
611 if let Some(ProviderEvent::Done { reason, message }) = done_event {
612 assert_eq!(*reason, StopReason::Stop);
613 assert_eq!(message.usage.input, 10);
614 assert_eq!(message.usage.output, 5);
615 }
616 }
617
618 #[test]
619 fn test_parse_sse_done_marker() {
620 let sse_data = "data: [DONE]\n";
622 let events = parse_sse_events(sse_data, "mistral", "mistral-small");
623 assert!(events.is_empty());
624 }
625
626 #[test]
627 fn test_parse_sse_tool_call() {
628 let sse_data = r#"data: {"id":"chatcmpl-123","choices":[{"index":0,"delta":{"tool_calls":[{"id":"call_abc","function":{"name":"get_weather","arguments":"{\"city\":\"NYC\"}"}}]},"finish_reason":"tool_calls"}]}"#;
629 let events = parse_sse_events(sse_data, "mistral", "mistral-small");
630
631 assert!(events.len() >= 2);
633 let has_tool_call = events
634 .iter()
635 .any(|e| matches!(e, ProviderEvent::ToolCallDelta { .. }));
636 assert!(has_tool_call);
637 }
638
639 #[test]
640 fn test_build_tools() {
641 let tool = crate::Tool::new(
642 "get_weather",
643 "Get weather for a location",
644 serde_json::json!({
645 "type": "object",
646 "properties": {
647 "city": { "type": "string", "description": "City name" }
648 },
649 "required": ["city"]
650 }),
651 );
652
653 let result = build_tools(&[tool]).unwrap();
654 let tools_array = result.as_array().unwrap();
655 assert_eq!(tools_array.len(), 1);
656
657 let first_tool = &tools_array[0];
658 assert_eq!(first_tool["type"], "function");
659 assert_eq!(first_tool["function"]["name"], "get_weather");
660 }
661
662 #[test]
663 fn test_build_messages_with_tool_result() {
664 use crate::{ContentBlock, Message, TextContent, ToolResultMessage};
665
666 let mut context = Context::new();
667 context.add_message(Message::ToolResult(ToolResultMessage::new(
668 "call_abc123456789",
669 "get_weather",
670 vec![ContentBlock::Text(TextContent::new("Sunny, 72°F"))],
671 )));
672
673 let messages = build_messages(&context).unwrap();
674 assert_eq!(messages.len(), 1);
675
676 let msg = &messages[0];
678 let tool_call_id = msg["tool_call_id"].as_str().unwrap();
679 assert_eq!(tool_call_id.len(), 9);
680 }
681
682 #[test]
683 fn test_blocks_to_content_single_text() {
684 use crate::TextContent;
685 let blocks = vec![ContentBlock::Text(TextContent::new("Hello world"))];
686 let result = blocks_to_content(&blocks).unwrap();
687 assert_eq!(result, serde_json::json!("Hello world"));
688 }
689
690 #[test]
691 fn test_blocks_to_content_multiple() {
692 use crate::TextContent;
693 let blocks = vec![
694 ContentBlock::Text(TextContent::new("Hello")),
695 ContentBlock::Text(TextContent::new(" world")),
696 ];
697 let result = blocks_to_content(&blocks).unwrap();
698 let arr = result.as_array().unwrap();
699 assert_eq!(arr.len(), 2);
700 }
701}