Skip to main content

aico/
model_registry.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::env;
4use std::fs;
5use std::path::PathBuf;
6use std::time::Duration;
7use time::OffsetDateTime;
8
9use crate::fs::atomic_write_json;
10
11const CACHE_TTL_DAYS: i64 = 14;
12
13fn get_litellm_url() -> String {
14    env::var("AICO_LITELLM_URL").unwrap_or_else(|_| {
15        "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
16            .to_string()
17    })
18}
19
20fn get_openrouter_url() -> String {
21    env::var("AICO_OPENROUTER_URL")
22        .unwrap_or_else(|_| "https://openrouter.ai/api/v1/models".to_string())
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26struct ModelRegistry {
27    #[serde(with = "time::serde::rfc3339")]
28    last_fetched: OffsetDateTime,
29    models: HashMap<String, ModelInfo>,
30}
31
32#[derive(Deserialize)]
33struct OpenRouterPricing {
34    prompt: String,
35    completion: String,
36}
37
38#[derive(Deserialize)]
39struct OpenRouterItem {
40    id: String,
41    context_length: u32,
42    pricing: OpenRouterPricing,
43}
44
45#[derive(Deserialize)]
46struct OpenRouterResponse {
47    data: Vec<OpenRouterItem>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ModelInfo {
52    pub max_input_tokens: Option<u32>,
53    pub input_cost_per_token: Option<f64>,
54    pub output_cost_per_token: Option<f64>,
55}
56
57fn get_cache_path() -> PathBuf {
58    crate::utils::get_app_cache_dir().join("models.json")
59}
60
61static REGISTRY_CACHE: std::sync::OnceLock<ModelRegistry> = std::sync::OnceLock::new();
62
63pub async fn get_model_info(model_id: &str) -> Option<ModelInfo> {
64    if let Some(registry) = REGISTRY_CACHE.get() {
65        return get_info_from_registry(model_id, registry);
66    }
67
68    let path = get_cache_path();
69    if let Some(registry) = ensure_cache(&path).await {
70        let _ = REGISTRY_CACHE.set(registry);
71    }
72
73    if let Some(registry) = REGISTRY_CACHE.get() {
74        return get_info_from_registry(model_id, registry);
75    }
76    None
77}
78
79async fn ensure_cache(path: &PathBuf) -> Option<ModelRegistry> {
80    let mut should_fetch = false;
81    let existing: Option<ModelRegistry> = if path.exists() {
82        crate::fs::read_json::<ModelRegistry>(path)
83            .ok()
84            .inspect(|reg| {
85                if (OffsetDateTime::now_utc() - reg.last_fetched).whole_days() < CACHE_TTL_DAYS {
86                    return;
87                }
88                should_fetch = true;
89            })
90    } else {
91        should_fetch = true;
92        None
93    };
94
95    if should_fetch {
96        let _ = update_registry(path.clone()).await;
97        // Re-read after potential update
98        fs::read_to_string(path)
99            .ok()
100            .and_then(|c| serde_json::from_str(&c).ok())
101    } else {
102        existing
103    }
104}
105
106async fn update_registry(path: PathBuf) -> Result<(), Box<dyn std::error::Error>> {
107    crate::utils::setup_crypto_provider();
108
109    let client = reqwest::Client::builder()
110        .timeout(Duration::from_secs(3))
111        .build()?;
112    let mut all_models: HashMap<String, ModelInfo> = HashMap::new();
113
114    if let Ok(resp) = client.get(get_litellm_url()).send().await
115        && let Ok(lite) = resp.json::<HashMap<String, ModelInfo>>().await
116    {
117        all_models.extend(lite);
118    }
119    if let Ok(resp) = client.get(get_openrouter_url()).send().await
120        && let Ok(or) = resp.json::<OpenRouterResponse>().await
121    {
122        for item in or.data {
123            all_models.insert(
124                item.id,
125                ModelInfo {
126                    max_input_tokens: Some(item.context_length),
127                    input_cost_per_token: item.pricing.prompt.parse().ok(),
128                    output_cost_per_token: item.pricing.completion.parse().ok(),
129                },
130            );
131        }
132    }
133
134    if all_models.is_empty() {
135        return Ok(());
136    }
137
138    let registry = ModelRegistry {
139        last_fetched: OffsetDateTime::now_utc(),
140        models: all_models,
141    };
142
143    if let Some(parent) = path.parent() {
144        fs::create_dir_all(parent)?;
145    }
146    atomic_write_json(&path, &registry)?;
147
148    Ok(())
149}
150
151fn get_info_from_registry(model_id: &str, registry: &ModelRegistry) -> Option<ModelInfo> {
152    // Pre-process: Strip any flags (everything after first +)
153    let base_model = model_id.split('+').next().unwrap_or(model_id);
154
155    // Helper to check a specific key
156    let check_key = |key: &str| -> Option<ModelInfo> {
157        // 1. Exact match
158        if let Some(info) = registry.models.get(key) {
159            return Some(info.clone());
160        }
161        // 2. Fallback: Strip modifiers like :online (openai/gpt-4o:online -> openai/gpt-4o)
162        if let Some((simple, _)) = key.split_once(':')
163            && let Some(info) = registry.models.get(simple)
164        {
165            return Some(info.clone());
166        }
167        None
168    };
169
170    // 1. Try full base model (e.g. "openai/gpt-4o:online")
171    if let Some(info) = check_key(base_model) {
172        return Some(info);
173    }
174
175    // 2. Strip Provider Prefix (openai/gpt-4 -> gpt-4)
176    if let Some((_, stripped)) = base_model.split_once('/') {
177        if let Some(info) = check_key(stripped) {
178            return Some(info);
179        }
180
181        // 3. Strip Vendor (google/gemini -> gemini)
182        if let Some((_, bare)) = stripped.split_once('/')
183            && let Some(info) = check_key(bare)
184        {
185            return Some(info);
186        }
187    }
188
189    None
190}
191
192pub fn get_model_info_at(model_id: &str, path: PathBuf) -> Option<ModelInfo> {
193    if !path.exists() {
194        return None;
195    }
196
197    let registry: ModelRegistry = crate::fs::read_json(&path).ok()?;
198    get_info_from_registry(model_id, &registry)
199}