use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use aa_core::policy::EnforcementMode;
use aa_core::storage::{AgentId, PolicyDocument, PolicyStore, Result, StorageError};
use async_trait::async_trait;
#[must_use]
pub fn sample_policy(version: u32) -> PolicyDocument {
PolicyDocument {
version,
name: "sample".to_owned(),
rules: Vec::new(),
enforcement_mode: EnforcementMode::default(),
}
}
#[derive(Default)]
pub struct MemoryPolicyStore {
policies: HashMap<[u8; 16], PolicyDocument>,
calls: AtomicUsize,
delay: Option<Duration>,
}
impl MemoryPolicyStore {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_policy(agent_id: AgentId, policy: PolicyDocument) -> Self {
let mut store = Self::new();
store.insert(agent_id, policy);
store
}
pub fn insert(&mut self, agent_id: AgentId, policy: PolicyDocument) {
self.policies.insert(*agent_id.as_bytes(), policy);
}
#[must_use]
pub fn with_delay(mut self, delay: Duration) -> Self {
self.delay = Some(delay);
self
}
#[must_use]
pub fn call_count(&self) -> usize {
self.calls.load(Ordering::SeqCst)
}
}
#[async_trait]
impl PolicyStore for MemoryPolicyStore {
async fn get_policy(&self, agent_id: &AgentId) -> Result<PolicyDocument> {
self.calls.fetch_add(1, Ordering::SeqCst);
if let Some(delay) = self.delay {
tokio::time::sleep(delay).await;
}
self.policies
.get(agent_id.as_bytes())
.cloned()
.ok_or_else(|| StorageError::NotFound(format!("{:?}", agent_id.as_bytes())))
}
async fn invalidate(&self, _agent_id: &AgentId) -> Result<()> {
Ok(())
}
}