use std::collections::HashMap;
use std::fmt::Write;
use std::sync::{Arc, OnceLock};
use camino::Utf8PathBuf;
use sha2::{Digest, Sha256};
use tokio::sync::Mutex;
use crate::error::{Error, Result};
use crate::git;
use crate::paths::template_cache_dir;
#[derive(Default)]
struct CacheLocks {
map: std::sync::Mutex<HashMap<String, Arc<Mutex<()>>>>,
}
impl CacheLocks {
fn for_url(&self, url: &str) -> Arc<Mutex<()>> {
let mut map = self
.map
.lock()
.expect("CacheLocks map mutex poisoned — a previous task panicked while inserting");
map.entry(url.to_string()).or_default().clone()
}
}
fn cache_locks() -> &'static CacheLocks {
static LOCKS: OnceLock<CacheLocks> = OnceLock::new();
LOCKS.get_or_init(CacheLocks::default)
}
pub fn lock_for_url(url: &str) -> Arc<Mutex<()>> {
cache_locks().for_url(url)
}
pub struct TemplateCache {
pub root: Utf8PathBuf,
}
impl TemplateCache {
pub fn ensure() -> Result<Self> {
let root = template_cache_dir()?;
std::fs::create_dir_all(root.as_std_path())
.map_err(|e| Error::io_at(root.as_std_path(), e))?;
Ok(Self { root })
}
pub fn slot(&self, source: &str) -> Utf8PathBuf {
let mut h = Sha256::new();
h.update(source.as_bytes());
let bytes = h.finalize();
let mut hex = String::with_capacity(16);
for b in bytes.iter().take(8) {
let _ = write!(hex, "{b:02x}");
}
self.root.join(hex)
}
pub async fn fetch_or_clone(
&self,
source: &str,
rev_spec: Option<&str>,
) -> Result<(Utf8PathBuf, String)> {
let slot = self.slot(source);
let url_lock = lock_for_url(source);
let _guard = url_lock.lock().await;
let cached = slot.join(".git").is_dir();
if !cached {
if slot.exists() {
std::fs::remove_dir_all(slot.as_std_path())
.map_err(|e| Error::io_at(slot.as_std_path(), e))?;
}
if let Some(parent) = slot.parent() {
std::fs::create_dir_all(parent.as_std_path())
.map_err(|e| Error::io_at(parent.as_std_path(), e))?;
}
git::clone_at(source, slot.as_path()).await?;
} else {
if let Err(e) = git::fetch(slot.as_path()).await {
eprintln!("kata: warning: fetch failed for {source}: {e}; using cached refs");
}
}
let target = rev_spec.unwrap_or("origin/HEAD");
git::checkout(slot.as_path(), target).await?;
let head = git::current_head(slot.as_path()).await?;
Ok((slot, head))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn slot_is_stable_for_same_source() {
let cache = TemplateCache {
root: Utf8PathBuf::from("/tmp/kata-cache"),
};
let a = cache.slot("github.com/yukimemi/pj-base");
let b = cache.slot("github.com/yukimemi/pj-base");
assert_eq!(a, b);
}
#[test]
fn slot_is_invariant_to_rev() {
let cache = TemplateCache {
root: Utf8PathBuf::from("/tmp/kata-cache"),
};
let s = cache.slot("github.com/x/y");
let s_again = cache.slot("github.com/x/y");
assert_eq!(s, s_again);
}
#[test]
fn slot_differs_for_different_sources() {
let cache = TemplateCache {
root: Utf8PathBuf::from("/tmp/kata-cache"),
};
let a = cache.slot("github.com/x/a");
let b = cache.slot("github.com/x/b");
assert_ne!(a, b);
}
#[test]
fn lock_for_url_returns_same_mutex_for_same_url() {
let a = lock_for_url("github.com/yukimemi/pj-base");
let b = lock_for_url("github.com/yukimemi/pj-base");
assert!(
Arc::ptr_eq(&a, &b),
"same URL must yield identical Arc<Mutex>"
);
}
#[test]
fn lock_for_url_returns_distinct_mutex_for_different_urls() {
let a = lock_for_url("github.com/yukimemi/pj-base");
let b = lock_for_url("github.com/yukimemi/pj-rust");
assert!(
!Arc::ptr_eq(&a, &b),
"different URLs must yield distinct Arc<Mutex>"
);
}
#[tokio::test]
async fn lock_for_url_serialises_concurrent_holders() {
use std::sync::Arc as StdArc;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::task::JoinSet;
use tokio::time::{Duration, sleep};
let url = "github.com/test/concurrent";
let counter = StdArc::new(AtomicU32::new(0));
let mut set = JoinSet::new();
for _ in 0..4 {
let counter = counter.clone();
set.spawn(async move {
let lock = lock_for_url(url);
let _guard = lock.lock().await;
let before = counter.fetch_add(1, Ordering::SeqCst);
assert_eq!(before, 0, "another task was inside the guard");
sleep(Duration::from_millis(10)).await;
counter.fetch_sub(1, Ordering::SeqCst);
});
}
while let Some(res) = set.join_next().await {
res.expect("task panicked — lock failed to serialise");
}
}
}