1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::env;
5use std::fs;
6use std::path::PathBuf;
7use std::time::Duration;
8
9use crate::fs::atomic_write_json;
10
11const CACHE_TTL_DAYS: i64 = 14;
12
13fn get_litellm_url() -> String {
14 env::var("AICO_LITELLM_URL").unwrap_or_else(|_| {
15 "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
16 .to_string()
17 })
18}
19
20fn get_openrouter_url() -> String {
21 env::var("AICO_OPENROUTER_URL")
22 .unwrap_or_else(|_| "https://openrouter.ai/api/v1/models".to_string())
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26struct ModelRegistry {
27 last_fetched: String,
28 models: HashMap<String, ModelInfo>,
29}
30
31#[derive(Deserialize)]
32struct OpenRouterPricing {
33 prompt: String,
34 completion: String,
35}
36
37#[derive(Deserialize)]
38struct OpenRouterItem {
39 id: String,
40 context_length: u32,
41 pricing: OpenRouterPricing,
42}
43
44#[derive(Deserialize)]
45struct OpenRouterResponse {
46 data: Vec<OpenRouterItem>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ModelInfo {
51 pub max_input_tokens: Option<u32>,
52 pub input_cost_per_token: Option<f64>,
53 pub output_cost_per_token: Option<f64>,
54}
55
56fn get_cache_path() -> PathBuf {
57 if let Ok(custom) = env::var("AICO_CACHE_DIR") {
58 return PathBuf::from(custom).join("models.json");
59 }
60 let xdg = env::var("XDG_CACHE_HOME").map(PathBuf::from).ok();
61 let home = env::var("HOME").map(PathBuf::from).ok();
62
63 let base = xdg
64 .or_else(|| home.map(|h| h.join(".cache")))
65 .unwrap_or_else(|| PathBuf::from("."));
66 base.join("aico").join("models.json")
67}
68
69static REGISTRY_CACHE: std::sync::OnceLock<ModelRegistry> = std::sync::OnceLock::new();
70
71pub async fn get_model_info(model_id: &str) -> Option<ModelInfo> {
72 if let Some(registry) = REGISTRY_CACHE.get() {
73 return get_info_from_registry(model_id, registry);
74 }
75
76 let path = get_cache_path();
77 if let Some(registry) = ensure_cache(&path).await {
78 let _ = REGISTRY_CACHE.set(registry);
79 }
80
81 if let Some(registry) = REGISTRY_CACHE.get() {
82 return get_info_from_registry(model_id, registry);
83 }
84 None
85}
86
87async fn ensure_cache(path: &PathBuf) -> Option<ModelRegistry> {
88 let mut should_fetch = false;
89 let existing: Option<ModelRegistry> = if path.exists() {
90 fs::read_to_string(path).ok().and_then(|c| {
91 let reg: Option<ModelRegistry> = serde_json::from_str(&c).ok();
92 if let Some(ref r) = reg
93 && let Ok(dt) = DateTime::parse_from_rfc3339(&r.last_fetched)
94 && (Utc::now() - dt.with_timezone(&Utc)).num_days() < CACHE_TTL_DAYS
95 {
96 return reg;
97 }
98 should_fetch = true;
99 reg
100 })
101 } else {
102 should_fetch = true;
103 None
104 };
105
106 if should_fetch {
107 let _ = update_registry(path.clone()).await;
108 fs::read_to_string(path)
110 .ok()
111 .and_then(|c| serde_json::from_str(&c).ok())
112 } else {
113 existing
114 }
115}
116
117async fn update_registry(path: PathBuf) -> Result<(), Box<dyn std::error::Error>> {
118 crate::utils::setup_crypto_provider();
119
120 let client = reqwest::Client::builder()
121 .timeout(Duration::from_secs(3))
122 .build()?;
123 let mut all_models: HashMap<String, ModelInfo> = HashMap::new();
124
125 if let Ok(resp) = client.get(get_litellm_url()).send().await
126 && let Ok(lite) = resp.json::<HashMap<String, ModelInfo>>().await
127 {
128 all_models.extend(lite);
129 }
130 if let Ok(resp) = client.get(get_openrouter_url()).send().await
131 && let Ok(or) = resp.json::<OpenRouterResponse>().await
132 {
133 for item in or.data {
134 all_models.insert(
135 item.id,
136 ModelInfo {
137 max_input_tokens: Some(item.context_length),
138 input_cost_per_token: item.pricing.prompt.parse().ok(),
139 output_cost_per_token: item.pricing.completion.parse().ok(),
140 },
141 );
142 }
143 }
144
145 if all_models.is_empty() {
146 return Ok(());
147 }
148
149 let registry = ModelRegistry {
150 last_fetched: Utc::now().to_rfc3339(),
151 models: all_models,
152 };
153
154 if let Some(parent) = path.parent() {
155 fs::create_dir_all(parent)?;
156 }
157 atomic_write_json(&path, ®istry)?;
158
159 Ok(())
160}
161
162fn get_info_from_registry(model_id: &str, registry: &ModelRegistry) -> Option<ModelInfo> {
163 let base_model = model_id.split('+').next().unwrap_or(model_id);
165
166 let check_key = |key: &str| -> Option<ModelInfo> {
168 if let Some(info) = registry.models.get(key) {
170 return Some(info.clone());
171 }
172 if let Some((simple, _)) = key.split_once(':')
174 && let Some(info) = registry.models.get(simple)
175 {
176 return Some(info.clone());
177 }
178 None
179 };
180
181 if let Some(info) = check_key(base_model) {
183 return Some(info);
184 }
185
186 if let Some((_, stripped)) = base_model.split_once('/') {
188 if let Some(info) = check_key(stripped) {
189 return Some(info);
190 }
191
192 if let Some((_, bare)) = stripped.split_once('/')
194 && let Some(info) = check_key(bare)
195 {
196 return Some(info);
197 }
198 }
199
200 None
201}
202
203pub fn get_model_info_at(model_id: &str, path: PathBuf) -> Option<ModelInfo> {
204 if !path.exists() {
205 return None;
206 }
207
208 let content = fs::read_to_string(path).ok()?;
209 let registry: ModelRegistry = serde_json::from_str(&content).ok()?;
210 get_info_from_registry(model_id, ®istry)
211}