Skip to main content

axon/pem/
backend.rs

1//! [`PersistenceBackend`] — async trait for cognitive-state persistence.
2//!
3//! Shipped impls:
4//! - [`InMemoryBackend`]: dev + test, single-process, no TTL eviction
5//!   (tests snapshot + restore within a unit test's lifetime).
6//!
7//! Production impl lives in `axon_enterprise::cognitive_states` —
8//! Postgres + envelope-encrypted rows, worker-driven eviction that
9//! cryptoshreds envelope keys on expiry.
10
11use std::collections::HashMap;
12use std::sync::Mutex;
13
14use async_trait::async_trait;
15use chrono::{DateTime, Duration as ChronoDuration, Utc};
16
17use crate::pem::state::CognitiveState;
18
19/// Errors every backend speaks.
20#[derive(Debug)]
21pub enum PersistenceError {
22    NotFound {
23        session_id: String,
24    },
25    Expired {
26        session_id: String,
27        expired_at: DateTime<Utc>,
28    },
29    Backend(String),
30}
31
32impl std::fmt::Display for PersistenceError {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::NotFound { session_id } => {
36                write!(f, "cognitive state not found: {session_id:?}")
37            }
38            Self::Expired {
39                session_id,
40                expired_at,
41            } => write!(
42                f,
43                "cognitive state for {session_id:?} expired at {expired_at}"
44            ),
45            Self::Backend(m) => write!(f, "backend: {m}"),
46        }
47    }
48}
49
50impl std::error::Error for PersistenceError {}
51
52/// Minimal interface every backend implements. Adopters who need
53/// richer querying (list by tenant, filter by subject) extend in
54/// their own trait — this core surface is the contract Axon
55/// itself depends on.
56#[async_trait]
57pub trait PersistenceBackend: Send + Sync {
58    /// Persist the state under `session_id`. `ttl` is advisory; the
59    /// backend is free to honour it via a scheduled eviction job
60    /// (the Postgres impl) or a best-effort timer (in-memory).
61    async fn persist(
62        &self,
63        session_id: &str,
64        state: &CognitiveState,
65        ttl: ChronoDuration,
66    ) -> Result<(), PersistenceError>;
67
68    /// Fetch a previously persisted state. Returns
69    /// [`PersistenceError::NotFound`] when no record exists and
70    /// [`PersistenceError::Expired`] when one exists but its TTL
71    /// lapsed.
72    async fn restore(
73        &self,
74        session_id: &str,
75    ) -> Result<CognitiveState, PersistenceError>;
76
77    /// Irreversibly delete the state. Idempotent — no-op when the
78    /// session has no stored state.
79    async fn evict(
80        &self,
81        session_id: &str,
82    ) -> Result<(), PersistenceError>;
83
84    /// Evict every state whose TTL lapsed at or before `before`.
85    /// Returns the count of rows removed for observability. Called
86    /// periodically by the eviction worker in production; in-memory
87    /// impl sweeps on-demand.
88    async fn evict_expired(
89        &self,
90        before: DateTime<Utc>,
91    ) -> Result<u64, PersistenceError>;
92}
93
94// ── InMemoryBackend ─────────────────────────────────────────────────
95
96struct InMemoryEntry {
97    state: CognitiveState,
98    expires_at: DateTime<Utc>,
99}
100
101/// Single-process, Mutex-guarded. Rejects stale fetches so the
102/// same semantics as the Postgres impl hold in tests.
103#[derive(Debug, Default)]
104pub struct InMemoryBackend {
105    inner: Mutex<HashMap<String, StoredEntry>>,
106}
107
108// Private struct so callers don't reach past the trait.
109struct StoredEntry {
110    state_bytes: Vec<u8>,
111    expires_at: DateTime<Utc>,
112}
113
114impl std::fmt::Debug for StoredEntry {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        f.debug_struct("StoredEntry")
117            .field("state_bytes_len", &self.state_bytes.len())
118            .field("expires_at", &self.expires_at)
119            .finish()
120    }
121}
122
123impl InMemoryBackend {
124    pub fn new() -> Self {
125        Default::default()
126    }
127
128    pub fn len(&self) -> usize {
129        self.inner.lock().expect("poisoned").len()
130    }
131
132    pub fn is_empty(&self) -> bool {
133        self.len() == 0
134    }
135}
136
137#[async_trait]
138impl PersistenceBackend for InMemoryBackend {
139    async fn persist(
140        &self,
141        session_id: &str,
142        state: &CognitiveState,
143        ttl: ChronoDuration,
144    ) -> Result<(), PersistenceError> {
145        let expires_at = Utc::now() + ttl;
146        let bytes = state.encode();
147        let mut guard = self.inner.lock().expect("poisoned");
148        guard.insert(
149            session_id.to_string(),
150            StoredEntry {
151                state_bytes: bytes,
152                expires_at,
153            },
154        );
155        Ok(())
156    }
157
158    async fn restore(
159        &self,
160        session_id: &str,
161    ) -> Result<CognitiveState, PersistenceError> {
162        let guard = self.inner.lock().expect("poisoned");
163        let entry = guard
164            .get(session_id)
165            .ok_or(PersistenceError::NotFound {
166                session_id: session_id.to_string(),
167            })?;
168        if entry.expires_at <= Utc::now() {
169            return Err(PersistenceError::Expired {
170                session_id: session_id.to_string(),
171                expired_at: entry.expires_at,
172            });
173        }
174        CognitiveState::decode(&entry.state_bytes).map_err(|e| {
175            PersistenceError::Backend(format!(
176                "decode failed for {session_id:?}: {e}"
177            ))
178        })
179    }
180
181    async fn evict(
182        &self,
183        session_id: &str,
184    ) -> Result<(), PersistenceError> {
185        let mut guard = self.inner.lock().expect("poisoned");
186        guard.remove(session_id);
187        Ok(())
188    }
189
190    async fn evict_expired(
191        &self,
192        before: DateTime<Utc>,
193    ) -> Result<u64, PersistenceError> {
194        let mut guard = self.inner.lock().expect("poisoned");
195        let expired: Vec<String> = guard
196            .iter()
197            .filter(|(_, e)| e.expires_at <= before)
198            .map(|(k, _)| k.clone())
199            .collect();
200        let count = expired.len() as u64;
201        for k in expired {
202            guard.remove(&k);
203        }
204        Ok(count)
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::pem::state::{CognitiveState, FixedPoint};
212    use chrono::Duration;
213
214    fn make_state() -> CognitiveState {
215        let mut s = CognitiveState::new("sess-1", "alpha", "flow-1");
216        s.density_matrix = vec![FixedPoint::vec_from_f64(&[0.1, 0.9])];
217        s
218    }
219
220    #[tokio::test]
221    async fn persist_then_restore_roundtrip() {
222        let b = InMemoryBackend::new();
223        let state = make_state();
224        b.persist(&state.session_id, &state, Duration::minutes(15))
225            .await
226            .unwrap();
227        let restored = b.restore(&state.session_id).await.unwrap();
228        assert_eq!(restored, state);
229    }
230
231    #[tokio::test]
232    async fn restore_unknown_session_returns_not_found() {
233        let b = InMemoryBackend::new();
234        let err = b.restore("missing").await.unwrap_err();
235        matches!(err, PersistenceError::NotFound { .. });
236    }
237
238    #[tokio::test]
239    async fn restore_expired_session_returns_expired() {
240        let b = InMemoryBackend::new();
241        let state = make_state();
242        // Persist with negative TTL so the entry is already stale.
243        b.persist(&state.session_id, &state, Duration::seconds(-1))
244            .await
245            .unwrap();
246        let err = b.restore(&state.session_id).await.unwrap_err();
247        matches!(err, PersistenceError::Expired { .. });
248    }
249
250    #[tokio::test]
251    async fn evict_is_idempotent() {
252        let b = InMemoryBackend::new();
253        b.evict("nothing-here").await.unwrap();
254
255        let state = make_state();
256        b.persist(&state.session_id, &state, Duration::minutes(5))
257            .await
258            .unwrap();
259        b.evict(&state.session_id).await.unwrap();
260        b.evict(&state.session_id).await.unwrap();
261        let err = b.restore(&state.session_id).await.unwrap_err();
262        matches!(err, PersistenceError::NotFound { .. });
263    }
264
265    #[tokio::test]
266    async fn evict_expired_removes_only_stale_rows() {
267        let b = InMemoryBackend::new();
268        let mut stale = make_state();
269        stale.session_id = "stale".into();
270        let mut fresh = make_state();
271        fresh.session_id = "fresh".into();
272
273        b.persist(&stale.session_id, &stale, Duration::seconds(-10))
274            .await
275            .unwrap();
276        b.persist(&fresh.session_id, &fresh, Duration::minutes(15))
277            .await
278            .unwrap();
279
280        let removed = b.evict_expired(Utc::now()).await.unwrap();
281        assert_eq!(removed, 1);
282
283        // Fresh still there, stale gone.
284        b.restore(&fresh.session_id).await.unwrap();
285        let err = b.restore(&stale.session_id).await.unwrap_err();
286        matches!(err, PersistenceError::NotFound { .. });
287    }
288
289    #[tokio::test]
290    async fn len_tracks_live_entries() {
291        let b = InMemoryBackend::new();
292        assert!(b.is_empty());
293        let s = make_state();
294        b.persist(&s.session_id, &s, Duration::minutes(5)).await.unwrap();
295        assert_eq!(b.len(), 1);
296        b.evict(&s.session_id).await.unwrap();
297        assert_eq!(b.len(), 0);
298    }
299}