Skip to main content

rain_engine_core/
state_cache.rs

1use crate::{SessionSnapshot, StateProjectionCache};
2use async_trait::async_trait;
3use redis::AsyncCommands;
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7#[derive(Clone, Default)]
8pub struct InMemoryStateCache {
9    cache: Arc<RwLock<HashMap<String, SessionSnapshot>>>,
10}
11
12impl InMemoryStateCache {
13    pub fn new() -> Self {
14        Self::default()
15    }
16}
17
18#[async_trait]
19impl StateProjectionCache for InMemoryStateCache {
20    async fn get_projection(&self, session_id: &str) -> Result<Option<SessionSnapshot>, String> {
21        let lock = self.cache.read().map_err(|e| e.to_string())?;
22        Ok(lock.get(session_id).cloned())
23    }
24
25    async fn set_projection(
26        &self,
27        session_id: &str,
28        snapshot: SessionSnapshot,
29    ) -> Result<(), String> {
30        let mut lock = self.cache.write().map_err(|e| e.to_string())?;
31        lock.insert(session_id.to_string(), snapshot);
32        Ok(())
33    }
34
35    async fn invalidate(&self, session_id: &str) -> Result<(), String> {
36        let mut lock = self.cache.write().map_err(|e| e.to_string())?;
37        lock.remove(session_id);
38        Ok(())
39    }
40}
41
42pub struct ValkeyStateCache {
43    client: redis::Client,
44    prefix: String,
45}
46
47impl ValkeyStateCache {
48    pub fn new(url: &str, prefix: &str) -> Result<Self, String> {
49        let client = redis::Client::open(url).map_err(|e| e.to_string())?;
50        Ok(Self {
51            client,
52            prefix: prefix.to_string(),
53        })
54    }
55
56    fn key(&self, session_id: &str) -> String {
57        format!("{}:state:{}", self.prefix, session_id)
58    }
59}
60
61#[async_trait]
62impl StateProjectionCache for ValkeyStateCache {
63    async fn get_projection(&self, session_id: &str) -> Result<Option<SessionSnapshot>, String> {
64        let mut conn = self
65            .client
66            .get_multiplexed_async_connection()
67            .await
68            .map_err(|e| e.to_string())?;
69        let key = self.key(session_id);
70        let val: Option<Vec<u8>> = conn.get(&key).await.map_err(|e| e.to_string())?;
71
72        match val {
73            Some(bytes) => {
74                let snapshot = serde_json::from_slice(&bytes)
75                    .map_err(|e| format!("De-serialize error: {}", e))?;
76                Ok(Some(snapshot))
77            }
78            None => Ok(None),
79        }
80    }
81
82    async fn set_projection(
83        &self,
84        session_id: &str,
85        snapshot: SessionSnapshot,
86    ) -> Result<(), String> {
87        let mut conn = self
88            .client
89            .get_multiplexed_async_connection()
90            .await
91            .map_err(|e| e.to_string())?;
92        let key = self.key(session_id);
93        let bytes = serde_json::to_vec(&snapshot).map_err(|e| e.to_string())?;
94
95        // Expire snapshots after 24 hours to prevent cache ballooning
96        let _: () = conn
97            .set_ex(&key, bytes, 60 * 60 * 24)
98            .await
99            .map_err(|e| e.to_string())?;
100        Ok(())
101    }
102
103    async fn invalidate(&self, session_id: &str) -> Result<(), String> {
104        let mut conn = self
105            .client
106            .get_multiplexed_async_connection()
107            .await
108            .map_err(|e| e.to_string())?;
109        let key = self.key(session_id);
110        let _: () = conn.del(&key).await.map_err(|e| e.to_string())?;
111        Ok(())
112    }
113}