use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use serde_json::Value;
use super::policy::ApprovalScope;
const DEFAULT_MAX_ENTRIES: usize = 10_000;
#[derive(Debug)]
pub struct SessionApprovalCache {
approvals: Arc<RwLock<HashMap<String, HashMap<String, Instant>>>>,
global_approvals: Arc<RwLock<HashMap<String, Instant>>>,
cache_ttl: Option<Duration>,
max_entries: usize,
hash_args: bool,
}
impl SessionApprovalCache {
pub fn new() -> Self {
Self {
approvals: Arc::new(RwLock::new(HashMap::new())),
global_approvals: Arc::new(RwLock::new(HashMap::new())),
cache_ttl: None,
max_entries: DEFAULT_MAX_ENTRIES,
hash_args: false,
}
}
pub fn with_ttl(ttl: Duration) -> Self {
Self {
approvals: Arc::new(RwLock::new(HashMap::new())),
global_approvals: Arc::new(RwLock::new(HashMap::new())),
cache_ttl: Some(ttl),
max_entries: DEFAULT_MAX_ENTRIES,
hash_args: false,
}
}
pub fn with_hash_args() -> Self {
Self {
approvals: Arc::new(RwLock::new(HashMap::new())),
global_approvals: Arc::new(RwLock::new(HashMap::new())),
cache_ttl: None,
max_entries: DEFAULT_MAX_ENTRIES,
hash_args: true,
}
}
pub fn set_hash_args(&mut self, enabled: bool) {
self.hash_args = enabled;
}
pub fn with_max_entries(mut self, max: usize) -> Self {
self.max_entries = max;
self
}
pub fn is_approved(&self, tool_name: &str, args: &Value) -> bool {
{
let global = self.global_approvals.read().unwrap();
if let Some(recorded_at) = global.get(tool_name)
&& self.is_entry_valid(recorded_at)
{
return true;
}
}
let key = self.args_key(args);
let approvals = self.approvals.read().unwrap();
if let Some(entries) = approvals.get(tool_name)
&& let Some(recorded_at) = entries.get(&key)
&& self.is_entry_valid(recorded_at)
{
return true;
}
false
}
pub fn record_approval(&self, tool_name: &str, args: &Value, scope: ApprovalScope) {
match scope {
ApprovalScope::Once => {
}
ApprovalScope::Session => {
let key = self.args_key(args);
let mut approvals = self.approvals.write().unwrap();
let entry = approvals.entry(tool_name.to_string()).or_default();
if entry.len() >= self.max_entries {
if let Some(oldest_key) = entry
.iter()
.min_by_key(|(_, instant)| *instant)
.map(|(k, _)| k.clone())
{
entry.remove(&oldest_key);
}
}
entry.insert(key, Instant::now());
}
ApprovalScope::SessionAllTools => {
let mut global = self.global_approvals.write().unwrap();
if global.len() >= self.max_entries
&& let Some(oldest_key) = global
.iter()
.min_by_key(|(_, instant)| *instant)
.map(|(k, _)| k.clone())
{
global.remove(&oldest_key);
}
global.insert(tool_name.to_string(), Instant::now());
}
}
}
pub fn revoke(&self, tool_name: &str) {
{
let mut approvals = self.approvals.write().unwrap();
approvals.remove(tool_name);
}
{
let mut global = self.global_approvals.write().unwrap();
global.remove(tool_name);
}
}
pub fn clear(&self) {
self.approvals.write().unwrap().clear();
self.global_approvals.write().unwrap().clear();
}
pub fn cleanup_expired(&self) -> usize {
let ttl = match self.cache_ttl {
Some(ttl) => ttl,
None => return 0,
};
let now = Instant::now();
let mut removed = 0;
{
let mut approvals = self.approvals.write().unwrap();
for hashes in approvals.values_mut() {
let before = hashes.len();
hashes.retain(|_, recorded_at| now.duration_since(*recorded_at) < ttl);
removed += before - hashes.len();
}
approvals.retain(|_, hashes| !hashes.is_empty());
}
{
let mut global = self.global_approvals.write().unwrap();
let before = global.len();
global.retain(|_, recorded_at| now.duration_since(*recorded_at) < ttl);
removed += before - global.len();
}
removed
}
pub fn stats(&self) -> CacheStats {
let approvals = self.approvals.read().unwrap();
let global = self.global_approvals.read().unwrap();
let per_tool_entries: usize = approvals.values().map(|h| h.len()).sum();
let tools_cached = approvals.len();
CacheStats {
per_tool_entries,
global_entries: global.len(),
tools_cached,
ttl: self.cache_ttl,
}
}
fn is_entry_valid(&self, recorded_at: &Instant) -> bool {
match self.cache_ttl {
Some(ttl) => recorded_at.elapsed() < ttl,
None => true, }
}
fn args_key(&self, args: &Value) -> String {
let json = serde_json::to_string(args).unwrap_or_else(|_| format!("{args}"));
if self.hash_args {
Self::sha256_hex(&json)
} else {
json
}
}
fn sha256_hex(input: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
input.hash(&mut hasher);
let hash = hasher.finish();
format!("{hash:016x}")
}
}
impl Default for SessionApprovalCache {
fn default() -> Self {
Self::new()
}
}
impl Clone for SessionApprovalCache {
fn clone(&self) -> Self {
Self {
approvals: Arc::new(RwLock::new(self.approvals.read().unwrap().clone())),
global_approvals: Arc::new(RwLock::new(self.global_approvals.read().unwrap().clone())),
cache_ttl: self.cache_ttl,
max_entries: self.max_entries,
hash_args: self.hash_args,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub per_tool_entries: usize,
pub global_entries: usize,
pub tools_cached: usize,
pub ttl: Option<Duration>,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_empty_cache_not_approved() {
let cache = SessionApprovalCache::new();
assert!(!cache.is_approved("tool", &json!({"a": 1})));
}
#[test]
fn test_session_scope_caches_exact_args() {
let cache = SessionApprovalCache::new();
let args = json!({"path": "/tmp/test"});
cache.record_approval("Read", &args, ApprovalScope::Session);
assert!(cache.is_approved("Read", &json!({"path": "/tmp/test"})));
assert!(!cache.is_approved("Read", &json!({"path": "/tmp/other"})));
}
#[test]
fn test_session_all_tools_scope() {
let cache = SessionApprovalCache::new();
cache.record_approval("Bash", &json!({}), ApprovalScope::SessionAllTools);
assert!(cache.is_approved("Bash", &json!({"cmd": "ls"})));
assert!(cache.is_approved("Bash", &json!({"cmd": "rm -rf /"})));
assert!(!cache.is_approved("Read", &json!({})));
}
#[test]
fn test_once_scope_no_cache() {
let cache = SessionApprovalCache::new();
cache.record_approval("tool", &json!({"a": 1}), ApprovalScope::Once);
assert!(!cache.is_approved("tool", &json!({"a": 1})));
}
#[test]
fn test_revoke() {
let cache = SessionApprovalCache::new();
cache.record_approval("tool1", &json!({}), ApprovalScope::SessionAllTools);
cache.record_approval("tool2", &json!({"x": 1}), ApprovalScope::Session);
cache.revoke("tool1");
assert!(!cache.is_approved("tool1", &json!({})));
assert!(cache.is_approved("tool2", &json!({"x": 1})));
}
#[test]
fn test_clear() {
let cache = SessionApprovalCache::new();
cache.record_approval("tool1", &json!({}), ApprovalScope::SessionAllTools);
cache.record_approval("tool2", &json!({}), ApprovalScope::Session);
cache.clear();
assert!(!cache.is_approved("tool1", &json!({})));
assert!(!cache.is_approved("tool2", &json!({})));
}
#[test]
fn test_clone_independent() {
let cache = SessionApprovalCache::new();
cache.record_approval("tool", &json!({}), ApprovalScope::SessionAllTools);
let cloned = cache.clone();
cache.clear();
assert!(!cache.is_approved("tool", &json!({})));
assert!(cloned.is_approved("tool", &json!({})));
}
#[test]
fn test_ttl_entry_valid_before_expiry() {
let cache = SessionApprovalCache::with_ttl(Duration::from_secs(3600));
cache.record_approval(
"Read",
&json!({"path": "/tmp/test"}),
ApprovalScope::Session,
);
assert!(cache.is_approved("Read", &json!({"path": "/tmp/test"})));
}
#[test]
fn test_ttl_global_entry_valid_before_expiry() {
let cache = SessionApprovalCache::with_ttl(Duration::from_secs(3600));
cache.record_approval("Bash", &json!({}), ApprovalScope::SessionAllTools);
assert!(cache.is_approved("Bash", &json!({"cmd": "ls"})));
}
#[test]
fn test_ttl_default_no_expiry() {
let cache = SessionApprovalCache::new();
cache.record_approval("tool", &json!({"a": 1}), ApprovalScope::Session);
assert!(cache.is_approved("tool", &json!({"a": 1})));
assert_eq!(cache.stats().ttl, None);
}
#[test]
fn test_ttl_expired_entry_not_approved() {
let cache = SessionApprovalCache::with_ttl(Duration::from_millis(1));
cache.record_approval("tool", &json!({"a": 1}), ApprovalScope::Session);
std::thread::sleep(Duration::from_millis(5));
assert!(!cache.is_approved("tool", &json!({"a": 1})));
}
#[test]
fn test_ttl_expired_global_entry_not_approved() {
let cache = SessionApprovalCache::with_ttl(Duration::from_millis(1));
cache.record_approval("Bash", &json!({}), ApprovalScope::SessionAllTools);
std::thread::sleep(Duration::from_millis(5));
assert!(!cache.is_approved("Bash", &json!({"cmd": "ls"})));
}
#[test]
fn test_cleanup_expired() {
let cache = SessionApprovalCache::with_ttl(Duration::from_millis(1));
cache.record_approval("tool1", &json!({"a": 1}), ApprovalScope::Session);
cache.record_approval("tool2", &json!({}), ApprovalScope::SessionAllTools);
std::thread::sleep(Duration::from_millis(5));
let removed = cache.cleanup_expired();
assert_eq!(removed, 2); assert_eq!(cache.stats().per_tool_entries, 0);
assert_eq!(cache.stats().global_entries, 0);
}
#[test]
fn test_cleanup_no_ttl_is_noop() {
let cache = SessionApprovalCache::new();
cache.record_approval("tool", &json!({"a": 1}), ApprovalScope::Session);
let removed = cache.cleanup_expired();
assert_eq!(removed, 0);
assert!(cache.is_approved("tool", &json!({"a": 1})));
}
#[test]
fn test_stats() {
let cache = SessionApprovalCache::with_ttl(Duration::from_secs(300));
cache.record_approval("tool1", &json!({"a": 1}), ApprovalScope::Session);
cache.record_approval("tool1", &json!({"a": 2}), ApprovalScope::Session);
cache.record_approval("tool2", &json!({}), ApprovalScope::SessionAllTools);
let stats = cache.stats();
assert_eq!(stats.per_tool_entries, 2);
assert_eq!(stats.global_entries, 1);
assert_eq!(stats.tools_cached, 1); assert_eq!(stats.ttl, Some(Duration::from_secs(300)));
}
}