use std::collections::HashSet;
use std::path::Path;
use std::time::SystemTime;
use anyhow::Result;
use super::ContextEntry;
pub struct InvalidationChecker {
head_commit: Option<String>,
dirty_files: HashSet<String>,
repo_root: Option<std::path::PathBuf>,
}
impl InvalidationChecker {
pub fn new() -> Self {
Self {
head_commit: None,
dirty_files: HashSet::new(),
repo_root: None,
}
}
pub fn from_git_repo(repo_path: impl AsRef<Path>) -> Result<Self> {
let repo = git2::Repository::discover(repo_path.as_ref())?;
let repo_root = repo.workdir().map(|p| p.to_path_buf());
let head_commit = repo.head()
.ok()
.and_then(|h| h.target())
.map(|oid| oid.to_string());
let mut dirty_files = HashSet::new();
if let Ok(statuses) = repo.statuses(None) {
for status in statuses.iter() {
if let Some(path) = status.path() {
dirty_files.insert(path.to_string());
}
}
}
Ok(Self {
head_commit,
dirty_files,
repo_root,
})
}
pub fn refresh(&mut self) -> Result<()> {
if let Some(ref root) = self.repo_root {
let repo = git2::Repository::open(root)?;
self.head_commit = repo.head()
.ok()
.and_then(|h| h.target())
.map(|oid| oid.to_string());
self.dirty_files.clear();
let dirty_paths: Vec<String> = repo.statuses(None)
.ok()
.map(|statuses| {
statuses.iter()
.filter_map(|s| s.path().map(|p| p.to_string()))
.collect()
})
.unwrap_or_default();
self.dirty_files.extend(dirty_paths);
}
Ok(())
}
pub fn head_commit(&self) -> Option<&str> {
self.head_commit.as_deref()
}
pub fn is_dirty(&self, path: &str) -> bool {
self.dirty_files.contains(path)
}
pub fn get_mtime(&self, path: &str) -> Option<i64> {
let full_path = if let Some(ref root) = self.repo_root {
root.join(path)
} else {
std::path::PathBuf::from(path)
};
std::fs::metadata(&full_path)
.ok()
.and_then(|m| m.modified().ok())
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
.map(|d| d.as_secs() as i64)
}
pub fn is_valid(&self, entry: &ContextEntry) -> bool {
if entry.is_expired() {
return false;
}
if let Some(ref file_path) = entry.file_path {
if self.is_dirty(file_path) {
return false;
}
if let (Some(cached_mtime), Some(current_mtime)) =
(entry.file_mtime, self.get_mtime(file_path))
{
if current_mtime > cached_mtime {
return false;
}
}
}
if let (Some(ref entry_commit), Some(ref current_commit)) =
(&entry.git_commit, &self.head_commit)
{
if entry.file_path.is_none() && entry_commit != current_commit {
return false;
}
}
true
}
pub fn find_invalid(&self, entries: &[ContextEntry]) -> Vec<String> {
entries
.iter()
.filter(|e| !self.is_valid(e))
.map(|e| e.key.clone())
.collect()
}
}
impl Default for InvalidationChecker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_ttl_expiration() {
let checker = InvalidationChecker::new();
let entry = ContextEntry::new("test:key", json!({"data": "value"}));
assert!(checker.is_valid(&entry));
let mut expired = ContextEntry::new("test:expired", json!({}));
expired.expires_at = Some(chrono::Utc::now() - chrono::Duration::hours(1));
assert!(!checker.is_valid(&expired));
}
}