1use crate::error::CliError;
2use crate::session_tree::{SessionEntry, SessionTree, SessionTreeError};
3use chrono::{DateTime, Utc};
4use limit_llm::Message;
5use rusqlite::{params, Connection};
6use serde::{Deserialize, Serialize};
7use std::fs::{self, OpenOptions};
8use std::io::Write;
9use std::path::PathBuf;
10use tracing::instrument;
11use uuid::Uuid;
12
13const CURRENT_VERSION: u32 = 2;
14
15#[derive(Debug, Clone)]
16pub struct SessionInfo {
17 pub id: String,
18 #[allow(dead_code)]
19 pub created_at: DateTime<Utc>,
20 #[allow(dead_code)]
21 pub last_accessed: DateTime<Utc>,
22 pub message_count: usize,
23 pub total_input_tokens: u64,
24 pub total_output_tokens: u64,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28struct PersistedState {
29 version: u32,
30 messages: Vec<PersistedMessage>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34struct PersistedMessage {
35 role: PersistedRole,
36 content: Option<String>,
37 tool_calls: Option<Vec<limit_llm::ToolCall>>,
38 tool_call_id: Option<String>,
39 cache_control: Option<limit_llm::CacheControl>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
43enum PersistedRole {
44 User,
45 Assistant,
46 System,
47 Tool,
48}
49
50impl From<PersistedRole> for limit_llm::Role {
51 fn from(role: PersistedRole) -> Self {
52 match role {
53 PersistedRole::User => limit_llm::Role::User,
54 PersistedRole::Assistant => limit_llm::Role::Assistant,
55 PersistedRole::System => limit_llm::Role::System,
56 PersistedRole::Tool => limit_llm::Role::Tool,
57 }
58 }
59}
60
61impl From<limit_llm::Role> for PersistedRole {
62 fn from(role: limit_llm::Role) -> Self {
63 match role {
64 limit_llm::Role::User => PersistedRole::User,
65 limit_llm::Role::Assistant => PersistedRole::Assistant,
66 limit_llm::Role::System => PersistedRole::System,
67 limit_llm::Role::Tool => PersistedRole::Tool,
68 }
69 }
70}
71
72impl From<PersistedMessage> for Message {
73 fn from(msg: PersistedMessage) -> Self {
74 Message {
75 role: msg.role.into(),
76 content: msg.content,
77 tool_calls: msg.tool_calls,
78 tool_call_id: msg.tool_call_id,
79 cache_control: msg.cache_control,
80 }
81 }
82}
83
84impl From<Message> for PersistedMessage {
85 fn from(msg: Message) -> Self {
86 PersistedMessage {
87 role: msg.role.into(),
88 content: msg.content,
89 tool_calls: msg.tool_calls,
90 tool_call_id: msg.tool_call_id,
91 cache_control: msg.cache_control,
92 }
93 }
94}
95
96pub struct SessionManager {
97 db_path: PathBuf,
98 sessions_dir: PathBuf,
99}
100
101impl SessionManager {
102 pub fn new() -> Result<Self, CliError> {
103 let home_dir = dirs::home_dir()
105 .ok_or_else(|| CliError::ConfigError("Failed to get home directory".to_string()))?;
106 let limit_dir = home_dir.join(".limit");
107 fs::create_dir_all(&limit_dir).map_err(|e| {
108 CliError::ConfigError(format!("Failed to create .limit directory: {}", e))
109 })?;
110
111 let sessions_dir = limit_dir.join("sessions");
112 fs::create_dir_all(&sessions_dir).map_err(|e| {
113 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
114 })?;
115
116 let db_path = limit_dir.join("session.db");
117
118 Self::with_paths(db_path, sessions_dir)
119 }
120
121 pub fn with_paths(db_path: PathBuf, sessions_dir: PathBuf) -> Result<Self, CliError> {
123 fs::create_dir_all(&sessions_dir).map_err(|e| {
124 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
125 })?;
126
127 let session_manager = Self {
128 db_path,
129 sessions_dir,
130 };
131
132 session_manager.init_db()?;
133 Ok(session_manager)
134 }
135
136 pub fn init_db(&self) -> Result<(), CliError> {
137 let conn = Connection::open(&self.db_path)
138 .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))?;
139
140 conn.execute(
141 "CREATE TABLE IF NOT EXISTS sessions (
142 id TEXT PRIMARY KEY,
143 created_at TEXT NOT NULL,
144 last_accessed TEXT NOT NULL,
145 message_count INTEGER NOT NULL,
146 total_input_tokens INTEGER NOT NULL DEFAULT 0,
147 total_output_tokens INTEGER NOT NULL DEFAULT 0
148 )",
149 [],
150 )
151 .map_err(|e| CliError::ConfigError(format!("Failed to create sessions table: {}", e)))?;
152
153 conn.execute(
154 "CREATE INDEX IF NOT EXISTS idx_last_accessed ON sessions(last_accessed DESC)",
155 [],
156 )
157 .map_err(|e| CliError::ConfigError(format!("Failed to create index: {}", e)))?;
158
159 Ok(())
160 }
161
162 fn get_connection(&self) -> Result<Connection, CliError> {
163 Connection::open(&self.db_path)
164 .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))
165 }
166
167 pub fn create_new_session(&self) -> Result<String, CliError> {
168 let session_id = Uuid::new_v4().to_string();
169 let now = Utc::now().to_rfc3339();
170
171 let conn = self.get_connection()?;
172 conn.execute(
173 "INSERT INTO sessions (id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
174 params![&session_id, &now, &now, 0, 0, 0],
175 )
176 .map_err(|e| CliError::ConfigError(format!("Failed to create session: {}", e)))?;
177
178 Ok(session_id)
179 }
180
181 #[instrument(skip(self, messages))]
182 pub fn save_session(
183 &self,
184 session_id: &str,
185 messages: &[Message],
186 total_input_tokens: u64,
187 total_output_tokens: u64,
188 ) -> Result<(), CliError> {
189 let file_path = self.sessions_dir.join(format!("{}.json", session_id));
190
191 fs::create_dir_all(&self.sessions_dir).map_err(|e| {
192 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
193 })?;
194
195 let persisted_messages: Vec<PersistedMessage> =
196 messages.iter().cloned().map(|m| m.into()).collect();
197
198 let state = PersistedState {
199 version: CURRENT_VERSION,
200 messages: persisted_messages,
201 };
202
203 let serialized = serde_json::to_string_pretty(&state)
204 .map_err(|e| CliError::ConfigError(format!("Failed to serialize messages: {}", e)))?;
205
206 fs::write(&file_path, serialized)
207 .map_err(|e| CliError::ConfigError(format!("Failed to write session file: {}", e)))?;
208
209 let now = Utc::now().to_rfc3339();
210 let conn = self.get_connection()?;
211 conn.execute(
212 "UPDATE sessions SET last_accessed = ?1, message_count = ?2, total_input_tokens = ?3, total_output_tokens = ?4 WHERE id = ?5",
213 params![&now, messages.len() as i64, total_input_tokens as i64, total_output_tokens as i64, session_id],
214 )
215 .map_err(|e| CliError::ConfigError(format!("Failed to update session metadata: {}", e)))?;
216
217 Ok(())
218 }
219
220 #[instrument(skip(self))]
221 pub fn load_session(&self, session_id: &str) -> Result<Vec<Message>, CliError> {
222 let file_path = self.sessions_dir.join(format!("{}.json", session_id));
223
224 let data = match fs::read_to_string(&file_path) {
225 Ok(data) => data,
226 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
227 return Err(CliError::ConfigError(format!(
228 "Session file not found: {}",
229 session_id
230 )))
231 }
232 Err(e) => {
233 return Err(CliError::ConfigError(format!(
234 "Failed to read session file: {}",
235 e
236 )))
237 }
238 };
239
240 let state: PersistedState = serde_json::from_str(&data)
241 .map_err(|e| CliError::ConfigError(format!("Failed to deserialize messages: {}", e)))?;
242
243 if state.version > CURRENT_VERSION {
245 return Err(CliError::ConfigError(format!(
246 "Version mismatch: expected {}, found {}",
247 CURRENT_VERSION, state.version
248 )));
249 }
250
251 let now = Utc::now().to_rfc3339();
252 let conn = self.get_connection()?;
253 conn.execute(
254 "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
255 params![&now, session_id],
256 )
257 .map_err(|e| CliError::ConfigError(format!("Failed to update last_accessed: {}", e)))?;
258
259 let messages: Vec<Message> = state
261 .messages
262 .into_iter()
263 .map(Message::from)
264 .filter(|m| m.role != limit_llm::Role::System)
265 .collect();
266
267 Ok(messages)
268 }
269
270 pub fn list_sessions(&self) -> Result<Vec<SessionInfo>, CliError> {
271 let conn = self.get_connection()?;
272
273 let mut stmt = conn
274 .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC")
275 .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
276
277 let session_iter = stmt
278 .query_map([], |row| {
279 Ok(SessionInfo {
280 id: row.get(0)?,
281 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
282 .unwrap()
283 .with_timezone(&Utc),
284 last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
285 .unwrap()
286 .with_timezone(&Utc),
287 message_count: row.get::<_, i64>(3)? as usize,
288 total_input_tokens: row.get::<_, i64>(4)? as u64,
289 total_output_tokens: row.get::<_, i64>(5)? as u64,
290 })
291 })
292 .map_err(|e| CliError::ConfigError(format!("Failed to query sessions: {}", e)))?;
293
294 let mut sessions = Vec::new();
295 for session in session_iter {
296 sessions.push(
297 session.map_err(|e| {
298 CliError::ConfigError(format!("Failed to parse session: {}", e))
299 })?,
300 );
301 }
302
303 Ok(sessions)
304 }
305
306 pub fn get_last_session(&self) -> Result<Option<SessionInfo>, CliError> {
307 let conn = self.get_connection()?;
308
309 let mut stmt = conn
310 .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC LIMIT 1")
311 .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
312
313 let mut session_iter = stmt
314 .query_map([], |row| {
315 Ok(SessionInfo {
316 id: row.get(0)?,
317 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
318 .unwrap()
319 .with_timezone(&Utc),
320 last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
321 .unwrap()
322 .with_timezone(&Utc),
323 message_count: row.get::<_, i64>(3)? as usize,
324 total_input_tokens: row.get::<_, i64>(4)? as u64,
325 total_output_tokens: row.get::<_, i64>(5)? as u64,
326 })
327 })
328 .map_err(|e| CliError::ConfigError(format!("Failed to query last session: {}", e)))?;
329
330 match session_iter.next() {
331 Some(session) => Ok(Some(session.map_err(|e| {
332 CliError::ConfigError(format!("Failed to parse session: {}", e))
333 })?)),
334 None => Ok(None),
335 }
336 }
337
338 #[allow(dead_code)]
340 pub fn update_session_tokens(
341 &self,
342 session_id: &str,
343 input_tokens: u64,
344 output_tokens: u64,
345 ) -> Result<(), CliError> {
346 let conn = self.get_connection()?;
347 conn.execute(
348 "UPDATE sessions SET total_input_tokens = total_input_tokens + ?1, total_output_tokens = total_output_tokens + ?2 WHERE id = ?3",
349 params![input_tokens as i64, output_tokens as i64, session_id],
350 )
351 .map_err(|e| CliError::ConfigError(format!("Failed to update session tokens: {}", e)))?;
352 Ok(())
353 }
354
355 pub fn create_tree_session(
357 &self,
358 session_id: &str,
359 cwd: String,
360 ) -> Result<SessionTree, CliError> {
361 let tree = SessionTree::new(cwd);
362 let file_path = self.tree_session_path(session_id);
363
364 if let Some(parent) = file_path.parent() {
365 fs::create_dir_all(parent)?;
366 }
367
368 tree.save_to_file(&file_path)?;
369 Ok(tree)
370 }
371
372 pub fn load_tree_session(&self, session_id: &str) -> Result<SessionTree, CliError> {
374 let file_path = self.tree_session_path(session_id);
375 SessionTree::load_from_file(&file_path).map_err(CliError::from)
376 }
377
378 pub fn append_tree_entry(
380 &self,
381 session_id: &str,
382 entry: &SessionEntry,
383 ) -> Result<(), CliError> {
384 let file_path = self.tree_session_path(session_id);
385
386 if !file_path.exists() {
387 return Err(CliError::ConfigError(
388 "Tree session file not found".to_string(),
389 ));
390 }
391
392 let mut file = OpenOptions::new()
394 .create(true)
395 .append(true)
396 .open(&file_path)?;
397 writeln!(
398 file,
399 "{}",
400 serde_json::to_string(entry).map_err(SessionTreeError::from)?
401 )?;
402 file.flush()?;
403 Ok(())
404 }
405
406 pub fn save_tree_session(&self, session_id: &str, tree: &SessionTree) -> Result<(), CliError> {
408 let file_path = self.tree_session_path(session_id);
409
410 if let Some(parent) = file_path.parent() {
411 fs::create_dir_all(parent)?;
412 }
413
414 tree.save_to_file(&file_path)?;
415 Ok(())
416 }
417
418 pub fn has_tree_session(&self, session_id: &str) -> Result<bool, CliError> {
420 let path = self.tree_session_path(session_id);
421 Ok(path.exists())
422 }
423
424 pub fn migrate_to_tree(&self, session_id: &str) -> Result<SessionTree, CliError> {
426 use crate::session_tree::{generate_entry_id, SerializableMessage, SessionEntryType};
427
428 if self.has_tree_session(session_id)? {
430 let tree = self.load_tree_session(session_id)?;
431 let messages = self.load_session(session_id)?;
432
433 if tree.entries().len() == messages.len() && !messages.is_empty() {
435 return Ok(tree);
436 }
437 }
439
440 let messages = self.load_session(session_id)?;
442
443 if messages.is_empty() {
444 let cwd = dirs::home_dir()
445 .map(|p| p.to_string_lossy().to_string())
446 .unwrap_or_else(|| "/".to_string());
447 return self.create_tree_session(session_id, cwd);
448 }
449
450 let cwd = dirs::home_dir()
452 .map(|p| p.to_string_lossy().to_string())
453 .unwrap_or_else(|| "/".to_string());
454
455 let _tree = self.create_tree_session(session_id, cwd)?;
457
458 let mut parent_id: Option<String> = None;
460 for msg in &messages {
461 let entry_id = generate_entry_id();
462 let entry = SessionEntry {
463 id: entry_id.clone(),
464 parent_id: parent_id.clone(),
465 timestamp: Utc::now().to_rfc3339(),
466 entry_type: SessionEntryType::Message {
467 message: SerializableMessage::from(msg.clone()),
468 },
469 };
470
471 self.append_tree_entry(session_id, &entry)?;
472 parent_id = Some(entry_id);
473 }
474
475 let tree = self.load_tree_session(session_id)?;
477
478 let json_path = self.sessions_dir.join(format!("{}.json", session_id));
480 if json_path.exists() {
481 let _ = fs::remove_file(&json_path);
482 }
483
484 Ok(tree)
485 }
486
487 fn tree_session_path(&self, session_id: &str) -> PathBuf {
489 self.sessions_dir.join(format!("{}.jsonl", session_id))
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use tempfile::tempdir;
497
498 #[test]
499 fn test_session_manager_new() {
500 let dir = tempdir().unwrap();
501 let sessions_dir = dir.path().join("sessions");
502 let db_path = dir.path().join("session.db");
503
504 fs::create_dir_all(&sessions_dir).unwrap();
505
506 let manager = SessionManager {
507 db_path,
508 sessions_dir,
509 };
510
511 manager.init_db().unwrap();
512
513 assert!(manager.db_path.exists());
514 }
515
516 #[test]
517 fn test_create_new_session() {
518 let dir = tempdir().unwrap();
519 let sessions_dir = dir.path().join("sessions");
520 let db_path = dir.path().join("session.db");
521
522 fs::create_dir_all(&sessions_dir).unwrap();
523
524 let manager = SessionManager {
525 db_path,
526 sessions_dir,
527 };
528
529 manager.init_db().unwrap();
530
531 let session_id = manager.create_new_session().unwrap();
532
533 assert!(!session_id.is_empty());
534
535 let sessions = manager.list_sessions().unwrap();
536 assert_eq!(sessions.len(), 1);
537 assert_eq!(sessions[0].id, session_id);
538 assert_eq!(sessions[0].message_count, 0);
539 }
540
541 #[test]
542 fn test_save_and_load_session() {
543 let dir = tempdir().unwrap();
544 let sessions_dir = dir.path().join("sessions");
545 let db_path = dir.path().join("session.db");
546
547 fs::create_dir_all(&sessions_dir).unwrap();
548
549 let manager = SessionManager {
550 db_path,
551 sessions_dir,
552 };
553
554 manager.init_db().unwrap();
555
556 let session_id = manager.create_new_session().unwrap();
557
558 let messages = vec![
559 Message {
560 role: limit_llm::Role::User,
561 content: Some("Hello".to_string()),
562 tool_calls: None,
563 tool_call_id: None,
564 cache_control: None,
565 },
566 Message {
567 role: limit_llm::Role::Assistant,
568 content: Some("Hi there!".to_string()),
569 tool_calls: None,
570 tool_call_id: None,
571 cache_control: None,
572 },
573 ];
574
575 manager.save_session(&session_id, &messages, 0, 0).unwrap();
576
577 let loaded = manager.load_session(&session_id).unwrap();
578
579 assert_eq!(loaded.len(), messages.len());
580 assert_eq!(loaded[0].content, messages[0].content);
581 assert_eq!(loaded[1].content, messages[1].content);
582
583 let sessions = manager.list_sessions().unwrap();
584 assert_eq!(sessions[0].message_count, 2);
585 }
586
587 #[test]
588 fn test_list_sessions() {
589 let dir = tempdir().unwrap();
590 let sessions_dir = dir.path().join("sessions");
591 let db_path = dir.path().join("session.db");
592
593 fs::create_dir_all(&sessions_dir).unwrap();
594
595 let manager = SessionManager {
596 db_path,
597 sessions_dir,
598 };
599
600 manager.init_db().unwrap();
601
602 let session_id1 = manager.create_new_session().unwrap();
603 std::thread::sleep(std::time::Duration::from_millis(10));
604 let session_id2 = manager.create_new_session().unwrap();
605
606 let sessions = manager.list_sessions().unwrap();
607 assert_eq!(sessions.len(), 2);
608 assert_eq!(sessions[0].id, session_id2);
609 assert_eq!(sessions[1].id, session_id1);
610 }
611
612 #[test]
613 fn test_get_last_session() {
614 let dir = tempdir().unwrap();
615 let sessions_dir = dir.path().join("sessions");
616 let db_path = dir.path().join("session.db");
617
618 fs::create_dir_all(&sessions_dir).unwrap();
619
620 let manager = SessionManager {
621 db_path,
622 sessions_dir,
623 };
624
625 manager.init_db().unwrap();
626
627 let last = manager.get_last_session().unwrap();
628 assert!(last.is_none());
629
630 let _session_id1 = manager.create_new_session().unwrap();
631 std::thread::sleep(std::time::Duration::from_millis(10));
632 let session_id2 = manager.create_new_session().unwrap();
633
634 let last = manager.get_last_session().unwrap();
635 assert!(last.is_some());
636 assert_eq!(last.unwrap().id, session_id2);
637 }
638
639 #[test]
640 fn test_session_persistence_across_restarts() {
641 let dir = tempdir().unwrap();
642 let sessions_dir = dir.path().join("sessions");
643 let db_path = dir.path().join("session.db");
644
645 fs::create_dir_all(&sessions_dir).unwrap();
646
647 let manager1 = SessionManager {
648 db_path: db_path.clone(),
649 sessions_dir: sessions_dir.clone(),
650 };
651
652 manager1.init_db().unwrap();
653
654 let session_id = manager1.create_new_session().unwrap();
655
656 let messages = vec![Message {
657 role: limit_llm::Role::User,
658 content: Some("Test message".to_string()),
659 tool_calls: None,
660 tool_call_id: None,
661 cache_control: None,
662 }];
663
664 manager1.save_session(&session_id, &messages, 0, 0).unwrap();
665
666 drop(manager1);
667
668 let manager2 = SessionManager {
669 db_path,
670 sessions_dir,
671 };
672
673 manager2.init_db().unwrap();
674
675 let loaded = manager2.load_session(&session_id).unwrap();
676 assert_eq!(loaded.len(), 1);
677 assert_eq!(loaded[0].content, Some("Test message".to_string()));
678 }
679
680 #[test]
681 fn test_tree_session_save_load() {
682 use crate::session_tree::{generate_entry_id, SerializableMessage, SessionEntryType};
683
684 let dir = tempdir().unwrap();
685 let db_path = dir.path().join("session.db");
686 let sessions_dir = dir.path().join("sessions");
687
688 let manager = SessionManager::with_paths(db_path, sessions_dir).unwrap();
689 let session_id = manager.create_new_session().unwrap();
690
691 manager
692 .create_tree_session(&session_id, "/test".to_string())
693 .unwrap();
694
695 let entry = SessionEntry {
696 id: generate_entry_id(),
697 parent_id: None,
698 timestamp: Utc::now().to_rfc3339(),
699 entry_type: SessionEntryType::Message {
700 message: SerializableMessage::from(Message {
701 role: limit_llm::Role::User,
702 content: Some("Hello".to_string()),
703 tool_calls: None,
704 tool_call_id: None,
705 cache_control: None,
706 }),
707 },
708 };
709
710 manager.append_tree_entry(&session_id, &entry).unwrap();
711
712 let loaded = manager.load_tree_session(&session_id).unwrap();
713 assert_eq!(loaded.entries().len(), 1);
714 }
715
716 #[test]
717 fn test_migrate_json_to_jsonl() {
718 let dir = tempdir().unwrap();
719 let db_path = dir.path().join("session.db");
720 let sessions_dir = dir.path().join("sessions");
721
722 let manager = SessionManager::with_paths(db_path, sessions_dir).unwrap();
723 let session_id = manager.create_new_session().unwrap();
724
725 let messages = vec![
727 Message {
728 role: limit_llm::Role::User,
729 content: Some("Hello".to_string()),
730 tool_calls: None,
731 tool_call_id: None,
732 cache_control: None,
733 },
734 Message {
735 role: limit_llm::Role::Assistant,
736 content: Some("Hi!".to_string()),
737 tool_calls: None,
738 tool_call_id: None,
739 cache_control: None,
740 },
741 ];
742
743 manager
744 .save_session(&session_id, &messages, 100, 50)
745 .unwrap();
746
747 let tree = manager.migrate_to_tree(&session_id).unwrap();
749
750 assert_eq!(tree.entries().len(), 2);
752 let context = tree.build_context(tree.leaf_id()).unwrap();
753 assert_eq!(context.len(), 2);
754 assert_eq!(context[0].content, Some("Hello".to_string()));
755 assert_eq!(context[1].content, Some("Hi!".to_string()));
756 }
757}