1use crate::{config::Config, provider::OpenAIClient};
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs;
6use std::path::PathBuf;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8
9#[derive(Debug, Serialize, Deserialize)]
10pub struct ModelsCache {
11 pub last_updated: u64, pub models: HashMap<String, Vec<String>>, #[serde(skip)]
15 pub cached_json: Option<String>,
16}
17
18#[derive(Debug)]
19pub struct CachedModel {
20 pub provider: String,
21 pub model: String,
22}
23
24impl ModelsCache {
25 pub fn new() -> Self {
26 Self {
27 last_updated: 0,
28 models: HashMap::new(),
29 cached_json: None,
30 }
31 }
32
33 fn invalidate_cache(&mut self) {
34 self.cached_json = None;
35 }
36
37 fn get_cached_json(&mut self) -> Result<&str> {
38 if self.cached_json.is_none() {
39 self.cached_json = Some(serde_json::to_string_pretty(self)?);
40 }
41 Ok(self.cached_json.as_ref().unwrap())
42 }
43
44 pub fn load() -> Result<Self> {
45 let cache_path = Self::cache_file_path()?;
46
47 if cache_path.exists() {
48 let content = fs::read_to_string(&cache_path)?;
49 let cache: ModelsCache = serde_json::from_str(&content)?;
50 Ok(cache)
51 } else {
52 Ok(Self::new())
53 }
54 }
55
56 pub fn save(&mut self) -> Result<()> {
57 let cache_path = Self::cache_file_path()?;
58
59 if let Some(parent) = cache_path.parent() {
61 fs::create_dir_all(parent)?;
62 }
63
64 let content = self.get_cached_json()?;
66 fs::write(&cache_path, content)?;
67 Ok(())
68 }
69
70 pub fn is_expired(&self) -> bool {
71 let now = SystemTime::now()
72 .duration_since(UNIX_EPOCH)
73 .unwrap_or(Duration::from_secs(0))
74 .as_secs();
75
76 now - self.last_updated > 86400
78 }
79
80 pub fn needs_refresh(&self) -> bool {
81 self.models.is_empty() || self.is_expired()
82 }
83
84 pub async fn refresh(&mut self) -> Result<()> {
85 println!("Refreshing models cache...");
86
87 let config = Config::load()?;
88 let mut new_models = HashMap::new();
89 let mut successful_providers = 0;
90 let mut total_models = 0;
91
92 for (provider_name, provider_config) in &config.providers {
93 if provider_config.api_key.is_none() {
95 continue;
96 }
97
98 print!("Fetching models from {}... ", provider_name);
99
100 let client = OpenAIClient::new_with_headers(
101 provider_config.endpoint.clone(),
102 provider_config.api_key.clone().unwrap(),
103 provider_config.models_path.clone(),
104 provider_config.chat_path.clone(),
105 provider_config.headers.clone(),
106 );
107
108 match client.list_models().await {
109 Ok(models) => {
110 let model_names: Vec<String> = models.into_iter().map(|m| m.id).collect();
111 let count = model_names.len();
112 new_models.insert(provider_name.clone(), model_names);
113 successful_providers += 1;
114 total_models += count;
115 println!("✓ ({} models)", count);
116 }
117 Err(e) => {
118 println!("✗ ({})", e);
119 }
120 }
121 }
122
123 self.models = new_models;
124 self.last_updated = SystemTime::now()
125 .duration_since(UNIX_EPOCH)
126 .unwrap_or(Duration::from_secs(0))
127 .as_secs();
128
129 self.invalidate_cache();
131 self.save()?;
132
133 println!(
134 "\nCache updated: {} providers, {} total models",
135 successful_providers, total_models
136 );
137 Ok(())
138 }
139
140 pub fn get_all_models(&self) -> Vec<CachedModel> {
141 let mut all_models = Vec::new();
142
143 for (provider, models) in &self.models {
144 for model in models {
145 all_models.push(CachedModel {
146 provider: provider.clone(),
147 model: model.clone(),
148 });
149 }
150 }
151
152 all_models.sort_by(|a, b| a.provider.cmp(&b.provider).then(a.model.cmp(&b.model)));
154
155 all_models
156 }
157
158 fn cache_file_path() -> Result<PathBuf> {
159 let config_dir =
160 dirs::config_dir().ok_or_else(|| anyhow::anyhow!("Could not find config directory"))?;
161
162 Ok(config_dir.join("lc").join("models_cache.json"))
163 }
164}