use crate::llm::retry;
use crate::llm::traits::AiProvider;
use crate::llm::types::{
ChatCompletionParams, Message, ProviderExchange, ProviderResponse, ThinkingBlock, TokenUsage,
ToolCall,
};
use crate::llm::utils::{normalize_model_name, starts_with_ignore_ascii_case};
use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
const PRICING: &[(&str, f64, f64)] = &[
("MiniMax-M2.1-lightning", 0.30, 2.40),
("MiniMax-M2.1", 0.30, 1.20),
("MiniMax-M2", 0.30, 1.20),
];
struct CacheTokenUsage {
regular_input_tokens: u64,
cache_creation_tokens: u64,
cache_read_tokens: u64,
output_tokens: u64,
}
fn calculate_cost_with_cache(model: &str, usage: CacheTokenUsage) -> Option<f64> {
for (pricing_model, input_price, output_price) in PRICING {
if model.contains(pricing_model) {
let regular_input_cost =
(usage.regular_input_tokens as f64 / 1_000_000.0) * input_price;
let cache_creation_cost = (usage.cache_creation_tokens as f64 / 1_000_000.0) * 0.375;
let cache_read_cost = (usage.cache_read_tokens as f64 / 1_000_000.0) * 0.03;
let output_cost = (usage.output_tokens as f64 / 1_000_000.0) * output_price;
return Some(regular_input_cost + cache_creation_cost + cache_read_cost + output_cost);
}
}
None
}
fn calculate_minimax_cost(
model: &str,
input_tokens: u32,
output_tokens: u32,
cache_creation_tokens: u32,
cache_read_tokens: u32,
) -> Option<f64> {
let regular_input_tokens =
input_tokens.saturating_sub(cache_creation_tokens + cache_read_tokens);
let usage = CacheTokenUsage {
regular_input_tokens: regular_input_tokens as u64,
cache_creation_tokens: cache_creation_tokens as u64,
cache_read_tokens: cache_read_tokens as u64,
output_tokens: output_tokens as u64,
};
calculate_cost_with_cache(model, usage)
}
#[derive(Debug, Clone, Default)]
pub struct MinimaxProvider;
impl MinimaxProvider {
pub fn new() -> Self {
Self
}
}
const MINIMAX_API_KEY_ENV: &str = "MINIMAX_API_KEY";
const MINIMAX_API_URL_ENV: &str = "MINIMAX_API_URL";
const MINIMAX_API_URL: &str = "https://api.minimax.io/anthropic/v1/messages";
#[async_trait]
impl AiProvider for MinimaxProvider {
fn name(&self) -> &str {
"minimax"
}
fn supports_model(&self, model: &str) -> bool {
starts_with_ignore_ascii_case(model, "minimax-m2")
}
fn get_api_key(&self) -> Result<String> {
env::var(MINIMAX_API_KEY_ENV)
.map_err(|_| anyhow::anyhow!("MINIMAX_API_KEY not found in environment"))
}
fn supports_caching(&self, _model: &str) -> bool {
true }
fn supports_vision(&self, _model: &str) -> bool {
false }
fn supports_structured_output(&self, _model: &str) -> bool {
true }
fn get_max_input_tokens(&self, model: &str) -> usize {
let model_lower = normalize_model_name(model);
if model_lower.contains("minimax-m2.1") || model_lower.contains("minimax-m2") {
1_000_000 } else {
128_000 }
}
async fn chat_completion(&self, params: ChatCompletionParams) -> Result<ProviderResponse> {
let api_key = self.get_api_key()?;
let minimax_messages = convert_messages(¶ms.messages);
let system_message = params
.messages
.iter()
.find(|m| m.role == "system")
.map(|m| m.content.clone())
.unwrap_or_else(|| "You are a helpful assistant.".to_string());
let system_cached = params
.messages
.iter()
.any(|m| m.role == "system" && m.cached);
if params.temperature <= 0.0 || params.temperature > 1.0 {
return Err(anyhow::anyhow!(
"MiniMax requires temperature in range (0.0, 1.0], got {}",
params.temperature
));
}
let mut request_body = serde_json::json!({
"model": params.model,
"messages": minimax_messages,
"temperature": params.temperature,
"top_p": params.top_p,
});
if params.max_tokens > 0 {
request_body["max_tokens"] = serde_json::json!(params.max_tokens);
}
if system_cached {
request_body["system"] = serde_json::json!([{
"type": "text",
"text": system_message,
"cache_control": {
"type": "ephemeral"
}
}]);
} else {
request_body["system"] = serde_json::json!(system_message);
}
if let Some(response_format) = ¶ms.response_format {
match &response_format.format {
crate::llm::types::OutputFormat::Json => {
request_body["response_format"] = serde_json::json!({
"type": "json_object"
});
}
crate::llm::types::OutputFormat::JsonSchema => {
if let Some(schema) = &response_format.schema {
let mut format_obj = serde_json::json!({
"type": "json_schema",
"json_schema": {
"schema": schema
}
});
if matches!(
response_format.mode,
crate::llm::types::ResponseMode::Strict
) {
format_obj["json_schema"]["strict"] = serde_json::json!(true);
}
request_body["response_format"] = format_obj;
}
}
}
}
if let Some(tools) = ¶ms.tools {
if !tools.is_empty() {
let mut sorted_tools = tools.clone();
sorted_tools.sort_by(|a, b| a.name.cmp(&b.name));
let minimax_tools = sorted_tools
.iter()
.map(|f| {
let mut tool = serde_json::json!({
"name": f.name,
"description": f.description,
"input_schema": f.parameters
});
if let Some(ref cache_control) = f.cache_control {
tool["cache_control"] = cache_control.clone();
}
tool
})
.collect::<Vec<_>>();
request_body["tools"] = serde_json::json!(minimax_tools);
}
}
let api_url = env::var(MINIMAX_API_URL_ENV).unwrap_or_else(|_| MINIMAX_API_URL.to_string());
let response = execute_minimax_request(
api_key,
api_url,
request_body,
params.max_retries,
params.retry_timeout,
params.cancellation_token.as_ref(),
)
.await?;
Ok(response)
}
}
#[derive(Serialize, Deserialize, Debug)]
struct MinimaxMessage {
role: String,
content: Vec<MinimaxContent>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
enum MinimaxContent {
#[serde(rename = "text")]
Text {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<serde_json::Value>,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<serde_json::Value>,
},
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
}
#[derive(Deserialize, Debug)]
struct MinimaxResponse {
content: Vec<MinimaxResponseContent>,
usage: MinimaxUsage,
#[serde(default)]
stop_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type")]
enum MinimaxResponseContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "thinking")]
Thinking { thinking: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
}
#[derive(Deserialize, Debug)]
struct MinimaxUsage {
input_tokens: u64,
output_tokens: u64,
#[serde(default)]
cache_creation_input_tokens: Option<u64>,
#[serde(default)]
cache_read_input_tokens: Option<u64>,
}
fn convert_messages(messages: &[Message]) -> Vec<MinimaxMessage> {
let mut result = Vec::new();
for message in messages {
if message.role == "system" {
continue;
}
match message.role.as_str() {
"tool" => {
let tool_call_id = message.tool_call_id.as_deref().unwrap_or("");
let content = vec![MinimaxContent::ToolResult {
tool_use_id: tool_call_id.to_string(),
content: message.content.clone(),
cache_control: if message.cached {
Some(serde_json::json!({"type": "ephemeral"}))
} else {
None
},
}];
result.push(MinimaxMessage {
role: "user".to_string(), content,
});
}
_ => {
if message.role == "assistant" && message.tool_calls.is_some() {
let mut content = Vec::new();
if !message.content.trim().is_empty() {
content.push(MinimaxContent::Text {
text: message.content.clone(),
cache_control: if message.cached {
Some(serde_json::json!({"type": "ephemeral"}))
} else {
None
},
});
}
if let Some(ref tool_calls_data) = message.tool_calls {
if let Ok(generic_calls) = serde_json::from_value::<
Vec<crate::llm::tool_calls::GenericToolCall>,
>(tool_calls_data.clone())
{
for call in generic_calls {
content.push(MinimaxContent::ToolUse {
id: call.id,
name: call.name,
input: call.arguments,
});
}
}
}
result.push(MinimaxMessage {
role: message.role.clone(),
content,
});
} else {
let content = vec![MinimaxContent::Text {
text: message.content.clone(),
cache_control: if message.cached {
Some(serde_json::json!({"type": "ephemeral"}))
} else {
None
},
}];
result.push(MinimaxMessage {
role: message.role.clone(),
content,
});
}
}
}
}
result
}
async fn execute_minimax_request(
api_key: String,
api_url: String,
request_body: serde_json::Value,
max_retries: u32,
base_timeout: std::time::Duration,
cancellation_token: Option<&tokio::sync::watch::Receiver<bool>>,
) -> Result<ProviderResponse> {
let client = Client::new();
let start_time = std::time::Instant::now();
let response = retry::retry_with_exponential_backoff(
|| {
let client = client.clone();
let api_key = api_key.clone();
let api_url = api_url.clone();
let request_body = request_body.clone();
Box::pin(async move {
client
.post(&api_url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.header("anthropic-version", "2023-06-01")
.json(&request_body)
.send()
.await
})
},
max_retries,
base_timeout,
cancellation_token,
)
.await?;
let request_time_ms = start_time.elapsed().as_millis() as u64;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(anyhow::anyhow!(
"MiniMax API error {}: {}",
status,
error_text
));
}
let response_text = response.text().await?;
let minimax_response: MinimaxResponse = serde_json::from_str(&response_text)?;
let mut content_parts = Vec::new();
let mut thinking_parts = Vec::new();
let mut tool_calls = Vec::new();
for content in minimax_response.content {
match content {
MinimaxResponseContent::Text { text } => {
content_parts.push(text);
}
MinimaxResponseContent::Thinking { thinking } => {
thinking_parts.push(thinking);
}
MinimaxResponseContent::ToolUse { id, name, input } => {
tool_calls.push(ToolCall {
id: id.clone(),
name: name.clone(),
arguments: input,
});
}
}
}
let final_content = content_parts.join("\n");
let (thinking, reasoning_tokens) = if thinking_parts.is_empty() {
(None, 0)
} else {
let thinking_content = thinking_parts.join("\n\n");
let estimated = (thinking_content.len() / 4) as u64;
(
Some(ThinkingBlock {
content: thinking_content,
tokens: estimated,
}),
estimated,
)
};
let cached_tokens = minimax_response.usage.cache_read_input_tokens.unwrap_or(0);
let cache_creation_tokens = minimax_response
.usage
.cache_creation_input_tokens
.unwrap_or(0);
let cost = calculate_minimax_cost(
request_body["model"].as_str().unwrap_or(""),
minimax_response.usage.input_tokens as u32,
minimax_response.usage.output_tokens as u32,
cache_creation_tokens as u32,
cached_tokens as u32,
);
let usage = TokenUsage {
prompt_tokens: minimax_response.usage.input_tokens,
output_tokens: minimax_response.usage.output_tokens,
reasoning_tokens, total_tokens: minimax_response.usage.input_tokens + minimax_response.usage.output_tokens,
cached_tokens,
cost,
request_time_ms: Some(request_time_ms),
};
let mut response_json: serde_json::Value = serde_json::from_str(&response_text)?;
if !tool_calls.is_empty() {
let generic_calls: Vec<crate::llm::tool_calls::GenericToolCall> = tool_calls
.iter()
.map(|tc| crate::llm::tool_calls::GenericToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
meta: None, })
.collect();
response_json["tool_calls"] = serde_json::to_value(&generic_calls).unwrap_or_default();
}
let exchange = ProviderExchange::new(request_body, response_json, Some(usage), "minimax");
let structured_output =
if final_content.trim().starts_with('{') || final_content.trim().starts_with('[') {
serde_json::from_str(&final_content).ok()
} else {
None
};
Ok(ProviderResponse {
content: final_content,
thinking, exchange,
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
finish_reason: minimax_response.stop_reason,
structured_output,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_support() {
let provider = MinimaxProvider::new();
assert!(provider.supports_model("MiniMax-M2.1"));
assert!(provider.supports_model("MiniMax-M2.1-lightning"));
assert!(provider.supports_model("MiniMax-M2"));
assert!(!provider.supports_model("gpt-4"));
assert!(!provider.supports_model("claude-3"));
}
#[test]
fn test_model_support_case_insensitive() {
let provider = MinimaxProvider::new();
assert!(provider.supports_model("minimax-m2.1"));
assert!(provider.supports_model("minimax-m2.1-lightning"));
assert!(provider.supports_model("minimax-m2"));
assert!(provider.supports_model("MINIMAX-M2.1"));
assert!(provider.supports_model("MINIMAX-M2"));
assert!(provider.supports_model("Minimax-M2.1"));
assert!(provider.supports_model("MINIMAX-m2.1"));
}
#[test]
fn test_cost_calculation() {
let cost = calculate_minimax_cost("MiniMax-M2.1", 1_000_000, 1_000_000, 0, 0);
assert_eq!(cost, Some(1.5));
let cost = calculate_minimax_cost("MiniMax-M2.1", 1_000_000, 1_000_000, 500_000, 0);
assert_eq!(cost, Some(1.5375));
let cost = calculate_minimax_cost("MiniMax-M2.1", 1_000_000, 1_000_000, 0, 500_000);
assert_eq!(cost, Some(1.365));
let cost = calculate_minimax_cost("MiniMax-M2.1-lightning", 1_000_000, 1_000_000, 0, 0);
assert!((cost.unwrap() - 2.7).abs() < 0.0001);
}
#[test]
fn test_provider_capabilities() {
let provider = MinimaxProvider::new();
assert!(provider.supports_caching("MiniMax-M2.1"));
assert!(!provider.supports_vision("MiniMax-M2.1"));
assert!(provider.supports_structured_output("MiniMax-M2.1"));
}
}