use std::collections::HashMap;
use std::fs::File;
use std::io::Write as IoWrite;
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
use std::time::{SystemTime, UNIX_EPOCH};
use memmap2::Mmap;
use parking_lot::Mutex;
use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
use std::os::unix::fs::MetadataExt;
pub const SHARD_MAGIC: u32 = 0x5A52414C;
pub const SHARD_FORMAT_VERSION: u32 = 1;
#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
#[archive(check_bytes)]
pub struct ShardHeader {
pub magic: u32,
pub format_version: u32,
pub zshrs_version: String,
pub pointer_width: u32,
pub built_at_secs: u64,
}
#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
#[archive(check_bytes)]
pub struct AutoloadEntry {
pub binary_mtime_at_cache: i64,
pub cached_at_secs: i64,
pub chunk_blob: Vec<u8>,
}
#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
#[archive(check_bytes)]
pub struct AutoloadShard {
pub header: ShardHeader,
pub entries: HashMap<String, AutoloadEntry>,
}
pub struct MmappedShard {
_mmap: Mmap,
archived: *const ArchivedAutoloadShard,
}
unsafe impl Send for MmappedShard {}
unsafe impl Sync for MmappedShard {}
impl MmappedShard {
pub fn open(path: &Path) -> Option<Self> {
let file = File::open(path).ok()?;
let mmap = unsafe { Mmap::map(&file).ok()? };
let archived = rkyv::check_archived_root::<AutoloadShard>(&mmap[..]).ok()?;
let archived_ptr = archived as *const ArchivedAutoloadShard;
Some(Self {
_mmap: mmap,
archived: archived_ptr,
})
}
fn shard(&self) -> &ArchivedAutoloadShard {
unsafe { &*self.archived }
}
fn header_ok(&self) -> bool {
let h = &self.shard().header;
let magic: u32 = h.magic.into();
let fv: u32 = h.format_version.into();
let pw: u32 = h.pointer_width.into();
magic == SHARD_MAGIC
&& fv == SHARD_FORMAT_VERSION
&& pw as usize == std::mem::size_of::<usize>()
&& h.zshrs_version.as_str() == env!("CARGO_PKG_VERSION")
}
fn lookup(&self, name: &str) -> Option<&ArchivedAutoloadEntry> {
self.shard().entries.get(name)
}
}
pub struct AutoloadCache {
path: PathBuf,
lock_path: PathBuf,
mmap: Mutex<Option<MmappedShard>>,
}
impl AutoloadCache {
pub fn open(path: &Path) -> std::io::Result<Self> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let parent = path.parent().unwrap_or_else(|| Path::new("/tmp"));
let lock_path = parent.join(format!(
"{}.lock",
path.file_name()
.and_then(|s| s.to_str())
.unwrap_or("autoloads.rkyv")
));
Ok(Self {
path: path.to_path_buf(),
lock_path,
mmap: Mutex::new(None),
})
}
fn ensure_mmap(&self) {
let mut guard = self.mmap.lock();
if guard.is_none() {
*guard = MmappedShard::open(&self.path);
}
}
fn invalidate_mmap(&self) {
let mut guard = self.mmap.lock();
*guard = None;
}
pub fn get(&self, name: &str) -> Option<Vec<u8>> {
self.ensure_mmap();
let guard = self.mmap.lock();
let shard = guard.as_ref()?;
if !shard.header_ok() {
return None;
}
let entry = shard.lookup(name)?;
if let Some(bin_mtime) = current_binary_mtime_secs() {
let cached_bin_mtime: i64 = entry.binary_mtime_at_cache.into();
if cached_bin_mtime < bin_mtime {
return None;
}
}
Some(entry.chunk_blob.as_slice().to_vec())
}
pub fn put_one(&self, name: &str, chunk_blob: Vec<u8>) -> Result<(), String> {
let _lock = match acquire_lock(&self.lock_path) {
Some(l) => l,
None => return Ok(()),
};
let mut shard = match read_owned_shard(&self.path) {
Some(s)
if s.header.zshrs_version == env!("CARGO_PKG_VERSION")
&& s.header.pointer_width as usize == std::mem::size_of::<usize>()
&& s.header.format_version == SHARD_FORMAT_VERSION =>
{
s
}
_ => fresh_shard(),
};
let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
shard.entries.insert(
name.to_string(),
AutoloadEntry {
binary_mtime_at_cache: bin_mtime,
cached_at_secs: now_secs(),
chunk_blob,
},
);
shard.header.built_at_secs = now_secs() as u64;
write_shard_atomic(&self.path, &shard)?;
self.invalidate_mmap();
Ok(())
}
pub fn merge_in(&self, entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
if entries.is_empty() {
return Ok(());
}
let _lock = match acquire_lock(&self.lock_path) {
Some(l) => l,
None => return Ok(()),
};
let mut shard = match read_owned_shard(&self.path) {
Some(s)
if s.header.zshrs_version == env!("CARGO_PKG_VERSION")
&& s.header.pointer_width as usize == std::mem::size_of::<usize>()
&& s.header.format_version == SHARD_FORMAT_VERSION =>
{
s
}
_ => fresh_shard(),
};
let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
let now = now_secs();
for (name, chunk_blob) in entries {
shard.entries.insert(
name,
AutoloadEntry {
binary_mtime_at_cache: bin_mtime,
cached_at_secs: now,
chunk_blob,
},
);
}
shard.header.built_at_secs = now as u64;
write_shard_atomic(&self.path, &shard)?;
self.invalidate_mmap();
Ok(())
}
pub fn replace_all(&self, entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
let _lock = match acquire_lock(&self.lock_path) {
Some(l) => l,
None => return Ok(()),
};
let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
let now = now_secs();
let mut shard = fresh_shard();
for (name, chunk_blob) in entries {
shard.entries.insert(
name,
AutoloadEntry {
binary_mtime_at_cache: bin_mtime,
cached_at_secs: now,
chunk_blob,
},
);
}
write_shard_atomic(&self.path, &shard)?;
self.invalidate_mmap();
Ok(())
}
pub fn entry_count(&self) -> usize {
self.ensure_mmap();
let guard = self.mmap.lock();
guard.as_ref().map(|s| s.shard().entries.len()).unwrap_or(0)
}
pub fn cached_names(&self) -> std::collections::HashSet<String> {
self.ensure_mmap();
let guard = self.mmap.lock();
let Some(shard) = guard.as_ref() else {
return std::collections::HashSet::new();
};
shard
.shard()
.entries
.keys()
.map(|k| k.as_str().to_string())
.collect()
}
pub fn stats(&self) -> (i64, i64) {
self.ensure_mmap();
let guard = self.mmap.lock();
let Some(shard) = guard.as_ref() else {
return (0, 0);
};
let count = shard.shard().entries.len() as i64;
let bytes: i64 = shard
.shard()
.entries
.values()
.map(|e| e.chunk_blob.len() as i64)
.sum();
(count, bytes)
}
pub fn clear(&self) -> std::io::Result<()> {
let _lock = acquire_lock(&self.lock_path);
let res = match std::fs::remove_file(&self.path) {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(e),
};
self.invalidate_mmap();
res
}
}
fn acquire_lock(path: &Path) -> Option<nix::fcntl::Flock<File>> {
let f = File::options()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(path)
.ok()?;
nix::fcntl::Flock::lock(f, nix::fcntl::FlockArg::LockExclusive).ok()
}
fn fresh_shard() -> AutoloadShard {
AutoloadShard {
header: ShardHeader {
magic: SHARD_MAGIC,
format_version: SHARD_FORMAT_VERSION,
zshrs_version: env!("CARGO_PKG_VERSION").to_string(),
pointer_width: std::mem::size_of::<usize>() as u32,
built_at_secs: now_secs() as u64,
},
entries: HashMap::new(),
}
}
fn read_owned_shard(path: &Path) -> Option<AutoloadShard> {
let bytes = std::fs::read(path).ok()?;
let archived = rkyv::check_archived_root::<AutoloadShard>(&bytes[..]).ok()?;
archived.deserialize(&mut rkyv::Infallible).ok()
}
fn write_shard_atomic(path: &Path, shard: &AutoloadShard) -> Result<(), String> {
let bytes = rkyv::to_bytes::<_, 4096>(shard)
.map_err(|e| format!("rkyv serialize: {}", e))?;
let parent = path.parent().expect("cache path has parent");
let _ = std::fs::create_dir_all(parent);
let pid = std::process::id();
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let tmp_path = parent.join(format!(
"{}.tmp.{}.{}",
path.file_name()
.and_then(|s| s.to_str())
.unwrap_or("autoloads.rkyv"),
pid,
nanos
));
{
let mut f = File::create(&tmp_path).map_err(|e| e.to_string())?;
f.write_all(&bytes).map_err(|e| e.to_string())?;
f.sync_all().map_err(|e| e.to_string())?;
}
std::fs::rename(&tmp_path, path).map_err(|e| e.to_string())?;
Ok(())
}
fn now_secs() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
fn file_mtime(path: &Path) -> Option<(i64, i64)> {
let meta = std::fs::metadata(path).ok()?;
Some((meta.mtime(), meta.mtime_nsec()))
}
fn current_binary_mtime_secs() -> Option<i64> {
static BIN_MTIME: OnceLock<Option<i64>> = OnceLock::new();
*BIN_MTIME.get_or_init(|| {
let exe = std::env::current_exe().ok()?;
let (secs, _) = file_mtime(&exe)?;
Some(secs)
})
}
pub fn default_cache_path() -> PathBuf {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("/tmp"))
.join(".cache/zshrs/autoloads.rkyv")
}
pub fn cache_enabled() -> bool {
!matches!(
std::env::var("ZSHRS_CACHE").as_deref(),
Ok("0") | Ok("false") | Ok("no")
)
}
pub static CACHE: once_cell::sync::Lazy<Option<AutoloadCache>> =
once_cell::sync::Lazy::new(|| {
if !cache_enabled() {
return None;
}
AutoloadCache::open(&default_cache_path()).ok()
});
pub fn try_load(name: &str) -> Option<Vec<u8>> {
let cache = CACHE.as_ref()?;
cache.get(name)
}
pub fn try_save_one(name: &str, chunk_blob: &[u8]) -> Result<(), String> {
let Some(cache) = CACHE.as_ref() else {
return Ok(());
};
cache.put_one(name, chunk_blob.to_vec())
}
pub fn try_replace_all(entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
let Some(cache) = CACHE.as_ref() else {
return Ok(());
};
cache.replace_all(entries)
}
pub fn try_merge_in(entries: HashMap<String, Vec<u8>>) -> Result<(), String> {
let Some(cache) = CACHE.as_ref() else {
return Ok(());
};
cache.merge_in(entries)
}
pub fn cached_names() -> std::collections::HashSet<String> {
CACHE
.as_ref()
.map(|c| c.cached_names())
.unwrap_or_default()
}
pub fn entry_count() -> usize {
CACHE.as_ref().map(|c| c.entry_count()).unwrap_or(0)
}
pub fn stats() -> Option<(i64, i64)> {
CACHE.as_ref().map(|c| c.stats())
}
pub fn clear() -> bool {
CACHE.as_ref().map(|c| c.clear().is_ok()).unwrap_or(false)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn round_trip_one() {
let dir = tempdir().unwrap();
let cache_path = dir.path().join("autoloads.rkyv");
let cache = AutoloadCache::open(&cache_path).unwrap();
cache.put_one("foo", vec![1, 2, 3]).unwrap();
assert_eq!(cache.get("foo"), Some(vec![1, 2, 3]));
assert_eq!(cache.entry_count(), 1);
}
#[test]
fn replace_all_overwrites() {
let dir = tempdir().unwrap();
let cache_path = dir.path().join("autoloads.rkyv");
let cache = AutoloadCache::open(&cache_path).unwrap();
cache.put_one("a", vec![10]).unwrap();
cache.put_one("b", vec![20]).unwrap();
assert_eq!(cache.entry_count(), 2);
let mut new_entries = HashMap::new();
new_entries.insert("c".to_string(), vec![30]);
new_entries.insert("d".to_string(), vec![40]);
cache.replace_all(new_entries).unwrap();
assert_eq!(cache.entry_count(), 2);
assert!(cache.get("a").is_none());
assert!(cache.get("b").is_none());
assert_eq!(cache.get("c"), Some(vec![30]));
assert_eq!(cache.get("d"), Some(vec![40]));
}
#[test]
fn cached_names_returns_keys() {
let dir = tempdir().unwrap();
let cache_path = dir.path().join("autoloads.rkyv");
let cache = AutoloadCache::open(&cache_path).unwrap();
cache.put_one("alpha", vec![1]).unwrap();
cache.put_one("beta", vec![2]).unwrap();
let names = cache.cached_names();
assert!(names.contains("alpha"));
assert!(names.contains("beta"));
assert_eq!(names.len(), 2);
}
#[test]
fn corrupt_shard_returns_none() {
let dir = tempdir().unwrap();
let cache_path = dir.path().join("autoloads.rkyv");
std::fs::write(&cache_path, b"garbage").unwrap();
let cache = AutoloadCache::open(&cache_path).unwrap();
assert!(cache.get("anything").is_none());
assert_eq!(cache.entry_count(), 0);
}
#[test]
fn clear_removes_file() {
let dir = tempdir().unwrap();
let cache_path = dir.path().join("autoloads.rkyv");
let cache = AutoloadCache::open(&cache_path).unwrap();
cache.put_one("x", vec![1]).unwrap();
assert!(cache_path.exists());
cache.clear().unwrap();
assert!(!cache_path.exists());
}
}