Skip to main content

limit_cli/
session_tree.rs

1use limit_llm::{CacheControl, Message, Role, ToolCall};
2use rand::RngExt;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs::{File, OpenOptions};
6use std::io::{BufRead, BufReader, BufWriter, Write};
7use std::path::Path;
8
9/// Unique 8-character hex ID for session entries
10pub type EntryId = String;
11
12/// A single entry in the session tree
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SessionEntry {
15    /// 8-char hex ID
16    pub id: EntryId,
17    /// Parent entry ID (None for root)
18    pub parent_id: Option<EntryId>,
19    /// ISO 8601 timestamp
20    pub timestamp: String,
21    /// The entry content
22    #[serde(flatten)]
23    pub entry_type: SessionEntryType,
24}
25
26/// Types of entries in the session tree
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "type", rename_all = "snake_case")]
29pub enum SessionEntryType {
30    /// Session metadata (first entry in file)
31    Session { version: u32, cwd: String },
32    /// User/assistant/tool message
33    Message { message: SerializableMessage },
34    /// Compaction summary (replaces old messages)
35    Compaction {
36        summary: String,
37        first_kept_id: EntryId,
38    },
39    /// Branch context when switching branches
40    BranchSummary { from_id: EntryId, summary: String },
41}
42
43/// Message that can be serialized to JSON
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SerializableMessage {
46    pub role: String,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub content: Option<String>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub tool_calls: Option<Vec<ToolCall>>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub tool_call_id: Option<String>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub cache_control: Option<CacheControl>,
55}
56
57impl From<Message> for SerializableMessage {
58    fn from(msg: Message) -> Self {
59        Self {
60            role: match msg.role {
61                Role::User => "user".to_string(),
62                Role::Assistant => "assistant".to_string(),
63                Role::System => "system".to_string(),
64                Role::Tool => "tool".to_string(),
65            },
66            content: msg.content.map(|c| c.to_text()),
67            tool_calls: msg.tool_calls,
68            tool_call_id: msg.tool_call_id,
69            cache_control: msg.cache_control,
70        }
71    }
72}
73
74impl From<SerializableMessage> for Message {
75    fn from(msg: SerializableMessage) -> Self {
76        Self {
77            role: match msg.role.as_str() {
78                "user" => Role::User,
79                "assistant" => Role::Assistant,
80                "system" => Role::System,
81                "tool" => Role::Tool,
82                _ => Role::User,
83            },
84            content: msg.content.map(limit_llm::MessageContent::text),
85            tool_calls: msg.tool_calls,
86            tool_call_id: msg.tool_call_id,
87            cache_control: msg.cache_control,
88        }
89    }
90}
91
92/// Generate a random 8-char hex ID
93pub fn generate_entry_id() -> EntryId {
94    let mut rng = rand::rng();
95    format!("{:08x}", rng.random::<u32>())
96}
97
98/// In-memory tree structure for session entries
99pub struct SessionTree {
100    /// All entries indexed by ID
101    entries: HashMap<EntryId, SessionEntry>,
102    /// Current leaf entry ID
103    leaf_id: EntryId,
104    /// Session metadata
105    session_id: String,
106    /// Working directory when session was created
107    cwd: String,
108}
109
110impl SessionTree {
111    /// Create a new empty session tree
112    pub fn new(cwd: String) -> Self {
113        let session_id = uuid::Uuid::new_v4().to_string();
114        Self {
115            entries: HashMap::new(),
116            leaf_id: String::new(),
117            session_id,
118            cwd,
119        }
120    }
121
122    /// Load from existing entries
123    pub fn from_entries(
124        entries: Vec<SessionEntry>,
125        session_id: String,
126        cwd: String,
127    ) -> Result<Self, SessionTreeError> {
128        let mut by_id: HashMap<EntryId, SessionEntry> = HashMap::new();
129        let mut leaf_id = String::new();
130
131        for entry in entries {
132            leaf_id = entry.id.clone();
133            by_id.insert(entry.id.clone(), entry);
134        }
135
136        Ok(Self {
137            entries: by_id,
138            leaf_id,
139            session_id,
140            cwd,
141        })
142    }
143
144    /// Append a new entry as child of current leaf
145    pub fn append(&mut self, entry: SessionEntry) -> Result<(), SessionTreeError> {
146        let id = entry.id.clone();
147
148        if self.entries.is_empty() {
149            if entry.parent_id.is_some() {
150                return Err(SessionTreeError::InvalidParent {
151                    expected: "none (first entry)".to_string(),
152                    got: entry.parent_id.clone(),
153                });
154            }
155        } else if entry.parent_id.as_ref() != Some(&self.leaf_id) {
156            return Err(SessionTreeError::InvalidParent {
157                expected: self.leaf_id.clone(),
158                got: entry.parent_id.clone(),
159            });
160        }
161
162        self.entries.insert(id.clone(), entry);
163        self.leaf_id = id;
164        Ok(())
165    }
166
167    /// Build message context from leaf to root
168    pub fn build_context(&self, leaf_id: &str) -> Result<Vec<Message>, SessionTreeError> {
169        let mut path = Vec::new();
170        let mut current_id = Some(leaf_id.to_string());
171
172        while let Some(id) = current_id {
173            let entry = self
174                .entries
175                .get(&id)
176                .ok_or(SessionTreeError::EntryNotFound(id))?;
177
178            // Handle compaction: stop at first_kept_id
179            if let SessionEntryType::Compaction { first_kept_id, .. } = &entry.entry_type {
180                current_id = Some(first_kept_id.clone());
181                path.push(entry.clone());
182                continue;
183            }
184
185            current_id = entry.parent_id.clone();
186            path.push(entry.clone());
187        }
188
189        path.reverse();
190
191        let messages: Vec<Message> = path
192            .into_iter()
193            .filter_map(|entry| match entry.entry_type {
194                SessionEntryType::Message { message } => Some(Message::from(message)),
195                SessionEntryType::Compaction { summary, .. } => Some(Message {
196                    role: Role::User,
197                    content: Some(limit_llm::MessageContent::text(format!(
198                        "<summary>\n{}\n</summary>",
199                        summary
200                    ))),
201                    tool_calls: None,
202                    tool_call_id: None,
203                    cache_control: None,
204                }),
205                SessionEntryType::Session { .. } => None,
206                SessionEntryType::BranchSummary { .. } => None,
207            })
208            .collect();
209
210        Ok(messages)
211    }
212
213    /// Create a branch from a specific entry
214    pub fn branch_from(&mut self, entry_id: &str) -> Result<EntryId, SessionTreeError> {
215        if !self.entries.contains_key(entry_id) {
216            return Err(SessionTreeError::EntryNotFound(entry_id.to_string()));
217        }
218
219        self.leaf_id = entry_id.to_string();
220        Ok(entry_id.to_string())
221    }
222
223    /// Get current leaf ID
224    pub fn leaf_id(&self) -> &str {
225        &self.leaf_id
226    }
227
228    /// Get all entries
229    pub fn entries(&self) -> Vec<&SessionEntry> {
230        self.entries.values().collect()
231    }
232
233    /// Get session ID
234    pub fn session_id(&self) -> &str {
235        &self.session_id
236    }
237
238    /// Save tree to JSONL file
239    pub fn save_to_file(&self, path: &Path) -> Result<(), SessionTreeError> {
240        let file = File::create(path)?;
241        let mut writer = BufWriter::new(file);
242
243        let header = SessionEntry {
244            id: self.session_id.clone(),
245            parent_id: None,
246            timestamp: chrono::Utc::now().to_rfc3339(),
247            entry_type: SessionEntryType::Session {
248                version: 1,
249                cwd: self.cwd.clone(),
250            },
251        };
252        writeln!(writer, "{}", serde_json::to_string(&header)?)?;
253
254        let sorted = self.sort_entries()?;
255        for entry in sorted {
256            writeln!(writer, "{}", serde_json::to_string(&entry)?)?;
257        }
258
259        writer.flush()?;
260        Ok(())
261    }
262
263    /// Load tree from JSONL file
264    pub fn load_from_file(path: &Path) -> Result<Self, SessionTreeError> {
265        let file = File::open(path)?;
266        let reader = BufReader::new(file);
267
268        let mut entries = Vec::new();
269        let mut session_id = String::new();
270        let mut cwd = String::new();
271
272        for line in reader.lines() {
273            let line: String = line?;
274            if line.trim().is_empty() {
275                continue;
276            }
277
278            let entry: SessionEntry = serde_json::from_str(&line)?;
279
280            if let SessionEntryType::Session { version: _, cwd: c } = &entry.entry_type {
281                session_id = entry.id.clone();
282                cwd = c.clone();
283            } else {
284                entries.push(entry);
285            }
286        }
287
288        Self::from_entries(entries, session_id, cwd)
289    }
290
291    /// Append a single entry to file (for incremental saves)
292    pub fn append_to_file(
293        &self,
294        path: &Path,
295        entry: &SessionEntry,
296    ) -> Result<(), SessionTreeError> {
297        let mut file = OpenOptions::new().create(true).append(true).open(path)?;
298
299        writeln!(file, "{}", serde_json::to_string(entry)?)?;
300        Ok(())
301    }
302
303    fn sort_entries(&self) -> Result<Vec<SessionEntry>, SessionTreeError> {
304        let mut sorted = Vec::new();
305        let mut visited: std::collections::HashSet<EntryId> = std::collections::HashSet::new();
306
307        let roots: Vec<_> = self
308            .entries
309            .values()
310            .filter(|e| e.parent_id.is_none())
311            .collect();
312
313        for root in roots {
314            self.sort_dfs(root, &mut sorted, &mut visited)?;
315        }
316
317        Ok(sorted)
318    }
319
320    fn sort_dfs(
321        &self,
322        entry: &SessionEntry,
323        sorted: &mut Vec<SessionEntry>,
324        visited: &mut std::collections::HashSet<EntryId>,
325    ) -> Result<(), SessionTreeError> {
326        if visited.contains(&entry.id) {
327            return Ok(());
328        }
329
330        visited.insert(entry.id.clone());
331        sorted.push(entry.clone());
332
333        for child in self.entries.values() {
334            if child.parent_id.as_ref() == Some(&entry.id) {
335                self.sort_dfs(child, sorted, visited)?;
336            }
337        }
338
339        Ok(())
340    }
341}
342
343#[derive(Debug, thiserror::Error)]
344pub enum SessionTreeError {
345    #[error("Entry not found: {0}")]
346    EntryNotFound(String),
347    #[error("Invalid parent: expected {expected:?}, got {got:?}")]
348    InvalidParent {
349        expected: String,
350        got: Option<String>,
351    },
352    #[error("IO error: {0}")]
353    IoError(#[from] std::io::Error),
354    #[error("JSON error: {0}")]
355    JsonError(#[from] serde_json::Error),
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_session_entry_serialization() {
364        let entry = SessionEntry {
365            id: "a1b2c3d4".to_string(),
366            parent_id: None,
367            timestamp: "2024-01-01T00:00:00Z".to_string(),
368            entry_type: SessionEntryType::Message {
369                message: SerializableMessage {
370                    role: "user".to_string(),
371                    content: Some("Hello".to_string()),
372                    tool_calls: None,
373                    tool_call_id: None,
374                    cache_control: None,
375                },
376            },
377        };
378
379        let json = serde_json::to_string(&entry).unwrap();
380        assert!(json.contains("\"id\":\"a1b2c3d4\""));
381        assert!(json.contains("\"type\":\"message\""));
382
383        let parsed: SessionEntry = serde_json::from_str(&json).unwrap();
384        assert_eq!(parsed.id, entry.id);
385    }
386
387    #[test]
388    fn test_build_context_linear() {
389        let mut tree = SessionTree::new("/test".to_string());
390
391        let msg1 = SessionEntry {
392            id: "a1b2c3d4".to_string(),
393            parent_id: None,
394            timestamp: "2024-01-01T00:00:00Z".to_string(),
395            entry_type: SessionEntryType::Message {
396                message: SerializableMessage::from(Message {
397                    role: Role::User,
398                    content: Some(limit_llm::MessageContent::text("Hello")),
399                    tool_calls: None,
400                    tool_call_id: None,
401                    cache_control: None,
402                }),
403            },
404        };
405
406        let msg2 = SessionEntry {
407            id: "b2c3d4e5".to_string(),
408            parent_id: Some("a1b2c3d4".to_string()),
409            timestamp: "2024-01-01T00:01:00Z".to_string(),
410            entry_type: SessionEntryType::Message {
411                message: SerializableMessage::from(Message {
412                    role: Role::Assistant,
413                    content: Some(limit_llm::MessageContent::text("Hi!")),
414                    tool_calls: None,
415                    tool_call_id: None,
416                    cache_control: None,
417                }),
418            },
419        };
420
421        tree.append(msg1).unwrap();
422        tree.append(msg2).unwrap();
423
424        let messages = tree.build_context("b2c3d4e5").unwrap();
425        assert_eq!(messages.len(), 2);
426        assert_eq!(messages[0].content.as_ref().unwrap().to_text(), "Hello");
427        assert_eq!(messages[1].content.as_ref().unwrap().to_text(), "Hi!");
428    }
429
430    #[test]
431    fn test_build_context_with_branching() {
432        let mut tree = SessionTree::new("/test".to_string());
433
434        // Root -> A -> B
435        //         \
436        //          -> C (branch from A)
437        let root = create_test_entry("root", None, "root content");
438        let a = create_test_entry("a", Some("root"), "a content");
439        let b = create_test_entry("b", Some("a"), "b content");
440        let c = create_test_entry("c", Some("a"), "c content");
441
442        // Build main path: root -> A -> B
443        tree.append(root).unwrap();
444        tree.append(a).unwrap();
445        tree.append(b).unwrap();
446
447        // Create branch from A to C
448        tree.branch_from("a").unwrap();
449        tree.append(c).unwrap();
450
451        // Build context from B: root -> A -> B
452        let context_b = tree.build_context("b").unwrap();
453        assert_eq!(context_b.len(), 3);
454
455        // Build context from C: root -> A -> C
456        let context_c = tree.build_context("c").unwrap();
457        assert_eq!(context_c.len(), 3);
458        assert_eq!(
459            context_c[2].content.as_ref().unwrap().to_text(),
460            "c content"
461        );
462    }
463
464    fn create_test_entry(id: &str, parent_id: Option<&str>, content: &str) -> SessionEntry {
465        SessionEntry {
466            id: id.to_string(),
467            parent_id: parent_id.map(|s| s.to_string()),
468            timestamp: "2024-01-01T00:00:00Z".to_string(),
469            entry_type: SessionEntryType::Message {
470                message: SerializableMessage::from(Message {
471                    role: Role::User,
472                    content: Some(limit_llm::MessageContent::text(content)),
473                    tool_calls: None,
474                    tool_call_id: None,
475                    cache_control: None,
476                }),
477            },
478        }
479    }
480
481    #[test]
482    fn test_jsonl_roundtrip() {
483        let mut tree = SessionTree::new("/test".to_string());
484
485        let entry1 = create_test_entry("a1b2c3d4", None, "first");
486        let entry2 = create_test_entry("b2c3d4e5", Some("a1b2c3d4"), "second");
487
488        tree.append(entry1).unwrap();
489        tree.append(entry2).unwrap();
490
491        let file = tempfile::NamedTempFile::new().unwrap();
492        tree.save_to_file(file.path()).unwrap();
493
494        let loaded = SessionTree::load_from_file(file.path()).unwrap();
495
496        assert_eq!(loaded.leaf_id(), "b2c3d4e5");
497        assert_eq!(loaded.entries().len(), 2);
498
499        let context = loaded.build_context("b2c3d4e5").unwrap();
500        assert_eq!(context.len(), 2);
501    }
502
503    #[test]
504    fn test_jsonl_format() {
505        let mut tree = SessionTree::new("/test".to_string());
506        tree.append(create_test_entry("a1b2c3d4", None, "test"))
507            .unwrap();
508
509        let file = tempfile::NamedTempFile::new().unwrap();
510        tree.save_to_file(file.path()).unwrap();
511
512        let content = std::fs::read_to_string(file.path()).unwrap();
513
514        for line in content.lines() {
515            if !line.is_empty() {
516                serde_json::from_str::<serde_json::Value>(line).expect("Line should be valid JSON");
517            }
518        }
519    }
520}