use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[cfg(feature = "rustfft-backend")]
use rustfft::FftPlanner;
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct PlanKey {
size: usize,
forward: bool,
}
#[cfg(feature = "rustfft-backend")]
#[derive(Clone)]
struct CachedPlan {
plan: Arc<dyn rustfft::Fft<f64>>,
last_used: Instant,
usage_count: usize,
}
#[cfg(all(feature = "oxifft", not(feature = "rustfft-backend")))]
#[derive(Clone)]
struct CachedPlan {
size: usize,
forward: bool,
last_used: Instant,
usage_count: usize,
}
pub struct PlanCache {
cache: Arc<Mutex<HashMap<PlanKey, CachedPlan>>>,
max_entries: usize,
max_age: Duration,
enabled: Arc<Mutex<bool>>,
hit_count: Arc<Mutex<u64>>,
miss_count: Arc<Mutex<u64>>,
}
impl PlanCache {
pub fn new() -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::new())),
max_entries: 128,
max_age: Duration::from_secs(3600), enabled: Arc::new(Mutex::new(true)),
hit_count: Arc::new(Mutex::new(0)),
miss_count: Arc::new(Mutex::new(0)),
}
}
pub fn with_config(max_entries: usize, max_age: Duration) -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::new())),
max_entries,
max_age,
enabled: Arc::new(Mutex::new(true)),
hit_count: Arc::new(Mutex::new(0)),
miss_count: Arc::new(Mutex::new(0)),
}
}
pub fn set_enabled(&self, enabled: bool) {
*self.enabled.lock().expect("Operation failed") = enabled;
}
pub fn is_enabled(&self) -> bool {
*self.enabled.lock().expect("Operation failed")
}
pub fn clear(&self) {
if let Ok(mut cache) = self.cache.lock() {
cache.clear();
}
}
pub fn get_stats(&self) -> CacheStats {
let hit_count = *self.hit_count.lock().expect("Operation failed");
let miss_count = *self.miss_count.lock().expect("Operation failed");
let total_requests = hit_count + miss_count;
let hit_rate = if total_requests > 0 {
hit_count as f64 / total_requests as f64
} else {
0.0
};
let size = self.cache.lock().map(|c| c.len()).unwrap_or(0);
CacheStats {
hit_count,
miss_count,
hit_rate,
size,
max_size: self.max_entries,
}
}
#[cfg(feature = "rustfft-backend")]
pub fn get_or_create_plan(
&self,
size: usize,
forward: bool,
planner: &mut FftPlanner<f64>,
) -> Arc<dyn rustfft::Fft<f64>> {
if !*self.enabled.lock().expect("Operation failed") {
return if forward {
planner.plan_fft_forward(size)
} else {
planner.plan_fft_inverse(size)
};
}
let key = PlanKey { size, forward };
if let Ok(mut cache) = self.cache.lock() {
if let Some(cached) = cache.get_mut(&key) {
if cached.last_used.elapsed() <= self.max_age {
cached.last_used = Instant::now();
cached.usage_count += 1;
*self.hit_count.lock().expect("Operation failed") += 1;
return cached.plan.clone();
} else {
cache.remove(&key);
}
}
}
*self.miss_count.lock().expect("Operation failed") += 1;
let plan: Arc<dyn rustfft::Fft<f64>> = if forward {
planner.plan_fft_forward(size)
} else {
planner.plan_fft_inverse(size)
};
if let Ok(mut cache) = self.cache.lock() {
if cache.len() >= self.max_entries {
self.evict_old_entries(&mut cache);
}
cache.insert(
key,
CachedPlan {
plan: plan.clone(),
last_used: Instant::now(),
usage_count: 1,
},
);
}
plan
}
#[cfg(all(feature = "oxifft", not(feature = "rustfft-backend")))]
pub fn track_plan_usage(&self, size: usize, forward: bool) {
if !*self.enabled.lock().expect("Operation failed") {
return;
}
let key = PlanKey { size, forward };
if let Ok(mut cache) = self.cache.lock() {
if let Some(cached) = cache.get_mut(&key) {
if cached.last_used.elapsed() <= self.max_age {
cached.last_used = Instant::now();
cached.usage_count += 1;
*self.hit_count.lock().expect("Operation failed") += 1;
return;
} else {
cache.remove(&key);
}
}
}
*self.miss_count.lock().expect("Operation failed") += 1;
if let Ok(mut cache) = self.cache.lock() {
if cache.len() >= self.max_entries {
self.evict_old_entries(&mut cache);
}
cache.insert(
key,
CachedPlan {
size,
forward,
last_used: Instant::now(),
usage_count: 1,
},
);
}
}
fn evict_old_entries(&self, cache: &mut HashMap<PlanKey, CachedPlan>) {
cache.retain(|_, v| v.last_used.elapsed() <= self.max_age);
while cache.len() >= self.max_entries {
if let Some((key_to_remove_, _)) = cache
.iter()
.min_by_key(|(_, v)| (v.last_used, v.usage_count))
.map(|(k, v)| (k.clone(), v.clone()))
{
cache.remove(&key_to_remove_);
} else {
break;
}
}
}
#[cfg(feature = "rustfft-backend")]
pub fn precompute_common_sizes(&self, sizes: &[usize], planner: &mut FftPlanner<f64>) {
for &size in sizes {
self.get_or_create_plan(size, true, planner);
self.get_or_create_plan(size, false, planner);
}
}
#[cfg(all(feature = "oxifft", not(feature = "rustfft-backend")))]
pub fn precompute_common_sizes(&self, sizes: &[usize]) {
for &size in sizes {
self.track_plan_usage(size, true);
self.track_plan_usage(size, false);
}
}
}
impl Default for PlanCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub hit_count: u64,
pub miss_count: u64,
pub hit_rate: f64,
pub size: usize,
pub max_size: usize,
}
impl std::fmt::Display for CacheStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Cache Stats: {} hits, {} misses ({:.1}% hit rate), {}/{} entries",
self.hit_count,
self.miss_count,
self.hit_rate * 100.0,
self.size,
self.max_size
)
}
}
static GLOBAL_PLAN_CACHE: std::sync::OnceLock<PlanCache> = std::sync::OnceLock::new();
#[allow(dead_code)]
pub fn get_global_cache() -> &'static PlanCache {
GLOBAL_PLAN_CACHE.get_or_init(PlanCache::new)
}
#[allow(dead_code)]
pub fn init_global_cache(max_entries: usize, max_age: Duration) -> Result<(), &'static str> {
GLOBAL_PLAN_CACHE
.set(PlanCache::with_config(max_entries, max_age))
.map_err(|_| "Global plan cache already initialized")
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "rustfft-backend")]
#[test]
fn test_plan_cache_basic_rustfft() {
let cache = PlanCache::new();
let mut planner = FftPlanner::new();
let _plan1 = cache.get_or_create_plan(128, true, &mut planner);
let _plan2 = cache.get_or_create_plan(128, true, &mut planner);
let stats = cache.get_stats();
assert_eq!(stats.hit_count, 1);
assert_eq!(stats.miss_count, 1);
}
#[cfg(feature = "rustfft-backend")]
#[test]
fn test_cache_eviction_rustfft() {
let cache = PlanCache::with_config(2, Duration::from_secs(3600));
let mut planner = FftPlanner::new();
cache.get_or_create_plan(64, true, &mut planner);
cache.get_or_create_plan(128, true, &mut planner);
cache.get_or_create_plan(256, true, &mut planner);
let stats = cache.get_stats();
assert_eq!(stats.size, 2);
}
#[cfg(feature = "rustfft-backend")]
#[test]
fn test_cache_disabled_rustfft() {
let cache = PlanCache::new();
cache.set_enabled(false);
let mut planner = FftPlanner::new();
cache.get_or_create_plan(128, true, &mut planner);
cache.get_or_create_plan(128, true, &mut planner);
let stats = cache.get_stats();
assert_eq!(stats.hit_count, 0);
assert_eq!(stats.miss_count, 0); }
#[cfg(all(feature = "oxifft", not(feature = "rustfft-backend")))]
#[test]
fn test_plan_cache_basic_oxifft() {
let cache = PlanCache::new();
cache.track_plan_usage(128, true);
cache.track_plan_usage(128, true);
let stats = cache.get_stats();
assert_eq!(stats.hit_count, 1);
assert_eq!(stats.miss_count, 1);
}
#[cfg(all(feature = "oxifft", not(feature = "rustfft-backend")))]
#[test]
fn test_cache_eviction_oxifft() {
let cache = PlanCache::with_config(2, Duration::from_secs(3600));
cache.track_plan_usage(64, true);
cache.track_plan_usage(128, true);
cache.track_plan_usage(256, true);
let stats = cache.get_stats();
assert_eq!(stats.size, 2);
}
#[cfg(all(feature = "oxifft", not(feature = "rustfft-backend")))]
#[test]
fn test_cache_disabled_oxifft() {
let cache = PlanCache::new();
cache.set_enabled(false);
cache.track_plan_usage(128, true);
cache.track_plan_usage(128, true);
let stats = cache.get_stats();
assert_eq!(stats.hit_count, 0);
assert_eq!(stats.miss_count, 0); }
}