Skip to main content

aa_storage_redis/
session.rs

1//! [`SessionStore`] backed by a Redis hash per session.
2
3use aa_storage::{AgentId, Result, SessionId, SessionRecord, SessionStore, StorageError};
4use async_trait::async_trait;
5use deadpool_redis::Pool;
6
7use crate::error::backend;
8use crate::util::hex16;
9
10/// Time-to-live applied to a session record on every
11/// [`save`](SessionStore::save), via Redis `EXPIRE`.
12///
13/// One hour. An actively re-saved session never lapses; an abandoned one is
14/// reclaimed automatically.
15pub const SESSION_TTL_SECS: u64 = 3600;
16
17/// Redis-backed [`SessionStore`].
18///
19/// Each record is a hash at `aa:session:<session_id>` holding the raw
20/// `agent_id` bytes and `started_at_ns`. See the [crate](crate) docs for the
21/// full key layout and TTL semantics. Cheap to [`Clone`] — clones share the
22/// underlying [`Pool`].
23#[derive(Clone)]
24pub struct RedisSessionStore {
25    pool: Pool,
26}
27
28impl RedisSessionStore {
29    /// Create a store over an existing connection pool.
30    pub fn new(pool: Pool) -> Self {
31        Self { pool }
32    }
33}
34
35fn session_key(id: &SessionId) -> String {
36    format!("aa:session:{}", hex16(id.as_bytes()))
37}
38
39#[async_trait]
40impl SessionStore for RedisSessionStore {
41    async fn save(&self, session: SessionRecord) -> Result<()> {
42        let mut conn = self.pool.get().await.map_err(backend)?;
43        let key = session_key(&session.session_id);
44        let _: () = redis::cmd("HSET")
45            .arg(&key)
46            .arg("agent_id")
47            .arg(&session.agent_id.as_bytes()[..])
48            .arg("started_at_ns")
49            .arg(session.started_at_ns)
50            .query_async(&mut conn)
51            .await
52            .map_err(backend)?;
53        let _: () = redis::cmd("EXPIRE")
54            .arg(&key)
55            .arg(SESSION_TTL_SECS)
56            .query_async(&mut conn)
57            .await
58            .map_err(backend)?;
59        Ok(())
60    }
61
62    async fn load(&self, session_id: &SessionId) -> Result<SessionRecord> {
63        let mut conn = self.pool.get().await.map_err(backend)?;
64        let key = session_key(session_id);
65        let (agent_bytes, started_at_ns): (Option<Vec<u8>>, Option<u64>) = redis::cmd("HMGET")
66            .arg(&key)
67            .arg("agent_id")
68            .arg("started_at_ns")
69            .query_async(&mut conn)
70            .await
71            .map_err(backend)?;
72        let agent_bytes = agent_bytes.ok_or_else(|| StorageError::NotFound(format!("session {key}")))?;
73        let started_at_ns = started_at_ns.ok_or_else(|| StorageError::NotFound(format!("session {key}")))?;
74        let agent_id: [u8; 16] = agent_bytes
75            .try_into()
76            .map_err(|_| StorageError::Serialization("session agent_id is not 16 bytes".to_owned()))?;
77        Ok(SessionRecord {
78            session_id: *session_id,
79            agent_id: AgentId::from_bytes(agent_id),
80            started_at_ns,
81        })
82    }
83
84    async fn delete(&self, session_id: &SessionId) -> Result<()> {
85        let mut conn = self.pool.get().await.map_err(backend)?;
86        let _: () = redis::cmd("DEL")
87            .arg(session_key(session_id))
88            .query_async(&mut conn)
89            .await
90            .map_err(backend)?;
91        Ok(())
92    }
93}