1use crate::api::{
2 ChatCompletionChunk, ChatProvider, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
3};
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::transport::{DynHttpTransportRef, HttpTransport};
6use crate::types::{
7 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
8 UsageStatus,
9};
10use futures::stream::Stream;
11use futures::StreamExt;
12use std::collections::HashMap;
13use std::sync::Arc;
14#[cfg(feature = "unified_transport")]
15use std::time::Duration;
16use tokio::sync::mpsc;
17use tokio_stream::wrappers::UnboundedReceiverStream;
18
19pub struct MistralAdapter {
24 #[allow(dead_code)] transport: DynHttpTransportRef,
26 api_key: Option<String>,
27 base_url: String,
28 metrics: Arc<dyn Metrics>,
29}
30
31impl MistralAdapter {
32 #[allow(dead_code)]
33 fn build_default_timeout_secs() -> u64 {
34 std::env::var("AI_HTTP_TIMEOUT_SECS")
35 .ok()
36 .and_then(|s| s.parse::<u64>().ok())
37 .unwrap_or(30)
38 }
39
40 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
41 #[cfg(feature = "unified_transport")]
42 {
43 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
44 let client = crate::transport::client_factory::build_shared_client().map_err(|e| {
45 AiLibError::NetworkError(format!("Failed to build http client: {}", e))
46 })?;
47 let t = HttpTransport::with_reqwest_client(client, timeout);
48 Ok(t.boxed())
49 }
50 #[cfg(not(feature = "unified_transport"))]
51 {
52 let t = HttpTransport::new();
53 return Ok(t.boxed());
54 }
55 }
56
57 pub fn new() -> Result<Self, AiLibError> {
58 let api_key = std::env::var("MISTRAL_API_KEY").ok();
59 let base_url = std::env::var("MISTRAL_BASE_URL")
60 .unwrap_or_else(|_| "https://api.mistral.ai".to_string());
61 let boxed = Self::build_default_transport()?;
62 Ok(Self {
63 transport: boxed,
64 api_key,
65 base_url,
66 metrics: Arc::new(NoopMetrics::new()),
67 })
68 }
69
70 pub fn new_with_overrides(
72 api_key: Option<String>,
73 base_url: Option<String>,
74 ) -> Result<Self, AiLibError> {
75 let boxed = Self::build_default_transport()?;
76 Ok(Self {
77 transport: boxed,
78 api_key,
79 base_url: base_url.unwrap_or_else(|| {
80 std::env::var("MISTRAL_BASE_URL")
81 .unwrap_or_else(|_| "https://api.mistral.ai".to_string())
82 }),
83 metrics: Arc::new(NoopMetrics::new()),
84 })
85 }
86
87 pub fn with_transport(
89 transport: DynHttpTransportRef,
90 api_key: Option<String>,
91 base_url: String,
92 ) -> Result<Self, AiLibError> {
93 Ok(Self {
94 transport,
95 api_key,
96 base_url,
97 metrics: Arc::new(NoopMetrics::new()),
98 })
99 }
100
101 pub fn with_transport_and_metrics(
103 transport: DynHttpTransportRef,
104 api_key: Option<String>,
105 base_url: String,
106 metrics: Arc<dyn Metrics>,
107 ) -> Result<Self, AiLibError> {
108 Ok(Self {
109 transport,
110 api_key,
111 base_url,
112 metrics,
113 })
114 }
115
116 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
117 let msgs: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
118 serde_json::json!({
119 "role": match msg.role { Role::System => "system", Role::User => "user", Role::Assistant => "assistant" },
120 "content": msg.content.as_text()
121 })
122 }).collect();
123
124 let mut body = serde_json::json!({ "model": request.model, "messages": msgs });
125 if let Some(temp) = request.temperature {
126 body["temperature"] =
127 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
128 }
129 if let Some(max_tokens) = request.max_tokens {
130 body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
131 }
132
133 if let Some(funcs) = &request.functions {
135 let mapped: Vec<serde_json::Value> = funcs
136 .iter()
137 .map(|t| {
138 serde_json::json!({
139 "name": t.name,
140 "description": t.description,
141 "parameters": t.parameters.clone().unwrap_or(serde_json::json!({}))
142 })
143 })
144 .collect();
145 body["functions"] = serde_json::Value::Array(mapped);
146 }
147 if let Some(policy) = &request.function_call {
148 match policy {
149 crate::types::FunctionCallPolicy::Auto(name) => {
150 if name == "auto" {
151 body["function_call"] = serde_json::Value::String("auto".to_string());
152 } else {
153 body["function_call"] = serde_json::json!({"name": name});
154 }
155 }
156 crate::types::FunctionCallPolicy::None => {
157 body["function_call"] = serde_json::Value::String("none".to_string());
158 }
159 }
160 }
161
162 request.apply_extensions(&mut body);
163
164 body
165 }
166
167 fn parse_response(
168 &self,
169 response: serde_json::Value,
170 ) -> Result<ChatCompletionResponse, AiLibError> {
171 let choices = response["choices"]
172 .as_array()
173 .unwrap_or(&vec![])
174 .iter()
175 .enumerate()
176 .map(|(index, choice)| {
177 let message = choice["message"].as_object().ok_or_else(|| {
178 AiLibError::ProviderError("Invalid choice format".to_string())
179 })?;
180 let role = match message["role"].as_str().unwrap_or("user") {
181 "system" => Role::System,
182 "assistant" => Role::Assistant,
183 _ => Role::User,
184 };
185 let content = message["content"].as_str().unwrap_or("").to_string();
186
187 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
189 if let Some(fc_val) = message.get("function_call") {
190 if let Ok(fc) = serde_json::from_value::<
191 crate::types::function_call::FunctionCall,
192 >(fc_val.clone())
193 {
194 function_call = Some(fc);
195 } else if let Some(name) = fc_val
196 .get("name")
197 .and_then(|v| v.as_str())
198 .map(|s| s.to_string())
199 {
200 let args = fc_val.get("arguments").and_then(|a| {
201 if a.is_string() {
202 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
203 } else {
204 Some(a.clone())
205 }
206 });
207 function_call = Some(crate::types::function_call::FunctionCall {
208 name,
209 arguments: args,
210 });
211 }
212 } else if let Some(tool_calls) =
213 message.get("tool_calls").and_then(|v| v.as_array())
214 {
215 if let Some(first) = tool_calls.first() {
216 if let Some(func) =
217 first.get("function").or_else(|| first.get("function_call"))
218 {
219 if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
220 let mut args_opt = func.get("arguments").cloned();
221 if let Some(args_val) = &args_opt {
222 if args_val.is_string() {
223 if let Some(s) = args_val.as_str() {
224 if let Ok(parsed) =
225 serde_json::from_str::<serde_json::Value>(s)
226 {
227 args_opt = Some(parsed);
228 }
229 }
230 }
231 }
232 function_call = Some(crate::types::function_call::FunctionCall {
233 name: name.to_string(),
234 arguments: args_opt,
235 });
236 }
237 }
238 }
239 }
240
241 Ok(Choice {
242 index: index as u32,
243 message: Message {
244 role,
245 content: crate::types::common::Content::Text(content),
246 function_call,
247 },
248 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
249 })
250 })
251 .collect::<Result<Vec<_>, AiLibError>>()?;
252
253 let usage = response["usage"].as_object().ok_or_else(|| {
254 AiLibError::ProviderError("Invalid response format: usage not found".to_string())
255 })?;
256 let usage = Usage {
257 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
258 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
259 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
260 };
261
262 Ok(ChatCompletionResponse {
263 id: response["id"].as_str().unwrap_or_default().to_string(),
264 object: response["object"].as_str().unwrap_or_default().to_string(),
265 created: response["created"].as_u64().unwrap_or(0),
266 model: response["model"].as_str().unwrap_or_default().to_string(),
267 choices,
268 usage,
269 usage_status: UsageStatus::Finalized, })
271 }
272}
273
274#[cfg(not(feature = "unified_sse"))]
275fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
276 let mut i = 0;
277 while i < buffer.len().saturating_sub(1) {
278 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
279 return Some(i + 2);
280 }
281 if i < buffer.len().saturating_sub(3)
282 && buffer[i] == b'\r'
283 && buffer[i + 1] == b'\n'
284 && buffer[i + 2] == b'\r'
285 && buffer[i + 3] == b'\n'
286 {
287 return Some(i + 4);
288 }
289 i += 1;
290 }
291 None
292}
293
294#[cfg(not(feature = "unified_sse"))]
295fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
296 for line in event_text.lines() {
297 let line = line.trim();
298 if let Some(stripped) = line.strip_prefix("data: ") {
299 let data = stripped;
300 if data == "[DONE]" {
301 return Some(Ok(None));
302 }
303 return Some(parse_chunk_data(data));
304 }
305 }
306 None
307}
308
309#[cfg(not(feature = "unified_sse"))]
310fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
311 let json: serde_json::Value = serde_json::from_str(data)
312 .map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
313 let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
314 if let Some(arr) = json["choices"].as_array() {
315 for (index, choice) in arr.iter().enumerate() {
316 let delta = &choice["delta"];
317 let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r {
318 "assistant" => Role::Assistant,
319 "user" => Role::User,
320 "system" => Role::System,
321 _ => Role::Assistant,
322 });
323 let content = delta
324 .get("content")
325 .and_then(|v| v.as_str())
326 .map(|s| s.to_string());
327 let md = MessageDelta { role, content };
328 let cd = ChoiceDelta {
329 index: index as u32,
330 delta: md,
331 finish_reason: choice
332 .get("finish_reason")
333 .and_then(|v| v.as_str())
334 .map(|s| s.to_string()),
335 };
336 choices_vec.push(cd);
337 }
338 }
339
340 Ok(Some(ChatCompletionChunk {
341 id: json["id"].as_str().unwrap_or_default().to_string(),
342 object: json["object"]
343 .as_str()
344 .unwrap_or("chat.completion.chunk")
345 .to_string(),
346 created: json["created"].as_u64().unwrap_or(0),
347 model: json["model"].as_str().unwrap_or_default().to_string(),
348 choices: choices_vec,
349 }))
350}
351
352fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
353 let mut chunks = Vec::new();
354 let mut start = 0;
355 let s = text.as_bytes();
356 while start < s.len() {
357 let end = std::cmp::min(start + max_len, s.len());
358 let mut cut = end;
359 if end < s.len() {
360 if let Some(pos) = text[start..end].rfind(' ') {
361 cut = start + pos;
362 }
363 }
364 if cut == start {
365 cut = end;
366 }
367 let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
368 chunks.push(chunk);
369 start = cut;
370 if start < s.len() && s[start] == b' ' {
371 start += 1;
372 }
373 }
374 chunks
375}
376
377#[async_trait::async_trait]
378impl ChatProvider for MistralAdapter {
379 fn name(&self) -> &str {
380 "Mistral"
381 }
382
383 async fn chat(
384 &self,
385 request: ChatCompletionRequest,
386 ) -> Result<ChatCompletionResponse, AiLibError> {
387 self.metrics.incr_counter("mistral.requests", 1).await;
388 let timer = self
389 .metrics
390 .start_timer("mistral.request_duration_ms")
391 .await;
392
393 let url = format!("{}{}", self.base_url, "/v1/chat/completions");
394 let provider_request = self.convert_request(&request);
395 let mut headers = HashMap::new();
396 headers.insert("Content-Type".to_string(), "application/json".to_string());
397 if let Some(key) = &self.api_key {
398 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
399 }
400 let response_json = self
401 .transport
402 .post_json(&url, Some(headers), provider_request)
403 .await
404 .map_err(|e| e.with_context(&format!("Mistral chat request to {}", url)))?;
405 if let Some(t) = timer {
406 t.stop();
407 }
408 self.parse_response(response_json)
409 }
410
411 async fn stream(
412 &self,
413 request: ChatCompletionRequest,
414 ) -> Result<
415 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
416 AiLibError,
417 > {
418 let mut stream_request = self.convert_request(&request);
419 stream_request["stream"] = serde_json::Value::Bool(true);
420
421 let url = format!("{}{}", self.base_url, "/v1/chat/completions");
422
423 let mut headers = HashMap::new();
424 headers.insert("Accept".to_string(), "text/event-stream".to_string());
425 if let Some(key) = &self.api_key {
426 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
427 }
428 if let Ok(mut byte_stream) = self
429 .transport
430 .post_stream(&url, Some(headers.clone()), stream_request.clone())
431 .await
432 {
433 let (tx, rx) = mpsc::unbounded_channel();
434 tokio::spawn(async move {
435 let mut buffer = Vec::new();
436 while let Some(item) = byte_stream.next().await {
437 match item {
438 Ok(bytes) => {
439 buffer.extend_from_slice(&bytes);
440 #[cfg(feature = "unified_sse")]
441 {
442 while let Some(boundary) =
443 crate::sse::parser::find_event_boundary(&buffer)
444 {
445 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
446 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
447 if let Some(parsed) =
448 crate::sse::parser::parse_sse_event(event_text)
449 {
450 match parsed {
451 Ok(Some(chunk)) => {
452 if tx.send(Ok(chunk)).is_err() {
453 return;
454 }
455 }
456 Ok(None) => return,
457 Err(e) => {
458 let _ = tx.send(Err(e));
459 return;
460 }
461 }
462 }
463 }
464 }
465 }
466 #[cfg(not(feature = "unified_sse"))]
467 {
468 while let Some(boundary) = find_event_boundary(&buffer) {
469 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
470 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
471 if let Some(parsed) = parse_sse_event(event_text) {
472 match parsed {
473 Ok(Some(chunk)) => {
474 if tx.send(Ok(chunk)).is_err() {
475 return;
476 }
477 }
478 Ok(None) => return,
479 Err(e) => {
480 let _ = tx.send(Err(e));
481 return;
482 }
483 }
484 }
485 }
486 }
487 }
488 }
489 Err(e) => {
490 let _ = tx.send(Err(AiLibError::ProviderError(format!(
491 "Stream error: {}",
492 e
493 ))));
494 break;
495 }
496 }
497 }
498 });
499 let stream = UnboundedReceiverStream::new(rx);
500 return Ok(Box::new(Box::pin(stream)));
501 }
502
503 let finished = self.chat(request).await?;
505 let text = finished
506 .choices
507 .first()
508 .map(|c| c.message.content.as_text())
509 .unwrap_or_default();
510 let (tx, rx) = mpsc::unbounded_channel();
511 tokio::spawn(async move {
512 let chunks = split_text_into_chunks(&text, 80);
513 for chunk in chunks {
514 let delta = ChoiceDelta {
515 index: 0,
516 delta: MessageDelta {
517 role: Some(Role::Assistant),
518 content: Some(chunk.clone()),
519 },
520 finish_reason: None,
521 };
522 let chunk_obj = ChatCompletionChunk {
523 id: "simulated".to_string(),
524 object: "chat.completion.chunk".to_string(),
525 created: 0,
526 model: finished.model.clone(),
527 choices: vec![delta],
528 };
529 if tx.send(Ok(chunk_obj)).is_err() {
530 return;
531 }
532 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
533 }
534 });
535 let stream = UnboundedReceiverStream::new(rx);
536 Ok(Box::new(Box::pin(stream)))
537 }
538
539 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
540 let url = format!("{}/v1/models", self.base_url);
542 let mut headers = HashMap::new();
543 if let Some(key) = &self.api_key {
544 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
545 }
546 let response = self.transport.get_json(&url, Some(headers)).await?;
547 Ok(response["data"]
548 .as_array()
549 .unwrap_or(&vec![])
550 .iter()
551 .filter_map(|m| m["id"].as_str().map(|s| s.to_string()))
552 .collect())
553 }
554
555 async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
556 Ok(ModelInfo {
557 id: model_id.to_string(),
558 object: "model".to_string(),
559 created: 0,
560 owned_by: "mistral".to_string(),
561 permission: vec![ModelPermission {
562 id: "default".to_string(),
563 object: "model_permission".to_string(),
564 created: 0,
565 allow_create_engine: false,
566 allow_sampling: true,
567 allow_logprobs: false,
568 allow_search_indices: false,
569 allow_view: true,
570 allow_fine_tuning: false,
571 organization: "*".to_string(),
572 group: None,
573 is_blocking: false,
574 }],
575 })
576 }
577}