halldyll_memory_model/storage/
redis.rs1use crate::core::{MemoryError, MemoryResult};
4use redis::aio::ConnectionManager;
5use redis::AsyncCommands;
6use serde::{de::DeserializeOwned, Serialize};
7
8const DEFAULT_TTL_SECONDS: u64 = 3600;
10
11const EMBEDDING_TTL_SECONDS: u64 = 86400;
13
14const SESSION_TTL_SECONDS: u64 = 1800;
16
17pub struct RedisCache {
19 client: ConnectionManager,
20 key_prefix: String,
21}
22
23impl RedisCache {
24 pub fn new(client: ConnectionManager) -> Self {
26 Self {
27 client,
28 key_prefix: "halldyll:".to_string(),
29 }
30 }
31
32 pub fn with_prefix(client: ConnectionManager, prefix: impl Into<String>) -> Self {
34 Self {
35 client,
36 key_prefix: prefix.into(),
37 }
38 }
39
40 fn build_key(&self, key: &str) -> String {
42 format!("{}{}", self.key_prefix, key)
43 }
44
45 pub async fn cache_embedding(&self, key: &str, embedding: &[f32]) -> MemoryResult<()> {
47 let full_key = self.build_key(&format!("emb:{}", key));
48 let json = serde_json::to_string(embedding)
49 .map_err(|e| MemoryError::Serialization(e))?;
50
51 let mut conn = self.client.clone();
52 redis::cmd("SETEX")
53 .arg(&full_key)
54 .arg(EMBEDDING_TTL_SECONDS)
55 .arg(&json)
56 .query_async::<()>(&mut conn)
57 .await
58 .map_err(|e| MemoryError::Redis(e))?;
59
60 Ok(())
61 }
62
63 pub async fn get_embedding(&self, key: &str) -> MemoryResult<Option<Vec<f32>>> {
65 let full_key = self.build_key(&format!("emb:{}", key));
66 let mut conn = self.client.clone();
67
68 let result: Option<String> = conn.get(&full_key).await
69 .map_err(|e| MemoryError::Redis(e))?;
70
71 match result {
72 Some(json) => {
73 let embedding: Vec<f32> = serde_json::from_str(&json)
74 .map_err(|e| MemoryError::Serialization(e))?;
75 Ok(Some(embedding))
76 }
77 None => Ok(None),
78 }
79 }
80
81 pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> MemoryResult<()> {
83 self.set_with_ttl(key, value, DEFAULT_TTL_SECONDS).await
84 }
85
86 pub async fn set_with_ttl<T: Serialize>(
88 &self,
89 key: &str,
90 value: &T,
91 ttl_seconds: u64,
92 ) -> MemoryResult<()> {
93 let full_key = self.build_key(key);
94 let json = serde_json::to_string(value)
95 .map_err(|e| MemoryError::Serialization(e))?;
96
97 let mut conn = self.client.clone();
98 redis::cmd("SETEX")
99 .arg(&full_key)
100 .arg(ttl_seconds)
101 .arg(&json)
102 .query_async::<()>(&mut conn)
103 .await
104 .map_err(|e| MemoryError::Redis(e))?;
105
106 Ok(())
107 }
108
109 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> MemoryResult<Option<T>> {
111 let full_key = self.build_key(key);
112 let mut conn = self.client.clone();
113
114 let result: Option<String> = conn.get(&full_key).await
115 .map_err(|e| MemoryError::Redis(e))?;
116
117 match result {
118 Some(json) => {
119 let value: T = serde_json::from_str(&json)
120 .map_err(|e| MemoryError::Serialization(e))?;
121 Ok(Some(value))
122 }
123 None => Ok(None),
124 }
125 }
126
127 pub async fn invalidate(&self, key: &str) -> MemoryResult<()> {
129 let full_key = self.build_key(key);
130 let mut conn = self.client.clone();
131
132 conn.del::<_, ()>(&full_key).await
133 .map_err(|e| MemoryError::Redis(e))?;
134
135 Ok(())
136 }
137
138 pub async fn invalidate_pattern(&self, pattern: &str) -> MemoryResult<u64> {
140 let full_pattern = self.build_key(pattern);
141 let mut conn = self.client.clone();
142
143 let keys: Vec<String> = redis::cmd("KEYS")
145 .arg(&full_pattern)
146 .query_async(&mut conn)
147 .await
148 .map_err(|e| MemoryError::Redis(e))?;
149
150 if keys.is_empty() {
151 return Ok(0);
152 }
153
154 let count: u64 = conn.del(&keys).await
156 .map_err(|e| MemoryError::Redis(e))?;
157
158 Ok(count)
159 }
160
161 pub async fn exists(&self, key: &str) -> MemoryResult<bool> {
163 let full_key = self.build_key(key);
164 let mut conn = self.client.clone();
165
166 let exists: bool = conn.exists(&full_key).await
167 .map_err(|e| MemoryError::Redis(e))?;
168
169 Ok(exists)
170 }
171
172 pub async fn expire(&self, key: &str, ttl_seconds: u64) -> MemoryResult<bool> {
174 let full_key = self.build_key(key);
175 let mut conn = self.client.clone();
176
177 let result: bool = conn.expire(&full_key, ttl_seconds as i64).await
178 .map_err(|e| MemoryError::Redis(e))?;
179
180 Ok(result)
181 }
182
183 pub async fn ttl(&self, key: &str) -> MemoryResult<i64> {
185 let full_key = self.build_key(key);
186 let mut conn = self.client.clone();
187
188 let ttl: i64 = conn.ttl(&full_key).await
189 .map_err(|e| MemoryError::Redis(e))?;
190
191 Ok(ttl)
192 }
193
194 pub async fn incr(&self, key: &str) -> MemoryResult<i64> {
196 let full_key = self.build_key(key);
197 let mut conn = self.client.clone();
198
199 let value: i64 = conn.incr(&full_key, 1).await
200 .map_err(|e| MemoryError::Redis(e))?;
201
202 Ok(value)
203 }
204
205 pub async fn set_session<T: Serialize>(&self, session_id: &str, data: &T) -> MemoryResult<()> {
207 let key = format!("session:{}", session_id);
208 self.set_with_ttl(&key, data, SESSION_TTL_SECONDS).await
209 }
210
211 pub async fn get_session<T: DeserializeOwned>(&self, session_id: &str) -> MemoryResult<Option<T>> {
213 let key = format!("session:{}", session_id);
214 self.get(&key).await
215 }
216
217 pub async fn extend_session(&self, session_id: &str) -> MemoryResult<bool> {
219 let key = format!("session:{}", session_id);
220 self.expire(&key, SESSION_TTL_SECONDS).await
221 }
222
223 pub async fn delete_session(&self, session_id: &str) -> MemoryResult<()> {
225 let key = format!("session:{}", session_id);
226 self.invalidate(&key).await
227 }
228
229 pub async fn cache_user_memories(
231 &self,
232 user_id: &str,
233 memory_type: &str,
234 memory_ids: &[String],
235 ) -> MemoryResult<()> {
236 let key = format!("user:{}:memories:{}", user_id, memory_type);
237 let ids_vec: Vec<String> = memory_ids.to_vec();
239 self.set(&key, &ids_vec).await
240 }
241
242 pub async fn get_user_memories(
244 &self,
245 user_id: &str,
246 memory_type: &str,
247 ) -> MemoryResult<Option<Vec<String>>> {
248 let key = format!("user:{}:memories:{}", user_id, memory_type);
249 self.get(&key).await
250 }
251
252 pub async fn invalidate_user_cache(&self, user_id: &str) -> MemoryResult<u64> {
254 let pattern = format!("user:{}:*", user_id);
255 self.invalidate_pattern(&pattern).await
256 }
257
258 pub async fn ping(&self) -> MemoryResult<bool> {
260 let mut conn = self.client.clone();
261
262 let result: String = redis::cmd("PING")
263 .query_async(&mut conn)
264 .await
265 .map_err(|e| MemoryError::Redis(e))?;
266
267 Ok(result == "PONG")
268 }
269
270 pub async fn info(&self) -> MemoryResult<String> {
272 let mut conn = self.client.clone();
273
274 let info: String = redis::cmd("INFO")
275 .query_async(&mut conn)
276 .await
277 .map_err(|e| MemoryError::Redis(e))?;
278
279 Ok(info)
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 #[test]
289 fn test_build_key() {
290 let key_prefix = "halldyll:";
292 let key = "test";
293 let full_key = format!("{}{}", key_prefix, key);
294 assert_eq!(full_key, "halldyll:test");
295 }
296
297 #[test]
298 fn test_embedding_key_format() {
299 let key_prefix = "halldyll:";
300 let key = "my_text";
301 let full_key = format!("{}emb:{}", key_prefix, key);
302 assert_eq!(full_key, "halldyll:emb:my_text");
303 }
304}