1use crate::api::{
2 ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
3};
4use crate::metrics::{Metrics, NoopMetrics};
5use crate::transport::{DynHttpTransportRef, HttpTransport};
6use crate::types::{
7 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
8};
9use futures::stream::Stream;
10use futures::StreamExt;
11use std::collections::HashMap;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14use std::sync::Arc;
15use tokio::sync::mpsc;
16use tokio_stream::wrappers::UnboundedReceiverStream;
17
18pub struct MistralAdapter {
23 #[allow(dead_code)] transport: DynHttpTransportRef,
25 api_key: Option<String>,
26 base_url: String,
27 metrics: Arc<dyn Metrics>,
28}
29
30impl MistralAdapter {
31 fn build_default_timeout_secs() -> u64 {
32 std::env::var("AI_HTTP_TIMEOUT_SECS")
33 .ok()
34 .and_then(|s| s.parse::<u64>().ok())
35 .unwrap_or(30)
36 }
37
38 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
39 #[cfg(feature = "unified_transport")]
40 {
41 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
42 let client = crate::transport::client_factory::build_shared_client()
43 .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
44 let t = HttpTransport::with_reqwest_client(client, timeout);
45 return Ok(t.boxed());
46 }
47 #[cfg(not(feature = "unified_transport"))]
48 {
49 let t = HttpTransport::new();
50 return Ok(t.boxed());
51 }
52 }
53
54 pub fn new() -> Result<Self, AiLibError> {
55 let api_key = std::env::var("MISTRAL_API_KEY").ok();
56 let base_url = std::env::var("MISTRAL_BASE_URL")
57 .unwrap_or_else(|_| "https://api.mistral.ai".to_string());
58 let boxed = Self::build_default_transport()?;
59 Ok(Self {
60 transport: boxed,
61 api_key,
62 base_url,
63 metrics: Arc::new(NoopMetrics::new()),
64 })
65 }
66
67 pub fn new_with_overrides(
69 api_key: Option<String>,
70 base_url: Option<String>,
71 ) -> Result<Self, AiLibError> {
72 let boxed = Self::build_default_transport()?;
73 Ok(Self {
74 transport: boxed,
75 api_key,
76 base_url: base_url.unwrap_or_else(|| {
77 std::env::var("MISTRAL_BASE_URL")
78 .unwrap_or_else(|_| "https://api.mistral.ai".to_string())
79 }),
80 metrics: Arc::new(NoopMetrics::new()),
81 })
82 }
83
84 pub fn with_transport(
86 transport: DynHttpTransportRef,
87 api_key: Option<String>,
88 base_url: String,
89 ) -> Result<Self, AiLibError> {
90 Ok(Self {
91 transport,
92 api_key,
93 base_url,
94 metrics: Arc::new(NoopMetrics::new()),
95 })
96 }
97
98 pub fn with_transport_and_metrics(
100 transport: DynHttpTransportRef,
101 api_key: Option<String>,
102 base_url: String,
103 metrics: Arc<dyn Metrics>,
104 ) -> Result<Self, AiLibError> {
105 Ok(Self {
106 transport,
107 api_key,
108 base_url,
109 metrics,
110 })
111 }
112
113 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
114 let msgs: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
115 serde_json::json!({
116 "role": match msg.role { Role::System => "system", Role::User => "user", Role::Assistant => "assistant" },
117 "content": msg.content.as_text()
118 })
119 }).collect();
120
121 let mut body = serde_json::json!({ "model": request.model, "messages": msgs });
122 if let Some(temp) = request.temperature {
123 body["temperature"] =
124 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
125 }
126 if let Some(max_tokens) = request.max_tokens {
127 body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
128 }
129
130 if let Some(funcs) = &request.functions {
132 let mapped: Vec<serde_json::Value> = funcs
133 .iter()
134 .map(|t| {
135 serde_json::json!({
136 "name": t.name,
137 "description": t.description,
138 "parameters": t.parameters.clone().unwrap_or(serde_json::json!({}))
139 })
140 })
141 .collect();
142 body["functions"] = serde_json::Value::Array(mapped);
143 }
144 if let Some(policy) = &request.function_call {
145 match policy {
146 crate::types::FunctionCallPolicy::Auto(name) => {
147 if name == "auto" {
148 body["function_call"] = serde_json::Value::String("auto".to_string());
149 } else {
150 body["function_call"] = serde_json::json!({"name": name});
151 }
152 }
153 crate::types::FunctionCallPolicy::None => {
154 body["function_call"] = serde_json::Value::String("none".to_string());
155 }
156 }
157 }
158
159 body
160 }
161
162 fn parse_response(
163 &self,
164 response: serde_json::Value,
165 ) -> Result<ChatCompletionResponse, AiLibError> {
166 let choices = response["choices"]
167 .as_array()
168 .unwrap_or(&vec![])
169 .iter()
170 .enumerate()
171 .map(|(index, choice)| {
172 let message = choice["message"].as_object().ok_or_else(|| {
173 AiLibError::ProviderError("Invalid choice format".to_string())
174 })?;
175 let role = match message["role"].as_str().unwrap_or("user") {
176 "system" => Role::System,
177 "assistant" => Role::Assistant,
178 _ => Role::User,
179 };
180 let content = message["content"].as_str().unwrap_or("").to_string();
181
182 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
184 if let Some(fc_val) = message.get("function_call") {
185 if let Ok(fc) = serde_json::from_value::<
186 crate::types::function_call::FunctionCall,
187 >(fc_val.clone())
188 {
189 function_call = Some(fc);
190 } else if let Some(name) = fc_val
191 .get("name")
192 .and_then(|v| v.as_str())
193 .map(|s| s.to_string())
194 {
195 let args = fc_val.get("arguments").and_then(|a| {
196 if a.is_string() {
197 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
198 } else {
199 Some(a.clone())
200 }
201 });
202 function_call = Some(crate::types::function_call::FunctionCall {
203 name,
204 arguments: args,
205 });
206 }
207 } else if let Some(tool_calls) = message.get("tool_calls").and_then(|v| v.as_array()) {
208 if let Some(first) = tool_calls.first() {
209 if let Some(func) = first.get("function").or_else(|| first.get("function_call")) {
210 if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
211 let mut args_opt = func.get("arguments").cloned();
212 if let Some(args_val) = &args_opt {
213 if args_val.is_string() {
214 if let Some(s) = args_val.as_str() {
215 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s) {
216 args_opt = Some(parsed);
217 }
218 }
219 }
220 }
221 function_call = Some(crate::types::function_call::FunctionCall {
222 name: name.to_string(),
223 arguments: args_opt,
224 });
225 }
226 }
227 }
228 }
229
230 Ok(Choice {
231 index: index as u32,
232 message: Message {
233 role,
234 content: crate::types::common::Content::Text(content),
235 function_call,
236 },
237 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
238 })
239 })
240 .collect::<Result<Vec<_>, AiLibError>>()?;
241
242 let usage = response["usage"].as_object().ok_or_else(|| {
243 AiLibError::ProviderError("Invalid response format: usage not found".to_string())
244 })?;
245 let usage = Usage {
246 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
247 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
248 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
249 };
250
251 Ok(ChatCompletionResponse {
252 id: response["id"].as_str().unwrap_or_default().to_string(),
253 object: response["object"].as_str().unwrap_or_default().to_string(),
254 created: response["created"].as_u64().unwrap_or(0),
255 model: response["model"].as_str().unwrap_or_default().to_string(),
256 choices,
257 usage,
258 usage_status: UsageStatus::Finalized, })
260 }
261}
262
263#[cfg(not(feature = "unified_sse"))]
264fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
265 let mut i = 0;
266 while i < buffer.len().saturating_sub(1) {
267 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
268 return Some(i + 2);
269 }
270 if i < buffer.len().saturating_sub(3)
271 && buffer[i] == b'\r'
272 && buffer[i + 1] == b'\n'
273 && buffer[i + 2] == b'\r'
274 && buffer[i + 3] == b'\n'
275 {
276 return Some(i + 4);
277 }
278 i += 1;
279 }
280 None
281}
282
283#[cfg(not(feature = "unified_sse"))]
284fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
285 for line in event_text.lines() {
286 let line = line.trim();
287 if let Some(stripped) = line.strip_prefix("data: ") {
288 let data = stripped;
289 if data == "[DONE]" {
290 return Some(Ok(None));
291 }
292 return Some(parse_chunk_data(data));
293 }
294 }
295 None
296}
297
298#[cfg(not(feature = "unified_sse"))]
299fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
300 let json: serde_json::Value = serde_json::from_str(data)
301 .map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
302 let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
303 if let Some(arr) = json["choices"].as_array() {
304 for (index, choice) in arr.iter().enumerate() {
305 let delta = &choice["delta"];
306 let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r {
307 "assistant" => Role::Assistant,
308 "user" => Role::User,
309 "system" => Role::System,
310 _ => Role::Assistant,
311 });
312 let content = delta
313 .get("content")
314 .and_then(|v| v.as_str())
315 .map(|s| s.to_string());
316 let md = MessageDelta { role, content };
317 let cd = ChoiceDelta {
318 index: index as u32,
319 delta: md,
320 finish_reason: choice
321 .get("finish_reason")
322 .and_then(|v| v.as_str())
323 .map(|s| s.to_string()),
324 };
325 choices_vec.push(cd);
326 }
327 }
328
329 Ok(Some(ChatCompletionChunk {
330 id: json["id"].as_str().unwrap_or_default().to_string(),
331 object: json["object"]
332 .as_str()
333 .unwrap_or("chat.completion.chunk")
334 .to_string(),
335 created: json["created"].as_u64().unwrap_or(0),
336 model: json["model"].as_str().unwrap_or_default().to_string(),
337 choices: choices_vec,
338 }))
339}
340
341fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
342 let mut chunks = Vec::new();
343 let mut start = 0;
344 let s = text.as_bytes();
345 while start < s.len() {
346 let end = std::cmp::min(start + max_len, s.len());
347 let mut cut = end;
348 if end < s.len() {
349 if let Some(pos) = text[start..end].rfind(' ') {
350 cut = start + pos;
351 }
352 }
353 if cut == start {
354 cut = end;
355 }
356 let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
357 chunks.push(chunk);
358 start = cut;
359 if start < s.len() && s[start] == b' ' {
360 start += 1;
361 }
362 }
363 chunks
364}
365
366#[async_trait::async_trait]
367impl ChatApi for MistralAdapter {
368 async fn chat_completion(
369 &self,
370 request: ChatCompletionRequest,
371 ) -> Result<ChatCompletionResponse, AiLibError> {
372 self.metrics.incr_counter("mistral.requests", 1).await;
373 let timer = self
374 .metrics
375 .start_timer("mistral.request_duration_ms")
376 .await;
377
378 let url = format!("{}{}", self.base_url, "/v1/chat/completions");
379 let provider_request = self.convert_request(&request);
380 let mut headers = HashMap::new();
381 headers.insert("Content-Type".to_string(), "application/json".to_string());
382 if let Some(key) = &self.api_key {
383 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
384 }
385 let response_json = self
386 .transport
387 .post_json(&url, Some(headers), provider_request)
388 .await?;
389 if let Some(t) = timer {
390 t.stop();
391 }
392 self.parse_response(response_json)
393 }
394
395 async fn chat_completion_stream(
396 &self,
397 request: ChatCompletionRequest,
398 ) -> Result<
399 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
400 AiLibError,
401 > {
402 let mut stream_request = self.convert_request(&request);
403 stream_request["stream"] = serde_json::Value::Bool(true);
404
405 let url = format!("{}{}", self.base_url, "/v1/chat/completions");
406
407 let mut headers = HashMap::new();
408 headers.insert("Accept".to_string(), "text/event-stream".to_string());
409 if let Some(key) = &self.api_key {
410 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
411 }
412 if let Ok(mut byte_stream) = self
413 .transport
414 .post_stream(&url, Some(headers.clone()), stream_request.clone())
415 .await
416 {
417 let (tx, rx) = mpsc::unbounded_channel();
418 tokio::spawn(async move {
419 let mut buffer = Vec::new();
420 while let Some(item) = byte_stream.next().await {
421 match item {
422 Ok(bytes) => {
423 buffer.extend_from_slice(&bytes);
424 #[cfg(feature = "unified_sse")]
425 {
426 while let Some(boundary) =
427 crate::sse::parser::find_event_boundary(&buffer)
428 {
429 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
430 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
431 if let Some(parsed) =
432 crate::sse::parser::parse_sse_event(event_text)
433 {
434 match parsed {
435 Ok(Some(chunk)) => {
436 if tx.send(Ok(chunk)).is_err() {
437 return;
438 }
439 }
440 Ok(None) => return,
441 Err(e) => {
442 let _ = tx.send(Err(e));
443 return;
444 }
445 }
446 }
447 }
448 }
449 }
450 #[cfg(not(feature = "unified_sse"))]
451 {
452 while let Some(boundary) = find_event_boundary(&buffer) {
453 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
454 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
455 if let Some(parsed) = parse_sse_event(event_text) {
456 match parsed {
457 Ok(Some(chunk)) => {
458 if tx.send(Ok(chunk)).is_err() {
459 return;
460 }
461 }
462 Ok(None) => return,
463 Err(e) => {
464 let _ = tx.send(Err(e));
465 return;
466 }
467 }
468 }
469 }
470 }
471 }
472 }
473 Err(e) => {
474 let _ = tx.send(Err(AiLibError::ProviderError(format!(
475 "Stream error: {}",
476 e
477 ))));
478 break;
479 }
480 }
481 }
482 });
483 let stream = UnboundedReceiverStream::new(rx);
484 return Ok(Box::new(Box::pin(stream)));
485 }
486
487 let finished = self.chat_completion(request).await?;
489 let text = finished
490 .choices
491 .first()
492 .map(|c| c.message.content.as_text())
493 .unwrap_or_default();
494 let (tx, rx) = mpsc::unbounded_channel();
495 tokio::spawn(async move {
496 let chunks = split_text_into_chunks(&text, 80);
497 for chunk in chunks {
498 let delta = ChoiceDelta {
499 index: 0,
500 delta: MessageDelta {
501 role: Some(Role::Assistant),
502 content: Some(chunk.clone()),
503 },
504 finish_reason: None,
505 };
506 let chunk_obj = ChatCompletionChunk {
507 id: "simulated".to_string(),
508 object: "chat.completion.chunk".to_string(),
509 created: 0,
510 model: finished.model.clone(),
511 choices: vec![delta],
512 };
513 if tx.send(Ok(chunk_obj)).is_err() {
514 return;
515 }
516 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
517 }
518 });
519 let stream = UnboundedReceiverStream::new(rx);
520 Ok(Box::new(Box::pin(stream)))
521 }
522
523 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
524 let url = format!("{}/v1/models", self.base_url);
526 let mut headers = HashMap::new();
527 if let Some(key) = &self.api_key {
528 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
529 }
530 let response = self.transport.get_json(&url, Some(headers)).await?;
531 Ok(response["data"]
532 .as_array()
533 .unwrap_or(&vec![])
534 .iter()
535 .filter_map(|m| m["id"].as_str().map(|s| s.to_string()))
536 .collect())
537 }
538
539 async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
540 Ok(ModelInfo {
541 id: model_id.to_string(),
542 object: "model".to_string(),
543 created: 0,
544 owned_by: "mistral".to_string(),
545 permission: vec![ModelPermission {
546 id: "default".to_string(),
547 object: "model_permission".to_string(),
548 created: 0,
549 allow_create_engine: false,
550 allow_sampling: true,
551 allow_logprobs: false,
552 allow_search_indices: false,
553 allow_view: true,
554 allow_fine_tuning: false,
555 organization: "*".to_string(),
556 group: None,
557 is_blocking: false,
558 }],
559 })
560 }
561}