halldyll_memory_model/storage/
redis.rs

1//! Redis cache layer
2
3use crate::core::{MemoryError, MemoryResult};
4use redis::aio::ConnectionManager;
5use redis::AsyncCommands;
6use serde::{de::DeserializeOwned, Serialize};
7
8/// Default TTL for cached items (1 hour)
9const DEFAULT_TTL_SECONDS: u64 = 3600;
10
11/// TTL for embeddings (24 hours)
12const EMBEDDING_TTL_SECONDS: u64 = 86400;
13
14/// TTL for session data (30 minutes)
15const SESSION_TTL_SECONDS: u64 = 1800;
16
17/// Redis caching operations
18pub struct RedisCache {
19    client: ConnectionManager,
20    key_prefix: String,
21}
22
23impl RedisCache {
24    /// Create new Redis cache
25    pub fn new(client: ConnectionManager) -> Self {
26        Self {
27            client,
28            key_prefix: "halldyll:".to_string(),
29        }
30    }
31
32    /// Create new Redis cache with custom prefix
33    pub fn with_prefix(client: ConnectionManager, prefix: impl Into<String>) -> Self {
34        Self {
35            client,
36            key_prefix: prefix.into(),
37        }
38    }
39
40    /// Build a full key with prefix
41    fn build_key(&self, key: &str) -> String {
42        format!("{}{}", self.key_prefix, key)
43    }
44
45    /// Cache an embedding vector
46    pub async fn cache_embedding(&self, key: &str, embedding: &[f32]) -> MemoryResult<()> {
47        let full_key = self.build_key(&format!("emb:{}", key));
48        let json = serde_json::to_string(embedding)
49            .map_err(|e| MemoryError::Serialization(e))?;
50
51        let mut conn = self.client.clone();
52        redis::cmd("SETEX")
53            .arg(&full_key)
54            .arg(EMBEDDING_TTL_SECONDS)
55            .arg(&json)
56            .query_async::<()>(&mut conn)
57            .await
58            .map_err(|e| MemoryError::Redis(e))?;
59
60        Ok(())
61    }
62
63    /// Get cached embedding
64    pub async fn get_embedding(&self, key: &str) -> MemoryResult<Option<Vec<f32>>> {
65        let full_key = self.build_key(&format!("emb:{}", key));
66        let mut conn = self.client.clone();
67
68        let result: Option<String> = conn.get(&full_key).await
69            .map_err(|e| MemoryError::Redis(e))?;
70
71        match result {
72            Some(json) => {
73                let embedding: Vec<f32> = serde_json::from_str(&json)
74                    .map_err(|e| MemoryError::Serialization(e))?;
75                Ok(Some(embedding))
76            }
77            None => Ok(None),
78        }
79    }
80
81    /// Cache a serializable value with default TTL
82    pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> MemoryResult<()> {
83        self.set_with_ttl(key, value, DEFAULT_TTL_SECONDS).await
84    }
85
86    /// Cache a serializable value with custom TTL
87    pub async fn set_with_ttl<T: Serialize>(
88        &self,
89        key: &str,
90        value: &T,
91        ttl_seconds: u64,
92    ) -> MemoryResult<()> {
93        let full_key = self.build_key(key);
94        let json = serde_json::to_string(value)
95            .map_err(|e| MemoryError::Serialization(e))?;
96
97        let mut conn = self.client.clone();
98        redis::cmd("SETEX")
99            .arg(&full_key)
100            .arg(ttl_seconds)
101            .arg(&json)
102            .query_async::<()>(&mut conn)
103            .await
104            .map_err(|e| MemoryError::Redis(e))?;
105
106        Ok(())
107    }
108
109    /// Get a cached value
110    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> MemoryResult<Option<T>> {
111        let full_key = self.build_key(key);
112        let mut conn = self.client.clone();
113
114        let result: Option<String> = conn.get(&full_key).await
115            .map_err(|e| MemoryError::Redis(e))?;
116
117        match result {
118            Some(json) => {
119                let value: T = serde_json::from_str(&json)
120                    .map_err(|e| MemoryError::Serialization(e))?;
121                Ok(Some(value))
122            }
123            None => Ok(None),
124        }
125    }
126
127    /// Invalidate cache key
128    pub async fn invalidate(&self, key: &str) -> MemoryResult<()> {
129        let full_key = self.build_key(key);
130        let mut conn = self.client.clone();
131
132        conn.del::<_, ()>(&full_key).await
133            .map_err(|e| MemoryError::Redis(e))?;
134
135        Ok(())
136    }
137
138    /// Invalidate all keys matching a pattern
139    pub async fn invalidate_pattern(&self, pattern: &str) -> MemoryResult<u64> {
140        let full_pattern = self.build_key(pattern);
141        let mut conn = self.client.clone();
142
143        // Get matching keys
144        let keys: Vec<String> = redis::cmd("KEYS")
145            .arg(&full_pattern)
146            .query_async(&mut conn)
147            .await
148            .map_err(|e| MemoryError::Redis(e))?;
149
150        if keys.is_empty() {
151            return Ok(0);
152        }
153
154        // Delete all matching keys
155        let count: u64 = conn.del(&keys).await
156            .map_err(|e| MemoryError::Redis(e))?;
157
158        Ok(count)
159    }
160
161    /// Check if a key exists
162    pub async fn exists(&self, key: &str) -> MemoryResult<bool> {
163        let full_key = self.build_key(key);
164        let mut conn = self.client.clone();
165
166        let exists: bool = conn.exists(&full_key).await
167            .map_err(|e| MemoryError::Redis(e))?;
168
169        Ok(exists)
170    }
171
172    /// Set expiration on an existing key
173    pub async fn expire(&self, key: &str, ttl_seconds: u64) -> MemoryResult<bool> {
174        let full_key = self.build_key(key);
175        let mut conn = self.client.clone();
176
177        let result: bool = conn.expire(&full_key, ttl_seconds as i64).await
178            .map_err(|e| MemoryError::Redis(e))?;
179
180        Ok(result)
181    }
182
183    /// Get TTL of a key
184    pub async fn ttl(&self, key: &str) -> MemoryResult<i64> {
185        let full_key = self.build_key(key);
186        let mut conn = self.client.clone();
187
188        let ttl: i64 = conn.ttl(&full_key).await
189            .map_err(|e| MemoryError::Redis(e))?;
190
191        Ok(ttl)
192    }
193
194    /// Increment a counter
195    pub async fn incr(&self, key: &str) -> MemoryResult<i64> {
196        let full_key = self.build_key(key);
197        let mut conn = self.client.clone();
198
199        let value: i64 = conn.incr(&full_key, 1).await
200            .map_err(|e| MemoryError::Redis(e))?;
201
202        Ok(value)
203    }
204
205    /// Store session data
206    pub async fn set_session<T: Serialize>(&self, session_id: &str, data: &T) -> MemoryResult<()> {
207        let key = format!("session:{}", session_id);
208        self.set_with_ttl(&key, data, SESSION_TTL_SECONDS).await
209    }
210
211    /// Get session data
212    pub async fn get_session<T: DeserializeOwned>(&self, session_id: &str) -> MemoryResult<Option<T>> {
213        let key = format!("session:{}", session_id);
214        self.get(&key).await
215    }
216
217    /// Extend session TTL
218    pub async fn extend_session(&self, session_id: &str) -> MemoryResult<bool> {
219        let key = format!("session:{}", session_id);
220        self.expire(&key, SESSION_TTL_SECONDS).await
221    }
222
223    /// Delete session
224    pub async fn delete_session(&self, session_id: &str) -> MemoryResult<()> {
225        let key = format!("session:{}", session_id);
226        self.invalidate(&key).await
227    }
228
229    /// Cache user memory list
230    pub async fn cache_user_memories(
231        &self,
232        user_id: &str,
233        memory_type: &str,
234        memory_ids: &[String],
235    ) -> MemoryResult<()> {
236        let key = format!("user:{}:memories:{}", user_id, memory_type);
237        // Convert slice to Vec for serialization
238        let ids_vec: Vec<String> = memory_ids.to_vec();
239        self.set(&key, &ids_vec).await
240    }
241
242    /// Get cached user memory list
243    pub async fn get_user_memories(
244        &self,
245        user_id: &str,
246        memory_type: &str,
247    ) -> MemoryResult<Option<Vec<String>>> {
248        let key = format!("user:{}:memories:{}", user_id, memory_type);
249        self.get(&key).await
250    }
251
252    /// Invalidate user memory cache
253    pub async fn invalidate_user_cache(&self, user_id: &str) -> MemoryResult<u64> {
254        let pattern = format!("user:{}:*", user_id);
255        self.invalidate_pattern(&pattern).await
256    }
257
258    /// Ping Redis to check connectivity
259    pub async fn ping(&self) -> MemoryResult<bool> {
260        let mut conn = self.client.clone();
261
262        let result: String = redis::cmd("PING")
263            .query_async(&mut conn)
264            .await
265            .map_err(|e| MemoryError::Redis(e))?;
266
267        Ok(result == "PONG")
268    }
269
270    /// Get Redis info
271    pub async fn info(&self) -> MemoryResult<String> {
272        let mut conn = self.client.clone();
273
274        let info: String = redis::cmd("INFO")
275            .query_async(&mut conn)
276            .await
277            .map_err(|e| MemoryError::Redis(e))?;
278
279        Ok(info)
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    // Note: These tests require a running Redis instance
286    // They are marked as ignored by default
287
288    #[test]
289    fn test_build_key() {
290        // Mock test without actual Redis connection
291        let key_prefix = "halldyll:";
292        let key = "test";
293        let full_key = format!("{}{}", key_prefix, key);
294        assert_eq!(full_key, "halldyll:test");
295    }
296
297    #[test]
298    fn test_embedding_key_format() {
299        let key_prefix = "halldyll:";
300        let key = "my_text";
301        let full_key = format!("{}emb:{}", key_prefix, key);
302        assert_eq!(full_key, "halldyll:emb:my_text");
303    }
304}