aico/
model_registry.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::env;
5use std::fs;
6use std::path::PathBuf;
7use std::time::Duration;
8
9use crate::fs::atomic_write_json;
10
11const CACHE_TTL_DAYS: i64 = 14;
12
13fn get_litellm_url() -> String {
14    env::var("AICO_LITELLM_URL").unwrap_or_else(|_| {
15        "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
16            .to_string()
17    })
18}
19
20fn get_openrouter_url() -> String {
21    env::var("AICO_OPENROUTER_URL")
22        .unwrap_or_else(|_| "https://openrouter.ai/api/v1/models".to_string())
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26struct ModelRegistry {
27    last_fetched: String,
28    models: HashMap<String, ModelInfo>,
29}
30
31#[derive(Deserialize)]
32struct OpenRouterPricing {
33    prompt: String,
34    completion: String,
35}
36
37#[derive(Deserialize)]
38struct OpenRouterItem {
39    id: String,
40    context_length: u32,
41    pricing: OpenRouterPricing,
42}
43
44#[derive(Deserialize)]
45struct OpenRouterResponse {
46    data: Vec<OpenRouterItem>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ModelInfo {
51    pub max_input_tokens: Option<u32>,
52    pub input_cost_per_token: Option<f64>,
53    pub output_cost_per_token: Option<f64>,
54}
55
56fn get_cache_path() -> PathBuf {
57    if let Ok(custom) = env::var("AICO_CACHE_DIR") {
58        return PathBuf::from(custom).join("models.json");
59    }
60    let xdg = env::var("XDG_CACHE_HOME").map(PathBuf::from).ok();
61    let home = env::var("HOME").map(PathBuf::from).ok();
62
63    let base = xdg
64        .or_else(|| home.map(|h| h.join(".cache")))
65        .unwrap_or_else(|| PathBuf::from("."));
66    base.join("aico").join("models.json")
67}
68
69static REGISTRY_CACHE: std::sync::OnceLock<ModelRegistry> = std::sync::OnceLock::new();
70
71pub async fn get_model_info(model_id: &str) -> Option<ModelInfo> {
72    if let Some(registry) = REGISTRY_CACHE.get() {
73        return get_info_from_registry(model_id, registry);
74    }
75
76    let path = get_cache_path();
77    if let Some(registry) = ensure_cache(&path).await {
78        let _ = REGISTRY_CACHE.set(registry);
79    }
80
81    if let Some(registry) = REGISTRY_CACHE.get() {
82        return get_info_from_registry(model_id, registry);
83    }
84    None
85}
86
87async fn ensure_cache(path: &PathBuf) -> Option<ModelRegistry> {
88    let mut should_fetch = false;
89    let existing: Option<ModelRegistry> = if path.exists() {
90        fs::read_to_string(path).ok().and_then(|c| {
91            let reg: Option<ModelRegistry> = serde_json::from_str(&c).ok();
92            if let Some(ref r) = reg
93                && let Ok(dt) = DateTime::parse_from_rfc3339(&r.last_fetched)
94                && (Utc::now() - dt.with_timezone(&Utc)).num_days() < CACHE_TTL_DAYS
95            {
96                return reg;
97            }
98            should_fetch = true;
99            reg
100        })
101    } else {
102        should_fetch = true;
103        None
104    };
105
106    if should_fetch {
107        let _ = update_registry(path.clone()).await;
108        // Re-read after potential update
109        fs::read_to_string(path)
110            .ok()
111            .and_then(|c| serde_json::from_str(&c).ok())
112    } else {
113        existing
114    }
115}
116
117async fn update_registry(path: PathBuf) -> Result<(), Box<dyn std::error::Error>> {
118    crate::utils::setup_crypto_provider();
119
120    let client = reqwest::Client::builder()
121        .timeout(Duration::from_secs(3))
122        .build()?;
123    let mut all_models: HashMap<String, ModelInfo> = HashMap::new();
124
125    if let Ok(resp) = client.get(get_litellm_url()).send().await
126        && let Ok(lite) = resp.json::<HashMap<String, ModelInfo>>().await
127    {
128        all_models.extend(lite);
129    }
130    if let Ok(resp) = client.get(get_openrouter_url()).send().await
131        && let Ok(or) = resp.json::<OpenRouterResponse>().await
132    {
133        for item in or.data {
134            all_models.insert(
135                item.id,
136                ModelInfo {
137                    max_input_tokens: Some(item.context_length),
138                    input_cost_per_token: item.pricing.prompt.parse().ok(),
139                    output_cost_per_token: item.pricing.completion.parse().ok(),
140                },
141            );
142        }
143    }
144
145    if all_models.is_empty() {
146        return Ok(());
147    }
148
149    let registry = ModelRegistry {
150        last_fetched: Utc::now().to_rfc3339(),
151        models: all_models,
152    };
153
154    if let Some(parent) = path.parent() {
155        fs::create_dir_all(parent)?;
156    }
157    atomic_write_json(&path, &registry)?;
158
159    Ok(())
160}
161
162fn get_info_from_registry(model_id: &str, registry: &ModelRegistry) -> Option<ModelInfo> {
163    // Pre-process: Strip any flags (everything after first +)
164    let base_model = model_id.split('+').next().unwrap_or(model_id);
165
166    // Helper to check a specific key
167    let check_key = |key: &str| -> Option<ModelInfo> {
168        // 1. Exact match
169        if let Some(info) = registry.models.get(key) {
170            return Some(info.clone());
171        }
172        // 2. Fallback: Strip modifiers like :online (openai/gpt-4o:online -> openai/gpt-4o)
173        if let Some((simple, _)) = key.split_once(':')
174            && let Some(info) = registry.models.get(simple)
175        {
176            return Some(info.clone());
177        }
178        None
179    };
180
181    // 1. Try full base model (e.g. "openai/gpt-4o:online")
182    if let Some(info) = check_key(base_model) {
183        return Some(info);
184    }
185
186    // 2. Strip Provider Prefix (openai/gpt-4 -> gpt-4)
187    if let Some((_, stripped)) = base_model.split_once('/') {
188        if let Some(info) = check_key(stripped) {
189            return Some(info);
190        }
191
192        // 3. Strip Vendor (google/gemini -> gemini)
193        if let Some((_, bare)) = stripped.split_once('/')
194            && let Some(info) = check_key(bare)
195        {
196            return Some(info);
197        }
198    }
199
200    None
201}
202
203pub fn get_model_info_at(model_id: &str, path: PathBuf) -> Option<ModelInfo> {
204    if !path.exists() {
205        return None;
206    }
207
208    let content = fs::read_to_string(path).ok()?;
209    let registry: ModelRegistry = serde_json::from_str(&content).ok()?;
210    get_info_from_registry(model_id, &registry)
211}