use crate::{core::UserPermissions, metrics::RoleSystemMetrics};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use std::collections::HashSet;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CacheTag {
Subject(String),
Role(String),
ResourceType(String),
Action(String),
ContextDependent,
}
#[derive(Debug, Clone)]
pub struct CacheEntry {
pub permissions: UserPermissions,
pub tags: HashSet<CacheTag>,
pub created_at: DateTime<Utc>,
}
impl CacheEntry {
pub fn new(permissions: UserPermissions, tags: HashSet<CacheTag>) -> Self {
Self {
permissions,
tags,
created_at: Utc::now(),
}
}
pub fn has_tag(&self, tag: &CacheTag) -> bool {
self.tags.contains(tag)
}
pub fn is_expired(&self, ttl_seconds: u64) -> bool {
self.permissions.is_expired(ttl_seconds)
}
}
#[derive(Debug)]
pub struct CacheManager {
cache: DashMap<(String, String), CacheEntry>,
tag_index: DashMap<CacheTag, HashSet<(String, String)>>,
metrics: Arc<RoleSystemMetrics>,
}
impl CacheManager {
pub fn new(metrics: Arc<RoleSystemMetrics>) -> Self {
Self {
cache: DashMap::new(),
tag_index: DashMap::new(),
metrics,
}
}
pub fn insert(
&self,
key: (String, String),
permissions: UserPermissions,
tags: HashSet<CacheTag>,
) {
let entry = CacheEntry::new(permissions, tags.clone());
self.cache.insert(key.clone(), entry);
for tag in tags {
self.tag_index.entry(tag).or_default().insert(key.clone());
}
}
pub fn get(&self, key: &(String, String), ttl_seconds: u64) -> Option<UserPermissions> {
if let Some(entry) = self.cache.get(key) {
if !entry.is_expired(ttl_seconds) {
self.metrics.record_cache_hit();
return Some(entry.permissions.clone());
} else {
drop(entry);
self.remove_expired_entry(key);
}
}
self.metrics.record_cache_miss();
None
}
pub fn invalidate_by_tag(&self, tag: &CacheTag) {
if let Some(keys) = self.tag_index.get(tag) {
let keys_to_remove: Vec<_> = keys.iter().cloned().collect();
drop(keys);
for key in keys_to_remove {
self.remove_entry(&key);
}
}
}
pub fn invalidate_subject(&self, subject_id: &str) {
self.invalidate_by_tag(&CacheTag::Subject(subject_id.to_string()));
}
pub fn invalidate_role(&self, role_name: &str) {
self.invalidate_by_tag(&CacheTag::Role(role_name.to_string()));
}
pub fn invalidate_resource_type(&self, resource_type: &str) {
self.invalidate_by_tag(&CacheTag::ResourceType(resource_type.to_string()));
}
pub fn invalidate_context_dependent(&self) {
self.invalidate_by_tag(&CacheTag::ContextDependent);
}
pub fn cleanup_expired(&self, ttl_seconds: u64) {
let expired_keys: Vec<_> = self
.cache
.iter()
.filter(|entry| entry.value().is_expired(ttl_seconds))
.map(|entry| entry.key().clone())
.collect();
for key in expired_keys {
self.remove_expired_entry(&key);
}
}
pub fn stats(&self) -> CacheStats {
let total_entries = self.cache.len();
let tag_count = self.tag_index.len();
let mut tags_per_entry = 0;
for entry in self.cache.iter() {
tags_per_entry += entry.value().tags.len();
}
let avg_tags_per_entry = if total_entries > 0 {
tags_per_entry as f64 / total_entries as f64
} else {
0.0
};
CacheStats {
total_entries,
tag_count,
avg_tags_per_entry,
}
}
pub fn clear(&self) {
self.cache.clear();
self.tag_index.clear();
}
pub fn generate_tags(
subject_id: &str,
action: &str,
resource_type: &str,
roles: &[String],
has_context: bool,
) -> HashSet<CacheTag> {
let mut tags = HashSet::new();
tags.insert(CacheTag::Subject(subject_id.to_string()));
tags.insert(CacheTag::Action(action.to_string()));
tags.insert(CacheTag::ResourceType(resource_type.to_string()));
for role in roles {
tags.insert(CacheTag::Role(role.clone()));
}
if has_context {
tags.insert(CacheTag::ContextDependent);
}
tags
}
fn remove_entry(&self, key: &(String, String)) {
if let Some((_, entry)) = self.cache.remove(key) {
for tag in &entry.tags {
if let Some(mut keys) = self.tag_index.get_mut(tag) {
keys.remove(key);
if keys.is_empty() {
drop(keys);
self.tag_index.remove(tag);
}
}
}
}
}
fn remove_expired_entry(&self, key: &(String, String)) {
self.remove_entry(key);
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub total_entries: usize,
pub tag_count: usize,
pub avg_tags_per_entry: f64,
}
pub trait CacheInvalidation {
fn invalidate_subject_cache(&self, subject_id: &str);
fn invalidate_role_cache(&self, role_name: &str);
fn invalidate_resource_type_cache(&self, resource_type: &str);
fn cleanup_expired_cache(&self, ttl_seconds: u64);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metrics::RoleSystemMetrics;
use std::collections::HashMap;
use std::sync::Arc;
#[test]
fn test_cache_manager_basic_operations() {
let metrics = Arc::new(RoleSystemMetrics::new());
let cache = CacheManager::new(metrics.clone());
let key = ("user1".to_string(), "read:documents".to_string());
let mut permissions_map = HashMap::new();
permissions_map.insert("read".to_string(), crate::core::AccessResult::Granted);
let permissions = UserPermissions::new(permissions_map);
let mut tags = HashSet::new();
tags.insert(CacheTag::Subject("user1".to_string()));
tags.insert(CacheTag::Action("read".to_string()));
tags.insert(CacheTag::ResourceType("documents".to_string()));
cache.insert(key.clone(), permissions.clone(), tags);
let retrieved = cache.get(&key, 300).unwrap();
assert_eq!(
retrieved.computed_permissions.len(),
permissions.computed_permissions.len()
);
let stats = cache.stats();
assert_eq!(stats.total_entries, 1);
assert_eq!(stats.tag_count, 3);
}
#[test]
fn test_cache_invalidation_by_tag() {
let metrics = Arc::new(RoleSystemMetrics::new());
let cache = CacheManager::new(metrics);
let key1 = ("user1".to_string(), "read:documents".to_string());
let key2 = ("user2".to_string(), "read:documents".to_string());
let permissions = UserPermissions::new(HashMap::new());
let mut tags1 = HashSet::new();
tags1.insert(CacheTag::Subject("user1".to_string()));
tags1.insert(CacheTag::ResourceType("documents".to_string()));
let mut tags2 = HashSet::new();
tags2.insert(CacheTag::Subject("user2".to_string()));
tags2.insert(CacheTag::ResourceType("documents".to_string()));
cache.insert(key1.clone(), permissions.clone(), tags1);
cache.insert(key2.clone(), permissions.clone(), tags2);
assert_eq!(cache.stats().total_entries, 2);
cache.invalidate_resource_type("documents");
assert_eq!(cache.stats().total_entries, 0);
}
#[test]
fn test_cache_tag_generation() {
let tags = CacheManager::generate_tags(
"user1",
"read",
"documents",
&["reader".to_string(), "user".to_string()],
true,
);
assert!(tags.contains(&CacheTag::Subject("user1".to_string())));
assert!(tags.contains(&CacheTag::Action("read".to_string())));
assert!(tags.contains(&CacheTag::ResourceType("documents".to_string())));
assert!(tags.contains(&CacheTag::Role("reader".to_string())));
assert!(tags.contains(&CacheTag::Role("user".to_string())));
assert!(tags.contains(&CacheTag::ContextDependent));
}
#[test]
fn test_expired_cache_cleanup() {
let metrics = Arc::new(RoleSystemMetrics::new());
let cache = CacheManager::new(metrics);
let key = ("user1".to_string(), "read:documents".to_string());
let mut permissions_map = HashMap::new();
permissions_map.insert("read".to_string(), crate::core::AccessResult::Granted);
let mut permissions = UserPermissions::new(permissions_map);
permissions.last_updated = Utc::now() - chrono::Duration::seconds(400);
let tags = HashSet::new();
cache.insert(key.clone(), permissions, tags);
assert_eq!(cache.stats().total_entries, 1);
cache.cleanup_expired(300);
assert_eq!(cache.stats().total_entries, 0);
}
}