Skip to main content

aa_storage_redis/
policy.rs

1//! [`PolicyStore`] read-through cache backed by Redis JSON values.
2
3use aa_storage::{AgentId, PolicyDocument, PolicyStore, Result, StorageError};
4use async_trait::async_trait;
5use deadpool_redis::Pool;
6
7use crate::error::backend;
8use crate::util::hex16;
9
10/// Suggested default TTL, in seconds, for a cached policy entry.
11///
12/// Passed to [`RedisPolicyStore::cache_policy`] by callers that do not have a
13/// policy-specific TTL of their own.
14pub const DEFAULT_POLICY_CACHE_TTL_SECS: u64 = 300;
15
16/// Redis-backed read-through [`PolicyStore`].
17///
18/// [`get_policy`](PolicyStore::get_policy) reads a JSON [`PolicyDocument`] from
19/// `aa:policy:<agent_id>` and returns
20/// [`NotFound`](aa_storage::StorageError::NotFound) on a cache miss — callers
21/// fall through to the authoritative store and then repopulate the cache with
22/// [`cache_policy`](Self::cache_policy).
23/// [`invalidate`](PolicyStore::invalidate) deletes the cached key. Cheap to
24/// [`Clone`] — clones share the underlying [`Pool`].
25#[derive(Clone)]
26pub struct RedisPolicyStore {
27    pool: Pool,
28}
29
30impl RedisPolicyStore {
31    /// Create a store over an existing connection pool.
32    pub fn new(pool: Pool) -> Self {
33        Self { pool }
34    }
35
36    /// Populate the cache for `agent_id` with `policy`, expiring after
37    /// `ttl_secs` seconds (`SET ... EX`).
38    ///
39    /// This is the write half of the read-through cache: callers invoke it
40    /// after loading a policy from the authoritative store on a
41    /// [`get_policy`](PolicyStore::get_policy) miss. See
42    /// [`DEFAULT_POLICY_CACHE_TTL_SECS`] for the suggested default TTL.
43    pub async fn cache_policy(&self, agent_id: &AgentId, policy: &PolicyDocument, ttl_secs: u64) -> Result<()> {
44        let mut conn = self.pool.get().await.map_err(backend)?;
45        let payload = serde_json::to_string(policy).map_err(|e| StorageError::Serialization(e.to_string()))?;
46        let _: () = redis::cmd("SET")
47            .arg(policy_key(agent_id))
48            .arg(payload)
49            .arg("EX")
50            .arg(ttl_secs)
51            .query_async(&mut conn)
52            .await
53            .map_err(backend)?;
54        Ok(())
55    }
56}
57
58fn policy_key(agent_id: &AgentId) -> String {
59    format!("aa:policy:{}", hex16(agent_id.as_bytes()))
60}
61
62#[async_trait]
63impl PolicyStore for RedisPolicyStore {
64    async fn get_policy(&self, agent_id: &AgentId) -> Result<PolicyDocument> {
65        let mut conn = self.pool.get().await.map_err(backend)?;
66        let raw: Option<String> = redis::cmd("GET")
67            .arg(policy_key(agent_id))
68            .query_async(&mut conn)
69            .await
70            .map_err(backend)?;
71        let raw =
72            raw.ok_or_else(|| StorageError::NotFound(format!("policy for agent {}", hex16(agent_id.as_bytes()))))?;
73        serde_json::from_str(&raw).map_err(|e| StorageError::Serialization(e.to_string()))
74    }
75
76    async fn invalidate(&self, agent_id: &AgentId) -> Result<()> {
77        let mut conn = self.pool.get().await.map_err(backend)?;
78        let _: () = redis::cmd("DEL")
79            .arg(policy_key(agent_id))
80            .query_async(&mut conn)
81            .await
82            .map_err(backend)?;
83        Ok(())
84    }
85}