use crate::error::{Error, Result};
use crate::types::{Model, ModelCost, ModelLimit};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
pub const MODELS_DEV_URL: &str = "https://models.dev/api.json";
pub const DEFAULT_CACHE_PATH: &str = ".stakpak/cache/models.json";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderInfo {
pub id: String,
pub name: String,
#[serde(default)]
pub env: Vec<String>,
#[serde(default)]
pub api: Option<String>,
pub models: HashMap<String, Model>,
}
#[derive(Debug, Clone, Deserialize)]
struct RawModel {
id: String,
name: String,
#[serde(default)]
reasoning: bool,
#[serde(default)]
tool_call: bool,
#[serde(default)]
cost: Option<RawCost>,
#[serde(default)]
limit: RawLimit,
#[serde(default)]
status: Option<String>,
#[serde(default)]
release_date: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Default)]
struct RawCost {
#[serde(default)]
input: f64,
#[serde(default)]
output: f64,
#[serde(default)]
cache_read: Option<f64>,
#[serde(default)]
cache_write: Option<f64>,
}
#[derive(Debug, Clone, Deserialize, Default)]
struct RawLimit {
#[serde(default)]
context: u64,
#[serde(default)]
output: u64,
}
#[derive(Debug, Clone, Deserialize)]
struct RawProvider {
name: String,
#[serde(default)]
env: Vec<String>,
#[serde(default)]
api: Option<String>,
models: HashMap<String, RawModel>,
}
#[derive(Debug, Clone, Deserialize)]
struct CacheFile {
#[allow(dead_code)]
fetched_at: u64,
providers: HashMap<String, ProviderInfo>,
}
pub async fn fetch_models_dev() -> Result<HashMap<String, ProviderInfo>> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|e| Error::NetworkError(format!("Failed to create HTTP client: {}", e)))?;
let response = client
.get(MODELS_DEV_URL)
.header("User-Agent", "stakpak-cli")
.send()
.await
.map_err(|e| Error::NetworkError(format!("Failed to fetch models.dev: {}", e)))?;
if !response.status().is_success() {
return Err(Error::NetworkError(format!(
"models.dev returned status {}",
response.status()
)));
}
let raw: HashMap<String, RawProvider> = response
.json()
.await
.map_err(|e| Error::NetworkError(format!("Failed to parse models.dev response: {}", e)))?;
Ok(convert_raw_providers(raw))
}
pub fn parse_models_dev(json: &str) -> Result<HashMap<String, ProviderInfo>> {
let raw: HashMap<String, RawProvider> = serde_json::from_str(json)
.map_err(|e| Error::ConfigError(format!("Failed to parse models JSON: {}", e)))?;
Ok(convert_raw_providers(raw))
}
pub fn load_models_for_provider(provider_id: &str) -> Result<Vec<Model>> {
let cache_path = dirs::home_dir()
.unwrap_or_default()
.join(DEFAULT_CACHE_PATH);
load_models_for_provider_from_path(provider_id, &cache_path)
}
pub fn load_models_for_provider_from_path(provider_id: &str, path: &Path) -> Result<Vec<Model>> {
let providers = load_cache_file(path)?;
Ok(providers
.get(provider_id)
.map(|p| p.models.values().cloned().collect())
.unwrap_or_default())
}
pub fn load_available_models() -> Result<Vec<Model>> {
let cache_path = dirs::home_dir()
.unwrap_or_default()
.join(DEFAULT_CACHE_PATH);
let providers = load_cache_file(&cache_path)?;
Ok(get_available_models(&providers))
}
pub fn filter_configured_providers(
providers: &HashMap<String, ProviderInfo>,
) -> HashMap<String, ProviderInfo> {
providers
.iter()
.filter(|(_, p)| p.env.iter().any(|var| is_env_set(var)))
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
pub fn get_available_models(providers: &HashMap<String, ProviderInfo>) -> Vec<Model> {
filter_configured_providers(providers)
.values()
.flat_map(|p| p.models.values().cloned())
.collect()
}
fn load_cache_file(path: &Path) -> Result<HashMap<String, ProviderInfo>> {
let content = std::fs::read_to_string(path).map_err(|e| {
Error::ConfigError(format!(
"Failed to read models cache at {}: {}",
path.display(),
e
))
})?;
if let Ok(cache) = serde_json::from_str::<CacheFile>(&content) {
return Ok(cache.providers);
}
parse_models_dev(&content)
}
fn convert_raw_providers(raw: HashMap<String, RawProvider>) -> HashMap<String, ProviderInfo> {
raw.into_iter()
.map(|(id, provider)| {
let models = provider
.models
.into_iter()
.filter(|(_, m)| is_model_compatible(m))
.map(|(model_id, m)| (model_id, convert_raw_model(m, &id)))
.collect();
let info = ProviderInfo {
id: id.clone(),
name: provider.name,
env: provider.env,
api: provider.api,
models,
};
(id, info)
})
.collect()
}
fn is_model_compatible(model: &RawModel) -> bool {
model.tool_call
&& model.status.as_deref() != Some("deprecated")
&& model.status.as_deref() != Some("alpha")
&& !is_embedding_model(&model.id)
}
fn is_embedding_model(model_id: &str) -> bool {
let id_lower = model_id.to_lowercase();
id_lower.contains("embed")
}
fn convert_raw_model(raw: RawModel, provider_id: &str) -> Model {
Model {
id: raw.id,
name: raw.name,
provider: provider_id.to_string(),
reasoning: raw.reasoning,
cost: raw.cost.map(|c| ModelCost {
input: c.input,
output: c.output,
cache_read: c.cache_read,
cache_write: c.cache_write,
}),
limit: ModelLimit {
context: raw.limit.context,
output: raw.limit.output,
},
release_date: raw.release_date,
}
}
fn is_env_set(var: &str) -> bool {
std::env::var(var).map(|v| !v.is_empty()).unwrap_or(false)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_models_dev() {
let json = r#"{
"anthropic": {
"name": "Anthropic",
"env": ["ANTHROPIC_API_KEY"],
"api": "https://api.anthropic.com/v1",
"models": {
"claude-sonnet-4": {
"id": "claude-sonnet-4",
"name": "Claude Sonnet 4",
"reasoning": true,
"tool_call": true,
"cost": { "input": 3.0, "output": 15.0, "cache_read": 0.3 },
"limit": { "context": 200000, "output": 64000 }
}
}
}
}"#;
let providers = parse_models_dev(json).unwrap();
assert!(providers.contains_key("anthropic"));
let anthropic = &providers["anthropic"];
assert_eq!(anthropic.name, "Anthropic");
assert_eq!(anthropic.env, vec!["ANTHROPIC_API_KEY"]);
let model = &anthropic.models["claude-sonnet-4"];
assert_eq!(model.name, "Claude Sonnet 4");
assert!(model.reasoning);
assert_eq!(model.cost.as_ref().unwrap().input, 3.0);
}
#[test]
fn test_model_filtering() {
let json = r#"{
"test": {
"name": "Test Provider",
"env": [],
"models": {
"good": { "id": "good", "name": "Good Model", "tool_call": true, "limit": {} },
"deprecated": { "id": "deprecated", "name": "Old", "tool_call": true, "status": "deprecated", "limit": {} },
"alpha": { "id": "alpha", "name": "Experimental", "tool_call": true, "status": "alpha", "limit": {} },
"no_tools": { "id": "no_tools", "name": "No Tools", "tool_call": false, "limit": {} },
"text-embedding-3": { "id": "text-embedding-3", "name": "Embedding", "tool_call": true, "limit": {} }
}
}
}"#;
let providers = parse_models_dev(json).unwrap();
let test = &providers["test"];
assert_eq!(test.models.len(), 1);
assert!(test.models.contains_key("good"));
}
#[test]
fn test_embedding_model_detection() {
assert!(is_embedding_model("text-embedding-3-large"));
assert!(is_embedding_model("openai/text-embedding-ada-002"));
assert!(is_embedding_model("gemini-embedding-001"));
assert!(is_embedding_model("EMBED-something"));
assert!(!is_embedding_model("gpt-4"));
assert!(!is_embedding_model("claude-sonnet-4"));
}
}