use std::collections::HashMap;
use std::sync::RwLock;
use serde::Serialize;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApprovalDecision {
Approved,
ApprovedForSession,
Denied,
}
impl ApprovalDecision {
pub fn is_approved(self) -> bool {
matches!(self, Self::Approved | Self::ApprovedForSession)
}
}
#[derive(Debug, Default)]
pub struct ApprovalStore {
cache: RwLock<HashMap<String, ApprovalDecision>>,
}
impl ApprovalStore {
pub fn new() -> Self {
Self {
cache: RwLock::new(HashMap::new()),
}
}
pub fn get<K: Serialize>(&self, key: &K) -> Option<ApprovalDecision> {
let serialized = serde_json::to_string(key).ok()?;
let cache = self.cache.read().ok()?;
cache.get(&serialized).copied()
}
pub fn put<K: Serialize>(&self, key: K, decision: ApprovalDecision) {
if let Ok(serialized) = serde_json::to_string(&key) {
if let Ok(mut cache) = self.cache.write() {
cache.insert(serialized, decision);
}
}
}
pub fn check_all<K: Serialize>(&self, keys: &[K]) -> Option<ApprovalDecision> {
let cache = self.cache.read().ok()?;
let mut all_approved = true;
let mut any_session = false;
for key in keys {
let serialized = serde_json::to_string(key).ok()?;
match cache.get(&serialized) {
Some(ApprovalDecision::ApprovedForSession) => {
any_session = true;
}
Some(ApprovalDecision::Approved) => {}
_ => {
all_approved = false;
break;
}
}
}
if all_approved {
if any_session {
Some(ApprovalDecision::ApprovedForSession)
} else {
Some(ApprovalDecision::Approved)
}
} else {
None
}
}
pub fn put_all<K: Serialize>(&self, keys: &[K], decision: ApprovalDecision) {
if let Ok(mut cache) = self.cache.write() {
for key in keys {
if let Ok(serialized) = serde_json::to_string(key) {
cache.insert(serialized, decision);
}
}
}
}
pub fn clear(&self) {
if let Ok(mut cache) = self.cache.write() {
cache.clear();
}
}
pub fn len(&self) -> usize {
self.cache.read().map(|c| c.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub async fn with_cached_approval<K, F, Fut>(
store: &ApprovalStore,
keys: Vec<K>,
fetch: F,
) -> ApprovalDecision
where
K: Serialize + Clone,
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = ApprovalDecision>,
{
if let Some(cached) = store.check_all(&keys) {
tracing::trace!("ApprovalStore: cache hit for {} key(s)", keys.len());
return cached;
}
let decision = fetch().await;
if decision.is_approved() {
store.put_all(&keys, decision);
}
decision
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_store() {
let store = ApprovalStore::new();
assert!(store.is_empty());
assert_eq!(store.len(), 0);
assert!(store.get(&"anything").is_none());
}
#[test]
fn test_put_and_get() {
let store = ApprovalStore::new();
store.put(&("Bash", "cargo test"), ApprovalDecision::Approved);
let decision = store.get(&("Bash", "cargo test")).unwrap();
assert_eq!(decision, ApprovalDecision::Approved);
assert!(decision.is_approved());
}
#[test]
fn test_denied_not_approved() {
let store = ApprovalStore::new();
store.put(&"dangerous_op", ApprovalDecision::Denied);
let decision = store.get(&"dangerous_op").unwrap();
assert!(!decision.is_approved());
}
#[test]
fn test_different_keys_independent() {
let store = ApprovalStore::new();
store.put(&("Bash", "echo hi"), ApprovalDecision::Approved);
store.put(&("Bash", "rm -rf /"), ApprovalDecision::Denied);
assert!(store.get(&("Bash", "echo hi")).unwrap().is_approved());
assert!(!store.get(&("Bash", "rm -rf /")).unwrap().is_approved());
assert!(store.get(&("Bash", "other")).is_none());
}
#[test]
fn test_check_all_all_approved() {
let store = ApprovalStore::new();
store.put(&"/tmp/a.rs", ApprovalDecision::Approved);
store.put(&"/tmp/b.rs", ApprovalDecision::ApprovedForSession);
let keys = vec!["/tmp/a.rs", "/tmp/b.rs"];
let decision = store.check_all(&keys).unwrap();
assert_eq!(decision, ApprovalDecision::ApprovedForSession);
}
#[test]
fn test_check_all_missing_key() {
let store = ApprovalStore::new();
store.put(&"/tmp/a.rs", ApprovalDecision::Approved);
let keys = vec!["/tmp/a.rs", "/tmp/c.rs"];
assert!(store.check_all(&keys).is_none());
}
#[test]
fn test_check_all_denied_key() {
let store = ApprovalStore::new();
store.put(&"/tmp/a.rs", ApprovalDecision::Approved);
store.put(&"/tmp/b.rs", ApprovalDecision::Denied);
let keys = vec!["/tmp/a.rs", "/tmp/b.rs"];
assert!(store.check_all(&keys).is_none());
}
#[test]
fn test_put_all() {
let store = ApprovalStore::new();
let keys = vec!["/tmp/a.rs", "/tmp/b.rs", "/tmp/c.rs"];
store.put_all(&keys, ApprovalDecision::ApprovedForSession);
assert_eq!(store.len(), 3);
for key in &keys {
let d = store.get(key).unwrap();
assert_eq!(d, ApprovalDecision::ApprovedForSession);
}
}
#[test]
fn test_clear() {
let store = ApprovalStore::new();
store.put(&"a", ApprovalDecision::Approved);
store.put(&"b", ApprovalDecision::Denied);
assert_eq!(store.len(), 2);
store.clear();
assert!(store.is_empty());
assert!(store.get(&"a").is_none());
}
#[test]
fn test_overwrite() {
let store = ApprovalStore::new();
store.put(&"cmd", ApprovalDecision::Denied);
assert!(!store.get(&"cmd").unwrap().is_approved());
store.put(&"cmd", ApprovalDecision::ApprovedForSession);
assert!(store.get(&"cmd").unwrap().is_approved());
}
#[tokio::test]
async fn test_with_cached_approval_cache_hit() {
let store = ApprovalStore::new();
store.put(&"key1", ApprovalDecision::Approved);
let decision = with_cached_approval(&store, vec!["key1"], || async {
panic!("should not be called — cache hit");
})
.await;
assert_eq!(decision, ApprovalDecision::Approved);
}
#[tokio::test]
async fn test_with_cached_approval_cache_miss() {
let store = ApprovalStore::new();
let decision = with_cached_approval(&store, vec!["key2"], || async {
ApprovalDecision::ApprovedForSession
})
.await;
assert_eq!(decision, ApprovalDecision::ApprovedForSession);
assert!(store.get(&"key2").unwrap().is_approved());
}
#[tokio::test]
async fn test_with_cached_approval_denied_not_cached() {
let store = ApprovalStore::new();
let decision =
with_cached_approval(&store, vec!["key3"], || async { ApprovalDecision::Denied }).await;
assert_eq!(decision, ApprovalDecision::Denied);
assert!(store.get(&"key3").is_none());
}
#[tokio::test]
async fn test_with_cached_approval_multi_key() {
let store = ApprovalStore::new();
store.put(&"file_a", ApprovalDecision::Approved);
let decision = with_cached_approval(&store, vec!["file_a", "file_b"], || async {
ApprovalDecision::ApprovedForSession
})
.await;
assert_eq!(decision, ApprovalDecision::ApprovedForSession);
assert!(store.get(&"file_a").unwrap().is_approved());
assert!(store.get(&"file_b").unwrap().is_approved());
}
}