1use crate::api::{ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission};
2use crate::types::{ChatCompletionRequest, ChatCompletionResponse, AiLibError, Message, Role, Choice, Usage};
3use crate::transport::{HttpTransport, DynHttpTransportRef};
4use futures::stream::Stream;
5use std::collections::HashMap;
6use tokio_stream::wrappers::UnboundedReceiverStream;
7use tokio::sync::mpsc;
8use futures::StreamExt;
9
10pub struct MistralAdapter {
15 transport: DynHttpTransportRef,
16 api_key: Option<String>,
17 base_url: String,
18}
19
20impl MistralAdapter {
21 pub fn new() -> Result<Self, AiLibError> {
22 let api_key = std::env::var("MISTRAL_API_KEY").ok();
23 let base_url = std::env::var("MISTRAL_BASE_URL").unwrap_or_else(|_| "https://api.mistral.ai".to_string());
24 let boxed = HttpTransport::new().boxed();
25 Ok(Self { transport: boxed, api_key, base_url })
26 }
27
28 pub fn with_transport(transport: DynHttpTransportRef, api_key: Option<String>, base_url: String) -> Result<Self, AiLibError> {
30 Ok(Self { transport, api_key, base_url })
31 }
32
33 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
34 let msgs: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
35 serde_json::json!({
36 "role": match msg.role { Role::System => "system", Role::User => "user", Role::Assistant => "assistant" },
37 "content": msg.content
38 })
39 }).collect();
40
41 let mut body = serde_json::json!({ "model": request.model, "messages": msgs });
42 if let Some(temp) = request.temperature { body["temperature"] = serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap()); }
43 if let Some(max_tokens) = request.max_tokens { body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens)); }
44 body
45 }
46
47 fn parse_response(&self, response: serde_json::Value) -> Result<ChatCompletionResponse, AiLibError> {
48 let choices = response["choices"].as_array()
49 .unwrap_or(&vec![])
50 .iter()
51 .enumerate()
52 .map(|(index, choice)| {
53 let message = choice["message"].as_object().ok_or_else(|| AiLibError::ProviderError("Invalid choice format".to_string()))?;
54 let role = match message["role"].as_str().unwrap_or("user") { "system" => Role::System, "assistant" => Role::Assistant, _ => Role::User };
55 let content = message["content"].as_str().unwrap_or("").to_string();
56 Ok(Choice { index: index as u32, message: Message { role, content }, finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()) })
57 })
58 .collect::<Result<Vec<_>, AiLibError>>()?;
59
60 let usage = response["usage"].as_object().ok_or_else(|| AiLibError::ProviderError("Invalid response format: usage not found".to_string()))?;
61 let usage = Usage { prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32, completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32, total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32 };
62
63 Ok(ChatCompletionResponse { id: response["id"].as_str().unwrap_or_default().to_string(), object: response["object"].as_str().unwrap_or_default().to_string(), created: response["created"].as_u64().unwrap_or(0), model: response["model"].as_str().unwrap_or_default().to_string(), choices, usage })
64 }
65}
66
67fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
68 let mut i = 0;
69 while i < buffer.len().saturating_sub(1) {
70 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' { return Some(i + 2); }
71 if i < buffer.len().saturating_sub(3) && buffer[i] == b'\r' && buffer[i+1] == b'\n' && buffer[i+2] == b'\r' && buffer[i+3] == b'\n' { return Some(i + 4); }
72 i += 1;
73 }
74 None
75}
76
77fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
78 for line in event_text.lines() {
79 let line = line.trim();
80 if line.starts_with("data: ") {
81 let data = &line[6..];
82 if data == "[DONE]" { return Some(Ok(None)); }
83 return Some(parse_chunk_data(data));
84 }
85 }
86 None
87}
88
89fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
90 let json: serde_json::Value = serde_json::from_str(data).map_err(|e| AiLibError::ProviderError(format!("JSON parse error: {}", e)))?;
91 let mut choices_vec: Vec<ChoiceDelta> = Vec::new();
92 if let Some(arr) = json["choices"].as_array() {
93 for (index, choice) in arr.iter().enumerate() {
94 let delta = &choice["delta"];
95 let role = delta.get("role").and_then(|v| v.as_str()).map(|r| match r { "assistant" => Role::Assistant, "user" => Role::User, "system" => Role::System, _ => Role::Assistant });
96 let content = delta.get("content").and_then(|v| v.as_str()).map(|s| s.to_string());
97 let md = MessageDelta { role, content };
98 let cd = ChoiceDelta { index: index as u32, delta: md, finish_reason: choice.get("finish_reason").and_then(|v| v.as_str()).map(|s| s.to_string()) };
99 choices_vec.push(cd);
100 }
101 }
102
103 Ok(Some(ChatCompletionChunk {
104 id: json["id"].as_str().unwrap_or_default().to_string(),
105 object: json["object"].as_str().unwrap_or("chat.completion.chunk").to_string(),
106 created: json["created"].as_u64().unwrap_or(0),
107 model: json["model"].as_str().unwrap_or_default().to_string(),
108 choices: choices_vec,
109 }))
110}
111
112fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
113 let mut chunks = Vec::new();
114 let mut start = 0;
115 let s = text.as_bytes();
116 while start < s.len() {
117 let end = std::cmp::min(start + max_len, s.len());
118 let mut cut = end;
119 if end < s.len() {
120 if let Some(pos) = text[start..end].rfind(' ') { cut = start + pos; }
121 }
122 if cut == start { cut = end; }
123 let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
124 chunks.push(chunk);
125 start = cut;
126 if start < s.len() && s[start] == b' ' { start += 1; }
127 }
128 chunks
129}
130
131#[async_trait::async_trait]
132impl ChatApi for MistralAdapter {
133 async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, AiLibError> {
134 let provider_request = self.convert_request(&request);
135 let url = format!("{}{}", self.base_url, "/v1/chat/completions");
136
137 let mut headers = HashMap::new();
138 headers.insert("Content-Type".to_string(), "application/json".to_string());
139 if let Some(key) = &self.api_key { headers.insert("Authorization".to_string(), format!("Bearer {}", key)); }
140
141 let response: serde_json::Value = self.transport.post_json(&url, Some(headers), provider_request).await?;
142 self.parse_response(response)
143 }
144
145 async fn chat_completion_stream(&self, request: ChatCompletionRequest) -> Result<Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
146 let mut stream_request = self.convert_request(&request);
147 stream_request["stream"] = serde_json::Value::Bool(true);
148
149 let url = format!("{}{}", self.base_url, "/v1/chat/completions");
150
151 let mut client_builder = reqwest::Client::builder();
153 if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
154 if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) { client_builder = client_builder.proxy(proxy); }
155 }
156 let client = client_builder.build().map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
157
158 let mut headers = HashMap::new();
159 headers.insert("Accept".to_string(), "text/event-stream".to_string());
160 if let Some(key) = &self.api_key { headers.insert("Authorization".to_string(), format!("Bearer {}", key)); }
161
162 let response = client.post(&url).json(&stream_request);
163 let mut req = response;
164 for (k, v) in headers.clone() { req = req.header(k, v); }
165
166 let send_result = req.send().await;
167 match send_result {
168 Ok(response) => {
169 if response.status().is_success() {
170 let (tx, rx) = mpsc::unbounded_channel();
171 tokio::spawn(async move {
172 let mut buffer = Vec::new();
173 let mut stream = response.bytes_stream();
174 while let Some(item) = stream.next().await {
175 match item {
176 Ok(bytes) => {
177 buffer.extend_from_slice(&bytes);
178 while let Some(boundary) = find_event_boundary(&buffer) {
179 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
180 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
181 if let Some(parsed) = parse_sse_event(event_text) {
182 match parsed {
183 Ok(Some(chunk)) => { if tx.send(Ok(chunk)).is_err() { return; } }
184 Ok(None) => return,
185 Err(e) => { let _ = tx.send(Err(e)); return; }
186 }
187 }
188 }
189 }
190 }
191 Err(e) => { let _ = tx.send(Err(AiLibError::ProviderError(format!("Stream error: {}", e)))); break; }
192 }
193 }
194 });
195 let stream = UnboundedReceiverStream::new(rx);
196 return Ok(Box::new(Box::pin(stream)));
197 }
198 }
199 Err(_) => {}
200 }
201
202 let finished = self.chat_completion(request).await?;
204 let text = finished.choices.get(0).map(|c| c.message.content.clone()).unwrap_or_default();
205 let (tx, rx) = mpsc::unbounded_channel();
206 tokio::spawn(async move {
207 let chunks = split_text_into_chunks(&text, 80);
208 for chunk in chunks {
209 let delta = ChoiceDelta { index: 0, delta: MessageDelta { role: Some(Role::Assistant), content: Some(chunk.clone()) }, finish_reason: None };
210 let chunk_obj = ChatCompletionChunk { id: "simulated".to_string(), object: "chat.completion.chunk".to_string(), created: 0, model: finished.model.clone(), choices: vec![delta] };
211 if tx.send(Ok(chunk_obj)).is_err() { return; }
212 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
213 }
214 });
215 let stream = UnboundedReceiverStream::new(rx);
216 Ok(Box::new(Box::pin(stream)))
217 }
218
219 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
220 let url = format!("{}/v1/models", self.base_url);
221 let mut headers = HashMap::new();
222 if let Some(key) = &self.api_key { headers.insert("Authorization".to_string(), format!("Bearer {}", key)); }
223 let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
224 Ok(response["data"].as_array().unwrap_or(&vec![]).iter().filter_map(|m| m["id"].as_str().map(|s| s.to_string())).collect())
225 }
226
227 async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
228 Ok(ModelInfo { id: model_id.to_string(), object: "model".to_string(), created: 0, owned_by: "mistral".to_string(), permission: vec![ModelPermission { id: "default".to_string(), object: "model_permission".to_string(), created: 0, allow_create_engine: false, allow_sampling: true, allow_logprobs: false, allow_search_indices: false, allow_view: true, allow_fine_tuning: false, organization: "*".to_string(), group: None, is_blocking: false }] })
229 }
230}