use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use moka::sync::Cache;
use crate::engine::decision::PolicyDecision;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CacheKey {
pub agent_id: [u8; 16],
pub policy_epoch: u64,
pub action_hash: u64,
}
impl CacheKey {
pub fn new(agent_id: &[u8; 16], policy_epoch: u64, action: &aa_core::GovernanceAction) -> Self {
Self {
agent_id: *agent_id,
policy_epoch,
action_hash: action_discriminant(action),
}
}
}
fn action_discriminant(action: &aa_core::GovernanceAction) -> u64 {
use ahash::AHasher;
use std::hash::{Hash, Hasher};
let mut h = AHasher::default();
match action {
aa_core::GovernanceAction::ToolCall { name, .. } => {
"tool".hash(&mut h);
name.hash(&mut h);
}
aa_core::GovernanceAction::ToolResult { tool_name, .. } => {
"tool_result".hash(&mut h);
tool_name.hash(&mut h);
}
aa_core::GovernanceAction::NetworkRequest { url, method } => {
"net".hash(&mut h);
url.hash(&mut h);
method.hash(&mut h);
}
aa_core::GovernanceAction::FileAccess { path, mode } => {
"file".hash(&mut h);
path.hash(&mut h);
format!("{mode:?}").hash(&mut h);
}
aa_core::GovernanceAction::ProcessExec { command } => {
"exec".hash(&mut h);
command.hash(&mut h);
}
aa_core::GovernanceAction::SendMessage {
source_team_id,
target_team_id,
channel_id,
} => {
"msg".hash(&mut h);
source_team_id.hash(&mut h);
target_team_id.hash(&mut h);
channel_id.hash(&mut h);
}
}
h.finish()
}
#[derive(Clone)]
pub struct DecisionCache {
inner: Cache<CacheKey, PolicyDecision>,
hits: Arc<AtomicU64>,
misses: Arc<AtomicU64>,
}
impl DecisionCache {
pub fn new(capacity: u64) -> Self {
let inner = Cache::builder()
.max_capacity(capacity)
.time_to_live(std::time::Duration::from_secs(60))
.build();
Self {
inner,
hits: Arc::new(AtomicU64::new(0)),
misses: Arc::new(AtomicU64::new(0)),
}
}
pub fn get(&self, key: &CacheKey) -> Option<PolicyDecision> {
let result = self.inner.get(key);
if result.is_some() {
self.hits.fetch_add(1, Ordering::Relaxed);
metrics::counter!("policy_decision_cache_hits_total").increment(1);
} else {
self.misses.fetch_add(1, Ordering::Relaxed);
metrics::counter!("policy_decision_cache_misses_total").increment(1);
}
result
}
pub fn insert(&self, key: CacheKey, value: PolicyDecision) {
self.inner.insert(key, value);
}
pub fn invalidate_all(&self) {
self.inner.invalidate_all();
}
pub fn invalidate_for_agent(&self, _agent_id: &[u8; 16]) {
self.inner.invalidate_all();
}
pub fn cache_hits(&self) -> u64 {
self.hits.load(Ordering::Relaxed)
}
pub fn cache_misses(&self) -> u64 {
self.misses.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tool_action(name: &str) -> aa_core::GovernanceAction {
aa_core::GovernanceAction::ToolCall {
name: name.to_string(),
args: String::new(),
}
}
#[test]
fn cache_hit_after_insert() {
let cache = DecisionCache::new(128);
let key = CacheKey::new(&[1u8; 16], 1, &tool_action("bash"));
cache.insert(key.clone(), PolicyDecision::Allow);
assert_eq!(cache.get(&key), Some(PolicyDecision::Allow));
assert_eq!(cache.cache_hits(), 1);
assert_eq!(cache.cache_misses(), 0);
}
#[test]
fn cache_miss_is_counted() {
let cache = DecisionCache::new(128);
let key = CacheKey::new(&[2u8; 16], 1, &tool_action("deploy"));
assert_eq!(cache.get(&key), None);
assert_eq!(cache.cache_misses(), 1);
assert_eq!(cache.cache_hits(), 0);
}
#[test]
fn different_tool_names_produce_different_keys() {
let key_bash = CacheKey::new(&[1u8; 16], 1, &tool_action("bash"));
let key_deploy = CacheKey::new(&[1u8; 16], 1, &tool_action("deploy"));
assert_ne!(key_bash, key_deploy);
}
#[test]
fn different_epochs_produce_different_keys() {
let action = tool_action("bash");
let key_e1 = CacheKey::new(&[1u8; 16], 1, &action);
let key_e2 = CacheKey::new(&[1u8; 16], 2, &action);
assert_ne!(key_e1, key_e2);
}
}