Skip to main content

lc/models/
cache.rs

1use crate::{config::Config, provider::OpenAIClient};
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs;
6use std::path::PathBuf;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8
9#[derive(Debug, Serialize, Deserialize)]
10pub struct ModelsCache {
11    pub last_updated: u64,                    // Unix timestamp
12    pub models: HashMap<String, Vec<String>>, // provider -> models
13    // Cache the serialized JSON to avoid repeated serialization
14    #[serde(skip)]
15    pub cached_json: Option<String>,
16}
17
18#[derive(Debug)]
19pub struct CachedModel {
20    pub provider: String,
21    pub model: String,
22}
23
24impl ModelsCache {
25    pub fn new() -> Self {
26        Self {
27            last_updated: 0,
28            models: HashMap::new(),
29            cached_json: None,
30        }
31    }
32
33    fn invalidate_cache(&mut self) {
34        self.cached_json = None;
35    }
36
37    fn get_cached_json(&mut self) -> Result<&str> {
38        if self.cached_json.is_none() {
39            self.cached_json = Some(serde_json::to_string_pretty(self)?);
40        }
41        Ok(self
42            .cached_json
43            .as_ref()
44            .ok_or_else(|| anyhow::anyhow!("Failed to get cached JSON - internal error"))?
45            .as_str())
46    }
47
48    pub fn load() -> Result<Self> {
49        let cache_path = Self::cache_file_path()?;
50
51        if cache_path.exists() {
52            let content = fs::read_to_string(&cache_path)?;
53            let cache: ModelsCache = serde_json::from_str(&content)?;
54            Ok(cache)
55        } else {
56            Ok(Self::new())
57        }
58    }
59
60    pub fn save(&mut self) -> Result<()> {
61        let cache_path = Self::cache_file_path()?;
62
63        // Ensure cache directory exists
64        if let Some(parent) = cache_path.parent() {
65            fs::create_dir_all(parent)?;
66        }
67
68        // Use cached JSON if available to avoid re-serialization
69        let content = self.get_cached_json()?;
70        fs::write(&cache_path, content)?;
71        Ok(())
72    }
73
74    pub fn is_expired(&self) -> bool {
75        let now = SystemTime::now()
76            .duration_since(UNIX_EPOCH)
77            .unwrap_or(Duration::from_secs(0))
78            .as_secs();
79
80        // Cache expires after 24 hours (86400 seconds)
81        now - self.last_updated > 86400
82    }
83
84    pub fn needs_refresh(&self) -> bool {
85        self.models.is_empty() || self.is_expired()
86    }
87
88    pub async fn refresh(&mut self) -> Result<()> {
89        println!("Refreshing models cache...");
90
91        let config = Config::load()?;
92        let mut new_models = HashMap::new();
93        let mut successful_providers = 0;
94        let mut total_models = 0;
95
96        for (provider_name, provider_config) in &config.providers {
97            // Skip providers without API keys
98            if provider_config.api_key.is_none() {
99                continue;
100            }
101
102            print!("Fetching models from {}... ", provider_name);
103
104            let api_key = provider_config.api_key.clone().ok_or_else(|| {
105                anyhow::anyhow!(
106                    "API key is required but not found for provider {}",
107                    provider_name
108                )
109            })?;
110
111            let client = OpenAIClient::new_with_headers(
112                provider_config.endpoint.clone(),
113                api_key,
114                provider_config.models_path.clone(),
115                provider_config.chat_path.clone(),
116                provider_config.headers.clone(),
117            );
118
119            match client.list_models().await {
120                Ok(models) => {
121                    let model_names: Vec<String> = models.into_iter().map(|m| m.id).collect();
122                    let count = model_names.len();
123                    new_models.insert(provider_name.clone(), model_names);
124                    successful_providers += 1;
125                    total_models += count;
126                    println!("✓ ({} models)", count);
127                }
128                Err(e) => {
129                    println!("✗ ({})", e);
130                }
131            }
132        }
133
134        self.models = new_models;
135        self.last_updated = SystemTime::now()
136            .duration_since(UNIX_EPOCH)
137            .unwrap_or(Duration::from_secs(0))
138            .as_secs();
139
140        // Invalidate cached JSON since data changed
141        self.invalidate_cache();
142        self.save()?;
143
144        println!(
145            "\nCache updated: {} providers, {} total models",
146            successful_providers, total_models
147        );
148        Ok(())
149    }
150
151    pub fn get_all_models(&self) -> Vec<CachedModel> {
152        let mut all_models = Vec::new();
153
154        for (provider, models) in &self.models {
155            for model in models {
156                all_models.push(CachedModel {
157                    provider: provider.clone(),
158                    model: model.clone(),
159                });
160            }
161        }
162
163        // Sort by provider, then by model
164        all_models.sort_by(|a, b| a.provider.cmp(&b.provider).then(a.model.cmp(&b.model)));
165
166        all_models
167    }
168
169    fn cache_file_path() -> Result<PathBuf> {
170        let config_dir =
171            dirs::config_dir().ok_or_else(|| anyhow::anyhow!("Could not find config directory"))?;
172
173        Ok(config_dir.join("lc").join("models_cache.json"))
174    }
175}