1use crate::{
4 EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage, LlmChunk, LlmError,
5 LlmProvider, LlmRequest, LlmResponse, LlmStream, Result, StreamUsage, StreamingLlmProvider,
6 Usage,
7};
8use async_trait::async_trait;
9use futures::stream::StreamExt;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13pub struct MistralProvider {
15 api_key: String,
16 model: String,
17 client: reqwest::Client,
18 base_url: String,
19}
20
21#[derive(Serialize)]
22struct MistralRequest {
23 model: String,
24 messages: Vec<MistralMessage>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 temperature: Option<f64>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 max_tokens: Option<u32>,
29}
30
31#[derive(Serialize, Deserialize)]
32struct MistralMessage {
33 role: String,
34 content: String,
35}
36
37#[derive(Deserialize)]
38struct MistralResponse {
39 choices: Vec<MistralChoice>,
40 usage: MistralUsage,
41 model: String,
42}
43
44#[derive(Deserialize)]
45struct MistralChoice {
46 message: MistralMessage,
47}
48
49#[derive(Deserialize)]
50struct MistralUsage {
51 prompt_tokens: u32,
52 completion_tokens: u32,
53 total_tokens: u32,
54}
55
56impl MistralProvider {
57 pub fn new(api_key: String, model: String) -> Self {
59 Self {
60 api_key,
61 model,
62 client: reqwest::Client::new(),
63 base_url: "https://api.mistral.ai/v1".to_string(),
64 }
65 }
66
67 pub fn for_embeddings(api_key: String) -> Self {
69 Self::new(api_key, "mistral-embed".to_string())
70 }
71
72 pub fn with_base_url(mut self, base_url: String) -> Self {
74 self.base_url = base_url;
75 self
76 }
77}
78
79#[async_trait]
80impl LlmProvider for MistralProvider {
81 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
82 let mut messages = Vec::new();
83
84 if let Some(system_prompt) = &request.system_prompt {
86 messages.push(MistralMessage {
87 role: "system".to_string(),
88 content: system_prompt.clone(),
89 });
90 }
91
92 messages.push(MistralMessage {
94 role: "user".to_string(),
95 content: request.prompt.clone(),
96 });
97
98 let mistral_request = MistralRequest {
99 model: self.model.clone(),
100 messages,
101 temperature: request.temperature,
102 max_tokens: request.max_tokens,
103 };
104
105 let response = self
106 .client
107 .post(format!("{}/chat/completions", self.base_url))
108 .header("Authorization", format!("Bearer {}", self.api_key))
109 .header("Content-Type", "application/json")
110 .json(&mistral_request)
111 .send()
112 .await?;
113
114 let status = response.status();
115
116 if status == 429 {
117 let retry_after = response
119 .headers()
120 .get("retry-after")
121 .and_then(|v| v.to_str().ok())
122 .and_then(|s| s.parse::<u64>().ok())
123 .map(Duration::from_secs);
124
125 return Err(LlmError::RateLimited(retry_after));
126 }
127
128 let body = response.text().await?;
129
130 if !status.is_success() {
131 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
132 }
133
134 let mistral_response: MistralResponse =
135 serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
136
137 if mistral_response.choices.is_empty() {
138 return Err(LlmError::ApiError("No choices in response".to_string()));
139 }
140
141 Ok(LlmResponse {
142 content: mistral_response.choices[0].message.content.clone(),
143 model: mistral_response.model,
144 usage: Some(Usage {
145 prompt_tokens: mistral_response.usage.prompt_tokens,
146 completion_tokens: mistral_response.usage.completion_tokens,
147 total_tokens: mistral_response.usage.total_tokens,
148 }),
149 tool_calls: Vec::new(),
150 })
151 }
152}
153
154#[derive(Deserialize)]
157struct MistralStreamChunk {
158 choices: Vec<MistralStreamChoice>,
159 #[serde(default)]
160 usage: Option<MistralUsage>,
161 model: String,
162}
163
164#[derive(Deserialize)]
165struct MistralStreamChoice {
166 delta: MistralDelta,
167 finish_reason: Option<String>,
168}
169
170#[derive(Deserialize)]
171struct MistralDelta {
172 #[serde(default)]
173 content: Option<String>,
174}
175
176#[async_trait]
177impl StreamingLlmProvider for MistralProvider {
178 async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
179 let mut messages = Vec::new();
180
181 if let Some(system_prompt) = &request.system_prompt {
182 messages.push(MistralMessage {
183 role: "system".to_string(),
184 content: system_prompt.clone(),
185 });
186 }
187
188 messages.push(MistralMessage {
189 role: "user".to_string(),
190 content: request.prompt.clone(),
191 });
192
193 let mistral_request = serde_json::json!({
194 "model": self.model,
195 "messages": messages,
196 "temperature": request.temperature,
197 "max_tokens": request.max_tokens,
198 "stream": true
199 });
200
201 let response = self
202 .client
203 .post(format!("{}/chat/completions", self.base_url))
204 .header("Authorization", format!("Bearer {}", self.api_key))
205 .header("Content-Type", "application/json")
206 .json(&mistral_request)
207 .send()
208 .await?;
209
210 let status = response.status();
211 if status == 429 {
212 let retry_after = response
214 .headers()
215 .get("retry-after")
216 .and_then(|v| v.to_str().ok())
217 .and_then(|s| s.parse::<u64>().ok())
218 .map(Duration::from_secs);
219
220 return Err(LlmError::RateLimited(retry_after));
221 }
222
223 if !status.is_success() {
224 let body = response.text().await?;
225 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
226 }
227
228 let stream = response.bytes_stream();
229
230 let parsed_stream = stream.filter_map(|chunk_result| async move {
231 match chunk_result {
232 Ok(bytes) => {
233 let text = String::from_utf8_lossy(&bytes);
234 for line in text.lines() {
235 if let Some(data) = line.strip_prefix("data: ") {
236 if data == "[DONE]" {
237 return Some(Ok(LlmChunk {
238 content: String::new(),
239 done: true,
240 model: None,
241 usage: None,
242 }));
243 }
244
245 if let Ok(chunk) = serde_json::from_str::<MistralStreamChunk>(data) {
246 if let Some(choice) = chunk.choices.first() {
247 let is_done = choice.finish_reason.is_some();
248 let content = choice.delta.content.clone().unwrap_or_default();
249
250 let usage = chunk.usage.as_ref().map(|u| StreamUsage {
251 prompt_tokens: Some(u.prompt_tokens),
252 completion_tokens: Some(u.completion_tokens),
253 total_tokens: Some(u.total_tokens),
254 });
255
256 if !content.is_empty() || is_done {
257 return Some(Ok(LlmChunk {
258 content,
259 done: is_done,
260 model: if is_done { Some(chunk.model) } else { None },
261 usage,
262 }));
263 }
264 }
265 }
266 }
267 }
268 None
269 }
270 Err(e) => Some(Err(LlmError::NetworkError(e))),
271 }
272 });
273
274 Ok(Box::pin(parsed_stream))
275 }
276}
277
278#[derive(Serialize)]
281struct MistralEmbeddingRequest {
282 input: Vec<String>,
283 model: String,
284}
285
286#[derive(Deserialize)]
287struct MistralEmbeddingResponse {
288 data: Vec<MistralEmbeddingData>,
289 model: String,
290 usage: MistralEmbeddingUsage,
291}
292
293#[derive(Deserialize)]
294struct MistralEmbeddingData {
295 embedding: Vec<f32>,
296 index: usize,
297}
298
299#[derive(Deserialize)]
300struct MistralEmbeddingUsage {
301 prompt_tokens: u32,
302 total_tokens: u32,
303}
304
305#[async_trait]
306impl EmbeddingProvider for MistralProvider {
307 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
308 let model = request.model.unwrap_or_else(|| self.model.clone());
309
310 let mistral_request = MistralEmbeddingRequest {
311 input: request.texts,
312 model,
313 };
314
315 let response = self
316 .client
317 .post(format!("{}/embeddings", self.base_url))
318 .header("Authorization", format!("Bearer {}", self.api_key))
319 .header("Content-Type", "application/json")
320 .json(&mistral_request)
321 .send()
322 .await?;
323
324 let status = response.status();
325
326 if status == 429 {
327 let retry_after = response
329 .headers()
330 .get("retry-after")
331 .and_then(|v| v.to_str().ok())
332 .and_then(|s| s.parse::<u64>().ok())
333 .map(Duration::from_secs);
334
335 return Err(LlmError::RateLimited(retry_after));
336 }
337
338 let body = response.text().await?;
339
340 if !status.is_success() {
341 return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
342 }
343
344 let mistral_response: MistralEmbeddingResponse =
345 serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
346
347 let mut data = mistral_response.data;
349 data.sort_by_key(|d| d.index);
350
351 Ok(EmbeddingResponse {
352 embeddings: data.into_iter().map(|d| d.embedding).collect(),
353 model: mistral_response.model,
354 usage: Some(EmbeddingUsage {
355 prompt_tokens: mistral_response.usage.prompt_tokens,
356 total_tokens: mistral_response.usage.total_tokens,
357 }),
358 })
359 }
360}