use std::collections::HashMap;
use std::fs::File;
use std::io::Write as IoWrite;
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
use std::time::{SystemTime, UNIX_EPOCH};
use memmap2::Mmap;
use parking_lot::Mutex;
use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
use std::os::unix::fs::MetadataExt;
pub const SHARD_MAGIC: u32 = 0x5A525343;
pub const SHARD_FORMAT_VERSION: u32 = 1;
#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
#[archive(check_bytes)]
pub struct ShardHeader {
pub magic: u32,
pub format_version: u32,
pub zshrs_version: String,
pub pointer_width: u32,
pub built_at_secs: u64,
}
#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
#[archive(check_bytes)]
pub struct ScriptEntry {
pub mtime_secs: i64,
pub mtime_nsecs: i64,
pub binary_mtime_at_cache: i64,
pub cached_at_secs: i64,
pub chunk_blob: Vec<u8>,
}
#[derive(Archive, RkyvDeserialize, RkyvSerialize, Debug, Clone)]
#[archive(check_bytes)]
pub struct ScriptShard {
pub header: ShardHeader,
pub entries: HashMap<String, ScriptEntry>,
}
pub struct MmappedShard {
_mmap: Mmap,
archived: *const ArchivedScriptShard,
}
unsafe impl Send for MmappedShard {}
unsafe impl Sync for MmappedShard {}
impl MmappedShard {
pub fn open(path: &Path) -> Option<Self> {
let file = File::open(path).ok()?;
let mmap = unsafe { Mmap::map(&file).ok()? };
let archived = rkyv::check_archived_root::<ScriptShard>(&mmap[..]).ok()?;
let archived_ptr = archived as *const ArchivedScriptShard;
Some(Self {
_mmap: mmap,
archived: archived_ptr,
})
}
fn shard(&self) -> &ArchivedScriptShard {
unsafe { &*self.archived }
}
fn header_ok(&self) -> bool {
let h = &self.shard().header;
let magic: u32 = h.magic.into();
let fv: u32 = h.format_version.into();
let pw: u32 = h.pointer_width.into();
magic == SHARD_MAGIC
&& fv == SHARD_FORMAT_VERSION
&& pw as usize == std::mem::size_of::<usize>()
&& h.zshrs_version.as_str() == env!("CARGO_PKG_VERSION")
}
fn lookup(&self, path: &str) -> Option<&ArchivedScriptEntry> {
self.shard().entries.get(path)
}
fn entry_count(&self) -> usize {
self.shard().entries.len()
}
}
pub struct ScriptCache {
path: PathBuf,
lock_path: PathBuf,
mmap: Mutex<Option<MmappedShard>>,
}
impl ScriptCache {
pub fn open(path: &Path) -> std::io::Result<Self> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let parent = path.parent().unwrap_or_else(|| Path::new("/tmp"));
let lock_path = parent.join(format!(
"{}.lock",
path.file_name()
.and_then(|s| s.to_str())
.unwrap_or("scripts.rkyv")
));
Ok(Self {
path: path.to_path_buf(),
lock_path,
mmap: Mutex::new(None),
})
}
fn ensure_mmap(&self) {
let mut guard = self.mmap.lock();
if guard.is_none() {
*guard = MmappedShard::open(&self.path);
}
}
fn invalidate_mmap(&self) {
let mut guard = self.mmap.lock();
*guard = None;
}
pub fn get(&self, path: &str, mtime_secs: i64, mtime_nsecs: i64) -> Option<Vec<u8>> {
self.ensure_mmap();
let guard = self.mmap.lock();
let shard = guard.as_ref()?;
if !shard.header_ok() {
return None;
}
let entry = shard.lookup(path)?;
let entry_mtime_s: i64 = entry.mtime_secs.into();
let entry_mtime_ns: i64 = entry.mtime_nsecs.into();
if entry_mtime_s != mtime_secs || entry_mtime_ns != mtime_nsecs {
return None;
}
if let Some(bin_mtime) = current_binary_mtime_secs() {
let cached_bin_mtime: i64 = entry.binary_mtime_at_cache.into();
if cached_bin_mtime < bin_mtime {
return None;
}
}
Some(entry.chunk_blob.as_slice().to_vec())
}
pub fn put(
&self,
path: &str,
mtime_secs: i64,
mtime_nsecs: i64,
chunk_blob: Vec<u8>,
) -> Result<(), String> {
let _lock = match acquire_lock(&self.lock_path) {
Some(l) => l,
None => return Ok(()),
};
let mut shard = match read_owned_shard(&self.path) {
Some(s)
if s.header.zshrs_version == env!("CARGO_PKG_VERSION")
&& s.header.pointer_width as usize == std::mem::size_of::<usize>()
&& s.header.format_version == SHARD_FORMAT_VERSION =>
{
s
}
_ => fresh_shard(),
};
let bin_mtime = current_binary_mtime_secs().unwrap_or(0);
let entry = ScriptEntry {
mtime_secs,
mtime_nsecs,
binary_mtime_at_cache: bin_mtime,
cached_at_secs: now_secs(),
chunk_blob,
};
shard.entries.insert(path.to_string(), entry);
shard.header.built_at_secs = now_secs() as u64;
write_shard_atomic(&self.path, &shard)?;
self.invalidate_mmap();
Ok(())
}
pub fn stats(&self) -> (i64, i64) {
self.ensure_mmap();
let guard = self.mmap.lock();
let Some(shard) = guard.as_ref() else {
return (0, 0);
};
let count = shard.entry_count() as i64;
let bytes: i64 = shard
.shard()
.entries
.values()
.map(|e| e.chunk_blob.len() as i64)
.sum();
(count, bytes)
}
pub fn list_scripts(&self) -> Vec<(String, f64, String, String)> {
self.ensure_mmap();
let guard = self.mmap.lock();
let Some(shard) = guard.as_ref() else {
return Vec::new();
};
let v = shard.shard().header.zshrs_version.as_str().to_string();
let mut out: Vec<(String, f64, String, String, i64)> = shard
.shard()
.entries
.iter()
.map(|(k, e)| {
let chunk_kb = e.chunk_blob.len() as f64 / 1024.0;
let cached_at: i64 = e.cached_at_secs.into();
let ts = format_local_ts(cached_at);
(
k.as_str().to_string(),
chunk_kb,
v.clone(),
ts,
cached_at,
)
})
.collect();
out.sort_by_key(|x| std::cmp::Reverse(x.4));
out.into_iter()
.map(|(p, ck, ver, ts, _)| (p, ck, ver, ts))
.collect()
}
pub fn evict_stale(&self) -> usize {
let _lock = match acquire_lock(&self.lock_path) {
Some(l) => l,
None => return 0,
};
let mut shard = match read_owned_shard(&self.path) {
Some(s) => s,
None => return 0,
};
let before = shard.entries.len();
shard.entries.retain(|p, e| match file_mtime(Path::new(p)) {
Some((s, ns)) => s == e.mtime_secs && ns == e.mtime_nsecs,
None => false,
});
let evicted = before - shard.entries.len();
if evicted > 0 {
let _ = write_shard_atomic(&self.path, &shard);
self.invalidate_mmap();
}
evicted
}
pub fn clear(&self) -> std::io::Result<()> {
let _lock = acquire_lock(&self.lock_path);
let res = match std::fs::remove_file(&self.path) {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(e),
};
self.invalidate_mmap();
res
}
}
fn acquire_lock(path: &Path) -> Option<nix::fcntl::Flock<File>> {
let f = File::options()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(path)
.ok()?;
nix::fcntl::Flock::lock(f, nix::fcntl::FlockArg::LockExclusive).ok()
}
fn fresh_shard() -> ScriptShard {
ScriptShard {
header: ShardHeader {
magic: SHARD_MAGIC,
format_version: SHARD_FORMAT_VERSION,
zshrs_version: env!("CARGO_PKG_VERSION").to_string(),
pointer_width: std::mem::size_of::<usize>() as u32,
built_at_secs: now_secs() as u64,
},
entries: HashMap::new(),
}
}
fn read_owned_shard(path: &Path) -> Option<ScriptShard> {
let bytes = std::fs::read(path).ok()?;
let archived = rkyv::check_archived_root::<ScriptShard>(&bytes[..]).ok()?;
archived.deserialize(&mut rkyv::Infallible).ok()
}
fn write_shard_atomic(path: &Path, shard: &ScriptShard) -> Result<(), String> {
let bytes = rkyv::to_bytes::<_, 4096>(shard)
.map_err(|e| format!("rkyv serialize: {}", e))?;
let parent = path.parent().expect("cache path has parent");
let _ = std::fs::create_dir_all(parent);
let pid = std::process::id();
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let tmp_path = parent.join(format!(
"{}.tmp.{}.{}",
path.file_name()
.and_then(|s| s.to_str())
.unwrap_or("scripts.rkyv"),
pid,
nanos
));
{
let mut f = File::create(&tmp_path).map_err(|e| e.to_string())?;
f.write_all(&bytes).map_err(|e| e.to_string())?;
f.sync_all().map_err(|e| e.to_string())?;
}
std::fs::rename(&tmp_path, path).map_err(|e| e.to_string())?;
Ok(())
}
fn now_secs() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
fn format_local_ts(secs: i64) -> String {
let dt = chrono::DateTime::<chrono::Local>::from(
UNIX_EPOCH + std::time::Duration::from_secs(secs.max(0) as u64),
);
dt.format("%Y-%m-%d %H:%M:%S").to_string()
}
pub fn file_mtime(path: &Path) -> Option<(i64, i64)> {
let meta = std::fs::metadata(path).ok()?;
Some((meta.mtime(), meta.mtime_nsec()))
}
fn current_binary_mtime_secs() -> Option<i64> {
static BIN_MTIME: OnceLock<Option<i64>> = OnceLock::new();
*BIN_MTIME.get_or_init(|| {
let exe = std::env::current_exe().ok()?;
let (secs, _) = file_mtime(&exe)?;
Some(secs)
})
}
pub fn default_cache_path() -> PathBuf {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("/tmp"))
.join(".zshrs/scripts.rkyv")
}
pub fn cache_enabled() -> bool {
!matches!(
std::env::var("ZSHRS_CACHE").as_deref(),
Ok("0") | Ok("false") | Ok("no")
)
}
pub static CACHE: once_cell::sync::Lazy<Option<ScriptCache>> = once_cell::sync::Lazy::new(|| {
if !cache_enabled() {
return None;
}
ScriptCache::open(&default_cache_path()).ok()
});
pub fn try_load_bytes(path: &Path) -> Option<Vec<u8>> {
let cache = CACHE.as_ref()?;
let canonical = path.canonicalize().ok()?;
let path_str = canonical.to_string_lossy();
let (mtime_s, mtime_ns) = file_mtime(&canonical)?;
cache.get(&path_str, mtime_s, mtime_ns)
}
pub fn try_save_bytes(path: &Path, chunk_blob: &[u8]) -> Result<(), String> {
let Some(cache) = CACHE.as_ref() else {
return Ok(());
};
let canonical = match path.canonicalize() {
Ok(p) => p,
Err(_) => return Ok(()),
};
let path_str = canonical.to_string_lossy();
let (mtime_s, mtime_ns) = match file_mtime(&canonical) {
Some(m) => m,
None => return Ok(()),
};
cache.put(&path_str, mtime_s, mtime_ns, chunk_blob.to_vec())
}
pub fn stats() -> Option<(i64, i64)> {
CACHE.as_ref().map(|c| c.stats())
}
pub fn evict_stale() -> usize {
CACHE.as_ref().map(|c| c.evict_stale()).unwrap_or(0)
}
pub fn clear() -> bool {
CACHE.as_ref().map(|c| c.clear().is_ok()).unwrap_or(false)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn round_trip() {
let dir = tempdir().unwrap();
let cache_path = dir.path().join("scripts.rkyv");
let cache = ScriptCache::open(&cache_path).unwrap();
let script_path = dir.path().join("test.zsh");
std::fs::write(&script_path, "echo hi").unwrap();
let (mtime_s, mtime_ns) = file_mtime(&script_path).unwrap();
let path_str = script_path.to_string_lossy().to_string();
let blob = vec![1u8, 2, 3, 4, 5];
cache.put(&path_str, mtime_s, mtime_ns, blob.clone()).unwrap();
let loaded = cache.get(&path_str, mtime_s, mtime_ns).unwrap();
assert_eq!(loaded, blob);
let (count, _bytes) = cache.stats();
assert_eq!(count, 1);
}
#[test]
fn mtime_invalidation() {
let dir = tempdir().unwrap();
let cache_path = dir.path().join("scripts.rkyv");
let cache = ScriptCache::open(&cache_path).unwrap();
let script_path = dir.path().join("test.zsh");
std::fs::write(&script_path, "echo hi").unwrap();
let (mtime_s, mtime_ns) = file_mtime(&script_path).unwrap();
let path_str = script_path.to_string_lossy().to_string();
cache.put(&path_str, mtime_s, mtime_ns, vec![9u8]).unwrap();
assert!(cache.get(&path_str, mtime_s + 1, mtime_ns).is_none());
}
#[test]
fn second_put_replaces_first() {
let dir = tempdir().unwrap();
let cache_path = dir.path().join("scripts.rkyv");
let cache = ScriptCache::open(&cache_path).unwrap();
let p1 = dir.path().join("a.zsh");
let p2 = dir.path().join("b.zsh");
std::fs::write(&p1, "1").unwrap();
std::fs::write(&p2, "2").unwrap();
let (m1s, m1n) = file_mtime(&p1).unwrap();
let (m2s, m2n) = file_mtime(&p2).unwrap();
cache.put(&p1.to_string_lossy(), m1s, m1n, vec![1u8]).unwrap();
cache.put(&p2.to_string_lossy(), m2s, m2n, vec![2u8]).unwrap();
let (count, _) = cache.stats();
assert_eq!(count, 2);
assert!(cache.get(&p1.to_string_lossy(), m1s, m1n).is_some());
assert!(cache.get(&p2.to_string_lossy(), m2s, m2n).is_some());
}
#[test]
fn corrupt_file_returns_no_mmap() {
let dir = tempdir().unwrap();
let cache_path = dir.path().join("scripts.rkyv");
std::fs::write(&cache_path, b"this is not a valid rkyv archive").unwrap();
let cache = ScriptCache::open(&cache_path).unwrap();
assert!(cache.get("/nope", 0, 0).is_none());
}
#[test]
fn clear_removes_file() {
let dir = tempdir().unwrap();
let cache_path = dir.path().join("scripts.rkyv");
let cache = ScriptCache::open(&cache_path).unwrap();
let script_path = dir.path().join("test.zsh");
std::fs::write(&script_path, "echo hi").unwrap();
let (mtime_s, mtime_ns) = file_mtime(&script_path).unwrap();
cache.put(&script_path.to_string_lossy(), mtime_s, mtime_ns, vec![7u8]).unwrap();
assert!(cache_path.exists());
cache.clear().unwrap();
assert!(!cache_path.exists());
}
}