lc/models/
unified_cache.rs

1use crate::{
2    config::Config,
3    debug_log, // Import debug_log macro
4    model_metadata::{extract_models_from_provider, ModelMetadata},
5    provider::Provider,
6};
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::{Arc, RwLock};
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13use tokio::fs;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CachedProviderData {
17    pub last_updated: u64,          // Unix timestamp
18    pub raw_response: String,       // Raw JSON response from provider
19    pub models: Vec<ModelMetadata>, // Extracted metadata
20    // Cache the serialized JSON to avoid repeated serialization
21    #[serde(skip)]
22    pub cached_json: Option<String>,
23}
24
25impl CachedProviderData {
26    fn new(raw_response: String, models: Vec<ModelMetadata>) -> Self {
27        let now = std::time::SystemTime::now()
28            .duration_since(std::time::UNIX_EPOCH)
29            .unwrap_or(std::time::Duration::from_secs(0))
30            .as_secs();
31
32        Self {
33            last_updated: now,
34            raw_response,
35            models,
36            cached_json: None,
37        }
38    }
39
40    fn get_cached_json(&mut self) -> Result<&str> {
41        if self.cached_json.is_none() {
42            self.cached_json = Some(serde_json::to_string_pretty(self)?);
43        }
44        self.cached_json
45            .as_ref()
46            .map(|s| s.as_str())
47            .ok_or_else(|| anyhow::anyhow!("Failed to get cached JSON"))
48    }
49}
50
51// In-memory cache entry with TTL
52#[derive(Debug, Clone)]
53struct MemoryCacheEntry {
54    data: CachedProviderData,
55    expires_at: u64,
56}
57
58impl MemoryCacheEntry {
59    fn new(data: CachedProviderData, ttl_seconds: u64) -> Self {
60        let now = SystemTime::now()
61            .duration_since(UNIX_EPOCH)
62            .unwrap_or(Duration::from_secs(0))
63            .as_secs();
64
65        Self {
66            data,
67            expires_at: now + ttl_seconds,
68        }
69    }
70
71    fn is_expired(&self) -> bool {
72        let now = SystemTime::now()
73            .duration_since(UNIX_EPOCH)
74            .unwrap_or(Duration::from_secs(0))
75            .as_secs();
76
77        now >= self.expires_at
78    }
79}
80
81// Global in-memory cache with efficient invalidation
82lazy_static::lazy_static! {
83    static ref MEMORY_CACHE: Arc<RwLock<HashMap<String, MemoryCacheEntry>>> =
84        Arc::new(RwLock::new(HashMap::new()));
85}
86
87pub struct UnifiedCache;
88
89impl UnifiedCache {
90    /// Cache TTL in seconds (24 hours)
91    const CACHE_TTL: u64 = 86400;
92
93    /// Get the models directory path (cross-platform)
94    pub fn models_dir() -> Result<PathBuf> {
95        let config_dir =
96            dirs::config_dir().ok_or_else(|| anyhow::anyhow!("Could not find config directory"))?;
97
98        Ok(config_dir.join("lc").join("models"))
99    }
100
101    /// Get the cache file path for a specific provider
102    pub fn provider_cache_path(provider: &str) -> Result<PathBuf> {
103        let models_dir = Self::models_dir()?;
104        Ok(models_dir.join(format!("{}.json", provider)))
105    }
106
107    /// Check in-memory cache first, then file cache
108    pub async fn is_cache_fresh(provider: &str) -> Result<bool> {
109        debug_log!("Checking cache freshness for provider '{}'", provider);
110
111        // Check in-memory cache first
112        if let Ok(cache) = MEMORY_CACHE.read() {
113            if let Some(entry) = cache.get(provider) {
114                if !entry.is_expired() {
115                    debug_log!("Found fresh in-memory cache for provider '{}'", provider);
116                    return Ok(true);
117                } else {
118                    debug_log!("In-memory cache expired for provider '{}'", provider);
119                }
120            }
121        }
122
123        // Fall back to file cache
124        let cache_path = Self::provider_cache_path(provider)?;
125
126        if !cache_path.exists() {
127            debug_log!("Cache file does not exist for provider '{}'", provider);
128            return Ok(false);
129        }
130
131        // Use async file I/O to avoid blocking
132        let content = fs::read_to_string(&cache_path).await?;
133        let cached_data: CachedProviderData = serde_json::from_str(&content)?;
134
135        let now = SystemTime::now()
136            .duration_since(UNIX_EPOCH)
137            .unwrap_or(Duration::from_secs(0))
138            .as_secs();
139
140        let age_seconds = now - cached_data.last_updated;
141        let is_fresh = age_seconds < Self::CACHE_TTL;
142
143        debug_log!(
144            "File cache for provider '{}' is {} seconds old, fresh: {}",
145            provider,
146            age_seconds,
147            is_fresh
148        );
149
150        // If file cache is fresh, populate in-memory cache
151        if is_fresh {
152            Self::populate_memory_cache(provider, cached_data);
153        }
154
155        Ok(is_fresh)
156    }
157
158    /// Populate in-memory cache with data
159    fn populate_memory_cache(provider: &str, data: CachedProviderData) {
160        if let Ok(mut cache) = MEMORY_CACHE.write() {
161            let entry = MemoryCacheEntry::new(data, Self::CACHE_TTL);
162            cache.insert(provider.to_string(), entry);
163            debug_log!("Populated in-memory cache for provider '{}'", provider);
164        }
165    }
166
167    /// Invalidate cache for a specific provider
168    pub fn invalidate_provider_cache(provider: &str) {
169        if let Ok(mut cache) = MEMORY_CACHE.write() {
170            cache.remove(provider);
171            debug_log!("Invalidated in-memory cache for provider '{}'", provider);
172        }
173    }
174
175    /// Clear all in-memory cache
176    #[allow(dead_code)]
177    pub fn clear_memory_cache() {
178        if let Ok(mut cache) = MEMORY_CACHE.write() {
179            cache.clear();
180            debug_log!("Cleared all in-memory cache");
181        }
182    }
183
184    /// Get cache age in human-readable format (e.g., "5 mins ago", "2 hrs ago")
185    pub async fn get_cache_age_display(provider: &str) -> Result<String> {
186        // Check in-memory cache first
187        if let Ok(cache) = MEMORY_CACHE.read() {
188            if let Some(entry) = cache.get(provider) {
189                let now = SystemTime::now()
190                    .duration_since(UNIX_EPOCH)
191                    .unwrap_or(Duration::from_secs(0))
192                    .as_secs();
193
194                let age_seconds = now - entry.data.last_updated;
195                return Ok(Self::format_age(age_seconds));
196            }
197        }
198
199        // Fall back to file cache
200        let cache_path = Self::provider_cache_path(provider)?;
201
202        if !cache_path.exists() {
203            return Ok("No cache".to_string());
204        }
205
206        let content = fs::read_to_string(&cache_path).await?;
207        let cached_data: CachedProviderData = serde_json::from_str(&content)?;
208
209        let now = SystemTime::now()
210            .duration_since(UNIX_EPOCH)
211            .unwrap_or(Duration::from_secs(0))
212            .as_secs();
213
214        let age_seconds = now - cached_data.last_updated;
215        Ok(Self::format_age(age_seconds))
216    }
217
218    /// Format age in seconds to human-readable string
219    fn format_age(age_seconds: u64) -> String {
220        if age_seconds < 60 {
221            format!("{} secs ago", age_seconds)
222        } else if age_seconds < 3600 {
223            let minutes = age_seconds / 60;
224            format!("{} min{} ago", minutes, if minutes == 1 { "" } else { "s" })
225        } else if age_seconds < 86400 {
226            let hours = age_seconds / 3600;
227            format!("{} hr{} ago", hours, if hours == 1 { "" } else { "s" })
228        } else {
229            let days = age_seconds / 86400;
230            format!("{} day{} ago", days, if days == 1 { "" } else { "s" })
231        }
232    }
233
234    /// Load cached models for a provider (async with in-memory cache)
235    pub async fn load_provider_models(provider: &str) -> Result<Vec<ModelMetadata>> {
236        debug_log!("Loading cached models for provider '{}'", provider);
237
238        // Check in-memory cache first
239        if let Ok(cache) = MEMORY_CACHE.read() {
240            if let Some(entry) = cache.get(provider) {
241                if !entry.is_expired() {
242                    debug_log!(
243                        "Loaded {} models from in-memory cache for provider '{}'",
244                        entry.data.models.len(),
245                        provider
246                    );
247                    return Ok(entry.data.models.clone());
248                } else {
249                    debug_log!("In-memory cache expired for provider '{}'", provider);
250                }
251            }
252        }
253
254        // Fall back to file cache
255        let cache_path = Self::provider_cache_path(provider)?;
256
257        if !cache_path.exists() {
258            debug_log!("No cache file found for provider '{}'", provider);
259            return Ok(Vec::new());
260        }
261
262        let content = fs::read_to_string(&cache_path).await?;
263        let cached_data: CachedProviderData = serde_json::from_str(&content)?;
264
265        debug_log!(
266            "Loaded {} models from file cache for provider '{}'",
267            cached_data.models.len(),
268            provider
269        );
270
271        // Populate in-memory cache if data is fresh
272        let now = SystemTime::now()
273            .duration_since(UNIX_EPOCH)
274            .unwrap_or(Duration::from_secs(0))
275            .as_secs();
276
277        if now - cached_data.last_updated < Self::CACHE_TTL {
278            Self::populate_memory_cache(provider, cached_data.clone());
279        }
280
281        Ok(cached_data.models)
282    }
283
284    /// Fetch and cache models for a provider
285    pub async fn fetch_and_cache_provider_models(
286        provider: &str,
287        force_refresh: bool,
288    ) -> Result<Vec<ModelMetadata>> {
289        debug_log!(
290            "Fetching models for provider '{}', force_refresh: {}",
291            provider,
292            force_refresh
293        );
294
295        // Check if we need to refresh
296        if !force_refresh && Self::is_cache_fresh(provider).await? {
297            debug_log!(
298                "Using cached models for provider '{}' (cache is fresh)",
299                provider
300            );
301            return Self::load_provider_models(provider).await;
302        }
303
304        debug_log!(
305            "Cache is stale or refresh forced, fetching fresh models for provider '{}'",
306            provider
307        );
308        println!("Fetching models from provider '{}'...", provider);
309
310        // Invalidate existing cache
311        Self::invalidate_provider_cache(provider);
312
313        // Load config and create client
314        let config = Config::load()?;
315        // Load provider with authentication (API key, headers, tokens) from centralized keys
316        let provider_config = config.get_provider_with_auth(provider)?;
317
318        debug_log!(
319            "Creating authenticated client for provider '{}' with endpoint: {}",
320            provider,
321            provider_config.endpoint
322        );
323
324        let mut config_mut = config.clone();
325        let client = crate::chat::create_authenticated_client(&mut config_mut, provider).await?;
326
327        // Save config if tokens were updated
328        if config_mut.get_cached_token(provider) != config.get_cached_token(provider) {
329            debug_log!(
330                "Tokens were updated for provider '{}', saving config",
331                provider
332            );
333            config_mut.save()?;
334        }
335
336        // Build the models URL
337        let models_url = format!(
338            "{}{}",
339            provider_config.endpoint, provider_config.models_path
340        );
341        debug_log!("Fetching models from URL: {}", models_url);
342
343        // Fetch raw response using the client's list_models method
344        debug_log!(
345            "Making API request to fetch models from provider '{}'",
346            provider
347        );
348
349        // Make the actual API request to fetch models
350        let models_list = client.list_models().await?;
351
352        // Create a JSON response that matches the OpenAI models format
353        // This is what we'll cache as the "raw response"
354        let models_json = serde_json::json!({
355            "object": "list",
356            "data": models_list.iter().map(|m| {
357                serde_json::json!({
358                    "id": m.id,
359                    "object": m.object,
360                    "providers": m.providers.iter().map(|p| {
361                        serde_json::json!({
362                            "provider": p.provider,
363                            "status": p.status,
364                            "supports_tools": p.supports_tools,
365                            "supports_structured_output": p.supports_structured_output
366                        })
367                    }).collect::<Vec<_>>()
368                })
369            }).collect::<Vec<_>>()
370        });
371
372        let raw_response = serde_json::to_string_pretty(&models_json)?;
373
374        debug_log!(
375            "Received raw response from provider '{}' ({} bytes)",
376            provider,
377            raw_response.len()
378        );
379
380        // Debug log the full response when -d flag is used
381        debug_log!(
382            "Full response from provider '{}': {}",
383            provider,
384            raw_response
385        );
386
387        // Extract metadata using the new generic approach
388        debug_log!(
389            "Extracting metadata from response for provider '{}'",
390            provider
391        );
392
393        // Create a Provider object for the extractor
394        let provider_obj = Provider {
395            provider: provider.to_string(),
396            status: "active".to_string(),
397            supports_tools: false,
398            supports_structured_output: false,
399        };
400
401        let models = extract_models_from_provider(&provider_obj, &raw_response)?;
402
403        debug_log!(
404            "Extracted {} models from provider '{}'",
405            models.len(),
406            provider
407        );
408
409        // Cache the data (both in-memory and file)
410        debug_log!("Saving cache data for provider '{}'", provider);
411        Self::save_provider_cache(provider, &raw_response, &models).await?;
412
413        Ok(models)
414    }
415
416    /// Save provider data to cache (async with in-memory caching)
417    async fn save_provider_cache(
418        provider: &str,
419        raw_response: &str,
420        models: &[ModelMetadata],
421    ) -> Result<()> {
422        let cache_path = Self::provider_cache_path(provider)?;
423
424        debug_log!(
425            "Saving cache for provider '{}' to: {}",
426            provider,
427            cache_path.display()
428        );
429
430        // Create cached data
431        let cached_data = CachedProviderData::new(raw_response.to_string(), models.to_vec());
432
433        // Update in-memory cache first (fastest access)
434        Self::populate_memory_cache(provider, cached_data.clone());
435
436        // Ensure cache directory exists
437        if let Some(parent) = cache_path.parent() {
438            debug_log!("Creating cache directory: {}", parent.display());
439            fs::create_dir_all(parent).await?;
440        }
441
442        // Use async file I/O to avoid blocking
443        let mut cached_data_mut = cached_data;
444        let content = cached_data_mut.get_cached_json()?;
445        debug_log!(
446            "Writing {} bytes to cache file for provider '{}'",
447            content.len(),
448            provider
449        );
450        fs::write(&cache_path, content).await?;
451
452        debug_log!(
453            "Successfully saved cache for provider '{}' with {} models",
454            provider,
455            models.len()
456        );
457
458        Ok(())
459    }
460
461    /// Load all cached models from all providers (async with in-memory cache)
462    pub async fn load_all_cached_models() -> Result<Vec<ModelMetadata>> {
463        let models_dir = Self::models_dir()?;
464        let mut all_models = Vec::new();
465
466        if !models_dir.exists() {
467            return Ok(all_models);
468        }
469
470        let mut entries = fs::read_dir(&models_dir).await?;
471
472        while let Some(entry) = entries.next_entry().await? {
473            let path = entry.path();
474
475            if let Some(extension) = path.extension() {
476                if extension == "json" {
477                    if let Some(provider_name) = path.file_stem().and_then(|s| s.to_str()) {
478                        match Self::load_provider_models(provider_name).await {
479                            Ok(mut models) => {
480                                all_models.append(&mut models);
481                            }
482                            Err(e) => {
483                                eprintln!(
484                                    "Warning: Failed to load cached models for {}: {}",
485                                    provider_name, e
486                                );
487                            }
488                        }
489                    }
490                }
491            }
492        }
493
494        // Sort by provider, then by model name
495        all_models.sort_by(|a, b| a.provider.cmp(&b.provider).then(a.id.cmp(&b.id)));
496
497        Ok(all_models)
498    }
499
500    /// Refresh all providers' caches
501    pub async fn refresh_all_providers() -> Result<()> {
502        let config = Config::load()?;
503        let mut successful_providers = 0;
504        let mut total_models = 0;
505
506        println!("Refreshing models cache for all providers...");
507
508        for (provider_name, _provider_config) in &config.providers {
509            // Skip providers that have neither API key nor custom headers (after loading centralized auth)
510            let pc_auth = match config.get_provider_with_auth(provider_name) {
511                Ok(cfg) => cfg,
512                Err(_) => continue,
513            };
514            if pc_auth.api_key.is_none() && pc_auth.headers.is_empty() {
515                continue;
516            }
517
518            match Self::fetch_and_cache_provider_models(provider_name, true).await {
519                Ok(models) => {
520                    let count = models.len();
521                    successful_providers += 1;
522                    total_models += count;
523                    println!("✓ {} ({} models)", provider_name, count);
524                }
525                Err(e) => {
526                    println!("✗ {} ({})", provider_name, e);
527                }
528            }
529        }
530
531        println!(
532            "\nCache updated: {} providers, {} total models",
533            successful_providers, total_models
534        );
535        Ok(())
536    }
537}