use std::collections::hash_map::DefaultHasher;
use std::fs;
use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use worktrunk::git::Repository;
use worktrunk::path::sanitize_for_filename;
use super::PrStatus;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct CachedCiStatus {
pub status: Option<PrStatus>,
pub checked_at: u64,
pub head: String,
pub branch: String,
}
impl CachedCiStatus {
const TTL_BASE_SECS: u64 = 30;
const TTL_JITTER_SECS: u64 = 30;
pub(crate) fn ttl_for_repo(repo_root: &Path) -> u64 {
let mut hasher = DefaultHasher::new();
repo_root.as_os_str().hash(&mut hasher);
let hash = hasher.finish();
let jitter = hash % Self::TTL_JITTER_SECS;
Self::TTL_BASE_SECS + jitter
}
pub(super) fn is_valid(&self, current_head: &str, now_secs: u64, repo_root: &Path) -> bool {
let ttl = Self::ttl_for_repo(repo_root);
self.head == current_head && now_secs.saturating_sub(self.checked_at) < ttl
}
fn cache_dir(repo: &Repository) -> PathBuf {
repo.wt_dir().join("cache").join("ci-status")
}
fn cache_file(repo: &Repository, branch: &str) -> PathBuf {
let dir = Self::cache_dir(repo);
let safe_branch = sanitize_for_filename(branch);
dir.join(format!("{safe_branch}.json"))
}
pub(super) fn read(repo: &Repository, branch: &str) -> Option<Self> {
let path = Self::cache_file(repo, branch);
let json = fs::read_to_string(&path).ok()?;
serde_json::from_str(&json).ok()
}
pub(super) fn write(&self, repo: &Repository, branch: &str) {
let path = Self::cache_file(repo, branch);
if let Some(parent) = path.parent()
&& let Err(e) = fs::create_dir_all(parent)
{
log::debug!("Failed to create cache dir for {}: {}", branch, e);
return;
}
let Ok(json) = serde_json::to_string(self) else {
log::debug!("Failed to serialize CI cache for {}", branch);
return;
};
let temp_path = path.with_extension("json.tmp");
if let Err(e) = fs::write(&temp_path, &json) {
log::debug!("Failed to write CI cache temp file for {}: {}", branch, e);
return;
}
#[cfg(windows)]
let _ = fs::remove_file(&path);
if let Err(e) = fs::rename(&temp_path, &path) {
log::debug!("Failed to rename CI cache file for {}: {}", branch, e);
let _ = fs::remove_file(&temp_path);
}
}
pub(crate) fn list_all(repo: &Repository) -> Vec<(String, Self)> {
let cache_dir = Self::cache_dir(repo);
let entries = match fs::read_dir(&cache_dir) {
Ok(entries) => entries,
Err(_) => return Vec::new(),
};
entries
.filter_map(|entry| {
let entry = entry.ok()?;
let path = entry.path();
if path.extension()?.to_str()? != "json" {
return None;
}
let json = fs::read_to_string(&path).ok()?;
let cached: Self = serde_json::from_str(&json).ok()?;
Some((cached.branch.clone(), cached))
})
.collect()
}
pub(crate) fn clear_all(repo: &Repository) -> usize {
let cache_dir = Self::cache_dir(repo);
let entries = match fs::read_dir(&cache_dir) {
Ok(entries) => entries,
Err(_) => return 0,
};
let mut cleared = 0;
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "json") && fs::remove_file(&path).is_ok() {
cleared += 1;
}
}
cleared
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ttl_jitter_range_and_determinism() {
let paths = [
"/tmp/repo1",
"/tmp/repo2",
"/workspace/project",
"/home/user/code",
];
for path in paths {
let ttl = CachedCiStatus::ttl_for_repo(Path::new(path));
assert!(
(30..60).contains(&ttl),
"TTL {} for path {} should be in [30, 60)",
ttl,
path
);
}
let path = Path::new("/some/consistent/path");
let ttl1 = CachedCiStatus::ttl_for_repo(path);
let ttl2 = CachedCiStatus::ttl_for_repo(path);
assert_eq!(ttl1, ttl2, "Same path should produce same TTL");
let diverse_paths: Vec<_> = (0..20).map(|i| format!("/repo/path{}", i)).collect();
let ttls: std::collections::HashSet<_> = diverse_paths
.iter()
.map(|p| CachedCiStatus::ttl_for_repo(Path::new(p)))
.collect();
assert!(
ttls.len() >= 10,
"Expected diverse TTLs across paths, got {} unique values",
ttls.len()
);
}
}