use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use worktrunk::cache;
use worktrunk::git::Repository;
use worktrunk::path::sanitize_for_filename;
use super::PrStatus;
const KIND: &str = "ci-status";
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub(crate) struct MaxPrNumber {
pub number: u64,
}
impl MaxPrNumber {
const KIND: &str = "pr-number";
const KEY: &str = "max.json";
fn path(repo: &Repository) -> PathBuf {
cache::cache_dir(repo, Self::KIND).join(Self::KEY)
}
pub(crate) fn read(repo: &Repository) -> Option<u64> {
cache::read_json::<Self>(&Self::path(repo)).map(|m| m.number)
}
pub(super) fn ratchet(repo: &Repository, number: u64) {
if Self::read(repo).is_none_or(|stored| number > stored) {
cache::write_json(&Self::path(repo), &Self { number });
}
}
pub(crate) fn clear(repo: &Repository) -> anyhow::Result<usize> {
Ok(usize::from(cache::clear_one(&Self::path(repo))?))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct CachedCiStatus {
pub status: Option<PrStatus>,
pub checked_at: u64,
pub head: String,
pub branch: String,
}
impl CachedCiStatus {
const TTL_BASE_SECS: u64 = 30;
const TTL_JITTER_SECS: u64 = 30;
pub(crate) fn ttl_for_repo(repo_root: &Path) -> u64 {
let mut hasher = DefaultHasher::new();
repo_root.as_os_str().hash(&mut hasher);
let hash = hasher.finish();
let jitter = hash % Self::TTL_JITTER_SECS;
Self::TTL_BASE_SECS + jitter
}
pub(super) fn is_valid(&self, current_head: &str, now_secs: u64, repo_root: &Path) -> bool {
let ttl = Self::ttl_for_repo(repo_root);
self.head == current_head && now_secs.saturating_sub(self.checked_at) < ttl
}
fn cache_dir(repo: &Repository) -> PathBuf {
cache::cache_dir(repo, KIND)
}
fn cache_file(repo: &Repository, branch: &str) -> PathBuf {
let safe_branch = sanitize_for_filename(branch);
Self::cache_dir(repo).join(format!("{safe_branch}.json"))
}
pub(super) fn read(repo: &Repository, branch: &str) -> Option<Self> {
cache::read_json(&Self::cache_file(repo, branch))
}
pub(super) fn write(&self, repo: &Repository, branch: &str) {
cache::write_json(&Self::cache_file(repo, branch), self);
}
pub(crate) fn list_all(repo: &Repository) -> Vec<Self> {
let dir = Self::cache_dir(repo);
let Ok(entries) = std::fs::read_dir(&dir) else {
return Vec::new();
};
let mut out: Vec<Self> = entries
.filter_map(|entry| {
let path = entry.ok()?.path();
if path.extension()?.to_str()? != "json" {
return None;
}
cache::read_json(&path)
})
.collect();
out.sort_by(|a, b| {
b.checked_at
.cmp(&a.checked_at)
.then_with(|| a.branch.cmp(&b.branch))
});
out
}
pub(crate) fn clear_one(repo: &Repository, branch: &str) -> anyhow::Result<bool> {
cache::clear_one(&Self::cache_file(repo, branch))
}
pub(crate) fn clear_all(repo: &Repository) -> anyhow::Result<usize> {
cache::clear_json_files(&Self::cache_dir(repo))
}
}
#[cfg(test)]
mod tests {
use super::*;
use worktrunk::testing::TestRepo;
#[test]
fn test_max_pr_number_ratchet() {
let test = TestRepo::with_initial_commit();
let repo = Repository::at(test.root_path()).unwrap();
assert_eq!(MaxPrNumber::read(&repo), None);
MaxPrNumber::ratchet(&repo, 42);
assert_eq!(MaxPrNumber::read(&repo), Some(42));
MaxPrNumber::ratchet(&repo, 7);
assert_eq!(MaxPrNumber::read(&repo), Some(42));
MaxPrNumber::ratchet(&repo, 100);
assert_eq!(MaxPrNumber::read(&repo), Some(100));
assert_eq!(MaxPrNumber::clear(&repo).unwrap(), 1);
assert_eq!(MaxPrNumber::read(&repo), None);
assert_eq!(MaxPrNumber::clear(&repo).unwrap(), 0);
}
#[test]
fn test_ttl_jitter_range_and_determinism() {
let paths = [
"/tmp/repo1",
"/tmp/repo2",
"/workspace/project",
"/home/user/code",
];
for path in paths {
let ttl = CachedCiStatus::ttl_for_repo(Path::new(path));
assert!(
(30..60).contains(&ttl),
"TTL {} for path {} should be in [30, 60)",
ttl,
path
);
}
let path = Path::new("/some/consistent/path");
let ttl1 = CachedCiStatus::ttl_for_repo(path);
let ttl2 = CachedCiStatus::ttl_for_repo(path);
assert_eq!(ttl1, ttl2, "Same path should produce same TTL");
let diverse_paths: Vec<_> = (0..20).map(|i| format!("/repo/path{}", i)).collect();
let ttls: std::collections::HashSet<_> = diverse_paths
.iter()
.map(|p| CachedCiStatus::ttl_for_repo(Path::new(p)))
.collect();
assert!(
ttls.len() >= 10,
"Expected diverse TTLs across paths, got {} unique values",
ttls.len()
);
}
}