1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::OnceLock;
4
5use directories::ProjectDirs;
6use serde::{Deserialize, Serialize};
7
8static CACHE_FILE_NAME: &str = "flake_edit.json";
9
10fn cache_dir() -> &'static PathBuf {
11 static CACHE_DIR: OnceLock<PathBuf> = OnceLock::new();
12 CACHE_DIR.get_or_init(|| {
13 let project_dir = ProjectDirs::from("com", "a-kenji", "flake-edit").unwrap();
14 project_dir.data_dir().to_path_buf()
15 })
16}
17
18fn cache_file() -> &'static PathBuf {
19 static CACHE_FILE: OnceLock<PathBuf> = OnceLock::new();
20 CACHE_FILE.get_or_init(|| cache_dir().join(CACHE_FILE_NAME))
21}
22
23#[derive(Debug, Default, Clone, Serialize, Deserialize)]
24struct CacheEntry {
25 id: String,
26 uri: String,
27 hit: u32,
28}
29
30fn entry_key(id: &str, uri: &str) -> String {
35 format!("{}.{}", id, uri)
36}
37
38#[derive(Debug, Default, Clone, Serialize, Deserialize)]
42pub struct Cache {
43 entries: HashMap<String, CacheEntry>,
44}
45
46impl Cache {
47 pub fn commit(&self) -> std::io::Result<()> {
49 let cache_dir = cache_dir();
50 if !cache_dir.exists() {
51 std::fs::create_dir_all(cache_dir)?;
52 }
53 let cache_file_location = cache_file();
54 let cache_file = std::fs::File::create(cache_file_location)?;
55 serde_json::to_writer(cache_file, self)
56 .map_err(|e| std::io::Error::other(e.to_string()))?;
57 Ok(())
58 }
59
60 pub fn load() -> Self {
62 Self::from_path(cache_file())
63 }
64
65 pub fn from_path(path: &std::path::Path) -> Self {
67 Self::try_from_path(path).unwrap_or_else(|e| {
68 tracing::warn!("Could not read cache file {:?}: {}", path, e);
69 Self::default()
70 })
71 }
72
73 pub fn try_from_path(path: &std::path::Path) -> std::io::Result<Self> {
75 let file = std::fs::File::open(path)?;
76 serde_json::from_reader(file).map_err(|e| std::io::Error::other(e.to_string()))
77 }
78
79 pub fn add_entry(&mut self, id: String, uri: String) {
81 let key = entry_key(&id, &uri);
82 match self.entries.get_mut(&key) {
83 Some(entry) => entry.hit += 1,
84 None => {
85 let entry = CacheEntry { id, uri, hit: 0 };
86 self.entries.insert(key, entry);
87 }
88 }
89 }
90
91 pub fn list_uris(&self) -> Vec<String> {
93 let mut entries: Vec<_> = self.entries.values().collect();
94 entries.sort_by(|a, b| b.hit.cmp(&a.hit));
95 entries.iter().map(|e| e.uri.clone()).collect()
96 }
97
98 pub fn list_uris_for_id(&self, id: &str) -> Vec<String> {
104 let mut entries: Vec<_> = self.entries.values().filter(|e| e.id == id).collect();
105 entries.sort_by(|a, b| b.hit.cmp(&a.hit));
106 entries.iter().map(|e| e.uri.clone()).collect()
107 }
108
109 pub fn populate_from_inputs<'a>(&mut self, inputs: impl Iterator<Item = (&'a str, &'a str)>) {
116 for (id, uri) in inputs {
117 let key = entry_key(id, uri);
118 self.entries.entry(key).or_insert_with(|| CacheEntry {
120 id: id.to_string(),
121 uri: uri.to_string(),
122 hit: 0,
123 });
124 }
125 }
126}
127
128pub fn populate_cache_from_inputs<'a>(
137 inputs: impl Iterator<Item = (&'a str, &'a str)>,
138 no_cache: bool,
139) {
140 if no_cache {
141 return;
142 }
143
144 let mut cache = Cache::load();
145 let initial_len = cache.entries.len();
146 cache.populate_from_inputs(inputs);
147
148 if cache.entries.len() > initial_len
150 && let Err(e) = cache.commit()
151 {
152 tracing::debug!("Could not write to cache: {}", e);
153 }
154}
155
156pub fn populate_cache_from_input_map(inputs: &crate::edit::InputMap, no_cache: bool) {
163 populate_cache_from_inputs(
164 inputs
165 .iter()
166 .map(|(id, input)| (id.as_str(), input.url().trim_matches('"'))),
167 no_cache,
168 );
169}
170
171pub const DEFAULT_URI_TYPES: [&str; 14] = [
173 "github:",
174 "gitlab:",
175 "sourcehut:",
176 "git+https://",
177 "git+ssh://",
178 "git+http://",
179 "git+file://",
180 "git://",
181 "path:",
182 "file://",
183 "tarball:",
184 "https://",
185 "http://",
186 "flake:",
187];
188
189#[derive(Debug, Clone, Default)]
193pub enum CacheConfig {
194 #[default]
196 Default,
197 None,
199 Custom(std::path::PathBuf),
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_cache_add_and_list() {
209 let mut cache = Cache::default();
210 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
211 cache.add_entry(
212 "home-manager".into(),
213 "github:nix-community/home-manager".into(),
214 );
215 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into()); let uris = cache.list_uris();
218 assert_eq!(uris.len(), 2);
219 assert_eq!(uris[0], "github:NixOS/nixpkgs");
221 }
222
223 #[test]
224 fn test_list_uris_for_id() {
225 let mut cache = Cache::default();
226 cache.add_entry("treefmt-nix".into(), "github:numtide/treefmt-nix".into());
228 cache.add_entry(
229 "treefmt-nix".into(),
230 "path:/home/user/dev/treefmt-nix".into(),
231 );
232 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
234 cache.add_entry("treefmt-nix".into(), "github:numtide/treefmt-nix".into());
236
237 let uris = cache.list_uris_for_id("treefmt-nix");
238 assert_eq!(uris.len(), 2);
239 assert_eq!(uris[0], "github:numtide/treefmt-nix");
241 assert_eq!(uris[1], "path:/home/user/dev/treefmt-nix");
242
243 assert!(!uris.contains(&"github:NixOS/nixpkgs".to_string()));
245 }
246
247 #[test]
248 fn test_list_uris_for_id_empty() {
249 let cache = Cache::default();
250 let uris = cache.list_uris_for_id("nonexistent");
251 assert!(uris.is_empty());
252 }
253
254 #[test]
255 fn test_populate_from_inputs() {
256 let mut cache = Cache::default();
257
258 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
260 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into()); let inputs = vec![
264 ("nixpkgs", "github:NixOS/nixpkgs"), ("flake-utils", "github:numtide/flake-utils"), ("home-manager", "github:nix-community/home-manager"), ];
268 cache.populate_from_inputs(inputs.into_iter());
269
270 let uris = cache.list_uris();
272 assert_eq!(uris.len(), 3);
273
274 assert_eq!(uris[0], "github:NixOS/nixpkgs");
276
277 assert!(uris.contains(&"github:numtide/flake-utils".to_string()));
279 assert!(uris.contains(&"github:nix-community/home-manager".to_string()));
280 }
281
282 #[test]
283 fn test_populate_does_not_increment_hits() {
284 let mut cache = Cache::default();
285
286 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
288 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into()); let inputs = vec![("nixpkgs", "github:NixOS/nixpkgs")];
292 cache.populate_from_inputs(inputs.into_iter());
293
294 let entry = cache.entries.get("nixpkgs.github:NixOS/nixpkgs").unwrap();
296 assert_eq!(entry.hit, 1);
297 }
298}