Skip to main content

hermes_bot/
session.rs

1//! Persistent session storage for tracking Claude Code sessions.
2//!
3//! The [`SessionStore`] maintains a mapping from Slack threads to Claude Code
4//! sessions, persisted in a SQLite database. This allows sessions to survive
5//! restarts and enables the bot to resume conversations.
6//!
7//! # Features
8//!
9//! - Thread-safe concurrent access (Mutex<Connection>)
10//! - Per-row updates (no full-file rewrites)
11//! - Indexed lookups on session_id
12//! - WAL mode for concurrent readers
13//! - Automatic migration from legacy sessions.json
14
15use crate::config::AgentKind;
16use crate::error::{HermesError, Result};
17use chrono::{DateTime, Duration, Utc};
18use rusqlite::{Connection, params};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::path::PathBuf;
22use std::sync::{Arc, Mutex};
23use tracing::error;
24
25fn serialize_as_active<S: serde::Serializer>(s: S) -> std::result::Result<S::Ok, S::Error> {
26    s.serialize_str("active")
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30#[serde(rename_all = "lowercase")]
31pub enum SessionStatus {
32    Active,
33    Error,
34    /// Legacy: old sessions may have "stopped" on disk. Deserializes from
35    /// "stopped" but serializes back as "active" so it can round-trip.
36    #[serde(
37        rename(deserialize = "stopped"),
38        serialize_with = "serialize_as_active"
39    )]
40    Stopped,
41}
42
43impl SessionStatus {
44    fn as_str(&self) -> &'static str {
45        match self {
46            SessionStatus::Active | SessionStatus::Stopped => "active",
47            SessionStatus::Error => "error",
48        }
49    }
50
51    fn from_str(s: &str) -> Self {
52        match s {
53            "error" => SessionStatus::Error,
54            "stopped" => SessionStatus::Stopped,
55            _ => SessionStatus::Active,
56        }
57    }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct SessionInfo {
62    /// The agent's session ID (used for --resume).
63    pub session_id: String,
64    /// Repo name from config.
65    pub repo: String,
66    /// Absolute path to the repo.
67    pub repo_path: PathBuf,
68    /// Which agent backend.
69    pub agent_kind: AgentKind,
70    /// Slack channel ID.
71    pub channel_id: String,
72    /// Thread timestamp (Slack's thread_ts) — identifies the thread.
73    pub thread_ts: String,
74    pub created_at: DateTime<Utc>,
75    pub last_active: DateTime<Utc>,
76    pub status: SessionStatus,
77    pub total_turns: u32,
78    /// Model used for this session (for display; CLI remembers on --resume).
79    #[serde(default)]
80    pub model: Option<String>,
81}
82
83/// Thread-safe persistent session store backed by SQLite.
84///
85/// Manages active Claude Code sessions, tracking which Slack threads
86/// correspond to which agent sessions. Uses WAL mode for concurrent
87/// read access and per-row updates.
88///
89/// # Thread Safety
90///
91/// All methods are async and use `spawn_blocking` with a sync `Mutex`
92/// to avoid holding locks across await points.
93#[derive(Clone)]
94pub struct SessionStore {
95    conn: Arc<Mutex<Connection>>,
96}
97
98const SCHEMA: &str = "
99CREATE TABLE IF NOT EXISTS sessions (
100    thread_ts    TEXT PRIMARY KEY,
101    session_id   TEXT NOT NULL,
102    repo         TEXT NOT NULL,
103    repo_path    TEXT NOT NULL,
104    agent_kind   TEXT NOT NULL DEFAULT 'claude',
105    channel_id   TEXT NOT NULL,
106    created_at   TEXT NOT NULL,
107    last_active  TEXT NOT NULL,
108    status       TEXT NOT NULL DEFAULT 'active',
109    total_turns  INTEGER NOT NULL DEFAULT 0,
110    model        TEXT
111);
112CREATE INDEX IF NOT EXISTS idx_sessions_session_id ON sessions(session_id);
113";
114
115impl SessionStore {
116    /// Creates a new session store, opening or creating the SQLite database.
117    ///
118    /// If a legacy `sessions.json` file exists next to the database path,
119    /// it will be migrated automatically.
120    ///
121    /// # Arguments
122    ///
123    /// * `path` - Path to the SQLite database file
124    ///
125    /// # Returns
126    ///
127    /// A new `SessionStore` instance.
128    pub fn new(path: PathBuf) -> Self {
129        let conn = Connection::open(&path).unwrap_or_else(|e| {
130            panic!("Failed to open SQLite database '{}': {}", path.display(), e);
131        });
132
133        conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;")
134            .unwrap_or_else(|e| {
135                panic!("Failed to set SQLite pragmas: {}", e);
136            });
137
138        conn.execute_batch(SCHEMA).unwrap_or_else(|e| {
139            panic!("Failed to create sessions schema: {}", e);
140        });
141
142        let store = Self {
143            conn: Arc::new(Mutex::new(conn)),
144        };
145
146        // Attempt JSON migration
147        store.migrate_from_json(&path);
148
149        store
150    }
151
152    /// Migrate sessions from a legacy JSON file if one exists.
153    fn migrate_from_json(&self, db_path: &std::path::Path) {
154        // Look for sessions.json in the same directory as the DB file
155        let json_path = db_path.with_extension("json");
156        // Also check if the original path was .json (shouldn't happen post-migration,
157        // but handle the edge case of a path like "sessions.json" being passed)
158        let candidates = [json_path];
159
160        for candidate in &candidates {
161            if !candidate.exists() {
162                continue;
163            }
164
165            let contents = match std::fs::read_to_string(candidate) {
166                Ok(c) => c,
167                Err(e) => {
168                    tracing::warn!(
169                        "Found legacy session file '{}' but failed to read it: {}",
170                        candidate.display(),
171                        e
172                    );
173                    continue;
174                }
175            };
176
177            let sessions: HashMap<String, SessionInfo> = match serde_json::from_str(&contents) {
178                Ok(s) => s,
179                Err(e) => {
180                    tracing::warn!(
181                        "Found legacy session file '{}' but failed to parse it: {}",
182                        candidate.display(),
183                        e
184                    );
185                    continue;
186                }
187            };
188
189            if sessions.is_empty() {
190                // Remove empty JSON file
191                let backup = candidate.with_extension("json.bak");
192                if let Err(e) = std::fs::rename(candidate, &backup) {
193                    tracing::warn!("Failed to rename empty legacy file: {}", e);
194                }
195                continue;
196            }
197
198            let conn = self.conn.lock().unwrap();
199            let result = (|| -> std::result::Result<usize, rusqlite::Error> {
200                let tx = conn.unchecked_transaction()?;
201                let mut count = 0;
202                for (thread_ts, session) in &sessions {
203                    tx.execute(
204                        "INSERT OR IGNORE INTO sessions (thread_ts, session_id, repo, repo_path, agent_kind, channel_id, created_at, last_active, status, total_turns, model) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
205                        params![
206                            thread_ts,
207                            session.session_id,
208                            session.repo,
209                            session.repo_path.to_string_lossy().to_string(),
210                            "claude",
211                            session.channel_id,
212                            session.created_at.to_rfc3339(),
213                            session.last_active.to_rfc3339(),
214                            session.status.as_str(),
215                            session.total_turns,
216                            session.model,
217                        ],
218                    )?;
219                    count += 1;
220                }
221                tx.commit()?;
222                Ok(count)
223            })();
224
225            drop(conn);
226
227            match result {
228                Ok(count) => {
229                    tracing::info!(
230                        "Migrated {} session(s) from '{}' to SQLite",
231                        count,
232                        candidate.display()
233                    );
234                    let backup = candidate.with_extension("json.bak");
235                    if let Err(e) = std::fs::rename(candidate, &backup) {
236                        tracing::warn!(
237                            "Failed to rename '{}' to '{}': {}",
238                            candidate.display(),
239                            backup.display(),
240                            e
241                        );
242                    }
243                }
244                Err(e) => {
245                    tracing::warn!(
246                        "Failed to migrate sessions from '{}': {} (continuing without migration)",
247                        candidate.display(),
248                        e
249                    );
250                }
251            }
252        }
253    }
254
255    /// Inserts a new session into the database.
256    #[must_use = "session insert errors mean the session won't persist across restarts"]
257    pub async fn insert(&self, session: SessionInfo) -> Result<()> {
258        let conn = self.conn.clone();
259        tokio::task::spawn_blocking(move || {
260            let conn = conn.lock().unwrap();
261            conn.execute(
262                "INSERT OR REPLACE INTO sessions (thread_ts, session_id, repo, repo_path, agent_kind, channel_id, created_at, last_active, status, total_turns, model) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
263                params![
264                    session.thread_ts,
265                    session.session_id,
266                    session.repo,
267                    session.repo_path.to_string_lossy().to_string(),
268                    "claude",
269                    session.channel_id,
270                    session.created_at.to_rfc3339(),
271                    session.last_active.to_rfc3339(),
272                    session.status.as_str(),
273                    session.total_turns,
274                    session.model,
275                ],
276            )?;
277            Ok(())
278        })
279        .await
280        .unwrap()
281    }
282
283    /// Retrieves a session by its Slack thread timestamp.
284    pub async fn get_by_thread(&self, thread_ts: &str) -> Option<SessionInfo> {
285        let conn = self.conn.clone();
286        let thread_ts = thread_ts.to_string();
287        tokio::task::spawn_blocking(move || {
288            let conn = conn.lock().unwrap();
289            row_to_session(&conn, &thread_ts)
290        })
291        .await
292        .unwrap()
293    }
294
295    /// Updates a session in-place within a transaction.
296    ///
297    /// # Arguments
298    ///
299    /// * `thread_ts` - Thread timestamp of the session to update
300    /// * `f` - Closure that modifies the session
301    ///
302    /// # Errors
303    ///
304    /// Returns `SessionNotFound` if the thread doesn't exist.
305    #[must_use = "session update errors mean changes won't persist to disk"]
306    pub async fn update<F>(&self, thread_ts: &str, f: F) -> Result<()>
307    where
308        F: FnOnce(&mut SessionInfo) + Send + 'static,
309    {
310        let conn = self.conn.clone();
311        let thread_ts = thread_ts.to_string();
312        tokio::task::spawn_blocking(move || {
313            let conn = conn.lock().unwrap();
314            let mut session = row_to_session(&conn, &thread_ts)
315                .ok_or_else(|| HermesError::SessionNotFound(thread_ts.clone()))?;
316            f(&mut session);
317            conn.execute(
318                "UPDATE sessions SET session_id=?1, repo=?2, repo_path=?3, agent_kind=?4, channel_id=?5, created_at=?6, last_active=?7, status=?8, total_turns=?9, model=?10 WHERE thread_ts=?11",
319                params![
320                    session.session_id,
321                    session.repo,
322                    session.repo_path.to_string_lossy().to_string(),
323                    "claude",
324                    session.channel_id,
325                    session.created_at.to_rfc3339(),
326                    session.last_active.to_rfc3339(),
327                    session.status.as_str(),
328                    session.total_turns,
329                    session.model,
330                    thread_ts,
331                ],
332            )?;
333            Ok(())
334        })
335        .await
336        .unwrap()
337    }
338
339    pub async fn active_sessions(&self) -> Vec<SessionInfo> {
340        let conn = self.conn.clone();
341        tokio::task::spawn_blocking(move || {
342            let conn = conn.lock().unwrap();
343            let mut stmt = conn
344                .prepare("SELECT thread_ts, session_id, repo, repo_path, agent_kind, channel_id, created_at, last_active, status, total_turns, model FROM sessions WHERE status != 'error'")
345                .unwrap();
346            stmt.query_map([], row_mapper)
347                .unwrap()
348                .filter_map(|r| r.ok())
349                .collect()
350        })
351        .await
352        .unwrap()
353    }
354
355    /// Checks if any session has the given agent session ID (indexed lookup).
356    pub async fn has_session_id(&self, session_id: &str) -> bool {
357        let conn = self.conn.clone();
358        let session_id = session_id.to_string();
359        tokio::task::spawn_blocking(move || {
360            let conn = conn.lock().unwrap();
361            let exists: bool = conn
362                .query_row(
363                    "SELECT EXISTS(SELECT 1 FROM sessions WHERE session_id = ?1 LIMIT 1)",
364                    params![session_id],
365                    |row| row.get(0),
366                )
367                .unwrap_or(false);
368            exists
369        })
370        .await
371        .unwrap()
372    }
373
374    /// Remove sessions whose channel_id doesn't match the current channel for their repo.
375    pub async fn prune_stale_channels(&self, repo_channels: &HashMap<String, String>) {
376        let conn = self.conn.clone();
377        let repo_channels = repo_channels.clone();
378        let result = tokio::task::spawn_blocking(move || {
379            let conn = conn.lock().unwrap();
380            // Read all sessions, determine which to delete
381            let mut stmt = conn
382                .prepare("SELECT thread_ts, repo, channel_id FROM sessions")
383                .unwrap();
384            let stale: Vec<String> = stmt
385                .query_map([], |row| {
386                    Ok((
387                        row.get::<_, String>(0)?,
388                        row.get::<_, String>(1)?,
389                        row.get::<_, String>(2)?,
390                    ))
391                })
392                .unwrap()
393                .filter_map(|r| r.ok())
394                .filter(|(_, repo, channel_id)| match repo_channels.get(repo) {
395                    Some(current_channel) => channel_id != current_channel,
396                    None => true, // Repo no longer configured
397                })
398                .map(|(thread_ts, _, _)| thread_ts)
399                .collect();
400
401            if stale.is_empty() {
402                return 0usize;
403            }
404
405            let count = stale.len();
406            for thread_ts in &stale {
407                if let Err(e) = conn.execute(
408                    "DELETE FROM sessions WHERE thread_ts = ?1",
409                    params![thread_ts],
410                ) {
411                    error!("Failed to delete stale session '{}': {}", thread_ts, e);
412                }
413            }
414            count
415        })
416        .await
417        .unwrap();
418
419        if result > 0 {
420            tracing::info!("Pruned {} stale session(s) from previous run", result);
421        }
422    }
423
424    /// Remove sessions whose last_active is older than the TTL.
425    pub async fn prune_expired(&self, ttl_days: i64) {
426        let conn = self.conn.clone();
427        let result = tokio::task::spawn_blocking(move || {
428            let cutoff = Utc::now() - Duration::days(ttl_days);
429            let cutoff_str = cutoff.to_rfc3339();
430            let conn = conn.lock().unwrap();
431            conn.execute(
432                "DELETE FROM sessions WHERE last_active < ?1",
433                params![cutoff_str],
434            )
435        })
436        .await
437        .unwrap();
438
439        match result {
440            Ok(count) if count > 0 => {
441                tracing::info!(
442                    "Pruned {} expired session(s) (older than {} days)",
443                    count,
444                    ttl_days
445                );
446            }
447            Err(e) => {
448                error!("Failed to prune expired sessions: {}", e);
449            }
450            _ => {}
451        }
452    }
453}
454
455/// Read a single session row by thread_ts.
456fn row_to_session(conn: &Connection, thread_ts: &str) -> Option<SessionInfo> {
457    conn.query_row(
458        "SELECT thread_ts, session_id, repo, repo_path, agent_kind, channel_id, created_at, last_active, status, total_turns, model FROM sessions WHERE thread_ts = ?1",
459        params![thread_ts],
460        row_mapper,
461    )
462    .ok()
463}
464
465/// Map a row to SessionInfo.
466fn row_mapper(row: &rusqlite::Row) -> rusqlite::Result<SessionInfo> {
467    let thread_ts: String = row.get(0)?;
468    let session_id: String = row.get(1)?;
469    let repo: String = row.get(2)?;
470    let repo_path: String = row.get(3)?;
471    let _agent_kind: String = row.get(4)?;
472    let channel_id: String = row.get(5)?;
473    let created_at: String = row.get(6)?;
474    let last_active: String = row.get(7)?;
475    let status: String = row.get(8)?;
476    let total_turns: u32 = row.get(9)?;
477    let model: Option<String> = row.get(10)?;
478
479    Ok(SessionInfo {
480        session_id,
481        repo,
482        repo_path: PathBuf::from(repo_path),
483        agent_kind: AgentKind::Claude,
484        channel_id,
485        thread_ts,
486        created_at: DateTime::parse_from_rfc3339(&created_at)
487            .map(|dt| dt.with_timezone(&Utc))
488            .unwrap_or_else(|_| Utc::now()),
489        last_active: DateTime::parse_from_rfc3339(&last_active)
490            .map(|dt| dt.with_timezone(&Utc))
491            .unwrap_or_else(|_| Utc::now()),
492        status: SessionStatus::from_str(&status),
493        total_turns,
494        model,
495    })
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    fn make_session(session_id: &str, thread_ts: &str, repo: &str) -> SessionInfo {
503        SessionInfo {
504            session_id: session_id.to_string(),
505            repo: repo.to_string(),
506            repo_path: PathBuf::from("/tmp"),
507            agent_kind: AgentKind::Claude,
508            channel_id: "C123".to_string(),
509            thread_ts: thread_ts.to_string(),
510            created_at: Utc::now(),
511            last_active: Utc::now(),
512            status: SessionStatus::Active,
513            total_turns: 0,
514            model: None,
515        }
516    }
517
518    fn temp_store() -> (SessionStore, PathBuf) {
519        let path = std::env::temp_dir().join(format!("hermes_test_{}.db", unique_id()));
520        let store = SessionStore::new(path.clone());
521        (store, path)
522    }
523
524    fn unique_id() -> u64 {
525        use std::sync::atomic::{AtomicU64, Ordering};
526        static COUNTER: AtomicU64 = AtomicU64::new(0);
527        COUNTER.fetch_add(1, Ordering::Relaxed)
528    }
529
530    fn cleanup_db(path: &PathBuf) {
531        let _ = std::fs::remove_file(path);
532        let _ = std::fs::remove_file(path.with_extension("db-wal"));
533        let _ = std::fs::remove_file(path.with_extension("db-shm"));
534    }
535
536    #[tokio::test]
537    async fn test_insert_and_get() {
538        let (store, path) = temp_store();
539        let session = make_session("s1", "t1", "repo1");
540        store.insert(session.clone()).await.unwrap();
541
542        let retrieved = store.get_by_thread("t1").await.unwrap();
543        assert_eq!(retrieved.session_id, "s1");
544        assert_eq!(retrieved.repo, "repo1");
545
546        assert!(store.get_by_thread("nonexistent").await.is_none());
547
548        cleanup_db(&path);
549    }
550
551    #[tokio::test]
552    async fn test_update() {
553        let (store, path) = temp_store();
554        store
555            .insert(make_session("s1", "t1", "repo1"))
556            .await
557            .unwrap();
558
559        store
560            .update("t1", |s| {
561                s.total_turns = 5;
562                s.status = SessionStatus::Error;
563            })
564            .await
565            .unwrap();
566
567        let retrieved = store.get_by_thread("t1").await.unwrap();
568        assert_eq!(retrieved.total_turns, 5);
569        assert_eq!(retrieved.status, SessionStatus::Error);
570
571        cleanup_db(&path);
572    }
573
574    #[tokio::test]
575    async fn test_update_nonexistent_returns_error() {
576        let (store, path) = temp_store();
577        let result = store.update("nonexistent", |_| {}).await;
578        assert!(result.is_err());
579
580        cleanup_db(&path);
581    }
582
583    #[tokio::test]
584    async fn test_active_sessions() {
585        let (store, path) = temp_store();
586        store
587            .insert(make_session("s1", "t1", "repo1"))
588            .await
589            .unwrap();
590
591        let mut errored = make_session("s2", "t2", "repo1");
592        errored.status = SessionStatus::Error;
593        store.insert(errored).await.unwrap();
594
595        let active = store.active_sessions().await;
596        assert_eq!(active.len(), 1);
597        assert_eq!(active[0].session_id, "s1");
598
599        cleanup_db(&path);
600    }
601
602    #[tokio::test]
603    async fn test_has_session_id() {
604        let (store, path) = temp_store();
605        store
606            .insert(make_session("s1", "t1", "repo1"))
607            .await
608            .unwrap();
609
610        assert!(store.has_session_id("s1").await);
611        assert!(!store.has_session_id("s999").await);
612
613        cleanup_db(&path);
614    }
615
616    #[tokio::test]
617    async fn test_persistence_survives_reload() {
618        let (store, path) = temp_store();
619        store
620            .insert(make_session("s1", "t1", "repo1"))
621            .await
622            .unwrap();
623
624        // Create a new store from the same file.
625        let store2 = SessionStore::new(path.clone());
626        let retrieved = store2.get_by_thread("t1").await.unwrap();
627        assert_eq!(retrieved.session_id, "s1");
628
629        cleanup_db(&path);
630    }
631
632    #[tokio::test]
633    async fn test_prune_stale_channels() {
634        let (store, path) = temp_store();
635        store
636            .insert(make_session("s1", "t1", "repo1"))
637            .await
638            .unwrap();
639
640        let mut s2 = make_session("s2", "t2", "repo2");
641        s2.channel_id = "C999".to_string();
642        store.insert(s2).await.unwrap();
643
644        // Only repo1 with C123 is current.
645        let mut repo_channels = HashMap::new();
646        repo_channels.insert("repo1".to_string(), "C123".to_string());
647
648        store.prune_stale_channels(&repo_channels).await;
649
650        assert!(store.get_by_thread("t1").await.is_some());
651        assert!(store.get_by_thread("t2").await.is_none());
652
653        cleanup_db(&path);
654    }
655
656    #[tokio::test]
657    async fn test_prune_expired() {
658        let (store, path) = temp_store();
659
660        // Recent session — should survive.
661        store
662            .insert(make_session("s1", "t1", "repo1"))
663            .await
664            .unwrap();
665
666        // Old session — should be pruned.
667        let mut old = make_session("s2", "t2", "repo1");
668        old.last_active = Utc::now() - Duration::days(10);
669        store.insert(old).await.unwrap();
670
671        store.prune_expired(7).await;
672
673        assert!(store.get_by_thread("t1").await.is_some());
674        assert!(store.get_by_thread("t2").await.is_none());
675
676        cleanup_db(&path);
677    }
678
679    #[tokio::test]
680    async fn test_new_with_nonexistent_file() {
681        let path = std::env::temp_dir().join("hermes_test_nonexistent_12345.db");
682        cleanup_db(&path);
683        let store = SessionStore::new(path.clone());
684
685        assert!(store.active_sessions().await.is_empty());
686
687        cleanup_db(&path);
688    }
689
690    #[tokio::test]
691    async fn test_json_migration() {
692        let db_path = std::env::temp_dir().join(format!("hermes_test_migrate_{}.db", unique_id()));
693        let json_path = db_path.with_extension("json");
694
695        // Create a legacy JSON file
696        let mut sessions = HashMap::new();
697        sessions.insert("t1".to_string(), make_session("s1", "t1", "repo1"));
698        sessions.insert("t2".to_string(), make_session("s2", "t2", "repo2"));
699        let json = serde_json::to_string_pretty(&sessions).unwrap();
700        std::fs::write(&json_path, &json).unwrap();
701
702        // Open the store — should migrate
703        let store = SessionStore::new(db_path.clone());
704
705        // Verify sessions were migrated
706        assert!(store.get_by_thread("t1").await.is_some());
707        assert!(store.get_by_thread("t2").await.is_some());
708
709        // Verify JSON file was renamed to .bak
710        assert!(!json_path.exists());
711        assert!(db_path.with_extension("json.bak").exists());
712
713        cleanup_db(&db_path);
714        let _ = std::fs::remove_file(db_path.with_extension("json.bak"));
715    }
716}