1use crate::error::CliError;
2use bincode::{deserialize, serialize};
3use chrono::{DateTime, Utc};
4use limit_llm::Message;
5use rusqlite::{params, Connection};
6use serde::{Deserialize, Serialize};
7use std::fs;
8use std::path::PathBuf;
9use tracing::instrument;
10use uuid::Uuid;
11
12const CURRENT_VERSION: u32 = 2;
13
14#[derive(Debug, Clone)]
15pub struct SessionInfo {
16 pub id: String,
17 #[allow(dead_code)]
18 pub created_at: DateTime<Utc>,
19 #[allow(dead_code)]
20 pub last_accessed: DateTime<Utc>,
21 pub message_count: usize,
22 pub total_input_tokens: u64,
23 pub total_output_tokens: u64,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27struct PersistedState {
28 version: u32,
29 messages: Vec<PersistedMessage>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33struct PersistedMessage {
34 role: PersistedRole,
35 content: Option<String>,
36 tool_calls: Option<Vec<limit_llm::ToolCall>>,
37 tool_call_id: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41enum PersistedRole {
42 User,
43 Assistant,
44 System,
45 Tool,
46}
47
48impl From<PersistedRole> for limit_llm::Role {
49 fn from(role: PersistedRole) -> Self {
50 match role {
51 PersistedRole::User => limit_llm::Role::User,
52 PersistedRole::Assistant => limit_llm::Role::Assistant,
53 PersistedRole::System => limit_llm::Role::System,
54 PersistedRole::Tool => limit_llm::Role::Tool,
55 }
56 }
57}
58
59impl From<limit_llm::Role> for PersistedRole {
60 fn from(role: limit_llm::Role) -> Self {
61 match role {
62 limit_llm::Role::User => PersistedRole::User,
63 limit_llm::Role::Assistant => PersistedRole::Assistant,
64 limit_llm::Role::System => PersistedRole::System,
65 limit_llm::Role::Tool => PersistedRole::Tool,
66 }
67 }
68}
69
70impl From<PersistedMessage> for Message {
71 fn from(msg: PersistedMessage) -> Self {
72 Message {
73 role: msg.role.into(),
74 content: msg.content,
75 tool_calls: msg.tool_calls,
76 tool_call_id: msg.tool_call_id,
77 }
78 }
79}
80
81impl From<Message> for PersistedMessage {
82 fn from(msg: Message) -> Self {
83 PersistedMessage {
84 role: msg.role.into(),
85 content: msg.content,
86 tool_calls: msg.tool_calls,
87 tool_call_id: msg.tool_call_id,
88 }
89 }
90}
91
92pub struct SessionManager {
93 db_path: PathBuf,
94 sessions_dir: PathBuf,
95}
96
97impl SessionManager {
98 pub fn new() -> Result<Self, CliError> {
99 let home_dir = dirs::home_dir()
101 .ok_or_else(|| CliError::ConfigError("Failed to get home directory".to_string()))?;
102 let limit_dir = home_dir.join(".limit");
103 fs::create_dir_all(&limit_dir).map_err(|e| {
104 CliError::ConfigError(format!("Failed to create .limit directory: {}", e))
105 })?;
106
107 let sessions_dir = limit_dir.join("sessions");
108 fs::create_dir_all(&sessions_dir).map_err(|e| {
109 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
110 })?;
111
112 let db_path = limit_dir.join("session.db");
113
114 Self::with_paths(db_path, sessions_dir)
115 }
116
117 pub fn with_paths(db_path: PathBuf, sessions_dir: PathBuf) -> Result<Self, CliError> {
119 fs::create_dir_all(&sessions_dir).map_err(|e| {
120 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
121 })?;
122
123 let session_manager = Self {
124 db_path,
125 sessions_dir,
126 };
127
128 session_manager.init_db()?;
129 Ok(session_manager)
130 }
131
132 pub fn init_db(&self) -> Result<(), CliError> {
133 let conn = Connection::open(&self.db_path)
134 .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))?;
135
136 conn.execute(
137 "CREATE TABLE IF NOT EXISTS sessions (
138 id TEXT PRIMARY KEY,
139 created_at TEXT NOT NULL,
140 last_accessed TEXT NOT NULL,
141 message_count INTEGER NOT NULL,
142 total_input_tokens INTEGER NOT NULL DEFAULT 0,
143 total_output_tokens INTEGER NOT NULL DEFAULT 0
144 )",
145 [],
146 )
147 .map_err(|e| CliError::ConfigError(format!("Failed to create sessions table: {}", e)))?;
148
149 conn.execute(
150 "CREATE INDEX IF NOT EXISTS idx_last_accessed ON sessions(last_accessed DESC)",
151 [],
152 )
153 .map_err(|e| CliError::ConfigError(format!("Failed to create index: {}", e)))?;
154
155 Ok(())
156 }
157
158 fn get_connection(&self) -> Result<Connection, CliError> {
159 Connection::open(&self.db_path)
160 .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))
161 }
162
163 pub fn create_new_session(&self) -> Result<String, CliError> {
164 let session_id = Uuid::new_v4().to_string();
165 let now = Utc::now().to_rfc3339();
166
167 let conn = self.get_connection()?;
168 conn.execute(
169 "INSERT INTO sessions (id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
170 params![&session_id, &now, &now, 0, 0, 0],
171 )
172 .map_err(|e| CliError::ConfigError(format!("Failed to create session: {}", e)))?;
173
174 Ok(session_id)
175 }
176
177 #[instrument(skip(self, messages))]
178 pub fn save_session(
179 &self,
180 session_id: &str,
181 messages: &[Message],
182 total_input_tokens: u64,
183 total_output_tokens: u64,
184 ) -> Result<(), CliError> {
185 let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
186
187 fs::create_dir_all(&self.sessions_dir).map_err(|e| {
188 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
189 })?;
190
191 let persisted_messages: Vec<PersistedMessage> =
192 messages.iter().cloned().map(|m| m.into()).collect();
193
194 let state = PersistedState {
195 version: CURRENT_VERSION,
196 messages: persisted_messages,
197 };
198
199 let serialized = serialize(&state)
200 .map_err(|e| CliError::ConfigError(format!("Failed to serialize messages: {}", e)))?;
201
202 fs::write(&file_path, serialized)
203 .map_err(|e| CliError::ConfigError(format!("Failed to write session file: {}", e)))?;
204
205 let now = Utc::now().to_rfc3339();
206 let conn = self.get_connection()?;
207 conn.execute(
208 "UPDATE sessions SET last_accessed = ?1, message_count = ?2, total_input_tokens = ?3, total_output_tokens = ?4 WHERE id = ?5",
209 params![&now, messages.len() as i64, total_input_tokens as i64, total_output_tokens as i64, session_id],
210 )
211 .map_err(|e| CliError::ConfigError(format!("Failed to update session metadata: {}", e)))?;
212
213 Ok(())
214 }
215
216 #[instrument(skip(self))]
217 pub fn load_session(&self, session_id: &str) -> Result<Vec<Message>, CliError> {
218 let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
219
220 let data = fs::read(&file_path)
221 .map_err(|e| CliError::ConfigError(format!("Failed to read session file: {}", e)))?;
222
223 let state: PersistedState = deserialize(&data)
224 .map_err(|e| CliError::ConfigError(format!("Failed to deserialize messages: {}", e)))?;
225
226 if state.version > CURRENT_VERSION {
228 return Err(CliError::ConfigError(format!(
229 "Version mismatch: expected {}, found {}",
230 CURRENT_VERSION, state.version
231 )));
232 }
233
234 let now = Utc::now().to_rfc3339();
235 let conn = self.get_connection()?;
236 conn.execute(
237 "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
238 params![&now, session_id],
239 )
240 .map_err(|e| CliError::ConfigError(format!("Failed to update last_accessed: {}", e)))?;
241
242 let messages: Vec<Message> = state
244 .messages
245 .into_iter()
246 .map(Message::from)
247 .filter(|m| m.role != limit_llm::Role::System)
248 .collect();
249
250 Ok(messages)
251 }
252
253 pub fn list_sessions(&self) -> Result<Vec<SessionInfo>, CliError> {
254 let conn = self.get_connection()?;
255
256 let mut stmt = conn
257 .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC")
258 .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
259
260 let session_iter = stmt
261 .query_map([], |row| {
262 Ok(SessionInfo {
263 id: row.get(0)?,
264 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
265 .unwrap()
266 .with_timezone(&Utc),
267 last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
268 .unwrap()
269 .with_timezone(&Utc),
270 message_count: row.get::<_, i64>(3)? as usize,
271 total_input_tokens: row.get::<_, i64>(4)? as u64,
272 total_output_tokens: row.get::<_, i64>(5)? as u64,
273 })
274 })
275 .map_err(|e| CliError::ConfigError(format!("Failed to query sessions: {}", e)))?;
276
277 let mut sessions = Vec::new();
278 for session in session_iter {
279 sessions.push(
280 session.map_err(|e| {
281 CliError::ConfigError(format!("Failed to parse session: {}", e))
282 })?,
283 );
284 }
285
286 Ok(sessions)
287 }
288
289 pub fn get_last_session(&self) -> Result<Option<SessionInfo>, CliError> {
290 let conn = self.get_connection()?;
291
292 let mut stmt = conn
293 .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC LIMIT 1")
294 .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
295
296 let mut session_iter = stmt
297 .query_map([], |row| {
298 Ok(SessionInfo {
299 id: row.get(0)?,
300 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
301 .unwrap()
302 .with_timezone(&Utc),
303 last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
304 .unwrap()
305 .with_timezone(&Utc),
306 message_count: row.get::<_, i64>(3)? as usize,
307 total_input_tokens: row.get::<_, i64>(4)? as u64,
308 total_output_tokens: row.get::<_, i64>(5)? as u64,
309 })
310 })
311 .map_err(|e| CliError::ConfigError(format!("Failed to query last session: {}", e)))?;
312
313 match session_iter.next() {
314 Some(session) => Ok(Some(session.map_err(|e| {
315 CliError::ConfigError(format!("Failed to parse session: {}", e))
316 })?)),
317 None => Ok(None),
318 }
319 }
320
321 #[allow(dead_code)]
323 pub fn update_session_tokens(
324 &self,
325 session_id: &str,
326 input_tokens: u64,
327 output_tokens: u64,
328 ) -> Result<(), CliError> {
329 let conn = self.get_connection()?;
330 conn.execute(
331 "UPDATE sessions SET total_input_tokens = total_input_tokens + ?1, total_output_tokens = total_output_tokens + ?2 WHERE id = ?3",
332 params![input_tokens as i64, output_tokens as i64, session_id],
333 )
334 .map_err(|e| CliError::ConfigError(format!("Failed to update session tokens: {}", e)))?;
335 Ok(())
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use tempfile::tempdir;
343
344 #[test]
345 fn test_session_manager_new() {
346 let dir = tempdir().unwrap();
347 let sessions_dir = dir.path().join("sessions");
348 let db_path = dir.path().join("session.db");
349
350 fs::create_dir_all(&sessions_dir).unwrap();
351
352 let manager = SessionManager {
353 db_path,
354 sessions_dir,
355 };
356
357 manager.init_db().unwrap();
358
359 assert!(manager.db_path.exists());
360 }
361
362 #[test]
363 fn test_create_new_session() {
364 let dir = tempdir().unwrap();
365 let sessions_dir = dir.path().join("sessions");
366 let db_path = dir.path().join("session.db");
367
368 fs::create_dir_all(&sessions_dir).unwrap();
369
370 let manager = SessionManager {
371 db_path,
372 sessions_dir,
373 };
374
375 manager.init_db().unwrap();
376
377 let session_id = manager.create_new_session().unwrap();
378
379 assert!(!session_id.is_empty());
380
381 let sessions = manager.list_sessions().unwrap();
382 assert_eq!(sessions.len(), 1);
383 assert_eq!(sessions[0].id, session_id);
384 assert_eq!(sessions[0].message_count, 0);
385 }
386
387 #[test]
388 fn test_save_and_load_session() {
389 let dir = tempdir().unwrap();
390 let sessions_dir = dir.path().join("sessions");
391 let db_path = dir.path().join("session.db");
392
393 fs::create_dir_all(&sessions_dir).unwrap();
394
395 let manager = SessionManager {
396 db_path,
397 sessions_dir,
398 };
399
400 manager.init_db().unwrap();
401
402 let session_id = manager.create_new_session().unwrap();
403
404 let messages = vec![
405 Message {
406 role: limit_llm::Role::User,
407 content: Some("Hello".to_string()),
408 tool_calls: None,
409 tool_call_id: None,
410 },
411 Message {
412 role: limit_llm::Role::Assistant,
413 content: Some("Hi there!".to_string()),
414 tool_calls: None,
415 tool_call_id: None,
416 },
417 ];
418
419 manager.save_session(&session_id, &messages, 0, 0).unwrap();
420
421 let loaded = manager.load_session(&session_id).unwrap();
422
423 assert_eq!(loaded.len(), messages.len());
424 assert_eq!(loaded[0].content, messages[0].content);
425 assert_eq!(loaded[1].content, messages[1].content);
426
427 let sessions = manager.list_sessions().unwrap();
428 assert_eq!(sessions[0].message_count, 2);
429 }
430
431 #[test]
432 fn test_list_sessions() {
433 let dir = tempdir().unwrap();
434 let sessions_dir = dir.path().join("sessions");
435 let db_path = dir.path().join("session.db");
436
437 fs::create_dir_all(&sessions_dir).unwrap();
438
439 let manager = SessionManager {
440 db_path,
441 sessions_dir,
442 };
443
444 manager.init_db().unwrap();
445
446 let session_id1 = manager.create_new_session().unwrap();
447 std::thread::sleep(std::time::Duration::from_millis(10));
448 let session_id2 = manager.create_new_session().unwrap();
449
450 let sessions = manager.list_sessions().unwrap();
451 assert_eq!(sessions.len(), 2);
452 assert_eq!(sessions[0].id, session_id2);
453 assert_eq!(sessions[1].id, session_id1);
454 }
455
456 #[test]
457 fn test_get_last_session() {
458 let dir = tempdir().unwrap();
459 let sessions_dir = dir.path().join("sessions");
460 let db_path = dir.path().join("session.db");
461
462 fs::create_dir_all(&sessions_dir).unwrap();
463
464 let manager = SessionManager {
465 db_path,
466 sessions_dir,
467 };
468
469 manager.init_db().unwrap();
470
471 let last = manager.get_last_session().unwrap();
472 assert!(last.is_none());
473
474 let _session_id1 = manager.create_new_session().unwrap();
475 std::thread::sleep(std::time::Duration::from_millis(10));
476 let session_id2 = manager.create_new_session().unwrap();
477
478 let last = manager.get_last_session().unwrap();
479 assert!(last.is_some());
480 assert_eq!(last.unwrap().id, session_id2);
481 }
482
483 #[test]
484 fn test_session_persistence_across_restarts() {
485 let dir = tempdir().unwrap();
486 let sessions_dir = dir.path().join("sessions");
487 let db_path = dir.path().join("session.db");
488
489 fs::create_dir_all(&sessions_dir).unwrap();
490
491 let manager1 = SessionManager {
492 db_path: db_path.clone(),
493 sessions_dir: sessions_dir.clone(),
494 };
495
496 manager1.init_db().unwrap();
497
498 let session_id = manager1.create_new_session().unwrap();
499
500 let messages = vec![Message {
501 role: limit_llm::Role::User,
502 content: Some("Test message".to_string()),
503 tool_calls: None,
504 tool_call_id: None,
505 }];
506
507 manager1.save_session(&session_id, &messages, 0, 0).unwrap();
508
509 drop(manager1);
510
511 let manager2 = SessionManager {
512 db_path,
513 sessions_dir,
514 };
515
516 manager2.init_db().unwrap();
517
518 let loaded = manager2.load_session(&session_id).unwrap();
519 assert_eq!(loaded.len(), 1);
520 assert_eq!(loaded[0].content, Some("Test message".to_string()));
521 }
522}