use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::time::{Duration, Instant};
use std::sync::{Arc, RwLock};
use crate::core::{Expression, Number, BinaryOperator, UnaryOperator};
use crate::api::CacheConfig;
#[derive(Debug, Clone)]
pub struct CacheEntry<T> {
pub value: T,
pub created_at: Instant,
pub last_accessed: Instant,
pub access_count: u64,
pub compute_cost: u32,
}
impl<T> CacheEntry<T> {
pub fn new(value: T, compute_cost: u32) -> Self {
let now = Instant::now();
Self {
value,
created_at: now,
last_accessed: now,
access_count: 1,
compute_cost,
}
}
pub fn access(&mut self) {
self.last_accessed = Instant::now();
self.access_count += 1;
}
pub fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
pub fn priority(&self) -> f64 {
let age_factor = self.last_accessed.elapsed().as_secs_f64();
let frequency_factor = self.access_count as f64;
let cost_factor = self.compute_cost as f64;
(frequency_factor * cost_factor) / (age_factor + 1.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum FastCacheKey {
BinaryOp(i64, i64, BinaryOperator),
UnaryOp(i64, UnaryOperator),
Function(String, Vec<i64>),
}
#[derive(Debug, Clone)]
pub struct ExactCacheKey {
pub operand1: Number,
pub operand2: Option<Number>,
pub operation: String,
}
impl PartialEq for ExactCacheKey {
fn eq(&self, other: &Self) -> bool {
self.operand1 == other.operand1
&& self.operand2 == other.operand2
&& self.operation == other.operation
}
}
impl Eq for ExactCacheKey {}
impl Hash for ExactCacheKey {
fn hash<H: Hasher>(&self, state: &mut H) {
match &self.operand1 {
Number::Integer(i) => {
0u8.hash(state);
i.to_string().hash(state);
}
Number::Rational(r) => {
1u8.hash(state);
r.to_string().hash(state);
}
Number::Real(r) => {
2u8.hash(state);
r.to_string().hash(state);
}
Number::Complex { real, imaginary } => {
3u8.hash(state);
real.hash(state);
imaginary.hash(state);
}
Number::Symbolic(expr) => {
4u8.hash(state);
format!("{:?}", expr).hash(state);
}
Number::Float(f) => {
5u8.hash(state);
f.to_bits().hash(state);
}
Number::Constant(c) => {
6u8.hash(state);
format!("{:?}", c).hash(state);
}
}
if let Some(ref op2) = self.operand2 {
match op2 {
Number::Integer(i) => {
0u8.hash(state);
i.to_string().hash(state);
}
Number::Rational(r) => {
1u8.hash(state);
r.to_string().hash(state);
}
Number::Real(r) => {
2u8.hash(state);
r.to_string().hash(state);
}
Number::Complex { real, imaginary } => {
3u8.hash(state);
real.hash(state);
imaginary.hash(state);
}
Number::Symbolic(expr) => {
4u8.hash(state);
format!("{:?}", expr).hash(state);
}
Number::Float(f) => {
5u8.hash(state);
f.to_bits().hash(state);
}
Number::Constant(c) => {
6u8.hash(state);
format!("{:?}", c).hash(state);
}
}
}
self.operation.hash(state);
}
}
#[derive(Debug, Clone)]
pub struct SymbolicCacheKey {
pub expression: Expression,
pub operation: String,
pub variable: Option<String>,
}
impl PartialEq for SymbolicCacheKey {
fn eq(&self, other: &Self) -> bool {
self.expression == other.expression
&& self.operation == other.operation
&& self.variable == other.variable
}
}
impl Eq for SymbolicCacheKey {}
impl Hash for SymbolicCacheKey {
fn hash<H: Hasher>(&self, state: &mut H) {
format!("{:?}", self.expression).hash(state);
self.operation.hash(state);
self.variable.hash(state);
}
}
#[derive(Debug)]
pub struct ComputeCache {
fast_cache: Arc<RwLock<HashMap<FastCacheKey, CacheEntry<i64>>>>,
exact_cache: Arc<RwLock<HashMap<ExactCacheKey, CacheEntry<Number>>>>,
symbolic_cache: Arc<RwLock<HashMap<SymbolicCacheKey, CacheEntry<Expression>>>>,
config: CacheConfig,
stats: Arc<RwLock<CacheStats>>,
}
#[derive(Debug, Default, Clone)]
pub struct CacheStats {
pub fast_hits: u64,
pub fast_misses: u64,
pub exact_hits: u64,
pub exact_misses: u64,
pub symbolic_hits: u64,
pub symbolic_misses: u64,
pub cleanup_count: u64,
pub total_time_saved: Duration,
}
impl CacheStats {
pub fn total_hit_rate(&self) -> f64 {
let total_hits = self.fast_hits + self.exact_hits + self.symbolic_hits;
let total_requests = total_hits + self.fast_misses + self.exact_misses + self.symbolic_misses;
if total_requests == 0 {
0.0
} else {
total_hits as f64 / total_requests as f64
}
}
pub fn fast_hit_rate(&self) -> f64 {
let total = self.fast_hits + self.fast_misses;
if total == 0 {
0.0
} else {
self.fast_hits as f64 / total as f64
}
}
pub fn exact_hit_rate(&self) -> f64 {
let total = self.exact_hits + self.exact_misses;
if total == 0 {
0.0
} else {
self.exact_hits as f64 / total as f64
}
}
pub fn symbolic_hit_rate(&self) -> f64 {
let total = self.symbolic_hits + self.symbolic_misses;
if total == 0 {
0.0
} else {
self.symbolic_hits as f64 / total as f64
}
}
}
impl ComputeCache {
pub fn new(config: CacheConfig) -> Self {
Self {
fast_cache: Arc::new(RwLock::new(HashMap::new())),
exact_cache: Arc::new(RwLock::new(HashMap::new())),
symbolic_cache: Arc::new(RwLock::new(HashMap::new())),
config,
stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
pub fn get_fast(&self, key: &FastCacheKey) -> Option<i64> {
if !self.config.enabled {
return None;
}
let mut cache = self.fast_cache.write().unwrap();
if let Some(entry) = cache.get_mut(key) {
if let Some(ttl) = self.config.cache_ttl {
if entry.is_expired(ttl) {
cache.remove(key);
self.record_fast_miss();
return None;
}
}
entry.access();
self.record_fast_hit();
Some(entry.value)
} else {
self.record_fast_miss();
None
}
}
pub fn put_fast(&self, key: FastCacheKey, value: i64, compute_cost: u32) {
if !self.config.enabled {
return;
}
let mut cache = self.fast_cache.write().unwrap();
if cache.len() >= self.config.fast_cache_size {
self.cleanup_fast_cache(&mut cache);
}
cache.insert(key, CacheEntry::new(value, compute_cost));
}
pub fn get_exact(&self, key: &ExactCacheKey) -> Option<Number> {
if !self.config.enabled {
return None;
}
let mut cache = self.exact_cache.write().unwrap();
if let Some(entry) = cache.get_mut(key) {
if let Some(ttl) = self.config.cache_ttl {
if entry.is_expired(ttl) {
cache.remove(key);
self.record_exact_miss();
return None;
}
}
entry.access();
self.record_exact_hit();
Some(entry.value.clone())
} else {
self.record_exact_miss();
None
}
}
pub fn put_exact(&self, key: ExactCacheKey, value: Number, compute_cost: u32) {
if !self.config.enabled {
return;
}
let mut cache = self.exact_cache.write().unwrap();
if cache.len() >= self.config.exact_cache_size {
self.cleanup_exact_cache(&mut cache);
}
cache.insert(key, CacheEntry::new(value, compute_cost));
}
pub fn get_symbolic(&self, key: &SymbolicCacheKey) -> Option<Expression> {
if !self.config.enabled {
return None;
}
let mut cache = self.symbolic_cache.write().unwrap();
if let Some(entry) = cache.get_mut(key) {
if let Some(ttl) = self.config.cache_ttl {
if entry.is_expired(ttl) {
cache.remove(key);
self.record_symbolic_miss();
return None;
}
}
entry.access();
self.record_symbolic_hit();
Some(entry.value.clone())
} else {
self.record_symbolic_miss();
None
}
}
pub fn put_symbolic(&self, key: SymbolicCacheKey, value: Expression, compute_cost: u32) {
if !self.config.enabled {
return;
}
let mut cache = self.symbolic_cache.write().unwrap();
if cache.len() >= self.config.symbolic_cache_size {
self.cleanup_symbolic_cache(&mut cache);
}
cache.insert(key, CacheEntry::new(value, compute_cost));
}
fn cleanup_fast_cache(&self, cache: &mut HashMap<FastCacheKey, CacheEntry<i64>>) {
let target_size = self.config.fast_cache_size * 3 / 4;
if cache.len() <= target_size {
return;
}
let mut items: Vec<_> = cache.iter()
.map(|(key, entry)| (key.clone(), entry.priority()))
.collect();
items.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let to_remove = cache.len() - target_size;
for (key, _) in items.iter().take(to_remove) {
cache.remove(key);
}
self.record_cleanup();
}
fn cleanup_exact_cache(&self, cache: &mut HashMap<ExactCacheKey, CacheEntry<Number>>) {
let target_size = self.config.exact_cache_size * 3 / 4;
if cache.len() <= target_size {
return;
}
let mut items: Vec<_> = cache.iter()
.map(|(key, entry)| (key.clone(), entry.priority()))
.collect();
items.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let to_remove = cache.len() - target_size;
for (key, _) in items.iter().take(to_remove) {
cache.remove(key);
}
self.record_cleanup();
}
fn cleanup_symbolic_cache(&self, cache: &mut HashMap<SymbolicCacheKey, CacheEntry<Expression>>) {
let target_size = self.config.symbolic_cache_size * 3 / 4;
if cache.len() <= target_size {
return;
}
let mut items: Vec<_> = cache.iter()
.map(|(key, entry)| (key.clone(), entry.priority()))
.collect();
items.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let to_remove = cache.len() - target_size;
for (key, _) in items.iter().take(to_remove) {
cache.remove(key);
}
self.record_cleanup();
}
pub fn cleanup_expired(&self) {
if let Some(ttl) = self.config.cache_ttl {
{
let mut cache = self.fast_cache.write().unwrap();
cache.retain(|_, entry| !entry.is_expired(ttl));
}
{
let mut cache = self.exact_cache.write().unwrap();
cache.retain(|_, entry| !entry.is_expired(ttl));
}
{
let mut cache = self.symbolic_cache.write().unwrap();
cache.retain(|_, entry| !entry.is_expired(ttl));
}
self.record_cleanup();
}
}
pub fn clear_all(&self) {
self.fast_cache.write().unwrap().clear();
self.exact_cache.write().unwrap().clear();
self.symbolic_cache.write().unwrap().clear();
self.record_cleanup();
}
pub fn get_stats(&self) -> CacheStats {
self.stats.read().unwrap().clone()
}
pub fn get_usage_info(&self) -> CacheUsageInfo {
let fast_size = self.fast_cache.read().unwrap().len();
let exact_size = self.exact_cache.read().unwrap().len();
let symbolic_size = self.symbolic_cache.read().unwrap().len();
CacheUsageInfo {
fast_cache_usage: fast_size,
fast_cache_capacity: self.config.fast_cache_size,
exact_cache_usage: exact_size,
exact_cache_capacity: self.config.exact_cache_size,
symbolic_cache_usage: symbolic_size,
symbolic_cache_capacity: self.config.symbolic_cache_size,
}
}
fn record_fast_hit(&self) {
if let Ok(mut stats) = self.stats.write() {
stats.fast_hits += 1;
stats.total_time_saved += Duration::from_micros(1);
}
}
fn record_fast_miss(&self) {
if let Ok(mut stats) = self.stats.write() {
stats.fast_misses += 1;
}
}
fn record_exact_hit(&self) {
if let Ok(mut stats) = self.stats.write() {
stats.exact_hits += 1;
stats.total_time_saved += Duration::from_micros(100);
}
}
fn record_exact_miss(&self) {
if let Ok(mut stats) = self.stats.write() {
stats.exact_misses += 1;
}
}
fn record_symbolic_hit(&self) {
if let Ok(mut stats) = self.stats.write() {
stats.symbolic_hits += 1;
stats.total_time_saved += Duration::from_millis(1);
}
}
fn record_symbolic_miss(&self) {
if let Ok(mut stats) = self.stats.write() {
stats.symbolic_misses += 1;
}
}
fn record_cleanup(&self) {
if let Ok(mut stats) = self.stats.write() {
stats.cleanup_count += 1;
}
}
}
#[derive(Debug, Clone)]
pub struct CacheUsageInfo {
pub fast_cache_usage: usize,
pub fast_cache_capacity: usize,
pub exact_cache_usage: usize,
pub exact_cache_capacity: usize,
pub symbolic_cache_usage: usize,
pub symbolic_cache_capacity: usize,
}
impl CacheUsageInfo {
pub fn fast_cache_usage_rate(&self) -> f64 {
if self.fast_cache_capacity == 0 {
0.0
} else {
self.fast_cache_usage as f64 / self.fast_cache_capacity as f64
}
}
pub fn exact_cache_usage_rate(&self) -> f64 {
if self.exact_cache_capacity == 0 {
0.0
} else {
self.exact_cache_usage as f64 / self.exact_cache_capacity as f64
}
}
pub fn symbolic_cache_usage_rate(&self) -> f64 {
if self.symbolic_cache_capacity == 0 {
0.0
} else {
self.symbolic_cache_usage as f64 / self.symbolic_cache_capacity as f64
}
}
pub fn total_usage_rate(&self) -> f64 {
let total_usage = self.fast_cache_usage + self.exact_cache_usage + self.symbolic_cache_usage;
let total_capacity = self.fast_cache_capacity + self.exact_cache_capacity + self.symbolic_cache_capacity;
if total_capacity == 0 {
0.0
} else {
total_usage as f64 / total_capacity as f64
}
}
}
pub struct CacheManager {
cache: ComputeCache,
last_cleanup: Instant,
cleanup_interval: Duration,
}
impl CacheManager {
pub fn new(config: CacheConfig) -> Self {
Self {
cache: ComputeCache::new(config),
last_cleanup: Instant::now(),
cleanup_interval: Duration::from_secs(300), }
}
pub fn cache(&self) -> &ComputeCache {
&self.cache
}
pub fn periodic_cleanup(&mut self) {
if self.last_cleanup.elapsed() >= self.cleanup_interval {
self.cache.cleanup_expired();
self.last_cleanup = Instant::now();
}
}
pub fn force_cleanup(&mut self) {
self.cache.cleanup_expired();
self.last_cleanup = Instant::now();
}
pub fn set_cleanup_interval(&mut self, interval: Duration) {
self.cleanup_interval = interval;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::Number;
use num_bigint::BigInt;
#[test]
fn test_cache_entry_creation() {
let entry = CacheEntry::new(42i64, 10);
assert_eq!(entry.value, 42);
assert_eq!(entry.access_count, 1);
assert_eq!(entry.compute_cost, 10);
}
#[test]
fn test_cache_entry_access() {
let mut entry = CacheEntry::new(42i64, 10);
let initial_time = entry.last_accessed;
std::thread::sleep(Duration::from_millis(1));
entry.access();
assert_eq!(entry.access_count, 2);
assert!(entry.last_accessed > initial_time);
}
#[test]
fn test_cache_entry_expiration() {
let entry = CacheEntry::new(42i64, 10);
assert!(!entry.is_expired(Duration::from_secs(1)));
assert!(entry.is_expired(Duration::from_nanos(1)));
}
#[test]
fn test_fast_cache_operations() {
let config = CacheConfig::default();
let cache = ComputeCache::new(config);
let key = FastCacheKey::BinaryOp(2, 3, BinaryOperator::Add);
assert!(cache.get_fast(&key).is_none());
cache.put_fast(key.clone(), 5, 1);
assert_eq!(cache.get_fast(&key), Some(5));
let stats = cache.get_stats();
assert_eq!(stats.fast_hits, 1);
assert_eq!(stats.fast_misses, 1);
}
#[test]
fn test_exact_cache_operations() {
let config = CacheConfig::default();
let cache = ComputeCache::new(config);
let key = ExactCacheKey {
operand1: Number::Integer(BigInt::from(123)),
operand2: Some(Number::Integer(BigInt::from(456))),
operation: "add".to_string(),
};
let value = Number::Integer(BigInt::from(579));
assert!(cache.get_exact(&key).is_none());
cache.put_exact(key.clone(), value.clone(), 5);
assert_eq!(cache.get_exact(&key), Some(value));
let stats = cache.get_stats();
assert_eq!(stats.exact_hits, 1);
assert_eq!(stats.exact_misses, 1);
}
#[test]
fn test_symbolic_cache_operations() {
let config = CacheConfig::default();
let cache = ComputeCache::new(config);
let key = SymbolicCacheKey {
expression: Expression::variable("x"),
operation: "simplify".to_string(),
variable: None,
};
let value = Expression::variable("x");
assert!(cache.get_symbolic(&key).is_none());
cache.put_symbolic(key.clone(), value.clone(), 10);
assert_eq!(cache.get_symbolic(&key), Some(value));
let stats = cache.get_stats();
assert_eq!(stats.symbolic_hits, 1);
assert_eq!(stats.symbolic_misses, 1);
}
#[test]
fn test_cache_size_limits() {
let config = CacheConfig {
enabled: true,
fast_cache_size: 2,
exact_cache_size: 2,
symbolic_cache_size: 2,
cache_ttl: None,
};
let cache = ComputeCache::new(config);
cache.put_fast(FastCacheKey::BinaryOp(1, 2, BinaryOperator::Add), 3, 1);
cache.put_fast(FastCacheKey::BinaryOp(2, 3, BinaryOperator::Add), 5, 1);
cache.put_fast(FastCacheKey::BinaryOp(3, 4, BinaryOperator::Add), 7, 1);
let usage = cache.get_usage_info();
assert!(usage.fast_cache_usage <= 2);
}
#[test]
fn test_cache_stats() {
let config = CacheConfig::default();
let cache = ComputeCache::new(config);
let key = FastCacheKey::BinaryOp(1, 1, BinaryOperator::Add);
cache.get_fast(&key);
cache.put_fast(key.clone(), 2, 1);
cache.get_fast(&key);
cache.get_fast(&key);
let stats = cache.get_stats();
assert_eq!(stats.fast_hits, 2);
assert_eq!(stats.fast_misses, 1);
assert_eq!(stats.fast_hit_rate(), 2.0 / 3.0);
}
#[test]
fn test_cache_manager() {
let config = CacheConfig::default();
let mut manager = CacheManager::new(config);
manager.set_cleanup_interval(Duration::from_millis(1));
std::thread::sleep(Duration::from_millis(2));
manager.periodic_cleanup();
manager.force_cleanup();
}
}