lc/
models_cache.rs

1use crate::{config::Config, provider::OpenAIClient};
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs;
6use std::path::PathBuf;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8
9#[derive(Debug, Serialize, Deserialize)]
10pub struct ModelsCache {
11    pub last_updated: u64,                    // Unix timestamp
12    pub models: HashMap<String, Vec<String>>, // provider -> models
13    // Cache the serialized JSON to avoid repeated serialization
14    #[serde(skip)]
15    pub cached_json: Option<String>,
16}
17
18#[derive(Debug)]
19pub struct CachedModel {
20    pub provider: String,
21    pub model: String,
22}
23
24impl ModelsCache {
25    pub fn new() -> Self {
26        Self {
27            last_updated: 0,
28            models: HashMap::new(),
29            cached_json: None,
30        }
31    }
32
33    fn invalidate_cache(&mut self) {
34        self.cached_json = None;
35    }
36
37    fn get_cached_json(&mut self) -> Result<&str> {
38        if self.cached_json.is_none() {
39            self.cached_json = Some(serde_json::to_string_pretty(self)?);
40        }
41        Ok(self.cached_json.as_ref().unwrap())
42    }
43
44    pub fn load() -> Result<Self> {
45        let cache_path = Self::cache_file_path()?;
46
47        if cache_path.exists() {
48            let content = fs::read_to_string(&cache_path)?;
49            let cache: ModelsCache = serde_json::from_str(&content)?;
50            Ok(cache)
51        } else {
52            Ok(Self::new())
53        }
54    }
55
56    pub fn save(&mut self) -> Result<()> {
57        let cache_path = Self::cache_file_path()?;
58
59        // Ensure cache directory exists
60        if let Some(parent) = cache_path.parent() {
61            fs::create_dir_all(parent)?;
62        }
63
64        // Use cached JSON if available to avoid re-serialization
65        let content = self.get_cached_json()?;
66        fs::write(&cache_path, content)?;
67        Ok(())
68    }
69
70    pub fn is_expired(&self) -> bool {
71        let now = SystemTime::now()
72            .duration_since(UNIX_EPOCH)
73            .unwrap_or(Duration::from_secs(0))
74            .as_secs();
75
76        // Cache expires after 24 hours (86400 seconds)
77        now - self.last_updated > 86400
78    }
79
80    pub fn needs_refresh(&self) -> bool {
81        self.models.is_empty() || self.is_expired()
82    }
83
84    pub async fn refresh(&mut self) -> Result<()> {
85        println!("Refreshing models cache...");
86
87        let config = Config::load()?;
88        let mut new_models = HashMap::new();
89        let mut successful_providers = 0;
90        let mut total_models = 0;
91
92        for (provider_name, provider_config) in &config.providers {
93            // Skip providers without API keys
94            if provider_config.api_key.is_none() {
95                continue;
96            }
97
98            print!("Fetching models from {}... ", provider_name);
99
100            let client = OpenAIClient::new_with_headers(
101                provider_config.endpoint.clone(),
102                provider_config.api_key.clone().unwrap(),
103                provider_config.models_path.clone(),
104                provider_config.chat_path.clone(),
105                provider_config.headers.clone(),
106            );
107
108            match client.list_models().await {
109                Ok(models) => {
110                    let model_names: Vec<String> = models.into_iter().map(|m| m.id).collect();
111                    let count = model_names.len();
112                    new_models.insert(provider_name.clone(), model_names);
113                    successful_providers += 1;
114                    total_models += count;
115                    println!("✓ ({} models)", count);
116                }
117                Err(e) => {
118                    println!("✗ ({})", e);
119                }
120            }
121        }
122
123        self.models = new_models;
124        self.last_updated = SystemTime::now()
125            .duration_since(UNIX_EPOCH)
126            .unwrap_or(Duration::from_secs(0))
127            .as_secs();
128
129        // Invalidate cached JSON since data changed
130        self.invalidate_cache();
131        self.save()?;
132
133        println!(
134            "\nCache updated: {} providers, {} total models",
135            successful_providers, total_models
136        );
137        Ok(())
138    }
139
140    pub fn get_all_models(&self) -> Vec<CachedModel> {
141        let mut all_models = Vec::new();
142
143        for (provider, models) in &self.models {
144            for model in models {
145                all_models.push(CachedModel {
146                    provider: provider.clone(),
147                    model: model.clone(),
148                });
149            }
150        }
151
152        // Sort by provider, then by model
153        all_models.sort_by(|a, b| a.provider.cmp(&b.provider).then(a.model.cmp(&b.model)));
154
155        all_models
156    }
157
158    fn cache_file_path() -> Result<PathBuf> {
159        let config_dir =
160            dirs::config_dir().ok_or_else(|| anyhow::anyhow!("Could not find config directory"))?;
161
162        Ok(config_dir.join("lc").join("models_cache.json"))
163    }
164}