Skip to main content

oxi_ai/providers/
mistral.rs

1//! Mistral AI provider implementation
2//!
3//! Mistral AI uses an OpenAI-compatible API with some minor differences:
4//! - Tool call IDs are 9 characters (vs longer UUIDs from OpenAI)
5//! - API key is read from MISTRAL_API_KEY environment variable
6
7use async_trait::async_trait;
8use bytes::Bytes;
9use futures::{Stream, StreamExt};
10use reqwest::Client;
11use serde::Deserialize;
12use serde_json::Value as JsonValue;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use crate::{
17    Api, AssistantMessage, ContentBlock, Context, Model, Provider, ProviderError, ProviderEvent,
18    StopReason, StreamOptions, Usage,
19};
20
21use super::shared_client;
22
23/// Mistral AI provider
24#[derive(Clone)]
25pub struct MistralProvider {
26    client: &'static Client,
27    api_key: Option<String>,
28}
29
30/// Default Mistral API endpoint
31const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
32
33/// Mistral requires 9-character tool call IDs
34const MISTRAL_TOOL_CALL_ID_LENGTH: usize = 9;
35
36impl MistralProvider {
37    /// Create a new Mistral provider without an API key.
38    ///
39    /// API keys are resolved at request time via auth.json or StreamOptions.
40    pub fn new() -> Self {
41        Self {
42            client: shared_client(),
43            api_key: None,
44        }
45    }
46
47    /// Create a Mistral provider with a specific API key (test-only)
48    #[cfg(test)]
49    pub fn with_api_key(api_key: impl Into<String>) -> Self {
50        Self {
51            client: shared_client(),
52            api_key: Some(api_key.into()),
53        }
54    }
55
56    /// Normalize tool call ID to Mistral's expected format (9 characters)
57    ///
58    /// Mistral's API expects tool call IDs to be exactly 9 characters.
59    /// When processing tool results or assistant messages with tool calls,
60    /// we need to normalize longer IDs to fit this constraint.
61    fn normalize_tool_call_id(id: &str) -> String {
62        // If already 9 chars or shorter, return as-is
63        if id.len() <= MISTRAL_TOOL_CALL_ID_LENGTH {
64            return id.to_string();
65        }
66
67        // Take first 9 characters and ensure they're alphanumeric
68        let normalized: String = id
69            .chars()
70            .filter(|c| c.is_alphanumeric())
71            .take(MISTRAL_TOOL_CALL_ID_LENGTH)
72            .collect();
73
74        // If we don't have enough chars, pad with zeros
75        if normalized.len() < MISTRAL_TOOL_CALL_ID_LENGTH {
76            format!(
77                "{}{}",
78                normalized,
79                "0".repeat(MISTRAL_TOOL_CALL_ID_LENGTH - normalized.len())
80            )
81        } else {
82            normalized
83        }
84    }
85}
86
87impl Default for MistralProvider {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93#[async_trait]
94impl Provider for MistralProvider {
95    async fn stream(
96        &self,
97        model: &Model,
98        context: &Context,
99        options: Option<StreamOptions>,
100    ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
101        let options = options.unwrap_or_default();
102
103        // Build the request URL
104        let base_url = if model.base_url.is_empty() {
105            MISTRAL_API_URL.to_string()
106        } else {
107            model.base_url.trim_end_matches('/').to_string()
108        };
109        let url = format!("{}/chat/completions", base_url);
110
111        // Get API key
112        let api_key = options
113            .api_key
114            .as_ref()
115            .or(self.api_key.as_ref())
116            .ok_or(ProviderError::MissingApiKey)?;
117
118        // Build messages with normalized tool call IDs
119        let messages = build_messages(context)?;
120
121        // Build request body
122        let mut body = serde_json::json!({
123            "model": model.id,
124            "messages": messages,
125            "stream": true,
126        });
127
128        // Add optional parameters
129        if let Some(temp) = options.temperature {
130            body["temperature"] = serde_json::json!(temp);
131        }
132
133        if let Some(max) = options.max_tokens {
134            body["max_tokens"] = serde_json::json!(max);
135        }
136
137        // Add tools if present
138        if !context.tools.is_empty() {
139            body["tools"] = build_tools(&context.tools)?;
140        }
141
142        // Build headers
143        let mut headers = reqwest::header::HeaderMap::new();
144        headers.insert(
145            reqwest::header::AUTHORIZATION,
146            format!("Bearer {}", api_key)
147                .parse()
148                .expect("valid bearer header"),
149        );
150        headers.insert(
151            reqwest::header::CONTENT_TYPE,
152            "application/json".parse().expect("valid header value"),
153        );
154
155        for (k, v) in &options.headers {
156            if let (Ok(name), Ok(value)) = (
157                k.parse::<reqwest::header::HeaderName>(),
158                v.parse::<reqwest::header::HeaderValue>(),
159            ) {
160                headers.insert(name, value);
161            }
162        }
163
164        // Make request
165        let response = self
166            .client
167            .post(&url)
168            .headers(headers)
169            .json(&body)
170            .send()
171            .await
172            .map_err(ProviderError::RequestFailed)?;
173
174        if !response.status().is_success() {
175            let status = response.status();
176            let body: String = response.text().await.unwrap_or_default();
177            return Err(ProviderError::HttpError(status.as_u16(), body));
178        }
179
180        // Create event stream
181        let provider_name = model.provider.clone();
182        let model_id = model.id.clone();
183
184        let stream = response.bytes_stream().flat_map(
185            move |chunk: Result<Bytes, reqwest::Error>| match chunk {
186                Ok(bytes) => {
187                    let text = String::from_utf8_lossy(&bytes).to_string();
188                    futures::stream::iter(parse_sse_events(&text, &provider_name, &model_id))
189                }
190                Err(e) => futures::stream::iter(vec![ProviderEvent::Error {
191                    reason: StopReason::Error,
192                    error: create_error_message(&e.to_string(), &provider_name, &model_id),
193                }]),
194            },
195        );
196
197        Ok(Box::pin(stream))
198    }
199
200    fn name(&self) -> &str {
201        "mistral"
202    }
203}
204
205/// Build messages array from context
206fn build_messages(context: &Context) -> Result<Vec<JsonValue>, ProviderError> {
207    let mut messages = Vec::new();
208
209    // System prompt
210    if let Some(ref prompt) = context.system_prompt {
211        messages.push(serde_json::json!({
212            "role": "system",
213            "content": prompt,
214        }));
215    }
216
217    // Conversation messages
218    for msg in &context.messages {
219        match msg {
220            crate::Message::User(u) => {
221                let content: String = match &u.content {
222                    crate::MessageContent::Text(s) => s.clone(),
223                    crate::MessageContent::Blocks(blocks) => blocks_to_content(blocks)?.to_string(),
224                };
225                messages.push(serde_json::json!({
226                    "role": "user",
227                    "content": content,
228                }));
229            }
230            crate::Message::Assistant(a) => {
231                let content = blocks_to_content(&a.content)?.to_string();
232                messages.push(serde_json::json!({
233                    "role": "assistant",
234                    "content": content,
235                }));
236            }
237            crate::Message::ToolResult(t) => {
238                let content = blocks_to_content(&t.content)?.to_string();
239                // Normalize tool call ID for Mistral
240                let normalized_id = MistralProvider::normalize_tool_call_id(&t.tool_call_id);
241                messages.push(serde_json::json!({
242                    "role": "tool",
243                    "tool_call_id": normalized_id,
244                    "tool_name": t.tool_name,
245                    "content": content,
246                }));
247            }
248        }
249    }
250
251    Ok(messages)
252}
253
254/// Convert content blocks to a string representation
255fn blocks_to_content(blocks: &[ContentBlock]) -> Result<JsonValue, ProviderError> {
256    if blocks.len() == 1 {
257        if let Some(text) = blocks[0].as_text() {
258            return Ok(JsonValue::String(text.to_string()));
259        }
260    }
261
262    let items: Result<Vec<_>, _> = blocks
263        .iter()
264        .map(|block| match block {
265            ContentBlock::Text(t) => Ok(serde_json::json!({
266                "type": "text",
267                "text": t.text,
268            })),
269            ContentBlock::ToolCall(tc) => {
270                // Normalize tool call ID
271                let normalized_id = MistralProvider::normalize_tool_call_id(&tc.id);
272                Ok(serde_json::json!({
273                    "type": "function",
274                    "id": normalized_id,
275                    "function": {
276                        "name": tc.name,
277                        "arguments": tc.arguments.to_string(),
278                    },
279                }))
280            }
281            ContentBlock::Thinking(th) => Ok(serde_json::json!({
282                "type": "thinking",
283                "thinking": th.thinking,
284            })),
285            ContentBlock::Image(img) => Ok(serde_json::json!({
286                "type": "image_url",
287                "image_url": {
288                    "url": format!("data:{};base64,{}", img.mime_type, img.data),
289                },
290            })),
291            ContentBlock::Unknown(_) => Err(ProviderError::InvalidResponse(
292                "Unknown content block type".into(),
293            )),
294        })
295        .collect();
296
297    Ok(serde_json::json!(items?))
298}
299
300/// Build tools array
301fn build_tools(tools: &[crate::Tool]) -> Result<JsonValue, ProviderError> {
302    let items: Vec<_> = tools
303        .iter()
304        .map(|tool| {
305            serde_json::json!({
306                "type": "function",
307                "function": {
308                    "name": tool.name,
309                    "description": tool.description,
310                    "parameters": tool.parameters,
311                },
312            })
313        })
314        .collect();
315
316    Ok(serde_json::json!(items))
317}
318
319/// Parse SSE event stream from a byte buffer.
320///
321/// Mistral-specific handling:
322/// - Normalizes tool call IDs to 9 characters
323/// - Handles Mistral's streaming format (OpenAI-compatible)
324fn parse_sse_events(text: &str, provider: &str, model_id: &str) -> Vec<ProviderEvent> {
325    let mut events = Vec::new();
326    let mut partial_message = AssistantMessage::new(Api::OpenAiCompletions, provider, model_id);
327
328    // Pre-estimate capacity
329    let estimated_events = text.split('\n').filter(|l| l.starts_with("data: ")).count();
330    events.reserve(estimated_events);
331
332    let mut accumulated_usage = Usage::default();
333
334    for line in text.split('\n') {
335        let line = line.trim_end_matches('\r');
336        if line.is_empty() {
337            continue;
338        }
339
340        if !line.starts_with("data: ") {
341            continue;
342        }
343
344        let data = &line[6..];
345
346        if data == "[DONE]" {
347            break;
348        }
349
350        if data.is_empty() {
351            continue;
352        }
353
354        let chunk = match serde_json::from_str::<SSEChunk>(data) {
355            Ok(c) => c,
356            Err(_) => continue,
357        };
358
359        for choice in &chunk.choices {
360            // Accumulate usage first (before emitting Done)
361            if let Some(chunk_usage) = &chunk.usage {
362                accumulated_usage.input = chunk_usage.prompt_tokens;
363                accumulated_usage.output = chunk_usage.completion_tokens;
364                accumulated_usage.cache_read = chunk_usage
365                    .prompt_tokens_details
366                    .as_ref()
367                    .map(|d| d.cached_tokens)
368                    .unwrap_or(0);
369                accumulated_usage.total_tokens = chunk_usage.total_tokens;
370            }
371
372            if let Some(delta) = &choice.delta {
373                if let Some(content) = &delta.content {
374                    // pi-mono: accumulate into partial_message so the TUI can
375                    // diff against its snapshot tracker.
376                    let last_text_idx = partial_message
377                        .content
378                        .iter()
379                        .rposition(|b| matches!(b, ContentBlock::Text(_)));
380                    if let Some(idx) = last_text_idx {
381                        if let ContentBlock::Text(t) = &mut partial_message.content[idx] {
382                            t.text.push_str(content);
383                        }
384                    } else {
385                        partial_message
386                            .content
387                            .push(ContentBlock::Text(crate::TextContent::new(content.clone())));
388                    }
389                    events.push(ProviderEvent::TextDelta {
390                        content_index: choice.index,
391                        delta: content.clone(),
392                        partial: Arc::new(partial_message.clone()),
393                    });
394                }
395
396                if let Some(tool_calls) = &delta.tool_calls {
397                    for tc in tool_calls {
398                        // Normalize tool call ID if present
399                        if let Some(ref id) = tc.id {
400                            let _normalized_id = MistralProvider::normalize_tool_call_id(id);
401                            if let Some(func) = &tc.function {
402                                events.push(ProviderEvent::ToolCallDelta {
403                                    content_index: choice.index,
404                                    delta: func.arguments.clone().unwrap_or_default(),
405                                    partial: Arc::new(partial_message.clone()),
406                                });
407                            }
408                        } else if let Some(func) = &tc.function {
409                            // Send even without ID
410                            events.push(ProviderEvent::ToolCallDelta {
411                                content_index: choice.index,
412                                delta: func.arguments.clone().unwrap_or_default(),
413                                partial: Arc::new(partial_message.clone()),
414                            });
415                        }
416                    }
417                }
418            }
419
420            if choice.finish_reason.is_some() {
421                let reason = match choice.finish_reason.as_deref() {
422                    Some("stop") => StopReason::Stop,
423                    Some("length") => StopReason::Length,
424                    Some("tool_calls") => StopReason::ToolUse,
425                    _ => StopReason::Stop,
426                };
427
428                let mut done_msg = partial_message.clone();
429                done_msg.usage = accumulated_usage.clone();
430                events.push(ProviderEvent::Done {
431                    reason,
432                    message: done_msg,
433                });
434            }
435        }
436    }
437
438    events
439}
440
441/// Create error assistant message
442fn create_error_message(msg: &str, provider: &str, model_id: &str) -> AssistantMessage {
443    let mut message = AssistantMessage::new(Api::OpenAiCompletions, provider, model_id);
444    message.stop_reason = StopReason::Error;
445    message.error_message = Some(msg.to_string());
446    message
447}
448
449// SSE chunk structures (OpenAI-compatible)
450#[derive(Debug, Deserialize)]
451// serde deserialization structs
452struct SSEChunk {
453    _id: Option<String>,
454    #[serde(rename = "model")]
455    _model: Option<String>,
456    choices: Vec<Choice>,
457    usage: Option<UsageInfo>,
458}
459
460#[derive(Debug, Deserialize)]
461struct Choice {
462    index: usize,
463    delta: Option<Delta>,
464    finish_reason: Option<String>,
465}
466
467#[derive(Debug, Deserialize)]
468struct Delta {
469    content: Option<String>,
470    tool_calls: Option<Vec<ToolCallDelta>>,
471}
472
473#[derive(Debug, Deserialize)]
474// serde deserialization structs
475struct ToolCallDelta {
476    _index: Option<usize>,
477    id: Option<String>,
478    #[serde(rename = "type")]
479    _type_: Option<String>,
480    function: Option<FunctionDelta>,
481}
482
483#[derive(Debug, Deserialize)]
484// serde deserialization structs
485struct FunctionDelta {
486    _name: Option<String>,
487    arguments: Option<String>,
488}
489
490#[derive(Debug, Deserialize, Clone)]
491struct UsageInfo {
492    prompt_tokens: usize,
493    completion_tokens: usize,
494    total_tokens: usize,
495    #[serde(rename = "prompt_tokens_details")]
496    prompt_tokens_details: Option<PromptTokensDetails>,
497}
498
499#[derive(Debug, Deserialize, Clone)]
500struct PromptTokensDetails {
501    #[serde(rename = "cached_tokens")]
502    cached_tokens: usize,
503}
504
505// ============================================================================
506// Tests
507// ============================================================================
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512
513    #[test]
514    fn test_normalize_tool_call_id_short() {
515        // Already short IDs should pass through unchanged
516        assert_eq!(MistralProvider::normalize_tool_call_id("short"), "short");
517        assert_eq!(MistralProvider::normalize_tool_call_id("abc"), "abc");
518        assert_eq!(MistralProvider::normalize_tool_call_id(""), "");
519    }
520
521    #[test]
522    fn test_normalize_tool_call_id_exact_length() {
523        // Exactly 9 characters should pass through
524        assert_eq!(
525            MistralProvider::normalize_tool_call_id("123456789"),
526            "123456789"
527        );
528        assert_eq!(
529            MistralProvider::normalize_tool_call_id("abcdefghi"),
530            "abcdefghi"
531        );
532    }
533
534    #[test]
535    fn test_normalize_tool_call_id_long() {
536        // Longer IDs should be truncated to 9 chars, keeping only alphanumeric
537        let long_uuid = "call_abc123def456ghi789";
538        let result = MistralProvider::normalize_tool_call_id(long_uuid);
539        assert_eq!(result.len(), 9);
540        assert!(result.chars().all(|c| c.is_alphanumeric()));
541    }
542
543    #[test]
544    fn test_normalize_tool_call_id_with_special_chars() {
545        // IDs with special characters should have them removed
546        let id_with_special = "call-abc-123";
547        let result = MistralProvider::normalize_tool_call_id(id_with_special);
548        assert!(result.chars().all(|c| c.is_alphanumeric()));
549        assert_eq!(result.len(), 9);
550    }
551
552    #[test]
553    fn test_normalize_tool_call_id_padding() {
554        // IDs shorter than 9 chars are preserved as-is
555        // (only longer IDs get truncated/padded)
556        let short_id = "a-b-c";
557        let result = MistralProvider::normalize_tool_call_id(short_id);
558        assert_eq!(result, "a-b-c");
559
560        // Long IDs with special chars get normalized
561        let long_with_special = "call-abc-def-ghi-jkl";
562        let result = MistralProvider::normalize_tool_call_id(long_with_special);
563        assert_eq!(result.len(), 9);
564        assert!(result.chars().all(|c| c.is_alphanumeric()));
565    }
566
567    #[test]
568    fn test_provider_name() {
569        let provider = MistralProvider::new();
570        assert_eq!(provider.name(), "mistral");
571    }
572
573    #[test]
574    fn test_provider_default() {
575        let provider = MistralProvider::default();
576        assert_eq!(provider.name(), "mistral");
577    }
578
579    #[test]
580    fn test_provider_with_api_key() {
581        let provider = MistralProvider::with_api_key("test-key-123");
582        assert_eq!(provider.name(), "mistral");
583    }
584
585    #[test]
586    fn test_parse_sse_text_delta() {
587        let sse_data = r#"data: {"id":"chatcmpl-123","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#;
588        let events = parse_sse_events(sse_data, "mistral", "mistral-small");
589
590        assert!(!events.is_empty());
591        match &events[0] {
592            ProviderEvent::TextDelta { delta, .. } => {
593                assert_eq!(delta, "Hello");
594            }
595            _ => panic!("Expected TextDelta event"),
596        }
597    }
598
599    #[test]
600    fn test_parse_sse_done_event() {
601        let sse_data = r#"data: {"id":"chatcmpl-123","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}"#;
602        let events = parse_sse_events(sse_data, "mistral", "mistral-small");
603
604        assert!(!events.is_empty());
605        // Find the Done event
606        let done_event = events
607            .iter()
608            .find(|e| matches!(e, ProviderEvent::Done { .. }));
609        assert!(done_event.is_some());
610
611        if let Some(ProviderEvent::Done { reason, message }) = done_event {
612            assert_eq!(*reason, StopReason::Stop);
613            assert_eq!(message.usage.input, 10);
614            assert_eq!(message.usage.output, 5);
615        }
616    }
617
618    #[test]
619    fn test_parse_sse_done_marker() {
620        // Test that [DONE] is handled correctly
621        let sse_data = "data: [DONE]\n";
622        let events = parse_sse_events(sse_data, "mistral", "mistral-small");
623        assert!(events.is_empty());
624    }
625
626    #[test]
627    fn test_parse_sse_tool_call() {
628        let sse_data = r#"data: {"id":"chatcmpl-123","choices":[{"index":0,"delta":{"tool_calls":[{"id":"call_abc","function":{"name":"get_weather","arguments":"{\"city\":\"NYC\"}"}}]},"finish_reason":"tool_calls"}]}"#;
629        let events = parse_sse_events(sse_data, "mistral", "mistral-small");
630
631        // Should have tool call delta and done event
632        assert!(events.len() >= 2);
633        let has_tool_call = events
634            .iter()
635            .any(|e| matches!(e, ProviderEvent::ToolCallDelta { .. }));
636        assert!(has_tool_call);
637    }
638
639    #[test]
640    fn test_build_tools() {
641        let tool = crate::Tool::new(
642            "get_weather",
643            "Get weather for a location",
644            serde_json::json!({
645                "type": "object",
646                "properties": {
647                    "city": { "type": "string", "description": "City name" }
648                },
649                "required": ["city"]
650            }),
651        );
652
653        let result = build_tools(&[tool]).unwrap();
654        let tools_array = result.as_array().unwrap();
655        assert_eq!(tools_array.len(), 1);
656
657        let first_tool = &tools_array[0];
658        assert_eq!(first_tool["type"], "function");
659        assert_eq!(first_tool["function"]["name"], "get_weather");
660    }
661
662    #[test]
663    fn test_build_messages_with_tool_result() {
664        use crate::{ContentBlock, Message, TextContent, ToolResultMessage};
665
666        let mut context = Context::new();
667        context.add_message(Message::ToolResult(ToolResultMessage::new(
668            "call_abc123456789",
669            "get_weather",
670            vec![ContentBlock::Text(TextContent::new("Sunny, 72°F"))],
671        )));
672
673        let messages = build_messages(&context).unwrap();
674        assert_eq!(messages.len(), 1);
675
676        // Verify tool call ID was normalized to 9 chars
677        let msg = &messages[0];
678        let tool_call_id = msg["tool_call_id"].as_str().unwrap();
679        assert_eq!(tool_call_id.len(), 9);
680    }
681
682    #[test]
683    fn test_blocks_to_content_single_text() {
684        use crate::TextContent;
685        let blocks = vec![ContentBlock::Text(TextContent::new("Hello world"))];
686        let result = blocks_to_content(&blocks).unwrap();
687        assert_eq!(result, serde_json::json!("Hello world"));
688    }
689
690    #[test]
691    fn test_blocks_to_content_multiple() {
692        use crate::TextContent;
693        let blocks = vec![
694            ContentBlock::Text(TextContent::new("Hello")),
695            ContentBlock::Text(TextContent::new(" world")),
696        ];
697        let result = blocks_to_content(&blocks).unwrap();
698        let arr = result.as_array().unwrap();
699        assert_eq!(arr.len(), 2);
700    }
701}