1use crate::error::CliError;
2use bincode::{deserialize, serialize};
3use chrono::{DateTime, Utc};
4use limit_llm::Message;
5use rusqlite::{params, Connection};
6use serde::{Deserialize, Serialize};
7use std::fs;
8use std::path::PathBuf;
9use tracing::instrument;
10use uuid::Uuid;
11
12const CURRENT_VERSION: u32 = 2;
13
14#[derive(Debug, Clone)]
15pub struct SessionInfo {
16 pub id: String,
17 #[allow(dead_code)]
18 pub created_at: DateTime<Utc>,
19 #[allow(dead_code)]
20 pub last_accessed: DateTime<Utc>,
21 pub message_count: usize,
22 pub total_input_tokens: u64,
23 pub total_output_tokens: u64,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27struct PersistedState {
28 version: u32,
29 messages: Vec<PersistedMessage>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33struct PersistedMessage {
34 role: PersistedRole,
35 content: Option<String>,
36 tool_calls: Option<Vec<limit_llm::ToolCall>>,
37 tool_call_id: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41enum PersistedRole {
42 User,
43 Assistant,
44 System,
45 Tool,
46}
47
48impl From<PersistedRole> for limit_llm::Role {
49 fn from(role: PersistedRole) -> Self {
50 match role {
51 PersistedRole::User => limit_llm::Role::User,
52 PersistedRole::Assistant => limit_llm::Role::Assistant,
53 PersistedRole::System => limit_llm::Role::System,
54 PersistedRole::Tool => limit_llm::Role::Tool,
55 }
56 }
57}
58
59impl From<limit_llm::Role> for PersistedRole {
60 fn from(role: limit_llm::Role) -> Self {
61 match role {
62 limit_llm::Role::User => PersistedRole::User,
63 limit_llm::Role::Assistant => PersistedRole::Assistant,
64 limit_llm::Role::System => PersistedRole::System,
65 limit_llm::Role::Tool => PersistedRole::Tool,
66 }
67 }
68}
69
70impl From<PersistedMessage> for Message {
71 fn from(msg: PersistedMessage) -> Self {
72 Message {
73 role: msg.role.into(),
74 content: msg.content,
75 tool_calls: msg.tool_calls,
76 tool_call_id: msg.tool_call_id,
77 }
78 }
79}
80
81impl From<Message> for PersistedMessage {
82 fn from(msg: Message) -> Self {
83 PersistedMessage {
84 role: msg.role.into(),
85 content: msg.content,
86 tool_calls: msg.tool_calls,
87 tool_call_id: msg.tool_call_id,
88 }
89 }
90}
91
92pub struct SessionManager {
93 db_path: PathBuf,
94 sessions_dir: PathBuf,
95}
96
97impl SessionManager {
98 pub fn new() -> Result<Self, CliError> {
99 let limit_dir = PathBuf::from(".limit");
100 fs::create_dir_all(&limit_dir).map_err(|e| {
101 CliError::ConfigError(format!("Failed to create .limit directory: {}", e))
102 })?;
103
104 let sessions_dir = limit_dir.join("sessions");
105 fs::create_dir_all(&sessions_dir).map_err(|e| {
106 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
107 })?;
108
109 let db_path = limit_dir.join("session.db");
110 let session_manager = Self {
111 db_path,
112 sessions_dir,
113 };
114
115 session_manager.init_db()?;
116 Ok(session_manager)
117 }
118
119 pub fn init_db(&self) -> Result<(), CliError> {
120 let conn = Connection::open(&self.db_path)
121 .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))?;
122
123 conn.execute(
124 "CREATE TABLE IF NOT EXISTS sessions (
125 id TEXT PRIMARY KEY,
126 created_at TEXT NOT NULL,
127 last_accessed TEXT NOT NULL,
128 message_count INTEGER NOT NULL,
129 total_input_tokens INTEGER NOT NULL DEFAULT 0,
130 total_output_tokens INTEGER NOT NULL DEFAULT 0
131 )",
132 [],
133 )
134 .map_err(|e| CliError::ConfigError(format!("Failed to create sessions table: {}", e)))?;
135
136 conn.execute(
137 "CREATE INDEX IF NOT EXISTS idx_last_accessed ON sessions(last_accessed DESC)",
138 [],
139 )
140 .map_err(|e| CliError::ConfigError(format!("Failed to create index: {}", e)))?;
141
142 Ok(())
143 }
144
145 fn get_connection(&self) -> Result<Connection, CliError> {
146 Connection::open(&self.db_path)
147 .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))
148 }
149
150 pub fn create_new_session(&self) -> Result<String, CliError> {
151 let session_id = Uuid::new_v4().to_string();
152 let now = Utc::now().to_rfc3339();
153
154 let conn = self.get_connection()?;
155 conn.execute(
156 "INSERT INTO sessions (id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
157 params![&session_id, &now, &now, 0, 0, 0],
158 )
159 .map_err(|e| CliError::ConfigError(format!("Failed to create session: {}", e)))?;
160
161 Ok(session_id)
162 }
163
164 #[instrument(skip(self, messages))]
165 pub fn save_session(
166 &self,
167 session_id: &str,
168 messages: &[Message],
169 total_input_tokens: u64,
170 total_output_tokens: u64,
171 ) -> Result<(), CliError> {
172 let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
173
174 fs::create_dir_all(&self.sessions_dir).map_err(|e| {
175 CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
176 })?;
177
178 let persisted_messages: Vec<PersistedMessage> =
179 messages.iter().cloned().map(|m| m.into()).collect();
180
181 let state = PersistedState {
182 version: CURRENT_VERSION,
183 messages: persisted_messages,
184 };
185
186 let serialized = serialize(&state)
187 .map_err(|e| CliError::ConfigError(format!("Failed to serialize messages: {}", e)))?;
188
189 fs::write(&file_path, serialized)
190 .map_err(|e| CliError::ConfigError(format!("Failed to write session file: {}", e)))?;
191
192 let now = Utc::now().to_rfc3339();
193 let conn = self.get_connection()?;
194 conn.execute(
195 "UPDATE sessions SET last_accessed = ?1, message_count = ?2, total_input_tokens = ?3, total_output_tokens = ?4 WHERE id = ?5",
196 params![&now, messages.len() as i64, total_input_tokens as i64, total_output_tokens as i64, session_id],
197 )
198 .map_err(|e| CliError::ConfigError(format!("Failed to update session metadata: {}", e)))?;
199
200 Ok(())
201 }
202
203 #[instrument(skip(self))]
204 pub fn load_session(&self, session_id: &str) -> Result<Vec<Message>, CliError> {
205 let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
206
207 let data = fs::read(&file_path)
208 .map_err(|e| CliError::ConfigError(format!("Failed to read session file: {}", e)))?;
209
210 let state: PersistedState = deserialize(&data)
211 .map_err(|e| CliError::ConfigError(format!("Failed to deserialize messages: {}", e)))?;
212
213 if state.version > CURRENT_VERSION {
215 return Err(CliError::ConfigError(format!(
216 "Version mismatch: expected {}, found {}",
217 CURRENT_VERSION, state.version
218 )));
219 }
220
221 let now = Utc::now().to_rfc3339();
222 let conn = self.get_connection()?;
223 conn.execute(
224 "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
225 params![&now, session_id],
226 )
227 .map_err(|e| CliError::ConfigError(format!("Failed to update last_accessed: {}", e)))?;
228
229 let messages: Vec<Message> = state
231 .messages
232 .into_iter()
233 .map(Message::from)
234 .filter(|m| m.role != limit_llm::Role::System)
235 .collect();
236
237 Ok(messages)
238 }
239
240 pub fn list_sessions(&self) -> Result<Vec<SessionInfo>, CliError> {
241 let conn = self.get_connection()?;
242
243 let mut stmt = conn
244 .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC")
245 .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
246
247 let session_iter = stmt
248 .query_map([], |row| {
249 Ok(SessionInfo {
250 id: row.get(0)?,
251 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
252 .unwrap()
253 .with_timezone(&Utc),
254 last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
255 .unwrap()
256 .with_timezone(&Utc),
257 message_count: row.get::<_, i64>(3)? as usize,
258 total_input_tokens: row.get::<_, i64>(4)? as u64,
259 total_output_tokens: row.get::<_, i64>(5)? as u64,
260 })
261 })
262 .map_err(|e| CliError::ConfigError(format!("Failed to query sessions: {}", e)))?;
263
264 let mut sessions = Vec::new();
265 for session in session_iter {
266 sessions.push(
267 session.map_err(|e| {
268 CliError::ConfigError(format!("Failed to parse session: {}", e))
269 })?,
270 );
271 }
272
273 Ok(sessions)
274 }
275
276 pub fn get_last_session(&self) -> Result<Option<SessionInfo>, CliError> {
277 let conn = self.get_connection()?;
278
279 let mut stmt = conn
280 .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC LIMIT 1")
281 .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
282
283 let mut session_iter = stmt
284 .query_map([], |row| {
285 Ok(SessionInfo {
286 id: row.get(0)?,
287 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
288 .unwrap()
289 .with_timezone(&Utc),
290 last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
291 .unwrap()
292 .with_timezone(&Utc),
293 message_count: row.get::<_, i64>(3)? as usize,
294 total_input_tokens: row.get::<_, i64>(4)? as u64,
295 total_output_tokens: row.get::<_, i64>(5)? as u64,
296 })
297 })
298 .map_err(|e| CliError::ConfigError(format!("Failed to query last session: {}", e)))?;
299
300 match session_iter.next() {
301 Some(session) => Ok(Some(session.map_err(|e| {
302 CliError::ConfigError(format!("Failed to parse session: {}", e))
303 })?)),
304 None => Ok(None),
305 }
306 }
307
308 #[allow(dead_code)]
310 pub fn update_session_tokens(
311 &self,
312 session_id: &str,
313 input_tokens: u64,
314 output_tokens: u64,
315 ) -> Result<(), CliError> {
316 let conn = self.get_connection()?;
317 conn.execute(
318 "UPDATE sessions SET total_input_tokens = total_input_tokens + ?1, total_output_tokens = total_output_tokens + ?2 WHERE id = ?3",
319 params![input_tokens as i64, output_tokens as i64, session_id],
320 )
321 .map_err(|e| CliError::ConfigError(format!("Failed to update session tokens: {}", e)))?;
322 Ok(())
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use tempfile::tempdir;
330
331 #[test]
332 fn test_session_manager_new() {
333 let dir = tempdir().unwrap();
334 let sessions_dir = dir.path().join("sessions");
335 let db_path = dir.path().join("session.db");
336
337 fs::create_dir_all(&sessions_dir).unwrap();
338
339 let manager = SessionManager {
340 db_path,
341 sessions_dir,
342 };
343
344 manager.init_db().unwrap();
345
346 assert!(manager.db_path.exists());
347 }
348
349 #[test]
350 fn test_create_new_session() {
351 let dir = tempdir().unwrap();
352 let sessions_dir = dir.path().join("sessions");
353 let db_path = dir.path().join("session.db");
354
355 fs::create_dir_all(&sessions_dir).unwrap();
356
357 let manager = SessionManager {
358 db_path,
359 sessions_dir,
360 };
361
362 manager.init_db().unwrap();
363
364 let session_id = manager.create_new_session().unwrap();
365
366 assert!(!session_id.is_empty());
367
368 let sessions = manager.list_sessions().unwrap();
369 assert_eq!(sessions.len(), 1);
370 assert_eq!(sessions[0].id, session_id);
371 assert_eq!(sessions[0].message_count, 0);
372 }
373
374 #[test]
375 fn test_save_and_load_session() {
376 let dir = tempdir().unwrap();
377 let sessions_dir = dir.path().join("sessions");
378 let db_path = dir.path().join("session.db");
379
380 fs::create_dir_all(&sessions_dir).unwrap();
381
382 let manager = SessionManager {
383 db_path,
384 sessions_dir,
385 };
386
387 manager.init_db().unwrap();
388
389 let session_id = manager.create_new_session().unwrap();
390
391 let messages = vec![
392 Message {
393 role: limit_llm::Role::User,
394 content: Some("Hello".to_string()),
395 tool_calls: None,
396 tool_call_id: None,
397 },
398 Message {
399 role: limit_llm::Role::Assistant,
400 content: Some("Hi there!".to_string()),
401 tool_calls: None,
402 tool_call_id: None,
403 },
404 ];
405
406 manager.save_session(&session_id, &messages, 0, 0).unwrap();
407
408 let loaded = manager.load_session(&session_id).unwrap();
409
410 assert_eq!(loaded.len(), messages.len());
411 assert_eq!(loaded[0].content, messages[0].content);
412 assert_eq!(loaded[1].content, messages[1].content);
413
414 let sessions = manager.list_sessions().unwrap();
415 assert_eq!(sessions[0].message_count, 2);
416 }
417
418 #[test]
419 fn test_list_sessions() {
420 let dir = tempdir().unwrap();
421 let sessions_dir = dir.path().join("sessions");
422 let db_path = dir.path().join("session.db");
423
424 fs::create_dir_all(&sessions_dir).unwrap();
425
426 let manager = SessionManager {
427 db_path,
428 sessions_dir,
429 };
430
431 manager.init_db().unwrap();
432
433 let session_id1 = manager.create_new_session().unwrap();
434 std::thread::sleep(std::time::Duration::from_millis(10));
435 let session_id2 = manager.create_new_session().unwrap();
436
437 let sessions = manager.list_sessions().unwrap();
438 assert_eq!(sessions.len(), 2);
439 assert_eq!(sessions[0].id, session_id2);
440 assert_eq!(sessions[1].id, session_id1);
441 }
442
443 #[test]
444 fn test_get_last_session() {
445 let dir = tempdir().unwrap();
446 let sessions_dir = dir.path().join("sessions");
447 let db_path = dir.path().join("session.db");
448
449 fs::create_dir_all(&sessions_dir).unwrap();
450
451 let manager = SessionManager {
452 db_path,
453 sessions_dir,
454 };
455
456 manager.init_db().unwrap();
457
458 let last = manager.get_last_session().unwrap();
459 assert!(last.is_none());
460
461 let _session_id1 = manager.create_new_session().unwrap();
462 std::thread::sleep(std::time::Duration::from_millis(10));
463 let session_id2 = manager.create_new_session().unwrap();
464
465 let last = manager.get_last_session().unwrap();
466 assert!(last.is_some());
467 assert_eq!(last.unwrap().id, session_id2);
468 }
469
470 #[test]
471 fn test_session_persistence_across_restarts() {
472 let dir = tempdir().unwrap();
473 let sessions_dir = dir.path().join("sessions");
474 let db_path = dir.path().join("session.db");
475
476 fs::create_dir_all(&sessions_dir).unwrap();
477
478 let manager1 = SessionManager {
479 db_path: db_path.clone(),
480 sessions_dir: sessions_dir.clone(),
481 };
482
483 manager1.init_db().unwrap();
484
485 let session_id = manager1.create_new_session().unwrap();
486
487 let messages = vec![Message {
488 role: limit_llm::Role::User,
489 content: Some("Test message".to_string()),
490 tool_calls: None,
491 tool_call_id: None,
492 }];
493
494 manager1.save_session(&session_id, &messages, 0, 0).unwrap();
495
496 drop(manager1);
497
498 let manager2 = SessionManager {
499 db_path,
500 sessions_dir,
501 };
502
503 manager2.init_db().unwrap();
504
505 let loaded = manager2.load_session(&session_id).unwrap();
506 assert_eq!(loaded.len(), 1);
507 assert_eq!(loaded[0].content, Some("Test message".to_string()));
508 }
509}