language_barrier_core/provider/
mistral.rs

1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, LlmToolInfo, Mistral};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10/// Configuration for the Mistral provider
11#[derive(Debug, Clone)]
12pub struct MistralConfig {
13    /// API key for authentication
14    pub api_key: String,
15    /// Base URL for the API
16    pub base_url: String,
17}
18
19impl Default for MistralConfig {
20    fn default() -> Self {
21        Self {
22            api_key: env::var("MISTRAL_API_KEY").unwrap_or_default(),
23            base_url: "https://api.mistral.ai/v1".to_string(),
24        }
25    }
26}
27
28/// Implementation of the Mistral provider
29#[derive(Debug, Clone)]
30pub struct MistralProvider {
31    /// Configuration for the provider
32    config: MistralConfig,
33}
34
35impl MistralProvider {
36    /// Creates a new MistralProvider with default configuration
37    ///
38    /// This method will use the MISTRAL_API_KEY environment variable for authentication.
39    ///
40    /// # Examples
41    ///
42    /// ```
43    /// use language_barrier_core::provider::mistral::MistralProvider;
44    ///
45    /// let provider = MistralProvider::new();
46    /// ```
47    #[instrument(level = "debug")]
48    pub fn new() -> Self {
49        info!("Creating new MistralProvider with default configuration");
50        let config = MistralConfig::default();
51        debug!("API key set: {}", !config.api_key.is_empty());
52        debug!("Base URL: {}", config.base_url);
53
54        Self { config }
55    }
56
57    /// Creates a new MistralProvider with custom configuration
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// use language_barrier_core::provider::mistral::{MistralProvider, MistralConfig};
63    ///
64    /// let config = MistralConfig {
65    ///     api_key: "your-api-key".to_string(),
66    ///     base_url: "https://api.mistral.ai/v1".to_string(),
67    /// };
68    ///
69    /// let provider = MistralProvider::with_config(config);
70    /// ```
71    #[instrument(skip(config), level = "debug")]
72    pub fn with_config(config: MistralConfig) -> Self {
73        info!("Creating new MistralProvider with custom configuration");
74        debug!("API key set: {}", !config.api_key.is_empty());
75        debug!("Base URL: {}", config.base_url);
76
77        Self { config }
78    }
79}
80
81impl Default for MistralProvider {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87/// Trait to get Mistral-specific model IDs
88pub trait MistralModelInfo {
89    fn mistral_model_id(&self) -> String;
90}
91
92impl HTTPProvider<Mistral> for MistralProvider {
93    fn accept(&self, model: Mistral, chat: &Chat) -> Result<Request> {
94        info!("Creating request for Mistral model: {:?}", model);
95        debug!("Messages in chat history: {}", chat.history.len());
96
97        let url_str = format!("{}/chat/completions", self.config.base_url);
98        debug!("Parsing URL: {}", url_str);
99        let url = match Url::parse(&url_str) {
100            Ok(url) => {
101                debug!("URL parsed successfully: {}", url);
102                url
103            }
104            Err(e) => {
105                error!("Failed to parse URL '{}': {}", url_str, e);
106                return Err(e.into());
107            }
108        };
109
110        let mut request = Request::new(Method::POST, url);
111        debug!("Created request: {} {}", request.method(), request.url());
112
113        // Set headers
114        debug!("Setting request headers");
115
116        // API key as bearer token
117        let auth_header = match format!("Bearer {}", self.config.api_key).parse() {
118            Ok(header) => header,
119            Err(e) => {
120                error!("Invalid API key format: {}", e);
121                return Err(Error::Authentication("Invalid API key format".into()));
122            }
123        };
124
125        let content_type_header = match "application/json".parse() {
126            Ok(header) => header,
127            Err(e) => {
128                error!("Failed to set content type: {}", e);
129                return Err(Error::Other("Failed to set content type".into()));
130            }
131        };
132
133        request.headers_mut().insert("Authorization", auth_header);
134        request
135            .headers_mut()
136            .insert("Content-Type", content_type_header);
137
138        trace!("Request headers set: {:#?}", request.headers());
139
140        // Create the request payload
141        debug!("Creating request payload");
142        let payload = match self.create_request_payload(model, chat) {
143            Ok(payload) => {
144                debug!("Request payload created successfully");
145                trace!("Model: {}", payload.model);
146                trace!("Max tokens: {:?}", payload.max_tokens);
147                trace!("Number of messages: {}", payload.messages.len());
148                payload
149            }
150            Err(e) => {
151                error!("Failed to create request payload: {}", e);
152                return Err(e);
153            }
154        };
155
156        // Set the request body
157        debug!("Serializing request payload");
158        let body_bytes = match serde_json::to_vec(&payload) {
159            Ok(bytes) => {
160                debug!("Payload serialized successfully ({} bytes)", bytes.len());
161                bytes
162            }
163            Err(e) => {
164                error!("Failed to serialize payload: {}", e);
165                return Err(Error::Serialization(e));
166            }
167        };
168
169        *request.body_mut() = Some(body_bytes.into());
170        info!("Request created successfully");
171
172        Ok(request)
173    }
174
175    fn parse(&self, raw_response_text: String) -> Result<Message> {
176        info!("Parsing response from Mistral API");
177        trace!("Raw response: {}", raw_response_text);
178
179        // First try to parse as an error response
180        if let Ok(error_response) = serde_json::from_str::<MistralErrorResponse>(&raw_response_text)
181        {
182            if error_response.error.is_some() {
183                let error = error_response.error.unwrap();
184                error!("Mistral API returned an error: {}", error.message);
185                return Err(Error::ProviderUnavailable(error.message));
186            }
187        }
188
189        // If not an error, parse as a successful response
190        debug!("Deserializing response JSON");
191        let mistral_response = match serde_json::from_str::<MistralResponse>(&raw_response_text) {
192            Ok(response) => {
193                debug!("Response deserialized successfully");
194                debug!("Response id: {}", response.id);
195                debug!("Response model: {}", response.model);
196                if !response.choices.is_empty() {
197                    debug!("Number of choices: {}", response.choices.len());
198                    debug!(
199                        "First choice finish reason: {:?}",
200                        response.choices[0].finish_reason
201                    );
202                }
203                if let Some(usage) = &response.usage {
204                    debug!(
205                        "Token usage - prompt: {}, completion: {}, total: {}",
206                        usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
207                    );
208                }
209                response
210            }
211            Err(e) => {
212                error!("Failed to deserialize response: {}", e);
213                error!("Raw response: {}", raw_response_text);
214                return Err(Error::Serialization(e));
215            }
216        };
217
218        // Convert to our message format
219        debug!("Converting Mistral response to Message");
220        let message = Message::from(&mistral_response);
221
222        info!("Response parsed successfully");
223        trace!("Response message processed");
224
225        Ok(message)
226    }
227}
228
229impl MistralProvider {
230    /// Creates a request payload from a Chat object
231    ///
232    /// This method converts the Chat's messages and settings into a Mistral-specific
233    /// format for the API request.
234    #[instrument(skip(self, chat), level = "debug")]
235    fn create_request_payload(&self, model: Mistral, chat: &Chat) -> Result<MistralRequest> {
236        info!("Creating request payload for chat with Mistral model");
237        debug!("System prompt length: {}", chat.system_prompt.len());
238        debug!("Messages in history: {}", chat.history.len());
239        debug!("Max output tokens: {}", chat.max_output_tokens);
240
241        let model_id = model.mistral_model_id();
242        debug!("Using model ID: {}", model_id);
243
244        // Convert all messages including system prompt
245        debug!("Converting messages to Mistral format");
246        let mut messages: Vec<MistralMessage> = Vec::new();
247
248        // Add system prompt if present
249        if !chat.system_prompt.is_empty() {
250            debug!("Adding system prompt");
251            messages.push(MistralMessage {
252                role: "system".to_string(),
253                content: chat.system_prompt.clone(),
254                name: None,
255                tool_calls: None,
256                tool_call_id: None,
257            });
258        }
259
260        // Add conversation history
261        for msg in &chat.history {
262            debug!("Converting message with role: {}", msg.role_str());
263            messages.push(MistralMessage::from(msg));
264        }
265
266        debug!("Converted {} messages for the request", messages.len());
267
268        // Add tools if present
269        let tools = chat
270            .tools
271            .as_ref()
272            .map(|tools| tools.iter().map(MistralTool::from).collect());
273
274        // Create the tool choice setting
275        let tool_choice = if let Some(choice) = &chat.tool_choice {
276            // Use the explicitly configured choice
277            match choice {
278                crate::tool::ToolChoice::Auto => Some(serde_json::json!("auto")),
279                // Mistral uses "required" for what we call "Any" (following OpenAI's convention)
280                crate::tool::ToolChoice::Any => Some(serde_json::json!("required")),
281                crate::tool::ToolChoice::None => Some(serde_json::json!("none")),
282                crate::tool::ToolChoice::Specific(name) => {
283                    // For specific tool, we need to create an object with type and function properties
284                    Some(serde_json::json!({
285                        "type": "function",
286                        "function": {
287                            "name": name
288                        }
289                    }))
290                }
291            }
292        } else if tools.is_some() {
293            // Default to auto if tools are present but no choice specified
294            Some(serde_json::json!("auto"))
295        } else {
296            None
297        };
298
299        // Create the request
300        debug!("Creating MistralRequest");
301        let request = MistralRequest {
302            model: model_id,
303            messages,
304            temperature: None,
305            top_p: None,
306            max_tokens: Some(chat.max_output_tokens),
307            stream: None,
308            random_seed: None,
309            safe_prompt: None,
310            tools,
311            tool_choice,
312        };
313
314        info!("Request payload created successfully");
315        Ok(request)
316    }
317}
318
319/// Represents a message in the Mistral API format
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub(crate) struct MistralMessage {
322    /// The role of the message sender (system, user, assistant, etc.)
323    pub role: String,
324    /// The content of the message
325    pub content: String,
326    /// The name of the function
327    #[serde(skip_serializing_if = "Option::is_none")]
328    pub name: Option<String>,
329    /// Tool calls
330    #[serde(skip_serializing_if = "Option::is_none")]
331    pub tool_calls: Option<Vec<MistralToolCall>>,
332    /// Tool call ID
333    #[serde(skip_serializing_if = "Option::is_none")]
334    pub tool_call_id: Option<String>,
335}
336
337/// Represents a tool function in the Mistral API format
338#[derive(Debug, Serialize, Deserialize)]
339pub(crate) struct MistralFunction {
340    /// The name of the function
341    pub name: String,
342    /// The description of the function
343    pub description: String,
344    /// The parameters schema as a JSON object
345    pub parameters: serde_json::Value,
346}
347
348impl From<&LlmToolInfo> for MistralTool {
349    fn from(value: &LlmToolInfo) -> Self {
350        MistralTool {
351            tool_type: "function".to_string(),
352            function: MistralFunction {
353                name: value.name.clone(),
354                description: value.description.clone(),
355                parameters: value.parameters.clone(),
356            },
357        }
358    }
359}
360
361/// Represents a tool in the Mistral API format
362#[derive(Debug, Serialize, Deserialize)]
363pub(crate) struct MistralTool {
364    /// The type of the tool (currently always "function")
365    #[serde(rename = "type")]
366    pub tool_type: String,
367    /// The function definition
368    pub function: MistralFunction,
369}
370
371/// Represents a function call in the Mistral API format
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub(crate) struct MistralFunctionCall {
374    /// The name of the function
375    pub name: String,
376    /// The arguments as a JSON string
377    pub arguments: String,
378}
379
380/// Represents a tool call in the Mistral API format
381#[derive(Debug, Clone, Serialize, Deserialize)]
382pub(crate) struct MistralToolCall {
383    /// The ID of the tool call
384    pub id: String,
385    /// The function call
386    pub function: MistralFunctionCall,
387}
388
389/// Represents a request to the Mistral API
390#[derive(Debug, Serialize, Deserialize)]
391pub(crate) struct MistralRequest {
392    /// The model to use
393    pub model: String,
394    /// The messages to send
395    pub messages: Vec<MistralMessage>,
396    /// Temperature (randomness)
397    #[serde(skip_serializing_if = "Option::is_none")]
398    pub temperature: Option<f32>,
399    /// Top-p sampling
400    #[serde(skip_serializing_if = "Option::is_none")]
401    pub top_p: Option<f32>,
402    /// Maximum number of tokens to generate
403    #[serde(skip_serializing_if = "Option::is_none")]
404    pub max_tokens: Option<usize>,
405    /// Stream mode
406    #[serde(skip_serializing_if = "Option::is_none")]
407    pub stream: Option<bool>,
408    /// Random seed
409    #[serde(skip_serializing_if = "Option::is_none")]
410    pub random_seed: Option<u64>,
411    /// Safe prompt
412    #[serde(skip_serializing_if = "Option::is_none")]
413    pub safe_prompt: Option<bool>,
414    /// Tools available to the model
415    #[serde(skip_serializing_if = "Option::is_none")]
416    pub tools: Option<Vec<MistralTool>>,
417    /// Tool choice strategy (auto, none, or a specific tool)
418    #[serde(skip_serializing_if = "Option::is_none")]
419    pub tool_choice: Option<serde_json::Value>,
420}
421
422/// Represents a response from the Mistral API
423#[derive(Debug, Serialize, Deserialize)]
424pub(crate) struct MistralResponse {
425    /// Response ID
426    pub id: String,
427    /// Object type
428    pub object: String,
429    /// Creation timestamp
430    pub created: u64,
431    /// Model used
432    pub model: String,
433    /// Choices generated
434    pub choices: Vec<MistralChoice>,
435    /// Usage statistics
436    pub usage: Option<MistralUsage>,
437}
438
439/// Represents a choice in a Mistral response
440#[derive(Debug, Serialize, Deserialize)]
441pub(crate) struct MistralChoice {
442    /// The index of the choice
443    pub index: usize,
444    /// The message generated
445    pub message: MistralMessage,
446    /// The reason generation stopped
447    pub finish_reason: Option<String>,
448}
449
450/// Represents usage statistics in a Mistral response
451#[derive(Debug, Serialize, Deserialize)]
452pub(crate) struct MistralUsage {
453    /// Number of tokens in the prompt
454    pub prompt_tokens: u32,
455    /// Number of tokens in the completion
456    pub completion_tokens: u32,
457    /// Total number of tokens
458    pub total_tokens: u32,
459}
460
461/// Represents an error response from the Mistral API
462#[derive(Debug, Serialize, Deserialize)]
463pub(crate) struct MistralErrorResponse {
464    /// The error details
465    pub error: Option<MistralError>,
466}
467
468/// Represents an error from the Mistral API
469#[derive(Debug, Serialize, Deserialize)]
470pub(crate) struct MistralError {
471    /// The error message
472    pub message: String,
473    /// The error type
474    #[serde(rename = "type")]
475    pub error_type: String,
476    /// The error code
477    #[serde(skip_serializing_if = "Option::is_none")]
478    pub code: Option<String>,
479}
480
481/// Convert from our Message to Mistral's message format
482impl From<&Message> for MistralMessage {
483    fn from(msg: &Message) -> Self {
484        let role = match msg {
485            Message::System { .. } => "system",
486            Message::User { .. } => "user",
487            Message::Assistant { .. } => "assistant",
488            Message::Tool { .. } => "tool",
489        }
490        .to_string();
491
492        let (content, name, tool_calls, tool_call_id) = match msg {
493            Message::System { content, .. } => (content.clone(), None, None, None),
494            Message::User { content, name, .. } => {
495                let content_str = match content {
496                    Content::Text(text) => text.clone(),
497                    Content::Parts(parts) => {
498                        // For now, we just concatenate all text parts
499                        // A more complete implementation would handle multimodal content
500                        parts
501                            .iter()
502                            .filter_map(|part| match part {
503                                ContentPart::Text { text } => Some(text.clone()),
504                                _ => None,
505                            })
506                            .collect::<Vec<String>>()
507                            .join("\n")
508                    }
509                };
510                (content_str, name.clone(), None, None)
511            }
512            Message::Assistant {
513                content,
514                tool_calls,
515                ..
516            } => {
517                let content_str = match content {
518                    Some(Content::Text(text)) => text.clone(),
519                    Some(Content::Parts(parts)) => {
520                        // Concatenate text parts
521                        parts
522                            .iter()
523                            .filter_map(|part| match part {
524                                ContentPart::Text { text } => Some(text.clone()),
525                                _ => None,
526                            })
527                            .collect::<Vec<String>>()
528                            .join("\n")
529                    }
530                    None => String::new(),
531                };
532
533                // Convert tool calls if present
534                let mistral_tool_calls = if !tool_calls.is_empty() {
535                    let mut calls = Vec::with_capacity(tool_calls.len());
536
537                    for tc in tool_calls {
538                        calls.push(MistralToolCall {
539                            id: tc.id.clone(),
540                            function: MistralFunctionCall {
541                                name: tc.function.name.clone(),
542                                arguments: tc.function.arguments.clone(),
543                            },
544                        });
545                    }
546
547                    Some(calls)
548                } else {
549                    None
550                };
551
552                (content_str, None, mistral_tool_calls, None)
553            }
554            Message::Tool {
555                tool_call_id,
556                content,
557                ..
558            } => (content.clone(), None, None, Some(tool_call_id.clone())),
559        };
560
561        MistralMessage {
562            role,
563            content,
564            name,
565            tool_calls,
566            tool_call_id,
567        }
568    }
569}
570
571/// Convert from Mistral's response to our message format
572impl From<&MistralResponse> for Message {
573    fn from(response: &MistralResponse) -> Self {
574        // Get the first choice (there should be at least one)
575        if response.choices.is_empty() {
576            return Message::assistant("No response generated");
577        }
578
579        let choice = &response.choices[0];
580        let message = &choice.message;
581
582        // Create appropriate Message variant based on role
583        let mut msg = match message.role.as_str() {
584            "assistant" => {
585                let content = Some(Content::Text(message.content.clone()));
586
587                // Convert tool calls if present
588                if let Some(mistral_tool_calls) = &message.tool_calls {
589                    if !mistral_tool_calls.is_empty() {
590                        let mut tool_calls = Vec::with_capacity(mistral_tool_calls.len());
591
592                        for call in mistral_tool_calls {
593                            let tool_call = crate::message::ToolCall {
594                                id: call.id.clone(),
595                                tool_type: "function".to_string(),
596                                function: crate::message::Function {
597                                    name: call.function.name.clone(),
598                                    arguments: call.function.arguments.clone(),
599                                },
600                            };
601                            tool_calls.push(tool_call);
602                        }
603
604                        Message::Assistant {
605                            content,
606                            tool_calls,
607                            metadata: Default::default(),
608                        }
609                    } else {
610                        // No tool calls, just content
611                        if let Some(Content::Text(text)) = content {
612                            Message::assistant(text)
613                        } else {
614                            Message::Assistant {
615                                content,
616                                tool_calls: Vec::new(),
617                                metadata: Default::default(),
618                            }
619                        }
620                    }
621                } else {
622                    // No tool calls
623                    if let Some(Content::Text(text)) = content {
624                        Message::assistant(text)
625                    } else {
626                        Message::Assistant {
627                            content,
628                            tool_calls: Vec::new(),
629                            metadata: Default::default(),
630                        }
631                    }
632                }
633            }
634            "user" => {
635                if let Some(name) = &message.name {
636                    Message::user_with_name(name, message.content.clone())
637                } else {
638                    Message::user(message.content.clone())
639                }
640            }
641            "system" => Message::system(message.content.clone()),
642            "tool" => {
643                if let Some(tool_call_id) = &message.tool_call_id {
644                    Message::tool(tool_call_id, message.content.clone())
645                } else {
646                    // This shouldn't happen, but fall back to user message
647                    Message::user(message.content.clone())
648                }
649            }
650            _ => Message::user(message.content.clone()), // Default to user for unknown roles
651        };
652
653        // Add token usage information to metadata if available
654        if let Some(usage) = &response.usage {
655            msg = msg.with_metadata(
656                "prompt_tokens",
657                serde_json::Value::Number(usage.prompt_tokens.into()),
658            );
659            msg = msg.with_metadata(
660                "completion_tokens",
661                serde_json::Value::Number(usage.completion_tokens.into()),
662            );
663            msg = msg.with_metadata(
664                "total_tokens",
665                serde_json::Value::Number(usage.total_tokens.into()),
666            );
667        }
668
669        msg
670    }
671}
672
673#[cfg(test)]
674mod tests {
675    use super::*;
676
677    #[test]
678    fn test_message_conversion() {
679        // Test simple text message
680        let msg = Message::user("Hello, world!");
681        let mistral_msg = MistralMessage::from(&msg);
682
683        assert_eq!(mistral_msg.role, "user");
684        assert_eq!(mistral_msg.content, "Hello, world!");
685
686        // Test system message
687        let msg = Message::system("You are a helpful assistant.");
688        let mistral_msg = MistralMessage::from(&msg);
689
690        assert_eq!(mistral_msg.role, "system");
691        assert_eq!(mistral_msg.content, "You are a helpful assistant.");
692
693        // Test assistant message
694        let msg = Message::assistant("I can help with that.");
695        let mistral_msg = MistralMessage::from(&msg);
696
697        assert_eq!(mistral_msg.role, "assistant");
698        assert_eq!(mistral_msg.content, "I can help with that.");
699    }
700
701    #[test]
702    fn test_error_response_parsing() {
703        let error_json = r#"{
704            "error": {
705                "message": "The model does not exist",
706                "type": "invalid_request_error",
707                "code": "model_not_found"
708            }
709        }"#;
710
711        let error_response: MistralErrorResponse = serde_json::from_str(error_json).unwrap();
712        assert!(error_response.error.is_some());
713        let error = error_response.error.unwrap();
714        assert_eq!(error.error_type, "invalid_request_error");
715        assert_eq!(error.code, Some("model_not_found".to_string()));
716    }
717}