reinhardt_auth/sessions/backends/
cache.rs1use async_trait::async_trait;
33use chrono::{DateTime, Utc};
34use reinhardt_utils::cache::{Cache, InMemoryCache};
35use serde::{Deserialize, Serialize};
36use std::sync::Arc;
37use thiserror::Error;
38
39use crate::sessions::cleanup::{CleanupableBackend, SessionMetadata};
40
41#[non_exhaustive]
43#[derive(Debug, Clone, PartialEq, Eq, Error)]
44pub enum SessionError {
45 #[error("Cache error: {0}")]
47 CacheError(String),
48 #[error("Serialization error: {0}")]
50 SerializationError(String),
51 #[error("Session has expired due to inactivity")]
53 SessionExpired,
54}
55
56#[async_trait]
58pub trait SessionBackend: Send + Sync + Clone {
59 async fn load<T>(&self, session_key: &str) -> Result<Option<T>, SessionError>
61 where
62 T: for<'de> Deserialize<'de> + Serialize + Send + Sync;
63
64 async fn save<T>(
66 &self,
67 session_key: &str,
68 data: &T,
69 ttl: Option<u64>,
70 ) -> Result<(), SessionError>
71 where
72 T: Serialize + Send + Sync;
73
74 async fn delete(&self, session_key: &str) -> Result<(), SessionError>;
76
77 async fn exists(&self, session_key: &str) -> Result<bool, SessionError>;
79}
80
81#[derive(Clone)]
109pub struct InMemorySessionBackend {
110 cache: Arc<InMemoryCache>,
111}
112
113impl InMemorySessionBackend {
114 pub fn new() -> Self {
116 Self {
117 cache: Arc::new(InMemoryCache::new()),
118 }
119 }
120}
121
122impl Default for InMemorySessionBackend {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128#[async_trait]
129impl SessionBackend for InMemorySessionBackend {
130 async fn load<T>(&self, session_key: &str) -> Result<Option<T>, SessionError>
131 where
132 T: for<'de> Deserialize<'de> + Serialize + Send + Sync,
133 {
134 self.cache
135 .get(session_key)
136 .await
137 .map_err(|e| SessionError::CacheError(e.to_string()))
138 }
139
140 async fn save<T>(
141 &self,
142 session_key: &str,
143 data: &T,
144 ttl: Option<u64>,
145 ) -> Result<(), SessionError>
146 where
147 T: Serialize + Send + Sync,
148 {
149 let duration = ttl.map(std::time::Duration::from_secs);
150 self.cache
151 .set(session_key, data, duration)
152 .await
153 .map_err(|e| SessionError::CacheError(e.to_string()))
154 }
155
156 async fn delete(&self, session_key: &str) -> Result<(), SessionError> {
157 self.cache
158 .delete(session_key)
159 .await
160 .map_err(|e| SessionError::CacheError(e.to_string()))
161 }
162
163 async fn exists(&self, session_key: &str) -> Result<bool, SessionError> {
164 self.cache
165 .has_key(session_key)
166 .await
167 .map_err(|e| SessionError::CacheError(e.to_string()))
168 }
169}
170
171#[async_trait]
172impl CleanupableBackend for InMemorySessionBackend {
173 async fn get_all_keys(&self) -> Result<Vec<String>, SessionError> {
178 Ok(self.cache.list_keys().await)
179 }
180
181 async fn get_metadata(
186 &self,
187 session_key: &str,
188 ) -> Result<Option<SessionMetadata>, SessionError> {
189 match self.cache.inspect_entry_with_timestamps(session_key).await {
190 Ok(Some((created, accessed))) => Ok(Some(SessionMetadata {
191 created_at: DateTime::<Utc>::from(created),
192 last_accessed: accessed.map(DateTime::<Utc>::from),
193 })),
194 Ok(None) => Ok(None),
195 Err(e) => Err(SessionError::CacheError(e.to_string())),
196 }
197 }
198}
199
200#[derive(Clone)]
232pub struct CacheSessionBackend<C: Cache + Clone> {
233 cache: Arc<C>,
234}
235
236impl<C: Cache + Clone> CacheSessionBackend<C> {
237 pub fn new(cache: Arc<C>) -> Self {
239 Self { cache }
240 }
241}
242
243#[async_trait]
244impl<C: Cache + Clone + 'static> SessionBackend for CacheSessionBackend<C> {
245 async fn load<T>(&self, session_key: &str) -> Result<Option<T>, SessionError>
246 where
247 T: for<'de> Deserialize<'de> + Serialize + Send + Sync,
248 {
249 self.cache
250 .get(session_key)
251 .await
252 .map_err(|e| SessionError::CacheError(e.to_string()))
253 }
254
255 async fn save<T>(
256 &self,
257 session_key: &str,
258 data: &T,
259 ttl: Option<u64>,
260 ) -> Result<(), SessionError>
261 where
262 T: Serialize + Send + Sync,
263 {
264 let duration = ttl.map(std::time::Duration::from_secs);
265 self.cache
266 .set(session_key, data, duration)
267 .await
268 .map_err(|e| SessionError::CacheError(e.to_string()))
269 }
270
271 async fn delete(&self, session_key: &str) -> Result<(), SessionError> {
272 self.cache
273 .delete(session_key)
274 .await
275 .map_err(|e| SessionError::CacheError(e.to_string()))
276 }
277
278 async fn exists(&self, session_key: &str) -> Result<bool, SessionError> {
279 self.cache
280 .has_key(session_key)
281 .await
282 .map_err(|e| SessionError::CacheError(e.to_string()))
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use rstest::rstest;
290 use serde_json::json;
291 use std::collections::HashMap;
292
293 #[rstest]
294 #[tokio::test]
295 async fn test_in_memory_save_and_load_roundtrip() {
296 let backend = InMemorySessionBackend::new();
298 let mut data = HashMap::new();
299 data.insert("user_id".to_string(), json!(42));
300 data.insert("username".to_string(), json!("alice"));
301
302 backend.save("sess_1", &data, Some(3600)).await.unwrap();
304 let loaded: Option<HashMap<String, serde_json::Value>> =
305 backend.load("sess_1").await.unwrap();
306
307 let loaded = loaded.unwrap();
309 assert_eq!(loaded["user_id"], json!(42));
310 assert_eq!(loaded["username"], json!("alice"));
311 }
312
313 #[rstest]
314 #[tokio::test]
315 async fn test_in_memory_load_nonexistent_key() {
316 let backend = InMemorySessionBackend::new();
318
319 let loaded: Option<serde_json::Value> = backend.load("nonexistent").await.unwrap();
321
322 assert!(loaded.is_none());
324 }
325
326 #[rstest]
327 #[tokio::test]
328 async fn test_in_memory_delete_removes_session() {
329 let backend = InMemorySessionBackend::new();
331 let data = json!({"key": "value"});
332 backend.save("sess_del", &data, Some(3600)).await.unwrap();
333
334 backend.delete("sess_del").await.unwrap();
336 let loaded: Option<serde_json::Value> = backend.load("sess_del").await.unwrap();
337
338 assert!(loaded.is_none());
340 }
341
342 #[rstest]
343 #[tokio::test]
344 async fn test_in_memory_exists_reflects_state() {
345 let backend = InMemorySessionBackend::new();
347 let data = json!({"active": true});
348
349 assert!(!backend.exists("sess_ex").await.unwrap());
351
352 backend.save("sess_ex", &data, Some(3600)).await.unwrap();
354
355 assert!(backend.exists("sess_ex").await.unwrap());
357
358 backend.delete("sess_ex").await.unwrap();
360
361 assert!(!backend.exists("sess_ex").await.unwrap());
363 }
364
365 #[rstest]
366 #[tokio::test]
367 async fn test_in_memory_save_overwrites_existing() {
368 let backend = InMemorySessionBackend::new();
370 let data_v1 = json!({"version": 1});
371 let data_v2 = json!({"version": 2});
372
373 backend.save("sess_ow", &data_v1, Some(3600)).await.unwrap();
375 backend.save("sess_ow", &data_v2, Some(3600)).await.unwrap();
376 let loaded: Option<serde_json::Value> = backend.load("sess_ow").await.unwrap();
377
378 assert_eq!(loaded.unwrap()["version"], 2);
380 }
381
382 #[rstest]
383 #[tokio::test]
384 async fn test_in_memory_save_with_ttl() {
385 let backend = InMemorySessionBackend::new();
387 let data = json!({"ttl_test": true});
388
389 backend.save("sess_ttl", &data, Some(60)).await.unwrap();
391 let loaded: Option<serde_json::Value> = backend.load("sess_ttl").await.unwrap();
392
393 assert_eq!(loaded.unwrap()["ttl_test"], true);
395 }
396
397 #[rstest]
398 #[tokio::test]
399 async fn test_cache_backend_wrapper_save_and_load() {
400 let cache = Arc::new(InMemoryCache::new());
402 let backend = CacheSessionBackend::new(cache);
403 let data = json!({"wrapped": "value", "count": 99});
404
405 backend
407 .save("wrapped_sess", &data, Some(3600))
408 .await
409 .unwrap();
410 let loaded: Option<serde_json::Value> = backend.load("wrapped_sess").await.unwrap();
411
412 let loaded = loaded.unwrap();
414 assert_eq!(loaded["wrapped"], "value");
415 assert_eq!(loaded["count"], 99);
416 }
417
418 #[rstest]
419 #[tokio::test]
420 async fn test_cache_backend_wrapper_delete_and_exists() {
421 let cache = Arc::new(InMemoryCache::new());
423 let backend = CacheSessionBackend::new(cache);
424 let data = json!({"item": "to_delete"});
425
426 backend.save("wrap_del", &data, Some(3600)).await.unwrap();
428 assert!(backend.exists("wrap_del").await.unwrap());
429
430 backend.delete("wrap_del").await.unwrap();
432
433 assert!(!backend.exists("wrap_del").await.unwrap());
435 let loaded: Option<serde_json::Value> = backend.load("wrap_del").await.unwrap();
436 assert!(loaded.is_none());
437 }
438}