ai_lib/provider/
mistral.rs

1use crate::api::{
2    ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
3};
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::transport::{DynHttpTransportRef, HttpTransport};
6use crate::types::{
7    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
8};
9use futures::stream::Stream;
10use futures::StreamExt;
11use std::collections::HashMap;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14use std::sync::Arc;
15use tokio::sync::mpsc;
16use tokio_stream::wrappers::UnboundedReceiverStream;
17
18/// Mistral adapter (conservative HTTP implementation).
19///
20/// Note: Mistral provides an official Rust SDK (https://github.com/ivangabriele/mistralai-client-rs).
21/// We keep this implementation HTTP-based for now and can swap to the SDK later.
22pub struct MistralAdapter {
23    #[allow(dead_code)] // Kept for backward compatibility, now using direct reqwest
24    transport: DynHttpTransportRef,
25    api_key: Option<String>,
26    base_url: String,
27    metrics: Arc<dyn Metrics>,
28}
29
30impl MistralAdapter {
31    fn build_default_timeout_secs() -> u64 {
32        std::env::var("AI_HTTP_TIMEOUT_SECS")
33            .ok()
34            .and_then(|s| s.parse::<u64>().ok())
35            .unwrap_or(30)
36    }
37
38    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
39        #[cfg(feature = "unified_transport")]
40        {
41            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
42            let client = crate::transport::client_factory::build_shared_client()
43                .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
44            let t = HttpTransport::with_reqwest_client(client, timeout);
45            return Ok(t.boxed());
46        }
47        #[cfg(not(feature = "unified_transport"))]
48        {
49            let t = HttpTransport::new();
50            return Ok(t.boxed());
51        }
52    }
53
54    pub fn new() -> Result<Self, AiLibError> {
55        let api_key = std::env::var("MISTRAL_API_KEY").ok();
56        let base_url = std::env::var("MISTRAL_BASE_URL")
57            .unwrap_or_else(|_| "https://api.mistral.ai".to_string());
58        let boxed = Self::build_default_transport()?;
59        Ok(Self {
60            transport: boxed,
61            api_key,
62            base_url,
63            metrics: Arc::new(NoopMetrics::new()),
64        })
65    }
66
67    /// Explicit API key / base_url overrides.
68    pub fn new_with_overrides(
69        api_key: Option<String>,
70        base_url: Option<String>,
71    ) -> Result<Self, AiLibError> {
72        let boxed = Self::build_default_transport()?;
73        Ok(Self {
74            transport: boxed,
75            api_key,
76            base_url: base_url.unwrap_or_else(|| {
77                std::env::var("MISTRAL_BASE_URL")
78                    .unwrap_or_else(|_| "https://api.mistral.ai".to_string())
79            }),
80            metrics: Arc::new(NoopMetrics::new()),
81        })
82    }
83
84    /// Construct using an injected object-safe transport reference (for testing/SDKs)
85    pub fn with_transport(
86        transport: DynHttpTransportRef,
87        api_key: Option<String>,
88        base_url: String,
89    ) -> Result<Self, AiLibError> {
90        Ok(Self {
91            transport,
92            api_key,
93            base_url,
94            metrics: Arc::new(NoopMetrics::new()),
95        })
96    }
97
98    /// Construct with an injected transport and metrics implementation
99    pub fn with_transport_and_metrics(
100        transport: DynHttpTransportRef,
101        api_key: Option<String>,
102        base_url: String,
103        metrics: Arc<dyn Metrics>,
104    ) -> Result<Self, AiLibError> {
105        Ok(Self {
106            transport,
107            api_key,
108            base_url,
109            metrics,
110        })
111    }
112
113    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
114        let msgs: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
115            serde_json::json!({
116                "role": match msg.role { Role::System => "system", Role::User => "user", Role::Assistant => "assistant" },
117                "content": msg.content.as_text()
118            })
119        }).collect();
120
121        let mut body = serde_json::json!({ "model": request.model, "messages": msgs });
122        if let Some(temp) = request.temperature {
123            body["temperature"] =
124                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
125        }
126        if let Some(max_tokens) = request.max_tokens {
127            body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
128        }
129
130        // Function calling (OpenAI-compatible schema supported by Mistral chat/completions)
131        if let Some(funcs) = &request.functions {
132            let mapped: Vec<serde_json::Value> = funcs
133                .iter()
134                .map(|t| {
135                    serde_json::json!({
136                        "name": t.name,
137                        "description": t.description,
138                        "parameters": t.parameters.clone().unwrap_or(serde_json::json!({}))
139                    })
140                })
141                .collect();
142            body["functions"] = serde_json::Value::Array(mapped);
143        }
144        if let Some(policy) = &request.function_call {
145            match policy {
146                crate::types::FunctionCallPolicy::Auto(name) => {
147                    if name == "auto" {
148                        body["function_call"] = serde_json::Value::String("auto".to_string());
149                    } else {
150                        body["function_call"] = serde_json::json!({"name": name});
151                    }
152                }
153                crate::types::FunctionCallPolicy::None => {
154                    body["function_call"] = serde_json::Value::String("none".to_string());
155                }
156            }
157        }
158
159        body
160    }
161
162    fn parse_response(
163        &self,
164        response: serde_json::Value,
165    ) -> Result<ChatCompletionResponse, AiLibError> {
166        let choices = response["choices"]
167            .as_array()
168            .unwrap_or(&vec![])
169            .iter()
170            .enumerate()
171            .map(|(index, choice)| {
172                let message = choice["message"].as_object().ok_or_else(|| {
173                    AiLibError::ProviderError("Invalid choice format".to_string())
174                })?;
175                let role = match message["role"].as_str().unwrap_or("user") {
176                    "system" => Role::System,
177                    "assistant" => Role::Assistant,
178                    _ => Role::User,
179                };
180                let content = message["content"].as_str().unwrap_or("").to_string();
181
182                // try to parse function_call if present
183                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
184                if let Some(fc_val) = message.get("function_call") {
185                    if let Ok(fc) = serde_json::from_value::<
186                        crate::types::function_call::FunctionCall,
187                    >(fc_val.clone())
188                    {
189                        function_call = Some(fc);
190                    } else if let Some(name) = fc_val
191                        .get("name")
192                        .and_then(|v| v.as_str())
193                        .map(|s| s.to_string())
194                    {
195                        let args = fc_val.get("arguments").and_then(|a| {
196                            if a.is_string() {
197                                serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
198                            } else {
199                                Some(a.clone())
200                            }
201                        });
202                        function_call = Some(crate::types::function_call::FunctionCall {
203                            name,
204                            arguments: args,
205                        });
206                    }
207                } else if let Some(tool_calls) = message.get("tool_calls").and_then(|v| v.as_array()) {
208                    if let Some(first) = tool_calls.first() {
209                        if let Some(func) = first.get("function").or_else(|| first.get("function_call")) {
210                            if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
211                                let mut args_opt = func.get("arguments").cloned();
212                                if let Some(args_val) = &args_opt {
213                                    if args_val.is_string() {
214                                        if let Some(s) = args_val.as_str() {
215                                            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s) {
216                                                args_opt = Some(parsed);
217                                            }
218                                        }
219                                    }
220                                }
221                                function_call = Some(crate::types::function_call::FunctionCall {
222                                    name: name.to_string(),
223                                    arguments: args_opt,
224                                });
225                            }
226                        }
227                    }
228                }
229
230                Ok(Choice {
231                    index: index as u32,
232                    message: Message {
233                        role,
234                        content: crate::types::common::Content::Text(content),
235                        function_call,
236                    },
237                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
238                })
239            })
240            .collect::<Result<Vec<_>, AiLibError>>()?;
241
242        let usage = response["usage"].as_object().ok_or_else(|| {
243            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
244        })?;
245        let usage = Usage {
246            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
247            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
248            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
249        };
250
251        Ok(ChatCompletionResponse {
252            id: response["id"].as_str().unwrap_or_default().to_string(),
253            object: response["object"].as_str().unwrap_or_default().to_string(),
254            created: response["created"].as_u64().unwrap_or(0),
255            model: response["model"].as_str().unwrap_or_default().to_string(),
256            choices,
257            usage,
258            usage_status: UsageStatus::Finalized, // Mistral provides accurate usage data
259        })
260    }
261}
262
263#[cfg(not(feature = "unified_sse"))]
264fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
265    let mut i = 0;
266    while i < buffer.len().saturating_sub(1) {
267        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
268            return Some(i + 2);
269        }
270        if i < buffer.len().saturating_sub(3)
271            && buffer[i] == b'\r'
272            && buffer[i + 1] == b'\n'
273            && buffer[i + 2] == b'\r'
274            && buffer[i + 3] == b'\n'
275        {
276            return Some(i + 4);
277        }
278        i += 1;
279    }
280    None
281}
282
283#[cfg(not(feature = "unified_sse"))]
284fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
285    for line in event_text.lines() {
286        let line = line.trim();
287        if let Some(stripped) = line.strip_prefix("data: ") {
288            let data = stripped;
289            if data == "[DONE]" {
290                return Some(Ok(None));
291            }
292            return Some(parse_chunk_data(data));
293        }
294    }
295    None
296}
297
298#[cfg(not(feature = "unified_sse"))]
299fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
300    let json: serde_json::Value = serde_json::from_str(data)
301        .map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
302    let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
303    if let Some(arr) = json["choices"].as_array() {
304        for (index, choice) in arr.iter().enumerate() {
305            let delta = &choice["delta"];
306            let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r {
307                "assistant" => Role::Assistant,
308                "user" => Role::User,
309                "system" => Role::System,
310                _ => Role::Assistant,
311            });
312            let content = delta
313                .get("content")
314                .and_then(|v| v.as_str())
315                .map(|s| s.to_string());
316            let md = MessageDelta { role, content };
317            let cd = ChoiceDelta {
318                index: index as u32,
319                delta: md,
320                finish_reason: choice
321                    .get("finish_reason")
322                    .and_then(|v| v.as_str())
323                    .map(|s| s.to_string()),
324            };
325            choices_vec.push(cd);
326        }
327    }
328
329    Ok(Some(ChatCompletionChunk {
330        id: json["id"].as_str().unwrap_or_default().to_string(),
331        object: json["object"]
332            .as_str()
333            .unwrap_or("chat.completion.chunk")
334            .to_string(),
335        created: json["created"].as_u64().unwrap_or(0),
336        model: json["model"].as_str().unwrap_or_default().to_string(),
337        choices: choices_vec,
338    }))
339}
340
341fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
342    let mut chunks = Vec::new();
343    let mut start = 0;
344    let s = text.as_bytes();
345    while start < s.len() {
346        let end = std::cmp::min(start + max_len, s.len());
347        let mut cut = end;
348        if end < s.len() {
349            if let Some(pos) = text[start..end].rfind(' ') {
350                cut = start + pos;
351            }
352        }
353        if cut == start {
354            cut = end;
355        }
356        let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
357        chunks.push(chunk);
358        start = cut;
359        if start < s.len() && s[start] == b' ' {
360            start += 1;
361        }
362    }
363    chunks
364}
365
366#[async_trait::async_trait]
367impl ChatApi for MistralAdapter {
368    async fn chat_completion(
369        &self,
370        request: ChatCompletionRequest,
371    ) -> Result<ChatCompletionResponse, AiLibError> {
372        self.metrics.incr_counter("mistral.requests", 1).await;
373        let timer = self
374            .metrics
375            .start_timer("mistral.request_duration_ms")
376            .await;
377
378        let url = format!("{}{}", self.base_url, "/v1/chat/completions");
379        let provider_request = self.convert_request(&request);
380        let mut headers = HashMap::new();
381        headers.insert("Content-Type".to_string(), "application/json".to_string());
382        if let Some(key) = &self.api_key {
383            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
384        }
385        let response_json = self
386            .transport
387            .post_json(&url, Some(headers), provider_request)
388            .await?;
389        if let Some(t) = timer {
390            t.stop();
391        }
392        self.parse_response(response_json)
393    }
394
395    async fn chat_completion_stream(
396        &self,
397        request: ChatCompletionRequest,
398    ) -> Result<
399        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
400        AiLibError,
401    > {
402        let mut stream_request = self.convert_request(&request);
403        stream_request["stream"] = serde_json::Value::Bool(true);
404
405        let url = format!("{}{}", self.base_url, "/v1/chat/completions");
406
407        let mut headers = HashMap::new();
408        headers.insert("Accept".to_string(), "text/event-stream".to_string());
409        if let Some(key) = &self.api_key {
410            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
411        }
412        if let Ok(mut byte_stream) = self
413            .transport
414            .post_stream(&url, Some(headers.clone()), stream_request.clone())
415            .await
416        {
417            let (tx, rx) = mpsc::unbounded_channel();
418            tokio::spawn(async move {
419                let mut buffer = Vec::new();
420                while let Some(item) = byte_stream.next().await {
421                    match item {
422                        Ok(bytes) => {
423                            buffer.extend_from_slice(&bytes);
424                            #[cfg(feature = "unified_sse")]
425                            {
426                                while let Some(boundary) =
427                                    crate::sse::parser::find_event_boundary(&buffer)
428                                {
429                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
430                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
431                                        if let Some(parsed) =
432                                            crate::sse::parser::parse_sse_event(event_text)
433                                        {
434                                            match parsed {
435                                                Ok(Some(chunk)) => {
436                                                    if tx.send(Ok(chunk)).is_err() {
437                                                        return;
438                                                    }
439                                                }
440                                                Ok(None) => return,
441                                                Err(e) => {
442                                                    let _ = tx.send(Err(e));
443                                                    return;
444                                                }
445                                            }
446                                        }
447                                    }
448                                }
449                            }
450                            #[cfg(not(feature = "unified_sse"))]
451                            {
452                                while let Some(boundary) = find_event_boundary(&buffer) {
453                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
454                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
455                                        if let Some(parsed) = parse_sse_event(event_text) {
456                                            match parsed {
457                                                Ok(Some(chunk)) => {
458                                                    if tx.send(Ok(chunk)).is_err() {
459                                                        return;
460                                                    }
461                                                }
462                                                Ok(None) => return,
463                                                Err(e) => {
464                                                    let _ = tx.send(Err(e));
465                                                    return;
466                                                }
467                                            }
468                                        }
469                                    }
470                                }
471                            }
472                        }
473                        Err(e) => {
474                            let _ = tx.send(Err(AiLibError::ProviderError(format!(
475                                "Stream error: {}",
476                                e
477                            ))));
478                            break;
479                        }
480                    }
481                }
482            });
483            let stream = UnboundedReceiverStream::new(rx);
484            return Ok(Box::new(Box::pin(stream)));
485        }
486
487        // fallback: call chat_completion and stream chunks
488        let finished = self.chat_completion(request).await?;
489        let text = finished
490            .choices
491            .first()
492            .map(|c| c.message.content.as_text())
493            .unwrap_or_default();
494        let (tx, rx) = mpsc::unbounded_channel();
495        tokio::spawn(async move {
496            let chunks = split_text_into_chunks(&text, 80);
497            for chunk in chunks {
498                let delta = ChoiceDelta {
499                    index: 0,
500                    delta: MessageDelta {
501                        role: Some(Role::Assistant),
502                        content: Some(chunk.clone()),
503                    },
504                    finish_reason: None,
505                };
506                let chunk_obj = ChatCompletionChunk {
507                    id: "simulated".to_string(),
508                    object: "chat.completion.chunk".to_string(),
509                    created: 0,
510                    model: finished.model.clone(),
511                    choices: vec![delta],
512                };
513                if tx.send(Ok(chunk_obj)).is_err() {
514                    return;
515                }
516                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
517            }
518        });
519        let stream = UnboundedReceiverStream::new(rx);
520        Ok(Box::new(Box::pin(stream)))
521    }
522
523    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
524        // Mistral models endpoint
525        let url = format!("{}/v1/models", self.base_url);
526        let mut headers = HashMap::new();
527        if let Some(key) = &self.api_key {
528            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
529        }
530        let response = self.transport.get_json(&url, Some(headers)).await?;
531        Ok(response["data"]
532            .as_array()
533            .unwrap_or(&vec![])
534            .iter()
535            .filter_map(|m| m["id"].as_str().map(|s| s.to_string()))
536            .collect())
537    }
538
539    async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
540        Ok(ModelInfo {
541            id: model_id.to_string(),
542            object: "model".to_string(),
543            created: 0,
544            owned_by: "mistral".to_string(),
545            permission: vec![ModelPermission {
546                id: "default".to_string(),
547                object: "model_permission".to_string(),
548                created: 0,
549                allow_create_engine: false,
550                allow_sampling: true,
551                allow_logprobs: false,
552                allow_search_indices: false,
553                allow_view: true,
554                allow_fine_tuning: false,
555                organization: "*".to_string(),
556                group: None,
557                is_blocking: false,
558            }],
559        })
560    }
561}