use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::debug;
use crate::error::{LlmError, Result};
use crate::model_config::{
ModelCapabilities, ModelCard, ModelType, ProviderConfig, ProviderType as ConfigProviderType,
};
use crate::providers::openai_compatible::OpenAICompatibleProvider;
use crate::traits::StreamChunk;
use crate::traits::{
ChatMessage, CompletionOptions, EmbeddingProvider, LLMProvider, LLMResponse, ToolChoice,
ToolDefinition,
};
const MISTRAL_BASE_URL: &str = "https://api.mistral.ai/v1";
const MISTRAL_DEFAULT_MODEL: &str = "mistral-small-latest";
const MISTRAL_DEFAULT_EMBEDDING_MODEL: &str = "mistral-embed";
const MISTRAL_DEFAULT_MAX_OUTPUT_TOKENS: usize = 4096;
const MISTRAL_FRONTIER_MAX_OUTPUT_TOKENS: usize = 16384;
const MISTRAL_EMBED_DIMENSION: usize = 1024;
const MISTRAL_EMBED_MAX_TOKENS: usize = 8192;
const MISTRAL_PROVIDER_NAME: &str = "mistral";
const MISTRAL_CHAT_MODELS: &[(&str, &str, usize, bool, bool)] = &[
(
"mistral-large-latest", "Mistral Large 3 (latest)",
262_144, true, true, ),
(
"mistral-large-2512",
"Mistral Large 3 (2512)",
262_144,
true,
true,
),
(
"mistral-medium-latest", "Mistral Medium 3.1 (latest)",
131_072, true,
true,
),
(
"mistral-medium-2508",
"Mistral Medium 3.1 (2508)",
131_072,
true,
true,
),
(
"mistral-small-latest", "Mistral Small 4 (latest)",
262_144, true,
true,
),
(
"mistral-small-2603",
"Mistral Small 4 (2603)",
262_144,
true,
true,
),
(
"magistral-medium-latest", "Magistral Medium 1.2 (latest)",
131_072,
false,
true,
),
(
"magistral-small-latest", "Magistral Small 1.2 (latest)",
131_072,
false,
true,
),
(
"codestral-latest", "Codestral (latest)",
262_144, false,
false, ),
("codestral-2508", "Codestral 2508", 262_144, false, false),
(
"devstral-latest", "Devstral 2 (latest)",
131_072,
false,
true,
),
(
"devstral-small-latest",
"Devstral Small (latest)",
131_072,
false,
true,
),
(
"ministral-3b-latest",
"Ministral 3 3B (latest)",
131_072,
true,
true,
),
(
"ministral-8b-latest",
"Ministral 3 8B (latest)",
131_072,
true,
true,
),
(
"ministral-14b-latest",
"Ministral 3 14B (latest)",
131_072,
true,
true,
),
(
"open-mistral-nemo", "Mistral Nemo 12B (open weights)",
131_072,
false,
true,
),
(
"mistral-small-2506", "Mistral Small 3.2 (2506, legacy)",
131_072,
false,
true,
),
(
"mistral-large-2411", "Mistral Large 2411 (deprecated)",
131_072,
false,
true,
),
(
"pixtral-large-2411", "Pixtral Large 2411 (deprecated)",
131_072,
true,
true,
),
(
"pixtral-12b-2409", "Pixtral 12B 2409 (deprecated)",
131_072,
true,
true,
),
(
"open-mistral-7b",
"Mistral 7B (open weights, deprecated)",
32_768,
false,
false,
),
(
"open-mixtral-8x7b",
"Mixtral 8x7B (open weights, deprecated)",
32_768,
false,
true,
),
(
"open-mixtral-8x22b",
"Mixtral 8x22B (open weights, deprecated)",
65_536,
false,
true,
),
(
"codestral-2501",
"Codestral 2501 (deprecated)",
262_144,
false,
false,
),
(
"mistral-small-2501", "Mistral Small 2501 (deprecated)",
32_768,
false,
true,
),
];
#[derive(Debug, Serialize)]
struct EmbeddingRequest<'a> {
model: &'a str,
input: &'a [String],
#[serde(skip_serializing_if = "Option::is_none")]
encoding_format: Option<&'a str>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
index: usize,
}
#[derive(Debug, Deserialize)]
pub struct MistralModelsResponse {
pub data: Vec<MistralModelInfo>,
}
#[derive(Debug, Deserialize)]
pub struct MistralModelInfo {
pub id: String,
#[serde(default)]
pub created: Option<u64>,
#[serde(default)]
pub owned_by: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub max_context_length: Option<usize>,
#[serde(default)]
pub capabilities: Option<MistralModelCapabilities>,
}
#[derive(Debug, Deserialize)]
pub struct MistralModelCapabilities {
#[serde(default)]
pub completion_chat: bool,
#[serde(default)]
pub completion_fim: bool,
#[serde(default)]
pub function_calling: bool,
#[serde(default)]
pub fine_tuning: bool,
#[serde(default)]
pub vision: bool,
}
#[derive(Debug)]
pub struct MistralProvider {
inner: OpenAICompatibleProvider,
model: String,
embedding_model: String,
base_url: String,
api_key: String,
client: Client,
}
impl MistralProvider {
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("MISTRAL_API_KEY").map_err(|_| {
LlmError::ConfigError(
"MISTRAL_API_KEY environment variable not set. \
Get your API key from https://console.mistral.ai"
.to_string(),
)
})?;
if api_key.is_empty() {
return Err(LlmError::ConfigError(
"MISTRAL_API_KEY is empty. Please set a valid API key.".to_string(),
));
}
let model =
std::env::var("MISTRAL_MODEL").unwrap_or_else(|_| MISTRAL_DEFAULT_MODEL.to_string());
let embedding_model = std::env::var("MISTRAL_EMBEDDING_MODEL")
.unwrap_or_else(|_| MISTRAL_DEFAULT_EMBEDDING_MODEL.to_string());
let base_url =
std::env::var("MISTRAL_BASE_URL").unwrap_or_else(|_| MISTRAL_BASE_URL.to_string());
Self::new(api_key, model, embedding_model, Some(base_url))
}
pub fn from_config(config: &ProviderConfig) -> Result<Self> {
let api_key = if let Some(env_var) = &config.api_key_env {
std::env::var(env_var).map_err(|_| {
LlmError::ConfigError(format!(
"API key environment variable '{}' not set for Mistral provider.",
env_var
))
})?
} else {
return Err(LlmError::ConfigError(
"Mistral provider requires api_key_env to be set.".to_string(),
));
};
let model = config
.default_llm_model
.clone()
.unwrap_or_else(|| MISTRAL_DEFAULT_MODEL.to_string());
let embedding_model = config
.default_embedding_model
.clone()
.unwrap_or_else(|| MISTRAL_DEFAULT_EMBEDDING_MODEL.to_string());
let base_url = config
.base_url
.clone()
.unwrap_or_else(|| MISTRAL_BASE_URL.to_string());
Self::new(api_key, model, embedding_model, Some(base_url))
}
pub fn new(
api_key: String,
model: String,
embedding_model: String,
base_url: Option<String>,
) -> Result<Self> {
let base_url = base_url.unwrap_or_else(|| MISTRAL_BASE_URL.to_string());
let config = Self::build_provider_config(&api_key, &model, &embedding_model, &base_url);
let inner = OpenAICompatibleProvider::from_config(config)?;
let client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.map_err(|e| LlmError::ConfigError(format!("Failed to build HTTP client: {}", e)))?;
debug!(
provider = MISTRAL_PROVIDER_NAME,
model = %model,
base_url = %base_url,
"Created Mistral provider"
);
Ok(Self {
inner,
model,
embedding_model,
base_url,
api_key,
client,
})
}
pub fn with_model(mut self, model: &str) -> Self {
self.model = model.to_string();
self.inner = self.inner.with_model(model);
self
}
pub fn with_embedding_model(mut self, model: &str) -> Self {
self.embedding_model = model.to_string();
self
}
pub fn context_length(model: &str) -> usize {
MISTRAL_CHAT_MODELS
.iter()
.find(|(id, _, _, _, _)| *id == model)
.map(|(_, _, ctx, _, _)| *ctx)
.unwrap_or(32768)
}
pub fn available_models() -> Vec<(&'static str, &'static str, usize)> {
MISTRAL_CHAT_MODELS
.iter()
.map(|(id, name, ctx, _, _)| (*id, *name, *ctx))
.collect()
}
pub async fn list_models(&self) -> Result<MistralModelsResponse> {
let url = format!("{}/models", self.base_url.trim_end_matches('/'));
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Accept", "application/json")
.send()
.await
.map_err(|e| LlmError::NetworkError(format!("Failed to list Mistral models: {}", e)))?;
let status = response.status();
let body = response.text().await.map_err(|e| {
LlmError::NetworkError(format!("Failed to read model list response: {}", e))
})?;
if !status.is_success() {
return Err(LlmError::ApiError(format!(
"Mistral models list failed ({status}): {body}"
)));
}
serde_json::from_str(&body)
.map_err(|e| LlmError::ApiError(format!("Failed to parse models response: {e}")))
}
fn build_provider_config(
api_key: &str,
model: &str,
embedding_model: &str,
base_url: &str,
) -> ProviderConfig {
let models: Vec<ModelCard> = MISTRAL_CHAT_MODELS
.iter()
.map(|(id, display, ctx, vision, fc)| {
let max_output = if id.contains("large") || id.contains("medium") {
MISTRAL_FRONTIER_MAX_OUTPUT_TOKENS
} else {
MISTRAL_DEFAULT_MAX_OUTPUT_TOKENS
};
ModelCard {
name: id.to_string(),
display_name: display.to_string(),
model_type: ModelType::Llm,
capabilities: ModelCapabilities {
context_length: *ctx,
max_output_tokens: max_output,
supports_vision: *vision,
supports_function_calling: *fc,
supports_json_mode: true,
supports_streaming: true,
supports_system_message: true,
..Default::default()
},
..Default::default()
}
})
.chain(std::iter::once(ModelCard {
name: MISTRAL_DEFAULT_EMBEDDING_MODEL.to_string(),
display_name: "Mistral Embed".to_string(),
model_type: ModelType::Embedding,
capabilities: ModelCapabilities {
context_length: MISTRAL_EMBED_MAX_TOKENS,
embedding_dimension: MISTRAL_EMBED_DIMENSION,
..Default::default()
},
..Default::default()
}))
.collect();
ProviderConfig {
name: MISTRAL_PROVIDER_NAME.to_string(),
display_name: "Mistral AI".to_string(),
provider_type: ConfigProviderType::OpenAICompatible,
api_key: Some(api_key.to_string()),
api_key_env: Some("MISTRAL_API_KEY".to_string()),
base_url: Some(base_url.to_string()),
base_url_env: Some("MISTRAL_BASE_URL".to_string()),
default_llm_model: Some(model.to_string()),
default_embedding_model: Some(embedding_model.to_string()),
models,
enabled: true,
..Default::default()
}
}
}
#[async_trait]
impl LLMProvider for MistralProvider {
fn name(&self) -> &str {
MISTRAL_PROVIDER_NAME
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
Self::context_length(&self.model)
}
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
self.inner.complete(prompt).await
}
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
self.inner.complete_with_options(prompt, options).await
}
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
self.inner.chat(messages, options).await
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
self.inner
.chat_with_tools(messages, tools, tool_choice, options)
.await
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<BoxStream<'static, Result<StreamChunk>>> {
self.inner
.chat_with_tools_stream(messages, tools, tool_choice, options)
.await
}
async fn stream(&self, prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
self.inner.stream(prompt).await
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_function_calling(&self) -> bool {
let fc_capable = MISTRAL_CHAT_MODELS
.iter()
.find(|(id, _, _, _, _)| *id == self.model.as_str())
.map(|(_, _, _, _, fc)| *fc)
.unwrap_or(true);
fc_capable
}
fn supports_json_mode(&self) -> bool {
true
}
fn supports_tool_streaming(&self) -> bool {
self.inner.supports_tool_streaming()
}
}
#[async_trait]
impl EmbeddingProvider for MistralProvider {
fn name(&self) -> &str {
MISTRAL_PROVIDER_NAME
}
#[allow(clippy::misnamed_getters)]
fn model(&self) -> &str {
&self.embedding_model
}
fn dimension(&self) -> usize {
MISTRAL_EMBED_DIMENSION
}
fn max_tokens(&self) -> usize {
MISTRAL_EMBED_MAX_TOKENS
}
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
let request_body = EmbeddingRequest {
model: &self.embedding_model,
input: texts,
encoding_format: Some("float"),
};
debug!(
model = self.embedding_model,
count = texts.len(),
"Mistral embed request"
);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| {
LlmError::NetworkError(format!("Mistral embeddings request failed: {}", e))
})?;
let status = response.status();
let body = response.text().await.map_err(|e| {
LlmError::NetworkError(format!("Failed to read Mistral embeddings response: {}", e))
})?;
if !status.is_success() {
return Err(LlmError::ApiError(format!(
"Mistral embeddings API error ({status}): {body}"
)));
}
let embedding_response: EmbeddingResponse = serde_json::from_str(&body).map_err(|e| {
LlmError::ApiError(format!(
"Failed to parse Mistral embeddings response: {e} | body: {}",
&body[..body.len().min(500)]
))
})?;
let mut data = embedding_response.data;
data.sort_by_key(|d| d.index);
let embeddings: Vec<Vec<f32>> = data.into_iter().map(|d| d.embedding).collect();
if embeddings.len() != texts.len() {
return Err(LlmError::ApiError(format!(
"Mistral returned {} embeddings for {} inputs",
embeddings.len(),
texts.len()
)));
}
Ok(embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_available_models_not_empty() {
let models = MistralProvider::available_models();
assert!(!models.is_empty());
}
#[test]
fn test_available_models_contains_expected_ids() {
let models = MistralProvider::available_models();
let ids: Vec<&str> = models.iter().map(|(id, _, _)| *id).collect();
assert!(
ids.contains(&"mistral-large-latest"),
"Should contain mistral-large-latest"
);
assert!(
ids.contains(&"mistral-medium-latest"),
"Should contain mistral-medium-latest"
);
assert!(
ids.contains(&"mistral-small-latest"),
"Should contain mistral-small-latest"
);
assert!(
ids.contains(&"codestral-latest"),
"Should contain codestral-latest"
);
assert!(
ids.contains(&"devstral-latest"),
"Should contain devstral-latest"
);
assert!(
ids.contains(&"ministral-3b-latest"),
"Should contain ministral-3b-latest"
);
assert!(
ids.contains(&"ministral-8b-latest"),
"Should contain ministral-8b-latest"
);
assert!(
ids.contains(&"ministral-14b-latest"),
"Should contain ministral-14b-latest"
);
assert!(
ids.contains(&"magistral-medium-latest"),
"Should contain magistral-medium-latest"
);
assert!(
ids.contains(&"magistral-small-latest"),
"Should contain magistral-small-latest"
);
}
#[test]
fn test_context_length_known_models() {
assert_eq!(
MistralProvider::context_length("mistral-large-latest"),
262_144
);
assert_eq!(
MistralProvider::context_length("mistral-small-latest"),
262_144
);
assert_eq!(MistralProvider::context_length("codestral-latest"), 262_144);
assert_eq!(
MistralProvider::context_length("mistral-medium-latest"),
131_072
);
assert_eq!(
MistralProvider::context_length("open-mistral-nemo"),
131_072
);
assert_eq!(
MistralProvider::context_length("ministral-3b-latest"),
131_072
);
assert_eq!(
MistralProvider::context_length("ministral-8b-latest"),
131_072
);
assert_eq!(
MistralProvider::context_length("ministral-14b-latest"),
131_072
);
}
#[test]
fn test_context_length_unknown_model() {
assert_eq!(MistralProvider::context_length("unknown-model"), 32768);
}
#[test]
fn test_all_models_have_positive_context() {
for (id, _, ctx) in MistralProvider::available_models() {
assert!(ctx > 0, "Model {} must have positive context length", id);
}
}
#[test]
fn test_provider_name_constant() {
assert_eq!(MISTRAL_PROVIDER_NAME, "mistral");
}
#[test]
fn test_default_model_constant() {
assert_eq!(MISTRAL_DEFAULT_MODEL, "mistral-small-latest");
}
#[test]
fn test_default_embedding_model_constant() {
assert_eq!(MISTRAL_DEFAULT_EMBEDDING_MODEL, "mistral-embed");
}
#[test]
fn test_embed_dimension_constant() {
assert_eq!(MISTRAL_EMBED_DIMENSION, 1024);
}
#[test]
fn test_base_url_constant() {
assert_eq!(MISTRAL_BASE_URL, "https://api.mistral.ai/v1");
}
#[test]
fn test_build_provider_config_name() {
let cfg = MistralProvider::build_provider_config(
"key",
"mistral-small-latest",
"mistral-embed",
MISTRAL_BASE_URL,
);
assert_eq!(cfg.name, "mistral");
assert_eq!(cfg.display_name, "Mistral AI");
}
#[test]
fn test_build_provider_config_base_url() {
let cfg = MistralProvider::build_provider_config(
"key",
"mistral-small-latest",
"mistral-embed",
"https://custom.api/v1",
);
assert_eq!(cfg.base_url, Some("https://custom.api/v1".to_string()));
}
#[test]
fn test_build_provider_config_models_include_embedding() {
let cfg = MistralProvider::build_provider_config(
"key",
"mistral-small-latest",
"mistral-embed",
MISTRAL_BASE_URL,
);
let embedding_cards: Vec<_> = cfg
.models
.iter()
.filter(|m| m.model_type == ModelType::Embedding)
.collect();
assert_eq!(embedding_cards.len(), 1);
assert_eq!(embedding_cards[0].name, "mistral-embed");
assert_eq!(
embedding_cards[0].capabilities.embedding_dimension,
MISTRAL_EMBED_DIMENSION
);
}
#[test]
fn test_build_provider_config_default_models() {
let cfg = MistralProvider::build_provider_config(
"key",
"mistral-small-latest",
"mistral-embed",
MISTRAL_BASE_URL,
);
assert_eq!(
cfg.default_llm_model,
Some("mistral-small-latest".to_string())
);
assert_eq!(
cfg.default_embedding_model,
Some("mistral-embed".to_string())
);
}
#[test]
fn test_build_provider_config_api_key_env() {
let cfg = MistralProvider::build_provider_config(
"key",
"mistral-small-latest",
"mistral-embed",
MISTRAL_BASE_URL,
);
assert_eq!(cfg.api_key_env, Some("MISTRAL_API_KEY".to_string()));
}
#[test]
fn test_from_env_missing_api_key() {
std::env::remove_var("MISTRAL_API_KEY");
let result = MistralProvider::from_env();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("MISTRAL_API_KEY"));
}
#[test]
fn test_embedding_dimension() {
std::env::set_var("MISTRAL_API_KEY", "test-key-for-unit-test");
let p = MistralProvider::new(
"test-key".to_string(),
MISTRAL_DEFAULT_MODEL.to_string(),
MISTRAL_DEFAULT_EMBEDDING_MODEL.to_string(),
None,
)
.unwrap();
assert_eq!(EmbeddingProvider::dimension(&p), MISTRAL_EMBED_DIMENSION);
assert_eq!(EmbeddingProvider::max_tokens(&p), MISTRAL_EMBED_MAX_TOKENS);
assert_eq!(EmbeddingProvider::name(&p), "mistral");
assert_eq!(EmbeddingProvider::model(&p), "mistral-embed");
std::env::remove_var("MISTRAL_API_KEY");
}
#[test]
fn test_llm_provider_surface() {
std::env::set_var("MISTRAL_API_KEY", "test-key-for-unit-test");
let p = MistralProvider::new(
"test-key".to_string(),
"mistral-large-latest".to_string(),
MISTRAL_DEFAULT_EMBEDDING_MODEL.to_string(),
None,
)
.unwrap();
assert_eq!(LLMProvider::name(&p), "mistral");
assert_eq!(LLMProvider::model(&p), "mistral-large-latest");
assert_eq!(p.max_context_length(), 262_144); assert!(p.supports_streaming());
assert!(p.supports_function_calling());
assert!(p.supports_json_mode());
std::env::remove_var("MISTRAL_API_KEY");
}
#[test]
fn test_with_model_builder() {
std::env::set_var("MISTRAL_API_KEY", "test-key-for-unit-test");
let p = MistralProvider::new(
"test-key".to_string(),
MISTRAL_DEFAULT_MODEL.to_string(),
MISTRAL_DEFAULT_EMBEDDING_MODEL.to_string(),
None,
)
.unwrap()
.with_model("codestral-latest");
assert_eq!(p.model, "codestral-latest");
assert_eq!(p.max_context_length(), 262144);
std::env::remove_var("MISTRAL_API_KEY");
}
#[test]
fn test_with_embedding_model_builder() {
std::env::set_var("MISTRAL_API_KEY", "test-key-for-unit-test");
let p = MistralProvider::new(
"test-key".to_string(),
MISTRAL_DEFAULT_MODEL.to_string(),
MISTRAL_DEFAULT_EMBEDDING_MODEL.to_string(),
None,
)
.unwrap()
.with_embedding_model("custom-embed");
assert_eq!(p.embedding_model, "custom-embed");
std::env::remove_var("MISTRAL_API_KEY");
}
#[test]
fn test_embedding_request_serialization() {
let texts = vec!["hello world".to_string(), "foo bar".to_string()];
let req = EmbeddingRequest {
model: "mistral-embed",
input: &texts,
encoding_format: Some("float"),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["model"], "mistral-embed");
assert_eq!(json["input"][0], "hello world");
assert_eq!(json["input"][1], "foo bar");
assert_eq!(json["encoding_format"], "float");
}
#[test]
fn test_embedding_request_serialization_no_encoding_format() {
let texts = vec!["hello".to_string()];
let req = EmbeddingRequest {
model: "mistral-embed",
input: &texts,
encoding_format: None,
};
let json = serde_json::to_value(&req).unwrap();
assert!(
json.get("encoding_format").is_none(),
"encoding_format should be absent when None"
);
}
#[test]
fn test_embedding_response_deserialization() {
let raw = r#"{
"id": "embd-1",
"object": "list",
"data": [
{"object": "embedding", "embedding": [0.1, 0.2, 0.3], "index": 0},
{"object": "embedding", "embedding": [0.4, 0.5, 0.6], "index": 1}
],
"model": "mistral-embed",
"usage": {"prompt_tokens": 10, "total_tokens": 10}
}"#;
let resp: EmbeddingResponse = serde_json::from_str(raw).unwrap();
assert_eq!(resp.data.len(), 2);
assert_eq!(resp.data[0].embedding, vec![0.1, 0.2, 0.3]);
assert_eq!(resp.data[1].index, 1);
}
#[test]
fn test_model_info_deserialization() {
let raw = r#"{
"data": [
{
"id": "mistral-small-latest",
"created": 1735689600,
"owned_by": "mistralai",
"description": "Mistral Small",
"max_context_length": 32768,
"capabilities": {
"completion_chat": true,
"completion_fim": false,
"function_calling": true,
"fine_tuning": false,
"vision": false
}
}
]
}"#;
let resp: MistralModelsResponse = serde_json::from_str(raw).unwrap();
assert_eq!(resp.data.len(), 1);
let m = &resp.data[0];
assert_eq!(m.id, "mistral-small-latest");
assert_eq!(m.max_context_length, Some(32768));
let caps = m.capabilities.as_ref().unwrap();
assert!(caps.function_calling);
assert!(!caps.vision);
}
#[test]
fn test_vision_model_capabilities() {
let cfg = MistralProvider::build_provider_config(
"key",
"mistral-large-latest",
"mistral-embed",
MISTRAL_BASE_URL,
);
let large = cfg
.models
.iter()
.find(|m| m.name == "mistral-large-latest")
.unwrap();
assert!(large.capabilities.supports_vision);
assert!(large.capabilities.supports_function_calling);
}
#[test]
fn test_ministral_models_in_catalog() {
let cfg = MistralProvider::build_provider_config(
"key",
"mistral-small-latest",
"mistral-embed",
MISTRAL_BASE_URL,
);
for id in &[
"ministral-3b-latest",
"ministral-8b-latest",
"ministral-14b-latest",
] {
let card = cfg.models.iter().find(|m| m.name == *id);
assert!(card.is_some(), "Missing model: {id}");
let card = card.unwrap();
assert_eq!(
card.capabilities.context_length, 131_072,
"Wrong context for {id}"
);
assert!(
card.capabilities.supports_vision,
"{id} should support vision"
);
assert!(
card.capabilities.supports_function_calling,
"{id} should support FC"
);
}
}
#[test]
fn test_reasoning_models_in_catalog() {
let cfg = MistralProvider::build_provider_config(
"key",
"magistral-medium-latest",
"mistral-embed",
MISTRAL_BASE_URL,
);
for id in &["magistral-medium-latest", "magistral-small-latest"] {
let card = cfg.models.iter().find(|m| m.name == *id);
assert!(card.is_some(), "Missing reasoning model: {id}");
let card = card.unwrap();
assert_eq!(
card.capabilities.context_length, 131_072,
"Wrong context for {id}"
);
assert!(
card.capabilities.supports_function_calling,
"{id} should support FC"
);
}
}
#[test]
fn test_frontier_models_have_256k_context() {
for id in &[
"mistral-large-latest",
"mistral-small-latest",
"codestral-latest",
] {
assert_eq!(
MistralProvider::context_length(id),
262_144,
"Expected 256K context for {id}"
);
}
}
#[test]
fn test_provider_config_sets_api_key_directly() {
let cfg = MistralProvider::build_provider_config(
"secret-api-key",
"mistral-small-latest",
"mistral-embed",
MISTRAL_BASE_URL,
);
assert_eq!(cfg.api_key.as_deref(), Some("secret-api-key"));
}
#[test]
fn test_new_does_not_require_env_var() {
std::env::remove_var("MISTRAL_API_KEY");
let result = MistralProvider::new(
"explicit-key".to_string(),
MISTRAL_DEFAULT_MODEL.to_string(),
MISTRAL_DEFAULT_EMBEDDING_MODEL.to_string(),
None,
);
assert!(
result.is_ok(),
"MistralProvider::new() should succeed without env var when key is provided directly"
);
}
#[test]
fn test_frontier_max_output_tokens() {
let cfg = MistralProvider::build_provider_config(
"key",
"mistral-large-latest",
"mistral-embed",
MISTRAL_BASE_URL,
);
let large = cfg
.models
.iter()
.find(|m| m.name == "mistral-large-latest")
.unwrap();
assert_eq!(
large.capabilities.max_output_tokens, MISTRAL_FRONTIER_MAX_OUTPUT_TOKENS,
"Frontier large model should have 16 384 output tokens"
);
let medium = cfg
.models
.iter()
.find(|m| m.name == "mistral-medium-latest")
.unwrap();
assert_eq!(
medium.capabilities.max_output_tokens, MISTRAL_FRONTIER_MAX_OUTPUT_TOKENS,
"Frontier medium model should have 16 384 output tokens"
);
}
}