use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use memmap2::Mmap;
use rkyv::{Archive, Deserialize, Serialize};
use super::{paths::CachePaths, DaemonError, Result};
pub const SHARD_MAGIC: u32 = 0x5A53_4853;
pub const SHARD_FORMAT_VERSION: u32 = 1;
#[derive(Archive, Deserialize, Serialize, Clone, Debug)]
#[archive(check_bytes)]
pub struct ShardHeader {
pub magic: u32,
pub format_version: u32,
pub generation: u64,
pub built_at_ns: u64,
pub slug: String,
pub source_root: String,
pub entry_count: u32,
}
#[derive(Archive, Deserialize, Serialize, Clone, Debug)]
#[archive(check_bytes)]
pub struct Shard {
pub header: ShardHeader,
pub entries: HashMap<String, Vec<u8>>,
}
impl Shard {
pub fn new(slug: impl Into<String>, source_root: impl Into<String>, generation: u64) -> Self {
Self {
header: ShardHeader {
magic: SHARD_MAGIC,
format_version: SHARD_FORMAT_VERSION,
generation,
built_at_ns: now_ns(),
slug: slug.into(),
source_root: source_root.into(),
entry_count: 0,
},
entries: HashMap::new(),
}
}
pub fn insert(&mut self, fq_name: impl Into<String>, bytecode: Vec<u8>) {
self.entries.insert(fq_name.into(), bytecode);
self.header.entry_count = self.entries.len() as u32;
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
pub fn hash8(source_root: &str) -> String {
use sha2::{Digest, Sha256};
let digest = Sha256::digest(source_root.as_bytes());
digest
.iter()
.take(4)
.map(|b| format!("{:02x}", b))
.collect()
}
pub fn shard_filename(source_root: &str, slug: &str) -> String {
format!("{}-{}.rkyv", hash8(source_root), slug)
}
pub fn shard_path(paths: &CachePaths, source_root: &str, slug: &str) -> PathBuf {
paths.images.join(shard_filename(source_root, slug))
}
pub fn shard_lock_path(paths: &CachePaths, source_root: &str, slug: &str) -> PathBuf {
paths
.images
.join(format!("{}-{}.rkyv.lock", hash8(source_root), slug))
}
pub fn write_shard(paths: &CachePaths, shard: &Shard) -> Result<PathBuf> {
let final_path = shard_path(paths, &shard.header.source_root, &shard.header.slug);
let pid = std::process::id();
let nanos = now_ns();
let tmp_path = paths.images.join(format!(
"{}.tmp.{}.{}",
shard_filename(&shard.header.source_root, &shard.header.slug),
pid,
nanos
));
let bytes = rkyv::to_bytes::<_, 4096>(shard)
.map_err(|e| DaemonError::other(format!("rkyv serialize: {e}")))?;
{
use std::io::Write;
let mut f = std::fs::File::create(&tmp_path)?;
f.write_all(&bytes)?;
f.sync_all()?;
}
std::fs::rename(&tmp_path, &final_path)?;
super::paths::ensure_file_600(&final_path)?;
tracing::info!(
slug = %shard.header.slug,
generation = shard.header.generation,
entries = shard.header.entry_count,
bytes = bytes.len(),
path = %final_path.display(),
"shard written"
);
Ok(final_path)
}
pub struct MmappedShard {
_mmap: Mmap,
path: PathBuf,
archived: *const ArchivedShard,
}
impl std::fmt::Debug for MmappedShard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MmappedShard")
.field("path", &self.path)
.field("entries", &self.entry_count())
.field("generation", &self.generation())
.field("slug", &self.slug())
.finish()
}
}
unsafe impl Send for MmappedShard {}
unsafe impl Sync for MmappedShard {}
impl MmappedShard {
pub fn open(path: &Path) -> Result<Self> {
let file = std::fs::File::open(path)?;
let mmap = unsafe { Mmap::map(&file)? };
let archived = rkyv::check_archived_root::<Shard>(&mmap[..])
.map_err(|e| DaemonError::other(format!("shard validation failed: {e}")))?;
let archived_ptr = archived as *const ArchivedShard;
Ok(Self {
_mmap: mmap,
path: path.to_path_buf(),
archived: archived_ptr,
})
}
pub fn shard(&self) -> &ArchivedShard {
unsafe { &*self.archived }
}
pub fn header(&self) -> &ArchivedShardHeader {
&self.shard().header
}
pub fn generation(&self) -> u64 {
self.shard().header.generation.into()
}
pub fn slug(&self) -> &str {
self.shard().header.slug.as_str()
}
pub fn entry_count(&self) -> u32 {
self.shard().header.entry_count.into()
}
pub fn get(&self, fq_name: &str) -> Option<&[u8]> {
self.shard().entries.get(fq_name).map(|v| v.as_slice())
}
pub fn keys(&self) -> impl Iterator<Item = &str> {
self.shard().entries.keys().map(|s| s.as_str())
}
pub fn path(&self) -> &Path {
&self.path
}
}
pub fn sweep_tmp_files(paths: &CachePaths, max_age: std::time::Duration) -> Result<usize> {
let mut removed = 0usize;
let now = SystemTime::now();
if !paths.images.exists() {
return Ok(0);
}
for entry in std::fs::read_dir(&paths.images)? {
let entry = entry?;
let name = entry.file_name();
let s = name.to_string_lossy();
if !s.contains(".tmp.") {
continue;
}
let meta = entry.metadata()?;
let modified = meta.modified()?;
if now.duration_since(modified).unwrap_or_default() >= max_age {
std::fs::remove_file(entry.path())?;
removed += 1;
tracing::warn!(file = %s, "removed orphaned tmp shard");
}
}
Ok(removed)
}
pub fn list_shards(paths: &CachePaths) -> Result<Vec<PathBuf>> {
let mut out = Vec::new();
if !paths.images.exists() {
return Ok(out);
}
for entry in std::fs::read_dir(&paths.images)? {
let entry = entry?;
let name = entry.file_name();
let s = name.to_string_lossy();
if !s.ends_with(".rkyv") || s.contains(".tmp.") {
continue;
}
out.push(entry.path());
}
out.sort();
Ok(out)
}
fn now_ns() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn fresh() -> (TempDir, CachePaths) {
let tmp = TempDir::new().unwrap();
let paths = CachePaths::with_root(tmp.path().join("zshrs"));
paths.ensure_dirs().unwrap();
(tmp, paths)
}
#[test]
fn hash8_is_deterministic() {
let h1 = hash8("/Users/wizard/.zpwr");
let h2 = hash8("/Users/wizard/.zpwr");
assert_eq!(h1, h2);
assert_eq!(h1.len(), 8);
assert!(h1.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn hash8_distinct_for_distinct_inputs() {
let h1 = hash8("/Users/wizard/.zpwr");
let h2 = hash8("/Users/wizard/.zpwrr");
assert_ne!(h1, h2);
}
#[test]
fn shard_filename_format() {
let f = shard_filename("/some/path", "zpwr");
assert!(f.ends_with("-zpwr.rkyv"));
assert_eq!(f.split('-').next().unwrap().len(), 8);
}
#[test]
fn write_then_read_roundtrip() {
let (_tmp, paths) = fresh();
let mut shard = Shard::new("test", "/Users/wizard/test", 1);
shard.insert("_git", b"\x01\x02\x03 git bytecode".to_vec());
shard.insert("_docker", b"\xaa\xbb\xcc docker bytecode".to_vec());
shard.insert("_kubectl", b"\xff\xee\xdd kubectl bytecode".to_vec());
let path = write_shard(&paths, &shard).unwrap();
assert!(path.exists());
let mode = std::fs::metadata(&path).unwrap().permissions();
use std::os::unix::fs::PermissionsExt;
assert_eq!(mode.mode() & 0o777, 0o600);
let read = MmappedShard::open(&path).unwrap();
assert_eq!(read.entry_count(), 3);
assert_eq!(read.slug(), "test");
assert_eq!(read.generation(), 1);
assert_eq!(read.get("_git"), Some(&b"\x01\x02\x03 git bytecode"[..]));
assert_eq!(
read.get("_kubectl"),
Some(&b"\xff\xee\xdd kubectl bytecode"[..])
);
assert_eq!(read.get("_nonexistent"), None);
}
#[test]
fn write_overwrite_via_atomic_rename() {
let (_tmp, paths) = fresh();
let mut shard1 = Shard::new("test", "/Users/wizard/test", 1);
shard1.insert("_git", b"v1 bytecode".to_vec());
write_shard(&paths, &shard1).unwrap();
let mut shard2 = Shard::new("test", "/Users/wizard/test", 2);
shard2.insert("_git", b"v2 bytecode".to_vec());
shard2.insert("_docker", b"v2 docker".to_vec());
let path = write_shard(&paths, &shard2).unwrap();
let read = MmappedShard::open(&path).unwrap();
assert_eq!(read.generation(), 2);
assert_eq!(read.entry_count(), 2);
assert_eq!(read.get("_git"), Some(&b"v2 bytecode"[..]));
}
#[test]
fn sweep_removes_old_tmp_files() {
let (_tmp, paths) = fresh();
let orphan = paths.images.join("00000000-test.rkyv.tmp.99999.123");
std::fs::write(&orphan, b"orphan").unwrap();
let past = filetime::FileTime::from_unix_time(1, 0);
filetime::set_file_mtime(&orphan, past).unwrap();
let removed = sweep_tmp_files(&paths, std::time::Duration::from_secs(60)).unwrap();
assert_eq!(removed, 1);
assert!(!orphan.exists());
}
#[test]
fn sweep_skips_recent_tmp_files() {
let (_tmp, paths) = fresh();
let recent = paths.images.join("00000000-test.rkyv.tmp.99999.456");
std::fs::write(&recent, b"recent").unwrap();
let removed = sweep_tmp_files(&paths, std::time::Duration::from_secs(60)).unwrap();
assert_eq!(removed, 0);
assert!(recent.exists());
}
#[test]
fn list_shards_filters_tmp_and_lock() {
let (_tmp, paths) = fresh();
std::fs::write(paths.images.join("aaaaaaaa-foo.rkyv"), b"x").unwrap();
std::fs::write(paths.images.join("bbbbbbbb-bar.rkyv"), b"x").unwrap();
std::fs::write(paths.images.join("cccccccc-baz.rkyv.tmp.1.2"), b"x").unwrap();
std::fs::write(paths.images.join("dddddddd-zip.rkyv.lock"), b"x").unwrap();
let listed = list_shards(&paths).unwrap();
assert_eq!(listed.len(), 2);
assert!(listed.iter().all(|p| p.extension().unwrap() == "rkyv"));
assert!(listed
.iter()
.all(|p| !p.to_string_lossy().contains(".tmp.")));
}
#[test]
fn empty_shard_roundtrip() {
let (_tmp, paths) = fresh();
let shard = Shard::new("empty", "/some/root", 1);
let path = write_shard(&paths, &shard).unwrap();
let read = MmappedShard::open(&path).unwrap();
assert_eq!(read.entry_count(), 0);
assert!(read.shard().entries.is_empty());
}
#[test]
fn corrupt_file_rejected_on_open() {
let (_tmp, paths) = fresh();
let bogus = paths.images.join("zzzzzzzz-bogus.rkyv");
std::fs::write(&bogus, b"this is not a valid rkyv archive").unwrap();
let err = MmappedShard::open(&bogus).unwrap_err();
assert!(format!("{}", err).contains("validation failed"));
}
}