Skip to main content

zsh/extensions/
autoload_cache.rs

1//! rkyv-backed bytecode cache for autoload functions.
2//!
3//! Single-file shard at `~/.cache/zshrs/autoloads.rkyv`. Keyed by function name
4//! (not file path) — autoload bytecode is identified by the resolved function
5//! name, regardless of which fpath dir or .zwc archive it came from.
6//!
7//! Storage layout (rkyv archived):
8//!   AutoloadShard {
9//!     header: { magic, format_version, zshrs_version, pointer_width, built_at_secs },
10//!     entries: HashMap<function_name, AutoloadEntry>,
11//!   }
12//!   AutoloadEntry { binary_mtime_at_cache, cached_at_secs, chunk_blob: `Vec<u8>` }
13//!
14//! Inner `chunk_blob` is bincode-encoded `fusevm::Chunk` (same constraint as
15//! [`script_cache`](crate::script_cache) module — `fusevm::Chunk` is upstream and only derives serde).
16//!
17//! Invalidation:
18//!   - zshrs binary mtime newer than `binary_mtime_at_cache` ⇒ entry stale
19//!     (any zshrs rebuild silently invalidates the whole shard).
20//!   - There is no per-source-file mtime check here. Autoload bodies live in
21//!     fpath dirs / .zwc archives and the existing `compsys::cache::autoloads`
22//!     SQLite row tracks the source file/offset/size. Rebuild logic relies on
23//!     `compinit` clearing the whole rkyv shard at recompute time (see
24//!     `AutoloadShardWriter` — used by the compinit bulk-prewarm path).
25//!
26//! Bulk-write: compinit prewarms 16k+ autoload bytecodes in one go. Per-batch
27//! shard rewrites (the SQLite-era pattern) would re-serialize 16k entries
28//! 160 times. Instead the new API exposes `AutoloadShardWriter`: accumulate
29//! all `(name, blob)` pairs in memory, then `commit()` writes the shard once.
30//! The single-add `try_save_one` path remains for the cold-start case where
31//! one autoload at a time is compiled by the interactive shell.
32//!
33//! The on-disk shape mirrors [`ScriptShard`](crate::script_cache::ScriptShard) — same header,
34//! same magic-version-pointer_width discipline, same atomic-rename writes.
35
36use std::collections::HashMap;
37use std::fs::File;
38use std::io::Write as IoWrite;
39use std::path::{Path, PathBuf};
40use std::sync::OnceLock;
41use std::time::{SystemTime, UNIX_EPOCH};
42
43use memmap2::Mmap;
44use parking_lot::Mutex;
45use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
46use std::os::unix::fs::MetadataExt;
47
48/// "ZRAL" little-endian.
49pub const SHARD_MAGIC: u32 = 0x5A52414C;
50pub const SHARD_FORMAT_VERSION: u32 = 1;
51
52#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
53#[archive(check_bytes)]
54pub struct ShardHeader {
55    pub magic: u32,
56    pub format_version: u32,
57    pub zshrs_version: String,
58    pub pointer_width: u32,
59    pub built_at_secs: u64,
60}
61
62#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
63#[archive(check_bytes)]
64pub struct AutoloadEntry {
65    pub binary_mtime_at_cache: i64,
66    pub cached_at_secs: i64,
67    pub chunk_blob: Vec<u8>,
68}
69
70#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
71#[archive(check_bytes)]
72pub struct AutoloadShard {
73    pub header: ShardHeader,
74    pub entries: HashMap<String, AutoloadEntry>,
75}
76
77pub struct MmappedShard {
78    _mmap: Mmap,
79    archived: *const ArchivedAutoloadShard,
80}
81
82unsafe impl Send for MmappedShard {}
83unsafe impl Sync for MmappedShard {}
84
85impl MmappedShard {
86    pub fn open(path: &Path) -> Option<Self> {
87        let file = File::open(path).ok()?;
88        let mmap = unsafe { Mmap::map(&file).ok()? };
89        let archived = rkyv::check_archived_root::<AutoloadShard>(&mmap[..]).ok()?;
90        let archived_ptr = archived as *const ArchivedAutoloadShard;
91        Some(Self {
92            _mmap: mmap,
93            archived: archived_ptr,
94        })
95    }
96
97    fn shard(&self) -> &ArchivedAutoloadShard {
98        unsafe { &*self.archived }
99    }
100
101    fn header_ok(&self) -> bool {
102        let h = &self.shard().header;
103        let magic: u32 = h.magic.into();
104        let fv: u32 = h.format_version.into();
105        let pw: u32 = h.pointer_width.into();
106        magic == SHARD_MAGIC
107            && fv == SHARD_FORMAT_VERSION
108            && pw as usize == std::mem::size_of::<usize>()
109            && h.zshrs_version.as_str() == env!("CARGO_PKG_VERSION")
110    }
111
112    fn lookup(&self, name: &str) -> Option<&ArchivedAutoloadEntry> {
113        self.shard().entries.get(name)
114    }
115}
116
117pub struct AutoloadCache {
118    path: PathBuf,
119    lock_path: PathBuf,
120    mmap: Mutex<Option<MmappedShard>>,
121}
122
123impl AutoloadCache {
124    pub fn open(path: &Path) -> std::io::Result<Self> {
125        if let Some(parent) = path.parent() {
126            std::fs::create_dir_all(parent)?;
127        }
128        let parent = path.parent().unwrap_or_else(|| Path::new("/tmp"));
129        let lock_path = parent.join(format!(
130            "{}.lock",
131            path.file_name()
132                .and_then(|s| s.to_str())
133                .unwrap_or("autoloads.rkyv")
134        ));
135        Ok(Self {
136            path: path.to_path_buf(),
137            lock_path,
138            mmap: Mutex::new(None),
139        })
140    }
141
142    fn ensure_mmap(&self) {
143        let mut guard = self.mmap.lock();
144        if guard.is_none() {
145            *guard = MmappedShard::open(&self.path);
146        }
147    }
148
149    fn invalidate_mmap(&self) {
150        let mut guard = self.mmap.lock();
151        *guard = None;
152    }
153
154    pub fn get(&self, name: &str) -> Option<Vec<u8>> {
155        self.ensure_mmap();
156        let guard = self.mmap.lock();
157        let shard = guard.as_ref()?;
158        if !shard.header_ok() {
159            return None;
160        }
161        let entry = shard.lookup(name)?;
162        if let Some(bin_mtime) = current_binary_mtime_secs() {
163            let cached_bin_mtime: i64 = entry.binary_mtime_at_cache.into();
164            if cached_bin_mtime < bin_mtime {
165                return None;
166            }
167        }
168        Some(entry.chunk_blob.as_slice().to_vec())
169    }
170
171    /// Single-write: read shard, insert one entry, write shard. Used by the
172    /// cold-start path when a function is autoloaded before compinit
173    /// pre-warm completes.
174    pub fn put_one(&self, name: &str, chunk_blob: Vec<u8>) -> Result<(), String> {
175        let _lock = match acquire_lock(&self.lock_path) {
176            Some(l) => l,
177            None => return Ok(()),
178        };
179        let mut shard = match read_owned_shard(&self.path) {
180            Some(s)
181                if s.header.zshrs_version == env!("CARGO_PKG_VERSION")
182                    && s.header.pointer_width as usize == std::mem::size_of::<usize>()
183                    && s.header.format_version == SHARD_FORMAT_VERSION =>
184            {
185                s
186            }
187            _ => fresh_shard(),
188        };
189        let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
190        shard.entries.insert(
191            name.to_string(),
192            AutoloadEntry {
193                binary_mtime_at_cache: bin_mtime,
194                cached_at_secs: now_secs(),
195                chunk_blob,
196            },
197        );
198        shard.header.built_at_secs = now_secs() as u64;
199        write_shard_atomic(&self.path, &shard)?;
200        self.invalidate_mmap();
201        Ok(())
202    }
203
204    /// Merge `entries` into the existing shard, inserting/replacing each one.
205    /// Used by compinit's BACKFILL path — when an existing shard is missing
206    /// some entries (e.g. binary mtime bump invalidated a subset), the
207    /// caller computes the missing names + chunks and merges them in
208    /// without touching unrelated entries. Single read + single write,
209    /// even for 16k entries.
210    pub fn merge_in(&self, entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
211        if entries.is_empty() {
212            return Ok(());
213        }
214        let _lock = match acquire_lock(&self.lock_path) {
215            Some(l) => l,
216            None => return Ok(()),
217        };
218        let mut shard = match read_owned_shard(&self.path) {
219            Some(s)
220                if s.header.zshrs_version == env!("CARGO_PKG_VERSION")
221                    && s.header.pointer_width as usize == std::mem::size_of::<usize>()
222                    && s.header.format_version == SHARD_FORMAT_VERSION =>
223            {
224                s
225            }
226            _ => fresh_shard(),
227        };
228        let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
229        let now = now_secs();
230        for (name, chunk_blob) in entries {
231            shard.entries.insert(
232                name,
233                AutoloadEntry {
234                    binary_mtime_at_cache: bin_mtime,
235                    cached_at_secs: now,
236                    chunk_blob,
237                },
238            );
239        }
240        shard.header.built_at_secs = now as u64;
241        write_shard_atomic(&self.path, &shard)?;
242        self.invalidate_mmap();
243        Ok(())
244    }
245
246    /// Replace the entire shard with the given entries. Used by compinit's
247    /// bulk pre-warm — accumulate all (name, chunk_blob) pairs, then commit
248    /// once. Avoids re-serializing 16k entries on every batch flush.
249    pub fn replace_all(&self, entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
250        let _lock = match acquire_lock(&self.lock_path) {
251            Some(l) => l,
252            None => return Ok(()),
253        };
254        let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
255        let now = now_secs();
256        let mut shard = fresh_shard();
257        for (name, chunk_blob) in entries {
258            shard.entries.insert(
259                name,
260                AutoloadEntry {
261                    binary_mtime_at_cache: bin_mtime,
262                    cached_at_secs: now,
263                    chunk_blob,
264                },
265            );
266        }
267        write_shard_atomic(&self.path, &shard)?;
268        self.invalidate_mmap();
269        Ok(())
270    }
271
272    pub fn entry_count(&self) -> usize {
273        self.ensure_mmap();
274        let guard = self.mmap.lock();
275        guard.as_ref().map(|s| s.shard().entries.len()).unwrap_or(0)
276    }
277
278    /// Set of cached function names — caller can subtract this from "all
279    /// known autoload names" to compute the missing-bytecode set without a
280    /// SQL JOIN.
281    pub fn cached_names(&self) -> std::collections::HashSet<String> {
282        self.ensure_mmap();
283        let guard = self.mmap.lock();
284        let Some(shard) = guard.as_ref() else {
285            return std::collections::HashSet::new();
286        };
287        shard
288            .shard()
289            .entries
290            .keys()
291            .map(|k| k.as_str().to_string())
292            .collect()
293    }
294
295    pub fn stats(&self) -> (i64, i64) {
296        self.ensure_mmap();
297        let guard = self.mmap.lock();
298        let Some(shard) = guard.as_ref() else {
299            return (0, 0);
300        };
301        let count = shard.shard().entries.len() as i64;
302        let bytes: i64 = shard
303            .shard()
304            .entries
305            .values()
306            .map(|e| e.chunk_blob.len() as i64)
307            .sum();
308        (count, bytes)
309    }
310
311    pub fn clear(&self) -> std::io::Result<()> {
312        let _lock = acquire_lock(&self.lock_path);
313        let res = match std::fs::remove_file(&self.path) {
314            Ok(()) => Ok(()),
315            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
316            Err(e) => Err(e),
317        };
318        self.invalidate_mmap();
319        res
320    }
321}
322
323fn acquire_lock(path: &Path) -> Option<nix::fcntl::Flock<File>> {
324    let f = File::options()
325        .read(true)
326        .write(true)
327        .create(true)
328        .truncate(false)
329        .open(path)
330        .ok()?;
331    nix::fcntl::Flock::lock(f, nix::fcntl::FlockArg::LockExclusive).ok()
332}
333
334fn fresh_shard() -> AutoloadShard {
335    AutoloadShard {
336        header: ShardHeader {
337            magic: SHARD_MAGIC,
338            format_version: SHARD_FORMAT_VERSION,
339            zshrs_version: env!("CARGO_PKG_VERSION").to_string(),
340            pointer_width: std::mem::size_of::<usize>() as u32,
341            built_at_secs: now_secs() as u64,
342        },
343        entries: HashMap::new(),
344    }
345}
346
347fn read_owned_shard(path: &Path) -> Option<AutoloadShard> {
348    let bytes = std::fs::read(path).ok()?;
349    let archived = rkyv::check_archived_root::<AutoloadShard>(&bytes[..]).ok()?;
350    archived.deserialize(&mut rkyv::Infallible).ok()
351}
352
353fn write_shard_atomic(path: &Path, shard: &AutoloadShard) -> Result<(), String> {
354    let bytes = rkyv::to_bytes::<_, 4096>(shard)
355        .map_err(|e| format!("rkyv serialize: {}", e))?;
356    let parent = path.parent().expect("cache path has parent");
357    let _ = std::fs::create_dir_all(parent);
358    let pid = std::process::id();
359    let nanos = SystemTime::now()
360        .duration_since(UNIX_EPOCH)
361        .map(|d| d.as_nanos())
362        .unwrap_or(0);
363    let tmp_path = parent.join(format!(
364        "{}.tmp.{}.{}",
365        path.file_name()
366            .and_then(|s| s.to_str())
367            .unwrap_or("autoloads.rkyv"),
368        pid,
369        nanos
370    ));
371    {
372        let mut f = File::create(&tmp_path).map_err(|e| e.to_string())?;
373        f.write_all(&bytes).map_err(|e| e.to_string())?;
374        f.sync_all().map_err(|e| e.to_string())?;
375    }
376    std::fs::rename(&tmp_path, path).map_err(|e| e.to_string())?;
377    Ok(())
378}
379
380fn now_secs() -> i64 {
381    SystemTime::now()
382        .duration_since(UNIX_EPOCH)
383        .map(|d| d.as_secs() as i64)
384        .unwrap_or(0)
385}
386
387fn file_mtime(path: &Path) -> Option<(i64, i64)> {
388    let meta = std::fs::metadata(path).ok()?;
389    Some((meta.mtime(), meta.mtime_nsec()))
390}
391
392fn current_binary_mtime_secs() -> Option<i64> {
393    static BIN_MTIME: OnceLock<Option<i64>> = OnceLock::new();
394    *BIN_MTIME.get_or_init(|| {
395        let exe = std::env::current_exe().ok()?;
396        let (secs, _) = file_mtime(&exe)?;
397        Some(secs)
398    })
399}
400
401pub fn default_cache_path() -> PathBuf {
402    dirs::home_dir()
403        .unwrap_or_else(|| PathBuf::from("/tmp"))
404        .join(".cache/zshrs/autoloads.rkyv")
405}
406
407pub fn cache_enabled() -> bool {
408    !matches!(
409        std::env::var("ZSHRS_CACHE").as_deref(),
410        Ok("0") | Ok("false") | Ok("no")
411    )
412}
413
414pub static CACHE: once_cell::sync::Lazy<Option<AutoloadCache>> =
415    once_cell::sync::Lazy::new(|| {
416        if !cache_enabled() {
417            return None;
418        }
419        AutoloadCache::open(&default_cache_path()).ok()
420    });
421
422pub fn try_load(name: &str) -> Option<Vec<u8>> {
423    let cache = CACHE.as_ref()?;
424    cache.get(name)
425}
426
427pub fn try_save_one(name: &str, chunk_blob: &[u8]) -> Result<(), String> {
428    let Some(cache) = CACHE.as_ref() else {
429        return Ok(());
430    };
431    cache.put_one(name, chunk_blob.to_vec())
432}
433
434/// Replace the entire autoload shard with the given entries. Use this from
435/// compinit's bulk pre-warm path — accumulates all `(name, chunk_blob)` in
436/// the `entries` HashMap and writes the shard exactly once.
437pub fn try_replace_all(entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
438    let Some(cache) = CACHE.as_ref() else {
439        return Ok(());
440    };
441    cache.replace_all(entries)
442}
443
444/// Merge new entries into the existing shard. Use this from the compinit
445/// BACKFILL path (existing shard has most entries, just adding the missing
446/// ones).
447pub fn try_merge_in(entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
448    let Some(cache) = CACHE.as_ref() else {
449        return Ok(());
450    };
451    cache.merge_in(entries)
452}
453
454pub fn cached_names() -> std::collections::HashSet<String> {
455    CACHE
456        .as_ref()
457        .map(|c| c.cached_names())
458        .unwrap_or_default()
459}
460
461pub fn entry_count() -> usize {
462    CACHE.as_ref().map(|c| c.entry_count()).unwrap_or(0)
463}
464
465pub fn stats() -> Option<(i64, i64)> {
466    CACHE.as_ref().map(|c| c.stats())
467}
468
469pub fn clear() -> bool {
470    CACHE.as_ref().map(|c| c.clear().is_ok()).unwrap_or(false)
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use tempfile::tempdir;
477
478    #[test]
479    fn round_trip_one() {
480        let dir = tempdir().unwrap();
481        let cache_path = dir.path().join("autoloads.rkyv");
482        let cache = AutoloadCache::open(&cache_path).unwrap();
483        cache.put_one("foo", vec![1, 2, 3]).unwrap();
484        assert_eq!(cache.get("foo"), Some(vec![1, 2, 3]));
485        assert_eq!(cache.entry_count(), 1);
486    }
487
488    #[test]
489    fn replace_all_overwrites() {
490        let dir = tempdir().unwrap();
491        let cache_path = dir.path().join("autoloads.rkyv");
492        let cache = AutoloadCache::open(&cache_path).unwrap();
493        cache.put_one("a", vec![10]).unwrap();
494        cache.put_one("b", vec![20]).unwrap();
495        assert_eq!(cache.entry_count(), 2);
496
497        let mut new_entries = HashMap::new();
498        new_entries.insert("c".to_string(), vec![30]);
499        new_entries.insert("d".to_string(), vec![40]);
500        cache.replace_all(new_entries).unwrap();
501
502        assert_eq!(cache.entry_count(), 2);
503        assert!(cache.get("a").is_none());
504        assert!(cache.get("b").is_none());
505        assert_eq!(cache.get("c"), Some(vec![30]));
506        assert_eq!(cache.get("d"), Some(vec![40]));
507    }
508
509    #[test]
510    fn cached_names_returns_keys() {
511        let dir = tempdir().unwrap();
512        let cache_path = dir.path().join("autoloads.rkyv");
513        let cache = AutoloadCache::open(&cache_path).unwrap();
514        cache.put_one("alpha", vec![1]).unwrap();
515        cache.put_one("beta", vec![2]).unwrap();
516        let names = cache.cached_names();
517        assert!(names.contains("alpha"));
518        assert!(names.contains("beta"));
519        assert_eq!(names.len(), 2);
520    }
521
522    #[test]
523    fn corrupt_shard_returns_none() {
524        let dir = tempdir().unwrap();
525        let cache_path = dir.path().join("autoloads.rkyv");
526        std::fs::write(&cache_path, b"garbage").unwrap();
527        let cache = AutoloadCache::open(&cache_path).unwrap();
528        assert!(cache.get("anything").is_none());
529        assert_eq!(cache.entry_count(), 0);
530    }
531
532    #[test]
533    fn clear_removes_file() {
534        let dir = tempdir().unwrap();
535        let cache_path = dir.path().join("autoloads.rkyv");
536        let cache = AutoloadCache::open(&cache_path).unwrap();
537        cache.put_one("x", vec![1]).unwrap();
538        assert!(cache_path.exists());
539        cache.clear().unwrap();
540        assert!(!cache_path.exists());
541    }
542}