1use crate::api::{
2 ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
3};
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::transport::{DynHttpTransportRef, HttpTransport};
6use crate::types::{
7 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
8};
9use futures::stream::Stream;
10use futures::StreamExt;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::UnboundedReceiverStream;
15
16pub struct MistralAdapter {
21 transport: DynHttpTransportRef,
22 api_key: Option<String>,
23 base_url: String,
24 metrics: Arc<dyn Metrics>,
25}
26
27impl MistralAdapter {
28 pub fn new() -> Result<Self, AiLibError> {
29 let api_key = std::env::var("MISTRAL_API_KEY").ok();
30 let base_url = std::env::var("MISTRAL_BASE_URL")
31 .unwrap_or_else(|_| "https://api.mistral.ai".to_string());
32 let boxed = HttpTransport::new().boxed();
33 Ok(Self {
34 transport: boxed,
35 api_key,
36 base_url,
37 metrics: Arc::new(NoopMetrics::new()),
38 })
39 }
40
41 pub fn with_transport(
43 transport: DynHttpTransportRef,
44 api_key: Option<String>,
45 base_url: String,
46 ) -> Result<Self, AiLibError> {
47 Ok(Self {
48 transport,
49 api_key,
50 base_url,
51 metrics: Arc::new(NoopMetrics::new()),
52 })
53 }
54
55 pub fn with_transport_and_metrics(
57 transport: DynHttpTransportRef,
58 api_key: Option<String>,
59 base_url: String,
60 metrics: Arc<dyn Metrics>,
61 ) -> Result<Self, AiLibError> {
62 Ok(Self {
63 transport,
64 api_key,
65 base_url,
66 metrics,
67 })
68 }
69
70 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
71 let msgs: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
72 serde_json::json!({
73 "role": match msg.role { Role::System => "system", Role::User => "user", Role::Assistant => "assistant" },
74 "content": msg.content.as_text()
75 })
76 }).collect();
77
78 let mut body = serde_json::json!({ "model": request.model, "messages": msgs });
79 if let Some(temp) = request.temperature {
80 body["temperature"] =
81 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
82 }
83 if let Some(max_tokens) = request.max_tokens {
84 body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
85 }
86 body
87 }
88
89 fn parse_response(
90 &self,
91 response: serde_json::Value,
92 ) -> Result<ChatCompletionResponse, AiLibError> {
93 let choices = response["choices"]
94 .as_array()
95 .unwrap_or(&vec![])
96 .iter()
97 .enumerate()
98 .map(|(index, choice)| {
99 let message = choice["message"].as_object().ok_or_else(|| {
100 AiLibError::ProviderError("Invalid choice format".to_string())
101 })?;
102 let role = match message["role"].as_str().unwrap_or("user") {
103 "system" => Role::System,
104 "assistant" => Role::Assistant,
105 _ => Role::User,
106 };
107 let content = message["content"].as_str().unwrap_or("").to_string();
108
109 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
111 if let Some(fc_val) = message.get("function_call") {
112 if let Ok(fc) = serde_json::from_value::<
113 crate::types::function_call::FunctionCall,
114 >(fc_val.clone())
115 {
116 function_call = Some(fc);
117 } else if let Some(name) = fc_val
118 .get("name")
119 .and_then(|v| v.as_str())
120 .map(|s| s.to_string())
121 {
122 let args = fc_val.get("arguments").and_then(|a| {
123 if a.is_string() {
124 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
125 } else {
126 Some(a.clone())
127 }
128 });
129 function_call = Some(crate::types::function_call::FunctionCall {
130 name,
131 arguments: args,
132 });
133 }
134 }
135
136 Ok(Choice {
137 index: index as u32,
138 message: Message {
139 role,
140 content: crate::types::common::Content::Text(content),
141 function_call,
142 },
143 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
144 })
145 })
146 .collect::<Result<Vec<_>, AiLibError>>()?;
147
148 let usage = response["usage"].as_object().ok_or_else(|| {
149 AiLibError::ProviderError("Invalid response format: usage not found".to_string())
150 })?;
151 let usage = Usage {
152 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
153 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
154 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
155 };
156
157 Ok(ChatCompletionResponse {
158 id: response["id"].as_str().unwrap_or_default().to_string(),
159 object: response["object"].as_str().unwrap_or_default().to_string(),
160 created: response["created"].as_u64().unwrap_or(0),
161 model: response["model"].as_str().unwrap_or_default().to_string(),
162 choices,
163 usage,
164 })
165 }
166}
167
168fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
169 let mut i = 0;
170 while i < buffer.len().saturating_sub(1) {
171 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
172 return Some(i + 2);
173 }
174 if i < buffer.len().saturating_sub(3)
175 && buffer[i] == b'\r'
176 && buffer[i + 1] == b'\n'
177 && buffer[i + 2] == b'\r'
178 && buffer[i + 3] == b'\n'
179 {
180 return Some(i + 4);
181 }
182 i += 1;
183 }
184 None
185}
186
187fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
188 for line in event_text.lines() {
189 let line = line.trim();
190 if let Some(stripped) = line.strip_prefix("data: ") {
191 let data = stripped;
192 if data == "[DONE]" {
193 return Some(Ok(None));
194 }
195 return Some(parse_chunk_data(data));
196 }
197 }
198 None
199}
200
201fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
202 let json: serde_json::Value = serde_json::from_str(data)
203 .map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
204 let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
205 if let Some(arr) = json["choices"].as_array() {
206 for (index, choice) in arr.iter().enumerate() {
207 let delta = &choice["delta"];
208 let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r {
209 "assistant" => Role::Assistant,
210 "user" => Role::User,
211 "system" => Role::System,
212 _ => Role::Assistant,
213 });
214 let content = delta
215 .get("content")
216 .and_then(|v| v.as_str())
217 .map(|s| s.to_string());
218 let md = MessageDelta { role, content };
219 let cd = ChoiceDelta {
220 index: index as u32,
221 delta: md,
222 finish_reason: choice
223 .get("finish_reason")
224 .and_then(|v| v.as_str())
225 .map(|s| s.to_string()),
226 };
227 choices_vec.push(cd);
228 }
229 }
230
231 Ok(Some(ChatCompletionChunk {
232 id: json["id"].as_str().unwrap_or_default().to_string(),
233 object: json["object"]
234 .as_str()
235 .unwrap_or("chat.completion.chunk")
236 .to_string(),
237 created: json["created"].as_u64().unwrap_or(0),
238 model: json["model"].as_str().unwrap_or_default().to_string(),
239 choices: choices_vec,
240 }))
241}
242
243fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
244 let mut chunks = Vec::new();
245 let mut start = 0;
246 let s = text.as_bytes();
247 while start < s.len() {
248 let end = std::cmp::min(start + max_len, s.len());
249 let mut cut = end;
250 if end < s.len() {
251 if let Some(pos) = text[start..end].rfind(' ') {
252 cut = start + pos;
253 }
254 }
255 if cut == start {
256 cut = end;
257 }
258 let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
259 chunks.push(chunk);
260 start = cut;
261 if start < s.len() && s[start] == b' ' {
262 start += 1;
263 }
264 }
265 chunks
266}
267
268#[async_trait::async_trait]
269impl ChatApi for MistralAdapter {
270 async fn chat_completion(
271 &self,
272 request: ChatCompletionRequest,
273 ) -> Result<ChatCompletionResponse, AiLibError> {
274 self.metrics.incr_counter("mistral.requests", 1).await;
275 let timer = self.metrics.start_timer("mistral.request_duration_ms").await;
276
277 let provider_request = self.convert_request(&request);
278 let url = format!("{}{}", self.base_url, "/v1/chat/completions");
279
280 let mut headers = HashMap::new();
281 headers.insert("Content-Type".to_string(), "application/json".to_string());
282 if let Some(key) = &self.api_key {
283 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
284 }
285
286 let response = match self
287 .transport
288 .post_json(&url, Some(headers), provider_request)
289 .await
290 {
291 Ok(v) => {
292 if let Some(t) = timer {
293 t.stop();
294 }
295 v
296 }
297 Err(e) => {
298 if let Some(t) = timer {
299 t.stop();
300 }
301 return Err(e);
302 }
303 };
304
305 self.parse_response(response)
306 }
307
308 async fn chat_completion_stream(
309 &self,
310 request: ChatCompletionRequest,
311 ) -> Result<
312 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
313 AiLibError,
314 > {
315 let mut stream_request = self.convert_request(&request);
316 stream_request["stream"] = serde_json::Value::Bool(true);
317
318 let url = format!("{}{}", self.base_url, "/v1/chat/completions");
319
320 let mut client_builder = reqwest::Client::builder();
322 if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
323 if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
324 client_builder = client_builder.proxy(proxy);
325 }
326 }
327 let client = client_builder
328 .build()
329 .map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
330
331 let mut headers = HashMap::new();
332 headers.insert("Accept".to_string(), "text/event-stream".to_string());
333 if let Some(key) = &self.api_key {
334 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
335 }
336
337 let response = client.post(&url).json(&stream_request);
338 let mut req = response;
339 for (k, v) in headers.clone() {
340 req = req.header(k, v);
341 }
342
343 let send_result = req.send().await;
344 if let Ok(response) = send_result {
345 if response.status().is_success() {
346 let (tx, rx) = mpsc::unbounded_channel();
347
348 tokio::spawn(async move {
349 let mut buffer = Vec::new();
350 let mut stream = response.bytes_stream();
351
352 while let Some(item) = stream.next().await {
353 match item {
354 Ok(bytes) => {
355 buffer.extend_from_slice(&bytes);
356
357 while let Some(boundary) = find_event_boundary(&buffer) {
359 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
360 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
361 if let Some(parsed) = parse_sse_event(event_text) {
362 match parsed {
363 Ok(Some(chunk)) => {
364 if tx.send(Ok(chunk)).is_err() {
365 return;
366 }
367 }
368 Ok(None) => return, Err(e) => {
370 let _ = tx.send(Err(e));
371 return;
372 }
373 }
374 }
375 }
376 }
377 }
378 Err(e) => {
379 let _ = tx.send(Err(AiLibError::ProviderError(format!(
380 "Stream error: {}",
381 e
382 ))));
383 break;
384 }
385 }
386 }
387 });
388 let stream = UnboundedReceiverStream::new(rx);
389 return Ok(Box::new(Box::pin(stream)));
390 }
391 }
392
393 let finished = self.chat_completion(request).await?;
395 let text = finished
396 .choices
397 .first()
398 .map(|c| c.message.content.as_text())
399 .unwrap_or_default();
400 let (tx, rx) = mpsc::unbounded_channel();
401 tokio::spawn(async move {
402 let chunks = split_text_into_chunks(&text, 80);
403 for chunk in chunks {
404 let delta = ChoiceDelta {
405 index: 0,
406 delta: MessageDelta {
407 role: Some(Role::Assistant),
408 content: Some(chunk.clone()),
409 },
410 finish_reason: None,
411 };
412 let chunk_obj = ChatCompletionChunk {
413 id: "simulated".to_string(),
414 object: "chat.completion.chunk".to_string(),
415 created: 0,
416 model: finished.model.clone(),
417 choices: vec![delta],
418 };
419 if tx.send(Ok(chunk_obj)).is_err() {
420 return;
421 }
422 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
423 }
424 });
425 let stream = UnboundedReceiverStream::new(rx);
426 Ok(Box::new(Box::pin(stream)))
427 }
428
429 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
430 let url = format!("{}/v1/models", self.base_url);
432 let mut headers = HashMap::new();
433 if let Some(key) = &self.api_key {
434 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
435 }
436 let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
437 Ok(response["data"]
438 .as_array()
439 .unwrap_or(&vec![])
440 .iter()
441 .filter_map(|m| m["id"].as_str().map(|s| s.to_string()))
442 .collect())
443 }
444
445 async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
446 Ok(ModelInfo {
447 id: model_id.to_string(),
448 object: "model".to_string(),
449 created: 0,
450 owned_by: "mistral".to_string(),
451 permission: vec![ModelPermission {
452 id: "default".to_string(),
453 object: "model_permission".to_string(),
454 created: 0,
455 allow_create_engine: false,
456 allow_sampling: true,
457 allow_logprobs: false,
458 allow_search_indices: false,
459 allow_view: true,
460 allow_fine_tuning: false,
461 organization: "*".to_string(),
462 group: None,
463 is_blocking: false,
464 }],
465 })
466 }
467}