use anyhow::{Context, Result};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, OnceLock};
use std::hash::BuildHasher;
use ahash::AHasher;
use super::initramfs;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct BaseKey(pub(crate) u64);
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct HashFileKey {
path: PathBuf,
dev: u64,
ino: u64,
mtime_secs: i64,
mtime_nsecs: i64,
}
fn hash_file_cache() -> &'static Mutex<HashMap<HashFileKey, u64>> {
static CACHE: OnceLock<Mutex<HashMap<HashFileKey, u64>>> = OnceLock::new();
CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}
pub(crate) fn hash_file(path: &Path) -> Result<u64> {
use std::fs::File;
use std::os::unix::fs::MetadataExt;
let file = File::open(path).with_context(|| format!("open for hash: {}", path.display()))?;
let meta = file
.metadata()
.with_context(|| format!("stat for hash: {}", path.display()))?;
let cache_key = HashFileKey {
path: path.to_path_buf(),
dev: meta.dev(),
ino: meta.ino(),
mtime_secs: meta.mtime(),
mtime_nsecs: meta.mtime_nsec(),
};
if let Some(cached) = hash_file_cache().lock().unwrap().get(&cache_key).copied() {
return Ok(cached);
}
let mmap = unsafe {
memmap2::Mmap::map(&file).with_context(|| format!("mmap for hash: {}", path.display()))?
};
let mut hasher = ahash::RandomState::with_seeds(0, 0, 0, 0).build_hasher();
hasher.write(&mmap);
let digest = hasher.finish();
hash_file_cache().lock().unwrap().insert(cache_key, digest);
Ok(digest)
}
impl BaseKey {
pub(crate) fn new(
payload: &Path,
scheduler: Option<&Path>,
probe: Option<&Path>,
worker: Option<&Path>,
) -> Result<Self> {
let mut hasher = ahash::RandomState::with_seeds(0, 0, 0, 0).build_hasher();
hash_file(payload)?.hash(&mut hasher);
Self::hash_shared_libs(payload, &mut hasher);
match scheduler {
Some(s) => {
1u8.hash(&mut hasher);
hash_file(s)?.hash(&mut hasher);
Self::hash_shared_libs(s, &mut hasher);
}
None => 0u8.hash(&mut hasher),
}
match probe {
Some(p) => {
1u8.hash(&mut hasher);
hash_file(p)?.hash(&mut hasher);
Self::hash_shared_libs(p, &mut hasher);
}
None => 0u8.hash(&mut hasher),
}
match worker {
Some(w) => {
1u8.hash(&mut hasher);
hash_file(w)?.hash(&mut hasher);
Self::hash_shared_libs(w, &mut hasher);
}
None => 0u8.hash(&mut hasher),
}
Ok(BaseKey(hasher.finish()))
}
pub(crate) fn new_shell(
payload: &Path,
scheduler: Option<&Path>,
probe: Option<&Path>,
worker: Option<&Path>,
include_files: &[(String, PathBuf)],
busybox: bool,
) -> Result<Self> {
let mut hasher = ahash::RandomState::with_seeds(0, 0, 0, 0).build_hasher();
"ktstr-shell".hash(&mut hasher);
busybox.hash(&mut hasher);
hash_file(payload)?.hash(&mut hasher);
Self::hash_shared_libs(payload, &mut hasher);
match scheduler {
Some(s) => {
1u8.hash(&mut hasher);
hash_file(s)?.hash(&mut hasher);
Self::hash_shared_libs(s, &mut hasher);
}
None => 0u8.hash(&mut hasher),
}
match probe {
Some(p) => {
1u8.hash(&mut hasher);
hash_file(p)?.hash(&mut hasher);
Self::hash_shared_libs(p, &mut hasher);
}
None => 0u8.hash(&mut hasher),
}
match worker {
Some(w) => {
1u8.hash(&mut hasher);
hash_file(w)?.hash(&mut hasher);
Self::hash_shared_libs(w, &mut hasher);
}
None => 0u8.hash(&mut hasher),
}
let mut sorted: Vec<(&str, &Path)> = include_files
.iter()
.map(|(a, p)| (a.as_str(), p.as_path()))
.collect();
sorted.sort_by_key(|(a, _)| *a);
sorted.len().hash(&mut hasher);
for (archive_path, host_path) in &sorted {
archive_path.hash(&mut hasher);
hash_file(host_path)?.hash(&mut hasher);
Self::hash_shared_libs(host_path, &mut hasher);
}
Ok(BaseKey(hasher.finish()))
}
fn hash_shared_libs(binary: &Path, hasher: &mut AHasher) {
if let Ok(result) = initramfs::resolve_shared_libs(binary) {
let mut entries: Vec<_> = result.found.iter().map(|(_, p)| p.clone()).collect();
entries.sort();
for p in &entries {
p.as_os_str().as_encoded_bytes().hash(hasher);
if let Ok(sample) = hash_file(p) {
sample.hash(hasher);
}
}
}
}
}
pub(crate) fn base_cache() -> &'static Mutex<HashMap<BaseKey, Arc<Vec<u8>>>> {
static CACHE: OnceLock<Mutex<HashMap<BaseKey, Arc<Vec<u8>>>>> = OnceLock::new();
CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}
pub(crate) enum BaseRef {
Mapped(initramfs::MappedShm),
Owned(Arc<Vec<u8>>),
}
impl AsRef<[u8]> for BaseRef {
fn as_ref(&self) -> &[u8] {
match self {
BaseRef::Mapped(m) => m.as_ref(),
BaseRef::Owned(a) => a,
}
}
}
pub(crate) fn get_or_build_base(
payload: &Path,
extras: &[(&str, &Path)],
include_files: &[(&str, &Path)],
busybox: bool,
key: &BaseKey,
) -> Result<BaseRef> {
let cargo_test_mode = std::env::var("KTSTR_CARGO_TEST_MODE")
.map(|v| !v.is_empty())
.unwrap_or(false);
if let Some(arc) = base_cache().lock().unwrap().get(key).cloned() {
tracing::debug!("initramfs base cache hit (process)");
return Ok(BaseRef::Owned(arc));
}
if cargo_test_mode {
let t0 = std::time::Instant::now();
let data = initramfs::build_initramfs_base(payload, extras, include_files, busybox)?;
let arc = Arc::new(data);
tracing::debug!(
elapsed_us = t0.elapsed().as_micros(),
bytes = arc.len(),
"build_initramfs_base (cargo-test inline)",
);
base_cache()
.lock()
.unwrap()
.insert(key.clone(), arc.clone());
return Ok(BaseRef::Owned(arc));
}
static CLEANUP_ONCE: OnceLock<()> = OnceLock::new();
CLEANUP_ONCE.get_or_init(|| cleanup_stale_shm(key));
let seg_name = initramfs::shm_segment_name(key.0);
match shm_try_create_excl(&seg_name) {
ShmCreateResult::Winner(fd) => {
tracing::debug!("initramfs shm: builder (O_EXCL won)");
let t0 = std::time::Instant::now();
let data = initramfs::build_initramfs_base(payload, extras, include_files, busybox)?;
tracing::debug!(
elapsed_us = t0.elapsed().as_micros(),
bytes = data.len(),
"build_initramfs_base",
);
shm_write_and_release(fd, &data, &seg_name);
hold_shm_lock(&seg_name);
if let Some(mapped) = initramfs::shm_load_base(key.0) {
return Ok(BaseRef::Mapped(mapped));
}
let arc = Arc::new(data);
base_cache()
.lock()
.unwrap()
.insert(key.clone(), arc.clone());
return Ok(BaseRef::Owned(arc));
}
ShmCreateResult::Exists => {
tracing::debug!("initramfs shm: waiting for builder (EEXIST)");
if let Some(mapped) = initramfs::shm_load_base(key.0) {
tracing::debug!("initramfs base cache hit (shm, after wait)");
hold_shm_lock(&seg_name);
return Ok(BaseRef::Mapped(mapped));
}
}
ShmCreateResult::Error => {
if let Some(mapped) = initramfs::shm_load_base(key.0) {
tracing::debug!("initramfs base cache hit (shm)");
hold_shm_lock(&seg_name);
return Ok(BaseRef::Mapped(mapped));
}
}
}
let t0 = std::time::Instant::now();
let data = initramfs::build_initramfs_base(payload, extras, include_files, busybox)?;
let arc = Arc::new(data);
tracing::debug!(
elapsed_us = t0.elapsed().as_micros(),
bytes = arc.len(),
"build_initramfs_base (fallback)",
);
base_cache()
.lock()
.unwrap()
.insert(key.clone(), arc.clone());
if let Err(e) = initramfs::shm_store_base(key.0, &arc) {
tracing::warn!("shm_store_base: {e:#}");
}
Ok(BaseRef::Owned(arc))
}
fn cleanup_stale_shm(current: &BaseKey) {
let current_suffix = format!("{}-{:016x}", initramfs::SHM_ARCH_TAG, current.0);
let shm_dir = match std::fs::read_dir("/dev/shm") {
Ok(d) => d,
Err(_) => return,
};
for entry in shm_dir.flatten() {
let name = entry.file_name();
let Some(name_str) = name.to_str() else {
continue;
};
let suffix = if let Some(s) = name_str.strip_prefix("ktstr-base-") {
s
} else if let Some(s) = name_str.strip_prefix("ktstr-lz4-") {
s
} else if let Some(s) = name_str.strip_prefix("ktstr-gz-") {
s
} else {
continue;
};
if suffix == current_suffix {
continue;
}
let shm_name = format!("/{name_str}");
let Ok(fd) = rustix::shm::open(
shm_name.as_str(),
rustix::shm::OFlags::RDONLY,
rustix::fs::Mode::empty(),
) else {
continue;
};
if rustix::fs::flock(&fd, rustix::fs::FlockOperation::NonBlockingLockExclusive).is_err() {
continue;
}
let Ok(recheck_fd) = rustix::shm::open(
shm_name.as_str(),
rustix::shm::OFlags::RDONLY,
rustix::fs::Mode::empty(),
) else {
let _ = rustix::fs::flock(&fd, rustix::fs::FlockOperation::Unlock);
continue;
};
let stat_fd = rustix::fs::fstat(&fd);
let stat_recheck = rustix::fs::fstat(&recheck_fd);
match (stat_fd, stat_recheck) {
(Ok(a), Ok(b)) if a.st_dev == b.st_dev && a.st_ino == b.st_ino => {
let _ = rustix::shm::unlink(shm_name.as_str());
}
_ => {}
}
let _ = rustix::fs::flock(&fd, rustix::fs::FlockOperation::Unlock);
}
}
static HELD_SHM_LOCKS: Mutex<Vec<rustix::fd::OwnedFd>> = Mutex::new(Vec::new());
fn hold_shm_lock(shm_name: &str) {
for name in [
shm_name.to_string(),
shm_name.replace("ktstr-base-", "ktstr-lz4-"),
] {
if let Ok(fd) = rustix::shm::open(
name.as_str(),
rustix::shm::OFlags::RDONLY,
rustix::fs::Mode::empty(),
) && rustix::fs::flock(&fd, rustix::fs::FlockOperation::NonBlockingLockShared).is_ok()
{
HELD_SHM_LOCKS.lock().unwrap().push(fd);
}
}
}
pub(crate) enum ShmCreateResult {
Winner(std::os::fd::OwnedFd),
Exists,
Error,
}
pub(crate) fn shm_try_create_excl(name: &str) -> ShmCreateResult {
let fd = match rustix::shm::open(
name,
rustix::shm::OFlags::CREATE | rustix::shm::OFlags::EXCL | rustix::shm::OFlags::RDWR,
rustix::fs::Mode::from_raw_mode(0o644),
) {
Ok(fd) => fd,
Err(e) if e == rustix::io::Errno::EXIST => return ShmCreateResult::Exists,
Err(_) => return ShmCreateResult::Error,
};
if rustix::fs::flock(&fd, rustix::fs::FlockOperation::LockExclusive).is_err() {
return ShmCreateResult::Error;
}
ShmCreateResult::Winner(fd)
}
pub(crate) fn shm_write_and_release(fd: std::os::fd::OwnedFd, data: &[u8], seg_name: &str) {
use std::os::fd::AsRawFd;
let raw = fd.as_raw_fd();
unsafe {
if libc::ftruncate(raw, data.len() as libc::off_t) != 0 {
let _ = rustix::shm::unlink(seg_name);
return;
}
let ptr = libc::mmap(
std::ptr::null_mut(),
data.len(),
libc::PROT_WRITE,
libc::MAP_SHARED,
raw,
0,
);
if ptr == libc::MAP_FAILED {
libc::ftruncate(raw, 0);
let _ = rustix::shm::unlink(seg_name);
} else {
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr as *mut u8, data.len());
libc::munmap(ptr, data.len());
}
}
let _ = rustix::fs::flock(&fd, rustix::fs::FlockOperation::Unlock);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shm_try_create_excl_winner_then_exists() {
let name = format!(
"/ktstr-test-shm-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
);
match shm_try_create_excl(&name) {
ShmCreateResult::Winner(fd) => {
match shm_try_create_excl(&name) {
ShmCreateResult::Exists => {}
ShmCreateResult::Winner(_other) => {
let _ = rustix::shm::unlink(name.as_str());
drop(fd);
panic!("second shm_try_create_excl must return Exists, not Winner");
}
ShmCreateResult::Error => {
let _ = rustix::shm::unlink(name.as_str());
drop(fd);
panic!("second shm_try_create_excl returned Error");
}
}
shm_write_and_release(fd, b"ok", &name);
let _ = rustix::shm::unlink(name.as_str());
}
ShmCreateResult::Exists => {
let _ = rustix::shm::unlink(name.as_str());
panic!("test setup collision on shm name {name}");
}
ShmCreateResult::Error => {
skip!("shm_open unavailable in this environment");
}
}
}
#[test]
fn shm_write_and_release_publishes_data() {
let name = format!(
"/ktstr-test-shm-write-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
);
let fd = match shm_try_create_excl(&name) {
ShmCreateResult::Winner(fd) => fd,
_ => {
skip!("shm_open unavailable");
}
};
let payload = b"shm-unit-test-payload";
shm_write_and_release(fd, payload, &name);
let rfd = rustix::shm::open(
name.as_str(),
rustix::shm::OFlags::RDONLY,
rustix::fs::Mode::empty(),
)
.expect("shm_open for read failed");
let st = rustix::fs::fstat(&rfd).expect("fstat failed");
assert_eq!(st.st_size as usize, payload.len());
drop(rfd);
let _ = rustix::shm::unlink(name.as_str());
}
#[test]
fn base_key_same_inputs_match() {
let exe = crate::resolve_current_exe().unwrap();
let k1 = BaseKey::new(&exe, None, None, None).unwrap();
let k2 = BaseKey::new(&exe, None, None, None).unwrap();
assert_eq!(k1, k2);
}
#[test]
fn base_key_nonexistent_payload_fails() {
let result = BaseKey::new(Path::new("/nonexistent/binary"), None, None, None);
assert!(result.is_err());
}
#[test]
fn base_key_different_content_differs() {
let tmp =
std::env::temp_dir().join(format!("ktstr-cache-content-test-{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let bin = tmp.join("payload");
std::fs::write(&bin, b"content_v1").unwrap();
let k1 = BaseKey::new(&bin, None, None, None).unwrap();
std::fs::write(&bin, b"content_v2").unwrap();
let k2 = BaseKey::new(&bin, None, None, None).unwrap();
assert_ne!(
k1, k2,
"different file content should produce different key"
);
std::fs::remove_dir_all(&tmp).unwrap();
}
#[test]
fn base_key_with_scheduler() {
let exe = crate::resolve_current_exe().unwrap();
let k1 = BaseKey::new(&exe, None, None, None).unwrap();
let k2 = BaseKey::new(&exe, Some(&exe), None, None).unwrap();
assert_ne!(k1, k2, "with vs without scheduler should differ");
}
#[test]
fn hash_file_is_ahash_stable_golden() {
let tmp =
std::env::temp_dir().join(format!("ktstr-hash-golden-test-{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let f = tmp.join("known");
std::fs::write(&f, b"ktstr cache key probe").unwrap();
let observed = hash_file(&f).unwrap();
let mut h = ahash::RandomState::with_seeds(0, 0, 0, 0).build_hasher();
h.write(b"ktstr cache key probe");
let expected = h.finish();
assert_eq!(
observed, expected,
"hash_file must match ahash::RandomState::with_seeds(0, 0, 0, 0).build_hasher()"
);
std::fs::remove_dir_all(&tmp).unwrap();
}
#[test]
fn hash_file_large_file() {
let tmp =
std::env::temp_dir().join(format!("ktstr-hash-sample-test-{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let f = tmp.join("big");
let data: Vec<u8> = (0..16384).map(|i| (i % 256) as u8).collect();
std::fs::write(&f, &data).unwrap();
let h = hash_file(&f).unwrap();
assert_eq!(h, hash_file(&f).unwrap());
std::fs::remove_dir_all(&tmp).unwrap();
}
#[test]
fn hash_file_memoisation_invalidates_on_change() {
let tmp = std::env::temp_dir().join(format!("ktstr-hash-memo-test-{}", std::process::id()));
std::fs::create_dir_all(&tmp).unwrap();
let f = tmp.join("rev");
std::fs::write(&f, b"revision-one").unwrap();
let h1 = hash_file(&f).unwrap();
std::thread::sleep(std::time::Duration::from_millis(1100));
std::fs::write(&f, b"revision-two-with-different-bytes").unwrap();
let h2 = hash_file(&f).unwrap();
assert_ne!(h1, h2, "mtime change must bypass the memoisation cache");
std::fs::remove_dir_all(&tmp).unwrap();
}
#[test]
fn base_cache_hit() {
let exe = crate::resolve_current_exe().unwrap();
let key = BaseKey::new(&exe, None, None, None).unwrap();
let sentinel = Arc::new(vec![0xDE, 0xAD]);
base_cache()
.lock()
.unwrap()
.insert(key.clone(), sentinel.clone());
let cached = base_cache().lock().unwrap().get(&key).cloned();
assert!(cached.is_some());
assert!(Arc::ptr_eq(&cached.unwrap(), &sentinel));
base_cache().lock().unwrap().remove(&key);
}
#[test]
fn shm_store_and_load_roundtrip() {
let hash = 0xDEAD_BEEF_CAFE_1234u64;
let data = vec![0x07u8, 0x07, 0x01]; initramfs::shm_store_base(hash, &data).unwrap();
let loaded = initramfs::shm_load_base(hash);
assert!(loaded.is_some(), "shm_load_base should return Some");
assert_eq!(loaded.unwrap().as_ref(), &data[..]);
initramfs::shm_unlink_base(hash);
}
#[test]
fn shm_different_hashes_independent() {
let h1 = 0x1111_2222_3333_4444u64;
let h2 = 0x5555_6666_7777_8888u64;
let d1 = vec![0xAAu8; 16];
let d2 = vec![0xBBu8; 32];
initramfs::shm_store_base(h1, &d1).unwrap();
initramfs::shm_store_base(h2, &d2).unwrap();
assert_eq!(initramfs::shm_load_base(h1).unwrap().as_ref(), &d1[..]);
assert_eq!(initramfs::shm_load_base(h2).unwrap().as_ref(), &d2[..]);
initramfs::shm_unlink_base(h1);
initramfs::shm_unlink_base(h2);
}
#[test]
fn get_or_build_base_cargo_test_mode_uses_process_local_cache() {
use crate::test_support::test_helpers::{EnvVarGuard, lock_env};
let _lock = lock_env();
let _env = EnvVarGuard::set("KTSTR_CARGO_TEST_MODE", "1");
let exe = crate::resolve_current_exe().unwrap();
let key = BaseKey::new(&exe, None, None, None).unwrap();
let sentinel = Arc::new(vec![0xC0u8, 0xDE, 0x01, 0x07, 0x07, 0x01]);
base_cache()
.lock()
.unwrap()
.insert(key.clone(), sentinel.clone());
let result = get_or_build_base(&exe, &[], &[], false, &key)
.expect("cargo-test-mode must reuse process-local cache");
match result {
BaseRef::Owned(arc) => {
assert!(
Arc::ptr_eq(&arc, &sentinel),
"cargo-test-mode hit on a planted process-local entry \
must return the SAME Arc — a regression that fell \
through into the inline-build path would produce a \
fresh Arc with the same contents but a different \
identity"
);
}
BaseRef::Mapped(_) => {
panic!(
"cargo-test-mode must NEVER mmap an SHM segment — \
bypass contract requires process-local-only memoisation"
);
}
}
base_cache().lock().unwrap().remove(&key);
}
}