use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CacheKey {
pub query_type: String,
pub params: Vec<(String, String)>,
hash: u64,
}
impl CacheKey {
pub fn new(query_type: impl Into<String>) -> Self {
let query_type = query_type.into();
let mut key = Self {
query_type,
params: Vec::new(),
hash: 0,
};
key.rehash();
key
}
pub fn param(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.params.push((name.into(), value.into()));
self.params.sort_by(|a, b| a.0.cmp(&b.0));
self.rehash();
self
}
pub fn params(mut self, params: impl IntoIterator<Item = (String, String)>) -> Self {
self.params.extend(params);
self.params.sort_by(|a, b| a.0.cmp(&b.0));
self.rehash();
self
}
fn rehash(&mut self) {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
self.query_type.hash(&mut hasher);
for (k, v) in &self.params {
k.hash(&mut hasher);
v.hash(&mut hasher);
}
self.hash = hasher.finish();
}
}
impl Hash for CacheKey {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u64(self.hash);
}
}
#[derive(Debug, Clone)]
pub struct CachePolicy {
pub ttl: Duration,
pub dependencies: HashSet<String>,
pub priority: u8,
pub sliding: bool,
}
impl Default for CachePolicy {
fn default() -> Self {
Self {
ttl: Duration::from_secs(300), dependencies: HashSet::new(),
priority: 50,
sliding: false,
}
}
}
impl CachePolicy {
pub fn ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
pub fn depends_on(mut self, deps: &[&str]) -> Self {
for dep in deps {
self.dependencies.insert(dep.to_string());
}
self
}
pub fn priority(mut self, priority: u8) -> Self {
self.priority = priority;
self
}
pub fn sliding(mut self) -> Self {
self.sliding = true;
self
}
}
struct CacheEntry {
data: Vec<u8>,
size: usize,
created_at: Instant,
last_accessed: Instant,
access_count: AtomicU64,
policy: CachePolicy,
}
impl CacheEntry {
fn new(data: Vec<u8>, policy: CachePolicy) -> Self {
let size = data.len();
let now = Instant::now();
Self {
data,
size,
created_at: now,
last_accessed: now,
access_count: AtomicU64::new(0),
policy,
}
}
fn is_expired(&self) -> bool {
let elapsed = if self.policy.sliding {
self.last_accessed.elapsed()
} else {
self.created_at.elapsed()
};
elapsed > self.policy.ttl
}
fn touch(&mut self) {
self.access_count.fetch_add(1, Ordering::Relaxed);
self.last_accessed = Instant::now();
}
fn eviction_score(&self) -> u64 {
let frequency = self.access_count.load(Ordering::Relaxed);
let recency = self.last_accessed.elapsed().as_secs();
let priority = self.policy.priority as u64;
frequency
.saturating_mul(priority)
.checked_div(recency)
.unwrap_or_else(|| frequency.saturating_mul(priority).saturating_mul(1000))
}
}
#[derive(Debug, Clone, Default)]
pub struct ResultCacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub entry_count: usize,
pub memory_bytes: usize,
pub max_memory_bytes: usize,
pub expirations: u64,
pub invalidations: u64,
}
impl ResultCacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn memory_utilization(&self) -> f64 {
if self.max_memory_bytes == 0 {
0.0
} else {
self.memory_bytes as f64 / self.max_memory_bytes as f64
}
}
}
pub struct ResultCache {
entries: HashMap<CacheKey, CacheEntry>,
dependency_index: HashMap<String, HashSet<CacheKey>>,
max_memory: usize,
current_memory: usize,
stats: ResultCacheStats,
}
impl ResultCache {
pub fn new(max_memory_bytes: usize) -> Self {
Self {
entries: HashMap::new(),
dependency_index: HashMap::new(),
max_memory: max_memory_bytes,
current_memory: 0,
stats: ResultCacheStats {
max_memory_bytes,
..Default::default()
},
}
}
pub fn get(&mut self, key: &CacheKey) -> Option<Vec<u8>> {
if let Some(entry) = self.entries.get(key) {
if entry.is_expired() {
self.remove(key);
self.stats.expirations += 1;
self.stats.misses += 1;
return None;
}
}
if let Some(entry) = self.entries.get_mut(key) {
entry.touch();
self.stats.hits += 1;
Some(entry.data.clone())
} else {
self.stats.misses += 1;
None
}
}
pub fn contains(&self, key: &CacheKey) -> bool {
if let Some(entry) = self.entries.get(key) {
!entry.is_expired()
} else {
false
}
}
pub fn insert(&mut self, key: CacheKey, data: Vec<u8>, policy: CachePolicy) {
let entry_size = data.len() + std::mem::size_of::<CacheEntry>();
if self.entries.contains_key(&key) {
self.remove(&key);
}
while self.current_memory + entry_size > self.max_memory && !self.entries.is_empty() {
self.evict_one();
}
for dep in &policy.dependencies {
self.dependency_index
.entry(dep.clone())
.or_default()
.insert(key.clone());
}
let entry = CacheEntry::new(data, policy);
self.current_memory += entry.size;
self.entries.insert(key, entry);
self.stats.entry_count = self.entries.len();
self.stats.memory_bytes = self.current_memory;
}
pub fn remove(&mut self, key: &CacheKey) -> Option<Vec<u8>> {
if let Some(entry) = self.entries.remove(key) {
self.current_memory = self.current_memory.saturating_sub(entry.size);
for dep in &entry.policy.dependencies {
if let Some(keys) = self.dependency_index.get_mut(dep) {
keys.remove(key);
}
}
self.stats.entry_count = self.entries.len();
self.stats.memory_bytes = self.current_memory;
Some(entry.data)
} else {
None
}
}
pub fn invalidate_by_dependency(&mut self, dependency: &str) {
if let Some(keys) = self.dependency_index.remove(dependency) {
for key in keys {
if self.entries.remove(&key).is_some() {
self.stats.invalidations += 1;
}
}
self.stats.entry_count = self.entries.len();
self.current_memory = self.entries.values().map(|e| e.size).sum();
self.stats.memory_bytes = self.current_memory;
}
}
pub fn invalidate_where<F>(&mut self, predicate: F)
where
F: Fn(&CacheKey) -> bool,
{
let keys_to_remove: Vec<CacheKey> = self
.entries
.keys()
.filter(|k| predicate(k))
.cloned()
.collect();
for key in keys_to_remove {
self.remove(&key);
self.stats.invalidations += 1;
}
}
pub fn prune_expired(&mut self) {
let expired: Vec<CacheKey> = self
.entries
.iter()
.filter(|(_, v)| v.is_expired())
.map(|(k, _)| k.clone())
.collect();
for key in expired {
self.remove(&key);
self.stats.expirations += 1;
}
}
pub fn clear(&mut self) {
self.entries.clear();
self.dependency_index.clear();
self.current_memory = 0;
self.stats.entry_count = 0;
self.stats.memory_bytes = 0;
}
pub fn stats(&self) -> &ResultCacheStats {
&self.stats
}
fn evict_one(&mut self) {
let victim = self
.entries
.iter()
.min_by_key(|(_, v)| v.eviction_score())
.map(|(k, _)| k.clone());
if let Some(key) = victim {
self.remove(&key);
self.stats.evictions += 1;
}
}
}
#[derive(Debug, Clone)]
pub struct MaterializedViewDef {
pub name: String,
pub query: String,
pub dependencies: Vec<String>,
pub refresh: RefreshPolicy,
}
#[derive(Debug, Clone)]
pub enum RefreshPolicy {
Manual,
OnChange,
Periodic(Duration),
AfterWrites(usize),
}
struct MaterializedView {
data: Vec<u8>,
def: MaterializedViewDef,
last_refresh: Instant,
writes_since_refresh: usize,
stale: bool,
}
pub struct MaterializedViewCache {
views: HashMap<String, MaterializedView>,
dependency_index: HashMap<String, HashSet<String>>,
}
impl MaterializedViewCache {
pub fn new() -> Self {
Self {
views: HashMap::new(),
dependency_index: HashMap::new(),
}
}
pub fn register(&mut self, def: MaterializedViewDef) {
for dep in &def.dependencies {
self.dependency_index
.entry(dep.clone())
.or_default()
.insert(def.name.clone());
}
let view = MaterializedView {
data: Vec::new(),
def,
last_refresh: Instant::now(),
writes_since_refresh: 0,
stale: true,
};
self.views.insert(view.def.name.clone(), view);
}
pub fn get(&self, name: &str) -> Option<&[u8]> {
self.views
.get(name)
.filter(|v| !v.stale && !v.data.is_empty())
.map(|v| v.data.as_slice())
}
pub fn needs_refresh(&self, name: &str) -> bool {
self.views.get(name).map(|v| v.stale).unwrap_or(false)
}
pub fn refresh(&mut self, name: &str, data: Vec<u8>) {
if let Some(view) = self.views.get_mut(name) {
view.data = data;
view.last_refresh = Instant::now();
view.writes_since_refresh = 0;
view.stale = false;
}
}
pub fn mark_stale(&mut self, table: &str) {
if let Some(view_names) = self.dependency_index.get(table) {
for name in view_names.clone() {
if let Some(view) = self.views.get_mut(&name) {
view.writes_since_refresh += 1;
match &view.def.refresh {
RefreshPolicy::OnChange => {
view.stale = true;
}
RefreshPolicy::AfterWrites(threshold)
if view.writes_since_refresh >= *threshold =>
{
view.stale = true;
}
_ => {}
}
}
}
}
}
pub fn due_for_refresh(&self) -> Vec<String> {
self.views
.values()
.filter(|v| {
if let RefreshPolicy::Periodic(interval) = &v.def.refresh {
v.last_refresh.elapsed() >= *interval
} else {
false
}
})
.map(|v| v.def.name.clone())
.collect()
}
pub fn remove(&mut self, name: &str) {
if let Some(view) = self.views.remove(name) {
for dep in &view.def.dependencies {
if let Some(names) = self.dependency_index.get_mut(dep) {
names.remove(name);
}
}
}
}
pub fn list(&self) -> Vec<&str> {
self.views.keys().map(|s| s.as_str()).collect()
}
}
impl Default for MaterializedViewCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_key_hashing() {
let key1 = CacheKey::new("attack_paths")
.param("from", "host1")
.param("to", "host2");
let key2 = CacheKey::new("attack_paths")
.param("to", "host2")
.param("from", "host1");
assert_eq!(key1, key2);
assert_eq!(key1.hash, key2.hash);
}
#[test]
fn test_result_cache_basic() {
let mut cache = ResultCache::new(1024 * 1024);
let key = CacheKey::new("test_query").param("id", "123");
let data = vec![1, 2, 3, 4, 5];
cache.insert(key.clone(), data.clone(), CachePolicy::default());
let result = cache.get(&key);
assert_eq!(result, Some(data));
assert_eq!(cache.stats().hits, 1);
}
#[test]
fn test_cache_expiration() {
let mut cache = ResultCache::new(1024 * 1024);
let key = CacheKey::new("test");
let data = vec![1, 2, 3];
cache.insert(
key.clone(),
data,
CachePolicy::default().ttl(Duration::from_millis(1)),
);
std::thread::sleep(Duration::from_millis(10));
assert!(cache.get(&key).is_none());
assert_eq!(cache.stats().expirations, 1);
}
#[test]
fn test_dependency_invalidation() {
let mut cache = ResultCache::new(1024 * 1024);
let key = CacheKey::new("host_query");
cache.insert(
key.clone(),
vec![1, 2, 3],
CachePolicy::default().depends_on(&["hosts"]),
);
assert!(cache.contains(&key));
cache.invalidate_by_dependency("hosts");
assert!(!cache.contains(&key));
assert_eq!(cache.stats().invalidations, 1);
}
#[test]
fn test_memory_eviction() {
let mut cache = ResultCache::new(100);
for i in 0..10 {
let key = CacheKey::new("query").param("i", i.to_string());
cache.insert(key, vec![0u8; 20], CachePolicy::default());
}
assert!(cache.stats().evictions > 0);
assert!(cache.stats().memory_bytes <= 100);
}
#[test]
fn test_materialized_view() {
let mut cache = MaterializedViewCache::new();
cache.register(MaterializedViewDef {
name: "active_hosts".to_string(),
query: "SELECT * FROM hosts WHERE status = 'active'".to_string(),
dependencies: vec!["hosts".to_string()],
refresh: RefreshPolicy::OnChange,
});
assert!(cache.needs_refresh("active_hosts"));
cache.refresh("active_hosts", vec![1, 2, 3]);
assert!(!cache.needs_refresh("active_hosts"));
assert_eq!(cache.get("active_hosts"), Some(&[1, 2, 3][..]));
cache.mark_stale("hosts");
assert!(cache.needs_refresh("active_hosts"));
}
}