1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::OnceLock;
4
5use directories::ProjectDirs;
6use serde::{Deserialize, Serialize};
7
8static CACHE_FILE_NAME: &str = "flake_edit.json";
9
10fn cache_dir() -> &'static PathBuf {
11 static CACHE_DIR: OnceLock<PathBuf> = OnceLock::new();
12 CACHE_DIR.get_or_init(|| {
13 let project_dir = ProjectDirs::from("com", "a-kenji", "flake-edit").unwrap();
14 project_dir.data_dir().to_path_buf()
15 })
16}
17
18fn cache_file() -> &'static PathBuf {
19 static CACHE_FILE: OnceLock<PathBuf> = OnceLock::new();
20 CACHE_FILE.get_or_init(|| cache_dir().join(CACHE_FILE_NAME))
21}
22
23#[derive(Debug, Default, Clone, Serialize, Deserialize)]
24struct CacheEntry {
25 id: String,
26 uri: String,
27 hit: u32,
28}
29
30fn entry_key(id: &str, uri: &str) -> String {
33 format!("{}.{}", id, uri)
34}
35
36#[derive(Debug, Default, Clone, Serialize, Deserialize)]
40pub struct Cache {
41 entries: HashMap<String, CacheEntry>,
42}
43
44impl Cache {
45 pub fn commit(&self) -> std::io::Result<()> {
48 let cache_dir = cache_dir();
49 if !cache_dir.exists() {
50 std::fs::create_dir_all(cache_dir)?;
51 }
52 let cache_file_location = cache_file();
53 let cache_file = std::fs::File::create(cache_file_location)?;
54 serde_json::to_writer(cache_file, self)
55 .map_err(|e| std::io::Error::other(e.to_string()))?;
56 Ok(())
57 }
58
59 pub fn load() -> Self {
62 Self::from_path(cache_file())
63 }
64
65 pub fn from_path(path: &std::path::Path) -> Self {
67 Self::try_from_path(path).unwrap_or_else(|e| {
68 tracing::warn!("Could not read cache file {:?}: {}", path, e);
69 Self::default()
70 })
71 }
72
73 pub fn try_from_path(path: &std::path::Path) -> std::io::Result<Self> {
80 let file = std::fs::File::open(path)?;
81 serde_json::from_reader(file).map_err(|e| std::io::Error::other(e.to_string()))
82 }
83
84 pub fn add_entry(&mut self, id: String, uri: String) {
86 let key = entry_key(&id, &uri);
87 match self.entries.get_mut(&key) {
88 Some(entry) => entry.hit += 1,
89 None => {
90 let entry = CacheEntry { id, uri, hit: 0 };
91 self.entries.insert(key, entry);
92 }
93 }
94 }
95
96 pub fn list_uris(&self) -> Vec<String> {
98 let mut entries: Vec<_> = self.entries.values().collect();
99 entries.sort_by_key(|b| std::cmp::Reverse(b.hit));
100 entries.iter().map(|e| e.uri.clone()).collect()
101 }
102
103 pub fn list_uris_for_id(&self, id: &str) -> Vec<String> {
109 let mut entries: Vec<_> = self.entries.values().filter(|e| e.id == id).collect();
110 entries.sort_by_key(|b| std::cmp::Reverse(b.hit));
111 entries.iter().map(|e| e.uri.clone()).collect()
112 }
113
114 pub fn populate_from_inputs<'a>(&mut self, inputs: impl Iterator<Item = (&'a str, &'a str)>) {
120 for (id, uri) in inputs {
121 let key = entry_key(id, uri);
122 self.entries.entry(key).or_insert_with(|| CacheEntry {
123 id: id.to_string(),
124 uri: uri.to_string(),
125 hit: 0,
126 });
127 }
128 }
129}
130
131pub fn populate_cache_from_inputs<'a>(
136 inputs: impl Iterator<Item = (&'a str, &'a str)>,
137 no_cache: bool,
138) {
139 if no_cache {
140 return;
141 }
142
143 let mut cache = Cache::load();
144 let initial_len = cache.entries.len();
145 cache.populate_from_inputs(inputs);
146
147 if cache.entries.len() > initial_len
148 && let Err(e) = cache.commit()
149 {
150 tracing::debug!("Could not write to cache: {}", e);
151 }
152}
153
154pub fn populate_cache_from_input_map(inputs: &crate::edit::InputMap, no_cache: bool) {
158 populate_cache_from_inputs(
159 inputs.iter().map(|(id, input)| (id.as_str(), input.url())),
160 no_cache,
161 );
162}
163
164pub const DEFAULT_URI_TYPES: [&str; 14] = [
166 "github:",
167 "gitlab:",
168 "sourcehut:",
169 "git+https://",
170 "git+ssh://",
171 "git+http://",
172 "git+file://",
173 "git://",
174 "path:",
175 "file://",
176 "tarball:",
177 "https://",
178 "http://",
179 "flake:",
180];
181
182#[derive(Debug, Clone, Default)]
184pub enum CacheConfig {
185 #[default]
187 Default,
188 None,
190 Custom(std::path::PathBuf),
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_cache_add_and_list() {
200 let mut cache = Cache::default();
201 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
202 cache.add_entry(
203 "home-manager".into(),
204 "github:nix-community/home-manager".into(),
205 );
206 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into()); let uris = cache.list_uris();
209 assert_eq!(uris.len(), 2);
210 assert_eq!(uris[0], "github:NixOS/nixpkgs");
212 }
213
214 #[test]
215 fn test_list_uris_for_id() {
216 let mut cache = Cache::default();
217 cache.add_entry("treefmt-nix".into(), "github:numtide/treefmt-nix".into());
219 cache.add_entry(
220 "treefmt-nix".into(),
221 "path:/home/user/dev/treefmt-nix".into(),
222 );
223 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
225 cache.add_entry("treefmt-nix".into(), "github:numtide/treefmt-nix".into());
227
228 let uris = cache.list_uris_for_id("treefmt-nix");
229 assert_eq!(uris.len(), 2);
230 assert_eq!(uris[0], "github:numtide/treefmt-nix");
232 assert_eq!(uris[1], "path:/home/user/dev/treefmt-nix");
233
234 assert!(!uris.contains(&"github:NixOS/nixpkgs".to_string()));
236 }
237
238 #[test]
239 fn test_list_uris_for_id_empty() {
240 let cache = Cache::default();
241 let uris = cache.list_uris_for_id("nonexistent");
242 assert!(uris.is_empty());
243 }
244
245 #[test]
246 fn test_populate_from_inputs() {
247 let mut cache = Cache::default();
248
249 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
251 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into()); let inputs = vec![
255 ("nixpkgs", "github:NixOS/nixpkgs"), ("flake-utils", "github:numtide/flake-utils"), ("home-manager", "github:nix-community/home-manager"), ];
259 cache.populate_from_inputs(inputs.into_iter());
260
261 let uris = cache.list_uris();
263 assert_eq!(uris.len(), 3);
264
265 assert_eq!(uris[0], "github:NixOS/nixpkgs");
267
268 assert!(uris.contains(&"github:numtide/flake-utils".to_string()));
270 assert!(uris.contains(&"github:nix-community/home-manager".to_string()));
271 }
272
273 #[test]
274 fn test_populate_does_not_increment_hits() {
275 let mut cache = Cache::default();
276
277 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
279 cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into()); let inputs = vec![("nixpkgs", "github:NixOS/nixpkgs")];
283 cache.populate_from_inputs(inputs.into_iter());
284
285 let entry = cache.entries.get("nixpkgs.github:NixOS/nixpkgs").unwrap();
287 assert_eq!(entry.hit, 1);
288 }
289}