oxify_connect_llm/
mistral.rs

1//! Mistral AI LLM provider
2
3use crate::{
4    EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage, LlmChunk, LlmError,
5    LlmProvider, LlmRequest, LlmResponse, LlmStream, Result, StreamUsage, StreamingLlmProvider,
6    Usage,
7};
8use async_trait::async_trait;
9use futures::stream::StreamExt;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13/// Mistral AI provider
14pub struct MistralProvider {
15    api_key: String,
16    model: String,
17    client: reqwest::Client,
18    base_url: String,
19}
20
21#[derive(Serialize)]
22struct MistralRequest {
23    model: String,
24    messages: Vec<MistralMessage>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    temperature: Option<f64>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    max_tokens: Option<u32>,
29}
30
31#[derive(Serialize, Deserialize)]
32struct MistralMessage {
33    role: String,
34    content: String,
35}
36
37#[derive(Deserialize)]
38struct MistralResponse {
39    choices: Vec<MistralChoice>,
40    usage: MistralUsage,
41    model: String,
42}
43
44#[derive(Deserialize)]
45struct MistralChoice {
46    message: MistralMessage,
47}
48
49#[derive(Deserialize)]
50struct MistralUsage {
51    prompt_tokens: u32,
52    completion_tokens: u32,
53    total_tokens: u32,
54}
55
56impl MistralProvider {
57    /// Create a new Mistral provider
58    pub fn new(api_key: String, model: String) -> Self {
59        Self {
60            api_key,
61            model,
62            client: reqwest::Client::new(),
63            base_url: "https://api.mistral.ai/v1".to_string(),
64        }
65    }
66
67    /// Create a provider specifically for embeddings
68    pub fn for_embeddings(api_key: String) -> Self {
69        Self::new(api_key, "mistral-embed".to_string())
70    }
71
72    /// Set custom base URL
73    pub fn with_base_url(mut self, base_url: String) -> Self {
74        self.base_url = base_url;
75        self
76    }
77}
78
79#[async_trait]
80impl LlmProvider for MistralProvider {
81    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
82        let mut messages = Vec::new();
83
84        // Add system message if provided
85        if let Some(system_prompt) = &request.system_prompt {
86            messages.push(MistralMessage {
87                role: "system".to_string(),
88                content: system_prompt.clone(),
89            });
90        }
91
92        // Add user message
93        messages.push(MistralMessage {
94            role: "user".to_string(),
95            content: request.prompt.clone(),
96        });
97
98        let mistral_request = MistralRequest {
99            model: self.model.clone(),
100            messages,
101            temperature: request.temperature,
102            max_tokens: request.max_tokens,
103        };
104
105        let response = self
106            .client
107            .post(format!("{}/chat/completions", self.base_url))
108            .header("Authorization", format!("Bearer {}", self.api_key))
109            .header("Content-Type", "application/json")
110            .json(&mistral_request)
111            .send()
112            .await?;
113
114        let status = response.status();
115
116        if status == 429 {
117            // Extract Retry-After header if present
118            let retry_after = response
119                .headers()
120                .get("retry-after")
121                .and_then(|v| v.to_str().ok())
122                .and_then(|s| s.parse::<u64>().ok())
123                .map(Duration::from_secs);
124
125            return Err(LlmError::RateLimited(retry_after));
126        }
127
128        let body = response.text().await?;
129
130        if !status.is_success() {
131            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
132        }
133
134        let mistral_response: MistralResponse =
135            serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
136
137        if mistral_response.choices.is_empty() {
138            return Err(LlmError::ApiError("No choices in response".to_string()));
139        }
140
141        Ok(LlmResponse {
142            content: mistral_response.choices[0].message.content.clone(),
143            model: mistral_response.model,
144            usage: Some(Usage {
145                prompt_tokens: mistral_response.usage.prompt_tokens,
146                completion_tokens: mistral_response.usage.completion_tokens,
147                total_tokens: mistral_response.usage.total_tokens,
148            }),
149            tool_calls: Vec::new(),
150        })
151    }
152}
153
154// ===== Mistral Streaming Implementation =====
155
156#[derive(Deserialize)]
157struct MistralStreamChunk {
158    choices: Vec<MistralStreamChoice>,
159    #[serde(default)]
160    usage: Option<MistralUsage>,
161    model: String,
162}
163
164#[derive(Deserialize)]
165struct MistralStreamChoice {
166    delta: MistralDelta,
167    finish_reason: Option<String>,
168}
169
170#[derive(Deserialize)]
171struct MistralDelta {
172    #[serde(default)]
173    content: Option<String>,
174}
175
176#[async_trait]
177impl StreamingLlmProvider for MistralProvider {
178    async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
179        let mut messages = Vec::new();
180
181        if let Some(system_prompt) = &request.system_prompt {
182            messages.push(MistralMessage {
183                role: "system".to_string(),
184                content: system_prompt.clone(),
185            });
186        }
187
188        messages.push(MistralMessage {
189            role: "user".to_string(),
190            content: request.prompt.clone(),
191        });
192
193        let mistral_request = serde_json::json!({
194            "model": self.model,
195            "messages": messages,
196            "temperature": request.temperature,
197            "max_tokens": request.max_tokens,
198            "stream": true
199        });
200
201        let response = self
202            .client
203            .post(format!("{}/chat/completions", self.base_url))
204            .header("Authorization", format!("Bearer {}", self.api_key))
205            .header("Content-Type", "application/json")
206            .json(&mistral_request)
207            .send()
208            .await?;
209
210        let status = response.status();
211        if status == 429 {
212            // Extract Retry-After header if present
213            let retry_after = response
214                .headers()
215                .get("retry-after")
216                .and_then(|v| v.to_str().ok())
217                .and_then(|s| s.parse::<u64>().ok())
218                .map(Duration::from_secs);
219
220            return Err(LlmError::RateLimited(retry_after));
221        }
222
223        if !status.is_success() {
224            let body = response.text().await?;
225            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
226        }
227
228        let stream = response.bytes_stream();
229
230        let parsed_stream = stream.filter_map(|chunk_result| async move {
231            match chunk_result {
232                Ok(bytes) => {
233                    let text = String::from_utf8_lossy(&bytes);
234                    for line in text.lines() {
235                        if let Some(data) = line.strip_prefix("data: ") {
236                            if data == "[DONE]" {
237                                return Some(Ok(LlmChunk {
238                                    content: String::new(),
239                                    done: true,
240                                    model: None,
241                                    usage: None,
242                                }));
243                            }
244
245                            if let Ok(chunk) = serde_json::from_str::<MistralStreamChunk>(data) {
246                                if let Some(choice) = chunk.choices.first() {
247                                    let is_done = choice.finish_reason.is_some();
248                                    let content = choice.delta.content.clone().unwrap_or_default();
249
250                                    let usage = chunk.usage.as_ref().map(|u| StreamUsage {
251                                        prompt_tokens: Some(u.prompt_tokens),
252                                        completion_tokens: Some(u.completion_tokens),
253                                        total_tokens: Some(u.total_tokens),
254                                    });
255
256                                    if !content.is_empty() || is_done {
257                                        return Some(Ok(LlmChunk {
258                                            content,
259                                            done: is_done,
260                                            model: if is_done { Some(chunk.model) } else { None },
261                                            usage,
262                                        }));
263                                    }
264                                }
265                            }
266                        }
267                    }
268                    None
269                }
270                Err(e) => Some(Err(LlmError::NetworkError(e))),
271            }
272        });
273
274        Ok(Box::pin(parsed_stream))
275    }
276}
277
278// ===== Mistral Embeddings Implementation =====
279
280#[derive(Serialize)]
281struct MistralEmbeddingRequest {
282    input: Vec<String>,
283    model: String,
284}
285
286#[derive(Deserialize)]
287struct MistralEmbeddingResponse {
288    data: Vec<MistralEmbeddingData>,
289    model: String,
290    usage: MistralEmbeddingUsage,
291}
292
293#[derive(Deserialize)]
294struct MistralEmbeddingData {
295    embedding: Vec<f32>,
296    index: usize,
297}
298
299#[derive(Deserialize)]
300struct MistralEmbeddingUsage {
301    prompt_tokens: u32,
302    total_tokens: u32,
303}
304
305#[async_trait]
306impl EmbeddingProvider for MistralProvider {
307    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
308        let model = request.model.unwrap_or_else(|| self.model.clone());
309
310        let mistral_request = MistralEmbeddingRequest {
311            input: request.texts,
312            model,
313        };
314
315        let response = self
316            .client
317            .post(format!("{}/embeddings", self.base_url))
318            .header("Authorization", format!("Bearer {}", self.api_key))
319            .header("Content-Type", "application/json")
320            .json(&mistral_request)
321            .send()
322            .await?;
323
324        let status = response.status();
325
326        if status == 429 {
327            // Extract Retry-After header if present
328            let retry_after = response
329                .headers()
330                .get("retry-after")
331                .and_then(|v| v.to_str().ok())
332                .and_then(|s| s.parse::<u64>().ok())
333                .map(Duration::from_secs);
334
335            return Err(LlmError::RateLimited(retry_after));
336        }
337
338        let body = response.text().await?;
339
340        if !status.is_success() {
341            return Err(LlmError::ApiError(format!("HTTP {}: {}", status, body)));
342        }
343
344        let mistral_response: MistralEmbeddingResponse =
345            serde_json::from_str(&body).map_err(|e| LlmError::SerializationError(e.to_string()))?;
346
347        // Sort by index to ensure correct order
348        let mut data = mistral_response.data;
349        data.sort_by_key(|d| d.index);
350
351        Ok(EmbeddingResponse {
352            embeddings: data.into_iter().map(|d| d.embedding).collect(),
353            model: mistral_response.model,
354            usage: Some(EmbeddingUsage {
355                prompt_tokens: mistral_response.usage.prompt_tokens,
356                total_tokens: mistral_response.usage.total_tokens,
357            }),
358        })
359    }
360}