1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::env;
4use std::fs;
5use std::path::PathBuf;
6use std::time::Duration;
7use time::OffsetDateTime;
8
9use crate::fs::atomic_write_json;
10
11const CACHE_TTL_DAYS: i64 = 14;
12
13fn get_litellm_url() -> String {
14 env::var("AICO_LITELLM_URL").unwrap_or_else(|_| {
15 "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
16 .to_string()
17 })
18}
19
20fn get_openrouter_url() -> String {
21 env::var("AICO_OPENROUTER_URL")
22 .unwrap_or_else(|_| "https://openrouter.ai/api/v1/models".to_string())
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26struct ModelRegistry {
27 #[serde(with = "time::serde::rfc3339")]
28 last_fetched: OffsetDateTime,
29 models: HashMap<String, ModelInfo>,
30}
31
32#[derive(Deserialize)]
33struct OpenRouterPricing {
34 prompt: String,
35 completion: String,
36}
37
38#[derive(Deserialize)]
39struct OpenRouterItem {
40 id: String,
41 context_length: u32,
42 pricing: OpenRouterPricing,
43}
44
45#[derive(Deserialize)]
46struct OpenRouterResponse {
47 data: Vec<OpenRouterItem>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ModelInfo {
52 pub max_input_tokens: Option<u32>,
53 pub input_cost_per_token: Option<f64>,
54 pub output_cost_per_token: Option<f64>,
55}
56
57fn get_cache_path() -> PathBuf {
58 crate::utils::get_app_cache_dir().join("models.json")
59}
60
61static REGISTRY_CACHE: std::sync::OnceLock<ModelRegistry> = std::sync::OnceLock::new();
62
63pub async fn get_model_info(model_id: &str) -> Option<ModelInfo> {
64 if let Some(registry) = REGISTRY_CACHE.get() {
65 return get_info_from_registry(model_id, registry);
66 }
67
68 let path = get_cache_path();
69 if let Some(registry) = ensure_cache(&path).await {
70 let _ = REGISTRY_CACHE.set(registry);
71 }
72
73 if let Some(registry) = REGISTRY_CACHE.get() {
74 return get_info_from_registry(model_id, registry);
75 }
76 None
77}
78
79async fn ensure_cache(path: &PathBuf) -> Option<ModelRegistry> {
80 let mut should_fetch = false;
81 let existing: Option<ModelRegistry> = if path.exists() {
82 crate::fs::read_json::<ModelRegistry>(path)
83 .ok()
84 .inspect(|reg| {
85 if (OffsetDateTime::now_utc() - reg.last_fetched).whole_days() < CACHE_TTL_DAYS {
86 return;
87 }
88 should_fetch = true;
89 })
90 } else {
91 should_fetch = true;
92 None
93 };
94
95 if should_fetch {
96 let _ = update_registry(path.clone()).await;
97 fs::read_to_string(path)
99 .ok()
100 .and_then(|c| serde_json::from_str(&c).ok())
101 } else {
102 existing
103 }
104}
105
106async fn update_registry(path: PathBuf) -> Result<(), Box<dyn std::error::Error>> {
107 crate::utils::setup_crypto_provider();
108
109 let client = reqwest::Client::builder()
110 .timeout(Duration::from_secs(3))
111 .build()?;
112 let mut all_models: HashMap<String, ModelInfo> = HashMap::new();
113
114 if let Ok(resp) = client.get(get_litellm_url()).send().await
115 && let Ok(lite) = resp.json::<HashMap<String, ModelInfo>>().await
116 {
117 all_models.extend(lite);
118 }
119 if let Ok(resp) = client.get(get_openrouter_url()).send().await
120 && let Ok(or) = resp.json::<OpenRouterResponse>().await
121 {
122 for item in or.data {
123 all_models.insert(
124 item.id,
125 ModelInfo {
126 max_input_tokens: Some(item.context_length),
127 input_cost_per_token: item.pricing.prompt.parse().ok(),
128 output_cost_per_token: item.pricing.completion.parse().ok(),
129 },
130 );
131 }
132 }
133
134 if all_models.is_empty() {
135 return Ok(());
136 }
137
138 let registry = ModelRegistry {
139 last_fetched: OffsetDateTime::now_utc(),
140 models: all_models,
141 };
142
143 if let Some(parent) = path.parent() {
144 fs::create_dir_all(parent)?;
145 }
146 atomic_write_json(&path, ®istry)?;
147
148 Ok(())
149}
150
151fn get_info_from_registry(model_id: &str, registry: &ModelRegistry) -> Option<ModelInfo> {
152 let base_model = model_id.split('+').next().unwrap_or(model_id);
154
155 let check_key = |key: &str| -> Option<ModelInfo> {
157 if let Some(info) = registry.models.get(key) {
159 return Some(info.clone());
160 }
161 if let Some((simple, _)) = key.split_once(':')
163 && let Some(info) = registry.models.get(simple)
164 {
165 return Some(info.clone());
166 }
167 None
168 };
169
170 if let Some(info) = check_key(base_model) {
172 return Some(info);
173 }
174
175 if let Some((_, stripped)) = base_model.split_once('/') {
177 if let Some(info) = check_key(stripped) {
178 return Some(info);
179 }
180
181 if let Some((_, bare)) = stripped.split_once('/')
183 && let Some(info) = check_key(bare)
184 {
185 return Some(info);
186 }
187 }
188
189 None
190}
191
192pub fn get_model_info_at(model_id: &str, path: PathBuf) -> Option<ModelInfo> {
193 if !path.exists() {
194 return None;
195 }
196
197 let registry: ModelRegistry = crate::fs::read_json(&path).ok()?;
198 get_info_from_registry(model_id, ®istry)
199}