use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelCapability {
Generate,
Embed,
Rerank,
Classify,
Code,
Reasoning,
Summarize,
ToolUse,
MultiToolCall,
Vision,
VideoUnderstanding,
AudioUnderstanding,
Grounding,
SpeechToText,
TextToSpeech,
ImageGeneration,
VideoGeneration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ModelSource {
Local {
hf_repo: String,
hf_filename: String,
tokenizer_repo: String,
},
RemoteApi {
endpoint: String,
api_key_env: String,
#[serde(default)]
api_key_envs: Vec<String>,
#[serde(default)]
api_version: Option<String>,
protocol: ApiProtocol,
},
Ollama {
model_tag: String,
#[serde(default = "default_ollama_host")]
host: String,
},
Mlx {
hf_repo: String,
#[serde(default)]
hf_weight_file: Option<String>,
},
VllmMlx {
endpoint: String,
model_name: String,
},
AppleFoundationModels {
#[serde(default)]
use_case: Option<String>,
},
Proprietary {
provider: String,
endpoint: String,
auth: ProprietaryAuth,
protocol: ProprietaryProtocol,
},
Delegated {
#[serde(default)]
hint: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ProprietaryAuth {
OAuth2Pkce {
authority: String,
client_id: String,
scopes: Vec<String>,
},
ApiKeyEnv { env_var: String },
BearerTokenEnv { env_var: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProprietaryProtocol {
#[serde(default = "default_chat_path")]
pub chat_path: String,
#[serde(default = "default_content_type")]
pub content_type: String,
#[serde(default)]
pub streaming: bool,
#[serde(default)]
pub extra_headers: std::collections::HashMap<String, String>,
}
impl Default for ProprietaryProtocol {
fn default() -> Self {
Self {
chat_path: default_chat_path(),
content_type: default_content_type(),
streaming: false,
extra_headers: std::collections::HashMap::new(),
}
}
}
fn default_chat_path() -> String {
"/chat".to_string()
}
fn default_content_type() -> String {
"application/json".to_string()
}
fn default_ollama_host() -> String {
"http://localhost:11434".to_string()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ApiProtocol {
OpenAiCompat,
OpenAiResponses,
Anthropic,
Google,
AzureOpenAi,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PerformanceEnvelope {
#[serde(default)]
pub latency_p50_ms: Option<u64>,
#[serde(default)]
pub latency_p99_ms: Option<u64>,
#[serde(default)]
pub tokens_per_second: Option<f64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum GenerateParam {
Temperature,
TopP,
TopK,
MaxTokens,
StopSequences,
FrequencyPenalty,
PresencePenalty,
Seed,
ResponseFormat,
ExtendedThinking,
}
pub fn standard_params() -> Vec<GenerateParam> {
vec![
GenerateParam::Temperature,
GenerateParam::TopP,
GenerateParam::MaxTokens,
GenerateParam::StopSequences,
GenerateParam::FrequencyPenalty,
GenerateParam::PresencePenalty,
GenerateParam::Seed,
]
}
pub fn reasoning_params() -> Vec<GenerateParam> {
vec![GenerateParam::MaxTokens, GenerateParam::StopSequences]
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CostModel {
#[serde(default)]
pub input_per_mtok: Option<f64>,
#[serde(default)]
pub output_per_mtok: Option<f64>,
#[serde(default)]
pub size_mb: Option<u64>,
#[serde(default)]
pub ram_mb: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkScore {
pub name: String,
pub score: f64,
#[serde(default)]
pub harness: Option<String>,
#[serde(default)]
pub source_url: Option<String>,
#[serde(default)]
pub measured_at: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelSchema {
pub id: String,
pub name: String,
pub provider: String,
pub family: String,
#[serde(default)]
pub version: String,
pub capabilities: Vec<ModelCapability>,
pub context_length: usize,
#[serde(default)]
pub param_count: String,
#[serde(default)]
pub quantization: Option<String>,
#[serde(default)]
pub performance: PerformanceEnvelope,
#[serde(default)]
pub cost: CostModel,
pub source: ModelSource,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub supported_params: Vec<GenerateParam>,
#[serde(default)]
pub public_benchmarks: Vec<BenchmarkScore>,
#[serde(skip)]
pub available: bool,
}
impl ModelSchema {
pub fn has_capability(&self, cap: ModelCapability) -> bool {
self.capabilities.contains(&cap)
}
pub fn is_local(&self) -> bool {
matches!(
self.source,
ModelSource::Local { .. }
| ModelSource::Mlx { .. }
| ModelSource::VllmMlx { .. }
| ModelSource::AppleFoundationModels { .. }
)
}
pub fn is_delegated(&self) -> bool {
matches!(self.source, ModelSource::Delegated { .. })
}
pub fn is_mlx(&self) -> bool {
matches!(self.source, ModelSource::Mlx { .. })
}
pub fn is_foundation_models(&self) -> bool {
matches!(self.source, ModelSource::AppleFoundationModels { .. })
}
pub fn is_vllm_mlx(&self) -> bool {
matches!(self.source, ModelSource::VllmMlx { .. })
}
pub fn is_remote(&self) -> bool {
matches!(
self.source,
ModelSource::RemoteApi { .. } | ModelSource::Proprietary { .. }
)
}
pub fn all_api_key_envs(&self) -> Vec<String> {
match &self.source {
ModelSource::RemoteApi {
api_key_env,
api_key_envs,
..
} => {
let mut all = vec![api_key_env.clone()];
all.extend(api_key_envs.iter().cloned());
all
}
ModelSource::Proprietary {
auth: ProprietaryAuth::ApiKeyEnv { env_var },
..
}
| ModelSource::Proprietary {
auth: ProprietaryAuth::BearerTokenEnv { env_var },
..
} => vec![env_var.clone()],
_ => vec![],
}
}
pub fn size_mb(&self) -> u64 {
self.cost.size_mb.unwrap_or(0)
}
pub fn ram_mb(&self) -> u64 {
self.cost.ram_mb.unwrap_or_else(|| self.size_mb())
}
pub fn cost_per_1k_output(&self) -> f64 {
self.cost.output_per_mtok.map(|c| c / 1000.0).unwrap_or(0.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_local() -> ModelSchema {
ModelSchema {
id: "qwen/qwen3-4b:q4_k_m".into(),
name: "Qwen3-4B".into(),
provider: "qwen".into(),
family: "qwen3".into(),
version: "1.0".into(),
capabilities: vec![ModelCapability::Generate, ModelCapability::Code],
context_length: 32768,
param_count: "4B".into(),
quantization: Some("Q4_K_M".into()),
performance: PerformanceEnvelope {
tokens_per_second: Some(45.0),
..Default::default()
},
cost: CostModel {
size_mb: Some(2500),
ram_mb: Some(2500),
..Default::default()
},
source: ModelSource::Local {
hf_repo: "Qwen/Qwen3-4B-GGUF".into(),
hf_filename: "Qwen3-4B-Q4_K_M.gguf".into(),
tokenizer_repo: "Qwen/Qwen3-4B".into(),
},
tags: vec!["code".into(), "fast".into()],
supported_params: vec![],
public_benchmarks: vec![],
available: false,
}
}
fn sample_remote() -> ModelSchema {
ModelSchema {
id: "anthropic/claude-sonnet-4-6:latest".into(),
name: "Claude Sonnet 4.6".into(),
provider: "anthropic".into(),
family: "claude-4".into(),
version: "latest".into(),
capabilities: vec![
ModelCapability::Generate,
ModelCapability::Code,
ModelCapability::Reasoning,
ModelCapability::ToolUse,
ModelCapability::Vision,
],
context_length: 200000,
param_count: String::new(),
quantization: None,
performance: PerformanceEnvelope {
latency_p50_ms: Some(2000),
latency_p99_ms: Some(8000),
tokens_per_second: Some(80.0),
},
cost: CostModel {
input_per_mtok: Some(3.0),
output_per_mtok: Some(15.0),
..Default::default()
},
source: ModelSource::RemoteApi {
endpoint: "https://api.anthropic.com/v1/messages".into(),
api_key_env: "ANTHROPIC_API_KEY".into(),
api_key_envs: vec![],
api_version: Some("2023-06-01".into()),
protocol: ApiProtocol::Anthropic,
},
tags: vec!["reasoning".into(), "tool_use".into()],
supported_params: vec![],
public_benchmarks: vec![],
available: false,
}
}
#[test]
fn capabilities() {
let m = sample_local();
assert!(m.has_capability(ModelCapability::Code));
assert!(!m.has_capability(ModelCapability::Vision));
}
#[test]
fn local_vs_remote() {
assert!(sample_local().is_local());
assert!(!sample_local().is_remote());
assert!(sample_remote().is_remote());
assert!(!sample_remote().is_local());
}
#[test]
fn cost() {
let local = sample_local();
assert_eq!(local.cost_per_1k_output(), 0.0);
let remote = sample_remote();
assert!(remote.cost_per_1k_output() > 0.0);
}
#[test]
fn serde_roundtrip() {
let local = sample_local();
let json = serde_json::to_string(&local).unwrap();
let parsed: ModelSchema = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, local.id);
assert_eq!(parsed.capabilities, local.capabilities);
let remote = sample_remote();
let json = serde_json::to_string(&remote).unwrap();
let parsed: ModelSchema = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, remote.id);
assert!(!parsed.available);
}
}