ai_lib/provider/
mistral.rs

1use crate::api::{
2    ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
3};
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::transport::{DynHttpTransportRef, HttpTransport};
6use crate::types::{
7    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
8};
9use futures::stream::Stream;
10use futures::StreamExt;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::UnboundedReceiverStream;
15
16/// Mistral adapter (conservative HTTP implementation).
17///
18/// Note: Mistral provides an official Rust SDK (https://github.com/ivangabriele/mistralai-client-rs).
19/// We keep this implementation HTTP-based for now and can swap to the SDK later.
20pub struct MistralAdapter {
21    transport: DynHttpTransportRef,
22    api_key: Option<String>,
23    base_url: String,
24    metrics: Arc<dyn Metrics>,
25}
26
27impl MistralAdapter {
28    pub fn new() -> Result<Self, AiLibError> {
29        let api_key = std::env::var("MISTRAL_API_KEY").ok();
30        let base_url = std::env::var("MISTRAL_BASE_URL")
31            .unwrap_or_else(|_| "https://api.mistral.ai".to_string());
32        let boxed = HttpTransport::new().boxed();
33        Ok(Self {
34            transport: boxed,
35            api_key,
36            base_url,
37            metrics: Arc::new(NoopMetrics::new()),
38        })
39    }
40
41    /// Construct using an injected object-safe transport reference (for testing/SDKs)
42    pub fn with_transport(
43        transport: DynHttpTransportRef,
44        api_key: Option<String>,
45        base_url: String,
46    ) -> Result<Self, AiLibError> {
47        Ok(Self {
48            transport,
49            api_key,
50            base_url,
51            metrics: Arc::new(NoopMetrics::new()),
52        })
53    }
54
55    /// Construct with an injected transport and metrics implementation
56    pub fn with_transport_and_metrics(
57        transport: DynHttpTransportRef,
58        api_key: Option<String>,
59        base_url: String,
60        metrics: Arc<dyn Metrics>,
61    ) -> Result<Self, AiLibError> {
62        Ok(Self {
63            transport,
64            api_key,
65            base_url,
66            metrics,
67        })
68    }
69
70    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
71        let msgs: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
72            serde_json::json!({
73                "role": match msg.role { Role::System => "system", Role::User => "user", Role::Assistant => "assistant" },
74                "content": msg.content.as_text()
75            })
76        }).collect();
77
78        let mut body = serde_json::json!({ "model": request.model, "messages": msgs });
79        if let Some(temp) = request.temperature {
80            body["temperature"] =
81                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
82        }
83        if let Some(max_tokens) = request.max_tokens {
84            body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
85        }
86        body
87    }
88
89    fn parse_response(
90        &self,
91        response: serde_json::Value,
92    ) -> Result<ChatCompletionResponse, AiLibError> {
93        let choices = response["choices"]
94            .as_array()
95            .unwrap_or(&vec![])
96            .iter()
97            .enumerate()
98            .map(|(index, choice)| {
99                let message = choice["message"].as_object().ok_or_else(|| {
100                    AiLibError::ProviderError("Invalid choice format".to_string())
101                })?;
102                let role = match message["role"].as_str().unwrap_or("user") {
103                    "system" => Role::System,
104                    "assistant" => Role::Assistant,
105                    _ => Role::User,
106                };
107                let content = message["content"].as_str().unwrap_or("").to_string();
108
109                // try to parse function_call if present
110                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
111                if let Some(fc_val) = message.get("function_call") {
112                    if let Ok(fc) = serde_json::from_value::<
113                        crate::types::function_call::FunctionCall,
114                    >(fc_val.clone())
115                    {
116                        function_call = Some(fc);
117                    } else if let Some(name) = fc_val
118                        .get("name")
119                        .and_then(|v| v.as_str())
120                        .map(|s| s.to_string())
121                    {
122                        let args = fc_val.get("arguments").and_then(|a| {
123                            if a.is_string() {
124                                serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
125                            } else {
126                                Some(a.clone())
127                            }
128                        });
129                        function_call = Some(crate::types::function_call::FunctionCall {
130                            name,
131                            arguments: args,
132                        });
133                    }
134                }
135
136                Ok(Choice {
137                    index: index as u32,
138                    message: Message {
139                        role,
140                        content: crate::types::common::Content::Text(content),
141                        function_call,
142                    },
143                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
144                })
145            })
146            .collect::<Result<Vec<_>, AiLibError>>()?;
147
148        let usage = response["usage"].as_object().ok_or_else(|| {
149            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
150        })?;
151        let usage = Usage {
152            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
153            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
154            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
155        };
156
157        Ok(ChatCompletionResponse {
158            id: response["id"].as_str().unwrap_or_default().to_string(),
159            object: response["object"].as_str().unwrap_or_default().to_string(),
160            created: response["created"].as_u64().unwrap_or(0),
161            model: response["model"].as_str().unwrap_or_default().to_string(),
162            choices,
163            usage,
164        })
165    }
166}
167
168fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
169    let mut i = 0;
170    while i < buffer.len().saturating_sub(1) {
171        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
172            return Some(i + 2);
173        }
174        if i < buffer.len().saturating_sub(3)
175            && buffer[i] == b'\r'
176            && buffer[i + 1] == b'\n'
177            && buffer[i + 2] == b'\r'
178            && buffer[i + 3] == b'\n'
179        {
180            return Some(i + 4);
181        }
182        i += 1;
183    }
184    None
185}
186
187fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
188    for line in event_text.lines() {
189        let line = line.trim();
190        if let Some(stripped) = line.strip_prefix("data: ") {
191            let data = stripped;
192            if data == "[DONE]" {
193                return Some(Ok(None));
194            }
195            return Some(parse_chunk_data(data));
196        }
197    }
198    None
199}
200
201fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
202    let json: serde_json::Value = serde_json::from_str(data)
203        .map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
204    let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
205    if let Some(arr) = json["choices"].as_array() {
206        for (index, choice) in arr.iter().enumerate() {
207            let delta = &choice["delta"];
208            let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r {
209                "assistant" => Role::Assistant,
210                "user" => Role::User,
211                "system" => Role::System,
212                _ => Role::Assistant,
213            });
214            let content = delta
215                .get("content")
216                .and_then(|v| v.as_str())
217                .map(|s| s.to_string());
218            let md = MessageDelta { role, content };
219            let cd = ChoiceDelta {
220                index: index as u32,
221                delta: md,
222                finish_reason: choice
223                    .get("finish_reason")
224                    .and_then(|v| v.as_str())
225                    .map(|s| s.to_string()),
226            };
227            choices_vec.push(cd);
228        }
229    }
230
231    Ok(Some(ChatCompletionChunk {
232        id: json["id"].as_str().unwrap_or_default().to_string(),
233        object: json["object"]
234            .as_str()
235            .unwrap_or("chat.completion.chunk")
236            .to_string(),
237        created: json["created"].as_u64().unwrap_or(0),
238        model: json["model"].as_str().unwrap_or_default().to_string(),
239        choices: choices_vec,
240    }))
241}
242
243fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
244    let mut chunks = Vec::new();
245    let mut start = 0;
246    let s = text.as_bytes();
247    while start < s.len() {
248        let end = std::cmp::min(start + max_len, s.len());
249        let mut cut = end;
250        if end < s.len() {
251            if let Some(pos) = text[start..end].rfind(' ') {
252                cut = start + pos;
253            }
254        }
255        if cut == start {
256            cut = end;
257        }
258        let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
259        chunks.push(chunk);
260        start = cut;
261        if start < s.len() && s[start] == b' ' {
262            start += 1;
263        }
264    }
265    chunks
266}
267
268#[async_trait::async_trait]
269impl ChatApi for MistralAdapter {
270    async fn chat_completion(
271        &self,
272        request: ChatCompletionRequest,
273    ) -> Result<ChatCompletionResponse, AiLibError> {
274    self.metrics.incr_counter("mistral.requests", 1).await;
275    let timer = self.metrics.start_timer("mistral.request_duration_ms").await;
276
277        let provider_request = self.convert_request(&request);
278        let url = format!("{}{}", self.base_url, "/v1/chat/completions");
279
280        let mut headers = HashMap::new();
281        headers.insert("Content-Type".to_string(), "application/json".to_string());
282        if let Some(key) = &self.api_key {
283            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
284        }
285
286        let response = match self
287            .transport
288            .post_json(&url, Some(headers), provider_request)
289            .await
290        {
291            Ok(v) => {
292                if let Some(t) = timer {
293                    t.stop();
294                }
295                v
296            }
297            Err(e) => {
298                if let Some(t) = timer {
299                    t.stop();
300                }
301                return Err(e);
302            }
303        };
304
305        self.parse_response(response)
306    }
307
308    async fn chat_completion_stream(
309        &self,
310        request: ChatCompletionRequest,
311    ) -> Result<
312        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
313        AiLibError,
314    > {
315        let mut stream_request = self.convert_request(&request);
316        stream_request["stream"] = serde_json::Value::Bool(true);
317
318        let url = format!("{}{}", self.base_url, "/v1/chat/completions");
319
320        // build client honoring proxy
321        let mut client_builder = reqwest::Client::builder();
322        if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
323            if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
324                client_builder = client_builder.proxy(proxy);
325            }
326        }
327        let client = client_builder
328            .build()
329            .map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
330
331        let mut headers = HashMap::new();
332        headers.insert("Accept".to_string(), "text/event-stream".to_string());
333        if let Some(key) = &self.api_key {
334            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
335        }
336
337        let response = client.post(&url).json(&stream_request);
338        let mut req = response;
339        for (k, v) in headers.clone() {
340            req = req.header(k, v);
341        }
342
343        let send_result = req.send().await;
344        if let Ok(response) = send_result {
345            if response.status().is_success() {
346                let (tx, rx) = mpsc::unbounded_channel();
347
348                tokio::spawn(async move {
349                    let mut buffer = Vec::new();
350                    let mut stream = response.bytes_stream();
351
352                    while let Some(item) = stream.next().await {
353                        match item {
354                            Ok(bytes) => {
355                                buffer.extend_from_slice(&bytes);
356
357                                // process complete events separated by double newlines
358                                while let Some(boundary) = find_event_boundary(&buffer) {
359                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
360                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
361                                        if let Some(parsed) = parse_sse_event(event_text) {
362                                            match parsed {
363                                                Ok(Some(chunk)) => {
364                                                    if tx.send(Ok(chunk)).is_err() {
365                                                        return;
366                                                    }
367                                                }
368                                                Ok(None) => return, // [DONE]
369                                                Err(e) => {
370                                                    let _ = tx.send(Err(e));
371                                                    return;
372                                                }
373                                            }
374                                        }
375                                    }
376                                }
377                            }
378                            Err(e) => {
379                                let _ = tx.send(Err(AiLibError::ProviderError(format!(
380                                    "Stream error: {}",
381                                    e
382                                ))));
383                                break;
384                            }
385                        }
386                    }
387                });
388                let stream = UnboundedReceiverStream::new(rx);
389                return Ok(Box::new(Box::pin(stream)));
390            }
391        }
392
393        // fallback: call chat_completion and stream chunks
394        let finished = self.chat_completion(request).await?;
395        let text = finished
396            .choices
397            .first()
398            .map(|c| c.message.content.as_text())
399            .unwrap_or_default();
400        let (tx, rx) = mpsc::unbounded_channel();
401        tokio::spawn(async move {
402            let chunks = split_text_into_chunks(&text, 80);
403            for chunk in chunks {
404                let delta = ChoiceDelta {
405                    index: 0,
406                    delta: MessageDelta {
407                        role: Some(Role::Assistant),
408                        content: Some(chunk.clone()),
409                    },
410                    finish_reason: None,
411                };
412                let chunk_obj = ChatCompletionChunk {
413                    id: "simulated".to_string(),
414                    object: "chat.completion.chunk".to_string(),
415                    created: 0,
416                    model: finished.model.clone(),
417                    choices: vec![delta],
418                };
419                if tx.send(Ok(chunk_obj)).is_err() {
420                    return;
421                }
422                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
423            }
424        });
425        let stream = UnboundedReceiverStream::new(rx);
426        Ok(Box::new(Box::pin(stream)))
427    }
428
429    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
430        // Mistral models endpoint
431        let url = format!("{}/v1/models", self.base_url);
432        let mut headers = HashMap::new();
433        if let Some(key) = &self.api_key {
434            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
435        }
436        let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
437        Ok(response["data"]
438            .as_array()
439            .unwrap_or(&vec![])
440            .iter()
441            .filter_map(|m| m["id"].as_str().map(|s| s.to_string()))
442            .collect())
443    }
444
445    async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
446        Ok(ModelInfo {
447            id: model_id.to_string(),
448            object: "model".to_string(),
449            created: 0,
450            owned_by: "mistral".to_string(),
451            permission: vec![ModelPermission {
452                id: "default".to_string(),
453                object: "model_permission".to_string(),
454                created: 0,
455                allow_create_engine: false,
456                allow_sampling: true,
457                allow_logprobs: false,
458                allow_search_indices: false,
459                allow_view: true,
460                allow_fine_tuning: false,
461                organization: "*".to_string(),
462                group: None,
463                is_blocking: false,
464            }],
465        })
466    }
467}