ai_lib/provider/
mistral.rs

1use crate::api::{
2    ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
3};
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::transport::{DynHttpTransportRef, HttpTransport};
6use crate::types::{
7    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
8};
9use futures::stream::Stream;
10use futures::StreamExt;
11use std::collections::HashMap;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14use std::sync::Arc;
15use tokio::sync::mpsc;
16use tokio_stream::wrappers::UnboundedReceiverStream;
17
18/// Mistral adapter (conservative HTTP implementation).
19///
20/// Note: Mistral provides an official Rust SDK (https://github.com/ivangabriele/mistralai-client-rs).
21/// We keep this implementation HTTP-based for now and can swap to the SDK later.
22pub struct MistralAdapter {
23    #[allow(dead_code)] // Kept for backward compatibility, now using direct reqwest
24    transport: DynHttpTransportRef,
25    api_key: Option<String>,
26    base_url: String,
27    metrics: Arc<dyn Metrics>,
28}
29
30impl MistralAdapter {
31    #[allow(dead_code)]
32    fn build_default_timeout_secs() -> u64 {
33        std::env::var("AI_HTTP_TIMEOUT_SECS")
34            .ok()
35            .and_then(|s| s.parse::<u64>().ok())
36            .unwrap_or(30)
37    }
38
39    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
40        #[cfg(feature = "unified_transport")]
41        {
42            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
43            let client = crate::transport::client_factory::build_shared_client()
44                .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
45            let t = HttpTransport::with_reqwest_client(client, timeout);
46            return Ok(t.boxed());
47        }
48        #[cfg(not(feature = "unified_transport"))]
49        {
50            let t = HttpTransport::new();
51            return Ok(t.boxed());
52        }
53    }
54
55    pub fn new() -> Result<Self, AiLibError> {
56        let api_key = std::env::var("MISTRAL_API_KEY").ok();
57        let base_url = std::env::var("MISTRAL_BASE_URL")
58            .unwrap_or_else(|_| "https://api.mistral.ai".to_string());
59        let boxed = Self::build_default_transport()?;
60        Ok(Self {
61            transport: boxed,
62            api_key,
63            base_url,
64            metrics: Arc::new(NoopMetrics::new()),
65        })
66    }
67
68    /// Explicit API key / base_url overrides.
69    pub fn new_with_overrides(
70        api_key: Option<String>,
71        base_url: Option<String>,
72    ) -> Result<Self, AiLibError> {
73        let boxed = Self::build_default_transport()?;
74        Ok(Self {
75            transport: boxed,
76            api_key,
77            base_url: base_url.unwrap_or_else(|| {
78                std::env::var("MISTRAL_BASE_URL")
79                    .unwrap_or_else(|_| "https://api.mistral.ai".to_string())
80            }),
81            metrics: Arc::new(NoopMetrics::new()),
82        })
83    }
84
85    /// Construct using an injected object-safe transport reference (for testing/SDKs)
86    pub fn with_transport(
87        transport: DynHttpTransportRef,
88        api_key: Option<String>,
89        base_url: String,
90    ) -> Result<Self, AiLibError> {
91        Ok(Self {
92            transport,
93            api_key,
94            base_url,
95            metrics: Arc::new(NoopMetrics::new()),
96        })
97    }
98
99    /// Construct with an injected transport and metrics implementation
100    pub fn with_transport_and_metrics(
101        transport: DynHttpTransportRef,
102        api_key: Option<String>,
103        base_url: String,
104        metrics: Arc<dyn Metrics>,
105    ) -> Result<Self, AiLibError> {
106        Ok(Self {
107            transport,
108            api_key,
109            base_url,
110            metrics,
111        })
112    }
113
114    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
115        let msgs: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
116            serde_json::json!({
117                "role": match msg.role { Role::System => "system", Role::User => "user", Role::Assistant => "assistant" },
118                "content": msg.content.as_text()
119            })
120        }).collect();
121
122        let mut body = serde_json::json!({ "model": request.model, "messages": msgs });
123        if let Some(temp) = request.temperature {
124            body["temperature"] =
125                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
126        }
127        if let Some(max_tokens) = request.max_tokens {
128            body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
129        }
130
131        // Function calling (OpenAI-compatible schema supported by Mistral chat/completions)
132        if let Some(funcs) = &request.functions {
133            let mapped: Vec<serde_json::Value> = funcs
134                .iter()
135                .map(|t| {
136                    serde_json::json!({
137                        "name": t.name,
138                        "description": t.description,
139                        "parameters": t.parameters.clone().unwrap_or(serde_json::json!({}))
140                    })
141                })
142                .collect();
143            body["functions"] = serde_json::Value::Array(mapped);
144        }
145        if let Some(policy) = &request.function_call {
146            match policy {
147                crate::types::FunctionCallPolicy::Auto(name) => {
148                    if name == "auto" {
149                        body["function_call"] = serde_json::Value::String("auto".to_string());
150                    } else {
151                        body["function_call"] = serde_json::json!({"name": name});
152                    }
153                }
154                crate::types::FunctionCallPolicy::None => {
155                    body["function_call"] = serde_json::Value::String("none".to_string());
156                }
157            }
158        }
159
160        body
161    }
162
163    fn parse_response(
164        &self,
165        response: serde_json::Value,
166    ) -> Result<ChatCompletionResponse, AiLibError> {
167        let choices = response["choices"]
168            .as_array()
169            .unwrap_or(&vec![])
170            .iter()
171            .enumerate()
172            .map(|(index, choice)| {
173                let message = choice["message"].as_object().ok_or_else(|| {
174                    AiLibError::ProviderError("Invalid choice format".to_string())
175                })?;
176                let role = match message["role"].as_str().unwrap_or("user") {
177                    "system" => Role::System,
178                    "assistant" => Role::Assistant,
179                    _ => Role::User,
180                };
181                let content = message["content"].as_str().unwrap_or("").to_string();
182
183                // try to parse function_call if present
184                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
185                if let Some(fc_val) = message.get("function_call") {
186                    if let Ok(fc) = serde_json::from_value::<
187                        crate::types::function_call::FunctionCall,
188                    >(fc_val.clone())
189                    {
190                        function_call = Some(fc);
191                    } else if let Some(name) = fc_val
192                        .get("name")
193                        .and_then(|v| v.as_str())
194                        .map(|s| s.to_string())
195                    {
196                        let args = fc_val.get("arguments").and_then(|a| {
197                            if a.is_string() {
198                                serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
199                            } else {
200                                Some(a.clone())
201                            }
202                        });
203                        function_call = Some(crate::types::function_call::FunctionCall {
204                            name,
205                            arguments: args,
206                        });
207                    }
208                } else if let Some(tool_calls) = message.get("tool_calls").and_then(|v| v.as_array()) {
209                    if let Some(first) = tool_calls.first() {
210                        if let Some(func) = first.get("function").or_else(|| first.get("function_call")) {
211                            if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
212                                let mut args_opt = func.get("arguments").cloned();
213                                if let Some(args_val) = &args_opt {
214                                    if args_val.is_string() {
215                                        if let Some(s) = args_val.as_str() {
216                                            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s) {
217                                                args_opt = Some(parsed);
218                                            }
219                                        }
220                                    }
221                                }
222                                function_call = Some(crate::types::function_call::FunctionCall {
223                                    name: name.to_string(),
224                                    arguments: args_opt,
225                                });
226                            }
227                        }
228                    }
229                }
230
231                Ok(Choice {
232                    index: index as u32,
233                    message: Message {
234                        role,
235                        content: crate::types::common::Content::Text(content),
236                        function_call,
237                    },
238                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
239                })
240            })
241            .collect::<Result<Vec<_>, AiLibError>>()?;
242
243        let usage = response["usage"].as_object().ok_or_else(|| {
244            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
245        })?;
246        let usage = Usage {
247            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
248            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
249            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
250        };
251
252        Ok(ChatCompletionResponse {
253            id: response["id"].as_str().unwrap_or_default().to_string(),
254            object: response["object"].as_str().unwrap_or_default().to_string(),
255            created: response["created"].as_u64().unwrap_or(0),
256            model: response["model"].as_str().unwrap_or_default().to_string(),
257            choices,
258            usage,
259            usage_status: UsageStatus::Finalized, // Mistral provides accurate usage data
260        })
261    }
262}
263
264#[cfg(not(feature = "unified_sse"))]
265fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
266    let mut i = 0;
267    while i < buffer.len().saturating_sub(1) {
268        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
269            return Some(i + 2);
270        }
271        if i < buffer.len().saturating_sub(3)
272            && buffer[i] == b'\r'
273            && buffer[i + 1] == b'\n'
274            && buffer[i + 2] == b'\r'
275            && buffer[i + 3] == b'\n'
276        {
277            return Some(i + 4);
278        }
279        i += 1;
280    }
281    None
282}
283
284#[cfg(not(feature = "unified_sse"))]
285fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
286    for line in event_text.lines() {
287        let line = line.trim();
288        if let Some(stripped) = line.strip_prefix("data: ") {
289            let data = stripped;
290            if data == "[DONE]" {
291                return Some(Ok(None));
292            }
293            return Some(parse_chunk_data(data));
294        }
295    }
296    None
297}
298
299#[cfg(not(feature = "unified_sse"))]
300fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
301    let json: serde_json::Value = serde_json::from_str(data)
302        .map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
303    let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
304    if let Some(arr) = json["choices"].as_array() {
305        for (index, choice) in arr.iter().enumerate() {
306            let delta = &choice["delta"];
307            let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r {
308                "assistant" => Role::Assistant,
309                "user" => Role::User,
310                "system" => Role::System,
311                _ => Role::Assistant,
312            });
313            let content = delta
314                .get("content")
315                .and_then(|v| v.as_str())
316                .map(|s| s.to_string());
317            let md = MessageDelta { role, content };
318            let cd = ChoiceDelta {
319                index: index as u32,
320                delta: md,
321                finish_reason: choice
322                    .get("finish_reason")
323                    .and_then(|v| v.as_str())
324                    .map(|s| s.to_string()),
325            };
326            choices_vec.push(cd);
327        }
328    }
329
330    Ok(Some(ChatCompletionChunk {
331        id: json["id"].as_str().unwrap_or_default().to_string(),
332        object: json["object"]
333            .as_str()
334            .unwrap_or("chat.completion.chunk")
335            .to_string(),
336        created: json["created"].as_u64().unwrap_or(0),
337        model: json["model"].as_str().unwrap_or_default().to_string(),
338        choices: choices_vec,
339    }))
340}
341
342fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
343    let mut chunks = Vec::new();
344    let mut start = 0;
345    let s = text.as_bytes();
346    while start < s.len() {
347        let end = std::cmp::min(start + max_len, s.len());
348        let mut cut = end;
349        if end < s.len() {
350            if let Some(pos) = text[start..end].rfind(' ') {
351                cut = start + pos;
352            }
353        }
354        if cut == start {
355            cut = end;
356        }
357        let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
358        chunks.push(chunk);
359        start = cut;
360        if start < s.len() && s[start] == b' ' {
361            start += 1;
362        }
363    }
364    chunks
365}
366
367#[async_trait::async_trait]
368impl ChatApi for MistralAdapter {
369    async fn chat_completion(
370        &self,
371        request: ChatCompletionRequest,
372    ) -> Result<ChatCompletionResponse, AiLibError> {
373        self.metrics.incr_counter("mistral.requests", 1).await;
374        let timer = self
375            .metrics
376            .start_timer("mistral.request_duration_ms")
377            .await;
378
379        let url = format!("{}{}", self.base_url, "/v1/chat/completions");
380        let provider_request = self.convert_request(&request);
381        let mut headers = HashMap::new();
382        headers.insert("Content-Type".to_string(), "application/json".to_string());
383        if let Some(key) = &self.api_key {
384            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
385        }
386        let response_json = self
387            .transport
388            .post_json(&url, Some(headers), provider_request)
389            .await?;
390        if let Some(t) = timer {
391            t.stop();
392        }
393        self.parse_response(response_json)
394    }
395
396    async fn chat_completion_stream(
397        &self,
398        request: ChatCompletionRequest,
399    ) -> Result<
400        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
401        AiLibError,
402    > {
403        let mut stream_request = self.convert_request(&request);
404        stream_request["stream"] = serde_json::Value::Bool(true);
405
406        let url = format!("{}{}", self.base_url, "/v1/chat/completions");
407
408        let mut headers = HashMap::new();
409        headers.insert("Accept".to_string(), "text/event-stream".to_string());
410        if let Some(key) = &self.api_key {
411            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
412        }
413        if let Ok(mut byte_stream) = self
414            .transport
415            .post_stream(&url, Some(headers.clone()), stream_request.clone())
416            .await
417        {
418            let (tx, rx) = mpsc::unbounded_channel();
419            tokio::spawn(async move {
420                let mut buffer = Vec::new();
421                while let Some(item) = byte_stream.next().await {
422                    match item {
423                        Ok(bytes) => {
424                            buffer.extend_from_slice(&bytes);
425                            #[cfg(feature = "unified_sse")]
426                            {
427                                while let Some(boundary) =
428                                    crate::sse::parser::find_event_boundary(&buffer)
429                                {
430                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
431                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
432                                        if let Some(parsed) =
433                                            crate::sse::parser::parse_sse_event(event_text)
434                                        {
435                                            match parsed {
436                                                Ok(Some(chunk)) => {
437                                                    if tx.send(Ok(chunk)).is_err() {
438                                                        return;
439                                                    }
440                                                }
441                                                Ok(None) => return,
442                                                Err(e) => {
443                                                    let _ = tx.send(Err(e));
444                                                    return;
445                                                }
446                                            }
447                                        }
448                                    }
449                                }
450                            }
451                            #[cfg(not(feature = "unified_sse"))]
452                            {
453                                while let Some(boundary) = find_event_boundary(&buffer) {
454                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
455                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
456                                        if let Some(parsed) = parse_sse_event(event_text) {
457                                            match parsed {
458                                                Ok(Some(chunk)) => {
459                                                    if tx.send(Ok(chunk)).is_err() {
460                                                        return;
461                                                    }
462                                                }
463                                                Ok(None) => return,
464                                                Err(e) => {
465                                                    let _ = tx.send(Err(e));
466                                                    return;
467                                                }
468                                            }
469                                        }
470                                    }
471                                }
472                            }
473                        }
474                        Err(e) => {
475                            let _ = tx.send(Err(AiLibError::ProviderError(format!(
476                                "Stream error: {}",
477                                e
478                            ))));
479                            break;
480                        }
481                    }
482                }
483            });
484            let stream = UnboundedReceiverStream::new(rx);
485            return Ok(Box::new(Box::pin(stream)));
486        }
487
488        // fallback: call chat_completion and stream chunks
489        let finished = self.chat_completion(request).await?;
490        let text = finished
491            .choices
492            .first()
493            .map(|c| c.message.content.as_text())
494            .unwrap_or_default();
495        let (tx, rx) = mpsc::unbounded_channel();
496        tokio::spawn(async move {
497            let chunks = split_text_into_chunks(&text, 80);
498            for chunk in chunks {
499                let delta = ChoiceDelta {
500                    index: 0,
501                    delta: MessageDelta {
502                        role: Some(Role::Assistant),
503                        content: Some(chunk.clone()),
504                    },
505                    finish_reason: None,
506                };
507                let chunk_obj = ChatCompletionChunk {
508                    id: "simulated".to_string(),
509                    object: "chat.completion.chunk".to_string(),
510                    created: 0,
511                    model: finished.model.clone(),
512                    choices: vec![delta],
513                };
514                if tx.send(Ok(chunk_obj)).is_err() {
515                    return;
516                }
517                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
518            }
519        });
520        let stream = UnboundedReceiverStream::new(rx);
521        Ok(Box::new(Box::pin(stream)))
522    }
523
524    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
525        // Mistral models endpoint
526        let url = format!("{}/v1/models", self.base_url);
527        let mut headers = HashMap::new();
528        if let Some(key) = &self.api_key {
529            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
530        }
531        let response = self.transport.get_json(&url, Some(headers)).await?;
532        Ok(response["data"]
533            .as_array()
534            .unwrap_or(&vec![])
535            .iter()
536            .filter_map(|m| m["id"].as_str().map(|s| s.to_string()))
537            .collect())
538    }
539
540    async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
541        Ok(ModelInfo {
542            id: model_id.to_string(),
543            object: "model".to_string(),
544            created: 0,
545            owned_by: "mistral".to_string(),
546            permission: vec![ModelPermission {
547                id: "default".to_string(),
548                object: "model_permission".to_string(),
549                created: 0,
550                allow_create_engine: false,
551                allow_sampling: true,
552                allow_logprobs: false,
553                allow_search_indices: false,
554                allow_view: true,
555                allow_fine_tuning: false,
556                organization: "*".to_string(),
557                group: None,
558                is_blocking: false,
559            }],
560        })
561    }
562}