1use crate::error::CliError;
2use bincode::{deserialize, serialize};
3use chrono::{DateTime, Utc};
4use limit_llm::Message;
5use rusqlite::{params, Connection};
6use serde::{Deserialize, Serialize};
7use std::fs;
8use std::path::PathBuf;
9use tracing::instrument;
10use uuid::Uuid;
11
12const CURRENT_VERSION: u32 = 2;
13
14#[derive(Debug, Clone)]
15pub struct SessionInfo {
16 pub id: String,
17 #[allow(dead_code)]
18 pub created_at: DateTime<Utc>,
19 #[allow(dead_code)]
20 pub last_accessed: DateTime<Utc>,
21 pub message_count: usize,
22 pub total_input_tokens: u64,
23 pub total_output_tokens: u64,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27struct PersistedState {
28 version: u32,
29 messages: Vec<PersistedMessage>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33struct PersistedMessage {
34 role: PersistedRole,
35 content: Option<String>,
36 tool_calls: Option<Vec<limit_llm::ToolCall>>,
37 tool_call_id: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41enum PersistedRole {
42 User,
43 Assistant,
44 System,
45 Tool,
46}
47
48impl From<PersistedRole> for limit_llm::Role {
49 fn from(role: PersistedRole) -> Self {
50 match role {
51 PersistedRole::User => limit_llm::Role::User,
52 PersistedRole::Assistant => limit_llm::Role::Assistant,
53 PersistedRole::System => limit_llm::Role::System,
54 PersistedRole::Tool => limit_llm::Role::Tool,
55 }
56 }
57}
58
59impl From<limit_llm::Role> for PersistedRole {
60 fn from(role: limit_llm::Role) -> Self {
61 match role {
62 limit_llm::Role::User => PersistedRole::User,
63 limit_llm::Role::Assistant => PersistedRole::Assistant,
64 limit_llm::Role::System => PersistedRole::System,
65 limit_llm::Role::Tool => PersistedRole::Tool,
66 }
67 }
68}
69
70impl From<PersistedMessage> for Message {
71 fn from(msg: PersistedMessage) -> Self {
72 Message {
73 role: msg.role.into(),
74 content: msg.content,
75 tool_calls: msg.tool_calls,
76 tool_call_id: msg.tool_call_id,
77 }
78 }
79}
80
81impl From<Message> for PersistedMessage {
82 fn from(msg: Message) -> Self {
83 PersistedMessage {
84 role: msg.role.into(),
85 content: msg.content,
86 tool_calls: msg.tool_calls,
87 tool_call_id: msg.tool_call_id,
88 }
89 }
90}
91
92pub struct SessionManager {
93 db_path: PathBuf,
94 sessions_dir: PathBuf,
95}
96
97impl SessionManager {
98 pub fn new() -> Result<Self, CliError> {
99 let home_dir = dirs::home_dir()
101 .ok_or_else(|| CliError::ConfigError("Failed to get home directory".to_string()))?;
102 let limit_dir = home_dir.join(".limit");
103 fs::create_dir_all(&limit_dir).map_err(|e| {
104 CliError::ConfigError(format!("Failed to create .limit directory: {}", e))
105 })?;
106
107 let sessions_dir = limit_dir.join("sessions");
108 fs::create_dir_all(&sessions_dir).map_err(|e| {
109 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
110 })?;
111
112 let db_path = limit_dir.join("session.db");
113 let session_manager = Self {
114 db_path,
115 sessions_dir,
116 };
117
118 session_manager.init_db()?;
119 Ok(session_manager)
120 }
121
122 pub fn init_db(&self) -> Result<(), CliError> {
123 let conn = Connection::open(&self.db_path)
124 .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))?;
125
126 conn.execute(
127 "CREATE TABLE IF NOT EXISTS sessions (
128 id TEXT PRIMARY KEY,
129 created_at TEXT NOT NULL,
130 last_accessed TEXT NOT NULL,
131 message_count INTEGER NOT NULL,
132 total_input_tokens INTEGER NOT NULL DEFAULT 0,
133 total_output_tokens INTEGER NOT NULL DEFAULT 0
134 )",
135 [],
136 )
137 .map_err(|e| CliError::ConfigError(format!("Failed to create sessions table: {}", e)))?;
138
139 conn.execute(
140 "CREATE INDEX IF NOT EXISTS idx_last_accessed ON sessions(last_accessed DESC)",
141 [],
142 )
143 .map_err(|e| CliError::ConfigError(format!("Failed to create index: {}", e)))?;
144
145 Ok(())
146 }
147
148 fn get_connection(&self) -> Result<Connection, CliError> {
149 Connection::open(&self.db_path)
150 .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))
151 }
152
153 pub fn create_new_session(&self) -> Result<String, CliError> {
154 let session_id = Uuid::new_v4().to_string();
155 let now = Utc::now().to_rfc3339();
156
157 let conn = self.get_connection()?;
158 conn.execute(
159 "INSERT INTO sessions (id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
160 params![&session_id, &now, &now, 0, 0, 0],
161 )
162 .map_err(|e| CliError::ConfigError(format!("Failed to create session: {}", e)))?;
163
164 Ok(session_id)
165 }
166
167 #[instrument(skip(self, messages))]
168 pub fn save_session(
169 &self,
170 session_id: &str,
171 messages: &[Message],
172 total_input_tokens: u64,
173 total_output_tokens: u64,
174 ) -> Result<(), CliError> {
175 let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
176
177 fs::create_dir_all(&self.sessions_dir).map_err(|e| {
178 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
179 })?;
180
181 let persisted_messages: Vec<PersistedMessage> =
182 messages.iter().cloned().map(|m| m.into()).collect();
183
184 let state = PersistedState {
185 version: CURRENT_VERSION,
186 messages: persisted_messages,
187 };
188
189 let serialized = serialize(&state)
190 .map_err(|e| CliError::ConfigError(format!("Failed to serialize messages: {}", e)))?;
191
192 fs::write(&file_path, serialized)
193 .map_err(|e| CliError::ConfigError(format!("Failed to write session file: {}", e)))?;
194
195 let now = Utc::now().to_rfc3339();
196 let conn = self.get_connection()?;
197 conn.execute(
198 "UPDATE sessions SET last_accessed = ?1, message_count = ?2, total_input_tokens = ?3, total_output_tokens = ?4 WHERE id = ?5",
199 params![&now, messages.len() as i64, total_input_tokens as i64, total_output_tokens as i64, session_id],
200 )
201 .map_err(|e| CliError::ConfigError(format!("Failed to update session metadata: {}", e)))?;
202
203 Ok(())
204 }
205
206 #[instrument(skip(self))]
207 pub fn load_session(&self, session_id: &str) -> Result<Vec<Message>, CliError> {
208 let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
209
210 let data = fs::read(&file_path)
211 .map_err(|e| CliError::ConfigError(format!("Failed to read session file: {}", e)))?;
212
213 let state: PersistedState = deserialize(&data)
214 .map_err(|e| CliError::ConfigError(format!("Failed to deserialize messages: {}", e)))?;
215
216 if state.version > CURRENT_VERSION {
218 return Err(CliError::ConfigError(format!(
219 "Version mismatch: expected {}, found {}",
220 CURRENT_VERSION, state.version
221 )));
222 }
223
224 let now = Utc::now().to_rfc3339();
225 let conn = self.get_connection()?;
226 conn.execute(
227 "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
228 params![&now, session_id],
229 )
230 .map_err(|e| CliError::ConfigError(format!("Failed to update last_accessed: {}", e)))?;
231
232 let messages: Vec<Message> = state
234 .messages
235 .into_iter()
236 .map(Message::from)
237 .filter(|m| m.role != limit_llm::Role::System)
238 .collect();
239
240 Ok(messages)
241 }
242
243 pub fn list_sessions(&self) -> Result<Vec<SessionInfo>, CliError> {
244 let conn = self.get_connection()?;
245
246 let mut stmt = conn
247 .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC")
248 .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
249
250 let session_iter = stmt
251 .query_map([], |row| {
252 Ok(SessionInfo {
253 id: row.get(0)?,
254 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
255 .unwrap()
256 .with_timezone(&Utc),
257 last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
258 .unwrap()
259 .with_timezone(&Utc),
260 message_count: row.get::<_, i64>(3)? as usize,
261 total_input_tokens: row.get::<_, i64>(4)? as u64,
262 total_output_tokens: row.get::<_, i64>(5)? as u64,
263 })
264 })
265 .map_err(|e| CliError::ConfigError(format!("Failed to query sessions: {}", e)))?;
266
267 let mut sessions = Vec::new();
268 for session in session_iter {
269 sessions.push(
270 session.map_err(|e| {
271 CliError::ConfigError(format!("Failed to parse session: {}", e))
272 })?,
273 );
274 }
275
276 Ok(sessions)
277 }
278
279 pub fn get_last_session(&self) -> Result<Option<SessionInfo>, CliError> {
280 let conn = self.get_connection()?;
281
282 let mut stmt = conn
283 .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC LIMIT 1")
284 .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
285
286 let mut session_iter = stmt
287 .query_map([], |row| {
288 Ok(SessionInfo {
289 id: row.get(0)?,
290 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
291 .unwrap()
292 .with_timezone(&Utc),
293 last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
294 .unwrap()
295 .with_timezone(&Utc),
296 message_count: row.get::<_, i64>(3)? as usize,
297 total_input_tokens: row.get::<_, i64>(4)? as u64,
298 total_output_tokens: row.get::<_, i64>(5)? as u64,
299 })
300 })
301 .map_err(|e| CliError::ConfigError(format!("Failed to query last session: {}", e)))?;
302
303 match session_iter.next() {
304 Some(session) => Ok(Some(session.map_err(|e| {
305 CliError::ConfigError(format!("Failed to parse session: {}", e))
306 })?)),
307 None => Ok(None),
308 }
309 }
310
311 #[allow(dead_code)]
313 pub fn update_session_tokens(
314 &self,
315 session_id: &str,
316 input_tokens: u64,
317 output_tokens: u64,
318 ) -> Result<(), CliError> {
319 let conn = self.get_connection()?;
320 conn.execute(
321 "UPDATE sessions SET total_input_tokens = total_input_tokens + ?1, total_output_tokens = total_output_tokens + ?2 WHERE id = ?3",
322 params![input_tokens as i64, output_tokens as i64, session_id],
323 )
324 .map_err(|e| CliError::ConfigError(format!("Failed to update session tokens: {}", e)))?;
325 Ok(())
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use tempfile::tempdir;
333
334 #[test]
335 fn test_session_manager_new() {
336 let dir = tempdir().unwrap();
337 let sessions_dir = dir.path().join("sessions");
338 let db_path = dir.path().join("session.db");
339
340 fs::create_dir_all(&sessions_dir).unwrap();
341
342 let manager = SessionManager {
343 db_path,
344 sessions_dir,
345 };
346
347 manager.init_db().unwrap();
348
349 assert!(manager.db_path.exists());
350 }
351
352 #[test]
353 fn test_create_new_session() {
354 let dir = tempdir().unwrap();
355 let sessions_dir = dir.path().join("sessions");
356 let db_path = dir.path().join("session.db");
357
358 fs::create_dir_all(&sessions_dir).unwrap();
359
360 let manager = SessionManager {
361 db_path,
362 sessions_dir,
363 };
364
365 manager.init_db().unwrap();
366
367 let session_id = manager.create_new_session().unwrap();
368
369 assert!(!session_id.is_empty());
370
371 let sessions = manager.list_sessions().unwrap();
372 assert_eq!(sessions.len(), 1);
373 assert_eq!(sessions[0].id, session_id);
374 assert_eq!(sessions[0].message_count, 0);
375 }
376
377 #[test]
378 fn test_save_and_load_session() {
379 let dir = tempdir().unwrap();
380 let sessions_dir = dir.path().join("sessions");
381 let db_path = dir.path().join("session.db");
382
383 fs::create_dir_all(&sessions_dir).unwrap();
384
385 let manager = SessionManager {
386 db_path,
387 sessions_dir,
388 };
389
390 manager.init_db().unwrap();
391
392 let session_id = manager.create_new_session().unwrap();
393
394 let messages = vec![
395 Message {
396 role: limit_llm::Role::User,
397 content: Some("Hello".to_string()),
398 tool_calls: None,
399 tool_call_id: None,
400 },
401 Message {
402 role: limit_llm::Role::Assistant,
403 content: Some("Hi there!".to_string()),
404 tool_calls: None,
405 tool_call_id: None,
406 },
407 ];
408
409 manager.save_session(&session_id, &messages, 0, 0).unwrap();
410
411 let loaded = manager.load_session(&session_id).unwrap();
412
413 assert_eq!(loaded.len(), messages.len());
414 assert_eq!(loaded[0].content, messages[0].content);
415 assert_eq!(loaded[1].content, messages[1].content);
416
417 let sessions = manager.list_sessions().unwrap();
418 assert_eq!(sessions[0].message_count, 2);
419 }
420
421 #[test]
422 fn test_list_sessions() {
423 let dir = tempdir().unwrap();
424 let sessions_dir = dir.path().join("sessions");
425 let db_path = dir.path().join("session.db");
426
427 fs::create_dir_all(&sessions_dir).unwrap();
428
429 let manager = SessionManager {
430 db_path,
431 sessions_dir,
432 };
433
434 manager.init_db().unwrap();
435
436 let session_id1 = manager.create_new_session().unwrap();
437 std::thread::sleep(std::time::Duration::from_millis(10));
438 let session_id2 = manager.create_new_session().unwrap();
439
440 let sessions = manager.list_sessions().unwrap();
441 assert_eq!(sessions.len(), 2);
442 assert_eq!(sessions[0].id, session_id2);
443 assert_eq!(sessions[1].id, session_id1);
444 }
445
446 #[test]
447 fn test_get_last_session() {
448 let dir = tempdir().unwrap();
449 let sessions_dir = dir.path().join("sessions");
450 let db_path = dir.path().join("session.db");
451
452 fs::create_dir_all(&sessions_dir).unwrap();
453
454 let manager = SessionManager {
455 db_path,
456 sessions_dir,
457 };
458
459 manager.init_db().unwrap();
460
461 let last = manager.get_last_session().unwrap();
462 assert!(last.is_none());
463
464 let _session_id1 = manager.create_new_session().unwrap();
465 std::thread::sleep(std::time::Duration::from_millis(10));
466 let session_id2 = manager.create_new_session().unwrap();
467
468 let last = manager.get_last_session().unwrap();
469 assert!(last.is_some());
470 assert_eq!(last.unwrap().id, session_id2);
471 }
472
473 #[test]
474 fn test_session_persistence_across_restarts() {
475 let dir = tempdir().unwrap();
476 let sessions_dir = dir.path().join("sessions");
477 let db_path = dir.path().join("session.db");
478
479 fs::create_dir_all(&sessions_dir).unwrap();
480
481 let manager1 = SessionManager {
482 db_path: db_path.clone(),
483 sessions_dir: sessions_dir.clone(),
484 };
485
486 manager1.init_db().unwrap();
487
488 let session_id = manager1.create_new_session().unwrap();
489
490 let messages = vec![Message {
491 role: limit_llm::Role::User,
492 content: Some("Test message".to_string()),
493 tool_calls: None,
494 tool_call_id: None,
495 }];
496
497 manager1.save_session(&session_id, &messages, 0, 0).unwrap();
498
499 drop(manager1);
500
501 let manager2 = SessionManager {
502 db_path,
503 sessions_dir,
504 };
505
506 manager2.init_db().unwrap();
507
508 let loaded = manager2.load_session(&session_id).unwrap();
509 assert_eq!(loaded.len(), 1);
510 assert_eq!(loaded[0].content, Some("Test message".to_string()));
511 }
512}