1use crate::config::AgentKind;
16use crate::error::{HermesError, Result};
17use chrono::{DateTime, Duration, Utc};
18use rusqlite::{Connection, params};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::path::PathBuf;
22use std::sync::{Arc, Mutex};
23use tracing::error;
24
25fn serialize_as_active<S: serde::Serializer>(s: S) -> std::result::Result<S::Ok, S::Error> {
26 s.serialize_str("active")
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30#[serde(rename_all = "lowercase")]
31pub enum SessionStatus {
32 Active,
33 Error,
34 #[serde(
37 rename(deserialize = "stopped"),
38 serialize_with = "serialize_as_active"
39 )]
40 Stopped,
41}
42
43impl SessionStatus {
44 fn as_str(&self) -> &'static str {
45 match self {
46 SessionStatus::Active | SessionStatus::Stopped => "active",
47 SessionStatus::Error => "error",
48 }
49 }
50
51 fn from_str(s: &str) -> Self {
52 match s {
53 "error" => SessionStatus::Error,
54 "stopped" => SessionStatus::Stopped,
55 _ => SessionStatus::Active,
56 }
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct SessionInfo {
62 pub session_id: String,
64 pub repo: String,
66 pub repo_path: PathBuf,
68 pub agent_kind: AgentKind,
70 pub channel_id: String,
72 pub thread_ts: String,
74 pub created_at: DateTime<Utc>,
75 pub last_active: DateTime<Utc>,
76 pub status: SessionStatus,
77 pub total_turns: u32,
78 #[serde(default)]
80 pub model: Option<String>,
81}
82
83#[derive(Clone)]
94pub struct SessionStore {
95 conn: Arc<Mutex<Connection>>,
96}
97
98const SCHEMA: &str = "
99CREATE TABLE IF NOT EXISTS sessions (
100 thread_ts TEXT PRIMARY KEY,
101 session_id TEXT NOT NULL,
102 repo TEXT NOT NULL,
103 repo_path TEXT NOT NULL,
104 agent_kind TEXT NOT NULL DEFAULT 'claude',
105 channel_id TEXT NOT NULL,
106 created_at TEXT NOT NULL,
107 last_active TEXT NOT NULL,
108 status TEXT NOT NULL DEFAULT 'active',
109 total_turns INTEGER NOT NULL DEFAULT 0,
110 model TEXT
111);
112CREATE INDEX IF NOT EXISTS idx_sessions_session_id ON sessions(session_id);
113";
114
115impl SessionStore {
116 pub fn new(path: PathBuf) -> Self {
129 let conn = Connection::open(&path).unwrap_or_else(|e| {
130 panic!("Failed to open SQLite database '{}': {}", path.display(), e);
131 });
132
133 conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;")
134 .unwrap_or_else(|e| {
135 panic!("Failed to set SQLite pragmas: {}", e);
136 });
137
138 conn.execute_batch(SCHEMA).unwrap_or_else(|e| {
139 panic!("Failed to create sessions schema: {}", e);
140 });
141
142 let store = Self {
143 conn: Arc::new(Mutex::new(conn)),
144 };
145
146 store.migrate_from_json(&path);
148
149 store
150 }
151
152 fn migrate_from_json(&self, db_path: &std::path::Path) {
154 let json_path = db_path.with_extension("json");
156 let candidates = [json_path];
159
160 for candidate in &candidates {
161 if !candidate.exists() {
162 continue;
163 }
164
165 let contents = match std::fs::read_to_string(candidate) {
166 Ok(c) => c,
167 Err(e) => {
168 tracing::warn!(
169 "Found legacy session file '{}' but failed to read it: {}",
170 candidate.display(),
171 e
172 );
173 continue;
174 }
175 };
176
177 let sessions: HashMap<String, SessionInfo> = match serde_json::from_str(&contents) {
178 Ok(s) => s,
179 Err(e) => {
180 tracing::warn!(
181 "Found legacy session file '{}' but failed to parse it: {}",
182 candidate.display(),
183 e
184 );
185 continue;
186 }
187 };
188
189 if sessions.is_empty() {
190 let backup = candidate.with_extension("json.bak");
192 if let Err(e) = std::fs::rename(candidate, &backup) {
193 tracing::warn!("Failed to rename empty legacy file: {}", e);
194 }
195 continue;
196 }
197
198 let conn = self.conn.lock().unwrap();
199 let result = (|| -> std::result::Result<usize, rusqlite::Error> {
200 let tx = conn.unchecked_transaction()?;
201 let mut count = 0;
202 for (thread_ts, session) in &sessions {
203 tx.execute(
204 "INSERT OR IGNORE INTO sessions (thread_ts, session_id, repo, repo_path, agent_kind, channel_id, created_at, last_active, status, total_turns, model) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
205 params![
206 thread_ts,
207 session.session_id,
208 session.repo,
209 session.repo_path.to_string_lossy().to_string(),
210 "claude",
211 session.channel_id,
212 session.created_at.to_rfc3339(),
213 session.last_active.to_rfc3339(),
214 session.status.as_str(),
215 session.total_turns,
216 session.model,
217 ],
218 )?;
219 count += 1;
220 }
221 tx.commit()?;
222 Ok(count)
223 })();
224
225 drop(conn);
226
227 match result {
228 Ok(count) => {
229 tracing::info!(
230 "Migrated {} session(s) from '{}' to SQLite",
231 count,
232 candidate.display()
233 );
234 let backup = candidate.with_extension("json.bak");
235 if let Err(e) = std::fs::rename(candidate, &backup) {
236 tracing::warn!(
237 "Failed to rename '{}' to '{}': {}",
238 candidate.display(),
239 backup.display(),
240 e
241 );
242 }
243 }
244 Err(e) => {
245 tracing::warn!(
246 "Failed to migrate sessions from '{}': {} (continuing without migration)",
247 candidate.display(),
248 e
249 );
250 }
251 }
252 }
253 }
254
255 #[must_use = "session insert errors mean the session won't persist across restarts"]
257 pub async fn insert(&self, session: SessionInfo) -> Result<()> {
258 let conn = self.conn.clone();
259 tokio::task::spawn_blocking(move || {
260 let conn = conn.lock().unwrap();
261 conn.execute(
262 "INSERT OR REPLACE INTO sessions (thread_ts, session_id, repo, repo_path, agent_kind, channel_id, created_at, last_active, status, total_turns, model) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
263 params![
264 session.thread_ts,
265 session.session_id,
266 session.repo,
267 session.repo_path.to_string_lossy().to_string(),
268 "claude",
269 session.channel_id,
270 session.created_at.to_rfc3339(),
271 session.last_active.to_rfc3339(),
272 session.status.as_str(),
273 session.total_turns,
274 session.model,
275 ],
276 )?;
277 Ok(())
278 })
279 .await
280 .unwrap()
281 }
282
283 pub async fn get_by_thread(&self, thread_ts: &str) -> Option<SessionInfo> {
285 let conn = self.conn.clone();
286 let thread_ts = thread_ts.to_string();
287 tokio::task::spawn_blocking(move || {
288 let conn = conn.lock().unwrap();
289 row_to_session(&conn, &thread_ts)
290 })
291 .await
292 .unwrap()
293 }
294
295 #[must_use = "session update errors mean changes won't persist to disk"]
306 pub async fn update<F>(&self, thread_ts: &str, f: F) -> Result<()>
307 where
308 F: FnOnce(&mut SessionInfo) + Send + 'static,
309 {
310 let conn = self.conn.clone();
311 let thread_ts = thread_ts.to_string();
312 tokio::task::spawn_blocking(move || {
313 let conn = conn.lock().unwrap();
314 let mut session = row_to_session(&conn, &thread_ts)
315 .ok_or_else(|| HermesError::SessionNotFound(thread_ts.clone()))?;
316 f(&mut session);
317 conn.execute(
318 "UPDATE sessions SET session_id=?1, repo=?2, repo_path=?3, agent_kind=?4, channel_id=?5, created_at=?6, last_active=?7, status=?8, total_turns=?9, model=?10 WHERE thread_ts=?11",
319 params![
320 session.session_id,
321 session.repo,
322 session.repo_path.to_string_lossy().to_string(),
323 "claude",
324 session.channel_id,
325 session.created_at.to_rfc3339(),
326 session.last_active.to_rfc3339(),
327 session.status.as_str(),
328 session.total_turns,
329 session.model,
330 thread_ts,
331 ],
332 )?;
333 Ok(())
334 })
335 .await
336 .unwrap()
337 }
338
339 pub async fn active_sessions(&self) -> Vec<SessionInfo> {
340 let conn = self.conn.clone();
341 tokio::task::spawn_blocking(move || {
342 let conn = conn.lock().unwrap();
343 let mut stmt = conn
344 .prepare("SELECT thread_ts, session_id, repo, repo_path, agent_kind, channel_id, created_at, last_active, status, total_turns, model FROM sessions WHERE status != 'error'")
345 .unwrap();
346 stmt.query_map([], row_mapper)
347 .unwrap()
348 .filter_map(|r| r.ok())
349 .collect()
350 })
351 .await
352 .unwrap()
353 }
354
355 pub async fn has_session_id(&self, session_id: &str) -> bool {
357 let conn = self.conn.clone();
358 let session_id = session_id.to_string();
359 tokio::task::spawn_blocking(move || {
360 let conn = conn.lock().unwrap();
361 let exists: bool = conn
362 .query_row(
363 "SELECT EXISTS(SELECT 1 FROM sessions WHERE session_id = ?1 LIMIT 1)",
364 params![session_id],
365 |row| row.get(0),
366 )
367 .unwrap_or(false);
368 exists
369 })
370 .await
371 .unwrap()
372 }
373
374 pub async fn prune_stale_channels(&self, repo_channels: &HashMap<String, String>) {
376 let conn = self.conn.clone();
377 let repo_channels = repo_channels.clone();
378 let result = tokio::task::spawn_blocking(move || {
379 let conn = conn.lock().unwrap();
380 let mut stmt = conn
382 .prepare("SELECT thread_ts, repo, channel_id FROM sessions")
383 .unwrap();
384 let stale: Vec<String> = stmt
385 .query_map([], |row| {
386 Ok((
387 row.get::<_, String>(0)?,
388 row.get::<_, String>(1)?,
389 row.get::<_, String>(2)?,
390 ))
391 })
392 .unwrap()
393 .filter_map(|r| r.ok())
394 .filter(|(_, repo, channel_id)| match repo_channels.get(repo) {
395 Some(current_channel) => channel_id != current_channel,
396 None => true, })
398 .map(|(thread_ts, _, _)| thread_ts)
399 .collect();
400
401 if stale.is_empty() {
402 return 0usize;
403 }
404
405 let count = stale.len();
406 for thread_ts in &stale {
407 if let Err(e) = conn.execute(
408 "DELETE FROM sessions WHERE thread_ts = ?1",
409 params![thread_ts],
410 ) {
411 error!("Failed to delete stale session '{}': {}", thread_ts, e);
412 }
413 }
414 count
415 })
416 .await
417 .unwrap();
418
419 if result > 0 {
420 tracing::info!("Pruned {} stale session(s) from previous run", result);
421 }
422 }
423
424 pub async fn prune_expired(&self, ttl_days: i64) {
426 let conn = self.conn.clone();
427 let result = tokio::task::spawn_blocking(move || {
428 let cutoff = Utc::now() - Duration::days(ttl_days);
429 let cutoff_str = cutoff.to_rfc3339();
430 let conn = conn.lock().unwrap();
431 conn.execute(
432 "DELETE FROM sessions WHERE last_active < ?1",
433 params![cutoff_str],
434 )
435 })
436 .await
437 .unwrap();
438
439 match result {
440 Ok(count) if count > 0 => {
441 tracing::info!(
442 "Pruned {} expired session(s) (older than {} days)",
443 count,
444 ttl_days
445 );
446 }
447 Err(e) => {
448 error!("Failed to prune expired sessions: {}", e);
449 }
450 _ => {}
451 }
452 }
453}
454
455fn row_to_session(conn: &Connection, thread_ts: &str) -> Option<SessionInfo> {
457 conn.query_row(
458 "SELECT thread_ts, session_id, repo, repo_path, agent_kind, channel_id, created_at, last_active, status, total_turns, model FROM sessions WHERE thread_ts = ?1",
459 params![thread_ts],
460 row_mapper,
461 )
462 .ok()
463}
464
465fn row_mapper(row: &rusqlite::Row) -> rusqlite::Result<SessionInfo> {
467 let thread_ts: String = row.get(0)?;
468 let session_id: String = row.get(1)?;
469 let repo: String = row.get(2)?;
470 let repo_path: String = row.get(3)?;
471 let _agent_kind: String = row.get(4)?;
472 let channel_id: String = row.get(5)?;
473 let created_at: String = row.get(6)?;
474 let last_active: String = row.get(7)?;
475 let status: String = row.get(8)?;
476 let total_turns: u32 = row.get(9)?;
477 let model: Option<String> = row.get(10)?;
478
479 Ok(SessionInfo {
480 session_id,
481 repo,
482 repo_path: PathBuf::from(repo_path),
483 agent_kind: AgentKind::Claude,
484 channel_id,
485 thread_ts,
486 created_at: DateTime::parse_from_rfc3339(&created_at)
487 .map(|dt| dt.with_timezone(&Utc))
488 .unwrap_or_else(|_| Utc::now()),
489 last_active: DateTime::parse_from_rfc3339(&last_active)
490 .map(|dt| dt.with_timezone(&Utc))
491 .unwrap_or_else(|_| Utc::now()),
492 status: SessionStatus::from_str(&status),
493 total_turns,
494 model,
495 })
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 fn make_session(session_id: &str, thread_ts: &str, repo: &str) -> SessionInfo {
503 SessionInfo {
504 session_id: session_id.to_string(),
505 repo: repo.to_string(),
506 repo_path: PathBuf::from("/tmp"),
507 agent_kind: AgentKind::Claude,
508 channel_id: "C123".to_string(),
509 thread_ts: thread_ts.to_string(),
510 created_at: Utc::now(),
511 last_active: Utc::now(),
512 status: SessionStatus::Active,
513 total_turns: 0,
514 model: None,
515 }
516 }
517
518 fn temp_store() -> (SessionStore, PathBuf) {
519 let path = std::env::temp_dir().join(format!("hermes_test_{}.db", unique_id()));
520 let store = SessionStore::new(path.clone());
521 (store, path)
522 }
523
524 fn unique_id() -> u64 {
525 use std::sync::atomic::{AtomicU64, Ordering};
526 static COUNTER: AtomicU64 = AtomicU64::new(0);
527 COUNTER.fetch_add(1, Ordering::Relaxed)
528 }
529
530 fn cleanup_db(path: &PathBuf) {
531 let _ = std::fs::remove_file(path);
532 let _ = std::fs::remove_file(path.with_extension("db-wal"));
533 let _ = std::fs::remove_file(path.with_extension("db-shm"));
534 }
535
536 #[tokio::test]
537 async fn test_insert_and_get() {
538 let (store, path) = temp_store();
539 let session = make_session("s1", "t1", "repo1");
540 store.insert(session.clone()).await.unwrap();
541
542 let retrieved = store.get_by_thread("t1").await.unwrap();
543 assert_eq!(retrieved.session_id, "s1");
544 assert_eq!(retrieved.repo, "repo1");
545
546 assert!(store.get_by_thread("nonexistent").await.is_none());
547
548 cleanup_db(&path);
549 }
550
551 #[tokio::test]
552 async fn test_update() {
553 let (store, path) = temp_store();
554 store
555 .insert(make_session("s1", "t1", "repo1"))
556 .await
557 .unwrap();
558
559 store
560 .update("t1", |s| {
561 s.total_turns = 5;
562 s.status = SessionStatus::Error;
563 })
564 .await
565 .unwrap();
566
567 let retrieved = store.get_by_thread("t1").await.unwrap();
568 assert_eq!(retrieved.total_turns, 5);
569 assert_eq!(retrieved.status, SessionStatus::Error);
570
571 cleanup_db(&path);
572 }
573
574 #[tokio::test]
575 async fn test_update_nonexistent_returns_error() {
576 let (store, path) = temp_store();
577 let result = store.update("nonexistent", |_| {}).await;
578 assert!(result.is_err());
579
580 cleanup_db(&path);
581 }
582
583 #[tokio::test]
584 async fn test_active_sessions() {
585 let (store, path) = temp_store();
586 store
587 .insert(make_session("s1", "t1", "repo1"))
588 .await
589 .unwrap();
590
591 let mut errored = make_session("s2", "t2", "repo1");
592 errored.status = SessionStatus::Error;
593 store.insert(errored).await.unwrap();
594
595 let active = store.active_sessions().await;
596 assert_eq!(active.len(), 1);
597 assert_eq!(active[0].session_id, "s1");
598
599 cleanup_db(&path);
600 }
601
602 #[tokio::test]
603 async fn test_has_session_id() {
604 let (store, path) = temp_store();
605 store
606 .insert(make_session("s1", "t1", "repo1"))
607 .await
608 .unwrap();
609
610 assert!(store.has_session_id("s1").await);
611 assert!(!store.has_session_id("s999").await);
612
613 cleanup_db(&path);
614 }
615
616 #[tokio::test]
617 async fn test_persistence_survives_reload() {
618 let (store, path) = temp_store();
619 store
620 .insert(make_session("s1", "t1", "repo1"))
621 .await
622 .unwrap();
623
624 let store2 = SessionStore::new(path.clone());
626 let retrieved = store2.get_by_thread("t1").await.unwrap();
627 assert_eq!(retrieved.session_id, "s1");
628
629 cleanup_db(&path);
630 }
631
632 #[tokio::test]
633 async fn test_prune_stale_channels() {
634 let (store, path) = temp_store();
635 store
636 .insert(make_session("s1", "t1", "repo1"))
637 .await
638 .unwrap();
639
640 let mut s2 = make_session("s2", "t2", "repo2");
641 s2.channel_id = "C999".to_string();
642 store.insert(s2).await.unwrap();
643
644 let mut repo_channels = HashMap::new();
646 repo_channels.insert("repo1".to_string(), "C123".to_string());
647
648 store.prune_stale_channels(&repo_channels).await;
649
650 assert!(store.get_by_thread("t1").await.is_some());
651 assert!(store.get_by_thread("t2").await.is_none());
652
653 cleanup_db(&path);
654 }
655
656 #[tokio::test]
657 async fn test_prune_expired() {
658 let (store, path) = temp_store();
659
660 store
662 .insert(make_session("s1", "t1", "repo1"))
663 .await
664 .unwrap();
665
666 let mut old = make_session("s2", "t2", "repo1");
668 old.last_active = Utc::now() - Duration::days(10);
669 store.insert(old).await.unwrap();
670
671 store.prune_expired(7).await;
672
673 assert!(store.get_by_thread("t1").await.is_some());
674 assert!(store.get_by_thread("t2").await.is_none());
675
676 cleanup_db(&path);
677 }
678
679 #[tokio::test]
680 async fn test_new_with_nonexistent_file() {
681 let path = std::env::temp_dir().join("hermes_test_nonexistent_12345.db");
682 cleanup_db(&path);
683 let store = SessionStore::new(path.clone());
684
685 assert!(store.active_sessions().await.is_empty());
686
687 cleanup_db(&path);
688 }
689
690 #[tokio::test]
691 async fn test_json_migration() {
692 let db_path = std::env::temp_dir().join(format!("hermes_test_migrate_{}.db", unique_id()));
693 let json_path = db_path.with_extension("json");
694
695 let mut sessions = HashMap::new();
697 sessions.insert("t1".to_string(), make_session("s1", "t1", "repo1"));
698 sessions.insert("t2".to_string(), make_session("s2", "t2", "repo2"));
699 let json = serde_json::to_string_pretty(&sessions).unwrap();
700 std::fs::write(&json_path, &json).unwrap();
701
702 let store = SessionStore::new(db_path.clone());
704
705 assert!(store.get_by_thread("t1").await.is_some());
707 assert!(store.get_by_thread("t2").await.is_some());
708
709 assert!(!json_path.exists());
711 assert!(db_path.with_extension("json.bak").exists());
712
713 cleanup_db(&db_path);
714 let _ = std::fs::remove_file(db_path.with_extension("json.bak"));
715 }
716}