1use limit_llm::{CacheControl, Message, Role, ToolCall};
2use rand::RngExt;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs::{File, OpenOptions};
6use std::io::{BufRead, BufReader, BufWriter, Write};
7use std::path::Path;
8
9pub type EntryId = String;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SessionEntry {
15 pub id: EntryId,
17 pub parent_id: Option<EntryId>,
19 pub timestamp: String,
21 #[serde(flatten)]
23 pub entry_type: SessionEntryType,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "type", rename_all = "snake_case")]
29pub enum SessionEntryType {
30 Session { version: u32, cwd: String },
32 Message { message: SerializableMessage },
34 Compaction {
36 summary: String,
37 first_kept_id: EntryId,
38 },
39 BranchSummary { from_id: EntryId, summary: String },
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SerializableMessage {
46 pub role: String,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub content: Option<String>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub tool_calls: Option<Vec<ToolCall>>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub tool_call_id: Option<String>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub cache_control: Option<CacheControl>,
55}
56
57impl From<Message> for SerializableMessage {
58 fn from(msg: Message) -> Self {
59 Self {
60 role: match msg.role {
61 Role::User => "user".to_string(),
62 Role::Assistant => "assistant".to_string(),
63 Role::System => "system".to_string(),
64 Role::Tool => "tool".to_string(),
65 },
66 content: msg.content.map(|c| c.to_text()),
67 tool_calls: msg.tool_calls,
68 tool_call_id: msg.tool_call_id,
69 cache_control: msg.cache_control,
70 }
71 }
72}
73
74impl From<SerializableMessage> for Message {
75 fn from(msg: SerializableMessage) -> Self {
76 Self {
77 role: match msg.role.as_str() {
78 "user" => Role::User,
79 "assistant" => Role::Assistant,
80 "system" => Role::System,
81 "tool" => Role::Tool,
82 _ => Role::User,
83 },
84 content: msg.content.map(limit_llm::MessageContent::text),
85 tool_calls: msg.tool_calls,
86 tool_call_id: msg.tool_call_id,
87 cache_control: msg.cache_control,
88 }
89 }
90}
91
92pub fn generate_entry_id() -> EntryId {
94 let mut rng = rand::rng();
95 format!("{:08x}", rng.random::<u32>())
96}
97
98pub struct SessionTree {
100 entries: HashMap<EntryId, SessionEntry>,
102 leaf_id: EntryId,
104 session_id: String,
106 cwd: String,
108}
109
110impl SessionTree {
111 pub fn new(cwd: String) -> Self {
113 let session_id = uuid::Uuid::new_v4().to_string();
114 Self {
115 entries: HashMap::new(),
116 leaf_id: String::new(),
117 session_id,
118 cwd,
119 }
120 }
121
122 pub fn from_entries(
124 entries: Vec<SessionEntry>,
125 session_id: String,
126 cwd: String,
127 ) -> Result<Self, SessionTreeError> {
128 let mut by_id: HashMap<EntryId, SessionEntry> = HashMap::new();
129 let mut leaf_id = String::new();
130
131 for entry in entries {
132 leaf_id = entry.id.clone();
133 by_id.insert(entry.id.clone(), entry);
134 }
135
136 Ok(Self {
137 entries: by_id,
138 leaf_id,
139 session_id,
140 cwd,
141 })
142 }
143
144 pub fn append(&mut self, entry: SessionEntry) -> Result<(), SessionTreeError> {
146 let id = entry.id.clone();
147
148 if self.entries.is_empty() {
149 if entry.parent_id.is_some() {
150 return Err(SessionTreeError::InvalidParent {
151 expected: "none (first entry)".to_string(),
152 got: entry.parent_id.clone(),
153 });
154 }
155 } else if entry.parent_id.as_ref() != Some(&self.leaf_id) {
156 return Err(SessionTreeError::InvalidParent {
157 expected: self.leaf_id.clone(),
158 got: entry.parent_id.clone(),
159 });
160 }
161
162 self.entries.insert(id.clone(), entry);
163 self.leaf_id = id;
164 Ok(())
165 }
166
167 pub fn build_context(&self, leaf_id: &str) -> Result<Vec<Message>, SessionTreeError> {
169 let mut path = Vec::new();
170 let mut current_id = Some(leaf_id.to_string());
171
172 while let Some(id) = current_id {
173 let entry = self
174 .entries
175 .get(&id)
176 .ok_or(SessionTreeError::EntryNotFound(id))?;
177
178 if let SessionEntryType::Compaction { first_kept_id, .. } = &entry.entry_type {
180 current_id = Some(first_kept_id.clone());
181 path.push(entry.clone());
182 continue;
183 }
184
185 current_id = entry.parent_id.clone();
186 path.push(entry.clone());
187 }
188
189 path.reverse();
190
191 let messages: Vec<Message> = path
192 .into_iter()
193 .filter_map(|entry| match entry.entry_type {
194 SessionEntryType::Message { message } => Some(Message::from(message)),
195 SessionEntryType::Compaction { summary, .. } => Some(Message {
196 role: Role::User,
197 content: Some(limit_llm::MessageContent::text(format!(
198 "<summary>\n{}\n</summary>",
199 summary
200 ))),
201 tool_calls: None,
202 tool_call_id: None,
203 cache_control: None,
204 }),
205 SessionEntryType::Session { .. } => None,
206 SessionEntryType::BranchSummary { .. } => None,
207 })
208 .collect();
209
210 Ok(messages)
211 }
212
213 pub fn branch_from(&mut self, entry_id: &str) -> Result<EntryId, SessionTreeError> {
215 if !self.entries.contains_key(entry_id) {
216 return Err(SessionTreeError::EntryNotFound(entry_id.to_string()));
217 }
218
219 self.leaf_id = entry_id.to_string();
220 Ok(entry_id.to_string())
221 }
222
223 pub fn leaf_id(&self) -> &str {
225 &self.leaf_id
226 }
227
228 pub fn entries(&self) -> Vec<&SessionEntry> {
230 self.entries.values().collect()
231 }
232
233 pub fn session_id(&self) -> &str {
235 &self.session_id
236 }
237
238 pub fn save_to_file(&self, path: &Path) -> Result<(), SessionTreeError> {
240 let file = File::create(path)?;
241 let mut writer = BufWriter::new(file);
242
243 let header = SessionEntry {
244 id: self.session_id.clone(),
245 parent_id: None,
246 timestamp: chrono::Utc::now().to_rfc3339(),
247 entry_type: SessionEntryType::Session {
248 version: 1,
249 cwd: self.cwd.clone(),
250 },
251 };
252 writeln!(writer, "{}", serde_json::to_string(&header)?)?;
253
254 let sorted = self.sort_entries()?;
255 for entry in sorted {
256 writeln!(writer, "{}", serde_json::to_string(&entry)?)?;
257 }
258
259 writer.flush()?;
260 Ok(())
261 }
262
263 pub fn load_from_file(path: &Path) -> Result<Self, SessionTreeError> {
265 let file = File::open(path)?;
266 let reader = BufReader::new(file);
267
268 let mut entries = Vec::new();
269 let mut session_id = String::new();
270 let mut cwd = String::new();
271
272 for line in reader.lines() {
273 let line: String = line?;
274 if line.trim().is_empty() {
275 continue;
276 }
277
278 let entry: SessionEntry = serde_json::from_str(&line)?;
279
280 if let SessionEntryType::Session { version: _, cwd: c } = &entry.entry_type {
281 session_id = entry.id.clone();
282 cwd = c.clone();
283 } else {
284 entries.push(entry);
285 }
286 }
287
288 Self::from_entries(entries, session_id, cwd)
289 }
290
291 pub fn append_to_file(
293 &self,
294 path: &Path,
295 entry: &SessionEntry,
296 ) -> Result<(), SessionTreeError> {
297 let mut file = OpenOptions::new().create(true).append(true).open(path)?;
298
299 writeln!(file, "{}", serde_json::to_string(entry)?)?;
300 Ok(())
301 }
302
303 fn sort_entries(&self) -> Result<Vec<SessionEntry>, SessionTreeError> {
304 let mut sorted = Vec::new();
305 let mut visited: std::collections::HashSet<EntryId> = std::collections::HashSet::new();
306
307 let roots: Vec<_> = self
308 .entries
309 .values()
310 .filter(|e| e.parent_id.is_none())
311 .collect();
312
313 for root in roots {
314 self.sort_dfs(root, &mut sorted, &mut visited)?;
315 }
316
317 Ok(sorted)
318 }
319
320 fn sort_dfs(
321 &self,
322 entry: &SessionEntry,
323 sorted: &mut Vec<SessionEntry>,
324 visited: &mut std::collections::HashSet<EntryId>,
325 ) -> Result<(), SessionTreeError> {
326 if visited.contains(&entry.id) {
327 return Ok(());
328 }
329
330 visited.insert(entry.id.clone());
331 sorted.push(entry.clone());
332
333 for child in self.entries.values() {
334 if child.parent_id.as_ref() == Some(&entry.id) {
335 self.sort_dfs(child, sorted, visited)?;
336 }
337 }
338
339 Ok(())
340 }
341}
342
343#[derive(Debug, thiserror::Error)]
344pub enum SessionTreeError {
345 #[error("Entry not found: {0}")]
346 EntryNotFound(String),
347 #[error("Invalid parent: expected {expected:?}, got {got:?}")]
348 InvalidParent {
349 expected: String,
350 got: Option<String>,
351 },
352 #[error("IO error: {0}")]
353 IoError(#[from] std::io::Error),
354 #[error("JSON error: {0}")]
355 JsonError(#[from] serde_json::Error),
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_session_entry_serialization() {
364 let entry = SessionEntry {
365 id: "a1b2c3d4".to_string(),
366 parent_id: None,
367 timestamp: "2024-01-01T00:00:00Z".to_string(),
368 entry_type: SessionEntryType::Message {
369 message: SerializableMessage {
370 role: "user".to_string(),
371 content: Some("Hello".to_string()),
372 tool_calls: None,
373 tool_call_id: None,
374 cache_control: None,
375 },
376 },
377 };
378
379 let json = serde_json::to_string(&entry).unwrap();
380 assert!(json.contains("\"id\":\"a1b2c3d4\""));
381 assert!(json.contains("\"type\":\"message\""));
382
383 let parsed: SessionEntry = serde_json::from_str(&json).unwrap();
384 assert_eq!(parsed.id, entry.id);
385 }
386
387 #[test]
388 fn test_build_context_linear() {
389 let mut tree = SessionTree::new("/test".to_string());
390
391 let msg1 = SessionEntry {
392 id: "a1b2c3d4".to_string(),
393 parent_id: None,
394 timestamp: "2024-01-01T00:00:00Z".to_string(),
395 entry_type: SessionEntryType::Message {
396 message: SerializableMessage::from(Message {
397 role: Role::User,
398 content: Some(limit_llm::MessageContent::text("Hello")),
399 tool_calls: None,
400 tool_call_id: None,
401 cache_control: None,
402 }),
403 },
404 };
405
406 let msg2 = SessionEntry {
407 id: "b2c3d4e5".to_string(),
408 parent_id: Some("a1b2c3d4".to_string()),
409 timestamp: "2024-01-01T00:01:00Z".to_string(),
410 entry_type: SessionEntryType::Message {
411 message: SerializableMessage::from(Message {
412 role: Role::Assistant,
413 content: Some(limit_llm::MessageContent::text("Hi!")),
414 tool_calls: None,
415 tool_call_id: None,
416 cache_control: None,
417 }),
418 },
419 };
420
421 tree.append(msg1).unwrap();
422 tree.append(msg2).unwrap();
423
424 let messages = tree.build_context("b2c3d4e5").unwrap();
425 assert_eq!(messages.len(), 2);
426 assert_eq!(messages[0].content.as_ref().unwrap().to_text(), "Hello");
427 assert_eq!(messages[1].content.as_ref().unwrap().to_text(), "Hi!");
428 }
429
430 #[test]
431 fn test_build_context_with_branching() {
432 let mut tree = SessionTree::new("/test".to_string());
433
434 let root = create_test_entry("root", None, "root content");
438 let a = create_test_entry("a", Some("root"), "a content");
439 let b = create_test_entry("b", Some("a"), "b content");
440 let c = create_test_entry("c", Some("a"), "c content");
441
442 tree.append(root).unwrap();
444 tree.append(a).unwrap();
445 tree.append(b).unwrap();
446
447 tree.branch_from("a").unwrap();
449 tree.append(c).unwrap();
450
451 let context_b = tree.build_context("b").unwrap();
453 assert_eq!(context_b.len(), 3);
454
455 let context_c = tree.build_context("c").unwrap();
457 assert_eq!(context_c.len(), 3);
458 assert_eq!(
459 context_c[2].content.as_ref().unwrap().to_text(),
460 "c content"
461 );
462 }
463
464 fn create_test_entry(id: &str, parent_id: Option<&str>, content: &str) -> SessionEntry {
465 SessionEntry {
466 id: id.to_string(),
467 parent_id: parent_id.map(|s| s.to_string()),
468 timestamp: "2024-01-01T00:00:00Z".to_string(),
469 entry_type: SessionEntryType::Message {
470 message: SerializableMessage::from(Message {
471 role: Role::User,
472 content: Some(limit_llm::MessageContent::text(content)),
473 tool_calls: None,
474 tool_call_id: None,
475 cache_control: None,
476 }),
477 },
478 }
479 }
480
481 #[test]
482 fn test_jsonl_roundtrip() {
483 let mut tree = SessionTree::new("/test".to_string());
484
485 let entry1 = create_test_entry("a1b2c3d4", None, "first");
486 let entry2 = create_test_entry("b2c3d4e5", Some("a1b2c3d4"), "second");
487
488 tree.append(entry1).unwrap();
489 tree.append(entry2).unwrap();
490
491 let file = tempfile::NamedTempFile::new().unwrap();
492 tree.save_to_file(file.path()).unwrap();
493
494 let loaded = SessionTree::load_from_file(file.path()).unwrap();
495
496 assert_eq!(loaded.leaf_id(), "b2c3d4e5");
497 assert_eq!(loaded.entries().len(), 2);
498
499 let context = loaded.build_context("b2c3d4e5").unwrap();
500 assert_eq!(context.len(), 2);
501 }
502
503 #[test]
504 fn test_jsonl_format() {
505 let mut tree = SessionTree::new("/test".to_string());
506 tree.append(create_test_entry("a1b2c3d4", None, "test"))
507 .unwrap();
508
509 let file = tempfile::NamedTempFile::new().unwrap();
510 tree.save_to_file(file.path()).unwrap();
511
512 let content = std::fs::read_to_string(file.path()).unwrap();
513
514 for line in content.lines() {
515 if !line.is_empty() {
516 serde_json::from_str::<serde_json::Value>(line).expect("Line should be valid JSON");
517 }
518 }
519 }
520}