Skip to main content

rab/agent/
session_storage.rs

1//! Low-level session persistence abstraction — Pi-compatible `SessionStorage`.
2//!
3//! Pi architecture:
4//!   SessionStorage (trait) ← InMemorySessionStorage / JsonlSessionStorage
5//!   Session (struct)  ← wraps SessionStorage, provides high-level API
6//!   AgentHarness     ← owns Session, drives agent loop
7//!
8//! This module provides the trait and both implementations.
9//! The `Session` struct lives in `session.rs`.
10
11use crate::agent::session::{
12    LeafEntry, SessionEntry, SessionHeader, append_entry_to_file, generate_entry_id,
13    load_session_from_file,
14};
15use std::path::{Path, PathBuf};
16
17// ── SessionMetadata ────────────────────────────────────────────────
18
19/// Metadata about a session, derived from the session header.
20/// Pi-compatible: wraps header info into a metadata object.
21#[derive(Debug, Clone)]
22pub struct SessionMetadata {
23    pub id: String,
24    pub created_at: String,
25    pub cwd: String,
26    /// File path on disk, if this is a persisted session.
27    pub path: Option<PathBuf>,
28    /// Path to the parent session if this was forked.
29    pub parent_session_path: Option<String>,
30}
31
32// ── SessionStorage trait ───────────────────────────────────────────
33
34/// Low-level CRUD abstraction for session persistence.
35///
36/// Pi-compatible: provides leaf management, label tracking, path queries,
37/// and entry CRUD. `Session` builds on this for the high-level API.
38pub trait SessionStorage: Send {
39    /// Return header-derived metadata.
40    fn metadata(&self) -> SessionMetadata;
41
42    /// Get the current leaf entry ID (the last non-leaf entry, resolved through leaf entries).
43    /// Returns `None` if no entries exist.
44    fn get_leaf_id(&self) -> Option<String>;
45
46    /// Persist a leaf entry that records the active session-tree leaf.
47    /// `None` means reset to no leaf.
48    fn set_leaf_id(&mut self, leaf_id: Option<&str>) -> Result<(), String>;
49
50    /// Generate a unique 8-character hex entry ID, collision-checked.
51    fn create_entry_id(&self) -> String;
52
53    /// Append a fully-constructed entry. Updates in-memory state and persists to disk.
54    fn append_entry(&mut self, entry: SessionEntry) -> Result<(), String>;
55
56    /// Look up an entry by ID.
57    fn get_entry(&self, id: &str) -> Option<SessionEntry>;
58
59    /// Find all entries of the given `type` string.
60    fn find_entries(&self, type_name: &str) -> Vec<SessionEntry>;
61
62    /// Get the human-readable label for an entry, if any.
63    fn get_label(&self, id: &str) -> Option<String>;
64
65    /// Walk from `leaf_id` (or current leaf, if None) to root, returning entries in path order.
66    fn get_path_to_root(&self, leaf_id: Option<&str>) -> Result<Vec<SessionEntry>, String>;
67
68    /// Return all entries in insertion order.
69    fn get_entries(&self) -> Vec<SessionEntry>;
70
71    /// The file path on disk, if this storage is file-backed.
72    fn path(&self) -> Option<&Path>;
73}
74
75// ── Helpers shared by both implementations ─────────────────────────
76
77/// Given an entry, return the effective leaf ID after it.
78/// For `Leaf` entries, returns `targetId`; for all others, returns `entry.id`.
79fn leaf_id_after_entry(entry: &SessionEntry) -> Option<String> {
80    match entry {
81        SessionEntry::Leaf(e) => e.target_id.clone(),
82        _ => Some(entry.id().to_string()),
83    }
84}
85
86/// Update the label cache from an entry (call after every append).
87fn update_label_cache(
88    labels_by_id: &mut std::collections::HashMap<String, String>,
89    entry: &SessionEntry,
90) {
91    if let SessionEntry::Label(e) = entry {
92        if let Some(label) = &e.label {
93            let trimmed = label.trim();
94            if trimmed.is_empty() {
95                labels_by_id.remove(&e.target_id);
96            } else {
97                labels_by_id.insert(e.target_id.clone(), trimmed.to_string());
98            }
99        } else {
100            labels_by_id.remove(&e.target_id);
101        }
102    }
103}
104
105/// Build a label cache from a slice of entries.
106fn build_labels_by_id(entries: &[SessionEntry]) -> std::collections::HashMap<String, String> {
107    let mut labels = std::collections::HashMap::new();
108    for entry in entries {
109        update_label_cache(&mut labels, entry);
110    }
111    labels
112}
113
114// ── InMemorySessionStorage ─────────────────────────────────────────
115
116/// Fully in-memory storage — no file I/O.
117/// Pi-compatible: owns all state (entries, labels, leaf).
118pub struct InMemorySessionStorage {
119    metadata: SessionMetadata,
120    entries: Vec<SessionEntry>,
121    by_id: std::collections::HashMap<String, SessionEntry>,
122    labels_by_id: std::collections::HashMap<String, String>,
123    leaf_id: Option<String>,
124}
125
126impl InMemorySessionStorage {
127    /// Create empty storage with explicit metadata.
128    pub fn new(metadata: SessionMetadata) -> Self {
129        Self {
130            metadata,
131            entries: Vec::new(),
132            by_id: std::collections::HashMap::new(),
133            labels_by_id: std::collections::HashMap::new(),
134            leaf_id: None,
135        }
136    }
137}
138
139impl SessionStorage for InMemorySessionStorage {
140    fn metadata(&self) -> SessionMetadata {
141        self.metadata.clone()
142    }
143
144    fn get_leaf_id(&self) -> Option<String> {
145        self.leaf_id.clone()
146    }
147
148    fn set_leaf_id(&mut self, leaf_id: Option<&str>) -> Result<(), String> {
149        if let Some(id) = leaf_id
150            && !self.by_id.contains_key(id)
151        {
152            return Err(format!("Entry {} not found", id));
153        }
154        let entry = SessionEntry::Leaf(LeafEntry {
155            id: self.create_entry_id(),
156            parent_id: self.leaf_id.clone(),
157            timestamp: chrono::Utc::now().to_rfc3339(),
158            target_id: leaf_id.map(|s| s.to_string()),
159        });
160        self.leaf_id = leaf_id.map(|s| s.to_string());
161        self.entries.push(entry.clone());
162        self.by_id.insert(entry.id().to_string(), entry);
163        Ok(())
164    }
165
166    fn create_entry_id(&self) -> String {
167        generate_entry_id(&self.by_id)
168    }
169
170    fn append_entry(&mut self, entry: SessionEntry) -> Result<(), String> {
171        let id = entry.id().to_string();
172        self.by_id.insert(id.clone(), entry);
173        self.entries
174            .push(self.by_id.get(&id).expect("just inserted").clone());
175        self.leaf_id = leaf_id_after_entry(self.by_id.get(&id).expect("just inserted"));
176        update_label_cache(
177            &mut self.labels_by_id,
178            self.by_id.get(&id).expect("just inserted"),
179        );
180        Ok(())
181    }
182
183    fn get_entry(&self, id: &str) -> Option<SessionEntry> {
184        self.by_id.get(id).cloned()
185    }
186
187    fn find_entries(&self, type_name: &str) -> Vec<SessionEntry> {
188        self.entries
189            .iter()
190            .filter(|e| entry_type_name(e) == type_name)
191            .cloned()
192            .collect()
193    }
194
195    fn get_label(&self, id: &str) -> Option<String> {
196        self.labels_by_id.get(id).cloned()
197    }
198
199    fn get_path_to_root(&self, leaf_id: Option<&str>) -> Result<Vec<SessionEntry>, String> {
200        let start_id = leaf_id.or(self.leaf_id.as_deref());
201        if start_id.is_none() {
202            return Ok(vec![]);
203        }
204        let sid = start_id.unwrap();
205        let mut path: Vec<SessionEntry> = Vec::new();
206        let mut current = self.by_id.get(sid);
207        if current.is_none() {
208            return Err(format!("Entry {} not found", sid));
209        }
210        while let Some(entry) = current {
211            path.push(entry.clone());
212            match entry.parent_id() {
213                Some(pid) => {
214                    current = self.by_id.get(pid);
215                }
216                None => break,
217            }
218        }
219        path.reverse();
220        Ok(path)
221    }
222
223    fn get_entries(&self) -> Vec<SessionEntry> {
224        self.entries.clone()
225    }
226
227    fn path(&self) -> Option<&Path> {
228        None
229    }
230}
231
232// ── JsonlSessionStorage ────────────────────────────────────────────
233
234/// File-backed storage: holds full state in memory and persists to a JSONL file.
235/// Pi-compatible: loads from file on creation, appends on every write.
236pub struct JsonlSessionStorage {
237    metadata: SessionMetadata,
238    file_path: PathBuf,
239    entries: Vec<SessionEntry>,
240    by_id: std::collections::HashMap<String, SessionEntry>,
241    labels_by_id: std::collections::HashMap<String, String>,
242    leaf_id: Option<String>,
243}
244
245impl JsonlSessionStorage {
246    /// Create a new session at the given path. Writes the header.
247    pub fn create(
248        file_path: PathBuf,
249        cwd: &str,
250        session_id: &str,
251        parent_session_path: Option<String>,
252    ) -> Result<Self, String> {
253        let created_at = chrono::Utc::now().to_rfc3339();
254        let header = SessionHeader {
255            type_: "session".to_string(),
256            version: Some(crate::agent::session::CURRENT_SESSION_VERSION),
257            id: session_id.to_string(),
258            timestamp: created_at.clone(),
259            cwd: cwd.to_string(),
260            parent_session: parent_session_path.clone(),
261        };
262
263        // Ensure parent directory exists
264        if let Some(parent) = file_path.parent() {
265            std::fs::create_dir_all(parent)
266                .map_err(|e| format!("Failed to create session directory: {}", e))?;
267        }
268
269        // Write header
270        let header_json = serde_json::to_string(&header)
271            .map_err(|e| format!("Failed to serialize header: {}", e))?;
272        std::fs::write(&file_path, header_json + "\n")
273            .map_err(|e| format!("Failed to write session file: {}", e))?;
274
275        let metadata = SessionMetadata {
276            id: session_id.to_string(),
277            created_at,
278            cwd: cwd.to_string(),
279            path: Some(file_path.clone()),
280            parent_session_path,
281        };
282
283        Ok(Self {
284            metadata,
285            file_path,
286            entries: Vec::new(),
287            by_id: std::collections::HashMap::new(),
288            labels_by_id: std::collections::HashMap::new(),
289            leaf_id: None,
290        })
291    }
292
293    /// Open an existing session file. Loads all entries into memory.
294    pub fn open(file_path: PathBuf) -> Result<Self, String> {
295        let (header, entries) = load_session_from_file(&file_path);
296        let header = header
297            .ok_or_else(|| format!("Invalid or missing session header: {}", file_path.display()))?;
298
299        let metadata = SessionMetadata {
300            id: header.id.clone(),
301            created_at: header.timestamp.clone(),
302            cwd: header.cwd,
303            path: Some(file_path.clone()),
304            parent_session_path: header.parent_session,
305        };
306
307        let by_id: std::collections::HashMap<_, _> = entries
308            .iter()
309            .map(|e| (e.id().to_string(), e.clone()))
310            .collect();
311        let labels_by_id = build_labels_by_id(&entries);
312        let leaf_id = entries.last().and_then(leaf_id_after_entry);
313
314        Ok(Self {
315            metadata,
316            file_path,
317            entries,
318            by_id,
319            labels_by_id,
320            leaf_id,
321        })
322    }
323
324    /// Append a line to the file.
325    fn append_to_file(&self, entry: &SessionEntry) -> Result<(), String> {
326        append_entry_to_file(&self.file_path, entry)
327            .map_err(|e| format!("Failed to append session entry: {}", e))
328    }
329}
330
331impl SessionStorage for JsonlSessionStorage {
332    fn metadata(&self) -> SessionMetadata {
333        self.metadata.clone()
334    }
335
336    fn get_leaf_id(&self) -> Option<String> {
337        self.leaf_id.clone()
338    }
339
340    fn set_leaf_id(&mut self, leaf_id: Option<&str>) -> Result<(), String> {
341        if let Some(id) = leaf_id
342            && !self.by_id.contains_key(id)
343        {
344            return Err(format!("Entry {} not found", id));
345        }
346        let entry = SessionEntry::Leaf(LeafEntry {
347            id: self.create_entry_id(),
348            parent_id: self.leaf_id.clone(),
349            timestamp: chrono::Utc::now().to_rfc3339(),
350            target_id: leaf_id.map(|s| s.to_string()),
351        });
352        self.append_to_file(&entry)?;
353        self.leaf_id = leaf_id.map(|s| s.to_string());
354        self.entries.push(entry.clone());
355        self.by_id.insert(entry.id().to_string(), entry);
356        Ok(())
357    }
358
359    fn create_entry_id(&self) -> String {
360        generate_entry_id(&self.by_id)
361    }
362
363    fn append_entry(&mut self, entry: SessionEntry) -> Result<(), String> {
364        self.append_to_file(&entry)?;
365        let id = entry.id().to_string();
366        self.by_id.insert(id.clone(), entry);
367        self.entries
368            .push(self.by_id.get(&id).expect("just inserted").clone());
369        self.leaf_id = leaf_id_after_entry(self.by_id.get(&id).expect("just inserted"));
370        update_label_cache(
371            &mut self.labels_by_id,
372            self.by_id.get(&id).expect("just inserted"),
373        );
374        Ok(())
375    }
376
377    fn get_entry(&self, id: &str) -> Option<SessionEntry> {
378        self.by_id.get(id).cloned()
379    }
380
381    fn find_entries(&self, type_name: &str) -> Vec<SessionEntry> {
382        self.entries
383            .iter()
384            .filter(|e| entry_type_name(e) == type_name)
385            .cloned()
386            .collect()
387    }
388
389    fn get_label(&self, id: &str) -> Option<String> {
390        self.labels_by_id.get(id).cloned()
391    }
392
393    fn get_path_to_root(&self, leaf_id: Option<&str>) -> Result<Vec<SessionEntry>, String> {
394        let start_id = leaf_id.or(self.leaf_id.as_deref());
395        if start_id.is_none() {
396            return Ok(vec![]);
397        }
398        let sid = start_id.unwrap();
399        let mut path: Vec<SessionEntry> = Vec::new();
400        let mut current = self.by_id.get(sid);
401        if current.is_none() {
402            return Err(format!("Entry {} not found", sid));
403        }
404        while let Some(entry) = current {
405            path.push(entry.clone());
406            match entry.parent_id() {
407                Some(pid) => {
408                    current = self.by_id.get(pid);
409                }
410                None => break,
411            }
412        }
413        path.reverse();
414        Ok(path)
415    }
416
417    fn get_entries(&self) -> Vec<SessionEntry> {
418        self.entries.clone()
419    }
420
421    fn path(&self) -> Option<&Path> {
422        Some(&self.file_path)
423    }
424}
425
426// ── Helper: entry type name ────────────────────────────────────────
427
428/// Return the type string for a SessionEntry (pi-compatible).
429fn entry_type_name(entry: &SessionEntry) -> &'static str {
430    match entry {
431        SessionEntry::Message(_) => "message",
432        SessionEntry::ThinkingLevelChange(_) => "thinking_level_change",
433        SessionEntry::ModelChange(_) => "model_change",
434        SessionEntry::ActiveToolsChange(_) => "active_tools_change",
435        SessionEntry::Compaction(_) => "compaction",
436        SessionEntry::BranchSummary(_) => "branch_summary",
437        SessionEntry::SessionInfo(_) => "session_info",
438        SessionEntry::Label(_) => "label",
439        SessionEntry::Custom(_) => "custom",
440        SessionEntry::CustomMessage(_) => "custom_message",
441        SessionEntry::Leaf(_) => "leaf",
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use crate::agent::session::MessageEntry;
449    use crate::agent::types::user_message;
450    use tempfile::TempDir;
451
452    fn make_session_meta(id: &str) -> SessionMetadata {
453        SessionMetadata {
454            id: id.to_string(),
455            created_at: chrono::Utc::now().to_rfc3339(),
456            cwd: "/tmp/test".to_string(),
457            path: None,
458            parent_session_path: None,
459        }
460    }
461
462    fn make_msg_entry(id: &str, parent: Option<&str>, text: &str) -> SessionEntry {
463        SessionEntry::Message(MessageEntry {
464            id: id.to_string(),
465            parent_id: parent.map(|s| s.to_string()),
466            timestamp: chrono::Utc::now().to_rfc3339(),
467            message: user_message(text),
468        })
469    }
470
471    // ── InMemorySessionStorage tests ──────────────────────────────────
472
473    #[test]
474    fn test_in_memory_empty() {
475        let meta = make_session_meta("test");
476        let storage = InMemorySessionStorage::new(meta.clone());
477        assert_eq!(storage.metadata().id, "test");
478        assert!(storage.get_leaf_id().is_none());
479        assert!(storage.get_entries().is_empty());
480    }
481
482    #[test]
483    fn test_in_memory_append_and_get() {
484        let mut storage = InMemorySessionStorage::new(make_session_meta("s1"));
485        let e = make_msg_entry("m1", None, "hello");
486        storage.append_entry(e).unwrap();
487        assert_eq!(storage.get_leaf_id(), Some("m1".to_string()));
488        assert_eq!(storage.get_entry("m1").unwrap().id(), "m1");
489        assert_eq!(storage.get_entries().len(), 1);
490    }
491
492    #[test]
493    fn test_in_memory_path_to_root() {
494        let mut storage = InMemorySessionStorage::new(make_session_meta("s1"));
495        storage
496            .append_entry(make_msg_entry("m1", None, "first"))
497            .unwrap();
498        storage
499            .append_entry(make_msg_entry("m2", Some("m1"), "second"))
500            .unwrap();
501        storage
502            .append_entry(make_msg_entry("m3", Some("m2"), "third"))
503            .unwrap();
504
505        let path = storage.get_path_to_root(Some("m3")).unwrap();
506        assert_eq!(path.len(), 3);
507        assert_eq!(path[0].id(), "m1");
508        assert_eq!(path[2].id(), "m3");
509    }
510
511    #[test]
512    fn test_in_memory_labels() {
513        let mut storage = InMemorySessionStorage::new(make_session_meta("s1"));
514        storage
515            .append_entry(make_msg_entry("m1", None, "first"))
516            .unwrap();
517
518        // Add label
519        let label_entry = SessionEntry::Label(crate::agent::session::LabelEntry {
520            id: "l1".to_string(),
521            parent_id: Some("m1".to_string()),
522            timestamp: chrono::Utc::now().to_rfc3339(),
523            target_id: "m1".to_string(),
524            label: Some("important".to_string()),
525        });
526        storage.append_entry(label_entry).unwrap();
527        assert_eq!(storage.get_label("m1"), Some("important".to_string()));
528
529        // Remove label
530        let unlabel_entry = SessionEntry::Label(crate::agent::session::LabelEntry {
531            id: "l2".to_string(),
532            parent_id: Some("l1".to_string()),
533            timestamp: chrono::Utc::now().to_rfc3339(),
534            target_id: "m1".to_string(),
535            label: None,
536        });
537        storage.append_entry(unlabel_entry).unwrap();
538        assert_eq!(storage.get_label("m1"), None);
539    }
540
541    #[test]
542    fn test_in_memory_set_leaf_id() {
543        let mut storage = InMemorySessionStorage::new(make_session_meta("s1"));
544        storage
545            .append_entry(make_msg_entry("m1", None, "first"))
546            .unwrap();
547        storage
548            .append_entry(make_msg_entry("m2", Some("m1"), "second"))
549            .unwrap();
550
551        // Set leaf to m1 (branching)
552        storage.set_leaf_id(Some("m1")).unwrap();
553        // The leaf entry points to m1
554        assert_eq!(storage.get_leaf_id(), Some("m1".to_string()));
555
556        // Verify leaf entry was appended
557        let entries = storage.get_entries();
558        assert_eq!(entries.len(), 3);
559        assert_eq!(entries[2].id().len(), 8); // leaf entry has auto-generated id
560        assert!(matches!(entries[2], SessionEntry::Leaf(_)));
561    }
562
563    #[test]
564    fn test_in_memory_find_entries() {
565        let mut storage = InMemorySessionStorage::new(make_session_meta("s1"));
566        storage
567            .append_entry(make_msg_entry("m1", None, "first"))
568            .unwrap();
569        let tl =
570            SessionEntry::ThinkingLevelChange(crate::agent::session::ThinkingLevelChangeEntry {
571                id: "tc1".to_string(),
572                parent_id: Some("m1".to_string()),
573                timestamp: chrono::Utc::now().to_rfc3339(),
574                thinking_level: "high".to_string(),
575            });
576        storage.append_entry(tl).unwrap();
577        storage
578            .append_entry(make_msg_entry("m2", Some("tc1"), "second"))
579            .unwrap();
580
581        let msgs = storage.find_entries("message");
582        assert_eq!(msgs.len(), 2);
583        let tls = storage.find_entries("thinking_level_change");
584        assert_eq!(tls.len(), 1);
585    }
586
587    // ── JsonlSessionStorage tests ────────────────────────────────────
588
589    #[test]
590    fn test_jsonl_create_and_append() {
591        let tmp = TempDir::new().unwrap();
592        let path = tmp.path().join("session.jsonl");
593
594        let mut storage =
595            JsonlSessionStorage::create(path.clone(), "/tmp/test", "s1", None).unwrap();
596        assert_eq!(storage.metadata().id, "s1");
597        assert!(path.exists());
598
599        storage
600            .append_entry(make_msg_entry("m1", None, "hello"))
601            .unwrap();
602        assert_eq!(storage.get_entries().len(), 1);
603        assert_eq!(storage.get_leaf_id(), Some("m1".to_string()));
604
605        // Verify persistence by opening again
606        let loaded = JsonlSessionStorage::open(path).unwrap();
607        assert_eq!(loaded.get_entries().len(), 1);
608        assert_eq!(loaded.get_entry("m1").unwrap().id(), "m1");
609    }
610
611    #[test]
612    fn test_jsonl_open_and_traverse() {
613        let tmp = TempDir::new().unwrap();
614        let path = tmp.path().join("session.jsonl");
615
616        let mut storage =
617            JsonlSessionStorage::create(path.clone(), "/tmp/test", "s1", None).unwrap();
618        storage
619            .append_entry(make_msg_entry("m1", None, "first"))
620            .unwrap();
621        storage
622            .append_entry(make_msg_entry("m2", Some("m1"), "second"))
623            .unwrap();
624        drop(storage);
625
626        let loaded = JsonlSessionStorage::open(path).unwrap();
627        let path_to = loaded.get_path_to_root(Some("m2")).unwrap();
628        assert_eq!(path_to.len(), 2);
629    }
630}