1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, LlmToolInfo, Mistral};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10#[derive(Debug, Clone)]
12pub struct MistralConfig {
13 pub api_key: String,
15 pub base_url: String,
17}
18
19impl Default for MistralConfig {
20 fn default() -> Self {
21 Self {
22 api_key: env::var("MISTRAL_API_KEY").unwrap_or_default(),
23 base_url: "https://api.mistral.ai/v1".to_string(),
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct MistralProvider {
31 config: MistralConfig,
33}
34
35impl MistralProvider {
36 #[instrument(level = "debug")]
48 pub fn new() -> Self {
49 info!("Creating new MistralProvider with default configuration");
50 let config = MistralConfig::default();
51 debug!("API key set: {}", !config.api_key.is_empty());
52 debug!("Base URL: {}", config.base_url);
53
54 Self { config }
55 }
56
57 #[instrument(skip(config), level = "debug")]
72 pub fn with_config(config: MistralConfig) -> Self {
73 info!("Creating new MistralProvider with custom configuration");
74 debug!("API key set: {}", !config.api_key.is_empty());
75 debug!("Base URL: {}", config.base_url);
76
77 Self { config }
78 }
79}
80
81impl Default for MistralProvider {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87pub trait MistralModelInfo {
89 fn mistral_model_id(&self) -> String;
90}
91
92impl HTTPProvider<Mistral> for MistralProvider {
93 fn accept(&self, model: Mistral, chat: &Chat) -> Result<Request> {
94 info!("Creating request for Mistral model: {:?}", model);
95 debug!("Messages in chat history: {}", chat.history.len());
96
97 let url_str = format!("{}/chat/completions", self.config.base_url);
98 debug!("Parsing URL: {}", url_str);
99 let url = match Url::parse(&url_str) {
100 Ok(url) => {
101 debug!("URL parsed successfully: {}", url);
102 url
103 }
104 Err(e) => {
105 error!("Failed to parse URL '{}': {}", url_str, e);
106 return Err(e.into());
107 }
108 };
109
110 let mut request = Request::new(Method::POST, url);
111 debug!("Created request: {} {}", request.method(), request.url());
112
113 debug!("Setting request headers");
115
116 let auth_header = match format!("Bearer {}", self.config.api_key).parse() {
118 Ok(header) => header,
119 Err(e) => {
120 error!("Invalid API key format: {}", e);
121 return Err(Error::Authentication("Invalid API key format".into()));
122 }
123 };
124
125 let content_type_header = match "application/json".parse() {
126 Ok(header) => header,
127 Err(e) => {
128 error!("Failed to set content type: {}", e);
129 return Err(Error::Other("Failed to set content type".into()));
130 }
131 };
132
133 request.headers_mut().insert("Authorization", auth_header);
134 request
135 .headers_mut()
136 .insert("Content-Type", content_type_header);
137
138 trace!("Request headers set: {:#?}", request.headers());
139
140 debug!("Creating request payload");
142 let payload = match self.create_request_payload(model, chat) {
143 Ok(payload) => {
144 debug!("Request payload created successfully");
145 trace!("Model: {}", payload.model);
146 trace!("Max tokens: {:?}", payload.max_tokens);
147 trace!("Number of messages: {}", payload.messages.len());
148 payload
149 }
150 Err(e) => {
151 error!("Failed to create request payload: {}", e);
152 return Err(e);
153 }
154 };
155
156 debug!("Serializing request payload");
158 let body_bytes = match serde_json::to_vec(&payload) {
159 Ok(bytes) => {
160 debug!("Payload serialized successfully ({} bytes)", bytes.len());
161 bytes
162 }
163 Err(e) => {
164 error!("Failed to serialize payload: {}", e);
165 return Err(Error::Serialization(e));
166 }
167 };
168
169 *request.body_mut() = Some(body_bytes.into());
170 info!("Request created successfully");
171
172 Ok(request)
173 }
174
175 fn parse(&self, raw_response_text: String) -> Result<Message> {
176 info!("Parsing response from Mistral API");
177 trace!("Raw response: {}", raw_response_text);
178
179 if let Ok(error_response) = serde_json::from_str::<MistralErrorResponse>(&raw_response_text)
181 {
182 if error_response.error.is_some() {
183 let error = error_response.error.unwrap();
184 error!("Mistral API returned an error: {}", error.message);
185 return Err(Error::ProviderUnavailable(error.message));
186 }
187 }
188
189 debug!("Deserializing response JSON");
191 let mistral_response = match serde_json::from_str::<MistralResponse>(&raw_response_text) {
192 Ok(response) => {
193 debug!("Response deserialized successfully");
194 debug!("Response id: {}", response.id);
195 debug!("Response model: {}", response.model);
196 if !response.choices.is_empty() {
197 debug!("Number of choices: {}", response.choices.len());
198 debug!(
199 "First choice finish reason: {:?}",
200 response.choices[0].finish_reason
201 );
202 }
203 if let Some(usage) = &response.usage {
204 debug!(
205 "Token usage - prompt: {}, completion: {}, total: {}",
206 usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
207 );
208 }
209 response
210 }
211 Err(e) => {
212 error!("Failed to deserialize response: {}", e);
213 error!("Raw response: {}", raw_response_text);
214 return Err(Error::Serialization(e));
215 }
216 };
217
218 debug!("Converting Mistral response to Message");
220 let message = Message::from(&mistral_response);
221
222 info!("Response parsed successfully");
223 trace!("Response message processed");
224
225 Ok(message)
226 }
227}
228
229impl MistralProvider {
230 #[instrument(skip(self, chat), level = "debug")]
235 fn create_request_payload(&self, model: Mistral, chat: &Chat) -> Result<MistralRequest> {
236 info!("Creating request payload for chat with Mistral model");
237 debug!("System prompt length: {}", chat.system_prompt.len());
238 debug!("Messages in history: {}", chat.history.len());
239 debug!("Max output tokens: {}", chat.max_output_tokens);
240
241 let model_id = model.mistral_model_id();
242 debug!("Using model ID: {}", model_id);
243
244 debug!("Converting messages to Mistral format");
246 let mut messages: Vec<MistralMessage> = Vec::new();
247
248 if !chat.system_prompt.is_empty() {
250 debug!("Adding system prompt");
251 messages.push(MistralMessage {
252 role: "system".to_string(),
253 content: chat.system_prompt.clone(),
254 name: None,
255 tool_calls: None,
256 tool_call_id: None,
257 });
258 }
259
260 for msg in &chat.history {
262 debug!("Converting message with role: {}", msg.role_str());
263 messages.push(MistralMessage::from(msg));
264 }
265
266 debug!("Converted {} messages for the request", messages.len());
267
268 let tools = chat
270 .tools
271 .as_ref()
272 .map(|tools| tools.iter().map(MistralTool::from).collect());
273
274 let tool_choice = if let Some(choice) = &chat.tool_choice {
276 match choice {
278 crate::tool::ToolChoice::Auto => Some(serde_json::json!("auto")),
279 crate::tool::ToolChoice::Any => Some(serde_json::json!("required")),
281 crate::tool::ToolChoice::None => Some(serde_json::json!("none")),
282 crate::tool::ToolChoice::Specific(name) => {
283 Some(serde_json::json!({
285 "type": "function",
286 "function": {
287 "name": name
288 }
289 }))
290 }
291 }
292 } else if tools.is_some() {
293 Some(serde_json::json!("auto"))
295 } else {
296 None
297 };
298
299 debug!("Creating MistralRequest");
301 let request = MistralRequest {
302 model: model_id,
303 messages,
304 temperature: None,
305 top_p: None,
306 max_tokens: Some(chat.max_output_tokens),
307 stream: None,
308 random_seed: None,
309 safe_prompt: None,
310 tools,
311 tool_choice,
312 };
313
314 info!("Request payload created successfully");
315 Ok(request)
316 }
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
321pub(crate) struct MistralMessage {
322 pub role: String,
324 pub content: String,
326 #[serde(skip_serializing_if = "Option::is_none")]
328 pub name: Option<String>,
329 #[serde(skip_serializing_if = "Option::is_none")]
331 pub tool_calls: Option<Vec<MistralToolCall>>,
332 #[serde(skip_serializing_if = "Option::is_none")]
334 pub tool_call_id: Option<String>,
335}
336
337#[derive(Debug, Serialize, Deserialize)]
339pub(crate) struct MistralFunction {
340 pub name: String,
342 pub description: String,
344 pub parameters: serde_json::Value,
346}
347
348impl From<&LlmToolInfo> for MistralTool {
349 fn from(value: &LlmToolInfo) -> Self {
350 MistralTool {
351 tool_type: "function".to_string(),
352 function: MistralFunction {
353 name: value.name.clone(),
354 description: value.description.clone(),
355 parameters: value.parameters.clone(),
356 },
357 }
358 }
359}
360
361#[derive(Debug, Serialize, Deserialize)]
363pub(crate) struct MistralTool {
364 #[serde(rename = "type")]
366 pub tool_type: String,
367 pub function: MistralFunction,
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
373pub(crate) struct MistralFunctionCall {
374 pub name: String,
376 pub arguments: String,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
382pub(crate) struct MistralToolCall {
383 pub id: String,
385 pub function: MistralFunctionCall,
387}
388
389#[derive(Debug, Serialize, Deserialize)]
391pub(crate) struct MistralRequest {
392 pub model: String,
394 pub messages: Vec<MistralMessage>,
396 #[serde(skip_serializing_if = "Option::is_none")]
398 pub temperature: Option<f32>,
399 #[serde(skip_serializing_if = "Option::is_none")]
401 pub top_p: Option<f32>,
402 #[serde(skip_serializing_if = "Option::is_none")]
404 pub max_tokens: Option<usize>,
405 #[serde(skip_serializing_if = "Option::is_none")]
407 pub stream: Option<bool>,
408 #[serde(skip_serializing_if = "Option::is_none")]
410 pub random_seed: Option<u64>,
411 #[serde(skip_serializing_if = "Option::is_none")]
413 pub safe_prompt: Option<bool>,
414 #[serde(skip_serializing_if = "Option::is_none")]
416 pub tools: Option<Vec<MistralTool>>,
417 #[serde(skip_serializing_if = "Option::is_none")]
419 pub tool_choice: Option<serde_json::Value>,
420}
421
422#[derive(Debug, Serialize, Deserialize)]
424pub(crate) struct MistralResponse {
425 pub id: String,
427 pub object: String,
429 pub created: u64,
431 pub model: String,
433 pub choices: Vec<MistralChoice>,
435 pub usage: Option<MistralUsage>,
437}
438
439#[derive(Debug, Serialize, Deserialize)]
441pub(crate) struct MistralChoice {
442 pub index: usize,
444 pub message: MistralMessage,
446 pub finish_reason: Option<String>,
448}
449
450#[derive(Debug, Serialize, Deserialize)]
452pub(crate) struct MistralUsage {
453 pub prompt_tokens: u32,
455 pub completion_tokens: u32,
457 pub total_tokens: u32,
459}
460
461#[derive(Debug, Serialize, Deserialize)]
463pub(crate) struct MistralErrorResponse {
464 pub error: Option<MistralError>,
466}
467
468#[derive(Debug, Serialize, Deserialize)]
470pub(crate) struct MistralError {
471 pub message: String,
473 #[serde(rename = "type")]
475 pub error_type: String,
476 #[serde(skip_serializing_if = "Option::is_none")]
478 pub code: Option<String>,
479}
480
481impl From<&Message> for MistralMessage {
483 fn from(msg: &Message) -> Self {
484 let role = match msg {
485 Message::System { .. } => "system",
486 Message::User { .. } => "user",
487 Message::Assistant { .. } => "assistant",
488 Message::Tool { .. } => "tool",
489 }
490 .to_string();
491
492 let (content, name, tool_calls, tool_call_id) = match msg {
493 Message::System { content, .. } => (content.clone(), None, None, None),
494 Message::User { content, name, .. } => {
495 let content_str = match content {
496 Content::Text(text) => text.clone(),
497 Content::Parts(parts) => {
498 parts
501 .iter()
502 .filter_map(|part| match part {
503 ContentPart::Text { text } => Some(text.clone()),
504 _ => None,
505 })
506 .collect::<Vec<String>>()
507 .join("\n")
508 }
509 };
510 (content_str, name.clone(), None, None)
511 }
512 Message::Assistant {
513 content,
514 tool_calls,
515 ..
516 } => {
517 let content_str = match content {
518 Some(Content::Text(text)) => text.clone(),
519 Some(Content::Parts(parts)) => {
520 parts
522 .iter()
523 .filter_map(|part| match part {
524 ContentPart::Text { text } => Some(text.clone()),
525 _ => None,
526 })
527 .collect::<Vec<String>>()
528 .join("\n")
529 }
530 None => String::new(),
531 };
532
533 let mistral_tool_calls = if !tool_calls.is_empty() {
535 let mut calls = Vec::with_capacity(tool_calls.len());
536
537 for tc in tool_calls {
538 calls.push(MistralToolCall {
539 id: tc.id.clone(),
540 function: MistralFunctionCall {
541 name: tc.function.name.clone(),
542 arguments: tc.function.arguments.clone(),
543 },
544 });
545 }
546
547 Some(calls)
548 } else {
549 None
550 };
551
552 (content_str, None, mistral_tool_calls, None)
553 }
554 Message::Tool {
555 tool_call_id,
556 content,
557 ..
558 } => (content.clone(), None, None, Some(tool_call_id.clone())),
559 };
560
561 MistralMessage {
562 role,
563 content,
564 name,
565 tool_calls,
566 tool_call_id,
567 }
568 }
569}
570
571impl From<&MistralResponse> for Message {
573 fn from(response: &MistralResponse) -> Self {
574 if response.choices.is_empty() {
576 return Message::assistant("No response generated");
577 }
578
579 let choice = &response.choices[0];
580 let message = &choice.message;
581
582 let mut msg = match message.role.as_str() {
584 "assistant" => {
585 let content = Some(Content::Text(message.content.clone()));
586
587 if let Some(mistral_tool_calls) = &message.tool_calls {
589 if !mistral_tool_calls.is_empty() {
590 let mut tool_calls = Vec::with_capacity(mistral_tool_calls.len());
591
592 for call in mistral_tool_calls {
593 let tool_call = crate::message::ToolCall {
594 id: call.id.clone(),
595 tool_type: "function".to_string(),
596 function: crate::message::Function {
597 name: call.function.name.clone(),
598 arguments: call.function.arguments.clone(),
599 },
600 };
601 tool_calls.push(tool_call);
602 }
603
604 Message::Assistant {
605 content,
606 tool_calls,
607 metadata: Default::default(),
608 }
609 } else {
610 if let Some(Content::Text(text)) = content {
612 Message::assistant(text)
613 } else {
614 Message::Assistant {
615 content,
616 tool_calls: Vec::new(),
617 metadata: Default::default(),
618 }
619 }
620 }
621 } else {
622 if let Some(Content::Text(text)) = content {
624 Message::assistant(text)
625 } else {
626 Message::Assistant {
627 content,
628 tool_calls: Vec::new(),
629 metadata: Default::default(),
630 }
631 }
632 }
633 }
634 "user" => {
635 if let Some(name) = &message.name {
636 Message::user_with_name(name, message.content.clone())
637 } else {
638 Message::user(message.content.clone())
639 }
640 }
641 "system" => Message::system(message.content.clone()),
642 "tool" => {
643 if let Some(tool_call_id) = &message.tool_call_id {
644 Message::tool(tool_call_id, message.content.clone())
645 } else {
646 Message::user(message.content.clone())
648 }
649 }
650 _ => Message::user(message.content.clone()), };
652
653 if let Some(usage) = &response.usage {
655 msg = msg.with_metadata(
656 "prompt_tokens",
657 serde_json::Value::Number(usage.prompt_tokens.into()),
658 );
659 msg = msg.with_metadata(
660 "completion_tokens",
661 serde_json::Value::Number(usage.completion_tokens.into()),
662 );
663 msg = msg.with_metadata(
664 "total_tokens",
665 serde_json::Value::Number(usage.total_tokens.into()),
666 );
667 }
668
669 msg
670 }
671}
672
673#[cfg(test)]
674mod tests {
675 use super::*;
676
677 #[test]
678 fn test_message_conversion() {
679 let msg = Message::user("Hello, world!");
681 let mistral_msg = MistralMessage::from(&msg);
682
683 assert_eq!(mistral_msg.role, "user");
684 assert_eq!(mistral_msg.content, "Hello, world!");
685
686 let msg = Message::system("You are a helpful assistant.");
688 let mistral_msg = MistralMessage::from(&msg);
689
690 assert_eq!(mistral_msg.role, "system");
691 assert_eq!(mistral_msg.content, "You are a helpful assistant.");
692
693 let msg = Message::assistant("I can help with that.");
695 let mistral_msg = MistralMessage::from(&msg);
696
697 assert_eq!(mistral_msg.role, "assistant");
698 assert_eq!(mistral_msg.content, "I can help with that.");
699 }
700
701 #[test]
702 fn test_error_response_parsing() {
703 let error_json = r#"{
704 "error": {
705 "message": "The model does not exist",
706 "type": "invalid_request_error",
707 "code": "model_not_found"
708 }
709 }"#;
710
711 let error_response: MistralErrorResponse = serde_json::from_str(error_json).unwrap();
712 assert!(error_response.error.is_some());
713 let error = error_response.error.unwrap();
714 assert_eq!(error.error_type, "invalid_request_error");
715 assert_eq!(error.code, Some("model_not_found".to_string()));
716 }
717}