use aa_storage::{AgentId, PolicyDocument, PolicyStore, Result, StorageError};
use async_trait::async_trait;
use deadpool_redis::Pool;
use crate::error::backend;
use crate::util::hex16;
pub const DEFAULT_POLICY_CACHE_TTL_SECS: u64 = 300;
#[derive(Clone)]
pub struct RedisPolicyStore {
pool: Pool,
}
impl RedisPolicyStore {
pub fn new(pool: Pool) -> Self {
Self { pool }
}
pub async fn cache_policy(&self, agent_id: &AgentId, policy: &PolicyDocument, ttl_secs: u64) -> Result<()> {
let mut conn = self.pool.get().await.map_err(backend)?;
let payload = serde_json::to_string(policy).map_err(|e| StorageError::Serialization(e.to_string()))?;
let _: () = redis::cmd("SET")
.arg(policy_key(agent_id))
.arg(payload)
.arg("EX")
.arg(ttl_secs)
.query_async(&mut conn)
.await
.map_err(backend)?;
Ok(())
}
}
fn policy_key(agent_id: &AgentId) -> String {
format!("aa:policy:{}", hex16(agent_id.as_bytes()))
}
#[async_trait]
impl PolicyStore for RedisPolicyStore {
async fn get_policy(&self, agent_id: &AgentId) -> Result<PolicyDocument> {
let mut conn = self.pool.get().await.map_err(backend)?;
let raw: Option<String> = redis::cmd("GET")
.arg(policy_key(agent_id))
.query_async(&mut conn)
.await
.map_err(backend)?;
let raw =
raw.ok_or_else(|| StorageError::NotFound(format!("policy for agent {}", hex16(agent_id.as_bytes()))))?;
serde_json::from_str(&raw).map_err(|e| StorageError::Serialization(e.to_string()))
}
async fn invalidate(&self, agent_id: &AgentId) -> Result<()> {
let mut conn = self.pool.get().await.map_err(backend)?;
let _: () = redis::cmd("DEL")
.arg(policy_key(agent_id))
.query_async(&mut conn)
.await
.map_err(backend)?;
Ok(())
}
}