1use std::collections::HashMap;
12use std::sync::Mutex;
13
14use async_trait::async_trait;
15use chrono::{DateTime, Duration as ChronoDuration, Utc};
16
17use crate::pem::state::CognitiveState;
18
19#[derive(Debug)]
21pub enum PersistenceError {
22 NotFound {
23 session_id: String,
24 },
25 Expired {
26 session_id: String,
27 expired_at: DateTime<Utc>,
28 },
29 Backend(String),
30}
31
32impl std::fmt::Display for PersistenceError {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::NotFound { session_id } => {
36 write!(f, "cognitive state not found: {session_id:?}")
37 }
38 Self::Expired {
39 session_id,
40 expired_at,
41 } => write!(
42 f,
43 "cognitive state for {session_id:?} expired at {expired_at}"
44 ),
45 Self::Backend(m) => write!(f, "backend: {m}"),
46 }
47 }
48}
49
50impl std::error::Error for PersistenceError {}
51
52#[async_trait]
57pub trait PersistenceBackend: Send + Sync {
58 async fn persist(
62 &self,
63 session_id: &str,
64 state: &CognitiveState,
65 ttl: ChronoDuration,
66 ) -> Result<(), PersistenceError>;
67
68 async fn restore(
73 &self,
74 session_id: &str,
75 ) -> Result<CognitiveState, PersistenceError>;
76
77 async fn evict(
80 &self,
81 session_id: &str,
82 ) -> Result<(), PersistenceError>;
83
84 async fn evict_expired(
89 &self,
90 before: DateTime<Utc>,
91 ) -> Result<u64, PersistenceError>;
92}
93
94struct InMemoryEntry {
97 state: CognitiveState,
98 expires_at: DateTime<Utc>,
99}
100
101#[derive(Debug, Default)]
104pub struct InMemoryBackend {
105 inner: Mutex<HashMap<String, StoredEntry>>,
106}
107
108struct StoredEntry {
110 state_bytes: Vec<u8>,
111 expires_at: DateTime<Utc>,
112}
113
114impl std::fmt::Debug for StoredEntry {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 f.debug_struct("StoredEntry")
117 .field("state_bytes_len", &self.state_bytes.len())
118 .field("expires_at", &self.expires_at)
119 .finish()
120 }
121}
122
123impl InMemoryBackend {
124 pub fn new() -> Self {
125 Default::default()
126 }
127
128 pub fn len(&self) -> usize {
129 self.inner.lock().expect("poisoned").len()
130 }
131
132 pub fn is_empty(&self) -> bool {
133 self.len() == 0
134 }
135}
136
137#[async_trait]
138impl PersistenceBackend for InMemoryBackend {
139 async fn persist(
140 &self,
141 session_id: &str,
142 state: &CognitiveState,
143 ttl: ChronoDuration,
144 ) -> Result<(), PersistenceError> {
145 let expires_at = Utc::now() + ttl;
146 let bytes = state.encode();
147 let mut guard = self.inner.lock().expect("poisoned");
148 guard.insert(
149 session_id.to_string(),
150 StoredEntry {
151 state_bytes: bytes,
152 expires_at,
153 },
154 );
155 Ok(())
156 }
157
158 async fn restore(
159 &self,
160 session_id: &str,
161 ) -> Result<CognitiveState, PersistenceError> {
162 let guard = self.inner.lock().expect("poisoned");
163 let entry = guard
164 .get(session_id)
165 .ok_or(PersistenceError::NotFound {
166 session_id: session_id.to_string(),
167 })?;
168 if entry.expires_at <= Utc::now() {
169 return Err(PersistenceError::Expired {
170 session_id: session_id.to_string(),
171 expired_at: entry.expires_at,
172 });
173 }
174 CognitiveState::decode(&entry.state_bytes).map_err(|e| {
175 PersistenceError::Backend(format!(
176 "decode failed for {session_id:?}: {e}"
177 ))
178 })
179 }
180
181 async fn evict(
182 &self,
183 session_id: &str,
184 ) -> Result<(), PersistenceError> {
185 let mut guard = self.inner.lock().expect("poisoned");
186 guard.remove(session_id);
187 Ok(())
188 }
189
190 async fn evict_expired(
191 &self,
192 before: DateTime<Utc>,
193 ) -> Result<u64, PersistenceError> {
194 let mut guard = self.inner.lock().expect("poisoned");
195 let expired: Vec<String> = guard
196 .iter()
197 .filter(|(_, e)| e.expires_at <= before)
198 .map(|(k, _)| k.clone())
199 .collect();
200 let count = expired.len() as u64;
201 for k in expired {
202 guard.remove(&k);
203 }
204 Ok(count)
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use crate::pem::state::{CognitiveState, FixedPoint};
212 use chrono::Duration;
213
214 fn make_state() -> CognitiveState {
215 let mut s = CognitiveState::new("sess-1", "alpha", "flow-1");
216 s.density_matrix = vec![FixedPoint::vec_from_f64(&[0.1, 0.9])];
217 s
218 }
219
220 #[tokio::test]
221 async fn persist_then_restore_roundtrip() {
222 let b = InMemoryBackend::new();
223 let state = make_state();
224 b.persist(&state.session_id, &state, Duration::minutes(15))
225 .await
226 .unwrap();
227 let restored = b.restore(&state.session_id).await.unwrap();
228 assert_eq!(restored, state);
229 }
230
231 #[tokio::test]
232 async fn restore_unknown_session_returns_not_found() {
233 let b = InMemoryBackend::new();
234 let err = b.restore("missing").await.unwrap_err();
235 matches!(err, PersistenceError::NotFound { .. });
236 }
237
238 #[tokio::test]
239 async fn restore_expired_session_returns_expired() {
240 let b = InMemoryBackend::new();
241 let state = make_state();
242 b.persist(&state.session_id, &state, Duration::seconds(-1))
244 .await
245 .unwrap();
246 let err = b.restore(&state.session_id).await.unwrap_err();
247 matches!(err, PersistenceError::Expired { .. });
248 }
249
250 #[tokio::test]
251 async fn evict_is_idempotent() {
252 let b = InMemoryBackend::new();
253 b.evict("nothing-here").await.unwrap();
254
255 let state = make_state();
256 b.persist(&state.session_id, &state, Duration::minutes(5))
257 .await
258 .unwrap();
259 b.evict(&state.session_id).await.unwrap();
260 b.evict(&state.session_id).await.unwrap();
261 let err = b.restore(&state.session_id).await.unwrap_err();
262 matches!(err, PersistenceError::NotFound { .. });
263 }
264
265 #[tokio::test]
266 async fn evict_expired_removes_only_stale_rows() {
267 let b = InMemoryBackend::new();
268 let mut stale = make_state();
269 stale.session_id = "stale".into();
270 let mut fresh = make_state();
271 fresh.session_id = "fresh".into();
272
273 b.persist(&stale.session_id, &stale, Duration::seconds(-10))
274 .await
275 .unwrap();
276 b.persist(&fresh.session_id, &fresh, Duration::minutes(15))
277 .await
278 .unwrap();
279
280 let removed = b.evict_expired(Utc::now()).await.unwrap();
281 assert_eq!(removed, 1);
282
283 b.restore(&fresh.session_id).await.unwrap();
285 let err = b.restore(&stale.session_id).await.unwrap_err();
286 matches!(err, PersistenceError::NotFound { .. });
287 }
288
289 #[tokio::test]
290 async fn len_tracks_live_entries() {
291 let b = InMemoryBackend::new();
292 assert!(b.is_empty());
293 let s = make_state();
294 b.persist(&s.session_id, &s, Duration::minutes(5)).await.unwrap();
295 assert_eq!(b.len(), 1);
296 b.evict(&s.session_id).await.unwrap();
297 assert_eq!(b.len(), 0);
298 }
299}