1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, LlmToolInfo, ModelInfo};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10#[derive(Debug, Clone)]
12pub struct MistralConfig {
13 pub api_key: String,
15 pub base_url: String,
17}
18
19impl Default for MistralConfig {
20 fn default() -> Self {
21 Self {
22 api_key: env::var("MISTRAL_API_KEY").unwrap_or_default(),
23 base_url: "https://api.mistral.ai/v1".to_string(),
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct MistralProvider {
31 config: MistralConfig,
33}
34
35impl MistralProvider {
36 #[instrument(level = "debug")]
48 pub fn new() -> Self {
49 info!("Creating new MistralProvider with default configuration");
50 let config = MistralConfig::default();
51 debug!("API key set: {}", !config.api_key.is_empty());
52 debug!("Base URL: {}", config.base_url);
53
54 Self { config }
55 }
56
57 #[instrument(skip(config), level = "debug")]
72 pub fn with_config(config: MistralConfig) -> Self {
73 info!("Creating new MistralProvider with custom configuration");
74 debug!("API key set: {}", !config.api_key.is_empty());
75 debug!("Base URL: {}", config.base_url);
76
77 Self { config }
78 }
79}
80
81impl Default for MistralProvider {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87pub trait MistralModelInfo {
89 fn mistral_model_id(&self) -> String;
90}
91
92impl<M: ModelInfo + MistralModelInfo> HTTPProvider<M> for MistralProvider {
93 fn accept(&self, chat: Chat<M>) -> Result<Request> {
94 info!("Creating request for Mistral model: {:?}", chat.model);
95 debug!("Messages in chat history: {}", chat.history.len());
96
97 let url_str = format!("{}/chat/completions", self.config.base_url);
98 debug!("Parsing URL: {}", url_str);
99 let url = match Url::parse(&url_str) {
100 Ok(url) => {
101 debug!("URL parsed successfully: {}", url);
102 url
103 }
104 Err(e) => {
105 error!("Failed to parse URL '{}': {}", url_str, e);
106 return Err(e.into());
107 }
108 };
109
110 let mut request = Request::new(Method::POST, url);
111 debug!("Created request: {} {}", request.method(), request.url());
112
113 debug!("Setting request headers");
115
116 let auth_header = match format!("Bearer {}", self.config.api_key).parse() {
118 Ok(header) => header,
119 Err(e) => {
120 error!("Invalid API key format: {}", e);
121 return Err(Error::Authentication("Invalid API key format".into()));
122 }
123 };
124
125 let content_type_header = match "application/json".parse() {
126 Ok(header) => header,
127 Err(e) => {
128 error!("Failed to set content type: {}", e);
129 return Err(Error::Other("Failed to set content type".into()));
130 }
131 };
132
133 request.headers_mut().insert("Authorization", auth_header);
134 request
135 .headers_mut()
136 .insert("Content-Type", content_type_header);
137
138 trace!("Request headers set: {:#?}", request.headers());
139
140 debug!("Creating request payload");
142 let payload = match self.create_request_payload(&chat) {
143 Ok(payload) => {
144 debug!("Request payload created successfully");
145 trace!("Model: {}", payload.model);
146 trace!("Max tokens: {:?}", payload.max_tokens);
147 trace!("Number of messages: {}", payload.messages.len());
148 payload
149 }
150 Err(e) => {
151 error!("Failed to create request payload: {}", e);
152 return Err(e);
153 }
154 };
155
156 debug!("Serializing request payload");
158 let body_bytes = match serde_json::to_vec(&payload) {
159 Ok(bytes) => {
160 debug!("Payload serialized successfully ({} bytes)", bytes.len());
161 bytes
162 }
163 Err(e) => {
164 error!("Failed to serialize payload: {}", e);
165 return Err(Error::Serialization(e));
166 }
167 };
168
169 *request.body_mut() = Some(body_bytes.into());
170 info!("Request created successfully");
171
172 Ok(request)
173 }
174
175 fn parse(&self, raw_response_text: String) -> Result<Message> {
176 info!("Parsing response from Mistral API");
177 trace!("Raw response: {}", raw_response_text);
178
179 if let Ok(error_response) = serde_json::from_str::<MistralErrorResponse>(&raw_response_text)
181 {
182 if error_response.error.is_some() {
183 let error = error_response.error.unwrap();
184 error!("Mistral API returned an error: {}", error.message);
185 return Err(Error::ProviderUnavailable(error.message));
186 }
187 }
188
189 debug!("Deserializing response JSON");
191 let mistral_response = match serde_json::from_str::<MistralResponse>(&raw_response_text) {
192 Ok(response) => {
193 debug!("Response deserialized successfully");
194 debug!("Response id: {}", response.id);
195 debug!("Response model: {}", response.model);
196 if !response.choices.is_empty() {
197 debug!("Number of choices: {}", response.choices.len());
198 debug!(
199 "First choice finish reason: {:?}",
200 response.choices[0].finish_reason
201 );
202 }
203 if let Some(usage) = &response.usage {
204 debug!(
205 "Token usage - prompt: {}, completion: {}, total: {}",
206 usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
207 );
208 }
209 response
210 }
211 Err(e) => {
212 error!("Failed to deserialize response: {}", e);
213 error!("Raw response: {}", raw_response_text);
214 return Err(Error::Serialization(e));
215 }
216 };
217
218 debug!("Converting Mistral response to Message");
220 let message = Message::from(&mistral_response);
221
222 info!("Response parsed successfully");
223 trace!("Response message processed");
224
225 Ok(message)
226 }
227}
228
229impl MistralProvider {
230 #[instrument(skip(self, chat), level = "debug")]
235 fn create_request_payload<M: ModelInfo + MistralModelInfo>(
236 &self,
237 chat: &Chat<M>,
238 ) -> Result<MistralRequest> {
239 info!("Creating request payload for chat with Mistral model");
240 debug!("System prompt length: {}", chat.system_prompt.len());
241 debug!("Messages in history: {}", chat.history.len());
242 debug!("Max output tokens: {}", chat.max_output_tokens);
243
244 let model_id = chat.model.mistral_model_id();
245 debug!("Using model ID: {}", model_id);
246
247 debug!("Converting messages to Mistral format");
249 let mut messages: Vec<MistralMessage> = Vec::new();
250
251 if !chat.system_prompt.is_empty() {
253 debug!("Adding system prompt");
254 messages.push(MistralMessage {
255 role: "system".to_string(),
256 content: chat.system_prompt.clone(),
257 name: None,
258 tool_calls: None,
259 tool_call_id: None,
260 });
261 }
262
263 for msg in &chat.history {
265 debug!("Converting message with role: {}", msg.role_str());
266 messages.push(MistralMessage::from(msg));
267 }
268
269 debug!("Converted {} messages for the request", messages.len());
270
271 let tools = chat
273 .tools
274 .as_ref()
275 .map(|tools| tools.iter().map(MistralTool::from).collect());
276
277 let tool_choice = if tools.is_some() {
279 Some("auto".to_string())
280 } else {
281 None
282 };
283
284 debug!("Creating MistralRequest");
286 let request = MistralRequest {
287 model: model_id,
288 messages,
289 temperature: None,
290 top_p: None,
291 max_tokens: Some(chat.max_output_tokens),
292 stream: None,
293 random_seed: None,
294 safe_prompt: None,
295 tools,
296 tool_choice,
297 };
298
299 info!("Request payload created successfully");
300 Ok(request)
301 }
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
306pub(crate) struct MistralMessage {
307 pub role: String,
309 pub content: String,
311 #[serde(skip_serializing_if = "Option::is_none")]
313 pub name: Option<String>,
314 #[serde(skip_serializing_if = "Option::is_none")]
316 pub tool_calls: Option<Vec<MistralToolCall>>,
317 #[serde(skip_serializing_if = "Option::is_none")]
319 pub tool_call_id: Option<String>,
320}
321
322#[derive(Debug, Serialize, Deserialize)]
324pub(crate) struct MistralFunction {
325 pub name: String,
327 pub description: String,
329 pub parameters: serde_json::Value,
331}
332
333impl From<&LlmToolInfo> for MistralTool {
334 fn from(value: &LlmToolInfo) -> Self {
335 MistralTool {
336 tool_type: "function".to_string(),
337 function: MistralFunction {
338 name: value.name.clone(),
339 description: value.description.clone(),
340 parameters: value.parameters.clone(),
341 },
342 }
343 }
344}
345
346#[derive(Debug, Serialize, Deserialize)]
348pub(crate) struct MistralTool {
349 #[serde(rename = "type")]
351 pub tool_type: String,
352 pub function: MistralFunction,
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize)]
358pub(crate) struct MistralFunctionCall {
359 pub name: String,
361 pub arguments: String,
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
367pub(crate) struct MistralToolCall {
368 pub id: String,
370 pub function: MistralFunctionCall,
372}
373
374#[derive(Debug, Serialize, Deserialize)]
376pub(crate) struct MistralRequest {
377 pub model: String,
379 pub messages: Vec<MistralMessage>,
381 #[serde(skip_serializing_if = "Option::is_none")]
383 pub temperature: Option<f32>,
384 #[serde(skip_serializing_if = "Option::is_none")]
386 pub top_p: Option<f32>,
387 #[serde(skip_serializing_if = "Option::is_none")]
389 pub max_tokens: Option<usize>,
390 #[serde(skip_serializing_if = "Option::is_none")]
392 pub stream: Option<bool>,
393 #[serde(skip_serializing_if = "Option::is_none")]
395 pub random_seed: Option<u64>,
396 #[serde(skip_serializing_if = "Option::is_none")]
398 pub safe_prompt: Option<bool>,
399 #[serde(skip_serializing_if = "Option::is_none")]
401 pub tools: Option<Vec<MistralTool>>,
402 #[serde(skip_serializing_if = "Option::is_none")]
404 pub tool_choice: Option<String>,
405}
406
407#[derive(Debug, Serialize, Deserialize)]
409pub(crate) struct MistralResponse {
410 pub id: String,
412 pub object: String,
414 pub created: u64,
416 pub model: String,
418 pub choices: Vec<MistralChoice>,
420 pub usage: Option<MistralUsage>,
422}
423
424#[derive(Debug, Serialize, Deserialize)]
426pub(crate) struct MistralChoice {
427 pub index: usize,
429 pub message: MistralMessage,
431 pub finish_reason: Option<String>,
433}
434
435#[derive(Debug, Serialize, Deserialize)]
437pub(crate) struct MistralUsage {
438 pub prompt_tokens: u32,
440 pub completion_tokens: u32,
442 pub total_tokens: u32,
444}
445
446#[derive(Debug, Serialize, Deserialize)]
448pub(crate) struct MistralErrorResponse {
449 pub error: Option<MistralError>,
451}
452
453#[derive(Debug, Serialize, Deserialize)]
455pub(crate) struct MistralError {
456 pub message: String,
458 #[serde(rename = "type")]
460 pub error_type: String,
461 #[serde(skip_serializing_if = "Option::is_none")]
463 pub code: Option<String>,
464}
465
466impl From<&Message> for MistralMessage {
468 fn from(msg: &Message) -> Self {
469 let role = match msg {
470 Message::System { .. } => "system",
471 Message::User { .. } => "user",
472 Message::Assistant { .. } => "assistant",
473 Message::Tool { .. } => "tool",
474 }
475 .to_string();
476
477 let (content, name, tool_calls, tool_call_id) = match msg {
478 Message::System { content, .. } => (content.clone(), None, None, None),
479 Message::User { content, name, .. } => {
480 let content_str = match content {
481 Content::Text(text) => text.clone(),
482 Content::Parts(parts) => {
483 parts
486 .iter()
487 .filter_map(|part| match part {
488 ContentPart::Text { text } => Some(text.clone()),
489 _ => None,
490 })
491 .collect::<Vec<String>>()
492 .join("\n")
493 }
494 };
495 (content_str, name.clone(), None, None)
496 }
497 Message::Assistant {
498 content,
499 tool_calls,
500 ..
501 } => {
502 let content_str = match content {
503 Some(Content::Text(text)) => text.clone(),
504 Some(Content::Parts(parts)) => {
505 parts
507 .iter()
508 .filter_map(|part| match part {
509 ContentPart::Text { text } => Some(text.clone()),
510 _ => None,
511 })
512 .collect::<Vec<String>>()
513 .join("\n")
514 }
515 None => String::new(),
516 };
517
518 let mistral_tool_calls = if !tool_calls.is_empty() {
520 let mut calls = Vec::with_capacity(tool_calls.len());
521
522 for tc in tool_calls {
523 calls.push(MistralToolCall {
524 id: tc.id.clone(),
525 function: MistralFunctionCall {
526 name: tc.function.name.clone(),
527 arguments: tc.function.arguments.clone(),
528 },
529 });
530 }
531
532 Some(calls)
533 } else {
534 None
535 };
536
537 (content_str, None, mistral_tool_calls, None)
538 }
539 Message::Tool {
540 tool_call_id,
541 content,
542 ..
543 } => (content.clone(), None, None, Some(tool_call_id.clone())),
544 };
545
546 MistralMessage {
547 role,
548 content,
549 name,
550 tool_calls,
551 tool_call_id,
552 }
553 }
554}
555
556impl From<&MistralResponse> for Message {
558 fn from(response: &MistralResponse) -> Self {
559 if response.choices.is_empty() {
561 return Message::assistant("No response generated");
562 }
563
564 let choice = &response.choices[0];
565 let message = &choice.message;
566
567 let mut msg = match message.role.as_str() {
569 "assistant" => {
570 let content = Some(Content::Text(message.content.clone()));
571
572 if let Some(mistral_tool_calls) = &message.tool_calls {
574 if !mistral_tool_calls.is_empty() {
575 let mut tool_calls = Vec::with_capacity(mistral_tool_calls.len());
576
577 for call in mistral_tool_calls {
578 let tool_call = crate::message::ToolCall {
579 id: call.id.clone(),
580 tool_type: "function".to_string(),
581 function: crate::message::Function {
582 name: call.function.name.clone(),
583 arguments: call.function.arguments.clone(),
584 },
585 };
586 tool_calls.push(tool_call);
587 }
588
589 Message::Assistant {
590 content,
591 tool_calls,
592 metadata: Default::default(),
593 }
594 } else {
595 if let Some(Content::Text(text)) = content {
597 Message::assistant(text)
598 } else {
599 Message::Assistant {
600 content,
601 tool_calls: Vec::new(),
602 metadata: Default::default(),
603 }
604 }
605 }
606 } else {
607 if let Some(Content::Text(text)) = content {
609 Message::assistant(text)
610 } else {
611 Message::Assistant {
612 content,
613 tool_calls: Vec::new(),
614 metadata: Default::default(),
615 }
616 }
617 }
618 }
619 "user" => {
620 if let Some(name) = &message.name {
621 Message::user_with_name(name, message.content.clone())
622 } else {
623 Message::user(message.content.clone())
624 }
625 }
626 "system" => Message::system(message.content.clone()),
627 "tool" => {
628 if let Some(tool_call_id) = &message.tool_call_id {
629 Message::tool(tool_call_id, message.content.clone())
630 } else {
631 Message::user(message.content.clone())
633 }
634 }
635 _ => Message::user(message.content.clone()), };
637
638 if let Some(usage) = &response.usage {
640 msg = msg.with_metadata(
641 "prompt_tokens",
642 serde_json::Value::Number(usage.prompt_tokens.into()),
643 );
644 msg = msg.with_metadata(
645 "completion_tokens",
646 serde_json::Value::Number(usage.completion_tokens.into()),
647 );
648 msg = msg.with_metadata(
649 "total_tokens",
650 serde_json::Value::Number(usage.total_tokens.into()),
651 );
652 }
653
654 msg
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661
662 #[test]
663 fn test_message_conversion() {
664 let msg = Message::user("Hello, world!");
666 let mistral_msg = MistralMessage::from(&msg);
667
668 assert_eq!(mistral_msg.role, "user");
669 assert_eq!(mistral_msg.content, "Hello, world!");
670
671 let msg = Message::system("You are a helpful assistant.");
673 let mistral_msg = MistralMessage::from(&msg);
674
675 assert_eq!(mistral_msg.role, "system");
676 assert_eq!(mistral_msg.content, "You are a helpful assistant.");
677
678 let msg = Message::assistant("I can help with that.");
680 let mistral_msg = MistralMessage::from(&msg);
681
682 assert_eq!(mistral_msg.role, "assistant");
683 assert_eq!(mistral_msg.content, "I can help with that.");
684 }
685
686 #[test]
687 fn test_error_response_parsing() {
688 let error_json = r#"{
689 "error": {
690 "message": "The model does not exist",
691 "type": "invalid_request_error",
692 "code": "model_not_found"
693 }
694 }"#;
695
696 let error_response: MistralErrorResponse = serde_json::from_str(error_json).unwrap();
697 assert!(error_response.error.is_some());
698 let error = error_response.error.unwrap();
699 assert_eq!(error.error_type, "invalid_request_error");
700 assert_eq!(error.code, Some("model_not_found".to_string()));
701 }
702}