ai_lib/provider/
mistral.rs

1use crate::api::{
2    ChatCompletionChunk, ChatProvider, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
3};
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::transport::{DynHttpTransportRef, HttpTransport};
6use crate::types::{
7    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
8    UsageStatus,
9};
10use futures::stream::Stream;
11use futures::StreamExt;
12use std::collections::HashMap;
13use std::sync::Arc;
14#[cfg(feature = "unified_transport")]
15use std::time::Duration;
16use tokio::sync::mpsc;
17use tokio_stream::wrappers::UnboundedReceiverStream;
18
19/// Mistral adapter (conservative HTTP implementation).
20///
21/// Note: Mistral provides an official Rust SDK (<https://github.com/ivangabriele/mistralai-client-rs>).
22/// We keep this implementation HTTP-based for now and can swap to the SDK later.
23pub struct MistralAdapter {
24    #[allow(dead_code)] // Kept for backward compatibility, now using direct reqwest
25    transport: DynHttpTransportRef,
26    api_key: Option<String>,
27    base_url: String,
28    metrics: Arc<dyn Metrics>,
29}
30
31impl MistralAdapter {
32    #[allow(dead_code)]
33    fn build_default_timeout_secs() -> u64 {
34        std::env::var("AI_HTTP_TIMEOUT_SECS")
35            .ok()
36            .and_then(|s| s.parse::<u64>().ok())
37            .unwrap_or(30)
38    }
39
40    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
41        #[cfg(feature = "unified_transport")]
42        {
43            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
44            let client = crate::transport::client_factory::build_shared_client().map_err(|e| {
45                AiLibError::NetworkError(format!("Failed to build http client: {}", e))
46            })?;
47            let t = HttpTransport::with_reqwest_client(client, timeout);
48            Ok(t.boxed())
49        }
50        #[cfg(not(feature = "unified_transport"))]
51        {
52            let t = HttpTransport::new();
53            return Ok(t.boxed());
54        }
55    }
56
57    pub fn new() -> Result<Self, AiLibError> {
58        let api_key = std::env::var("MISTRAL_API_KEY").ok();
59        let base_url = std::env::var("MISTRAL_BASE_URL")
60            .unwrap_or_else(|_| "https://api.mistral.ai".to_string());
61        let boxed = Self::build_default_transport()?;
62        Ok(Self {
63            transport: boxed,
64            api_key,
65            base_url,
66            metrics: Arc::new(NoopMetrics::new()),
67        })
68    }
69
70    /// Explicit API key / base_url overrides.
71    pub fn new_with_overrides(
72        api_key: Option<String>,
73        base_url: Option<String>,
74    ) -> Result<Self, AiLibError> {
75        let boxed = Self::build_default_transport()?;
76        Ok(Self {
77            transport: boxed,
78            api_key,
79            base_url: base_url.unwrap_or_else(|| {
80                std::env::var("MISTRAL_BASE_URL")
81                    .unwrap_or_else(|_| "https://api.mistral.ai".to_string())
82            }),
83            metrics: Arc::new(NoopMetrics::new()),
84        })
85    }
86
87    /// Construct using an injected object-safe transport reference (for testing/SDKs)
88    pub fn with_transport(
89        transport: DynHttpTransportRef,
90        api_key: Option<String>,
91        base_url: String,
92    ) -> Result<Self, AiLibError> {
93        Ok(Self {
94            transport,
95            api_key,
96            base_url,
97            metrics: Arc::new(NoopMetrics::new()),
98        })
99    }
100
101    /// Construct with an injected transport and metrics implementation
102    pub fn with_transport_and_metrics(
103        transport: DynHttpTransportRef,
104        api_key: Option<String>,
105        base_url: String,
106        metrics: Arc<dyn Metrics>,
107    ) -> Result<Self, AiLibError> {
108        Ok(Self {
109            transport,
110            api_key,
111            base_url,
112            metrics,
113        })
114    }
115
116    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
117        let msgs: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
118            serde_json::json!({
119                "role": match msg.role { Role::System => "system", Role::User => "user", Role::Assistant => "assistant" },
120                "content": msg.content.as_text()
121            })
122        }).collect();
123
124        let mut body = serde_json::json!({ "model": request.model, "messages": msgs });
125        if let Some(temp) = request.temperature {
126            body["temperature"] =
127                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
128        }
129        if let Some(max_tokens) = request.max_tokens {
130            body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
131        }
132
133        // Function calling (OpenAI-compatible schema supported by Mistral chat/completions)
134        if let Some(funcs) = &request.functions {
135            let mapped: Vec<serde_json::Value> = funcs
136                .iter()
137                .map(|t| {
138                    serde_json::json!({
139                        "name": t.name,
140                        "description": t.description,
141                        "parameters": t.parameters.clone().unwrap_or(serde_json::json!({}))
142                    })
143                })
144                .collect();
145            body["functions"] = serde_json::Value::Array(mapped);
146        }
147        if let Some(policy) = &request.function_call {
148            match policy {
149                crate::types::FunctionCallPolicy::Auto(name) => {
150                    if name == "auto" {
151                        body["function_call"] = serde_json::Value::String("auto".to_string());
152                    } else {
153                        body["function_call"] = serde_json::json!({"name": name});
154                    }
155                }
156                crate::types::FunctionCallPolicy::None => {
157                    body["function_call"] = serde_json::Value::String("none".to_string());
158                }
159            }
160        }
161
162        request.apply_extensions(&mut body);
163
164        body
165    }
166
167    fn parse_response(
168        &self,
169        response: serde_json::Value,
170    ) -> Result<ChatCompletionResponse, AiLibError> {
171        let choices = response["choices"]
172            .as_array()
173            .unwrap_or(&vec![])
174            .iter()
175            .enumerate()
176            .map(|(index, choice)| {
177                let message = choice["message"].as_object().ok_or_else(|| {
178                    AiLibError::ProviderError("Invalid choice format".to_string())
179                })?;
180                let role = match message["role"].as_str().unwrap_or("user") {
181                    "system" => Role::System,
182                    "assistant" => Role::Assistant,
183                    _ => Role::User,
184                };
185                let content = message["content"].as_str().unwrap_or("").to_string();
186
187                // try to parse function_call if present
188                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
189                if let Some(fc_val) = message.get("function_call") {
190                    if let Ok(fc) = serde_json::from_value::<
191                        crate::types::function_call::FunctionCall,
192                    >(fc_val.clone())
193                    {
194                        function_call = Some(fc);
195                    } else if let Some(name) = fc_val
196                        .get("name")
197                        .and_then(|v| v.as_str())
198                        .map(|s| s.to_string())
199                    {
200                        let args = fc_val.get("arguments").and_then(|a| {
201                            if a.is_string() {
202                                serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
203                            } else {
204                                Some(a.clone())
205                            }
206                        });
207                        function_call = Some(crate::types::function_call::FunctionCall {
208                            name,
209                            arguments: args,
210                        });
211                    }
212                } else if let Some(tool_calls) =
213                    message.get("tool_calls").and_then(|v| v.as_array())
214                {
215                    if let Some(first) = tool_calls.first() {
216                        if let Some(func) =
217                            first.get("function").or_else(|| first.get("function_call"))
218                        {
219                            if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
220                                let mut args_opt = func.get("arguments").cloned();
221                                if let Some(args_val) = &args_opt {
222                                    if args_val.is_string() {
223                                        if let Some(s) = args_val.as_str() {
224                                            if let Ok(parsed) =
225                                                serde_json::from_str::<serde_json::Value>(s)
226                                            {
227                                                args_opt = Some(parsed);
228                                            }
229                                        }
230                                    }
231                                }
232                                function_call = Some(crate::types::function_call::FunctionCall {
233                                    name: name.to_string(),
234                                    arguments: args_opt,
235                                });
236                            }
237                        }
238                    }
239                }
240
241                Ok(Choice {
242                    index: index as u32,
243                    message: Message {
244                        role,
245                        content: crate::types::common::Content::Text(content),
246                        function_call,
247                    },
248                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
249                })
250            })
251            .collect::<Result<Vec<_>, AiLibError>>()?;
252
253        let usage = response["usage"].as_object().ok_or_else(|| {
254            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
255        })?;
256        let usage = Usage {
257            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
258            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
259            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
260        };
261
262        Ok(ChatCompletionResponse {
263            id: response["id"].as_str().unwrap_or_default().to_string(),
264            object: response["object"].as_str().unwrap_or_default().to_string(),
265            created: response["created"].as_u64().unwrap_or(0),
266            model: response["model"].as_str().unwrap_or_default().to_string(),
267            choices,
268            usage,
269            usage_status: UsageStatus::Finalized, // Mistral provides accurate usage data
270        })
271    }
272}
273
274#[cfg(not(feature = "unified_sse"))]
275fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
276    let mut i = 0;
277    while i < buffer.len().saturating_sub(1) {
278        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
279            return Some(i + 2);
280        }
281        if i < buffer.len().saturating_sub(3)
282            && buffer[i] == b'\r'
283            && buffer[i + 1] == b'\n'
284            && buffer[i + 2] == b'\r'
285            && buffer[i + 3] == b'\n'
286        {
287            return Some(i + 4);
288        }
289        i += 1;
290    }
291    None
292}
293
294#[cfg(not(feature = "unified_sse"))]
295fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
296    for line in event_text.lines() {
297        let line = line.trim();
298        if let Some(stripped) = line.strip_prefix("data: ") {
299            let data = stripped;
300            if data == "[DONE]" {
301                return Some(Ok(None));
302            }
303            return Some(parse_chunk_data(data));
304        }
305    }
306    None
307}
308
309#[cfg(not(feature = "unified_sse"))]
310fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
311    let json: serde_json::Value = serde_json::from_str(data)
312        .map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
313    let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
314    if let Some(arr) = json["choices"].as_array() {
315        for (index, choice) in arr.iter().enumerate() {
316            let delta = &choice["delta"];
317            let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r {
318                "assistant" => Role::Assistant,
319                "user" => Role::User,
320                "system" => Role::System,
321                _ => Role::Assistant,
322            });
323            let content = delta
324                .get("content")
325                .and_then(|v| v.as_str())
326                .map(|s| s.to_string());
327            let md = MessageDelta { role, content };
328            let cd = ChoiceDelta {
329                index: index as u32,
330                delta: md,
331                finish_reason: choice
332                    .get("finish_reason")
333                    .and_then(|v| v.as_str())
334                    .map(|s| s.to_string()),
335            };
336            choices_vec.push(cd);
337        }
338    }
339
340    Ok(Some(ChatCompletionChunk {
341        id: json["id"].as_str().unwrap_or_default().to_string(),
342        object: json["object"]
343            .as_str()
344            .unwrap_or("chat.completion.chunk")
345            .to_string(),
346        created: json["created"].as_u64().unwrap_or(0),
347        model: json["model"].as_str().unwrap_or_default().to_string(),
348        choices: choices_vec,
349    }))
350}
351
352fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
353    let mut chunks = Vec::new();
354    let mut start = 0;
355    let s = text.as_bytes();
356    while start < s.len() {
357        let end = std::cmp::min(start + max_len, s.len());
358        let mut cut = end;
359        if end < s.len() {
360            if let Some(pos) = text[start..end].rfind(' ') {
361                cut = start + pos;
362            }
363        }
364        if cut == start {
365            cut = end;
366        }
367        let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
368        chunks.push(chunk);
369        start = cut;
370        if start < s.len() && s[start] == b' ' {
371            start += 1;
372        }
373    }
374    chunks
375}
376
377#[async_trait::async_trait]
378impl ChatProvider for MistralAdapter {
379    fn name(&self) -> &str {
380        "Mistral"
381    }
382
383    async fn chat(
384        &self,
385        request: ChatCompletionRequest,
386    ) -> Result<ChatCompletionResponse, AiLibError> {
387        self.metrics.incr_counter("mistral.requests", 1).await;
388        let timer = self
389            .metrics
390            .start_timer("mistral.request_duration_ms")
391            .await;
392
393        let url = format!("{}{}", self.base_url, "/v1/chat/completions");
394        let provider_request = self.convert_request(&request);
395        let mut headers = HashMap::new();
396        headers.insert("Content-Type".to_string(), "application/json".to_string());
397        if let Some(key) = &self.api_key {
398            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
399        }
400        let response_json = self
401            .transport
402            .post_json(&url, Some(headers), provider_request)
403            .await
404            .map_err(|e| e.with_context(&format!("Mistral chat request to {}", url)))?;
405        if let Some(t) = timer {
406            t.stop();
407        }
408        self.parse_response(response_json)
409    }
410
411    async fn stream(
412        &self,
413        request: ChatCompletionRequest,
414    ) -> Result<
415        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
416        AiLibError,
417    > {
418        let mut stream_request = self.convert_request(&request);
419        stream_request["stream"] = serde_json::Value::Bool(true);
420
421        let url = format!("{}{}", self.base_url, "/v1/chat/completions");
422
423        let mut headers = HashMap::new();
424        headers.insert("Accept".to_string(), "text/event-stream".to_string());
425        if let Some(key) = &self.api_key {
426            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
427        }
428        if let Ok(mut byte_stream) = self
429            .transport
430            .post_stream(&url, Some(headers.clone()), stream_request.clone())
431            .await
432        {
433            let (tx, rx) = mpsc::unbounded_channel();
434            tokio::spawn(async move {
435                let mut buffer = Vec::new();
436                while let Some(item) = byte_stream.next().await {
437                    match item {
438                        Ok(bytes) => {
439                            buffer.extend_from_slice(&bytes);
440                            #[cfg(feature = "unified_sse")]
441                            {
442                                while let Some(boundary) =
443                                    crate::sse::parser::find_event_boundary(&buffer)
444                                {
445                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
446                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
447                                        if let Some(parsed) =
448                                            crate::sse::parser::parse_sse_event(event_text)
449                                        {
450                                            match parsed {
451                                                Ok(Some(chunk)) => {
452                                                    if tx.send(Ok(chunk)).is_err() {
453                                                        return;
454                                                    }
455                                                }
456                                                Ok(None) => return,
457                                                Err(e) => {
458                                                    let _ = tx.send(Err(e));
459                                                    return;
460                                                }
461                                            }
462                                        }
463                                    }
464                                }
465                            }
466                            #[cfg(not(feature = "unified_sse"))]
467                            {
468                                while let Some(boundary) = find_event_boundary(&buffer) {
469                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
470                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
471                                        if let Some(parsed) = parse_sse_event(event_text) {
472                                            match parsed {
473                                                Ok(Some(chunk)) => {
474                                                    if tx.send(Ok(chunk)).is_err() {
475                                                        return;
476                                                    }
477                                                }
478                                                Ok(None) => return,
479                                                Err(e) => {
480                                                    let _ = tx.send(Err(e));
481                                                    return;
482                                                }
483                                            }
484                                        }
485                                    }
486                                }
487                            }
488                        }
489                        Err(e) => {
490                            let _ = tx.send(Err(AiLibError::ProviderError(format!(
491                                "Stream error: {}",
492                                e
493                            ))));
494                            break;
495                        }
496                    }
497                }
498            });
499            let stream = UnboundedReceiverStream::new(rx);
500            return Ok(Box::new(Box::pin(stream)));
501        }
502
503        // fallback: call chat and stream chunks
504        let finished = self.chat(request).await?;
505        let text = finished
506            .choices
507            .first()
508            .map(|c| c.message.content.as_text())
509            .unwrap_or_default();
510        let (tx, rx) = mpsc::unbounded_channel();
511        tokio::spawn(async move {
512            let chunks = split_text_into_chunks(&text, 80);
513            for chunk in chunks {
514                let delta = ChoiceDelta {
515                    index: 0,
516                    delta: MessageDelta {
517                        role: Some(Role::Assistant),
518                        content: Some(chunk.clone()),
519                    },
520                    finish_reason: None,
521                };
522                let chunk_obj = ChatCompletionChunk {
523                    id: "simulated".to_string(),
524                    object: "chat.completion.chunk".to_string(),
525                    created: 0,
526                    model: finished.model.clone(),
527                    choices: vec![delta],
528                };
529                if tx.send(Ok(chunk_obj)).is_err() {
530                    return;
531                }
532                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
533            }
534        });
535        let stream = UnboundedReceiverStream::new(rx);
536        Ok(Box::new(Box::pin(stream)))
537    }
538
539    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
540        // Mistral models endpoint
541        let url = format!("{}/v1/models", self.base_url);
542        let mut headers = HashMap::new();
543        if let Some(key) = &self.api_key {
544            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
545        }
546        let response = self.transport.get_json(&url, Some(headers)).await?;
547        Ok(response["data"]
548            .as_array()
549            .unwrap_or(&vec![])
550            .iter()
551            .filter_map(|m| m["id"].as_str().map(|s| s.to_string()))
552            .collect())
553    }
554
555    async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
556        Ok(ModelInfo {
557            id: model_id.to_string(),
558            object: "model".to_string(),
559            created: 0,
560            owned_by: "mistral".to_string(),
561            permission: vec![ModelPermission {
562                id: "default".to_string(),
563                object: "model_permission".to_string(),
564                created: 0,
565                allow_create_engine: false,
566                allow_sampling: true,
567                allow_logprobs: false,
568                allow_search_indices: false,
569                allow_view: true,
570                allow_fine_tuning: false,
571                organization: "*".to_string(),
572                group: None,
573                is_blocking: false,
574            }],
575        })
576    }
577}