language_barrier_core/provider/
mistral.rs

1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, LlmToolInfo, ModelInfo};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10/// Configuration for the Mistral provider
11#[derive(Debug, Clone)]
12pub struct MistralConfig {
13    /// API key for authentication
14    pub api_key: String,
15    /// Base URL for the API
16    pub base_url: String,
17}
18
19impl Default for MistralConfig {
20    fn default() -> Self {
21        Self {
22            api_key: env::var("MISTRAL_API_KEY").unwrap_or_default(),
23            base_url: "https://api.mistral.ai/v1".to_string(),
24        }
25    }
26}
27
28/// Implementation of the Mistral provider
29#[derive(Debug, Clone)]
30pub struct MistralProvider {
31    /// Configuration for the provider
32    config: MistralConfig,
33}
34
35impl MistralProvider {
36    /// Creates a new MistralProvider with default configuration
37    ///
38    /// This method will use the MISTRAL_API_KEY environment variable for authentication.
39    ///
40    /// # Examples
41    ///
42    /// ```
43    /// use language_barrier_core::provider::mistral::MistralProvider;
44    ///
45    /// let provider = MistralProvider::new();
46    /// ```
47    #[instrument(level = "debug")]
48    pub fn new() -> Self {
49        info!("Creating new MistralProvider with default configuration");
50        let config = MistralConfig::default();
51        debug!("API key set: {}", !config.api_key.is_empty());
52        debug!("Base URL: {}", config.base_url);
53
54        Self { config }
55    }
56
57    /// Creates a new MistralProvider with custom configuration
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// use language_barrier_core::provider::mistral::{MistralProvider, MistralConfig};
63    ///
64    /// let config = MistralConfig {
65    ///     api_key: "your-api-key".to_string(),
66    ///     base_url: "https://api.mistral.ai/v1".to_string(),
67    /// };
68    ///
69    /// let provider = MistralProvider::with_config(config);
70    /// ```
71    #[instrument(skip(config), level = "debug")]
72    pub fn with_config(config: MistralConfig) -> Self {
73        info!("Creating new MistralProvider with custom configuration");
74        debug!("API key set: {}", !config.api_key.is_empty());
75        debug!("Base URL: {}", config.base_url);
76
77        Self { config }
78    }
79}
80
81impl Default for MistralProvider {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87/// Trait to get Mistral-specific model IDs
88pub trait MistralModelInfo {
89    fn mistral_model_id(&self) -> String;
90}
91
92impl<M: ModelInfo + MistralModelInfo> HTTPProvider<M> for MistralProvider {
93    fn accept(&self, chat: Chat<M>) -> Result<Request> {
94        info!("Creating request for Mistral model: {:?}", chat.model);
95        debug!("Messages in chat history: {}", chat.history.len());
96
97        let url_str = format!("{}/chat/completions", self.config.base_url);
98        debug!("Parsing URL: {}", url_str);
99        let url = match Url::parse(&url_str) {
100            Ok(url) => {
101                debug!("URL parsed successfully: {}", url);
102                url
103            }
104            Err(e) => {
105                error!("Failed to parse URL '{}': {}", url_str, e);
106                return Err(e.into());
107            }
108        };
109
110        let mut request = Request::new(Method::POST, url);
111        debug!("Created request: {} {}", request.method(), request.url());
112
113        // Set headers
114        debug!("Setting request headers");
115
116        // API key as bearer token
117        let auth_header = match format!("Bearer {}", self.config.api_key).parse() {
118            Ok(header) => header,
119            Err(e) => {
120                error!("Invalid API key format: {}", e);
121                return Err(Error::Authentication("Invalid API key format".into()));
122            }
123        };
124
125        let content_type_header = match "application/json".parse() {
126            Ok(header) => header,
127            Err(e) => {
128                error!("Failed to set content type: {}", e);
129                return Err(Error::Other("Failed to set content type".into()));
130            }
131        };
132
133        request.headers_mut().insert("Authorization", auth_header);
134        request
135            .headers_mut()
136            .insert("Content-Type", content_type_header);
137
138        trace!("Request headers set: {:#?}", request.headers());
139
140        // Create the request payload
141        debug!("Creating request payload");
142        let payload = match self.create_request_payload(&chat) {
143            Ok(payload) => {
144                debug!("Request payload created successfully");
145                trace!("Model: {}", payload.model);
146                trace!("Max tokens: {:?}", payload.max_tokens);
147                trace!("Number of messages: {}", payload.messages.len());
148                payload
149            }
150            Err(e) => {
151                error!("Failed to create request payload: {}", e);
152                return Err(e);
153            }
154        };
155
156        // Set the request body
157        debug!("Serializing request payload");
158        let body_bytes = match serde_json::to_vec(&payload) {
159            Ok(bytes) => {
160                debug!("Payload serialized successfully ({} bytes)", bytes.len());
161                bytes
162            }
163            Err(e) => {
164                error!("Failed to serialize payload: {}", e);
165                return Err(Error::Serialization(e));
166            }
167        };
168
169        *request.body_mut() = Some(body_bytes.into());
170        info!("Request created successfully");
171
172        Ok(request)
173    }
174
175    fn parse(&self, raw_response_text: String) -> Result<Message> {
176        info!("Parsing response from Mistral API");
177        trace!("Raw response: {}", raw_response_text);
178
179        // First try to parse as an error response
180        if let Ok(error_response) = serde_json::from_str::<MistralErrorResponse>(&raw_response_text)
181        {
182            if error_response.error.is_some() {
183                let error = error_response.error.unwrap();
184                error!("Mistral API returned an error: {}", error.message);
185                return Err(Error::ProviderUnavailable(error.message));
186            }
187        }
188
189        // If not an error, parse as a successful response
190        debug!("Deserializing response JSON");
191        let mistral_response = match serde_json::from_str::<MistralResponse>(&raw_response_text) {
192            Ok(response) => {
193                debug!("Response deserialized successfully");
194                debug!("Response id: {}", response.id);
195                debug!("Response model: {}", response.model);
196                if !response.choices.is_empty() {
197                    debug!("Number of choices: {}", response.choices.len());
198                    debug!(
199                        "First choice finish reason: {:?}",
200                        response.choices[0].finish_reason
201                    );
202                }
203                if let Some(usage) = &response.usage {
204                    debug!(
205                        "Token usage - prompt: {}, completion: {}, total: {}",
206                        usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
207                    );
208                }
209                response
210            }
211            Err(e) => {
212                error!("Failed to deserialize response: {}", e);
213                error!("Raw response: {}", raw_response_text);
214                return Err(Error::Serialization(e));
215            }
216        };
217
218        // Convert to our message format
219        debug!("Converting Mistral response to Message");
220        let message = Message::from(&mistral_response);
221
222        info!("Response parsed successfully");
223        trace!("Response message processed");
224
225        Ok(message)
226    }
227}
228
229impl MistralProvider {
230    /// Creates a request payload from a Chat object
231    ///
232    /// This method converts the Chat's messages and settings into a Mistral-specific
233    /// format for the API request.
234    #[instrument(skip(self, chat), level = "debug")]
235    fn create_request_payload<M: ModelInfo + MistralModelInfo>(
236        &self,
237        chat: &Chat<M>,
238    ) -> Result<MistralRequest> {
239        info!("Creating request payload for chat with Mistral model");
240        debug!("System prompt length: {}", chat.system_prompt.len());
241        debug!("Messages in history: {}", chat.history.len());
242        debug!("Max output tokens: {}", chat.max_output_tokens);
243
244        let model_id = chat.model.mistral_model_id();
245        debug!("Using model ID: {}", model_id);
246
247        // Convert all messages including system prompt
248        debug!("Converting messages to Mistral format");
249        let mut messages: Vec<MistralMessage> = Vec::new();
250
251        // Add system prompt if present
252        if !chat.system_prompt.is_empty() {
253            debug!("Adding system prompt");
254            messages.push(MistralMessage {
255                role: "system".to_string(),
256                content: chat.system_prompt.clone(),
257                name: None,
258                tool_calls: None,
259                tool_call_id: None,
260            });
261        }
262
263        // Add conversation history
264        for msg in &chat.history {
265            debug!("Converting message with role: {}", msg.role_str());
266            messages.push(MistralMessage::from(msg));
267        }
268
269        debug!("Converted {} messages for the request", messages.len());
270
271        // Add tools if present
272        let tools = chat
273            .tools
274            .as_ref()
275            .map(|tools| tools.iter().map(MistralTool::from).collect());
276
277        // Create the tool choice setting
278        let tool_choice = if tools.is_some() {
279            Some("auto".to_string())
280        } else {
281            None
282        };
283
284        // Create the request
285        debug!("Creating MistralRequest");
286        let request = MistralRequest {
287            model: model_id,
288            messages,
289            temperature: None,
290            top_p: None,
291            max_tokens: Some(chat.max_output_tokens),
292            stream: None,
293            random_seed: None,
294            safe_prompt: None,
295            tools,
296            tool_choice,
297        };
298
299        info!("Request payload created successfully");
300        Ok(request)
301    }
302}
303
304/// Represents a message in the Mistral API format
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub(crate) struct MistralMessage {
307    /// The role of the message sender (system, user, assistant, etc.)
308    pub role: String,
309    /// The content of the message
310    pub content: String,
311    /// The name of the function
312    #[serde(skip_serializing_if = "Option::is_none")]
313    pub name: Option<String>,
314    /// Tool calls
315    #[serde(skip_serializing_if = "Option::is_none")]
316    pub tool_calls: Option<Vec<MistralToolCall>>,
317    /// Tool call ID
318    #[serde(skip_serializing_if = "Option::is_none")]
319    pub tool_call_id: Option<String>,
320}
321
322/// Represents a tool function in the Mistral API format
323#[derive(Debug, Serialize, Deserialize)]
324pub(crate) struct MistralFunction {
325    /// The name of the function
326    pub name: String,
327    /// The description of the function
328    pub description: String,
329    /// The parameters schema as a JSON object
330    pub parameters: serde_json::Value,
331}
332
333impl From<&LlmToolInfo> for MistralTool {
334    fn from(value: &LlmToolInfo) -> Self {
335        MistralTool {
336            tool_type: "function".to_string(),
337            function: MistralFunction {
338                name: value.name.clone(),
339                description: value.description.clone(),
340                parameters: value.parameters.clone(),
341            },
342        }
343    }
344}
345
346/// Represents a tool in the Mistral API format
347#[derive(Debug, Serialize, Deserialize)]
348pub(crate) struct MistralTool {
349    /// The type of the tool (currently always "function")
350    #[serde(rename = "type")]
351    pub tool_type: String,
352    /// The function definition
353    pub function: MistralFunction,
354}
355
356/// Represents a function call in the Mistral API format
357#[derive(Debug, Clone, Serialize, Deserialize)]
358pub(crate) struct MistralFunctionCall {
359    /// The name of the function
360    pub name: String,
361    /// The arguments as a JSON string
362    pub arguments: String,
363}
364
365/// Represents a tool call in the Mistral API format
366#[derive(Debug, Clone, Serialize, Deserialize)]
367pub(crate) struct MistralToolCall {
368    /// The ID of the tool call
369    pub id: String,
370    /// The function call
371    pub function: MistralFunctionCall,
372}
373
374/// Represents a request to the Mistral API
375#[derive(Debug, Serialize, Deserialize)]
376pub(crate) struct MistralRequest {
377    /// The model to use
378    pub model: String,
379    /// The messages to send
380    pub messages: Vec<MistralMessage>,
381    /// Temperature (randomness)
382    #[serde(skip_serializing_if = "Option::is_none")]
383    pub temperature: Option<f32>,
384    /// Top-p sampling
385    #[serde(skip_serializing_if = "Option::is_none")]
386    pub top_p: Option<f32>,
387    /// Maximum number of tokens to generate
388    #[serde(skip_serializing_if = "Option::is_none")]
389    pub max_tokens: Option<usize>,
390    /// Stream mode
391    #[serde(skip_serializing_if = "Option::is_none")]
392    pub stream: Option<bool>,
393    /// Random seed
394    #[serde(skip_serializing_if = "Option::is_none")]
395    pub random_seed: Option<u64>,
396    /// Safe prompt
397    #[serde(skip_serializing_if = "Option::is_none")]
398    pub safe_prompt: Option<bool>,
399    /// Tools available to the model
400    #[serde(skip_serializing_if = "Option::is_none")]
401    pub tools: Option<Vec<MistralTool>>,
402    /// Tool choice strategy (auto, none, or a specific tool)
403    #[serde(skip_serializing_if = "Option::is_none")]
404    pub tool_choice: Option<String>,
405}
406
407/// Represents a response from the Mistral API
408#[derive(Debug, Serialize, Deserialize)]
409pub(crate) struct MistralResponse {
410    /// Response ID
411    pub id: String,
412    /// Object type
413    pub object: String,
414    /// Creation timestamp
415    pub created: u64,
416    /// Model used
417    pub model: String,
418    /// Choices generated
419    pub choices: Vec<MistralChoice>,
420    /// Usage statistics
421    pub usage: Option<MistralUsage>,
422}
423
424/// Represents a choice in a Mistral response
425#[derive(Debug, Serialize, Deserialize)]
426pub(crate) struct MistralChoice {
427    /// The index of the choice
428    pub index: usize,
429    /// The message generated
430    pub message: MistralMessage,
431    /// The reason generation stopped
432    pub finish_reason: Option<String>,
433}
434
435/// Represents usage statistics in a Mistral response
436#[derive(Debug, Serialize, Deserialize)]
437pub(crate) struct MistralUsage {
438    /// Number of tokens in the prompt
439    pub prompt_tokens: u32,
440    /// Number of tokens in the completion
441    pub completion_tokens: u32,
442    /// Total number of tokens
443    pub total_tokens: u32,
444}
445
446/// Represents an error response from the Mistral API
447#[derive(Debug, Serialize, Deserialize)]
448pub(crate) struct MistralErrorResponse {
449    /// The error details
450    pub error: Option<MistralError>,
451}
452
453/// Represents an error from the Mistral API
454#[derive(Debug, Serialize, Deserialize)]
455pub(crate) struct MistralError {
456    /// The error message
457    pub message: String,
458    /// The error type
459    #[serde(rename = "type")]
460    pub error_type: String,
461    /// The error code
462    #[serde(skip_serializing_if = "Option::is_none")]
463    pub code: Option<String>,
464}
465
466/// Convert from our Message to Mistral's message format
467impl From<&Message> for MistralMessage {
468    fn from(msg: &Message) -> Self {
469        let role = match msg {
470            Message::System { .. } => "system",
471            Message::User { .. } => "user",
472            Message::Assistant { .. } => "assistant",
473            Message::Tool { .. } => "tool",
474        }
475        .to_string();
476
477        let (content, name, tool_calls, tool_call_id) = match msg {
478            Message::System { content, .. } => (content.clone(), None, None, None),
479            Message::User { content, name, .. } => {
480                let content_str = match content {
481                    Content::Text(text) => text.clone(),
482                    Content::Parts(parts) => {
483                        // For now, we just concatenate all text parts
484                        // A more complete implementation would handle multimodal content
485                        parts
486                            .iter()
487                            .filter_map(|part| match part {
488                                ContentPart::Text { text } => Some(text.clone()),
489                                _ => None,
490                            })
491                            .collect::<Vec<String>>()
492                            .join("\n")
493                    }
494                };
495                (content_str, name.clone(), None, None)
496            }
497            Message::Assistant {
498                content,
499                tool_calls,
500                ..
501            } => {
502                let content_str = match content {
503                    Some(Content::Text(text)) => text.clone(),
504                    Some(Content::Parts(parts)) => {
505                        // Concatenate text parts
506                        parts
507                            .iter()
508                            .filter_map(|part| match part {
509                                ContentPart::Text { text } => Some(text.clone()),
510                                _ => None,
511                            })
512                            .collect::<Vec<String>>()
513                            .join("\n")
514                    }
515                    None => String::new(),
516                };
517
518                // Convert tool calls if present
519                let mistral_tool_calls = if !tool_calls.is_empty() {
520                    let mut calls = Vec::with_capacity(tool_calls.len());
521
522                    for tc in tool_calls {
523                        calls.push(MistralToolCall {
524                            id: tc.id.clone(),
525                            function: MistralFunctionCall {
526                                name: tc.function.name.clone(),
527                                arguments: tc.function.arguments.clone(),
528                            },
529                        });
530                    }
531
532                    Some(calls)
533                } else {
534                    None
535                };
536
537                (content_str, None, mistral_tool_calls, None)
538            }
539            Message::Tool {
540                tool_call_id,
541                content,
542                ..
543            } => (content.clone(), None, None, Some(tool_call_id.clone())),
544        };
545
546        MistralMessage {
547            role,
548            content,
549            name,
550            tool_calls,
551            tool_call_id,
552        }
553    }
554}
555
556/// Convert from Mistral's response to our message format
557impl From<&MistralResponse> for Message {
558    fn from(response: &MistralResponse) -> Self {
559        // Get the first choice (there should be at least one)
560        if response.choices.is_empty() {
561            return Message::assistant("No response generated");
562        }
563
564        let choice = &response.choices[0];
565        let message = &choice.message;
566
567        // Create appropriate Message variant based on role
568        let mut msg = match message.role.as_str() {
569            "assistant" => {
570                let content = Some(Content::Text(message.content.clone()));
571
572                // Convert tool calls if present
573                if let Some(mistral_tool_calls) = &message.tool_calls {
574                    if !mistral_tool_calls.is_empty() {
575                        let mut tool_calls = Vec::with_capacity(mistral_tool_calls.len());
576
577                        for call in mistral_tool_calls {
578                            let tool_call = crate::message::ToolCall {
579                                id: call.id.clone(),
580                                tool_type: "function".to_string(),
581                                function: crate::message::Function {
582                                    name: call.function.name.clone(),
583                                    arguments: call.function.arguments.clone(),
584                                },
585                            };
586                            tool_calls.push(tool_call);
587                        }
588
589                        Message::Assistant {
590                            content,
591                            tool_calls,
592                            metadata: Default::default(),
593                        }
594                    } else {
595                        // No tool calls, just content
596                        if let Some(Content::Text(text)) = content {
597                            Message::assistant(text)
598                        } else {
599                            Message::Assistant {
600                                content,
601                                tool_calls: Vec::new(),
602                                metadata: Default::default(),
603                            }
604                        }
605                    }
606                } else {
607                    // No tool calls
608                    if let Some(Content::Text(text)) = content {
609                        Message::assistant(text)
610                    } else {
611                        Message::Assistant {
612                            content,
613                            tool_calls: Vec::new(),
614                            metadata: Default::default(),
615                        }
616                    }
617                }
618            }
619            "user" => {
620                if let Some(name) = &message.name {
621                    Message::user_with_name(name, message.content.clone())
622                } else {
623                    Message::user(message.content.clone())
624                }
625            }
626            "system" => Message::system(message.content.clone()),
627            "tool" => {
628                if let Some(tool_call_id) = &message.tool_call_id {
629                    Message::tool(tool_call_id, message.content.clone())
630                } else {
631                    // This shouldn't happen, but fall back to user message
632                    Message::user(message.content.clone())
633                }
634            }
635            _ => Message::user(message.content.clone()), // Default to user for unknown roles
636        };
637
638        // Add token usage information to metadata if available
639        if let Some(usage) = &response.usage {
640            msg = msg.with_metadata(
641                "prompt_tokens",
642                serde_json::Value::Number(usage.prompt_tokens.into()),
643            );
644            msg = msg.with_metadata(
645                "completion_tokens",
646                serde_json::Value::Number(usage.completion_tokens.into()),
647            );
648            msg = msg.with_metadata(
649                "total_tokens",
650                serde_json::Value::Number(usage.total_tokens.into()),
651            );
652        }
653
654        msg
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661
662    #[test]
663    fn test_message_conversion() {
664        // Test simple text message
665        let msg = Message::user("Hello, world!");
666        let mistral_msg = MistralMessage::from(&msg);
667
668        assert_eq!(mistral_msg.role, "user");
669        assert_eq!(mistral_msg.content, "Hello, world!");
670
671        // Test system message
672        let msg = Message::system("You are a helpful assistant.");
673        let mistral_msg = MistralMessage::from(&msg);
674
675        assert_eq!(mistral_msg.role, "system");
676        assert_eq!(mistral_msg.content, "You are a helpful assistant.");
677
678        // Test assistant message
679        let msg = Message::assistant("I can help with that.");
680        let mistral_msg = MistralMessage::from(&msg);
681
682        assert_eq!(mistral_msg.role, "assistant");
683        assert_eq!(mistral_msg.content, "I can help with that.");
684    }
685
686    #[test]
687    fn test_error_response_parsing() {
688        let error_json = r#"{
689            "error": {
690                "message": "The model does not exist",
691                "type": "invalid_request_error",
692                "code": "model_not_found"
693            }
694        }"#;
695
696        let error_response: MistralErrorResponse = serde_json::from_str(error_json).unwrap();
697        assert!(error_response.error.is_some());
698        let error = error_response.error.unwrap();
699        assert_eq!(error.error_type, "invalid_request_error");
700        assert_eq!(error.code, Some("model_not_found".to_string()));
701    }
702}