use crate::error::{Result, ZiporaError};
use crate::memory::SecureMemoryPool;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheStrategy {
BreadthFirst,
DepthFirst,
CacheFriendly,
}
#[derive(Debug, Clone)]
pub struct FsaCacheConfig {
pub max_states: usize,
pub strategy: CacheStrategy,
pub compressed_paths: bool,
pub use_hugepages: bool,
pub max_memory_bytes: usize,
}
impl Default for FsaCacheConfig {
fn default() -> Self {
Self {
max_states: 1_000_000,
strategy: CacheStrategy::CacheFriendly,
compressed_paths: true,
use_hugepages: false,
max_memory_bytes: 256 * 1024 * 1024, }
}
}
impl FsaCacheConfig {
pub fn small() -> Self {
Self {
max_states: 10_000,
max_memory_bytes: 4 * 1024 * 1024, ..Default::default()
}
}
pub fn large() -> Self {
Self {
max_states: 10_000_000,
max_memory_bytes: 1024 * 1024 * 1024, use_hugepages: true,
..Default::default()
}
}
pub fn memory_efficient() -> Self {
Self {
max_states: 100_000,
max_memory_bytes: 16 * 1024 * 1024, compressed_paths: true,
strategy: CacheStrategy::DepthFirst,
..Default::default()
}
}
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct CachedState {
pub child_base: u32,
pub parent_and_flags: u32,
}
impl CachedState {
pub fn new(child_base: u32, parent: u32, is_terminal: bool, is_free: bool) -> Self {
let parent_and_flags = (parent & 0x00FFFFFF)
| if is_terminal { 0x80000000 } else { 0 }
| if is_free { 0x40000000 } else { 0 };
Self {
child_base,
parent_and_flags,
}
}
pub fn parent(&self) -> u32 {
self.parent_and_flags & 0x00FFFFFF
}
#[inline]
pub fn is_terminal(&self) -> bool {
(self.parent_and_flags & 0x80000000) != 0
}
#[inline]
pub fn is_free(&self) -> bool {
(self.parent_and_flags & 0x40000000) != 0
}
pub fn mark_free(&mut self) {
self.parent_and_flags |= 0x40000000;
}
pub fn mark_used(&mut self) {
self.parent_and_flags &= !0x40000000;
}
}
#[derive(Debug, Clone)]
pub struct ZeroPathData {
pub segments: Vec<u8>,
pub lengths: Vec<u8>,
pub total_length: u16,
}
impl ZeroPathData {
pub fn new() -> Self {
Self {
segments: Vec::new(),
lengths: Vec::new(),
total_length: 0,
}
}
pub fn add_segment(&mut self, data: &[u8]) -> Result<()> {
if data.len() > 255 {
return Err(ZiporaError::invalid_data("Path segment too long"));
}
self.segments.extend_from_slice(data);
self.lengths.push(data.len() as u8);
self.total_length += data.len() as u16;
Ok(())
}
pub fn get_full_path(&self) -> Vec<u8> {
self.segments.clone()
}
pub fn compression_ratio(&self) -> f64 {
if self.total_length == 0 {
return 0.0;
}
let compressed_size = self.segments.len() + self.lengths.len();
compressed_size as f64 / self.total_length as f64
}
}
#[derive(Debug, Clone, Default)]
pub struct FsaCacheStats {
pub hits: u64,
pub misses: u64,
pub cached_states: usize,
pub memory_usage: usize,
pub avg_compression_ratio: f64,
pub evictions: u64,
}
impl FsaCacheStats {
pub fn hit_ratio(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn memory_efficiency(&self) -> f64 {
if self.memory_usage == 0 {
0.0
} else {
self.cached_states as f64 / self.memory_usage as f64
}
}
}
pub struct FsaCache {
config: FsaCacheConfig,
states: HashMap<u32, CachedState>,
zero_paths: HashMap<u32, ZeroPathData>,
free_list: Vec<u32>,
next_state_id: u32,
stats: FsaCacheStats,
memory_pool: Option<Arc<SecureMemoryPool>>,
}
impl FsaCache {
pub fn new() -> Result<Self> {
Self::with_config(FsaCacheConfig::default())
}
pub fn with_config(config: FsaCacheConfig) -> Result<Self> {
let memory_pool = if config.max_memory_bytes > 0 {
let pool_config = crate::memory::SecurePoolConfig::small_secure();
Some(SecureMemoryPool::new(pool_config)?)
} else {
None
};
let initial_capacity = std::cmp::min(config.max_states, 1000);
Ok(Self {
config,
states: HashMap::with_capacity(initial_capacity),
zero_paths: HashMap::new(),
free_list: Vec::new(),
next_state_id: 1, stats: FsaCacheStats::default(),
memory_pool,
})
}
pub fn get_state(&self, state_id: u32) -> Option<CachedState> {
if let Some(state) = self.states.get(&state_id) {
Some(*state)
} else {
None
}
}
pub fn cache_state(&mut self, parent_id: u32, child_base: u32, is_terminal: bool) -> Result<u32> {
if self.states.len() >= self.config.max_states {
self.evict_states()?;
}
let state_id = if let Some(free_id) = self.free_list.pop() {
free_id
} else {
let id = self.next_state_id;
self.next_state_id += 1;
id
};
let state = CachedState::new(child_base, parent_id, is_terminal, false);
self.states.insert(state_id, state);
self.stats.cached_states = self.states.len();
self.stats.memory_usage = self.estimate_memory_usage();
Ok(state_id)
}
pub fn remove_state(&mut self, state_id: u32) -> bool {
if self.states.remove(&state_id).is_some() {
self.zero_paths.remove(&state_id);
self.free_list.push(state_id);
self.stats.cached_states = self.states.len();
self.stats.memory_usage = self.estimate_memory_usage();
true
} else {
false
}
}
pub fn add_zero_path(&mut self, state_id: u32, path_data: ZeroPathData) -> Result<()> {
if !self.states.contains_key(&state_id) {
return Err(ZiporaError::invalid_data("State not found in cache"));
}
self.zero_paths.insert(state_id, path_data);
self.update_compression_stats();
Ok(())
}
pub fn get_zero_path(&self, state_id: u32) -> Option<&ZeroPathData> {
self.zero_paths.get(&state_id)
}
pub fn clear(&mut self) {
self.states.clear();
self.zero_paths.clear();
self.free_list.clear();
self.next_state_id = 1;
self.stats = FsaCacheStats::default();
}
pub fn stats(&self) -> FsaCacheStats {
self.stats.clone()
}
pub fn config(&self) -> &FsaCacheConfig {
&self.config
}
#[inline]
pub fn is_full(&self) -> bool {
self.states.len() >= self.config.max_states
}
fn estimate_memory_usage(&self) -> usize {
let states_size = self.states.len() * std::mem::size_of::<(u32, CachedState)>();
let zero_paths_size: usize = self.zero_paths.values()
.map(|zp| zp.segments.len() + zp.lengths.len() + std::mem::size_of::<ZeroPathData>())
.sum();
let free_list_size = self.free_list.len() * std::mem::size_of::<u32>();
states_size + zero_paths_size + free_list_size
}
fn evict_states(&mut self) -> Result<()> {
let evict_count = std::cmp::max(1, self.config.max_states / 10);
match self.config.strategy {
CacheStrategy::BreadthFirst => self.evict_breadth_first(evict_count),
CacheStrategy::DepthFirst => self.evict_depth_first(evict_count),
CacheStrategy::CacheFriendly => self.evict_cache_friendly(evict_count),
}
}
fn evict_breadth_first(&mut self, count: usize) -> Result<()> {
let mut to_remove: Vec<u32> = self.states.keys().copied().collect();
to_remove.sort();
to_remove.truncate(count);
for state_id in to_remove {
self.states.remove(&state_id);
self.zero_paths.remove(&state_id);
self.free_list.push(state_id);
}
self.stats.evictions += count as u64;
Ok(())
}
fn evict_depth_first(&mut self, count: usize) -> Result<()> {
let mut to_remove: Vec<u32> = self.states.keys().copied().collect();
to_remove.sort_by(|a, b| b.cmp(a));
to_remove.truncate(count);
for state_id in to_remove {
self.states.remove(&state_id);
self.zero_paths.remove(&state_id);
self.free_list.push(state_id);
}
self.stats.evictions += count as u64;
Ok(())
}
fn evict_cache_friendly(&mut self, count: usize) -> Result<()> {
let mut terminal_states = Vec::new();
let mut non_terminal_states = Vec::new();
for (&state_id, &state) in &self.states {
if state.is_terminal() {
terminal_states.push(state_id);
} else {
non_terminal_states.push(state_id);
}
}
non_terminal_states.sort();
terminal_states.sort();
let mut to_remove = Vec::new();
let non_terminal_to_remove = std::cmp::min(count, non_terminal_states.len());
to_remove.extend_from_slice(&non_terminal_states[..non_terminal_to_remove]);
let remaining = count.saturating_sub(non_terminal_to_remove);
if remaining > 0 {
let terminal_to_remove = std::cmp::min(remaining, terminal_states.len());
to_remove.extend_from_slice(&terminal_states[..terminal_to_remove]);
}
for state_id in to_remove {
self.states.remove(&state_id);
self.zero_paths.remove(&state_id);
self.free_list.push(state_id);
}
self.stats.evictions += count as u64;
Ok(())
}
fn update_compression_stats(&mut self) {
if self.zero_paths.is_empty() {
self.stats.avg_compression_ratio = 0.0;
return;
}
let total_ratio: f64 = self.zero_paths.values()
.map(|zp| zp.compression_ratio())
.sum();
self.stats.avg_compression_ratio = total_ratio / self.zero_paths.len() as f64;
}
}
impl Default for FsaCache {
fn default() -> Self {
Self::new().unwrap_or_else(|e| {
panic!("FsaCache creation failed in Default: {}. \
This indicates severe memory pressure.", e)
})
}
}
const NIL_STATE: u32 = u32::MAX;
pub struct FastStateCache {
states: Vec<CachedState>,
num_used: usize,
}
impl FastStateCache {
pub fn new(capacity: usize) -> Self {
let free_state = CachedState::new(0, 0, false, true); Self {
states: vec![free_state; capacity],
num_used: 0,
}
}
#[inline]
pub fn get_state(&self, state_id: u32) -> Option<CachedState> {
let idx = state_id as usize;
if idx < self.states.len() {
let s = self.states[idx];
if !s.is_free() { Some(s) } else { None }
} else {
None
}
}
#[inline]
pub fn set_state(&mut self, state_id: u32, state: CachedState) {
let idx = state_id as usize;
if idx >= self.states.len() {
let free_state = CachedState::new(0, 0, false, true);
self.states.resize(idx + 1, free_state);
}
if self.states[idx].is_free() {
self.num_used += 1;
}
self.states[idx] = state;
}
#[inline]
pub fn has_state(&self, state_id: u32) -> bool {
let idx = state_id as usize;
idx < self.states.len() && !self.states[idx].is_free()
}
#[inline]
pub fn len(&self) -> usize {
self.num_used
}
#[inline]
pub fn is_empty(&self) -> bool {
self.num_used == 0
}
#[inline]
pub fn mem_size(&self) -> usize {
self.states.len() * std::mem::size_of::<CachedState>()
}
pub fn clear(&mut self) {
let free_state = CachedState::new(0, 0, false, true);
self.states.fill(free_state);
self.num_used = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cached_state_creation() {
let state = CachedState::new(100, 50, true, false);
assert_eq!(state.child_base, 100);
assert_eq!(state.parent(), 50);
assert!(state.is_terminal());
assert!(!state.is_free());
}
#[test]
fn test_cached_state_flags() {
let mut state = CachedState::new(100, 50, false, false);
assert!(!state.is_terminal());
assert!(!state.is_free());
state.mark_free();
assert!(state.is_free());
state.mark_used();
assert!(!state.is_free());
}
#[test]
fn test_zero_path_data() {
let mut zp = ZeroPathData::new();
assert_eq!(zp.total_length, 0);
zp.add_segment(b"hello").unwrap();
zp.add_segment(b"world").unwrap();
assert_eq!(zp.total_length, 10);
assert_eq!(zp.get_full_path(), b"helloworld");
assert!(zp.compression_ratio() > 0.0);
}
#[test]
fn test_fsa_cache_basic_operations() {
let mut cache = FsaCache::new().unwrap();
let state_id = cache.cache_state(0, 100, true).unwrap();
assert!(state_id > 0);
let state = cache.get_state(state_id).unwrap();
assert_eq!(state.child_base, 100);
assert_eq!(state.parent(), 0);
assert!(state.is_terminal());
assert!(cache.remove_state(state_id));
assert!(cache.get_state(state_id).is_none());
}
#[test]
fn test_fsa_cache_configurations() {
let small_cache = FsaCache::with_config(FsaCacheConfig::small()).unwrap();
assert_eq!(small_cache.config.max_states, 10_000);
let large_cache = FsaCache::with_config(FsaCacheConfig::large()).unwrap();
assert_eq!(large_cache.config.max_states, 10_000_000);
assert!(large_cache.config.use_hugepages);
let efficient_cache = FsaCache::with_config(FsaCacheConfig::memory_efficient()).unwrap();
assert_eq!(efficient_cache.config.strategy, CacheStrategy::DepthFirst);
assert!(efficient_cache.config.compressed_paths);
}
#[test]
fn test_fsa_cache_eviction() {
let config = FsaCacheConfig {
max_states: 3,
strategy: CacheStrategy::BreadthFirst,
..Default::default()
};
let mut cache = FsaCache::with_config(config).unwrap();
let id1 = cache.cache_state(0, 100, false).unwrap();
let id2 = cache.cache_state(0, 200, false).unwrap();
let id3 = cache.cache_state(0, 300, false).unwrap();
assert_eq!(cache.stats().cached_states, 3);
let id4 = cache.cache_state(0, 400, false).unwrap();
assert!(cache.stats().cached_states <= 3);
assert!(cache.stats().evictions > 0);
}
#[test]
fn test_zero_path_integration() {
let mut cache = FsaCache::new().unwrap();
let state_id = cache.cache_state(0, 100, false).unwrap();
let mut zp = ZeroPathData::new();
zp.add_segment(b"test").unwrap();
cache.add_zero_path(state_id, zp).unwrap();
let retrieved_zp = cache.get_zero_path(state_id).unwrap();
assert_eq!(retrieved_zp.get_full_path(), b"test");
}
#[test]
fn test_cache_statistics() {
let mut cache = FsaCache::new().unwrap();
let stats = cache.stats();
assert_eq!(stats.cached_states, 0);
assert_eq!(stats.memory_usage, 0);
cache.cache_state(0, 100, true).unwrap();
cache.cache_state(0, 200, false).unwrap();
let stats = cache.stats();
assert_eq!(stats.cached_states, 2);
assert!(stats.memory_usage > 0);
}
#[test]
fn test_fast_state_cache_basic() {
let mut cache = FastStateCache::new(100);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
let state = CachedState::new(42, 0, true, false);
cache.set_state(5, state);
assert_eq!(cache.len(), 1);
assert!(cache.has_state(5));
assert!(!cache.has_state(6));
let retrieved = cache.get_state(5).unwrap();
assert_eq!(retrieved.child_base, 42);
assert!(retrieved.is_terminal());
}
#[test]
fn test_fast_state_cache_grow() {
let mut cache = FastStateCache::new(10);
let state = CachedState::new(99, 0, false, false);
cache.set_state(50, state);
assert!(cache.has_state(50));
assert_eq!(cache.len(), 1);
}
#[test]
fn test_fast_state_cache_clear() {
let mut cache = FastStateCache::new(100);
cache.set_state(1, CachedState::new(10, 0, false, false));
cache.set_state(2, CachedState::new(20, 0, false, false));
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(!cache.has_state(1));
assert!(!cache.has_state(2));
}
#[test]
fn test_fast_state_cache_mem_size() {
let cache = FastStateCache::new(100);
assert_eq!(cache.mem_size(), 100 * std::mem::size_of::<CachedState>());
}
}