claude_agent/session/
persistence_redis.rs

1//! Redis persistence backend for sessions.
2
3use async_trait::async_trait;
4use redis::AsyncCommands;
5use std::sync::Arc;
6use std::time::Duration;
7
8use super::persistence::Persistence;
9use super::state::{Session, SessionId};
10use super::types::{QueueItem, SummarySnapshot};
11use super::{SessionError, SessionResult};
12use uuid::Uuid;
13
14#[derive(Clone, Debug)]
15pub struct RedisConfig {
16    pub key_prefix: String,
17    pub default_ttl: Option<Duration>,
18    pub connection_timeout: Duration,
19    pub response_timeout: Duration,
20}
21
22impl Default for RedisConfig {
23    fn default() -> Self {
24        Self {
25            key_prefix: "claude:session:".to_string(),
26            default_ttl: Some(Duration::from_secs(86400 * 7)),
27            connection_timeout: Duration::from_secs(10),
28            response_timeout: Duration::from_secs(30),
29        }
30    }
31}
32
33impl RedisConfig {
34    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
35        self.key_prefix = prefix.into();
36        self
37    }
38
39    pub fn with_ttl(mut self, ttl: Duration) -> Self {
40        self.default_ttl = Some(ttl);
41        self
42    }
43
44    pub fn without_ttl(mut self) -> Self {
45        self.default_ttl = None;
46        self
47    }
48}
49
50pub struct RedisPersistence {
51    client: Arc<redis::Client>,
52    config: RedisConfig,
53}
54
55impl RedisPersistence {
56    pub fn new(redis_url: &str) -> Result<Self, redis::RedisError> {
57        Self::from_config(redis_url, RedisConfig::default())
58    }
59
60    pub fn from_config(redis_url: &str, config: RedisConfig) -> Result<Self, redis::RedisError> {
61        let client = redis::Client::open(redis_url)?;
62        Ok(Self {
63            client: Arc::new(client),
64            config,
65        })
66    }
67
68    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
69        self.config.key_prefix = prefix.into();
70        self
71    }
72
73    pub fn with_ttl(mut self, ttl: Duration) -> Self {
74        self.config.default_ttl = Some(ttl);
75        self
76    }
77
78    pub fn without_ttl(mut self) -> Self {
79        self.config.default_ttl = None;
80        self
81    }
82
83    fn session_key(&self, id: &SessionId) -> String {
84        format!("{}{}", self.config.key_prefix, id)
85    }
86
87    fn tenant_key(&self, tenant_id: &str) -> String {
88        format!("{}tenant:{}", self.config.key_prefix, tenant_id)
89    }
90
91    fn children_key(&self, parent_id: &SessionId) -> String {
92        format!("{}children:{}", self.config.key_prefix, parent_id)
93    }
94
95    fn summaries_key(&self, session_id: &SessionId) -> String {
96        format!("{}summaries:{}", self.config.key_prefix, session_id)
97    }
98
99    fn queue_key(&self, session_id: &SessionId) -> String {
100        format!("{}queue:{}", self.config.key_prefix, session_id)
101    }
102
103    async fn get_connection(&self) -> SessionResult<redis::aio::MultiplexedConnection> {
104        tokio::time::timeout(
105            self.config.connection_timeout,
106            self.client.get_multiplexed_async_connection(),
107        )
108        .await
109        .map_err(|_| SessionError::Storage {
110            message: "Redis connection timeout".into(),
111        })?
112        .map_err(|e| SessionError::Storage {
113            message: e.to_string(),
114        })
115    }
116}
117
118#[async_trait]
119impl Persistence for RedisPersistence {
120    fn name(&self) -> &str {
121        "redis"
122    }
123
124    async fn save(&self, session: &Session) -> SessionResult<()> {
125        let mut conn = self.get_connection().await?;
126        let key = self.session_key(&session.id);
127        let data = serde_json::to_string(session).map_err(SessionError::Serialization)?;
128
129        let ttl_secs = session
130            .config
131            .ttl_secs
132            .or_else(|| self.config.default_ttl.map(|d| d.as_secs()));
133
134        match ttl_secs {
135            Some(ttl) => {
136                conn.set_ex::<_, _, ()>(&key, &data, ttl)
137                    .await
138                    .map_err(|e| SessionError::Storage {
139                        message: e.to_string(),
140                    })?;
141            }
142            None => {
143                conn.set::<_, _, ()>(&key, &data)
144                    .await
145                    .map_err(|e| SessionError::Storage {
146                        message: e.to_string(),
147                    })?;
148            }
149        }
150
151        if let Some(ref tenant_id) = session.tenant_id {
152            conn.sadd::<_, _, ()>(&self.tenant_key(tenant_id), session.id.to_string())
153                .await
154                .map_err(|e| SessionError::Storage {
155                    message: e.to_string(),
156                })?;
157        }
158
159        if let Some(parent_id) = session.parent_id {
160            conn.sadd::<_, _, ()>(&self.children_key(&parent_id), session.id.to_string())
161                .await
162                .map_err(|e| SessionError::Storage {
163                    message: e.to_string(),
164                })?;
165        }
166
167        Ok(())
168    }
169
170    async fn load(&self, id: &SessionId) -> SessionResult<Option<Session>> {
171        let mut conn = self.get_connection().await?;
172        let key = self.session_key(id);
173
174        let data: Option<String> = conn.get(&key).await.map_err(|e| SessionError::Storage {
175            message: e.to_string(),
176        })?;
177
178        match data {
179            Some(json) => {
180                let session: Session =
181                    serde_json::from_str(&json).map_err(SessionError::Serialization)?;
182                Ok(Some(session))
183            }
184            None => Ok(None),
185        }
186    }
187
188    async fn delete(&self, id: &SessionId) -> SessionResult<bool> {
189        let mut conn = self.get_connection().await?;
190        let key = self.session_key(id);
191
192        if let Some(session) = self.load(id).await?
193            && let Some(ref tenant_id) = session.tenant_id
194        {
195            conn.srem::<_, _, ()>(&self.tenant_key(tenant_id), id.to_string())
196                .await
197                .map_err(|e| SessionError::Storage {
198                    message: e.to_string(),
199                })?;
200        }
201
202        conn.del::<_, ()>(&self.summaries_key(id))
203            .await
204            .map_err(|e| SessionError::Storage {
205                message: e.to_string(),
206            })?;
207        conn.del::<_, ()>(&self.queue_key(id))
208            .await
209            .map_err(|e| SessionError::Storage {
210                message: e.to_string(),
211            })?;
212
213        let deleted: i32 = conn.del(&key).await.map_err(|e| SessionError::Storage {
214            message: e.to_string(),
215        })?;
216
217        Ok(deleted > 0)
218    }
219
220    async fn list(&self, tenant_id: Option<&str>) -> SessionResult<Vec<SessionId>> {
221        let mut conn = self.get_connection().await?;
222
223        match tenant_id {
224            Some(tid) => {
225                let ids: Vec<String> = conn.smembers(self.tenant_key(tid)).await.map_err(|e| {
226                    SessionError::Storage {
227                        message: e.to_string(),
228                    }
229                })?;
230                Ok(ids.into_iter().map(SessionId::from).collect())
231            }
232            None => {
233                let pattern = format!("{}*", self.config.key_prefix);
234                let mut cursor: u64 = 0;
235                let mut all_ids = Vec::new();
236
237                loop {
238                    let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
239                        .arg(cursor)
240                        .arg("MATCH")
241                        .arg(&pattern)
242                        .arg("COUNT")
243                        .arg(100)
244                        .query_async(&mut conn)
245                        .await
246                        .map_err(|e| SessionError::Storage {
247                            message: e.to_string(),
248                        })?;
249
250                    for key in keys {
251                        if let Some(id) = key.strip_prefix(&self.config.key_prefix)
252                            && !id.contains(':')
253                        {
254                            all_ids.push(SessionId::from(id));
255                        }
256                    }
257
258                    cursor = next_cursor;
259                    if cursor == 0 {
260                        break;
261                    }
262                }
263
264                Ok(all_ids)
265            }
266        }
267    }
268
269    async fn list_children(&self, parent_id: &SessionId) -> SessionResult<Vec<SessionId>> {
270        let mut conn = self.get_connection().await?;
271        let ids: Vec<String> = conn
272            .smembers(self.children_key(parent_id))
273            .await
274            .map_err(|e| SessionError::Storage {
275                message: e.to_string(),
276            })?;
277        Ok(ids.into_iter().map(SessionId::from).collect())
278    }
279
280    async fn add_summary(&self, snapshot: SummarySnapshot) -> SessionResult<()> {
281        let mut conn = self.get_connection().await?;
282        let key = self.summaries_key(&snapshot.session_id);
283        let data = serde_json::to_string(&snapshot).map_err(SessionError::Serialization)?;
284
285        conn.rpush::<_, _, ()>(&key, &data)
286            .await
287            .map_err(|e| SessionError::Storage {
288                message: e.to_string(),
289            })?;
290
291        Ok(())
292    }
293
294    async fn get_summaries(&self, session_id: &SessionId) -> SessionResult<Vec<SummarySnapshot>> {
295        let mut conn = self.get_connection().await?;
296        let key = self.summaries_key(session_id);
297
298        let items: Vec<String> =
299            conn.lrange(&key, 0, -1)
300                .await
301                .map_err(|e| SessionError::Storage {
302                    message: e.to_string(),
303                })?;
304
305        items
306            .into_iter()
307            .map(|json| serde_json::from_str(&json).map_err(SessionError::Serialization))
308            .collect()
309    }
310
311    async fn enqueue(
312        &self,
313        session_id: &SessionId,
314        content: String,
315        priority: i32,
316    ) -> SessionResult<QueueItem> {
317        let mut conn = self.get_connection().await?;
318        let key = self.queue_key(session_id);
319        let item = QueueItem::enqueue(*session_id, &content).with_priority(priority);
320        let data = serde_json::to_string(&item).map_err(SessionError::Serialization)?;
321
322        conn.zadd::<_, _, _, ()>(&key, &data, -(priority as f64))
323            .await
324            .map_err(|e| SessionError::Storage {
325                message: e.to_string(),
326            })?;
327
328        Ok(item)
329    }
330
331    async fn dequeue(&self, session_id: &SessionId) -> SessionResult<Option<QueueItem>> {
332        let mut conn = self.get_connection().await?;
333        let key = self.queue_key(session_id);
334
335        let items: Vec<String> =
336            conn.zpopmin(&key, 1)
337                .await
338                .map_err(|e| SessionError::Storage {
339                    message: e.to_string(),
340                })?;
341
342        if items.is_empty() {
343            return Ok(None);
344        }
345
346        let json = &items[0];
347        let mut item: QueueItem =
348            serde_json::from_str(json).map_err(SessionError::Serialization)?;
349        item.start_processing();
350        Ok(Some(item))
351    }
352
353    async fn cancel_queued(&self, item_id: Uuid) -> SessionResult<bool> {
354        let mut conn = self.get_connection().await?;
355        let pattern = format!("{}queue:*", self.config.key_prefix);
356
357        let mut cursor: u64 = 0;
358        loop {
359            let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
360                .arg(cursor)
361                .arg("MATCH")
362                .arg(&pattern)
363                .arg("COUNT")
364                .arg(100)
365                .query_async(&mut conn)
366                .await
367                .map_err(|e| SessionError::Storage {
368                    message: e.to_string(),
369                })?;
370
371            for key in keys {
372                let items: Vec<String> =
373                    conn.zrange(&key, 0, -1)
374                        .await
375                        .map_err(|e| SessionError::Storage {
376                            message: e.to_string(),
377                        })?;
378
379                for json in items {
380                    if let Ok(item) = serde_json::from_str::<QueueItem>(&json)
381                        && item.id == item_id
382                    {
383                        let removed: i32 =
384                            conn.zrem(&key, &json)
385                                .await
386                                .map_err(|e| SessionError::Storage {
387                                    message: e.to_string(),
388                                })?;
389                        return Ok(removed > 0);
390                    }
391                }
392            }
393
394            cursor = next_cursor;
395            if cursor == 0 {
396                break;
397            }
398        }
399
400        Ok(false)
401    }
402
403    async fn pending_queue(&self, session_id: &SessionId) -> SessionResult<Vec<QueueItem>> {
404        let mut conn = self.get_connection().await?;
405        let key = self.queue_key(session_id);
406
407        let items: Vec<String> =
408            conn.zrange(&key, 0, -1)
409                .await
410                .map_err(|e| SessionError::Storage {
411                    message: e.to_string(),
412                })?;
413
414        items
415            .into_iter()
416            .map(|json| serde_json::from_str(&json).map_err(SessionError::Serialization))
417            .collect()
418    }
419
420    async fn cleanup_expired(&self) -> SessionResult<usize> {
421        Ok(0)
422    }
423}