use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Instant;
use crate::provider::{LlmError, LlmProvider, LlmRequest, LlmResponse};
#[derive(Debug, Serialize)]
struct MistralRequest {
model: String,
messages: Vec<Message>,
temperature: f32,
max_tokens: u32,
}
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct MistralResponse {
choices: Vec<Choice>,
model: String,
usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: MessageContent,
}
#[derive(Debug, Deserialize)]
struct MessageContent {
content: String,
}
#[derive(Debug, Deserialize)]
struct Usage {
total_tokens: u32,
}
#[derive(Debug)]
pub struct MistralProvider {
api_key: String,
model: String,
client: reqwest::Client,
base_url: String,
}
impl MistralProvider {
pub fn new(api_key: &str, model: &str) -> Self {
Self {
api_key: api_key.to_string(),
model: model.to_string(),
client: reqwest::Client::new(),
base_url: "https://api.mistral.ai".to_string(),
}
}
pub fn large(api_key: &str) -> Self {
Self::new(api_key, "mistral-large-latest")
}
pub fn medium(api_key: &str) -> Self {
Self::new(api_key, "mistral-medium-latest")
}
pub fn small(api_key: &str) -> Self {
Self::new(api_key, "mistral-small-latest")
}
pub fn codestral(api_key: &str) -> Self {
Self::new(api_key, "codestral-latest")
}
pub fn devstral(api_key: &str) -> Self {
Self::new(api_key, "devstral-small-latest")
}
pub fn ministral_8b(api_key: &str) -> Self {
Self::new(api_key, "ministral-8b-latest")
}
pub fn ministral_3b(api_key: &str) -> Self {
Self::new(api_key, "ministral-3b-latest")
}
pub fn pixtral(api_key: &str) -> Self {
Self::new(api_key, "pixtral-large-latest")
}
pub fn nemo(api_key: &str) -> Self {
Self::new(api_key, "open-mistral-nemo")
}
pub fn with_base_url(mut self, base_url: &str) -> Self {
self.base_url = base_url.to_string();
self
}
}
#[async_trait]
impl LlmProvider for MistralProvider {
fn name(&self) -> &str {
"mistral"
}
async fn is_available(&self) -> bool {
self.client
.get(format!("{}/v1/models", self.base_url))
.bearer_auth(&self.api_key)
.send()
.await
.is_ok()
}
async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
let start = Instant::now();
let url = format!("{}/v1/chat/completions", self.base_url);
let messages = vec![
Message {
role: "system".to_string(),
content: request.system,
},
Message {
role: "user".to_string(),
content: request.prompt,
},
];
let mistral_request = MistralRequest {
model: self.model.clone(),
messages,
temperature: request.temperature,
max_tokens: request.max_tokens,
};
let response = self
.client
.post(&url)
.bearer_auth(&self.api_key)
.json(&mistral_request)
.send()
.await
.map_err(|e| LlmError::ConnectionFailed(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
return Err(LlmError::RateLimited);
}
return Err(LlmError::RequestFailed(format!(
"Status: {}, Body: {}",
status, body
)));
}
let api_response: MistralResponse = response
.json()
.await
.map_err(|e| LlmError::InvalidResponse(e.to_string()))?;
let content = api_response
.choices
.first()
.map(|c| c.message.content.clone())
.unwrap_or_default();
Ok(LlmResponse {
content,
model: api_response.model,
tokens_used: api_response.usage.map(|u| u.total_tokens),
latency_ms: start.elapsed().as_millis() as u64,
trace_root: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore] async fn test_mistral() {
let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
let provider = MistralProvider::small(&api_key);
if provider.is_available().await {
let response = provider.ask("Say hello in one word").await.unwrap();
assert!(!response.is_empty());
println!("Mistral response: {}", response);
}
}
#[tokio::test]
#[ignore] async fn test_mistral_large() {
let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
let provider = MistralProvider::large(&api_key);
if provider.is_available().await {
let response = provider.ask("What is 2+2?").await.unwrap();
assert!(!response.is_empty());
println!("Mistral Large response: {}", response);
}
}
#[tokio::test]
#[ignore] async fn test_codestral() {
let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
let provider = MistralProvider::codestral(&api_key);
if provider.is_available().await {
let response = provider
.ask("Write a simple hello world function in Rust")
.await
.unwrap();
assert!(!response.is_empty());
println!("Codestral response: {}", response);
}
}
}