use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::models::{Model, ModelConfig, StreamEventStream};
use crate::types::content::{Message, Role, SystemContentBlock};
use crate::types::errors::StrandsError;
use crate::types::streaming::{StopReason, StreamEvent};
use crate::types::tools::{ToolChoice, ToolSpec};
#[derive(Debug, Clone, Default)]
pub struct MistralConfig {
pub model_id: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub api_key: Option<String>,
}
impl MistralConfig {
pub fn new(model_id: impl Into<String>) -> Self {
Self {
model_id: model_id.into(),
..Default::default()
}
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
}
#[derive(Debug, Serialize)]
struct MistralRequest {
model: String,
messages: Vec<MistralMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<MistralTool>>,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct MistralMessage {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<MistralToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct MistralToolCall {
id: String,
#[serde(rename = "type")]
call_type: String,
function: MistralFunction,
}
#[derive(Debug, Serialize, Deserialize)]
struct MistralFunction {
name: String,
arguments: String,
}
#[derive(Debug, Serialize)]
struct MistralTool {
#[serde(rename = "type")]
tool_type: String,
function: MistralFunctionDef,
}
#[derive(Debug, Serialize)]
struct MistralFunctionDef {
name: String,
description: String,
parameters: serde_json::Value,
}
pub struct MistralModel {
config: ModelConfig,
mistral_config: MistralConfig,
client: reqwest::Client,
}
impl MistralModel {
const BASE_URL: &'static str = "https://api.mistral.ai/v1";
pub fn new(config: MistralConfig) -> Self {
let model_config = ModelConfig {
model_id: config.model_id.clone(),
max_tokens: config.max_tokens,
temperature: config.temperature,
top_p: config.top_p,
..Default::default()
};
Self {
config: model_config,
mistral_config: config,
client: reqwest::Client::new(),
}
}
fn api_key(&self) -> Result<String, StrandsError> {
self.mistral_config
.api_key
.clone()
.or_else(|| std::env::var("MISTRAL_API_KEY").ok())
.ok_or_else(|| StrandsError::ConfigurationError {
message: "Mistral API key not configured. Set MISTRAL_API_KEY or provide api_key".into(),
})
}
fn convert_messages(&self, messages: &[Message], system_prompt: Option<&str>) -> Vec<MistralMessage> {
let mut result = Vec::new();
if let Some(prompt) = system_prompt {
result.push(MistralMessage {
role: "system".to_string(),
content: prompt.to_string(),
tool_calls: None,
tool_call_id: None,
});
}
for msg in messages {
let role = match msg.role {
Role::User => "user",
Role::Assistant => "assistant",
};
let content = msg.text_content();
let tool_calls: Option<Vec<MistralToolCall>> = {
let calls: Vec<_> = msg
.content
.iter()
.filter_map(|b| b.tool_use.as_ref())
.map(|tu| MistralToolCall {
id: tu.tool_use_id.clone(),
call_type: "function".to_string(),
function: MistralFunction {
name: tu.name.clone(),
arguments: serde_json::to_string(&tu.input).unwrap_or_default(),
},
})
.collect();
if calls.is_empty() {
None
} else {
Some(calls)
}
};
if tool_calls.is_some() {
result.push(MistralMessage {
role: role.to_string(),
content,
tool_calls,
tool_call_id: None,
});
} else if msg.has_tool_result() {
for block in &msg.content {
if let Some(tr) = &block.tool_result {
let content_text = tr
.content
.iter()
.filter_map(|c| c.text.as_ref())
.cloned()
.collect::<Vec<_>>()
.join("");
result.push(MistralMessage {
role: "tool".to_string(),
content: content_text,
tool_calls: None,
tool_call_id: Some(tr.tool_use_id.clone()),
});
}
}
} else {
result.push(MistralMessage {
role: role.to_string(),
content,
tool_calls: None,
tool_call_id: None,
});
}
}
result
}
fn convert_tools(&self, tool_specs: &[ToolSpec]) -> Vec<MistralTool> {
tool_specs
.iter()
.map(|spec| MistralTool {
tool_type: "function".to_string(),
function: MistralFunctionDef {
name: spec.name.clone(),
description: spec.description.clone(),
parameters: serde_json::to_value(&spec.input_schema).unwrap_or_default(),
},
})
.collect()
}
}
#[async_trait]
impl Model for MistralModel {
fn config(&self) -> &ModelConfig {
&self.config
}
fn update_config(&mut self, config: ModelConfig) {
self.config = config;
}
fn stream<'a>(
&'a self,
messages: &'a [Message],
tool_specs: Option<&'a [ToolSpec]>,
system_prompt: Option<&'a str>,
_tool_choice: Option<ToolChoice>,
_system_prompt_content: Option<&'a [SystemContentBlock]>,
) -> StreamEventStream<'a> {
let messages = messages.to_vec();
let tool_specs = tool_specs.map(|t| t.to_vec());
let system_prompt = system_prompt.map(|s| s.to_string());
Box::pin(async_stream::stream! {
let api_key = match self.api_key() {
Ok(key) => key,
Err(e) => {
yield Err(e);
return;
}
};
let mistral_messages = self.convert_messages(&messages, system_prompt.as_deref());
let tools = tool_specs.as_ref().map(|specs| self.convert_tools(specs));
let request = MistralRequest {
model: self.config.model_id.clone(),
messages: mistral_messages,
max_tokens: self.config.max_tokens,
temperature: self.config.temperature,
top_p: self.config.top_p,
tools,
stream: true,
};
let url = format!("{}/chat/completions", Self::BASE_URL);
let response = match self.client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
{
Ok(resp) => resp,
Err(e) => {
yield Err(StrandsError::NetworkError(e.to_string()));
return;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if status.as_u16() == 429 {
yield Err(StrandsError::ModelThrottled {
message: "Mistral rate limit exceeded".into(),
});
} else {
yield Err(StrandsError::ModelError {
message: format!("Mistral API error {}: {}", status, body),
source: None,
});
}
return;
}
yield Ok(StreamEvent::message_start(Role::Assistant));
let body = match response.text().await {
Ok(b) => b,
Err(e) => {
yield Err(StrandsError::NetworkError(e.to_string()));
return;
}
};
for line in body.lines() {
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
break;
}
if let Ok(chunk) = serde_json::from_str::<serde_json::Value>(data) {
if let Some(choices) = chunk.get("choices").and_then(|c| c.as_array()) {
for choice in choices {
if let Some(delta) = choice.get("delta") {
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
yield Ok(StreamEvent::text_delta(0, content));
}
}
}
}
}
}
}
yield Ok(StreamEvent::message_stop(StopReason::EndTurn));
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mistral_config() {
let config = MistralConfig::new("mistral-large-latest")
.with_api_key("test-key")
.with_temperature(0.7);
assert_eq!(config.model_id, "mistral-large-latest");
assert_eq!(config.api_key, Some("test-key".to_string()));
assert_eq!(config.temperature, Some(0.7));
}
}