1use async_trait::async_trait;
4use redis::AsyncCommands;
5use std::sync::Arc;
6use std::time::Duration;
7
8use super::persistence::Persistence;
9use super::state::{Session, SessionId};
10use super::types::{QueueItem, SummarySnapshot};
11use super::{SessionError, SessionResult};
12use uuid::Uuid;
13
14#[derive(Clone, Debug)]
15pub struct RedisConfig {
16 pub key_prefix: String,
17 pub default_ttl: Option<Duration>,
18 pub connection_timeout: Duration,
19 pub response_timeout: Duration,
20}
21
22impl Default for RedisConfig {
23 fn default() -> Self {
24 Self {
25 key_prefix: "claude:session:".to_string(),
26 default_ttl: Some(Duration::from_secs(86400 * 7)),
27 connection_timeout: Duration::from_secs(10),
28 response_timeout: Duration::from_secs(30),
29 }
30 }
31}
32
33impl RedisConfig {
34 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
35 self.key_prefix = prefix.into();
36 self
37 }
38
39 pub fn with_ttl(mut self, ttl: Duration) -> Self {
40 self.default_ttl = Some(ttl);
41 self
42 }
43
44 pub fn without_ttl(mut self) -> Self {
45 self.default_ttl = None;
46 self
47 }
48}
49
50pub struct RedisPersistence {
51 client: Arc<redis::Client>,
52 config: RedisConfig,
53}
54
55impl RedisPersistence {
56 pub fn new(redis_url: &str) -> Result<Self, redis::RedisError> {
57 Self::from_config(redis_url, RedisConfig::default())
58 }
59
60 pub fn from_config(redis_url: &str, config: RedisConfig) -> Result<Self, redis::RedisError> {
61 let client = redis::Client::open(redis_url)?;
62 Ok(Self {
63 client: Arc::new(client),
64 config,
65 })
66 }
67
68 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
69 self.config.key_prefix = prefix.into();
70 self
71 }
72
73 pub fn with_ttl(mut self, ttl: Duration) -> Self {
74 self.config.default_ttl = Some(ttl);
75 self
76 }
77
78 pub fn without_ttl(mut self) -> Self {
79 self.config.default_ttl = None;
80 self
81 }
82
83 fn session_key(&self, id: &SessionId) -> String {
84 format!("{}{}", self.config.key_prefix, id)
85 }
86
87 fn tenant_key(&self, tenant_id: &str) -> String {
88 format!("{}tenant:{}", self.config.key_prefix, tenant_id)
89 }
90
91 fn children_key(&self, parent_id: &SessionId) -> String {
92 format!("{}children:{}", self.config.key_prefix, parent_id)
93 }
94
95 fn summaries_key(&self, session_id: &SessionId) -> String {
96 format!("{}summaries:{}", self.config.key_prefix, session_id)
97 }
98
99 fn queue_key(&self, session_id: &SessionId) -> String {
100 format!("{}queue:{}", self.config.key_prefix, session_id)
101 }
102
103 async fn get_connection(&self) -> SessionResult<redis::aio::MultiplexedConnection> {
104 tokio::time::timeout(
105 self.config.connection_timeout,
106 self.client.get_multiplexed_async_connection(),
107 )
108 .await
109 .map_err(|_| SessionError::Storage {
110 message: "Redis connection timeout".into(),
111 })?
112 .map_err(|e| SessionError::Storage {
113 message: e.to_string(),
114 })
115 }
116}
117
118#[async_trait]
119impl Persistence for RedisPersistence {
120 fn name(&self) -> &str {
121 "redis"
122 }
123
124 async fn save(&self, session: &Session) -> SessionResult<()> {
125 let mut conn = self.get_connection().await?;
126 let key = self.session_key(&session.id);
127 let data = serde_json::to_string(session).map_err(SessionError::Serialization)?;
128
129 let ttl_secs = session
130 .config
131 .ttl_secs
132 .or_else(|| self.config.default_ttl.map(|d| d.as_secs()));
133
134 match ttl_secs {
135 Some(ttl) => {
136 conn.set_ex::<_, _, ()>(&key, &data, ttl)
137 .await
138 .map_err(|e| SessionError::Storage {
139 message: e.to_string(),
140 })?;
141 }
142 None => {
143 conn.set::<_, _, ()>(&key, &data)
144 .await
145 .map_err(|e| SessionError::Storage {
146 message: e.to_string(),
147 })?;
148 }
149 }
150
151 if let Some(ref tenant_id) = session.tenant_id {
152 conn.sadd::<_, _, ()>(&self.tenant_key(tenant_id), session.id.to_string())
153 .await
154 .map_err(|e| SessionError::Storage {
155 message: e.to_string(),
156 })?;
157 }
158
159 if let Some(parent_id) = session.parent_id {
160 conn.sadd::<_, _, ()>(&self.children_key(&parent_id), session.id.to_string())
161 .await
162 .map_err(|e| SessionError::Storage {
163 message: e.to_string(),
164 })?;
165 }
166
167 Ok(())
168 }
169
170 async fn load(&self, id: &SessionId) -> SessionResult<Option<Session>> {
171 let mut conn = self.get_connection().await?;
172 let key = self.session_key(id);
173
174 let data: Option<String> = conn.get(&key).await.map_err(|e| SessionError::Storage {
175 message: e.to_string(),
176 })?;
177
178 match data {
179 Some(json) => {
180 let session: Session =
181 serde_json::from_str(&json).map_err(SessionError::Serialization)?;
182 Ok(Some(session))
183 }
184 None => Ok(None),
185 }
186 }
187
188 async fn delete(&self, id: &SessionId) -> SessionResult<bool> {
189 let mut conn = self.get_connection().await?;
190 let key = self.session_key(id);
191
192 if let Some(session) = self.load(id).await?
193 && let Some(ref tenant_id) = session.tenant_id
194 {
195 conn.srem::<_, _, ()>(&self.tenant_key(tenant_id), id.to_string())
196 .await
197 .map_err(|e| SessionError::Storage {
198 message: e.to_string(),
199 })?;
200 }
201
202 conn.del::<_, ()>(&self.summaries_key(id))
203 .await
204 .map_err(|e| SessionError::Storage {
205 message: e.to_string(),
206 })?;
207 conn.del::<_, ()>(&self.queue_key(id))
208 .await
209 .map_err(|e| SessionError::Storage {
210 message: e.to_string(),
211 })?;
212
213 let deleted: i32 = conn.del(&key).await.map_err(|e| SessionError::Storage {
214 message: e.to_string(),
215 })?;
216
217 Ok(deleted > 0)
218 }
219
220 async fn list(&self, tenant_id: Option<&str>) -> SessionResult<Vec<SessionId>> {
221 let mut conn = self.get_connection().await?;
222
223 match tenant_id {
224 Some(tid) => {
225 let ids: Vec<String> = conn.smembers(self.tenant_key(tid)).await.map_err(|e| {
226 SessionError::Storage {
227 message: e.to_string(),
228 }
229 })?;
230 Ok(ids.into_iter().map(SessionId::from).collect())
231 }
232 None => {
233 let pattern = format!("{}*", self.config.key_prefix);
234 let mut cursor: u64 = 0;
235 let mut all_ids = Vec::new();
236
237 loop {
238 let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
239 .arg(cursor)
240 .arg("MATCH")
241 .arg(&pattern)
242 .arg("COUNT")
243 .arg(100)
244 .query_async(&mut conn)
245 .await
246 .map_err(|e| SessionError::Storage {
247 message: e.to_string(),
248 })?;
249
250 for key in keys {
251 if let Some(id) = key.strip_prefix(&self.config.key_prefix)
252 && !id.contains(':')
253 {
254 all_ids.push(SessionId::from(id));
255 }
256 }
257
258 cursor = next_cursor;
259 if cursor == 0 {
260 break;
261 }
262 }
263
264 Ok(all_ids)
265 }
266 }
267 }
268
269 async fn list_children(&self, parent_id: &SessionId) -> SessionResult<Vec<SessionId>> {
270 let mut conn = self.get_connection().await?;
271 let ids: Vec<String> = conn
272 .smembers(self.children_key(parent_id))
273 .await
274 .map_err(|e| SessionError::Storage {
275 message: e.to_string(),
276 })?;
277 Ok(ids.into_iter().map(SessionId::from).collect())
278 }
279
280 async fn add_summary(&self, snapshot: SummarySnapshot) -> SessionResult<()> {
281 let mut conn = self.get_connection().await?;
282 let key = self.summaries_key(&snapshot.session_id);
283 let data = serde_json::to_string(&snapshot).map_err(SessionError::Serialization)?;
284
285 conn.rpush::<_, _, ()>(&key, &data)
286 .await
287 .map_err(|e| SessionError::Storage {
288 message: e.to_string(),
289 })?;
290
291 Ok(())
292 }
293
294 async fn get_summaries(&self, session_id: &SessionId) -> SessionResult<Vec<SummarySnapshot>> {
295 let mut conn = self.get_connection().await?;
296 let key = self.summaries_key(session_id);
297
298 let items: Vec<String> =
299 conn.lrange(&key, 0, -1)
300 .await
301 .map_err(|e| SessionError::Storage {
302 message: e.to_string(),
303 })?;
304
305 items
306 .into_iter()
307 .map(|json| serde_json::from_str(&json).map_err(SessionError::Serialization))
308 .collect()
309 }
310
311 async fn enqueue(
312 &self,
313 session_id: &SessionId,
314 content: String,
315 priority: i32,
316 ) -> SessionResult<QueueItem> {
317 let mut conn = self.get_connection().await?;
318 let key = self.queue_key(session_id);
319 let item = QueueItem::enqueue(*session_id, &content).with_priority(priority);
320 let data = serde_json::to_string(&item).map_err(SessionError::Serialization)?;
321
322 conn.zadd::<_, _, _, ()>(&key, &data, -(priority as f64))
323 .await
324 .map_err(|e| SessionError::Storage {
325 message: e.to_string(),
326 })?;
327
328 Ok(item)
329 }
330
331 async fn dequeue(&self, session_id: &SessionId) -> SessionResult<Option<QueueItem>> {
332 let mut conn = self.get_connection().await?;
333 let key = self.queue_key(session_id);
334
335 let items: Vec<String> =
336 conn.zpopmin(&key, 1)
337 .await
338 .map_err(|e| SessionError::Storage {
339 message: e.to_string(),
340 })?;
341
342 if items.is_empty() {
343 return Ok(None);
344 }
345
346 let json = &items[0];
347 let mut item: QueueItem =
348 serde_json::from_str(json).map_err(SessionError::Serialization)?;
349 item.start_processing();
350 Ok(Some(item))
351 }
352
353 async fn cancel_queued(&self, item_id: Uuid) -> SessionResult<bool> {
354 let mut conn = self.get_connection().await?;
355 let pattern = format!("{}queue:*", self.config.key_prefix);
356
357 let mut cursor: u64 = 0;
358 loop {
359 let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
360 .arg(cursor)
361 .arg("MATCH")
362 .arg(&pattern)
363 .arg("COUNT")
364 .arg(100)
365 .query_async(&mut conn)
366 .await
367 .map_err(|e| SessionError::Storage {
368 message: e.to_string(),
369 })?;
370
371 for key in keys {
372 let items: Vec<String> =
373 conn.zrange(&key, 0, -1)
374 .await
375 .map_err(|e| SessionError::Storage {
376 message: e.to_string(),
377 })?;
378
379 for json in items {
380 if let Ok(item) = serde_json::from_str::<QueueItem>(&json)
381 && item.id == item_id
382 {
383 let removed: i32 =
384 conn.zrem(&key, &json)
385 .await
386 .map_err(|e| SessionError::Storage {
387 message: e.to_string(),
388 })?;
389 return Ok(removed > 0);
390 }
391 }
392 }
393
394 cursor = next_cursor;
395 if cursor == 0 {
396 break;
397 }
398 }
399
400 Ok(false)
401 }
402
403 async fn pending_queue(&self, session_id: &SessionId) -> SessionResult<Vec<QueueItem>> {
404 let mut conn = self.get_connection().await?;
405 let key = self.queue_key(session_id);
406
407 let items: Vec<String> =
408 conn.zrange(&key, 0, -1)
409 .await
410 .map_err(|e| SessionError::Storage {
411 message: e.to_string(),
412 })?;
413
414 items
415 .into_iter()
416 .map(|json| serde_json::from_str(&json).map_err(SessionError::Serialization))
417 .collect()
418 }
419
420 async fn cleanup_expired(&self) -> SessionResult<usize> {
421 Ok(0)
422 }
423}