use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ApiProtocol {
AnthropicMessages,
OpenAiCompletions,
OpenAiResponses,
AzureOpenAiResponses,
GoogleGenerativeAi,
GoogleVertex,
BedrockConverseStream,
}
impl std::fmt::Display for ApiProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AnthropicMessages => write!(f, "anthropic_messages"),
Self::OpenAiCompletions => write!(f, "openai_completions"),
Self::OpenAiResponses => write!(f, "openai_responses"),
Self::AzureOpenAiResponses => write!(f, "azure_openai_responses"),
Self::GoogleGenerativeAi => write!(f, "google_generative_ai"),
Self::GoogleVertex => write!(f, "google_vertex"),
Self::BedrockConverseStream => write!(f, "bedrock_converse_stream"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostConfig {
pub input_per_million: f64,
pub output_per_million: f64,
#[serde(default)]
pub cache_read_per_million: f64,
#[serde(default)]
pub cache_write_per_million: f64,
}
impl Default for CostConfig {
fn default() -> Self {
Self {
input_per_million: 0.0,
output_per_million: 0.0,
cache_read_per_million: 0.0,
cache_write_per_million: 0.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum MaxTokensField {
#[default]
MaxTokens,
MaxCompletionTokens,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ThinkingFormat {
#[default]
OpenAi,
Xai,
Qwen,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAiCompat {
pub supports_store: bool,
pub supports_developer_role: bool,
pub supports_reasoning_effort: bool,
#[serde(default)]
pub supports_thinking_control: bool,
pub supports_usage_in_streaming: bool,
pub max_tokens_field: MaxTokensField,
pub requires_tool_result_name: bool,
#[serde(default)]
pub requires_assistant_after_tool_result: bool,
pub thinking_format: ThinkingFormat,
}
impl Default for OpenAiCompat {
fn default() -> Self {
Self {
supports_store: false,
supports_developer_role: false,
supports_reasoning_effort: false,
supports_thinking_control: false,
supports_usage_in_streaming: true,
max_tokens_field: MaxTokensField::MaxTokens,
requires_tool_result_name: false,
requires_assistant_after_tool_result: false,
thinking_format: ThinkingFormat::OpenAi,
}
}
}
impl OpenAiCompat {
pub fn openai() -> Self {
Self {
supports_store: true,
supports_developer_role: true,
supports_reasoning_effort: true,
supports_usage_in_streaming: true,
max_tokens_field: MaxTokensField::MaxCompletionTokens,
..Default::default()
}
}
pub fn xai() -> Self {
Self {
supports_usage_in_streaming: true,
thinking_format: ThinkingFormat::Xai,
..Default::default()
}
}
pub fn groq() -> Self {
Self {
supports_usage_in_streaming: true,
..Default::default()
}
}
pub fn cerebras() -> Self {
Self::default()
}
pub fn openrouter() -> Self {
Self {
supports_usage_in_streaming: true,
max_tokens_field: MaxTokensField::MaxCompletionTokens,
..Default::default()
}
}
pub fn mistral() -> Self {
Self {
supports_usage_in_streaming: true,
max_tokens_field: MaxTokensField::MaxTokens,
..Default::default()
}
}
pub fn deepseek() -> Self {
Self {
supports_reasoning_effort: true,
supports_thinking_control: true,
supports_usage_in_streaming: true,
max_tokens_field: MaxTokensField::MaxTokens,
..Default::default()
}
}
pub fn zai() -> Self {
Self {
supports_usage_in_streaming: true,
..Default::default()
}
}
pub fn minimax() -> Self {
Self {
supports_usage_in_streaming: true,
..Default::default()
}
}
pub fn qwen() -> Self {
Self {
supports_usage_in_streaming: true,
max_tokens_field: MaxTokensField::MaxTokens,
thinking_format: ThinkingFormat::Qwen,
..Default::default()
}
}
pub fn ollama() -> Self {
Self {
requires_assistant_after_tool_result: true,
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub id: String,
pub name: String,
pub api: ApiProtocol,
pub provider: String,
pub base_url: String,
pub reasoning: bool,
pub context_window: u32,
pub max_tokens: u32,
#[serde(default)]
pub cost: CostConfig,
#[serde(default)]
pub headers: HashMap<String, String>,
#[serde(default)]
pub compat: Option<OpenAiCompat>,
}
impl ModelConfig {
pub fn anthropic(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::AnthropicMessages,
provider: "anthropic".into(),
base_url: "https://api.anthropic.com".into(),
reasoning: false,
context_window: 200_000,
max_tokens: 8192,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: None,
}
}
pub fn openai(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::OpenAiCompletions,
provider: "openai".into(),
base_url: "https://api.openai.com/v1".into(),
reasoning: false,
context_window: 128_000,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::openai()),
}
}
pub fn local(base_url: impl Into<String>, model_id: impl Into<String>) -> Self {
Self {
id: model_id.into(),
name: "Local Model".into(),
api: ApiProtocol::OpenAiCompletions,
provider: "local".into(),
base_url: base_url.into(),
reasoning: false,
context_window: 128_000,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::default()),
}
}
pub fn openai_compat(
base_url: impl Into<String>,
model_id: impl Into<String>,
provider: impl Into<String>,
compat: OpenAiCompat,
) -> Self {
let id = model_id.into();
Self {
id: id.clone(),
name: id,
api: ApiProtocol::OpenAiCompletions,
provider: provider.into(),
base_url: base_url.into(),
reasoning: false,
context_window: 128_000,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(compat),
}
}
pub fn ollama(base_url: impl Into<String>, model_id: impl Into<String>) -> Self {
let id = model_id.into();
Self {
id: id.clone(),
name: id,
api: ApiProtocol::OpenAiCompletions,
provider: "ollama".into(),
base_url: base_url.into(),
reasoning: false,
context_window: 128_000,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::ollama()),
}
}
pub fn zai(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::OpenAiCompletions,
provider: "zai".into(),
base_url: "https://api.z.ai/api/paas/v4".into(),
reasoning: false,
context_window: 128_000,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::zai()),
}
}
pub fn minimax(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::OpenAiCompletions,
provider: "minimax".into(),
base_url: "https://api.minimaxi.chat/v1".into(),
reasoning: false,
context_window: 1_000_000,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::minimax()),
}
}
pub fn qwen(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::OpenAiCompletions,
provider: "qwen".into(),
base_url: "https://dashscope-intl.aliyuncs.com/compatible-mode/v1".into(),
reasoning: true,
context_window: 128_000,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::qwen()),
}
}
pub fn xai(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::OpenAiCompletions,
provider: "xai".into(),
base_url: "https://api.x.ai/v1".into(),
reasoning: false,
context_window: 131_072,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::xai()),
}
}
pub fn groq(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::OpenAiCompletions,
provider: "groq".into(),
base_url: "https://api.groq.com/openai/v1".into(),
reasoning: false,
context_window: 128_000,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::groq()),
}
}
pub fn deepseek(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::OpenAiCompletions,
provider: "deepseek".into(),
base_url: "https://api.deepseek.com".into(),
reasoning: true,
context_window: 1_000_000,
max_tokens: 384_000,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::deepseek()),
}
}
pub fn mistral(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::OpenAiCompletions,
provider: "mistral".into(),
base_url: "https://api.mistral.ai/v1".into(),
reasoning: false,
context_window: 128_000,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::mistral()),
}
}
pub fn google(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::GoogleGenerativeAi,
provider: "google".into(),
base_url: "https://generativelanguage.googleapis.com".into(),
reasoning: false,
context_window: 1_000_000,
max_tokens: 8192,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_config_anthropic() {
let config = ModelConfig::anthropic("claude-sonnet-4-20250514", "Claude Sonnet 4");
assert_eq!(config.api, ApiProtocol::AnthropicMessages);
assert_eq!(config.provider, "anthropic");
assert!(config.compat.is_none());
}
#[test]
fn test_model_config_openai() {
let config = ModelConfig::openai("gpt-4o", "GPT-4o");
assert_eq!(config.api, ApiProtocol::OpenAiCompletions);
let compat = config.compat.unwrap();
assert!(compat.supports_store);
assert!(compat.supports_developer_role);
assert_eq!(compat.max_tokens_field, MaxTokensField::MaxCompletionTokens);
}
#[test]
fn test_openai_compat_variants() {
let xai = OpenAiCompat::xai();
assert_eq!(xai.thinking_format, ThinkingFormat::Xai);
assert!(!xai.supports_store);
let groq = OpenAiCompat::groq();
assert!(groq.supports_usage_in_streaming);
assert!(!groq.supports_store);
let deepseek = OpenAiCompat::deepseek();
assert_eq!(deepseek.max_tokens_field, MaxTokensField::MaxTokens);
assert!(deepseek.supports_reasoning_effort);
assert!(deepseek.supports_thinking_control);
let zai = OpenAiCompat::zai();
assert!(zai.supports_usage_in_streaming);
assert!(!zai.supports_store);
let minimax = OpenAiCompat::minimax();
assert!(minimax.supports_usage_in_streaming);
assert!(!minimax.supports_store);
let ollama = OpenAiCompat::ollama();
assert!(ollama.requires_assistant_after_tool_result);
assert!(!ollama.requires_tool_result_name);
let qwen = OpenAiCompat::qwen();
assert_eq!(qwen.thinking_format, ThinkingFormat::Qwen);
assert_eq!(qwen.max_tokens_field, MaxTokensField::MaxTokens);
assert!(qwen.supports_usage_in_streaming);
assert!(!qwen.supports_reasoning_effort);
assert!(!qwen.supports_thinking_control);
}
#[test]
fn test_openai_compat_deserializes_without_assistant_after_tool_result_flag() {
let compat: OpenAiCompat = serde_json::from_value(serde_json::json!({
"supports_store": false,
"supports_developer_role": false,
"supports_reasoning_effort": false,
"supports_thinking_control": false,
"supports_usage_in_streaming": true,
"max_tokens_field": "max_tokens",
"requires_tool_result_name": false,
"thinking_format": "open_ai"
}))
.unwrap();
assert!(!compat.requires_assistant_after_tool_result);
}
#[test]
fn test_model_config_local_remains_neutral() {
let config = ModelConfig::local("http://localhost:1234/v1", "local-model");
assert_eq!(config.api, ApiProtocol::OpenAiCompletions);
assert_eq!(config.provider, "local");
assert_eq!(config.base_url, "http://localhost:1234/v1");
let compat = config.compat.unwrap();
assert!(!compat.requires_assistant_after_tool_result);
}
#[test]
fn test_model_config_ollama() {
let config = ModelConfig::ollama("http://localhost:11434/v1", "llama3.1:8b");
assert_eq!(config.api, ApiProtocol::OpenAiCompletions);
assert_eq!(config.provider, "ollama");
assert_eq!(config.id, "llama3.1:8b");
assert_eq!(config.name, "llama3.1:8b");
assert_eq!(config.base_url, "http://localhost:11434/v1");
let compat = config.compat.unwrap();
assert!(compat.requires_assistant_after_tool_result);
}
#[test]
fn test_model_config_openai_compat() {
let config = ModelConfig::openai_compat(
"http://localhost:1234/v1",
"qwen3-local",
"qwen",
OpenAiCompat::qwen(),
);
assert_eq!(config.api, ApiProtocol::OpenAiCompletions);
assert_eq!(config.provider, "qwen");
assert_eq!(config.id, "qwen3-local");
assert_eq!(config.name, "qwen3-local");
assert_eq!(config.base_url, "http://localhost:1234/v1");
let compat = config.compat.unwrap();
assert_eq!(compat.thinking_format, ThinkingFormat::Qwen);
}
#[test]
fn test_model_config_qwen() {
let config = ModelConfig::qwen("qwen3.6-plus", "Qwen 3.6 Plus");
assert_eq!(config.api, ApiProtocol::OpenAiCompletions);
assert_eq!(config.provider, "qwen");
assert_eq!(
config.base_url,
"https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
);
assert!(config.reasoning);
let compat = config.compat.unwrap();
assert_eq!(compat.thinking_format, ThinkingFormat::Qwen);
assert_eq!(compat.max_tokens_field, MaxTokensField::MaxTokens);
}
#[test]
fn test_model_config_zai() {
let config = ModelConfig::zai("glm-4.7", "GLM 4.7");
assert_eq!(config.api, ApiProtocol::OpenAiCompletions);
assert_eq!(config.provider, "zai");
assert_eq!(config.base_url, "https://api.z.ai/api/paas/v4");
assert!(config.compat.is_some());
}
#[test]
fn test_model_config_minimax() {
let config = ModelConfig::minimax("MiniMax-Text-01", "MiniMax Text 01");
assert_eq!(config.api, ApiProtocol::OpenAiCompletions);
assert_eq!(config.provider, "minimax");
assert_eq!(config.base_url, "https://api.minimaxi.chat/v1");
assert_eq!(config.context_window, 1_000_000);
assert!(config.compat.is_some());
}
#[test]
fn test_model_config_deepseek() {
let config = ModelConfig::deepseek("deepseek-v4-flash", "DeepSeek V4 Flash");
assert_eq!(config.api, ApiProtocol::OpenAiCompletions);
assert_eq!(config.provider, "deepseek");
assert_eq!(config.base_url, "https://api.deepseek.com");
assert_eq!(config.context_window, 1_000_000);
assert_eq!(config.max_tokens, 384_000);
assert!(config.reasoning);
assert!(config.compat.is_some());
}
#[test]
fn test_api_protocol_display() {
assert_eq!(
ApiProtocol::AnthropicMessages.to_string(),
"anthropic_messages"
);
assert_eq!(
ApiProtocol::OpenAiCompletions.to_string(),
"openai_completions"
);
assert_eq!(
ApiProtocol::GoogleGenerativeAi.to_string(),
"google_generative_ai"
);
}
#[test]
fn test_cost_config_default() {
let cost = CostConfig::default();
assert_eq!(cost.input_per_million, 0.0);
assert_eq!(cost.output_per_million, 0.0);
}
}