1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7
8use crate::models::{Model, ModelConfig, StreamEventStream};
9use crate::types::content::{Message, Role, SystemContentBlock};
10use crate::types::errors::StrandsError;
11use crate::types::streaming::{StopReason, StreamEvent};
12use crate::types::tools::{ToolChoice, ToolSpec};
13
14#[derive(Debug, Clone, Default)]
16pub struct MistralConfig {
17 pub model_id: String,
19 pub max_tokens: Option<u32>,
21 pub temperature: Option<f32>,
23 pub top_p: Option<f32>,
25 pub api_key: Option<String>,
27}
28
29impl MistralConfig {
30 pub fn new(model_id: impl Into<String>) -> Self {
31 Self {
32 model_id: model_id.into(),
33 ..Default::default()
34 }
35 }
36
37 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
38 self.api_key = Some(api_key.into());
39 self
40 }
41
42 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
43 self.max_tokens = Some(max_tokens);
44 self
45 }
46
47 pub fn with_temperature(mut self, temperature: f32) -> Self {
48 self.temperature = Some(temperature);
49 self
50 }
51}
52
53#[derive(Debug, Serialize)]
55struct MistralRequest {
56 model: String,
57 messages: Vec<MistralMessage>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 max_tokens: Option<u32>,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 temperature: Option<f32>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 top_p: Option<f32>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 tools: Option<Vec<MistralTool>>,
66 stream: bool,
67}
68
69#[derive(Debug, Serialize, Deserialize)]
70struct MistralMessage {
71 role: String,
72 content: String,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 tool_calls: Option<Vec<MistralToolCall>>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 tool_call_id: Option<String>,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80struct MistralToolCall {
81 id: String,
82 #[serde(rename = "type")]
83 call_type: String,
84 function: MistralFunction,
85}
86
87#[derive(Debug, Serialize, Deserialize)]
88struct MistralFunction {
89 name: String,
90 arguments: String,
91}
92
93#[derive(Debug, Serialize)]
94struct MistralTool {
95 #[serde(rename = "type")]
96 tool_type: String,
97 function: MistralFunctionDef,
98}
99
100#[derive(Debug, Serialize)]
101struct MistralFunctionDef {
102 name: String,
103 description: String,
104 parameters: serde_json::Value,
105}
106
107pub struct MistralModel {
109 config: ModelConfig,
110 mistral_config: MistralConfig,
111 client: reqwest::Client,
112}
113
114impl MistralModel {
115 const BASE_URL: &'static str = "https://api.mistral.ai/v1";
116
117 pub fn new(config: MistralConfig) -> Self {
118 let model_config = ModelConfig {
119 model_id: config.model_id.clone(),
120 max_tokens: config.max_tokens,
121 temperature: config.temperature,
122 top_p: config.top_p,
123 ..Default::default()
124 };
125
126 Self {
127 config: model_config,
128 mistral_config: config,
129 client: reqwest::Client::new(),
130 }
131 }
132
133 fn api_key(&self) -> Result<String, StrandsError> {
134 self.mistral_config
135 .api_key
136 .clone()
137 .or_else(|| std::env::var("MISTRAL_API_KEY").ok())
138 .ok_or_else(|| StrandsError::ConfigurationError {
139 message: "Mistral API key not configured. Set MISTRAL_API_KEY or provide api_key".into(),
140 })
141 }
142
143 fn convert_messages(&self, messages: &[Message], system_prompt: Option<&str>) -> Vec<MistralMessage> {
144 let mut result = Vec::new();
145
146 if let Some(prompt) = system_prompt {
147 result.push(MistralMessage {
148 role: "system".to_string(),
149 content: prompt.to_string(),
150 tool_calls: None,
151 tool_call_id: None,
152 });
153 }
154
155 for msg in messages {
156 let role = match msg.role {
157 Role::User => "user",
158 Role::Assistant => "assistant",
159 };
160
161 let content = msg.text_content();
162
163 let tool_calls: Option<Vec<MistralToolCall>> = {
164 let calls: Vec<_> = msg
165 .content
166 .iter()
167 .filter_map(|b| b.tool_use.as_ref())
168 .map(|tu| MistralToolCall {
169 id: tu.tool_use_id.clone(),
170 call_type: "function".to_string(),
171 function: MistralFunction {
172 name: tu.name.clone(),
173 arguments: serde_json::to_string(&tu.input).unwrap_or_default(),
174 },
175 })
176 .collect();
177
178 if calls.is_empty() {
179 None
180 } else {
181 Some(calls)
182 }
183 };
184
185 if tool_calls.is_some() {
186 result.push(MistralMessage {
187 role: role.to_string(),
188 content,
189 tool_calls,
190 tool_call_id: None,
191 });
192 } else if msg.has_tool_result() {
193 for block in &msg.content {
194 if let Some(tr) = &block.tool_result {
195 let content_text = tr
196 .content
197 .iter()
198 .filter_map(|c| c.text.as_ref())
199 .cloned()
200 .collect::<Vec<_>>()
201 .join("");
202
203 result.push(MistralMessage {
204 role: "tool".to_string(),
205 content: content_text,
206 tool_calls: None,
207 tool_call_id: Some(tr.tool_use_id.clone()),
208 });
209 }
210 }
211 } else {
212 result.push(MistralMessage {
213 role: role.to_string(),
214 content,
215 tool_calls: None,
216 tool_call_id: None,
217 });
218 }
219 }
220
221 result
222 }
223
224 fn convert_tools(&self, tool_specs: &[ToolSpec]) -> Vec<MistralTool> {
225 tool_specs
226 .iter()
227 .map(|spec| MistralTool {
228 tool_type: "function".to_string(),
229 function: MistralFunctionDef {
230 name: spec.name.clone(),
231 description: spec.description.clone(),
232 parameters: serde_json::to_value(&spec.input_schema).unwrap_or_default(),
233 },
234 })
235 .collect()
236 }
237}
238
239#[async_trait]
240impl Model for MistralModel {
241 fn config(&self) -> &ModelConfig {
242 &self.config
243 }
244
245 fn update_config(&mut self, config: ModelConfig) {
246 self.config = config;
247 }
248
249 fn stream<'a>(
250 &'a self,
251 messages: &'a [Message],
252 tool_specs: Option<&'a [ToolSpec]>,
253 system_prompt: Option<&'a str>,
254 _tool_choice: Option<ToolChoice>,
255 _system_prompt_content: Option<&'a [SystemContentBlock]>,
256 ) -> StreamEventStream<'a> {
257 let messages = messages.to_vec();
258 let tool_specs = tool_specs.map(|t| t.to_vec());
259 let system_prompt = system_prompt.map(|s| s.to_string());
260
261 Box::pin(async_stream::stream! {
262 let api_key = match self.api_key() {
263 Ok(key) => key,
264 Err(e) => {
265 yield Err(e);
266 return;
267 }
268 };
269
270 let mistral_messages = self.convert_messages(&messages, system_prompt.as_deref());
271 let tools = tool_specs.as_ref().map(|specs| self.convert_tools(specs));
272
273 let request = MistralRequest {
274 model: self.config.model_id.clone(),
275 messages: mistral_messages,
276 max_tokens: self.config.max_tokens,
277 temperature: self.config.temperature,
278 top_p: self.config.top_p,
279 tools,
280 stream: true,
281 };
282
283 let url = format!("{}/chat/completions", Self::BASE_URL);
284
285 let response = match self.client
286 .post(&url)
287 .header("Authorization", format!("Bearer {}", api_key))
288 .header("Content-Type", "application/json")
289 .json(&request)
290 .send()
291 .await
292 {
293 Ok(resp) => resp,
294 Err(e) => {
295 yield Err(StrandsError::NetworkError(e.to_string()));
296 return;
297 }
298 };
299
300 if !response.status().is_success() {
301 let status = response.status();
302 let body = response.text().await.unwrap_or_default();
303
304 if status.as_u16() == 429 {
305 yield Err(StrandsError::ModelThrottled {
306 message: "Mistral rate limit exceeded".into(),
307 });
308 } else {
309 yield Err(StrandsError::ModelError {
310 message: format!("Mistral API error {}: {}", status, body),
311 source: None,
312 });
313 }
314 return;
315 }
316
317 yield Ok(StreamEvent::message_start(Role::Assistant));
318
319 let body = match response.text().await {
320 Ok(b) => b,
321 Err(e) => {
322 yield Err(StrandsError::NetworkError(e.to_string()));
323 return;
324 }
325 };
326
327 for line in body.lines() {
328 if line.starts_with("data: ") {
329 let data = &line[6..];
330 if data == "[DONE]" {
331 break;
332 }
333
334 if let Ok(chunk) = serde_json::from_str::<serde_json::Value>(data) {
335 if let Some(choices) = chunk.get("choices").and_then(|c| c.as_array()) {
336 for choice in choices {
337 if let Some(delta) = choice.get("delta") {
338 if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
339 yield Ok(StreamEvent::text_delta(0, content));
340 }
341 }
342 }
343 }
344 }
345 }
346 }
347
348 yield Ok(StreamEvent::message_stop(StopReason::EndTurn));
349 })
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn test_mistral_config() {
359 let config = MistralConfig::new("mistral-large-latest")
360 .with_api_key("test-key")
361 .with_temperature(0.7);
362
363 assert_eq!(config.model_id, "mistral-large-latest");
364 assert_eq!(config.api_key, Some("test-key".to_string()));
365 assert_eq!(config.temperature, Some(0.7));
366 }
367}
368