1use crate::{config::Config, provider::OpenAIClient};
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs;
6use std::path::PathBuf;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8
9#[derive(Debug, Serialize, Deserialize)]
10pub struct ModelsCache {
11 pub last_updated: u64, pub models: HashMap<String, Vec<String>>, #[serde(skip)]
15 pub cached_json: Option<String>,
16}
17
18#[derive(Debug)]
19pub struct CachedModel {
20 pub provider: String,
21 pub model: String,
22}
23
24impl ModelsCache {
25 pub fn new() -> Self {
26 Self {
27 last_updated: 0,
28 models: HashMap::new(),
29 cached_json: None,
30 }
31 }
32
33 fn invalidate_cache(&mut self) {
34 self.cached_json = None;
35 }
36
37 fn get_cached_json(&mut self) -> Result<&str> {
38 if self.cached_json.is_none() {
39 self.cached_json = Some(serde_json::to_string_pretty(self)?);
40 }
41 Ok(self
42 .cached_json
43 .as_ref()
44 .ok_or_else(|| anyhow::anyhow!("Failed to get cached JSON - internal error"))?
45 .as_str())
46 }
47
48 pub fn load() -> Result<Self> {
49 let cache_path = Self::cache_file_path()?;
50
51 if cache_path.exists() {
52 let content = fs::read_to_string(&cache_path)?;
53 let cache: ModelsCache = serde_json::from_str(&content)?;
54 Ok(cache)
55 } else {
56 Ok(Self::new())
57 }
58 }
59
60 pub fn save(&mut self) -> Result<()> {
61 let cache_path = Self::cache_file_path()?;
62
63 if let Some(parent) = cache_path.parent() {
65 fs::create_dir_all(parent)?;
66 }
67
68 let content = self.get_cached_json()?;
70 fs::write(&cache_path, content)?;
71 Ok(())
72 }
73
74 pub fn is_expired(&self) -> bool {
75 let now = SystemTime::now()
76 .duration_since(UNIX_EPOCH)
77 .unwrap_or(Duration::from_secs(0))
78 .as_secs();
79
80 now - self.last_updated > 86400
82 }
83
84 pub fn needs_refresh(&self) -> bool {
85 self.models.is_empty() || self.is_expired()
86 }
87
88 pub async fn refresh(&mut self) -> Result<()> {
89 println!("Refreshing models cache...");
90
91 let config = Config::load()?;
92 let mut new_models = HashMap::new();
93 let mut successful_providers = 0;
94 let mut total_models = 0;
95
96 for (provider_name, provider_config) in &config.providers {
97 if provider_config.api_key.is_none() {
99 continue;
100 }
101
102 print!("Fetching models from {}... ", provider_name);
103
104 let api_key = provider_config.api_key.clone().ok_or_else(|| {
105 anyhow::anyhow!(
106 "API key is required but not found for provider {}",
107 provider_name
108 )
109 })?;
110
111 let client = OpenAIClient::new_with_headers(
112 provider_config.endpoint.clone(),
113 api_key,
114 provider_config.models_path.clone(),
115 provider_config.chat_path.clone(),
116 provider_config.headers.clone(),
117 );
118
119 match client.list_models().await {
120 Ok(models) => {
121 let model_names: Vec<String> = models.into_iter().map(|m| m.id).collect();
122 let count = model_names.len();
123 new_models.insert(provider_name.clone(), model_names);
124 successful_providers += 1;
125 total_models += count;
126 println!("✓ ({} models)", count);
127 }
128 Err(e) => {
129 println!("✗ ({})", e);
130 }
131 }
132 }
133
134 self.models = new_models;
135 self.last_updated = SystemTime::now()
136 .duration_since(UNIX_EPOCH)
137 .unwrap_or(Duration::from_secs(0))
138 .as_secs();
139
140 self.invalidate_cache();
142 self.save()?;
143
144 println!(
145 "\nCache updated: {} providers, {} total models",
146 successful_providers, total_models
147 );
148 Ok(())
149 }
150
151 pub fn get_all_models(&self) -> Vec<CachedModel> {
152 let mut all_models = Vec::new();
153
154 for (provider, models) in &self.models {
155 for model in models {
156 all_models.push(CachedModel {
157 provider: provider.clone(),
158 model: model.clone(),
159 });
160 }
161 }
162
163 all_models.sort_by(|a, b| a.provider.cmp(&b.provider).then(a.model.cmp(&b.model)));
165
166 all_models
167 }
168
169 fn cache_file_path() -> Result<PathBuf> {
170 let config_dir =
171 dirs::config_dir().ok_or_else(|| anyhow::anyhow!("Could not find config directory"))?;
172
173 Ok(config_dir.join("lc").join("models_cache.json"))
174 }
175}