use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Instant;
use super::types::ContextFingerprint;
#[derive(Debug, Clone)]
pub struct CachePage {
pub page_id: u64,
pub data: Vec<f32>,
pub page_size: usize,
pub is_allocated: bool,
pub ref_count: u32,
pub last_accessed: Instant,
pub is_dirty: bool,
}
impl CachePage {
#[must_use]
pub fn new(page_id: u64, page_size: usize) -> Self {
Self {
page_id,
data: vec![0.0; page_size],
page_size,
is_allocated: false,
ref_count: 0,
last_accessed: Instant::now(),
is_dirty: false,
}
}
#[must_use]
pub fn with_data(page_id: u64, data: Vec<f32>) -> Self {
let page_size = data.len();
Self {
page_id,
data,
page_size,
is_allocated: true,
ref_count: 1,
last_accessed: Instant::now(),
is_dirty: false,
}
}
pub fn touch(&mut self) {
self.last_accessed = Instant::now();
}
pub fn add_ref(&mut self) {
self.ref_count = self.ref_count.saturating_add(1);
}
pub fn release(&mut self) -> bool {
self.ref_count = self.ref_count.saturating_sub(1);
self.ref_count == 0
}
#[must_use]
pub fn memory_size(&self) -> usize {
self.data.len() * std::mem::size_of::<f32>()
}
pub fn clear(&mut self) {
self.data.fill(0.0);
self.is_allocated = false;
self.ref_count = 0;
self.is_dirty = false;
}
pub fn write(&mut self, offset: usize, data: &[f32]) {
assert!(
offset + data.len() <= self.page_size,
"Data would overflow page: offset={offset}, data_len={}, page_size={}",
data.len(),
self.page_size
);
self.data[offset..offset + data.len()].copy_from_slice(data);
self.is_dirty = true;
self.touch();
}
#[must_use]
pub fn read(&self, offset: usize, len: usize) -> &[f32] {
let end = (offset + len).min(self.data.len());
&self.data[offset..end]
}
}
#[derive(Debug)]
pub struct PageTable {
pages: HashMap<u64, CachePage>,
free_pages: Vec<u64>,
page_size: usize,
max_pages: usize,
next_page_id: u64,
allocated_count: usize,
}
impl PageTable {
#[must_use]
pub fn new(page_size: usize, max_pages: usize) -> Self {
Self {
pages: HashMap::with_capacity(max_pages),
free_pages: Vec::with_capacity(max_pages),
page_size,
max_pages,
next_page_id: 0,
allocated_count: 0,
}
}
#[must_use]
pub fn page_size(&self) -> usize {
self.page_size
}
#[must_use]
pub fn max_pages(&self) -> usize {
self.max_pages
}
#[must_use]
pub fn allocated_count(&self) -> usize {
self.allocated_count
}
#[must_use]
pub fn free_count(&self) -> usize {
self.free_pages.len() + (self.max_pages.saturating_sub(self.pages.len()))
}
#[must_use]
pub fn memory_usage(&self) -> usize {
self.pages
.values()
.filter(|p| p.is_allocated)
.map(CachePage::memory_size)
.sum()
}
pub fn allocate_page(&mut self) -> Option<u64> {
if let Some(page_id) = self.free_pages.pop()
&& let Some(page) = self.pages.get_mut(&page_id)
{
page.is_allocated = true;
page.ref_count = 1;
page.touch();
self.allocated_count += 1;
return Some(page_id);
}
if self.pages.len() >= self.max_pages {
return None;
}
let page_id = self.next_page_id;
self.next_page_id += 1;
let mut page = CachePage::new(page_id, self.page_size);
page.is_allocated = true;
page.ref_count = 1;
self.pages.insert(page_id, page);
self.allocated_count += 1;
Some(page_id)
}
pub fn allocate_page_with_data(&mut self, data: Vec<f32>) -> Option<u64> {
if data.len() != self.page_size {
return None;
}
if self.pages.len() >= self.max_pages && self.free_pages.is_empty() {
return None;
}
if let Some(page_id) = self.free_pages.pop()
&& let Some(page) = self.pages.get_mut(&page_id)
{
page.data = data;
page.is_allocated = true;
page.ref_count = 1;
page.is_dirty = true;
page.touch();
self.allocated_count += 1;
return Some(page_id);
}
let page_id = self.next_page_id;
self.next_page_id += 1;
let page = CachePage::with_data(page_id, data);
self.pages.insert(page_id, page);
self.allocated_count += 1;
Some(page_id)
}
pub fn free_page(&mut self, page_id: u64) {
if let Some(page) = self.pages.get_mut(&page_id)
&& page.is_allocated
{
let can_free = page.release();
if can_free {
page.clear();
self.free_pages.push(page_id);
self.allocated_count = self.allocated_count.saturating_sub(1);
}
}
}
pub fn force_free_page(&mut self, page_id: u64) {
if let Some(page) = self.pages.get_mut(&page_id)
&& page.is_allocated
{
page.clear();
self.free_pages.push(page_id);
self.allocated_count = self.allocated_count.saturating_sub(1);
}
}
#[must_use]
pub fn get_page(&self, page_id: u64) -> Option<&CachePage> {
self.pages.get(&page_id).filter(|p| p.is_allocated)
}
pub fn get_page_mut(&mut self, page_id: u64) -> Option<&mut CachePage> {
self.pages.get_mut(&page_id).filter(|p| p.is_allocated)
}
pub fn defragment(&mut self) -> usize {
let pages_to_remove: Vec<u64> = self
.pages
.iter()
.filter(|(_, page)| !page.is_allocated && page.ref_count == 0)
.map(|(id, _)| *id)
.collect();
let removed_count = pages_to_remove.len();
for page_id in pages_to_remove {
self.pages.remove(&page_id);
self.free_pages.retain(|&id| id != page_id);
}
removed_count
}
pub fn evict_lru(&mut self) -> Option<u64> {
let lru_page_id = self
.pages
.iter()
.filter(|(_, page)| page.is_allocated && page.ref_count <= 1)
.min_by_key(|(_, page)| page.last_accessed)
.map(|(id, _)| *id);
if let Some(page_id) = lru_page_id {
self.force_free_page(page_id);
Some(page_id)
} else {
None
}
}
#[must_use]
pub fn allocated_page_ids(&self) -> Vec<u64> {
self.pages
.iter()
.filter(|(_, page)| page.is_allocated)
.map(|(id, _)| *id)
.collect()
}
pub fn add_page_ref(&mut self, page_id: u64) {
if let Some(page) = self.pages.get_mut(&page_id) {
page.add_ref();
}
}
pub fn clear(&mut self) {
for page in self.pages.values_mut() {
page.clear();
}
self.free_pages.clear();
self.free_pages.extend(self.pages.keys().copied());
self.allocated_count = 0;
}
}
#[derive(Debug, Clone)]
pub struct PagedKVEntry {
pub fingerprint: ContextFingerprint,
pub page_ids: Vec<u64>,
pub total_tokens: usize,
pub created_at: Instant,
pub last_accessed: Instant,
pub access_count: u64,
}
impl PagedKVEntry {
#[must_use]
pub fn new(fingerprint: ContextFingerprint, page_ids: Vec<u64>, total_tokens: usize) -> Self {
let now = Instant::now();
Self {
fingerprint,
page_ids,
total_tokens,
created_at: now,
last_accessed: now,
access_count: 0,
}
}
pub fn record_access(&mut self) {
self.last_accessed = Instant::now();
self.access_count += 1;
}
#[must_use]
pub fn page_count(&self) -> usize {
self.page_ids.len()
}
#[must_use]
pub fn age(&self) -> std::time::Duration {
self.created_at.elapsed()
}
#[must_use]
pub fn time_since_access(&self) -> std::time::Duration {
self.last_accessed.elapsed()
}
}
#[derive(Debug)]
pub struct PagedCache {
page_table: Arc<RwLock<PageTable>>,
entries: Arc<RwLock<HashMap<u64, PagedKVEntry>>>,
page_size: usize,
}
impl PagedCache {
#[must_use]
pub fn new(page_size: usize, max_pages: usize) -> Self {
Self {
page_table: Arc::new(RwLock::new(PageTable::new(page_size, max_pages))),
entries: Arc::new(RwLock::new(HashMap::new())),
page_size,
}
}
#[must_use]
pub fn page_size(&self) -> usize {
self.page_size
}
#[must_use]
pub fn put(&self, fingerprint: &ContextFingerprint, data: &[f32]) -> Option<()> {
let num_pages = data.len().div_ceil(self.page_size);
let mut page_table = self.page_table.write().expect("lock poisoned");
let mut page_ids = Vec::with_capacity(num_pages);
for i in 0..num_pages {
let start = i * self.page_size;
let end = ((i + 1) * self.page_size).min(data.len());
let chunk = &data[start..end];
let mut page_data = vec![0.0; self.page_size];
page_data[..chunk.len()].copy_from_slice(chunk);
let page_id = page_table.allocate_page_with_data(page_data)?;
page_ids.push(page_id);
}
drop(page_table);
let entry = PagedKVEntry::new(fingerprint.clone(), page_ids, data.len());
let mut entries = self.entries.write().expect("lock poisoned");
if let Some(old_entry) = entries.remove(&fingerprint.hash) {
let mut page_table = self.page_table.write().expect("lock poisoned");
for page_id in old_entry.page_ids {
page_table.free_page(page_id);
}
}
entries.insert(fingerprint.hash, entry);
Some(())
}
#[must_use]
pub fn get(&self, fingerprint: &ContextFingerprint) -> Option<Vec<f32>> {
let mut entries = self.entries.write().expect("lock poisoned");
let entry = entries.get_mut(&fingerprint.hash)?;
entry.record_access();
let page_table = self.page_table.read().expect("lock poisoned");
let mut result = Vec::with_capacity(entry.total_tokens);
for &page_id in &entry.page_ids {
if let Some(page) = page_table.get_page(page_id) {
result.extend_from_slice(&page.data);
} else {
return None;
}
}
result.truncate(entry.total_tokens);
Some(result)
}
#[must_use]
pub fn remove(&self, fingerprint: &ContextFingerprint) -> Option<PagedKVEntry> {
let mut entries = self.entries.write().expect("lock poisoned");
let entry = entries.remove(&fingerprint.hash)?;
let mut page_table = self.page_table.write().expect("lock poisoned");
for page_id in &entry.page_ids {
page_table.free_page(*page_id);
}
Some(entry)
}
#[must_use]
pub fn contains(&self, fingerprint: &ContextFingerprint) -> bool {
let entries = self.entries.read().expect("lock poisoned");
entries.contains_key(&fingerprint.hash)
}
#[must_use]
pub fn entry_count(&self) -> usize {
let entries = self.entries.read().expect("lock poisoned");
entries.len()
}
#[must_use]
pub fn memory_usage(&self) -> usize {
let page_table = self.page_table.read().expect("lock poisoned");
page_table.memory_usage()
}
pub fn clear(&self) {
let mut entries = self.entries.write().expect("lock poisoned");
let mut page_table = self.page_table.write().expect("lock poisoned");
entries.clear();
page_table.clear();
}
#[must_use]
pub fn defragment(&self) -> usize {
let mut page_table = self.page_table.write().expect("lock poisoned");
page_table.defragment()
}
#[must_use]
pub fn evict_lru_entries(&self, count: usize) -> usize {
let mut entries = self.entries.write().expect("lock poisoned");
let mut entry_times: Vec<_> = entries
.iter()
.map(|(hash, entry)| (*hash, entry.last_accessed))
.collect();
entry_times.sort_by_key(|(_, time)| *time);
let mut evicted = 0;
let mut page_table = self.page_table.write().expect("lock poisoned");
for (hash, _) in entry_times.into_iter().take(count) {
if let Some(entry) = entries.remove(&hash) {
for page_id in entry.page_ids {
page_table.free_page(page_id);
}
evicted += 1;
}
}
evicted
}
}
impl Clone for PagedCache {
fn clone(&self) -> Self {
Self {
page_table: Arc::clone(&self.page_table),
entries: Arc::clone(&self.entries),
page_size: self.page_size,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_page_new() {
let page = CachePage::new(1, 256);
assert_eq!(page.page_id, 1);
assert_eq!(page.page_size, 256);
assert!(!page.is_allocated);
assert_eq!(page.ref_count, 0);
}
#[test]
fn test_cache_page_with_data() {
let data = vec![1.0, 2.0, 3.0];
let page = CachePage::with_data(1, data.clone());
assert_eq!(page.page_id, 1);
assert_eq!(page.data, data);
assert!(page.is_allocated);
assert_eq!(page.ref_count, 1);
}
#[test]
fn test_cache_page_write_read() {
let mut page = CachePage::new(1, 10);
page.write(2, &[1.0, 2.0, 3.0]);
let read = page.read(2, 3);
assert_eq!(read, &[1.0, 2.0, 3.0]);
assert!(page.is_dirty);
}
#[test]
fn test_cache_page_ref_counting() {
let mut page = CachePage::new(1, 256);
page.add_ref();
page.add_ref();
assert_eq!(page.ref_count, 2);
assert!(!page.release());
assert_eq!(page.ref_count, 1);
assert!(page.release());
assert_eq!(page.ref_count, 0);
}
#[test]
fn test_cache_page_clear() {
let mut page = CachePage::with_data(1, vec![1.0, 2.0, 3.0]);
page.is_dirty = true;
page.clear();
assert!(!page.is_allocated);
assert_eq!(page.ref_count, 0);
assert!(!page.is_dirty);
assert_eq!(page.data, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_page_table_new() {
let table = PageTable::new(256, 100);
assert_eq!(table.page_size(), 256);
assert_eq!(table.max_pages(), 100);
assert_eq!(table.allocated_count(), 0);
}
#[test]
fn test_page_table_allocate() {
let mut table = PageTable::new(256, 10);
let page_id = table.allocate_page();
assert!(page_id.is_some());
assert_eq!(table.allocated_count(), 1);
let page = table.get_page(page_id.unwrap());
assert!(page.is_some());
assert!(page.unwrap().is_allocated);
}
#[test]
fn test_page_table_allocate_with_data() {
let mut table = PageTable::new(4, 10);
let data = vec![1.0, 2.0, 3.0, 4.0];
let page_id = table.allocate_page_with_data(data.clone());
assert!(page_id.is_some());
let page = table.get_page(page_id.unwrap()).unwrap();
assert_eq!(page.data, data);
}
#[test]
fn test_page_table_free() {
let mut table = PageTable::new(256, 10);
let page_id = table.allocate_page().unwrap();
assert_eq!(table.allocated_count(), 1);
table.free_page(page_id);
assert_eq!(table.allocated_count(), 0);
assert_eq!(table.free_count(), 10);
}
#[test]
fn test_page_table_reuse_freed_pages() {
let mut table = PageTable::new(256, 10);
let page_id1 = table.allocate_page().unwrap();
table.free_page(page_id1);
let page_id2 = table.allocate_page().unwrap();
assert_eq!(page_id1, page_id2);
}
#[test]
fn test_page_table_max_pages_limit() {
let mut table = PageTable::new(256, 2);
let _id1 = table.allocate_page().unwrap();
let _id2 = table.allocate_page().unwrap();
let id3 = table.allocate_page();
assert!(id3.is_none());
}
#[test]
fn test_page_table_defragment() {
let mut table = PageTable::new(256, 10);
let id1 = table.allocate_page().unwrap();
let _id2 = table.allocate_page().unwrap();
let id3 = table.allocate_page().unwrap();
table.free_page(id1);
table.free_page(id3);
let removed = table.defragment();
assert_eq!(removed, 2);
}
#[test]
fn test_page_table_evict_lru() {
let mut table = PageTable::new(256, 10);
let _id1 = table.allocate_page().unwrap();
std::thread::sleep(std::time::Duration::from_millis(1));
let _id2 = table.allocate_page().unwrap();
let evicted = table.evict_lru();
assert!(evicted.is_some());
assert_eq!(table.allocated_count(), 1);
}
#[test]
fn test_page_table_clear() {
let mut table = PageTable::new(256, 10);
table.allocate_page().unwrap();
table.allocate_page().unwrap();
table.clear();
assert_eq!(table.allocated_count(), 0);
}
#[test]
fn test_paged_kv_entry_new() {
let fp = ContextFingerprint::new(123, 100, "test");
let entry = PagedKVEntry::new(fp.clone(), vec![1, 2, 3], 300);
assert_eq!(entry.fingerprint, fp);
assert_eq!(entry.page_ids, vec![1, 2, 3]);
assert_eq!(entry.total_tokens, 300);
assert_eq!(entry.access_count, 0);
}
#[test]
fn test_paged_kv_entry_record_access() {
let fp = ContextFingerprint::new(123, 100, "test");
let mut entry = PagedKVEntry::new(fp, vec![1], 100);
let initial = entry.last_accessed;
std::thread::sleep(std::time::Duration::from_millis(1));
entry.record_access();
assert_eq!(entry.access_count, 1);
assert!(entry.last_accessed > initial);
}
#[test]
fn test_paged_cache_new() {
let cache = PagedCache::new(256, 100);
assert_eq!(cache.page_size(), 256);
assert_eq!(cache.entry_count(), 0);
}
#[test]
fn test_paged_cache_put_get() {
let cache = PagedCache::new(4, 100);
let fp = ContextFingerprint::new(123, 10, "test");
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = cache.put(&fp, &data);
assert!(result.is_some());
let retrieved = cache.get(&fp);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap(), data);
}
#[test]
fn test_paged_cache_contains() {
let cache = PagedCache::new(256, 100);
let fp = ContextFingerprint::new(123, 10, "test");
let data = vec![1.0, 2.0];
assert!(!cache.contains(&fp));
let _ = cache.put(&fp, &data);
assert!(cache.contains(&fp));
}
#[test]
fn test_paged_cache_remove() {
let cache = PagedCache::new(256, 100);
let fp = ContextFingerprint::new(123, 10, "test");
let data = vec![1.0, 2.0];
let _ = cache.put(&fp, &data);
assert!(cache.contains(&fp));
let removed = cache.remove(&fp);
assert!(removed.is_some());
assert!(!cache.contains(&fp));
}
#[test]
#[allow(clippy::cast_sign_loss, clippy::cast_precision_loss)]
fn test_paged_cache_clear() {
let cache = PagedCache::new(256, 100);
for i in 0..5_i32 {
let fp = ContextFingerprint::new(i as u64, 10, format!("test{i}"));
let data = [i as f32; 10];
let _ = cache.put(&fp, &data);
}
assert_eq!(cache.entry_count(), 5);
cache.clear();
assert_eq!(cache.entry_count(), 0);
}
#[test]
#[allow(clippy::cast_sign_loss, clippy::cast_precision_loss)]
fn test_paged_cache_evict_lru() {
let cache = PagedCache::new(4, 100);
for i in 0..5_i32 {
let fp = ContextFingerprint::new(i as u64, 10, format!("test{i}"));
let data = [i as f32; 4];
let _ = cache.put(&fp, &data);
std::thread::sleep(std::time::Duration::from_millis(1));
}
let evicted = cache.evict_lru_entries(2);
assert_eq!(evicted, 2);
assert_eq!(cache.entry_count(), 3);
}
#[test]
fn test_paged_cache_replace_existing() {
let cache = PagedCache::new(4, 100);
let fp = ContextFingerprint::new(123, 10, "test");
let _ = cache.put(&fp, &[1.0, 2.0, 3.0, 4.0]);
let _ = cache.put(&fp, &[5.0, 6.0, 7.0, 8.0]);
let retrieved = cache.get(&fp).expect("should retrieve");
assert_eq!(retrieved, vec![5.0, 6.0, 7.0, 8.0]);
assert_eq!(cache.entry_count(), 1);
}
#[test]
fn test_paged_cache_clone_shares_state() {
let cache1 = PagedCache::new(256, 100);
let cache2 = cache1.clone();
let fp = ContextFingerprint::new(123, 10, "test");
let _ = cache1.put(&fp, &[1.0, 2.0]);
assert!(cache2.contains(&fp));
assert_eq!(cache1.entry_count(), cache2.entry_count());
}
#[test]
#[allow(clippy::cast_precision_loss)]
fn test_paged_cache_multi_page_entry() {
let cache = PagedCache::new(4, 100);
let fp = ContextFingerprint::new(123, 10, "test");
let data: Vec<f32> = (0..10).map(|i| i as f32).collect();
let _ = cache.put(&fp, &data);
let retrieved = cache.get(&fp).expect("get failed");
assert_eq!(retrieved, data);
}
}