1use crate::error::CliError;
2use crate::session_tree::{SessionEntry, SessionTree, SessionTreeError};
3use bincode::{deserialize, serialize};
4use chrono::{DateTime, Utc};
5use limit_llm::Message;
6use rusqlite::{params, Connection};
7use serde::{Deserialize, Serialize};
8use std::fs::{self, OpenOptions};
9use std::io::Write;
10use std::path::PathBuf;
11use tracing::instrument;
12use uuid::Uuid;
13
14const CURRENT_VERSION: u32 = 2;
15
16#[derive(Debug, Clone)]
17pub struct SessionInfo {
18 pub id: String,
19 #[allow(dead_code)]
20 pub created_at: DateTime<Utc>,
21 #[allow(dead_code)]
22 pub last_accessed: DateTime<Utc>,
23 pub message_count: usize,
24 pub total_input_tokens: u64,
25 pub total_output_tokens: u64,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29struct PersistedState {
30 version: u32,
31 messages: Vec<PersistedMessage>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35struct PersistedMessage {
36 role: PersistedRole,
37 content: Option<String>,
38 tool_calls: Option<Vec<limit_llm::ToolCall>>,
39 tool_call_id: Option<String>,
40 cache_control: Option<limit_llm::CacheControl>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
44enum PersistedRole {
45 User,
46 Assistant,
47 System,
48 Tool,
49}
50
51impl From<PersistedRole> for limit_llm::Role {
52 fn from(role: PersistedRole) -> Self {
53 match role {
54 PersistedRole::User => limit_llm::Role::User,
55 PersistedRole::Assistant => limit_llm::Role::Assistant,
56 PersistedRole::System => limit_llm::Role::System,
57 PersistedRole::Tool => limit_llm::Role::Tool,
58 }
59 }
60}
61
62impl From<limit_llm::Role> for PersistedRole {
63 fn from(role: limit_llm::Role) -> Self {
64 match role {
65 limit_llm::Role::User => PersistedRole::User,
66 limit_llm::Role::Assistant => PersistedRole::Assistant,
67 limit_llm::Role::System => PersistedRole::System,
68 limit_llm::Role::Tool => PersistedRole::Tool,
69 }
70 }
71}
72
73impl From<PersistedMessage> for Message {
74 fn from(msg: PersistedMessage) -> Self {
75 Message {
76 role: msg.role.into(),
77 content: msg.content,
78 tool_calls: msg.tool_calls,
79 tool_call_id: msg.tool_call_id,
80 cache_control: msg.cache_control,
81 }
82 }
83}
84
85impl From<Message> for PersistedMessage {
86 fn from(msg: Message) -> Self {
87 PersistedMessage {
88 role: msg.role.into(),
89 content: msg.content,
90 tool_calls: msg.tool_calls,
91 tool_call_id: msg.tool_call_id,
92 cache_control: msg.cache_control,
93 }
94 }
95}
96
97pub struct SessionManager {
98 db_path: PathBuf,
99 sessions_dir: PathBuf,
100}
101
102impl SessionManager {
103 pub fn new() -> Result<Self, CliError> {
104 let home_dir = dirs::home_dir()
106 .ok_or_else(|| CliError::ConfigError("Failed to get home directory".to_string()))?;
107 let limit_dir = home_dir.join(".limit");
108 fs::create_dir_all(&limit_dir).map_err(|e| {
109 CliError::ConfigError(format!("Failed to create .limit directory: {}", e))
110 })?;
111
112 let sessions_dir = limit_dir.join("sessions");
113 fs::create_dir_all(&sessions_dir).map_err(|e| {
114 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
115 })?;
116
117 let db_path = limit_dir.join("session.db");
118
119 Self::with_paths(db_path, sessions_dir)
120 }
121
122 pub fn with_paths(db_path: PathBuf, sessions_dir: PathBuf) -> Result<Self, CliError> {
124 fs::create_dir_all(&sessions_dir).map_err(|e| {
125 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
126 })?;
127
128 let session_manager = Self {
129 db_path,
130 sessions_dir,
131 };
132
133 session_manager.init_db()?;
134 Ok(session_manager)
135 }
136
137 pub fn init_db(&self) -> Result<(), CliError> {
138 let conn = Connection::open(&self.db_path)
139 .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))?;
140
141 conn.execute(
142 "CREATE TABLE IF NOT EXISTS sessions (
143 id TEXT PRIMARY KEY,
144 created_at TEXT NOT NULL,
145 last_accessed TEXT NOT NULL,
146 message_count INTEGER NOT NULL,
147 total_input_tokens INTEGER NOT NULL DEFAULT 0,
148 total_output_tokens INTEGER NOT NULL DEFAULT 0
149 )",
150 [],
151 )
152 .map_err(|e| CliError::ConfigError(format!("Failed to create sessions table: {}", e)))?;
153
154 conn.execute(
155 "CREATE INDEX IF NOT EXISTS idx_last_accessed ON sessions(last_accessed DESC)",
156 [],
157 )
158 .map_err(|e| CliError::ConfigError(format!("Failed to create index: {}", e)))?;
159
160 Ok(())
161 }
162
163 fn get_connection(&self) -> Result<Connection, CliError> {
164 Connection::open(&self.db_path)
165 .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))
166 }
167
168 pub fn create_new_session(&self) -> Result<String, CliError> {
169 let session_id = Uuid::new_v4().to_string();
170 let now = Utc::now().to_rfc3339();
171
172 let conn = self.get_connection()?;
173 conn.execute(
174 "INSERT INTO sessions (id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
175 params![&session_id, &now, &now, 0, 0, 0],
176 )
177 .map_err(|e| CliError::ConfigError(format!("Failed to create session: {}", e)))?;
178
179 Ok(session_id)
180 }
181
182 #[instrument(skip(self, messages))]
183 pub fn save_session(
184 &self,
185 session_id: &str,
186 messages: &[Message],
187 total_input_tokens: u64,
188 total_output_tokens: u64,
189 ) -> Result<(), CliError> {
190 let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
191
192 fs::create_dir_all(&self.sessions_dir).map_err(|e| {
193 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
194 })?;
195
196 let persisted_messages: Vec<PersistedMessage> =
197 messages.iter().cloned().map(|m| m.into()).collect();
198
199 let state = PersistedState {
200 version: CURRENT_VERSION,
201 messages: persisted_messages,
202 };
203
204 let serialized = serialize(&state)
205 .map_err(|e| CliError::ConfigError(format!("Failed to serialize messages: {}", e)))?;
206
207 fs::write(&file_path, serialized)
208 .map_err(|e| CliError::ConfigError(format!("Failed to write session file: {}", e)))?;
209
210 let now = Utc::now().to_rfc3339();
211 let conn = self.get_connection()?;
212 conn.execute(
213 "UPDATE sessions SET last_accessed = ?1, message_count = ?2, total_input_tokens = ?3, total_output_tokens = ?4 WHERE id = ?5",
214 params![&now, messages.len() as i64, total_input_tokens as i64, total_output_tokens as i64, session_id],
215 )
216 .map_err(|e| CliError::ConfigError(format!("Failed to update session metadata: {}", e)))?;
217
218 Ok(())
219 }
220
221 #[instrument(skip(self))]
222 pub fn load_session(&self, session_id: &str) -> Result<Vec<Message>, CliError> {
223 let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
224
225 let data = fs::read(&file_path)
226 .map_err(|e| CliError::ConfigError(format!("Failed to read session file: {}", e)))?;
227
228 let state: PersistedState = deserialize(&data)
229 .map_err(|e| CliError::ConfigError(format!("Failed to deserialize messages: {}", e)))?;
230
231 if state.version > CURRENT_VERSION {
233 return Err(CliError::ConfigError(format!(
234 "Version mismatch: expected {}, found {}",
235 CURRENT_VERSION, state.version
236 )));
237 }
238
239 let now = Utc::now().to_rfc3339();
240 let conn = self.get_connection()?;
241 conn.execute(
242 "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
243 params![&now, session_id],
244 )
245 .map_err(|e| CliError::ConfigError(format!("Failed to update last_accessed: {}", e)))?;
246
247 let messages: Vec<Message> = state
249 .messages
250 .into_iter()
251 .map(Message::from)
252 .filter(|m| m.role != limit_llm::Role::System)
253 .collect();
254
255 Ok(messages)
256 }
257
258 pub fn list_sessions(&self) -> Result<Vec<SessionInfo>, CliError> {
259 let conn = self.get_connection()?;
260
261 let mut stmt = conn
262 .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC")
263 .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
264
265 let session_iter = stmt
266 .query_map([], |row| {
267 Ok(SessionInfo {
268 id: row.get(0)?,
269 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
270 .unwrap()
271 .with_timezone(&Utc),
272 last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
273 .unwrap()
274 .with_timezone(&Utc),
275 message_count: row.get::<_, i64>(3)? as usize,
276 total_input_tokens: row.get::<_, i64>(4)? as u64,
277 total_output_tokens: row.get::<_, i64>(5)? as u64,
278 })
279 })
280 .map_err(|e| CliError::ConfigError(format!("Failed to query sessions: {}", e)))?;
281
282 let mut sessions = Vec::new();
283 for session in session_iter {
284 sessions.push(
285 session.map_err(|e| {
286 CliError::ConfigError(format!("Failed to parse session: {}", e))
287 })?,
288 );
289 }
290
291 Ok(sessions)
292 }
293
294 pub fn get_last_session(&self) -> Result<Option<SessionInfo>, CliError> {
295 let conn = self.get_connection()?;
296
297 let mut stmt = conn
298 .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC LIMIT 1")
299 .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
300
301 let mut session_iter = stmt
302 .query_map([], |row| {
303 Ok(SessionInfo {
304 id: row.get(0)?,
305 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
306 .unwrap()
307 .with_timezone(&Utc),
308 last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
309 .unwrap()
310 .with_timezone(&Utc),
311 message_count: row.get::<_, i64>(3)? as usize,
312 total_input_tokens: row.get::<_, i64>(4)? as u64,
313 total_output_tokens: row.get::<_, i64>(5)? as u64,
314 })
315 })
316 .map_err(|e| CliError::ConfigError(format!("Failed to query last session: {}", e)))?;
317
318 match session_iter.next() {
319 Some(session) => Ok(Some(session.map_err(|e| {
320 CliError::ConfigError(format!("Failed to parse session: {}", e))
321 })?)),
322 None => Ok(None),
323 }
324 }
325
326 #[allow(dead_code)]
328 pub fn update_session_tokens(
329 &self,
330 session_id: &str,
331 input_tokens: u64,
332 output_tokens: u64,
333 ) -> Result<(), CliError> {
334 let conn = self.get_connection()?;
335 conn.execute(
336 "UPDATE sessions SET total_input_tokens = total_input_tokens + ?1, total_output_tokens = total_output_tokens + ?2 WHERE id = ?3",
337 params![input_tokens as i64, output_tokens as i64, session_id],
338 )
339 .map_err(|e| CliError::ConfigError(format!("Failed to update session tokens: {}", e)))?;
340 Ok(())
341 }
342
343 pub fn create_tree_session(
345 &self,
346 session_id: &str,
347 cwd: String,
348 ) -> Result<SessionTree, CliError> {
349 let tree = SessionTree::new(cwd);
350 let file_path = self.tree_session_path(session_id);
351
352 if let Some(parent) = file_path.parent() {
353 fs::create_dir_all(parent)?;
354 }
355
356 tree.save_to_file(&file_path)?;
357 Ok(tree)
358 }
359
360 pub fn load_tree_session(&self, session_id: &str) -> Result<SessionTree, CliError> {
362 let file_path = self.tree_session_path(session_id);
363 SessionTree::load_from_file(&file_path).map_err(CliError::from)
364 }
365
366 pub fn append_tree_entry(
368 &self,
369 session_id: &str,
370 entry: &SessionEntry,
371 ) -> Result<(), CliError> {
372 let file_path = self.tree_session_path(session_id);
373
374 if !file_path.exists() {
375 return Err(CliError::ConfigError(
376 "Tree session file not found".to_string(),
377 ));
378 }
379
380 let mut file = OpenOptions::new()
382 .create(true)
383 .append(true)
384 .open(&file_path)?;
385 writeln!(
386 file,
387 "{}",
388 serde_json::to_string(entry).map_err(SessionTreeError::from)?
389 )?;
390 file.flush()?;
391 Ok(())
392 }
393
394 pub fn save_tree_session(&self, session_id: &str, tree: &SessionTree) -> Result<(), CliError> {
396 let file_path = self.tree_session_path(session_id);
397
398 if let Some(parent) = file_path.parent() {
399 fs::create_dir_all(parent)?;
400 }
401
402 tree.save_to_file(&file_path)?;
403 Ok(())
404 }
405
406 pub fn has_tree_session(&self, session_id: &str) -> Result<bool, CliError> {
408 let path = self.tree_session_path(session_id);
409 Ok(path.exists())
410 }
411
412 pub fn migrate_to_tree(&self, session_id: &str) -> Result<SessionTree, CliError> {
414 use crate::session_tree::{generate_entry_id, SerializableMessage, SessionEntryType};
415
416 if self.has_tree_session(session_id)? {
418 let tree = self.load_tree_session(session_id)?;
419 let messages = self.load_session(session_id)?;
420
421 if tree.entries().len() == messages.len() && !messages.is_empty() {
423 return Ok(tree);
424 }
425 }
427
428 let messages = self.load_session(session_id)?;
430
431 if messages.is_empty() {
432 let cwd = dirs::home_dir()
433 .map(|p| p.to_string_lossy().to_string())
434 .unwrap_or_else(|| "/".to_string());
435 return self.create_tree_session(session_id, cwd);
436 }
437
438 let cwd = dirs::home_dir()
440 .map(|p| p.to_string_lossy().to_string())
441 .unwrap_or_else(|| "/".to_string());
442
443 let _tree = self.create_tree_session(session_id, cwd)?;
445
446 let mut parent_id: Option<String> = None;
448 for msg in &messages {
449 let entry_id = generate_entry_id();
450 let entry = SessionEntry {
451 id: entry_id.clone(),
452 parent_id: parent_id.clone(),
453 timestamp: Utc::now().to_rfc3339(),
454 entry_type: SessionEntryType::Message {
455 message: SerializableMessage::from(msg.clone()),
456 },
457 };
458
459 self.append_tree_entry(session_id, &entry)?;
460 parent_id = Some(entry_id);
461 }
462
463 let tree = self.load_tree_session(session_id)?;
465
466 let bin_path = self.sessions_dir.join(format!("{}.bin", session_id));
468 if bin_path.exists() {
469 let _ = fs::remove_file(&bin_path);
470 }
471
472 Ok(tree)
473 }
474
475 fn tree_session_path(&self, session_id: &str) -> PathBuf {
477 self.sessions_dir.join(format!("{}.jsonl", session_id))
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use tempfile::tempdir;
485
486 #[test]
487 fn test_session_manager_new() {
488 let dir = tempdir().unwrap();
489 let sessions_dir = dir.path().join("sessions");
490 let db_path = dir.path().join("session.db");
491
492 fs::create_dir_all(&sessions_dir).unwrap();
493
494 let manager = SessionManager {
495 db_path,
496 sessions_dir,
497 };
498
499 manager.init_db().unwrap();
500
501 assert!(manager.db_path.exists());
502 }
503
504 #[test]
505 fn test_create_new_session() {
506 let dir = tempdir().unwrap();
507 let sessions_dir = dir.path().join("sessions");
508 let db_path = dir.path().join("session.db");
509
510 fs::create_dir_all(&sessions_dir).unwrap();
511
512 let manager = SessionManager {
513 db_path,
514 sessions_dir,
515 };
516
517 manager.init_db().unwrap();
518
519 let session_id = manager.create_new_session().unwrap();
520
521 assert!(!session_id.is_empty());
522
523 let sessions = manager.list_sessions().unwrap();
524 assert_eq!(sessions.len(), 1);
525 assert_eq!(sessions[0].id, session_id);
526 assert_eq!(sessions[0].message_count, 0);
527 }
528
529 #[test]
530 fn test_save_and_load_session() {
531 let dir = tempdir().unwrap();
532 let sessions_dir = dir.path().join("sessions");
533 let db_path = dir.path().join("session.db");
534
535 fs::create_dir_all(&sessions_dir).unwrap();
536
537 let manager = SessionManager {
538 db_path,
539 sessions_dir,
540 };
541
542 manager.init_db().unwrap();
543
544 let session_id = manager.create_new_session().unwrap();
545
546 let messages = vec![
547 Message {
548 role: limit_llm::Role::User,
549 content: Some("Hello".to_string()),
550 tool_calls: None,
551 tool_call_id: None,
552 cache_control: None,
553 },
554 Message {
555 role: limit_llm::Role::Assistant,
556 content: Some("Hi there!".to_string()),
557 tool_calls: None,
558 tool_call_id: None,
559 cache_control: None,
560 },
561 ];
562
563 manager.save_session(&session_id, &messages, 0, 0).unwrap();
564
565 let loaded = manager.load_session(&session_id).unwrap();
566
567 assert_eq!(loaded.len(), messages.len());
568 assert_eq!(loaded[0].content, messages[0].content);
569 assert_eq!(loaded[1].content, messages[1].content);
570
571 let sessions = manager.list_sessions().unwrap();
572 assert_eq!(sessions[0].message_count, 2);
573 }
574
575 #[test]
576 fn test_list_sessions() {
577 let dir = tempdir().unwrap();
578 let sessions_dir = dir.path().join("sessions");
579 let db_path = dir.path().join("session.db");
580
581 fs::create_dir_all(&sessions_dir).unwrap();
582
583 let manager = SessionManager {
584 db_path,
585 sessions_dir,
586 };
587
588 manager.init_db().unwrap();
589
590 let session_id1 = manager.create_new_session().unwrap();
591 std::thread::sleep(std::time::Duration::from_millis(10));
592 let session_id2 = manager.create_new_session().unwrap();
593
594 let sessions = manager.list_sessions().unwrap();
595 assert_eq!(sessions.len(), 2);
596 assert_eq!(sessions[0].id, session_id2);
597 assert_eq!(sessions[1].id, session_id1);
598 }
599
600 #[test]
601 fn test_get_last_session() {
602 let dir = tempdir().unwrap();
603 let sessions_dir = dir.path().join("sessions");
604 let db_path = dir.path().join("session.db");
605
606 fs::create_dir_all(&sessions_dir).unwrap();
607
608 let manager = SessionManager {
609 db_path,
610 sessions_dir,
611 };
612
613 manager.init_db().unwrap();
614
615 let last = manager.get_last_session().unwrap();
616 assert!(last.is_none());
617
618 let _session_id1 = manager.create_new_session().unwrap();
619 std::thread::sleep(std::time::Duration::from_millis(10));
620 let session_id2 = manager.create_new_session().unwrap();
621
622 let last = manager.get_last_session().unwrap();
623 assert!(last.is_some());
624 assert_eq!(last.unwrap().id, session_id2);
625 }
626
627 #[test]
628 fn test_session_persistence_across_restarts() {
629 let dir = tempdir().unwrap();
630 let sessions_dir = dir.path().join("sessions");
631 let db_path = dir.path().join("session.db");
632
633 fs::create_dir_all(&sessions_dir).unwrap();
634
635 let manager1 = SessionManager {
636 db_path: db_path.clone(),
637 sessions_dir: sessions_dir.clone(),
638 };
639
640 manager1.init_db().unwrap();
641
642 let session_id = manager1.create_new_session().unwrap();
643
644 let messages = vec![Message {
645 role: limit_llm::Role::User,
646 content: Some("Test message".to_string()),
647 tool_calls: None,
648 tool_call_id: None,
649 cache_control: None,
650 }];
651
652 manager1.save_session(&session_id, &messages, 0, 0).unwrap();
653
654 drop(manager1);
655
656 let manager2 = SessionManager {
657 db_path,
658 sessions_dir,
659 };
660
661 manager2.init_db().unwrap();
662
663 let loaded = manager2.load_session(&session_id).unwrap();
664 assert_eq!(loaded.len(), 1);
665 assert_eq!(loaded[0].content, Some("Test message".to_string()));
666 }
667
668 #[test]
669 fn test_tree_session_save_load() {
670 use crate::session_tree::{generate_entry_id, SerializableMessage, SessionEntryType};
671
672 let dir = tempdir().unwrap();
673 let db_path = dir.path().join("session.db");
674 let sessions_dir = dir.path().join("sessions");
675
676 let manager = SessionManager::with_paths(db_path, sessions_dir).unwrap();
677 let session_id = manager.create_new_session().unwrap();
678
679 manager
680 .create_tree_session(&session_id, "/test".to_string())
681 .unwrap();
682
683 let entry = SessionEntry {
684 id: generate_entry_id(),
685 parent_id: None,
686 timestamp: Utc::now().to_rfc3339(),
687 entry_type: SessionEntryType::Message {
688 message: SerializableMessage::from(Message {
689 role: limit_llm::Role::User,
690 content: Some("Hello".to_string()),
691 tool_calls: None,
692 tool_call_id: None,
693 cache_control: None,
694 }),
695 },
696 };
697
698 manager.append_tree_entry(&session_id, &entry).unwrap();
699
700 let loaded = manager.load_tree_session(&session_id).unwrap();
701 assert_eq!(loaded.entries().len(), 1);
702 }
703
704 #[test]
705 fn test_migrate_bin_to_jsonl() {
706 let dir = tempdir().unwrap();
707 let db_path = dir.path().join("session.db");
708 let sessions_dir = dir.path().join("sessions");
709
710 let manager = SessionManager::with_paths(db_path, sessions_dir).unwrap();
711 let session_id = manager.create_new_session().unwrap();
712
713 let messages = vec![
715 Message {
716 role: limit_llm::Role::User,
717 content: Some("Hello".to_string()),
718 tool_calls: None,
719 tool_call_id: None,
720 cache_control: None,
721 },
722 Message {
723 role: limit_llm::Role::Assistant,
724 content: Some("Hi!".to_string()),
725 tool_calls: None,
726 tool_call_id: None,
727 cache_control: None,
728 },
729 ];
730
731 manager
732 .save_session(&session_id, &messages, 100, 50)
733 .unwrap();
734
735 let tree = manager.migrate_to_tree(&session_id).unwrap();
737
738 assert_eq!(tree.entries().len(), 2);
740 let context = tree.build_context(tree.leaf_id()).unwrap();
741 assert_eq!(context.len(), 2);
742 assert_eq!(context[0].content, Some("Hello".to_string()));
743 assert_eq!(context[1].content, Some("Hi!".to_string()));
744 }
745}