Skip to main content

limit_cli/
session.rs

1use crate::error::CliError;
2use bincode::{deserialize, serialize};
3use chrono::{DateTime, Utc};
4use limit_llm::Message;
5use rusqlite::{params, Connection};
6use serde::{Deserialize, Serialize};
7use std::fs;
8use std::path::PathBuf;
9use tracing::instrument;
10use uuid::Uuid;
11
12const CURRENT_VERSION: u32 = 2;
13
14#[derive(Debug, Clone)]
15pub struct SessionInfo {
16    pub id: String,
17    #[allow(dead_code)]
18    pub created_at: DateTime<Utc>,
19    #[allow(dead_code)]
20    pub last_accessed: DateTime<Utc>,
21    pub message_count: usize,
22    pub total_input_tokens: u64,
23    pub total_output_tokens: u64,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27struct PersistedState {
28    version: u32,
29    messages: Vec<PersistedMessage>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33struct PersistedMessage {
34    role: PersistedRole,
35    content: Option<String>,
36    tool_calls: Option<Vec<limit_llm::ToolCall>>,
37    tool_call_id: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41enum PersistedRole {
42    User,
43    Assistant,
44    System,
45    Tool,
46}
47
48impl From<PersistedRole> for limit_llm::Role {
49    fn from(role: PersistedRole) -> Self {
50        match role {
51            PersistedRole::User => limit_llm::Role::User,
52            PersistedRole::Assistant => limit_llm::Role::Assistant,
53            PersistedRole::System => limit_llm::Role::System,
54            PersistedRole::Tool => limit_llm::Role::Tool,
55        }
56    }
57}
58
59impl From<limit_llm::Role> for PersistedRole {
60    fn from(role: limit_llm::Role) -> Self {
61        match role {
62            limit_llm::Role::User => PersistedRole::User,
63            limit_llm::Role::Assistant => PersistedRole::Assistant,
64            limit_llm::Role::System => PersistedRole::System,
65            limit_llm::Role::Tool => PersistedRole::Tool,
66        }
67    }
68}
69
70impl From<PersistedMessage> for Message {
71    fn from(msg: PersistedMessage) -> Self {
72        Message {
73            role: msg.role.into(),
74            content: msg.content,
75            tool_calls: msg.tool_calls,
76            tool_call_id: msg.tool_call_id,
77        }
78    }
79}
80
81impl From<Message> for PersistedMessage {
82    fn from(msg: Message) -> Self {
83        PersistedMessage {
84            role: msg.role.into(),
85            content: msg.content,
86            tool_calls: msg.tool_calls,
87            tool_call_id: msg.tool_call_id,
88        }
89    }
90}
91
92pub struct SessionManager {
93    db_path: PathBuf,
94    sessions_dir: PathBuf,
95}
96
97impl SessionManager {
98    pub fn new() -> Result<Self, CliError> {
99        // Centralize all session data in ~/.limit/
100        let home_dir = dirs::home_dir()
101            .ok_or_else(|| CliError::ConfigError("Failed to get home directory".to_string()))?;
102        let limit_dir = home_dir.join(".limit");
103        fs::create_dir_all(&limit_dir).map_err(|e| {
104            CliError::ConfigError(format!("Failed to create .limit directory: {}", e))
105        })?;
106
107        let sessions_dir = limit_dir.join("sessions");
108        fs::create_dir_all(&sessions_dir).map_err(|e| {
109            CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
110        })?;
111
112        let db_path = limit_dir.join("session.db");
113
114        Self::with_paths(db_path, sessions_dir)
115    }
116
117    /// Create SessionManager with custom paths (for testing)
118    pub fn with_paths(db_path: PathBuf, sessions_dir: PathBuf) -> Result<Self, CliError> {
119        fs::create_dir_all(&sessions_dir).map_err(|e| {
120            CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
121        })?;
122
123        let session_manager = Self {
124            db_path,
125            sessions_dir,
126        };
127
128        session_manager.init_db()?;
129        Ok(session_manager)
130    }
131
132    pub fn init_db(&self) -> Result<(), CliError> {
133        let conn = Connection::open(&self.db_path)
134            .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))?;
135
136        conn.execute(
137            "CREATE TABLE IF NOT EXISTS sessions (
138                id TEXT PRIMARY KEY,
139                created_at TEXT NOT NULL,
140                last_accessed TEXT NOT NULL,
141                message_count INTEGER NOT NULL,
142                total_input_tokens INTEGER NOT NULL DEFAULT 0,
143                total_output_tokens INTEGER NOT NULL DEFAULT 0
144            )",
145            [],
146        )
147        .map_err(|e| CliError::ConfigError(format!("Failed to create sessions table: {}", e)))?;
148
149        conn.execute(
150            "CREATE INDEX IF NOT EXISTS idx_last_accessed ON sessions(last_accessed DESC)",
151            [],
152        )
153        .map_err(|e| CliError::ConfigError(format!("Failed to create index: {}", e)))?;
154
155        Ok(())
156    }
157
158    fn get_connection(&self) -> Result<Connection, CliError> {
159        Connection::open(&self.db_path)
160            .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))
161    }
162
163    pub fn create_new_session(&self) -> Result<String, CliError> {
164        let session_id = Uuid::new_v4().to_string();
165        let now = Utc::now().to_rfc3339();
166
167        let conn = self.get_connection()?;
168        conn.execute(
169            "INSERT INTO sessions (id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
170            params![&session_id, &now, &now, 0, 0, 0],
171        )
172        .map_err(|e| CliError::ConfigError(format!("Failed to create session: {}", e)))?;
173
174        Ok(session_id)
175    }
176
177    #[instrument(skip(self, messages))]
178    pub fn save_session(
179        &self,
180        session_id: &str,
181        messages: &[Message],
182        total_input_tokens: u64,
183        total_output_tokens: u64,
184    ) -> Result<(), CliError> {
185        let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
186
187        fs::create_dir_all(&self.sessions_dir).map_err(|e| {
188            CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
189        })?;
190
191        let persisted_messages: Vec<PersistedMessage> =
192            messages.iter().cloned().map(|m| m.into()).collect();
193
194        let state = PersistedState {
195            version: CURRENT_VERSION,
196            messages: persisted_messages,
197        };
198
199        let serialized = serialize(&state)
200            .map_err(|e| CliError::ConfigError(format!("Failed to serialize messages: {}", e)))?;
201
202        fs::write(&file_path, serialized)
203            .map_err(|e| CliError::ConfigError(format!("Failed to write session file: {}", e)))?;
204
205        let now = Utc::now().to_rfc3339();
206        let conn = self.get_connection()?;
207        conn.execute(
208            "UPDATE sessions SET last_accessed = ?1, message_count = ?2, total_input_tokens = ?3, total_output_tokens = ?4 WHERE id = ?5",
209            params![&now, messages.len() as i64, total_input_tokens as i64, total_output_tokens as i64, session_id],
210        )
211        .map_err(|e| CliError::ConfigError(format!("Failed to update session metadata: {}", e)))?;
212
213        Ok(())
214    }
215
216    #[instrument(skip(self))]
217    pub fn load_session(&self, session_id: &str) -> Result<Vec<Message>, CliError> {
218        let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
219
220        let data = fs::read(&file_path)
221            .map_err(|e| CliError::ConfigError(format!("Failed to read session file: {}", e)))?;
222
223        let state: PersistedState = deserialize(&data)
224            .map_err(|e| CliError::ConfigError(format!("Failed to deserialize messages: {}", e)))?;
225
226        // Handle version migration if needed
227        if state.version > CURRENT_VERSION {
228            return Err(CliError::ConfigError(format!(
229                "Version mismatch: expected {}, found {}",
230                CURRENT_VERSION, state.version
231            )));
232        }
233
234        let now = Utc::now().to_rfc3339();
235        let conn = self.get_connection()?;
236        conn.execute(
237            "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
238            params![&now, session_id],
239        )
240        .map_err(|e| CliError::ConfigError(format!("Failed to update last_accessed: {}", e)))?;
241
242        // Filter out system messages - z.ai API doesn't support system role
243        let messages: Vec<Message> = state
244            .messages
245            .into_iter()
246            .map(Message::from)
247            .filter(|m| m.role != limit_llm::Role::System)
248            .collect();
249
250        Ok(messages)
251    }
252
253    pub fn list_sessions(&self) -> Result<Vec<SessionInfo>, CliError> {
254        let conn = self.get_connection()?;
255
256        let mut stmt = conn
257            .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC")
258            .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
259
260        let session_iter = stmt
261            .query_map([], |row| {
262                Ok(SessionInfo {
263                    id: row.get(0)?,
264                    created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
265                        .unwrap()
266                        .with_timezone(&Utc),
267                    last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
268                        .unwrap()
269                        .with_timezone(&Utc),
270                    message_count: row.get::<_, i64>(3)? as usize,
271                    total_input_tokens: row.get::<_, i64>(4)? as u64,
272                    total_output_tokens: row.get::<_, i64>(5)? as u64,
273                })
274            })
275            .map_err(|e| CliError::ConfigError(format!("Failed to query sessions: {}", e)))?;
276
277        let mut sessions = Vec::new();
278        for session in session_iter {
279            sessions.push(
280                session.map_err(|e| {
281                    CliError::ConfigError(format!("Failed to parse session: {}", e))
282                })?,
283            );
284        }
285
286        Ok(sessions)
287    }
288
289    pub fn get_last_session(&self) -> Result<Option<SessionInfo>, CliError> {
290        let conn = self.get_connection()?;
291
292        let mut stmt = conn
293            .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC LIMIT 1")
294            .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
295
296        let mut session_iter = stmt
297            .query_map([], |row| {
298                Ok(SessionInfo {
299                    id: row.get(0)?,
300                    created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
301                        .unwrap()
302                        .with_timezone(&Utc),
303                    last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
304                        .unwrap()
305                        .with_timezone(&Utc),
306                    message_count: row.get::<_, i64>(3)? as usize,
307                    total_input_tokens: row.get::<_, i64>(4)? as u64,
308                    total_output_tokens: row.get::<_, i64>(5)? as u64,
309                })
310            })
311            .map_err(|e| CliError::ConfigError(format!("Failed to query last session: {}", e)))?;
312
313        match session_iter.next() {
314            Some(session) => Ok(Some(session.map_err(|e| {
315                CliError::ConfigError(format!("Failed to parse session: {}", e))
316            })?)),
317            None => Ok(None),
318        }
319    }
320
321    /// Update token counts for a session
322    #[allow(dead_code)]
323    pub fn update_session_tokens(
324        &self,
325        session_id: &str,
326        input_tokens: u64,
327        output_tokens: u64,
328    ) -> Result<(), CliError> {
329        let conn = self.get_connection()?;
330        conn.execute(
331            "UPDATE sessions SET total_input_tokens = total_input_tokens + ?1, total_output_tokens = total_output_tokens + ?2 WHERE id = ?3",
332            params![input_tokens as i64, output_tokens as i64, session_id],
333        )
334        .map_err(|e| CliError::ConfigError(format!("Failed to update session tokens: {}", e)))?;
335        Ok(())
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use tempfile::tempdir;
343
344    #[test]
345    fn test_session_manager_new() {
346        let dir = tempdir().unwrap();
347        let sessions_dir = dir.path().join("sessions");
348        let db_path = dir.path().join("session.db");
349
350        fs::create_dir_all(&sessions_dir).unwrap();
351
352        let manager = SessionManager {
353            db_path,
354            sessions_dir,
355        };
356
357        manager.init_db().unwrap();
358
359        assert!(manager.db_path.exists());
360    }
361
362    #[test]
363    fn test_create_new_session() {
364        let dir = tempdir().unwrap();
365        let sessions_dir = dir.path().join("sessions");
366        let db_path = dir.path().join("session.db");
367
368        fs::create_dir_all(&sessions_dir).unwrap();
369
370        let manager = SessionManager {
371            db_path,
372            sessions_dir,
373        };
374
375        manager.init_db().unwrap();
376
377        let session_id = manager.create_new_session().unwrap();
378
379        assert!(!session_id.is_empty());
380
381        let sessions = manager.list_sessions().unwrap();
382        assert_eq!(sessions.len(), 1);
383        assert_eq!(sessions[0].id, session_id);
384        assert_eq!(sessions[0].message_count, 0);
385    }
386
387    #[test]
388    fn test_save_and_load_session() {
389        let dir = tempdir().unwrap();
390        let sessions_dir = dir.path().join("sessions");
391        let db_path = dir.path().join("session.db");
392
393        fs::create_dir_all(&sessions_dir).unwrap();
394
395        let manager = SessionManager {
396            db_path,
397            sessions_dir,
398        };
399
400        manager.init_db().unwrap();
401
402        let session_id = manager.create_new_session().unwrap();
403
404        let messages = vec![
405            Message {
406                role: limit_llm::Role::User,
407                content: Some("Hello".to_string()),
408                tool_calls: None,
409                tool_call_id: None,
410            },
411            Message {
412                role: limit_llm::Role::Assistant,
413                content: Some("Hi there!".to_string()),
414                tool_calls: None,
415                tool_call_id: None,
416            },
417        ];
418
419        manager.save_session(&session_id, &messages, 0, 0).unwrap();
420
421        let loaded = manager.load_session(&session_id).unwrap();
422
423        assert_eq!(loaded.len(), messages.len());
424        assert_eq!(loaded[0].content, messages[0].content);
425        assert_eq!(loaded[1].content, messages[1].content);
426
427        let sessions = manager.list_sessions().unwrap();
428        assert_eq!(sessions[0].message_count, 2);
429    }
430
431    #[test]
432    fn test_list_sessions() {
433        let dir = tempdir().unwrap();
434        let sessions_dir = dir.path().join("sessions");
435        let db_path = dir.path().join("session.db");
436
437        fs::create_dir_all(&sessions_dir).unwrap();
438
439        let manager = SessionManager {
440            db_path,
441            sessions_dir,
442        };
443
444        manager.init_db().unwrap();
445
446        let session_id1 = manager.create_new_session().unwrap();
447        std::thread::sleep(std::time::Duration::from_millis(10));
448        let session_id2 = manager.create_new_session().unwrap();
449
450        let sessions = manager.list_sessions().unwrap();
451        assert_eq!(sessions.len(), 2);
452        assert_eq!(sessions[0].id, session_id2);
453        assert_eq!(sessions[1].id, session_id1);
454    }
455
456    #[test]
457    fn test_get_last_session() {
458        let dir = tempdir().unwrap();
459        let sessions_dir = dir.path().join("sessions");
460        let db_path = dir.path().join("session.db");
461
462        fs::create_dir_all(&sessions_dir).unwrap();
463
464        let manager = SessionManager {
465            db_path,
466            sessions_dir,
467        };
468
469        manager.init_db().unwrap();
470
471        let last = manager.get_last_session().unwrap();
472        assert!(last.is_none());
473
474        let _session_id1 = manager.create_new_session().unwrap();
475        std::thread::sleep(std::time::Duration::from_millis(10));
476        let session_id2 = manager.create_new_session().unwrap();
477
478        let last = manager.get_last_session().unwrap();
479        assert!(last.is_some());
480        assert_eq!(last.unwrap().id, session_id2);
481    }
482
483    #[test]
484    fn test_session_persistence_across_restarts() {
485        let dir = tempdir().unwrap();
486        let sessions_dir = dir.path().join("sessions");
487        let db_path = dir.path().join("session.db");
488
489        fs::create_dir_all(&sessions_dir).unwrap();
490
491        let manager1 = SessionManager {
492            db_path: db_path.clone(),
493            sessions_dir: sessions_dir.clone(),
494        };
495
496        manager1.init_db().unwrap();
497
498        let session_id = manager1.create_new_session().unwrap();
499
500        let messages = vec![Message {
501            role: limit_llm::Role::User,
502            content: Some("Test message".to_string()),
503            tool_calls: None,
504            tool_call_id: None,
505        }];
506
507        manager1.save_session(&session_id, &messages, 0, 0).unwrap();
508
509        drop(manager1);
510
511        let manager2 = SessionManager {
512            db_path,
513            sessions_dir,
514        };
515
516        manager2.init_db().unwrap();
517
518        let loaded = manager2.load_session(&session_id).unwrap();
519        assert_eq!(loaded.len(), 1);
520        assert_eq!(loaded[0].content, Some("Test message".to_string()));
521    }
522}