Skip to main content

flake_edit/
cache.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::OnceLock;
4
5use directories::ProjectDirs;
6use serde::{Deserialize, Serialize};
7
8static CACHE_FILE_NAME: &str = "flake_edit.json";
9
10fn cache_dir() -> &'static PathBuf {
11    static CACHE_DIR: OnceLock<PathBuf> = OnceLock::new();
12    CACHE_DIR.get_or_init(|| {
13        let project_dir = ProjectDirs::from("com", "a-kenji", "flake-edit").unwrap();
14        project_dir.data_dir().to_path_buf()
15    })
16}
17
18fn cache_file() -> &'static PathBuf {
19    static CACHE_FILE: OnceLock<PathBuf> = OnceLock::new();
20    CACHE_FILE.get_or_init(|| cache_dir().join(CACHE_FILE_NAME))
21}
22
23#[derive(Debug, Default, Clone, Serialize, Deserialize)]
24struct CacheEntry {
25    id: String,
26    uri: String,
27    hit: u32,
28}
29
30/// Build the cache entry key as `{id}.{uri}`. Keying by `(id, uri)` allows
31/// multiple URIs per input id (e.g. both a `github:` and a `path:` URI).
32fn entry_key(id: &str, uri: &str) -> String {
33    format!("{}.{}", id, uri)
34}
35
36/// Persistent store of previously seen flake URIs.
37///
38/// Powers shell-completion suggestions, ranked by hit count.
39#[derive(Debug, Default, Clone, Serialize, Deserialize)]
40pub struct Cache {
41    entries: HashMap<String, CacheEntry>,
42}
43
44impl Cache {
45    /// Write the cache to its on-disk location, creating the parent directory
46    /// if needed.
47    pub fn commit(&self) -> std::io::Result<()> {
48        let cache_dir = cache_dir();
49        if !cache_dir.exists() {
50            std::fs::create_dir_all(cache_dir)?;
51        }
52        let cache_file_location = cache_file();
53        let cache_file = std::fs::File::create(cache_file_location)?;
54        serde_json::to_writer(cache_file, self)
55            .map_err(|e| std::io::Error::other(e.to_string()))?;
56        Ok(())
57    }
58
59    /// Load the cache from the default location, or return an empty cache on
60    /// any failure.
61    pub fn load() -> Self {
62        Self::from_path(cache_file())
63    }
64
65    /// Load the cache from `path`, or return an empty cache on any failure.
66    pub fn from_path(path: &std::path::Path) -> Self {
67        Self::try_from_path(path).unwrap_or_else(|e| {
68            tracing::warn!("Could not read cache file {:?}: {}", path, e);
69            Self::default()
70        })
71    }
72
73    /// Load the cache from `path`, surfacing read or parse errors.
74    ///
75    /// # Errors
76    ///
77    /// Returns [`std::io::Error`] if `path` cannot be opened or its JSON
78    /// payload cannot be deserialized.
79    pub fn try_from_path(path: &std::path::Path) -> std::io::Result<Self> {
80        let file = std::fs::File::open(path)?;
81        serde_json::from_reader(file).map_err(|e| std::io::Error::other(e.to_string()))
82    }
83
84    /// Insert or bump the hit count of the `(id, uri)` entry.
85    pub fn add_entry(&mut self, id: String, uri: String) {
86        let key = entry_key(&id, &uri);
87        match self.entries.get_mut(&key) {
88            Some(entry) => entry.hit += 1,
89            None => {
90                let entry = CacheEntry { id, uri, hit: 0 };
91                self.entries.insert(key, entry);
92            }
93        }
94    }
95
96    /// All cached URIs sorted by descending hit count.
97    pub fn list_uris(&self) -> Vec<String> {
98        let mut entries: Vec<_> = self.entries.values().collect();
99        entries.sort_by_key(|b| std::cmp::Reverse(b.hit));
100        entries.iter().map(|e| e.uri.clone()).collect()
101    }
102
103    /// Cached URIs for `id` sorted by descending hit count.
104    ///
105    /// Useful for the `change` workflow, which suggests URIs that have been
106    /// used for the same input id (e.g. both a remote `github:` and a local
107    /// `path:` URI for testing).
108    pub fn list_uris_for_id(&self, id: &str) -> Vec<String> {
109        let mut entries: Vec<_> = self.entries.values().filter(|e| e.id == id).collect();
110        entries.sort_by_key(|b| std::cmp::Reverse(b.hit));
111        entries.iter().map(|e| e.uri.clone()).collect()
112    }
113
114    /// Insert any `(id, uri)` pairs not already present, without bumping hit
115    /// counts on existing entries.
116    ///
117    /// Use this when populating the cache as a side effect of any command
118    /// that reads inputs (`list`, `change`, `update`, ...), not only `add`.
119    pub fn populate_from_inputs<'a>(&mut self, inputs: impl Iterator<Item = (&'a str, &'a str)>) {
120        for (id, uri) in inputs {
121            let key = entry_key(id, uri);
122            self.entries.entry(key).or_insert_with(|| CacheEntry {
123                id: id.to_string(),
124                uri: uri.to_string(),
125                hit: 0,
126            });
127        }
128    }
129}
130
131/// Load the on-disk cache, add any new `(id, uri)` pairs, and commit.
132///
133/// Best-effort: I/O failures are logged, not propagated. A `no_cache` of
134/// `true` makes the call a no-op.
135pub fn populate_cache_from_inputs<'a>(
136    inputs: impl Iterator<Item = (&'a str, &'a str)>,
137    no_cache: bool,
138) {
139    if no_cache {
140        return;
141    }
142
143    let mut cache = Cache::load();
144    let initial_len = cache.entries.len();
145    cache.populate_from_inputs(inputs);
146
147    if cache.entries.len() > initial_len
148        && let Err(e) = cache.commit()
149    {
150        tracing::debug!("Could not write to cache: {}", e);
151    }
152}
153
154/// Convenience wrapper over [`populate_cache_from_inputs`] for the result of
155/// [`crate::edit::FlakeEdit::list`]. A `no_cache` of `true` makes the call a
156/// no-op.
157pub fn populate_cache_from_input_map(inputs: &crate::edit::InputMap, no_cache: bool) {
158    populate_cache_from_inputs(
159        inputs.iter().map(|(id, input)| (id.as_str(), input.url())),
160        no_cache,
161    );
162}
163
164/// Flake URI type prefixes offered by completion.
165pub const DEFAULT_URI_TYPES: [&str; 14] = [
166    "github:",
167    "gitlab:",
168    "sourcehut:",
169    "git+https://",
170    "git+ssh://",
171    "git+http://",
172    "git+file://",
173    "git://",
174    "path:",
175    "file://",
176    "tarball:",
177    "https://",
178    "http://",
179    "flake:",
180];
181
182/// Where to read and write the URI completion cache.
183#[derive(Debug, Clone, Default)]
184pub enum CacheConfig {
185    /// Default XDG location (`~/.local/share/flake-edit/`).
186    #[default]
187    Default,
188    /// Disable caching entirely (`--no-cache`).
189    None,
190    /// Read and write at a custom path (`--cache`, or tests).
191    Custom(std::path::PathBuf),
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_cache_add_and_list() {
200        let mut cache = Cache::default();
201        cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
202        cache.add_entry(
203            "home-manager".into(),
204            "github:nix-community/home-manager".into(),
205        );
206        cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into()); // Increment hit
207
208        let uris = cache.list_uris();
209        assert_eq!(uris.len(), 2);
210        // nixpkgs should be first due to higher hit count
211        assert_eq!(uris[0], "github:NixOS/nixpkgs");
212    }
213
214    #[test]
215    fn test_list_uris_for_id() {
216        let mut cache = Cache::default();
217        // Add multiple URIs for the same ID (simulating local/remote toggle workflow)
218        cache.add_entry("treefmt-nix".into(), "github:numtide/treefmt-nix".into());
219        cache.add_entry(
220            "treefmt-nix".into(),
221            "path:/home/user/dev/treefmt-nix".into(),
222        );
223        // Add unrelated entry
224        cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
225        // Increment hit on the github one
226        cache.add_entry("treefmt-nix".into(), "github:numtide/treefmt-nix".into());
227
228        let uris = cache.list_uris_for_id("treefmt-nix");
229        assert_eq!(uris.len(), 2);
230        // github should be first due to higher hit count
231        assert_eq!(uris[0], "github:numtide/treefmt-nix");
232        assert_eq!(uris[1], "path:/home/user/dev/treefmt-nix");
233
234        // Should not include nixpkgs
235        assert!(!uris.contains(&"github:NixOS/nixpkgs".to_string()));
236    }
237
238    #[test]
239    fn test_list_uris_for_id_empty() {
240        let cache = Cache::default();
241        let uris = cache.list_uris_for_id("nonexistent");
242        assert!(uris.is_empty());
243    }
244
245    #[test]
246    fn test_populate_from_inputs() {
247        let mut cache = Cache::default();
248
249        // Add some initial entries
250        cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
251        cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into()); // hit = 1
252
253        // Populate with inputs (simulating what happens when running a command)
254        let inputs = vec![
255            ("nixpkgs", "github:NixOS/nixpkgs"),           // Already exists
256            ("flake-utils", "github:numtide/flake-utils"), // New
257            ("home-manager", "github:nix-community/home-manager"), // New
258        ];
259        cache.populate_from_inputs(inputs.into_iter());
260
261        // Should have 3 entries total
262        let uris = cache.list_uris();
263        assert_eq!(uris.len(), 3);
264
265        // nixpkgs should still be first (hit=1, others hit=0)
266        assert_eq!(uris[0], "github:NixOS/nixpkgs");
267
268        // New entries should exist
269        assert!(uris.contains(&"github:numtide/flake-utils".to_string()));
270        assert!(uris.contains(&"github:nix-community/home-manager".to_string()));
271    }
272
273    #[test]
274    fn test_populate_does_not_increment_hits() {
275        let mut cache = Cache::default();
276
277        // Add entry with hit count
278        cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into());
279        cache.add_entry("nixpkgs".into(), "github:NixOS/nixpkgs".into()); // hit = 1
280
281        // Populate with same entry
282        let inputs = vec![("nixpkgs", "github:NixOS/nixpkgs")];
283        cache.populate_from_inputs(inputs.into_iter());
284
285        // Hit count should still be 1, not 2
286        let entry = cache.entries.get("nixpkgs.github:NixOS/nixpkgs").unwrap();
287        assert_eq!(entry.hit, 1);
288    }
289}