Skip to main content

lash_core/
model_info.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::{Arc, RwLock};
4use std::time::{Duration, SystemTime};
5
6use async_trait::async_trait;
7use reqwest as model_catalog_http;
8
9const MODELS_DEV_URL: &str = "https://models.dev/api.json";
10pub const DEFAULT_REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60);
11
12#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
13pub struct ModelInfo {
14    pub context_window: u64,
15    #[serde(default, skip_serializing_if = "Option::is_none")]
16    pub max_input_tokens: Option<u64>,
17    #[serde(default, skip_serializing_if = "Option::is_none")]
18    pub max_output_tokens: Option<u64>,
19}
20
21#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
22pub struct ResolvedModelSpec {
23    pub configured_model: String,
24    pub resolved_model: String,
25    pub catalog_model_id: String,
26    pub info: ModelInfo,
27}
28
29impl ResolvedModelSpec {
30    pub fn context_window(&self) -> u64 {
31        self.info.context_window
32    }
33}
34
35#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
36pub struct ModelCatalog {
37    entries: HashMap<String, ModelInfo>,
38}
39
40impl ModelCatalog {
41    pub fn get(&self, model_id: &str) -> Option<&ModelInfo> {
42        self.entries.get(model_id)
43    }
44
45    pub fn into_entries(self) -> HashMap<String, ModelInfo> {
46        self.entries
47    }
48
49    pub fn is_empty(&self) -> bool {
50        self.entries.is_empty()
51    }
52
53    pub fn from_models_dev_json(raw: &str) -> Result<Self, String> {
54        let providers = serde_json::from_str::<serde_json::Value>(raw)
55            .map_err(|err| format!("failed to parse models catalog JSON: {err}"))?;
56        let mut entries = HashMap::new();
57        let Some(obj) = providers.as_object() else {
58            return Err("models catalog root is not an object".to_string());
59        };
60        for (provider, provider_info) in obj {
61            let Some(models) = provider_info.get("models").and_then(|m| m.as_object()) else {
62                continue;
63            };
64            for (model_id, info) in models {
65                let Some(context_window) = info
66                    .get("limit")
67                    .and_then(|l| l.get("context"))
68                    .and_then(|c| c.as_u64())
69                else {
70                    continue;
71                };
72                let max_input_tokens = info
73                    .get("limit")
74                    .and_then(|l| l.get("input"))
75                    .and_then(|c| c.as_u64());
76                let max_output_tokens = info
77                    .get("limit")
78                    .and_then(|l| l.get("output"))
79                    .and_then(|c| c.as_u64());
80                entries.insert(
81                    format!("{provider}/{model_id}"),
82                    ModelInfo {
83                        context_window,
84                        max_input_tokens,
85                        max_output_tokens,
86                    },
87                );
88            }
89        }
90        Ok(Self { entries })
91    }
92}
93
94#[derive(Clone, Debug, Default)]
95pub struct MemoryModelCatalogStore {
96    raw: Arc<RwLock<Option<String>>>,
97    modified_at: Arc<RwLock<Option<SystemTime>>>,
98}
99
100impl MemoryModelCatalogStore {
101    pub fn new(raw: Option<String>) -> Self {
102        Self {
103            raw: Arc::new(RwLock::new(raw)),
104            modified_at: Arc::new(RwLock::new(None)),
105        }
106    }
107}
108
109impl ModelCatalogStore for MemoryModelCatalogStore {
110    fn load(&self) -> Result<Option<String>, String> {
111        self.raw
112            .read()
113            .map(|raw| raw.clone())
114            .map_err(|_| "model catalog memory store lock poisoned".to_string())
115    }
116
117    fn save(&self, raw: &str) -> Result<(), String> {
118        self.raw
119            .write()
120            .map_err(|_| "model catalog memory store lock poisoned".to_string())
121            .map(|mut slot| *slot = Some(raw.to_string()))?;
122        self.modified_at
123            .write()
124            .map_err(|_| "model catalog memory store lock poisoned".to_string())
125            .map(|mut slot| *slot = Some(SystemTime::now()))
126    }
127
128    fn modified_at(&self) -> Result<Option<SystemTime>, String> {
129        self.modified_at
130            .read()
131            .map(|value| *value)
132            .map_err(|_| "model catalog memory store lock poisoned".to_string())
133    }
134}
135
136#[async_trait]
137pub trait ModelCatalogSource: Send + Sync {
138    async fn fetch(&self) -> Result<String, String>;
139}
140
141pub trait ModelCatalogStore: Send + Sync {
142    fn load(&self) -> Result<Option<String>, String>;
143    fn save(&self, raw: &str) -> Result<(), String>;
144    fn modified_at(&self) -> Result<Option<SystemTime>, String>;
145}
146
147#[derive(Clone, Debug)]
148pub struct FileModelCatalogStore {
149    path: PathBuf,
150}
151
152impl FileModelCatalogStore {
153    pub fn new(path: impl Into<PathBuf>) -> Self {
154        Self { path: path.into() }
155    }
156
157    pub fn path(&self) -> &PathBuf {
158        &self.path
159    }
160}
161
162impl ModelCatalogStore for FileModelCatalogStore {
163    fn load(&self) -> Result<Option<String>, String> {
164        match std::fs::read_to_string(&self.path) {
165            Ok(raw) => Ok(Some(raw)),
166            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
167            Err(err) => Err(format!("failed to read model catalog cache: {err}")),
168        }
169    }
170
171    fn save(&self, raw: &str) -> Result<(), String> {
172        if let Some(parent) = self.path.parent() {
173            std::fs::create_dir_all(parent)
174                .map_err(|err| format!("failed to create model catalog cache directory: {err}"))?;
175        }
176        std::fs::write(&self.path, raw)
177            .map_err(|err| format!("failed to write model catalog cache: {err}"))
178    }
179
180    fn modified_at(&self) -> Result<Option<SystemTime>, String> {
181        let metadata = match std::fs::metadata(&self.path) {
182            Ok(metadata) => metadata,
183            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
184            Err(err) => return Err(format!("failed to stat model catalog cache: {err}")),
185        };
186        metadata
187            .modified()
188            .map(Some)
189            .map_err(|err| format!("failed to read model catalog cache mtime: {err}"))
190    }
191}
192
193#[derive(Clone, Debug)]
194pub struct ModelsDevHttpSource {
195    url: String,
196    user_agent: String,
197    timeout: Duration,
198}
199
200impl ModelsDevHttpSource {
201    pub fn new(url: impl Into<String>, user_agent: impl Into<String>, timeout: Duration) -> Self {
202        Self {
203            url: url.into(),
204            user_agent: user_agent.into(),
205            timeout,
206        }
207    }
208
209    pub fn default_models_dev() -> Self {
210        Self::new(
211            MODELS_DEV_URL,
212            format!("lash/{}", crate::VERSION),
213            Duration::from_secs(10),
214        )
215    }
216}
217
218#[async_trait]
219impl ModelCatalogSource for ModelsDevHttpSource {
220    async fn fetch(&self) -> Result<String, String> {
221        let client = model_catalog_http::ClientBuilder::new()
222            .timeout(self.timeout)
223            .user_agent(self.user_agent.clone())
224            .build()
225            .map_err(|err| format!("failed to build model catalog client: {err}"))?;
226        let response = client
227            .get(&self.url)
228            .send()
229            .await
230            .map_err(|err| format!("failed to fetch model catalog: {err}"))?;
231        let response = response
232            .error_for_status()
233            .map_err(|err| format!("model catalog source returned an error: {err}"))?;
234        response
235            .text()
236            .await
237            .map_err(|err| format!("failed to read model catalog response: {err}"))
238    }
239}
240
241#[derive(Clone)]
242pub struct CachedModelCatalog {
243    catalog: Arc<RwLock<ModelCatalog>>,
244    store: Arc<dyn ModelCatalogStore>,
245    source: Option<Arc<dyn ModelCatalogSource>>,
246}
247
248impl CachedModelCatalog {
249    pub fn new(
250        store: Arc<dyn ModelCatalogStore>,
251        source: Option<Arc<dyn ModelCatalogSource>>,
252        bundled_snapshot: &'static str,
253    ) -> Result<Self, String> {
254        let catalog = if let Some(raw) = store.load()?
255            && let Ok(parsed) = ModelCatalog::from_models_dev_json(&raw)
256            && !parsed.entries.is_empty()
257        {
258            parsed
259        } else {
260            ModelCatalog::from_models_dev_json(bundled_snapshot).unwrap_or_default()
261        };
262        Ok(Self {
263            catalog: Arc::new(RwLock::new(catalog)),
264            store,
265            source,
266        })
267    }
268
269    pub fn models_dev(
270        store: Arc<dyn ModelCatalogStore>,
271        source: Option<Arc<dyn ModelCatalogSource>>,
272    ) -> Result<Self, String> {
273        Self::new(store, source, bundled_models_dev_snapshot())
274    }
275
276    pub fn snapshot(&self) -> ModelCatalog {
277        self.catalog
278            .read()
279            .map(|catalog| catalog.clone())
280            .unwrap_or_default()
281    }
282
283    pub fn get(&self, model_id: &str) -> Option<ModelInfo> {
284        self.catalog
285            .read()
286            .ok()
287            .and_then(|catalog| catalog.get(model_id).cloned())
288    }
289
290    pub async fn refresh_if_stale(&self, max_age: Duration) -> Result<bool, String> {
291        if self.cache_is_fresh(max_age)? {
292            return Ok(false);
293        }
294        let Some(source) = self.source.as_ref() else {
295            return Ok(false);
296        };
297        let raw = source.fetch().await?;
298        let parsed = ModelCatalog::from_models_dev_json(&raw)?;
299        self.store.save(&raw)?;
300        let mut guard = self
301            .catalog
302            .write()
303            .map_err(|_| "model catalog lock poisoned".to_string())?;
304        *guard = parsed;
305        Ok(true)
306    }
307
308    fn cache_is_fresh(&self, max_age: Duration) -> Result<bool, String> {
309        let Some(modified) = self.store.modified_at()? else {
310            return false_result();
311        };
312        let Ok(age) = SystemTime::now().duration_since(modified) else {
313            return false_result();
314        };
315        Ok(age <= max_age)
316    }
317}
318
319fn false_result() -> Result<bool, String> {
320    Ok(false)
321}
322
323pub fn bundled_models_dev_snapshot() -> &'static str {
324    include_str!(concat!(env!("OUT_DIR"), "/models_snapshot.json"))
325}
326
327#[cfg(test)]
328mod tests {
329    use super::{CachedModelCatalog, ModelCatalog, ModelCatalogStore};
330    use std::sync::{Arc, Mutex};
331    use std::time::SystemTime;
332
333    #[test]
334    fn parse_context_map_reads_provider_prefixed_limits() {
335        let raw = r#"{
336          "anthropic": {
337            "models": {
338              "claude-opus-4-6": { "limit": { "context": 123456, "output": 32000 } }
339            }
340          },
341          "openai": {
342            "models": {
343              "gpt-4.1": { "limit": { "context": 1047576, "input": 900000, "output": 32768 } }
344            }
345          }
346        }"#;
347        let map = ModelCatalog::from_models_dev_json(raw).expect("parse context map");
348        assert_eq!(
349            map.get("anthropic/claude-opus-4-6")
350                .map(|info| info.context_window),
351            Some(123456)
352        );
353        assert_eq!(
354            map.get("openai/gpt-4.1")
355                .and_then(|info| info.max_input_tokens),
356            Some(900000)
357        );
358    }
359
360    struct MemoryStore {
361        raw: Mutex<Option<String>>,
362        modified: Mutex<Option<SystemTime>>,
363    }
364
365    impl MemoryStore {
366        fn new(raw: Option<String>) -> Self {
367            Self {
368                raw: Mutex::new(raw),
369                modified: Mutex::new(None),
370            }
371        }
372    }
373
374    impl ModelCatalogStore for MemoryStore {
375        fn load(&self) -> Result<Option<String>, String> {
376            Ok(self.raw.lock().unwrap().clone())
377        }
378
379        fn save(&self, raw: &str) -> Result<(), String> {
380            *self.raw.lock().unwrap() = Some(raw.to_string());
381            *self.modified.lock().unwrap() = Some(SystemTime::now());
382            Ok(())
383        }
384
385        fn modified_at(&self) -> Result<Option<SystemTime>, String> {
386            Ok(*self.modified.lock().unwrap())
387        }
388    }
389
390    #[test]
391    fn cached_catalog_uses_store_before_snapshot() {
392        let store = Arc::new(MemoryStore::new(Some(
393            r#"{"anthropic":{"models":{"claude-opus-4-6":{"limit":{"context":42}}}}}"#.to_string(),
394        )));
395        let cache = CachedModelCatalog::new(store, None, "{}").expect("cached catalog");
396        assert_eq!(
397            cache
398                .get("anthropic/claude-opus-4-6")
399                .map(|info| info.context_window),
400            Some(42)
401        );
402    }
403}