ai_lib/provider/
mistral.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission};
2use crate::types::{ChatCompletionRequest, ChatCompletionResponse, AiLibError, Message, Role, Choice, Usage};
3use crate::transport::{HttpTransport, DynHttpTransportRef};
4use futures::stream::Stream;
5use std::collections::HashMap;
6use tokio_stream::wrappers::UnboundedReceiverStream;
7use tokio::sync::mpsc;
8use futures::StreamExt;
9
10/// Mistral adapter (conservative HTTP implementation).
11///
12/// Note: Mistral provides an official Rust SDK (https://github.com/ivangabriele/mistralai-client-rs).
13/// We keep this implementation HTTP-based for now and can swap to the SDK later.
14pub struct MistralAdapter {
15    transport: DynHttpTransportRef,
16    api_key: Option<String>,
17    base_url: String,
18}
19
20impl MistralAdapter {
21    pub fn new() -> Result<Self, AiLibError> {
22        let api_key = std::env::var("MISTRAL_API_KEY").ok();
23        let base_url = std::env::var("MISTRAL_BASE_URL").unwrap_or_else(|_| "https://api.mistral.ai".to_string());
24        let boxed = HttpTransport::new().boxed();
25        Ok(Self { transport: boxed, api_key, base_url })
26    }
27
28    /// Construct using an injected object-safe transport reference (for testing/SDKs)
29    pub fn with_transport(transport: DynHttpTransportRef, api_key: Option<String>, base_url: String) -> Result<Self, AiLibError> {
30        Ok(Self { transport, api_key, base_url })
31    }
32
33    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
34        let msgs: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
35            serde_json::json!({
36                "role": match msg.role { Role::System => "system", Role::User => "user", Role::Assistant => "assistant" },
37                "content": msg.content
38            })
39        }).collect();
40
41        let mut body = serde_json::json!({ "model": request.model, "messages": msgs });
42        if let Some(temp) = request.temperature { body["temperature"] = serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap()); }
43        if let Some(max_tokens) = request.max_tokens { body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens)); }
44        body
45    }
46
47    fn parse_response(&self, response: serde_json::Value) -> Result<ChatCompletionResponse, AiLibError> {
48        let choices = response["choices"].as_array()
49            .unwrap_or(&vec![])
50            .iter()
51            .enumerate()
52            .map(|(index, choice)| {
53                let message = choice["message"].as_object().ok_or_else(|| AiLibError::ProviderError("Invalid choice format".to_string()))?;
54                let role = match message["role"].as_str().unwrap_or("user") { "system" => Role::System, "assistant" => Role::Assistant, _ => Role::User };
55                let content = message["content"].as_str().unwrap_or("").to_string();
56                Ok(Choice { index: index as u32, message: Message { role, content }, finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()) })
57            })
58            .collect::<Result<Vec<_>, AiLibError>>()?;
59
60        let usage = response["usage"].as_object().ok_or_else(|| AiLibError::ProviderError("Invalid response format: usage not found".to_string()))?;
61        let usage = Usage { prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32, completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32, total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32 };
62
63        Ok(ChatCompletionResponse { id: response["id"].as_str().unwrap_or_default().to_string(), object: response["object"].as_str().unwrap_or_default().to_string(), created: response["created"].as_u64().unwrap_or(0), model: response["model"].as_str().unwrap_or_default().to_string(), choices, usage })
64    }
65}
66
67fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
68    let mut i = 0;
69    while i < buffer.len().saturating_sub(1) {
70        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' { return Some(i + 2); }
71        if i < buffer.len().saturating_sub(3) && buffer[i] == b'\r' && buffer[i+1] == b'\n' && buffer[i+2] == b'\r' && buffer[i+3] == b'\n' { return Some(i + 4); }
72        i += 1;
73    }
74    None
75}
76
77fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
78    for line in event_text.lines() {
79        let line = line.trim();
80        if line.starts_with("data: ") {
81            let data = &line[6..];
82            if data == "[DONE]" { return Some(Ok(None)); }
83            return Some(parse_chunk_data(data));
84        }
85    }
86    None
87}
88
89fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
90    let json: serde_json::Value = serde_json::from_str(data).map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
91    let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
92    if let Some(arr) = json["choices"].as_array() {
93        for (index, choice) in arr.iter().enumerate() {
94            let delta = &choice["delta"];
95            let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r { "assistant" => Role::Assistant, "user" => Role::User, "system" => Role::System, _ => Role::Assistant });
96            let content = delta.get("content").and_then(|v| v.as_str()).map(|s| s.to_string());
97            let md = MessageDelta { role, content };
98            let cd = ChoiceDelta { index: index as u32, delta: md, finish_reason: choice.get("finish_reason").and_then(|v| v.as_str()).map(|s| s.to_string()) };
99            choices_vec.push(cd);
100        }
101    }
102
103    Ok(Some(ChatCompletionChunk {
104        id: json["id"].as_str().unwrap_or_default().to_string(),
105        object: json["object"].as_str().unwrap_or("chat.completion.chunk").to_string(),
106        created: json["created"].as_u64().unwrap_or(0),
107        model: json["model"].as_str().unwrap_or_default().to_string(),
108        choices: choices_vec,
109    }))
110}
111
112fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
113    let mut chunks = Vec::new();
114    let mut start = 0;
115    let s = text.as_bytes();
116    while start < s.len() {
117        let end = std::cmp::min(start + max_len, s.len());
118        let mut cut = end;
119        if end < s.len() {
120            if let Some(pos) = text[start..end].rfind(' ') { cut = start + pos; }
121        }
122        if cut == start { cut = end; }
123        let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
124        chunks.push(chunk);
125        start = cut;
126        if start < s.len() && s[start] == b' ' { start += 1; }
127    }
128    chunks
129}
130
131#[async_trait::async_trait]
132impl ChatApi for MistralAdapter {
133    async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, AiLibError> {
134        let provider_request = self.convert_request(&request);
135        let url = format!("{}{}", self.base_url, "/v1/chat/completions");
136
137        let mut headers = HashMap::new();
138        headers.insert("Content-Type".to_string(), "application/json".to_string());
139        if let Some(key) = &self.api_key { headers.insert("Authorization".to_string(), format!("Bearer {}", key)); }
140
141    let response: serde_json::Value = self.transport.post_json(&url, Some(headers), provider_request).await?;
142        self.parse_response(response)
143    }
144
145    async fn chat_completion_stream(&self, request: ChatCompletionRequest) -> Result<Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
146        let mut stream_request = self.convert_request(&request);
147        stream_request["stream"] = serde_json::Value::Bool(true);
148
149        let url = format!("{}{}", self.base_url, "/v1/chat/completions");
150
151        // build client honoring proxy
152        let mut client_builder = reqwest::Client::builder();
153        if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
154            if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) { client_builder = client_builder.proxy(proxy); }
155        }
156        let client = client_builder.build().map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
157
158        let mut headers = HashMap::new();
159        headers.insert("Accept".to_string(), "text/event-stream".to_string());
160        if let Some(key) = &self.api_key { headers.insert("Authorization".to_string(), format!("Bearer {}", key)); }
161
162        let response = client.post(&url).json(&stream_request);
163        let mut req = response;
164        for (k, v) in headers.clone() { req = req.header(k, v); }
165
166        let send_result = req.send().await;
167        match send_result {
168            Ok(response) => {
169                if response.status().is_success() {
170                    let (tx, rx) = mpsc::unbounded_channel();
171                    tokio::spawn(async move {
172                        let mut buffer = Vec::new();
173                        let mut stream = response.bytes_stream();
174                        while let Some(item) = stream.next().await {
175                            match item {
176                                Ok(bytes) => {
177                                    buffer.extend_from_slice(&bytes);
178                                    while let Some(boundary) = find_event_boundary(&buffer) {
179                                        let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
180                                        if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
181                                            if let Some(parsed) = parse_sse_event(event_text) {
182                                                match parsed {
183                                                    Ok(Some(chunk)) => { if tx.send(Ok(chunk)).is_err() { return; } }
184                                                    Ok(None) => return,
185                                                    Err(e) => { let _ = tx.send(Err(e)); return; }
186                                                }
187                                            }
188                                        }
189                                    }
190                                }
191                                Err(e) => { let _ = tx.send(Err(AiLibError::ProviderError(format!("Stream error: {}", e)))); break; }
192                            }
193                        }
194                    });
195                    let stream = UnboundedReceiverStream::new(rx);
196                    return Ok(Box::new(Box::pin(stream)));
197                }
198            }
199            Err(_) => {}
200        }
201
202        // fallback: call chat_completion and stream chunks
203        let finished = self.chat_completion(request).await?;
204        let text = finished.choices.get(0).map(|c| c.message.content.clone()).unwrap_or_default();
205        let (tx, rx) = mpsc::unbounded_channel();
206        tokio::spawn(async move {
207            let chunks = split_text_into_chunks(&text, 80);
208            for chunk in chunks {
209                let delta = ChoiceDelta { index: 0, delta: MessageDelta { role: Some(Role::Assistant), content: Some(chunk.clone()) }, finish_reason: None };
210                let chunk_obj = ChatCompletionChunk { id: "simulated".to_string(), object: "chat.completion.chunk".to_string(), created: 0, model: finished.model.clone(), choices: vec![delta] };
211                if tx.send(Ok(chunk_obj)).is_err() { return; }
212                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
213            }
214        });
215        let stream = UnboundedReceiverStream::new(rx);
216        Ok(Box::new(Box::pin(stream)))
217    }
218
219    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
220        let url = format!("{}/v1/models", self.base_url);
221        let mut headers = HashMap::new();
222        if let Some(key) = &self.api_key { headers.insert("Authorization".to_string(), format!("Bearer {}", key)); }
223    let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
224        Ok(response["data"].as_array().unwrap_or(&vec![]).iter().filter_map(|m| m["id"].as_str().map(|s| s.to_string())).collect())
225    }
226
227    async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
228        Ok(ModelInfo { id: model_id.to_string(), object: "model".to_string(), created: 0, owned_by: "mistral".to_string(), permission: vec![ModelPermission { id: "default".to_string(), object: "model_permission".to_string(), created: 0, allow_create_engine: false, allow_sampling: true, allow_logprobs: false, allow_search_indices: false, allow_view: true, allow_fine_tuning: false, organization: "*".to_string(), group: None, is_blocking: false }] })
229    }
230}