Skip to main content

limit_cli/
session_tree.rs

1use limit_llm::{CacheControl, Message, Role, ToolCall};
2use rand::Rng;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs::{File, OpenOptions};
6use std::io::{BufRead, BufReader, BufWriter, Write};
7use std::path::Path;
8
9/// Unique 8-character hex ID for session entries
10pub type EntryId = String;
11
12/// A single entry in the session tree
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SessionEntry {
15    /// 8-char hex ID
16    pub id: EntryId,
17    /// Parent entry ID (None for root)
18    pub parent_id: Option<EntryId>,
19    /// ISO 8601 timestamp
20    pub timestamp: String,
21    /// The entry content
22    #[serde(flatten)]
23    pub entry_type: SessionEntryType,
24}
25
26/// Types of entries in the session tree
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "type", rename_all = "snake_case")]
29pub enum SessionEntryType {
30    /// Session metadata (first entry in file)
31    Session { version: u32, cwd: String },
32    /// User/assistant/tool message
33    Message { message: SerializableMessage },
34    /// Compaction summary (replaces old messages)
35    Compaction {
36        summary: String,
37        first_kept_id: EntryId,
38    },
39    /// Branch context when switching branches
40    BranchSummary { from_id: EntryId, summary: String },
41}
42
43/// Message that can be serialized to JSON
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SerializableMessage {
46    pub role: String,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub content: Option<String>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub tool_calls: Option<Vec<ToolCall>>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub tool_call_id: Option<String>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub cache_control: Option<CacheControl>,
55}
56
57impl From<Message> for SerializableMessage {
58    fn from(msg: Message) -> Self {
59        Self {
60            role: match msg.role {
61                Role::User => "user".to_string(),
62                Role::Assistant => "assistant".to_string(),
63                Role::System => "system".to_string(),
64                Role::Tool => "tool".to_string(),
65            },
66            content: msg.content,
67            tool_calls: msg.tool_calls,
68            tool_call_id: msg.tool_call_id,
69            cache_control: msg.cache_control,
70        }
71    }
72}
73
74impl From<SerializableMessage> for Message {
75    fn from(msg: SerializableMessage) -> Self {
76        Self {
77            role: match msg.role.as_str() {
78                "user" => Role::User,
79                "assistant" => Role::Assistant,
80                "system" => Role::System,
81                "tool" => Role::Tool,
82                _ => Role::User,
83            },
84            content: msg.content,
85            tool_calls: msg.tool_calls,
86            tool_call_id: msg.tool_call_id,
87            cache_control: msg.cache_control,
88        }
89    }
90}
91
92/// Generate a random 8-char hex ID
93pub fn generate_entry_id() -> EntryId {
94    let mut rng = rand::thread_rng();
95    format!("{:08x}", rng.gen::<u32>())
96}
97
98/// In-memory tree structure for session entries
99pub struct SessionTree {
100    /// All entries indexed by ID
101    entries: HashMap<EntryId, SessionEntry>,
102    /// Current leaf entry ID
103    leaf_id: EntryId,
104    /// Session metadata
105    session_id: String,
106    /// Working directory when session was created
107    cwd: String,
108}
109
110impl SessionTree {
111    /// Create a new empty session tree
112    pub fn new(cwd: String) -> Self {
113        let session_id = uuid::Uuid::new_v4().to_string();
114        Self {
115            entries: HashMap::new(),
116            leaf_id: String::new(),
117            session_id,
118            cwd,
119        }
120    }
121
122    /// Load from existing entries
123    pub fn from_entries(
124        entries: Vec<SessionEntry>,
125        session_id: String,
126        cwd: String,
127    ) -> Result<Self, SessionTreeError> {
128        let mut by_id: HashMap<EntryId, SessionEntry> = HashMap::new();
129        let mut leaf_id = String::new();
130
131        for entry in entries {
132            leaf_id = entry.id.clone();
133            by_id.insert(entry.id.clone(), entry);
134        }
135
136        Ok(Self {
137            entries: by_id,
138            leaf_id,
139            session_id,
140            cwd,
141        })
142    }
143
144    /// Append a new entry as child of current leaf
145    pub fn append(&mut self, entry: SessionEntry) -> Result<(), SessionTreeError> {
146        let id = entry.id.clone();
147
148        if self.entries.is_empty() {
149            if entry.parent_id.is_some() {
150                return Err(SessionTreeError::InvalidParent {
151                    expected: "none (first entry)".to_string(),
152                    got: entry.parent_id.clone(),
153                });
154            }
155        } else if entry.parent_id.as_ref() != Some(&self.leaf_id) {
156            return Err(SessionTreeError::InvalidParent {
157                expected: self.leaf_id.clone(),
158                got: entry.parent_id.clone(),
159            });
160        }
161
162        self.entries.insert(id.clone(), entry);
163        self.leaf_id = id;
164        Ok(())
165    }
166
167    /// Build message context from leaf to root
168    pub fn build_context(&self, leaf_id: &str) -> Result<Vec<Message>, SessionTreeError> {
169        let mut path = Vec::new();
170        let mut current_id = Some(leaf_id.to_string());
171
172        while let Some(id) = current_id {
173            let entry = self
174                .entries
175                .get(&id)
176                .ok_or(SessionTreeError::EntryNotFound(id))?;
177
178            // Handle compaction: stop at first_kept_id
179            if let SessionEntryType::Compaction { first_kept_id, .. } = &entry.entry_type {
180                current_id = Some(first_kept_id.clone());
181                path.push(entry.clone());
182                continue;
183            }
184
185            current_id = entry.parent_id.clone();
186            path.push(entry.clone());
187        }
188
189        path.reverse();
190
191        let messages: Vec<Message> = path
192            .into_iter()
193            .filter_map(|entry| match entry.entry_type {
194                SessionEntryType::Message { message } => Some(Message::from(message)),
195                SessionEntryType::Compaction { summary, .. } => Some(Message {
196                    role: Role::User,
197                    content: Some(format!("<summary>\n{}\n</summary>", summary)),
198                    tool_calls: None,
199                    tool_call_id: None,
200                    cache_control: None,
201                }),
202                SessionEntryType::Session { .. } => None,
203                SessionEntryType::BranchSummary { .. } => None,
204            })
205            .collect();
206
207        Ok(messages)
208    }
209
210    /// Create a branch from a specific entry
211    pub fn branch_from(&mut self, entry_id: &str) -> Result<EntryId, SessionTreeError> {
212        if !self.entries.contains_key(entry_id) {
213            return Err(SessionTreeError::EntryNotFound(entry_id.to_string()));
214        }
215
216        self.leaf_id = entry_id.to_string();
217        Ok(entry_id.to_string())
218    }
219
220    /// Get current leaf ID
221    pub fn leaf_id(&self) -> &str {
222        &self.leaf_id
223    }
224
225    /// Get all entries
226    pub fn entries(&self) -> Vec<&SessionEntry> {
227        self.entries.values().collect()
228    }
229
230    /// Get session ID
231    pub fn session_id(&self) -> &str {
232        &self.session_id
233    }
234
235    /// Save tree to JSONL file
236    pub fn save_to_file(&self, path: &Path) -> Result<(), SessionTreeError> {
237        let file = File::create(path)?;
238        let mut writer = BufWriter::new(file);
239
240        let header = SessionEntry {
241            id: self.session_id.clone(),
242            parent_id: None,
243            timestamp: chrono::Utc::now().to_rfc3339(),
244            entry_type: SessionEntryType::Session {
245                version: 1,
246                cwd: self.cwd.clone(),
247            },
248        };
249        writeln!(writer, "{}", serde_json::to_string(&header)?)?;
250
251        let sorted = self.sort_entries()?;
252        for entry in sorted {
253            writeln!(writer, "{}", serde_json::to_string(&entry)?)?;
254        }
255
256        writer.flush()?;
257        Ok(())
258    }
259
260    /// Load tree from JSONL file
261    pub fn load_from_file(path: &Path) -> Result<Self, SessionTreeError> {
262        let file = File::open(path)?;
263        let reader = BufReader::new(file);
264
265        let mut entries = Vec::new();
266        let mut session_id = String::new();
267        let mut cwd = String::new();
268
269        for line in reader.lines() {
270            let line: String = line?;
271            if line.trim().is_empty() {
272                continue;
273            }
274
275            let entry: SessionEntry = serde_json::from_str(&line)?;
276
277            if let SessionEntryType::Session { version: _, cwd: c } = &entry.entry_type {
278                session_id = entry.id.clone();
279                cwd = c.clone();
280            } else {
281                entries.push(entry);
282            }
283        }
284
285        Self::from_entries(entries, session_id, cwd)
286    }
287
288    /// Append a single entry to file (for incremental saves)
289    pub fn append_to_file(
290        &self,
291        path: &Path,
292        entry: &SessionEntry,
293    ) -> Result<(), SessionTreeError> {
294        let mut file = OpenOptions::new().create(true).append(true).open(path)?;
295
296        writeln!(file, "{}", serde_json::to_string(entry)?)?;
297        Ok(())
298    }
299
300    fn sort_entries(&self) -> Result<Vec<SessionEntry>, SessionTreeError> {
301        let mut sorted = Vec::new();
302        let mut visited: std::collections::HashSet<EntryId> = std::collections::HashSet::new();
303
304        let roots: Vec<_> = self
305            .entries
306            .values()
307            .filter(|e| e.parent_id.is_none())
308            .collect();
309
310        for root in roots {
311            self.sort_dfs(root, &mut sorted, &mut visited)?;
312        }
313
314        Ok(sorted)
315    }
316
317    fn sort_dfs(
318        &self,
319        entry: &SessionEntry,
320        sorted: &mut Vec<SessionEntry>,
321        visited: &mut std::collections::HashSet<EntryId>,
322    ) -> Result<(), SessionTreeError> {
323        if visited.contains(&entry.id) {
324            return Ok(());
325        }
326
327        visited.insert(entry.id.clone());
328        sorted.push(entry.clone());
329
330        for child in self.entries.values() {
331            if child.parent_id.as_ref() == Some(&entry.id) {
332                self.sort_dfs(child, sorted, visited)?;
333            }
334        }
335
336        Ok(())
337    }
338}
339
340#[derive(Debug, thiserror::Error)]
341pub enum SessionTreeError {
342    #[error("Entry not found: {0}")]
343    EntryNotFound(String),
344    #[error("Invalid parent: expected {expected:?}, got {got:?}")]
345    InvalidParent {
346        expected: String,
347        got: Option<String>,
348    },
349    #[error("IO error: {0}")]
350    IoError(#[from] std::io::Error),
351    #[error("JSON error: {0}")]
352    JsonError(#[from] serde_json::Error),
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_session_entry_serialization() {
361        let entry = SessionEntry {
362            id: "a1b2c3d4".to_string(),
363            parent_id: None,
364            timestamp: "2024-01-01T00:00:00Z".to_string(),
365            entry_type: SessionEntryType::Message {
366                message: SerializableMessage {
367                    role: "user".to_string(),
368                    content: Some("Hello".to_string()),
369                    tool_calls: None,
370                    tool_call_id: None,
371                    cache_control: None,
372                },
373            },
374        };
375
376        let json = serde_json::to_string(&entry).unwrap();
377        assert!(json.contains("\"id\":\"a1b2c3d4\""));
378        assert!(json.contains("\"type\":\"message\""));
379
380        let parsed: SessionEntry = serde_json::from_str(&json).unwrap();
381        assert_eq!(parsed.id, entry.id);
382    }
383
384    #[test]
385    fn test_build_context_linear() {
386        let mut tree = SessionTree::new("/test".to_string());
387
388        let msg1 = SessionEntry {
389            id: "a1b2c3d4".to_string(),
390            parent_id: None,
391            timestamp: "2024-01-01T00:00:00Z".to_string(),
392            entry_type: SessionEntryType::Message {
393                message: SerializableMessage::from(Message {
394                    role: Role::User,
395                    content: Some("Hello".to_string()),
396                    tool_calls: None,
397                    tool_call_id: None,
398                    cache_control: None,
399                }),
400            },
401        };
402
403        let msg2 = SessionEntry {
404            id: "b2c3d4e5".to_string(),
405            parent_id: Some("a1b2c3d4".to_string()),
406            timestamp: "2024-01-01T00:01:00Z".to_string(),
407            entry_type: SessionEntryType::Message {
408                message: SerializableMessage::from(Message {
409                    role: Role::Assistant,
410                    content: Some("Hi!".to_string()),
411                    tool_calls: None,
412                    tool_call_id: None,
413                    cache_control: None,
414                }),
415            },
416        };
417
418        tree.append(msg1).unwrap();
419        tree.append(msg2).unwrap();
420
421        let messages = tree.build_context("b2c3d4e5").unwrap();
422        assert_eq!(messages.len(), 2);
423        assert_eq!(messages[0].content, Some("Hello".to_string()));
424        assert_eq!(messages[1].content, Some("Hi!".to_string()));
425    }
426
427    #[test]
428    fn test_build_context_with_branching() {
429        let mut tree = SessionTree::new("/test".to_string());
430
431        // Root -> A -> B
432        //         \
433        //          -> C (branch from A)
434        let root = create_test_entry("root", None, "root content");
435        let a = create_test_entry("a", Some("root"), "a content");
436        let b = create_test_entry("b", Some("a"), "b content");
437        let c = create_test_entry("c", Some("a"), "c content");
438
439        // Build main path: root -> A -> B
440        tree.append(root).unwrap();
441        tree.append(a).unwrap();
442        tree.append(b).unwrap();
443
444        // Create branch from A to C
445        tree.branch_from("a").unwrap();
446        tree.append(c).unwrap();
447
448        // Build context from B: root -> A -> B
449        let context_b = tree.build_context("b").unwrap();
450        assert_eq!(context_b.len(), 3);
451
452        // Build context from C: root -> A -> C
453        let context_c = tree.build_context("c").unwrap();
454        assert_eq!(context_c.len(), 3);
455        assert_eq!(context_c[2].content, Some("c content".to_string()));
456    }
457
458    fn create_test_entry(id: &str, parent_id: Option<&str>, content: &str) -> SessionEntry {
459        SessionEntry {
460            id: id.to_string(),
461            parent_id: parent_id.map(|s| s.to_string()),
462            timestamp: "2024-01-01T00:00:00Z".to_string(),
463            entry_type: SessionEntryType::Message {
464                message: SerializableMessage::from(Message {
465                    role: Role::User,
466                    content: Some(content.to_string()),
467                    tool_calls: None,
468                    tool_call_id: None,
469                    cache_control: None,
470                }),
471            },
472        }
473    }
474
475    #[test]
476    fn test_jsonl_roundtrip() {
477        let mut tree = SessionTree::new("/test".to_string());
478
479        let entry1 = create_test_entry("a1b2c3d4", None, "first");
480        let entry2 = create_test_entry("b2c3d4e5", Some("a1b2c3d4"), "second");
481
482        tree.append(entry1).unwrap();
483        tree.append(entry2).unwrap();
484
485        let file = tempfile::NamedTempFile::new().unwrap();
486        tree.save_to_file(file.path()).unwrap();
487
488        let loaded = SessionTree::load_from_file(file.path()).unwrap();
489
490        assert_eq!(loaded.leaf_id(), "b2c3d4e5");
491        assert_eq!(loaded.entries().len(), 2);
492
493        let context = loaded.build_context("b2c3d4e5").unwrap();
494        assert_eq!(context.len(), 2);
495    }
496
497    #[test]
498    fn test_jsonl_format() {
499        let mut tree = SessionTree::new("/test".to_string());
500        tree.append(create_test_entry("a1b2c3d4", None, "test"))
501            .unwrap();
502
503        let file = tempfile::NamedTempFile::new().unwrap();
504        tree.save_to_file(file.path()).unwrap();
505
506        let content = std::fs::read_to_string(file.path()).unwrap();
507
508        for line in content.lines() {
509            if !line.is_empty() {
510                serde_json::from_str::<serde_json::Value>(line).expect("Line should be valid JSON");
511            }
512        }
513    }
514}