use super::{Cache, CacheEntryMetadata, CacheStats};
use crate::RragResult;
use std::collections::{HashMap, VecDeque};
use std::hash::Hash;
use std::time::{Duration, SystemTime};
pub struct LRUCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
storage: HashMap<K, CacheNode<V>>,
access_order: VecDeque<K>,
max_size: usize,
stats: CacheStats,
_phantom: std::marker::PhantomData<(K, V)>,
}
pub struct LFUCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
storage: HashMap<K, CacheNode<V>>,
frequencies: HashMap<K, u64>,
frequency_buckets: HashMap<u64, Vec<K>>,
min_frequency: u64,
max_size: usize,
stats: CacheStats,
}
pub struct TTLCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
storage: HashMap<K, (V, SystemTime)>,
default_ttl: Duration,
cleanup_interval: Duration,
last_cleanup: SystemTime,
stats: CacheStats,
}
pub struct ARCCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
t1: HashMap<K, V>,
t2: HashMap<K, V>,
b1: HashMap<K, ()>,
b2: HashMap<K, ()>,
t1_lru: VecDeque<K>,
t2_lru: VecDeque<K>,
b1_lru: VecDeque<K>,
b2_lru: VecDeque<K>,
p: f32,
max_size: usize,
stats: CacheStats,
}
pub struct SemanticAwareCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
storage: HashMap<K, CacheNode<V>>,
similarity_groups: HashMap<u64, Vec<K>>,
embeddings: HashMap<K, Vec<f32>>,
access_patterns: HashMap<K, AccessPattern>,
max_size: usize,
similarity_threshold: f32,
stats: CacheStats,
}
#[derive(Debug, Clone)]
pub struct CacheNode<V> {
pub value: V,
pub metadata: CacheEntryMetadata,
pub size_bytes: usize,
}
#[derive(Debug, Clone)]
pub struct AccessPattern {
pub count: u64,
pub recent_accesses: VecDeque<SystemTime>,
pub avg_interval: Duration,
pub trend: AccessTrend,
}
#[derive(Debug, Clone, Copy)]
pub enum AccessTrend {
Increasing,
Decreasing,
Stable,
Unknown,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct FrequencyEntry<K>
where
K: Ord,
{
key: K,
frequency: u64,
last_access: SystemTime,
}
impl<K, V> LRUCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
pub fn new(max_size: usize) -> Self {
Self {
storage: HashMap::with_capacity(max_size),
access_order: VecDeque::with_capacity(max_size),
max_size,
stats: CacheStats::default(),
_phantom: std::marker::PhantomData,
}
}
fn update_lru(&mut self, key: &K) {
if let Some(pos) = self.access_order.iter().position(|k| k == key) {
self.access_order.remove(pos);
}
self.access_order.push_front(key.clone());
}
fn evict_lru(&mut self) -> Option<K> {
if let Some(key) = self.access_order.pop_back() {
self.storage.remove(&key);
self.stats.evictions += 1;
Some(key)
} else {
None
}
}
}
impl<K, V> Cache<K, V> for LRUCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
fn get(&self, key: &K) -> Option<V> {
let _start_time = SystemTime::now();
if let Some(node) = self.storage.get(key) {
Some(node.value.clone())
} else {
None
}
}
fn put(&mut self, key: K, value: V) -> RragResult<()> {
let size_bytes = std::mem::size_of::<V>();
let node = CacheNode {
value,
metadata: CacheEntryMetadata::new(),
size_bytes,
};
if self.storage.contains_key(&key) {
self.storage.insert(key.clone(), node);
self.update_lru(&key);
return Ok(());
}
if self.storage.len() >= self.max_size {
self.evict_lru();
}
self.storage.insert(key.clone(), node);
self.update_lru(&key);
Ok(())
}
fn remove(&mut self, key: &K) -> Option<V> {
if let Some(node) = self.storage.remove(key) {
if let Some(pos) = self.access_order.iter().position(|k| k == key) {
self.access_order.remove(pos);
}
Some(node.value)
} else {
None
}
}
fn contains(&self, key: &K) -> bool {
self.storage.contains_key(key)
}
fn clear(&mut self) {
self.storage.clear();
self.access_order.clear();
self.stats = CacheStats::default();
}
fn size(&self) -> usize {
self.storage.len()
}
fn stats(&self) -> CacheStats {
self.stats.clone()
}
}
impl<K, V> LFUCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
pub fn new(max_size: usize) -> Self {
Self {
storage: HashMap::with_capacity(max_size),
frequencies: HashMap::with_capacity(max_size),
frequency_buckets: HashMap::new(),
min_frequency: 1,
max_size,
stats: CacheStats::default(),
}
}
fn update_frequency(&mut self, key: &K) {
let old_freq = self.frequencies.get(key).copied().unwrap_or(0);
let new_freq = old_freq + 1;
self.frequencies.insert(key.clone(), new_freq);
if old_freq > 0 {
if let Some(bucket) = self.frequency_buckets.get_mut(&old_freq) {
bucket.retain(|k| k != key);
if bucket.is_empty() && old_freq == self.min_frequency {
self.min_frequency += 1;
}
}
}
self.frequency_buckets
.entry(new_freq)
.or_insert_with(Vec::new)
.push(key.clone());
}
fn evict_lfu(&mut self) -> Option<K> {
if let Some(bucket) = self.frequency_buckets.get_mut(&self.min_frequency) {
if let Some(key) = bucket.pop() {
self.storage.remove(&key);
self.frequencies.remove(&key);
self.stats.evictions += 1;
return Some(key);
}
}
None
}
}
impl<K, V> Cache<K, V> for LFUCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
fn get(&self, key: &K) -> Option<V> {
if let Some(node) = self.storage.get(key) {
Some(node.value.clone())
} else {
None
}
}
fn put(&mut self, key: K, value: V) -> RragResult<()> {
let size_bytes = std::mem::size_of::<V>();
let node = CacheNode {
value,
metadata: CacheEntryMetadata::new(),
size_bytes,
};
if self.storage.contains_key(&key) {
self.storage.insert(key.clone(), node);
self.update_frequency(&key);
return Ok(());
}
if self.storage.len() >= self.max_size {
self.evict_lfu();
}
self.storage.insert(key.clone(), node);
self.update_frequency(&key);
Ok(())
}
fn remove(&mut self, key: &K) -> Option<V> {
if let Some(node) = self.storage.remove(key) {
self.frequencies.remove(key);
Some(node.value)
} else {
None
}
}
fn contains(&self, key: &K) -> bool {
self.storage.contains_key(key)
}
fn clear(&mut self) {
self.storage.clear();
self.frequencies.clear();
self.frequency_buckets.clear();
self.min_frequency = 1;
self.stats = CacheStats::default();
}
fn size(&self) -> usize {
self.storage.len()
}
fn stats(&self) -> CacheStats {
self.stats.clone()
}
}
impl<K, V> TTLCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
pub fn new(default_ttl: Duration) -> Self {
Self {
storage: HashMap::new(),
default_ttl,
cleanup_interval: Duration::from_secs(60), last_cleanup: SystemTime::now(),
stats: CacheStats::default(),
}
}
fn cleanup_expired(&mut self) {
let now = SystemTime::now();
if now.duration_since(self.last_cleanup).unwrap_or_default() < self.cleanup_interval {
return;
}
let before_count = self.storage.len();
self.storage.retain(|_key, (_, expiry)| now < *expiry);
let after_count = self.storage.len();
self.stats.evictions += (before_count - after_count) as u64;
self.last_cleanup = now;
}
}
impl<K, V> Cache<K, V> for TTLCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
fn get(&self, key: &K) -> Option<V> {
if let Some((value, expiry)) = self.storage.get(key) {
if SystemTime::now() < *expiry {
Some(value.clone())
} else {
None
}
} else {
None
}
}
fn put(&mut self, key: K, value: V) -> RragResult<()> {
let expiry = SystemTime::now() + self.default_ttl;
self.storage.insert(key, (value, expiry));
self.cleanup_expired();
Ok(())
}
fn remove(&mut self, key: &K) -> Option<V> {
self.storage.remove(key).map(|(value, _)| value)
}
fn contains(&self, key: &K) -> bool {
if let Some((_, expiry)) = self.storage.get(key) {
SystemTime::now() < *expiry
} else {
false
}
}
fn clear(&mut self) {
self.storage.clear();
self.stats = CacheStats::default();
}
fn size(&self) -> usize {
let now = SystemTime::now();
self.storage
.values()
.filter(|(_, expiry)| now < *expiry)
.count()
}
fn stats(&self) -> CacheStats {
self.stats.clone()
}
}
impl<K> PartialOrd for FrequencyEntry<K>
where
K: Ord,
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<K> Ord for FrequencyEntry<K>
where
K: Ord,
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.frequency
.cmp(&other.frequency)
.then_with(|| self.last_access.cmp(&other.last_access))
}
}
impl AccessPattern {
pub fn new() -> Self {
Self {
count: 0,
recent_accesses: VecDeque::new(),
avg_interval: Duration::from_secs(0),
trend: AccessTrend::Unknown,
}
}
pub fn record_access(&mut self) {
let now = SystemTime::now();
self.count += 1;
self.recent_accesses.push_back(now);
if self.recent_accesses.len() > 10 {
self.recent_accesses.pop_front();
}
self.update_metrics();
}
fn update_metrics(&mut self) {
if self.recent_accesses.len() < 2 {
return;
}
let mut total_interval = Duration::from_secs(0);
let mut interval_count = 0;
for window in self.recent_accesses.as_slices().0.windows(2) {
if let Ok(interval) = window[1].duration_since(window[0]) {
total_interval += interval;
interval_count += 1;
}
}
if interval_count > 0 {
self.avg_interval = total_interval / interval_count as u32;
}
if self.recent_accesses.len() >= 4 {
let _first_half_avg = self.recent_accesses.len() / 2;
self.trend = AccessTrend::Stable; }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lru_cache() {
let mut cache = LRUCache::new(3);
cache.put("a".to_string(), 1).unwrap();
cache.put("b".to_string(), 2).unwrap();
cache.put("c".to_string(), 3).unwrap();
assert_eq!(cache.size(), 3);
assert_eq!(cache.get(&"a".to_string()), Some(1));
cache.put("d".to_string(), 4).unwrap();
assert_eq!(cache.size(), 3);
}
#[test]
fn test_lfu_cache() {
let mut cache = LFUCache::new(2);
cache.put("a".to_string(), 1).unwrap();
cache.put("b".to_string(), 2).unwrap();
cache.get(&"a".to_string());
cache.get(&"a".to_string());
cache.put("c".to_string(), 3).unwrap();
assert_eq!(cache.get(&"a".to_string()), Some(1));
assert_eq!(cache.get(&"b".to_string()), None);
assert_eq!(cache.get(&"c".to_string()), Some(3));
}
#[test]
fn test_ttl_cache() {
let mut cache = TTLCache::new(Duration::from_millis(100));
cache.put("key".to_string(), "value".to_string()).unwrap();
assert_eq!(cache.get(&"key".to_string()), Some("value".to_string()));
std::thread::sleep(Duration::from_millis(150));
assert_eq!(cache.get(&"key".to_string()), None);
}
#[test]
fn test_access_pattern() {
let mut pattern = AccessPattern::new();
assert_eq!(pattern.count, 0);
pattern.record_access();
assert_eq!(pattern.count, 1);
pattern.record_access();
assert_eq!(pattern.count, 2);
}
}