1use limit_llm::{CacheControl, Message, Role, ToolCall};
2use rand::RngExt;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs::{File, OpenOptions};
6use std::io::{BufRead, BufReader, BufWriter, Write};
7use std::path::Path;
8
9pub type EntryId = String;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SessionEntry {
15 pub id: EntryId,
17 pub parent_id: Option<EntryId>,
19 pub timestamp: String,
21 #[serde(flatten)]
23 pub entry_type: SessionEntryType,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "type", rename_all = "snake_case")]
29pub enum SessionEntryType {
30 Session { version: u32, cwd: String },
32 Message { message: SerializableMessage },
34 Compaction {
36 summary: String,
37 first_kept_id: EntryId,
38 },
39 BranchSummary { from_id: EntryId, summary: String },
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SerializableMessage {
46 pub role: String,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub content: Option<String>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub tool_calls: Option<Vec<ToolCall>>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub tool_call_id: Option<String>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub cache_control: Option<CacheControl>,
55}
56
57impl From<Message> for SerializableMessage {
58 fn from(msg: Message) -> Self {
59 Self {
60 role: match msg.role {
61 Role::User => "user".to_string(),
62 Role::Assistant => "assistant".to_string(),
63 Role::System => "system".to_string(),
64 Role::Tool => "tool".to_string(),
65 },
66 content: msg.content,
67 tool_calls: msg.tool_calls,
68 tool_call_id: msg.tool_call_id,
69 cache_control: msg.cache_control,
70 }
71 }
72}
73
74impl From<SerializableMessage> for Message {
75 fn from(msg: SerializableMessage) -> Self {
76 Self {
77 role: match msg.role.as_str() {
78 "user" => Role::User,
79 "assistant" => Role::Assistant,
80 "system" => Role::System,
81 "tool" => Role::Tool,
82 _ => Role::User,
83 },
84 content: msg.content,
85 tool_calls: msg.tool_calls,
86 tool_call_id: msg.tool_call_id,
87 cache_control: msg.cache_control,
88 }
89 }
90}
91
92pub fn generate_entry_id() -> EntryId {
94 let mut rng = rand::rng();
95 format!("{:08x}", rng.random::<u32>())
96}
97
98pub struct SessionTree {
100 entries: HashMap<EntryId, SessionEntry>,
102 leaf_id: EntryId,
104 session_id: String,
106 cwd: String,
108}
109
110impl SessionTree {
111 pub fn new(cwd: String) -> Self {
113 let session_id = uuid::Uuid::new_v4().to_string();
114 Self {
115 entries: HashMap::new(),
116 leaf_id: String::new(),
117 session_id,
118 cwd,
119 }
120 }
121
122 pub fn from_entries(
124 entries: Vec<SessionEntry>,
125 session_id: String,
126 cwd: String,
127 ) -> Result<Self, SessionTreeError> {
128 let mut by_id: HashMap<EntryId, SessionEntry> = HashMap::new();
129 let mut leaf_id = String::new();
130
131 for entry in entries {
132 leaf_id = entry.id.clone();
133 by_id.insert(entry.id.clone(), entry);
134 }
135
136 Ok(Self {
137 entries: by_id,
138 leaf_id,
139 session_id,
140 cwd,
141 })
142 }
143
144 pub fn append(&mut self, entry: SessionEntry) -> Result<(), SessionTreeError> {
146 let id = entry.id.clone();
147
148 if self.entries.is_empty() {
149 if entry.parent_id.is_some() {
150 return Err(SessionTreeError::InvalidParent {
151 expected: "none (first entry)".to_string(),
152 got: entry.parent_id.clone(),
153 });
154 }
155 } else if entry.parent_id.as_ref() != Some(&self.leaf_id) {
156 return Err(SessionTreeError::InvalidParent {
157 expected: self.leaf_id.clone(),
158 got: entry.parent_id.clone(),
159 });
160 }
161
162 self.entries.insert(id.clone(), entry);
163 self.leaf_id = id;
164 Ok(())
165 }
166
167 pub fn build_context(&self, leaf_id: &str) -> Result<Vec<Message>, SessionTreeError> {
169 let mut path = Vec::new();
170 let mut current_id = Some(leaf_id.to_string());
171
172 while let Some(id) = current_id {
173 let entry = self
174 .entries
175 .get(&id)
176 .ok_or(SessionTreeError::EntryNotFound(id))?;
177
178 if let SessionEntryType::Compaction { first_kept_id, .. } = &entry.entry_type {
180 current_id = Some(first_kept_id.clone());
181 path.push(entry.clone());
182 continue;
183 }
184
185 current_id = entry.parent_id.clone();
186 path.push(entry.clone());
187 }
188
189 path.reverse();
190
191 let messages: Vec<Message> = path
192 .into_iter()
193 .filter_map(|entry| match entry.entry_type {
194 SessionEntryType::Message { message } => Some(Message::from(message)),
195 SessionEntryType::Compaction { summary, .. } => Some(Message {
196 role: Role::User,
197 content: Some(format!("<summary>\n{}\n</summary>", summary)),
198 tool_calls: None,
199 tool_call_id: None,
200 cache_control: None,
201 }),
202 SessionEntryType::Session { .. } => None,
203 SessionEntryType::BranchSummary { .. } => None,
204 })
205 .collect();
206
207 Ok(messages)
208 }
209
210 pub fn branch_from(&mut self, entry_id: &str) -> Result<EntryId, SessionTreeError> {
212 if !self.entries.contains_key(entry_id) {
213 return Err(SessionTreeError::EntryNotFound(entry_id.to_string()));
214 }
215
216 self.leaf_id = entry_id.to_string();
217 Ok(entry_id.to_string())
218 }
219
220 pub fn leaf_id(&self) -> &str {
222 &self.leaf_id
223 }
224
225 pub fn entries(&self) -> Vec<&SessionEntry> {
227 self.entries.values().collect()
228 }
229
230 pub fn session_id(&self) -> &str {
232 &self.session_id
233 }
234
235 pub fn save_to_file(&self, path: &Path) -> Result<(), SessionTreeError> {
237 let file = File::create(path)?;
238 let mut writer = BufWriter::new(file);
239
240 let header = SessionEntry {
241 id: self.session_id.clone(),
242 parent_id: None,
243 timestamp: chrono::Utc::now().to_rfc3339(),
244 entry_type: SessionEntryType::Session {
245 version: 1,
246 cwd: self.cwd.clone(),
247 },
248 };
249 writeln!(writer, "{}", serde_json::to_string(&header)?)?;
250
251 let sorted = self.sort_entries()?;
252 for entry in sorted {
253 writeln!(writer, "{}", serde_json::to_string(&entry)?)?;
254 }
255
256 writer.flush()?;
257 Ok(())
258 }
259
260 pub fn load_from_file(path: &Path) -> Result<Self, SessionTreeError> {
262 let file = File::open(path)?;
263 let reader = BufReader::new(file);
264
265 let mut entries = Vec::new();
266 let mut session_id = String::new();
267 let mut cwd = String::new();
268
269 for line in reader.lines() {
270 let line: String = line?;
271 if line.trim().is_empty() {
272 continue;
273 }
274
275 let entry: SessionEntry = serde_json::from_str(&line)?;
276
277 if let SessionEntryType::Session { version: _, cwd: c } = &entry.entry_type {
278 session_id = entry.id.clone();
279 cwd = c.clone();
280 } else {
281 entries.push(entry);
282 }
283 }
284
285 Self::from_entries(entries, session_id, cwd)
286 }
287
288 pub fn append_to_file(
290 &self,
291 path: &Path,
292 entry: &SessionEntry,
293 ) -> Result<(), SessionTreeError> {
294 let mut file = OpenOptions::new().create(true).append(true).open(path)?;
295
296 writeln!(file, "{}", serde_json::to_string(entry)?)?;
297 Ok(())
298 }
299
300 fn sort_entries(&self) -> Result<Vec<SessionEntry>, SessionTreeError> {
301 let mut sorted = Vec::new();
302 let mut visited: std::collections::HashSet<EntryId> = std::collections::HashSet::new();
303
304 let roots: Vec<_> = self
305 .entries
306 .values()
307 .filter(|e| e.parent_id.is_none())
308 .collect();
309
310 for root in roots {
311 self.sort_dfs(root, &mut sorted, &mut visited)?;
312 }
313
314 Ok(sorted)
315 }
316
317 fn sort_dfs(
318 &self,
319 entry: &SessionEntry,
320 sorted: &mut Vec<SessionEntry>,
321 visited: &mut std::collections::HashSet<EntryId>,
322 ) -> Result<(), SessionTreeError> {
323 if visited.contains(&entry.id) {
324 return Ok(());
325 }
326
327 visited.insert(entry.id.clone());
328 sorted.push(entry.clone());
329
330 for child in self.entries.values() {
331 if child.parent_id.as_ref() == Some(&entry.id) {
332 self.sort_dfs(child, sorted, visited)?;
333 }
334 }
335
336 Ok(())
337 }
338}
339
340#[derive(Debug, thiserror::Error)]
341pub enum SessionTreeError {
342 #[error("Entry not found: {0}")]
343 EntryNotFound(String),
344 #[error("Invalid parent: expected {expected:?}, got {got:?}")]
345 InvalidParent {
346 expected: String,
347 got: Option<String>,
348 },
349 #[error("IO error: {0}")]
350 IoError(#[from] std::io::Error),
351 #[error("JSON error: {0}")]
352 JsonError(#[from] serde_json::Error),
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_session_entry_serialization() {
361 let entry = SessionEntry {
362 id: "a1b2c3d4".to_string(),
363 parent_id: None,
364 timestamp: "2024-01-01T00:00:00Z".to_string(),
365 entry_type: SessionEntryType::Message {
366 message: SerializableMessage {
367 role: "user".to_string(),
368 content: Some("Hello".to_string()),
369 tool_calls: None,
370 tool_call_id: None,
371 cache_control: None,
372 },
373 },
374 };
375
376 let json = serde_json::to_string(&entry).unwrap();
377 assert!(json.contains("\"id\":\"a1b2c3d4\""));
378 assert!(json.contains("\"type\":\"message\""));
379
380 let parsed: SessionEntry = serde_json::from_str(&json).unwrap();
381 assert_eq!(parsed.id, entry.id);
382 }
383
384 #[test]
385 fn test_build_context_linear() {
386 let mut tree = SessionTree::new("/test".to_string());
387
388 let msg1 = SessionEntry {
389 id: "a1b2c3d4".to_string(),
390 parent_id: None,
391 timestamp: "2024-01-01T00:00:00Z".to_string(),
392 entry_type: SessionEntryType::Message {
393 message: SerializableMessage::from(Message {
394 role: Role::User,
395 content: Some("Hello".to_string()),
396 tool_calls: None,
397 tool_call_id: None,
398 cache_control: None,
399 }),
400 },
401 };
402
403 let msg2 = SessionEntry {
404 id: "b2c3d4e5".to_string(),
405 parent_id: Some("a1b2c3d4".to_string()),
406 timestamp: "2024-01-01T00:01:00Z".to_string(),
407 entry_type: SessionEntryType::Message {
408 message: SerializableMessage::from(Message {
409 role: Role::Assistant,
410 content: Some("Hi!".to_string()),
411 tool_calls: None,
412 tool_call_id: None,
413 cache_control: None,
414 }),
415 },
416 };
417
418 tree.append(msg1).unwrap();
419 tree.append(msg2).unwrap();
420
421 let messages = tree.build_context("b2c3d4e5").unwrap();
422 assert_eq!(messages.len(), 2);
423 assert_eq!(messages[0].content, Some("Hello".to_string()));
424 assert_eq!(messages[1].content, Some("Hi!".to_string()));
425 }
426
427 #[test]
428 fn test_build_context_with_branching() {
429 let mut tree = SessionTree::new("/test".to_string());
430
431 let root = create_test_entry("root", None, "root content");
435 let a = create_test_entry("a", Some("root"), "a content");
436 let b = create_test_entry("b", Some("a"), "b content");
437 let c = create_test_entry("c", Some("a"), "c content");
438
439 tree.append(root).unwrap();
441 tree.append(a).unwrap();
442 tree.append(b).unwrap();
443
444 tree.branch_from("a").unwrap();
446 tree.append(c).unwrap();
447
448 let context_b = tree.build_context("b").unwrap();
450 assert_eq!(context_b.len(), 3);
451
452 let context_c = tree.build_context("c").unwrap();
454 assert_eq!(context_c.len(), 3);
455 assert_eq!(context_c[2].content, Some("c content".to_string()));
456 }
457
458 fn create_test_entry(id: &str, parent_id: Option<&str>, content: &str) -> SessionEntry {
459 SessionEntry {
460 id: id.to_string(),
461 parent_id: parent_id.map(|s| s.to_string()),
462 timestamp: "2024-01-01T00:00:00Z".to_string(),
463 entry_type: SessionEntryType::Message {
464 message: SerializableMessage::from(Message {
465 role: Role::User,
466 content: Some(content.to_string()),
467 tool_calls: None,
468 tool_call_id: None,
469 cache_control: None,
470 }),
471 },
472 }
473 }
474
475 #[test]
476 fn test_jsonl_roundtrip() {
477 let mut tree = SessionTree::new("/test".to_string());
478
479 let entry1 = create_test_entry("a1b2c3d4", None, "first");
480 let entry2 = create_test_entry("b2c3d4e5", Some("a1b2c3d4"), "second");
481
482 tree.append(entry1).unwrap();
483 tree.append(entry2).unwrap();
484
485 let file = tempfile::NamedTempFile::new().unwrap();
486 tree.save_to_file(file.path()).unwrap();
487
488 let loaded = SessionTree::load_from_file(file.path()).unwrap();
489
490 assert_eq!(loaded.leaf_id(), "b2c3d4e5");
491 assert_eq!(loaded.entries().len(), 2);
492
493 let context = loaded.build_context("b2c3d4e5").unwrap();
494 assert_eq!(context.len(), 2);
495 }
496
497 #[test]
498 fn test_jsonl_format() {
499 let mut tree = SessionTree::new("/test".to_string());
500 tree.append(create_test_entry("a1b2c3d4", None, "test"))
501 .unwrap();
502
503 let file = tempfile::NamedTempFile::new().unwrap();
504 tree.save_to_file(file.path()).unwrap();
505
506 let content = std::fs::read_to_string(file.path()).unwrap();
507
508 for line in content.lines() {
509 if !line.is_empty() {
510 serde_json::from_str::<serde_json::Value>(line).expect("Line should be valid JSON");
511 }
512 }
513 }
514}