1use crate::api::{
2 ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
3};
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::transport::{DynHttpTransportRef, HttpTransport};
6use crate::types::{
7 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
8};
9use futures::stream::Stream;
10use futures::StreamExt;
11use std::collections::HashMap;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14use std::sync::Arc;
15use tokio::sync::mpsc;
16use tokio_stream::wrappers::UnboundedReceiverStream;
17
18pub struct MistralAdapter {
23 #[allow(dead_code)] transport: DynHttpTransportRef,
25 api_key: Option<String>,
26 base_url: String,
27 metrics: Arc<dyn Metrics>,
28}
29
30impl MistralAdapter {
31 #[allow(dead_code)]
32 fn build_default_timeout_secs() -> u64 {
33 std::env::var("AI_HTTP_TIMEOUT_SECS")
34 .ok()
35 .and_then(|s| s.parse::<u64>().ok())
36 .unwrap_or(30)
37 }
38
39 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
40 #[cfg(feature = "unified_transport")]
41 {
42 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
43 let client = crate::transport::client_factory::build_shared_client()
44 .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
45 let t = HttpTransport::with_reqwest_client(client, timeout);
46 return Ok(t.boxed());
47 }
48 #[cfg(not(feature = "unified_transport"))]
49 {
50 let t = HttpTransport::new();
51 return Ok(t.boxed());
52 }
53 }
54
55 pub fn new() -> Result<Self, AiLibError> {
56 let api_key = std::env::var("MISTRAL_API_KEY").ok();
57 let base_url = std::env::var("MISTRAL_BASE_URL")
58 .unwrap_or_else(|_| "https://api.mistral.ai".to_string());
59 let boxed = Self::build_default_transport()?;
60 Ok(Self {
61 transport: boxed,
62 api_key,
63 base_url,
64 metrics: Arc::new(NoopMetrics::new()),
65 })
66 }
67
68 pub fn new_with_overrides(
70 api_key: Option<String>,
71 base_url: Option<String>,
72 ) -> Result<Self, AiLibError> {
73 let boxed = Self::build_default_transport()?;
74 Ok(Self {
75 transport: boxed,
76 api_key,
77 base_url: base_url.unwrap_or_else(|| {
78 std::env::var("MISTRAL_BASE_URL")
79 .unwrap_or_else(|_| "https://api.mistral.ai".to_string())
80 }),
81 metrics: Arc::new(NoopMetrics::new()),
82 })
83 }
84
85 pub fn with_transport(
87 transport: DynHttpTransportRef,
88 api_key: Option<String>,
89 base_url: String,
90 ) -> Result<Self, AiLibError> {
91 Ok(Self {
92 transport,
93 api_key,
94 base_url,
95 metrics: Arc::new(NoopMetrics::new()),
96 })
97 }
98
99 pub fn with_transport_and_metrics(
101 transport: DynHttpTransportRef,
102 api_key: Option<String>,
103 base_url: String,
104 metrics: Arc<dyn Metrics>,
105 ) -> Result<Self, AiLibError> {
106 Ok(Self {
107 transport,
108 api_key,
109 base_url,
110 metrics,
111 })
112 }
113
114 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
115 let msgs: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
116 serde_json::json!({
117 "role": match msg.role { Role::System => "system", Role::User => "user", Role::Assistant => "assistant" },
118 "content": msg.content.as_text()
119 })
120 }).collect();
121
122 let mut body = serde_json::json!({ "model": request.model, "messages": msgs });
123 if let Some(temp) = request.temperature {
124 body["temperature"] =
125 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
126 }
127 if let Some(max_tokens) = request.max_tokens {
128 body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
129 }
130
131 if let Some(funcs) = &request.functions {
133 let mapped: Vec<serde_json::Value> = funcs
134 .iter()
135 .map(|t| {
136 serde_json::json!({
137 "name": t.name,
138 "description": t.description,
139 "parameters": t.parameters.clone().unwrap_or(serde_json::json!({}))
140 })
141 })
142 .collect();
143 body["functions"] = serde_json::Value::Array(mapped);
144 }
145 if let Some(policy) = &request.function_call {
146 match policy {
147 crate::types::FunctionCallPolicy::Auto(name) => {
148 if name == "auto" {
149 body["function_call"] = serde_json::Value::String("auto".to_string());
150 } else {
151 body["function_call"] = serde_json::json!({"name": name});
152 }
153 }
154 crate::types::FunctionCallPolicy::None => {
155 body["function_call"] = serde_json::Value::String("none".to_string());
156 }
157 }
158 }
159
160 body
161 }
162
163 fn parse_response(
164 &self,
165 response: serde_json::Value,
166 ) -> Result<ChatCompletionResponse, AiLibError> {
167 let choices = response["choices"]
168 .as_array()
169 .unwrap_or(&vec![])
170 .iter()
171 .enumerate()
172 .map(|(index, choice)| {
173 let message = choice["message"].as_object().ok_or_else(|| {
174 AiLibError::ProviderError("Invalid choice format".to_string())
175 })?;
176 let role = match message["role"].as_str().unwrap_or("user") {
177 "system" => Role::System,
178 "assistant" => Role::Assistant,
179 _ => Role::User,
180 };
181 let content = message["content"].as_str().unwrap_or("").to_string();
182
183 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
185 if let Some(fc_val) = message.get("function_call") {
186 if let Ok(fc) = serde_json::from_value::<
187 crate::types::function_call::FunctionCall,
188 >(fc_val.clone())
189 {
190 function_call = Some(fc);
191 } else if let Some(name) = fc_val
192 .get("name")
193 .and_then(|v| v.as_str())
194 .map(|s| s.to_string())
195 {
196 let args = fc_val.get("arguments").and_then(|a| {
197 if a.is_string() {
198 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
199 } else {
200 Some(a.clone())
201 }
202 });
203 function_call = Some(crate::types::function_call::FunctionCall {
204 name,
205 arguments: args,
206 });
207 }
208 } else if let Some(tool_calls) = message.get("tool_calls").and_then(|v| v.as_array()) {
209 if let Some(first) = tool_calls.first() {
210 if let Some(func) = first.get("function").or_else(|| first.get("function_call")) {
211 if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
212 let mut args_opt = func.get("arguments").cloned();
213 if let Some(args_val) = &args_opt {
214 if args_val.is_string() {
215 if let Some(s) = args_val.as_str() {
216 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s) {
217 args_opt = Some(parsed);
218 }
219 }
220 }
221 }
222 function_call = Some(crate::types::function_call::FunctionCall {
223 name: name.to_string(),
224 arguments: args_opt,
225 });
226 }
227 }
228 }
229 }
230
231 Ok(Choice {
232 index: index as u32,
233 message: Message {
234 role,
235 content: crate::types::common::Content::Text(content),
236 function_call,
237 },
238 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
239 })
240 })
241 .collect::<Result<Vec<_>, AiLibError>>()?;
242
243 let usage = response["usage"].as_object().ok_or_else(|| {
244 AiLibError::ProviderError("Invalid response format: usage not found".to_string())
245 })?;
246 let usage = Usage {
247 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
248 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
249 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
250 };
251
252 Ok(ChatCompletionResponse {
253 id: response["id"].as_str().unwrap_or_default().to_string(),
254 object: response["object"].as_str().unwrap_or_default().to_string(),
255 created: response["created"].as_u64().unwrap_or(0),
256 model: response["model"].as_str().unwrap_or_default().to_string(),
257 choices,
258 usage,
259 usage_status: UsageStatus::Finalized, })
261 }
262}
263
264#[cfg(not(feature = "unified_sse"))]
265fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
266 let mut i = 0;
267 while i < buffer.len().saturating_sub(1) {
268 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
269 return Some(i + 2);
270 }
271 if i < buffer.len().saturating_sub(3)
272 && buffer[i] == b'\r'
273 && buffer[i + 1] == b'\n'
274 && buffer[i + 2] == b'\r'
275 && buffer[i + 3] == b'\n'
276 {
277 return Some(i + 4);
278 }
279 i += 1;
280 }
281 None
282}
283
284#[cfg(not(feature = "unified_sse"))]
285fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
286 for line in event_text.lines() {
287 let line = line.trim();
288 if let Some(stripped) = line.strip_prefix("data: ") {
289 let data = stripped;
290 if data == "[DONE]" {
291 return Some(Ok(None));
292 }
293 return Some(parse_chunk_data(data));
294 }
295 }
296 None
297}
298
299#[cfg(not(feature = "unified_sse"))]
300fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
301 let json: serde_json::Value = serde_json::from_str(data)
302 .map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
303 let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
304 if let Some(arr) = json["choices"].as_array() {
305 for (index, choice) in arr.iter().enumerate() {
306 let delta = &choice["delta"];
307 let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r {
308 "assistant" => Role::Assistant,
309 "user" => Role::User,
310 "system" => Role::System,
311 _ => Role::Assistant,
312 });
313 let content = delta
314 .get("content")
315 .and_then(|v| v.as_str())
316 .map(|s| s.to_string());
317 let md = MessageDelta { role, content };
318 let cd = ChoiceDelta {
319 index: index as u32,
320 delta: md,
321 finish_reason: choice
322 .get("finish_reason")
323 .and_then(|v| v.as_str())
324 .map(|s| s.to_string()),
325 };
326 choices_vec.push(cd);
327 }
328 }
329
330 Ok(Some(ChatCompletionChunk {
331 id: json["id"].as_str().unwrap_or_default().to_string(),
332 object: json["object"]
333 .as_str()
334 .unwrap_or("chat.completion.chunk")
335 .to_string(),
336 created: json["created"].as_u64().unwrap_or(0),
337 model: json["model"].as_str().unwrap_or_default().to_string(),
338 choices: choices_vec,
339 }))
340}
341
342fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
343 let mut chunks = Vec::new();
344 let mut start = 0;
345 let s = text.as_bytes();
346 while start < s.len() {
347 let end = std::cmp::min(start + max_len, s.len());
348 let mut cut = end;
349 if end < s.len() {
350 if let Some(pos) = text[start..end].rfind(' ') {
351 cut = start + pos;
352 }
353 }
354 if cut == start {
355 cut = end;
356 }
357 let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
358 chunks.push(chunk);
359 start = cut;
360 if start < s.len() && s[start] == b' ' {
361 start += 1;
362 }
363 }
364 chunks
365}
366
367#[async_trait::async_trait]
368impl ChatApi for MistralAdapter {
369 async fn chat_completion(
370 &self,
371 request: ChatCompletionRequest,
372 ) -> Result<ChatCompletionResponse, AiLibError> {
373 self.metrics.incr_counter("mistral.requests", 1).await;
374 let timer = self
375 .metrics
376 .start_timer("mistral.request_duration_ms")
377 .await;
378
379 let url = format!("{}{}", self.base_url, "/v1/chat/completions");
380 let provider_request = self.convert_request(&request);
381 let mut headers = HashMap::new();
382 headers.insert("Content-Type".to_string(), "application/json".to_string());
383 if let Some(key) = &self.api_key {
384 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
385 }
386 let response_json = self
387 .transport
388 .post_json(&url, Some(headers), provider_request)
389 .await?;
390 if let Some(t) = timer {
391 t.stop();
392 }
393 self.parse_response(response_json)
394 }
395
396 async fn chat_completion_stream(
397 &self,
398 request: ChatCompletionRequest,
399 ) -> Result<
400 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
401 AiLibError,
402 > {
403 let mut stream_request = self.convert_request(&request);
404 stream_request["stream"] = serde_json::Value::Bool(true);
405
406 let url = format!("{}{}", self.base_url, "/v1/chat/completions");
407
408 let mut headers = HashMap::new();
409 headers.insert("Accept".to_string(), "text/event-stream".to_string());
410 if let Some(key) = &self.api_key {
411 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
412 }
413 if let Ok(mut byte_stream) = self
414 .transport
415 .post_stream(&url, Some(headers.clone()), stream_request.clone())
416 .await
417 {
418 let (tx, rx) = mpsc::unbounded_channel();
419 tokio::spawn(async move {
420 let mut buffer = Vec::new();
421 while let Some(item) = byte_stream.next().await {
422 match item {
423 Ok(bytes) => {
424 buffer.extend_from_slice(&bytes);
425 #[cfg(feature = "unified_sse")]
426 {
427 while let Some(boundary) =
428 crate::sse::parser::find_event_boundary(&buffer)
429 {
430 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
431 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
432 if let Some(parsed) =
433 crate::sse::parser::parse_sse_event(event_text)
434 {
435 match parsed {
436 Ok(Some(chunk)) => {
437 if tx.send(Ok(chunk)).is_err() {
438 return;
439 }
440 }
441 Ok(None) => return,
442 Err(e) => {
443 let _ = tx.send(Err(e));
444 return;
445 }
446 }
447 }
448 }
449 }
450 }
451 #[cfg(not(feature = "unified_sse"))]
452 {
453 while let Some(boundary) = find_event_boundary(&buffer) {
454 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
455 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
456 if let Some(parsed) = parse_sse_event(event_text) {
457 match parsed {
458 Ok(Some(chunk)) => {
459 if tx.send(Ok(chunk)).is_err() {
460 return;
461 }
462 }
463 Ok(None) => return,
464 Err(e) => {
465 let _ = tx.send(Err(e));
466 return;
467 }
468 }
469 }
470 }
471 }
472 }
473 }
474 Err(e) => {
475 let _ = tx.send(Err(AiLibError::ProviderError(format!(
476 "Stream error: {}",
477 e
478 ))));
479 break;
480 }
481 }
482 }
483 });
484 let stream = UnboundedReceiverStream::new(rx);
485 return Ok(Box::new(Box::pin(stream)));
486 }
487
488 let finished = self.chat_completion(request).await?;
490 let text = finished
491 .choices
492 .first()
493 .map(|c| c.message.content.as_text())
494 .unwrap_or_default();
495 let (tx, rx) = mpsc::unbounded_channel();
496 tokio::spawn(async move {
497 let chunks = split_text_into_chunks(&text, 80);
498 for chunk in chunks {
499 let delta = ChoiceDelta {
500 index: 0,
501 delta: MessageDelta {
502 role: Some(Role::Assistant),
503 content: Some(chunk.clone()),
504 },
505 finish_reason: None,
506 };
507 let chunk_obj = ChatCompletionChunk {
508 id: "simulated".to_string(),
509 object: "chat.completion.chunk".to_string(),
510 created: 0,
511 model: finished.model.clone(),
512 choices: vec![delta],
513 };
514 if tx.send(Ok(chunk_obj)).is_err() {
515 return;
516 }
517 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
518 }
519 });
520 let stream = UnboundedReceiverStream::new(rx);
521 Ok(Box::new(Box::pin(stream)))
522 }
523
524 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
525 let url = format!("{}/v1/models", self.base_url);
527 let mut headers = HashMap::new();
528 if let Some(key) = &self.api_key {
529 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
530 }
531 let response = self.transport.get_json(&url, Some(headers)).await?;
532 Ok(response["data"]
533 .as_array()
534 .unwrap_or(&vec![])
535 .iter()
536 .filter_map(|m| m["id"].as_str().map(|s| s.to_string()))
537 .collect())
538 }
539
540 async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
541 Ok(ModelInfo {
542 id: model_id.to_string(),
543 object: "model".to_string(),
544 created: 0,
545 owned_by: "mistral".to_string(),
546 permission: vec![ModelPermission {
547 id: "default".to_string(),
548 object: "model_permission".to_string(),
549 created: 0,
550 allow_create_engine: false,
551 allow_sampling: true,
552 allow_logprobs: false,
553 allow_search_indices: false,
554 allow_view: true,
555 allow_fine_tuning: false,
556 organization: "*".to_string(),
557 group: None,
558 is_blocking: false,
559 }],
560 })
561 }
562}