use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::registry::AcceleratorRegistry;
pub struct CachedRegistry {
ttl: Duration,
inner: Mutex<CacheState>,
}
impl std::fmt::Debug for CachedRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CachedRegistry")
.field("ttl", &self.ttl)
.field(
"cached",
&self.inner.lock().is_ok_and(|s| s.registry.is_some()),
)
.finish()
}
}
struct CacheState {
registry: Option<Arc<AcceleratorRegistry>>,
last_detect: Option<Instant>,
}
impl CachedRegistry {
pub fn new(ttl: Duration) -> Self {
Self {
ttl,
inner: Mutex::new(CacheState {
registry: None,
last_detect: None,
}),
}
}
pub fn get(&self) -> Arc<AcceleratorRegistry> {
{
let state = match self.inner.lock() {
Ok(guard) => guard,
Err(poisoned) => {
tracing::warn!("CachedRegistry lock was poisoned, invalidating cache");
let mut guard = poisoned.into_inner();
guard.registry = None;
guard.last_detect = None;
guard
}
};
if let Some(ref reg) = state.registry
&& let Some(last) = state.last_detect
&& Instant::now().duration_since(last) < self.ttl
{
return Arc::clone(reg);
}
}
let reg = Arc::new(AcceleratorRegistry::detect());
let mut state = match self.inner.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
state.registry = Some(Arc::clone(®));
state.last_detect = Some(Instant::now());
reg
}
pub fn invalidate(&self) {
let mut state = match self.inner.lock() {
Ok(guard) => guard,
Err(poisoned) => {
tracing::warn!("CachedRegistry lock was poisoned, invalidating cache");
let mut guard = poisoned.into_inner();
guard.registry = None;
guard.last_detect = None;
guard
}
};
state.registry = None;
state.last_detect = None;
}
pub fn ttl(&self) -> Duration {
self.ttl
}
}
pub struct DiskCachedRegistry {
ttl: Duration,
cache_path: std::path::PathBuf,
memory: Mutex<CacheState>,
}
impl std::fmt::Debug for DiskCachedRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DiskCachedRegistry")
.field("ttl", &self.ttl)
.field("cache_path", &self.cache_path)
.finish()
}
}
impl DiskCachedRegistry {
pub fn new(ttl: Duration) -> Self {
let cache_path = Self::default_cache_path();
Self {
ttl,
cache_path,
memory: Mutex::new(CacheState {
registry: None,
last_detect: None,
}),
}
}
pub fn with_path(ttl: Duration, path: std::path::PathBuf) -> Self {
Self {
ttl,
cache_path: path,
memory: Mutex::new(CacheState {
registry: None,
last_detect: None,
}),
}
}
pub fn get(&self) -> Arc<AcceleratorRegistry> {
let mut state = match self.memory.lock() {
Ok(guard) => guard,
Err(poisoned) => {
tracing::warn!("DiskCachedRegistry lock was poisoned, invalidating cache");
let mut guard = poisoned.into_inner();
guard.registry = None;
guard.last_detect = None;
guard
}
};
if let Some(ref reg) = state.registry
&& let Some(last) = state.last_detect
&& Instant::now().duration_since(last) < self.ttl
{
return Arc::clone(reg);
}
if let Some(reg) = self.read_disk_cache() {
let arc = Arc::new(reg);
state.registry = Some(Arc::clone(&arc));
state.last_detect = Some(Instant::now());
return arc;
}
let reg = Arc::new(AcceleratorRegistry::detect());
state.registry = Some(Arc::clone(®));
state.last_detect = Some(Instant::now());
self.write_disk_cache(®);
reg
}
pub fn invalidate(&self) {
let mut state = match self.memory.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
state.registry = None;
state.last_detect = None;
let _ = std::fs::remove_file(&self.cache_path);
}
pub fn cache_path(&self) -> &std::path::Path {
&self.cache_path
}
fn default_cache_path() -> std::path::PathBuf {
let cache_dir = std::env::var("XDG_CACHE_HOME")
.ok()
.filter(|s| !s.is_empty())
.map(std::path::PathBuf::from)
.or_else(|| {
std::env::var("HOME")
.ok()
.map(|h| std::path::PathBuf::from(h).join(".cache"))
})
.unwrap_or_else(|| std::path::PathBuf::from("/tmp"));
cache_dir.join("ai-hwaccel").join("registry.json")
}
fn read_disk_cache(&self) -> Option<AcceleratorRegistry> {
let metadata = std::fs::metadata(&self.cache_path).ok()?;
let age = metadata.modified().ok()?.elapsed().unwrap_or(Duration::MAX);
if age > self.ttl {
return None;
}
let data = std::fs::read_to_string(&self.cache_path).ok()?;
let reg = AcceleratorRegistry::from_json(&data).ok()?;
tracing::debug!(
age_secs = age.as_secs_f64(),
path = %self.cache_path.display(),
"loaded registry from disk cache"
);
Some(reg)
}
fn write_disk_cache(&self, registry: &AcceleratorRegistry) {
if let Some(parent) = self.cache_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let json = match serde_json::to_string(registry) {
Ok(j) => j,
Err(e) => {
tracing::debug!(error = %e, "failed to serialize registry for disk cache");
return;
}
};
let tmp_path = self.cache_path.with_extension("tmp");
if std::fs::write(&tmp_path, &json).is_ok() {
if std::fs::rename(&tmp_path, &self.cache_path).is_ok() {
return;
}
let _ = std::fs::remove_file(&tmp_path);
}
if let Err(e) = std::fs::write(&self.cache_path, json) {
tracing::debug!(
error = %e,
path = %self.cache_path.display(),
"failed to write disk cache"
);
}
}
}