Skip to main content

limit_cli/
session.rs

1use crate::error::CliError;
2use bincode::{deserialize, serialize};
3use chrono::{DateTime, Utc};
4use limit_llm::Message;
5use rusqlite::{params, Connection};
6use serde::{Deserialize, Serialize};
7use std::fs;
8use std::path::PathBuf;
9use tracing::instrument;
10use uuid::Uuid;
11
12const CURRENT_VERSION: u32 = 2;
13
14#[derive(Debug, Clone)]
15pub struct SessionInfo {
16    pub id: String,
17    #[allow(dead_code)]
18    pub created_at: DateTime<Utc>,
19    #[allow(dead_code)]
20    pub last_accessed: DateTime<Utc>,
21    pub message_count: usize,
22    pub total_input_tokens: u64,
23    pub total_output_tokens: u64,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27struct PersistedState {
28    version: u32,
29    messages: Vec<PersistedMessage>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33struct PersistedMessage {
34    role: PersistedRole,
35    content: Option<String>,
36    tool_calls: Option<Vec<limit_llm::ToolCall>>,
37    tool_call_id: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41enum PersistedRole {
42    User,
43    Assistant,
44    System,
45    Tool,
46}
47
48impl From<PersistedRole> for limit_llm::Role {
49    fn from(role: PersistedRole) -> Self {
50        match role {
51            PersistedRole::User => limit_llm::Role::User,
52            PersistedRole::Assistant => limit_llm::Role::Assistant,
53            PersistedRole::System => limit_llm::Role::System,
54            PersistedRole::Tool => limit_llm::Role::Tool,
55        }
56    }
57}
58
59impl From<limit_llm::Role> for PersistedRole {
60    fn from(role: limit_llm::Role) -> Self {
61        match role {
62            limit_llm::Role::User => PersistedRole::User,
63            limit_llm::Role::Assistant => PersistedRole::Assistant,
64            limit_llm::Role::System => PersistedRole::System,
65            limit_llm::Role::Tool => PersistedRole::Tool,
66        }
67    }
68}
69
70impl From<PersistedMessage> for Message {
71    fn from(msg: PersistedMessage) -> Self {
72        Message {
73            role: msg.role.into(),
74            content: msg.content,
75            tool_calls: msg.tool_calls,
76            tool_call_id: msg.tool_call_id,
77        }
78    }
79}
80
81impl From<Message> for PersistedMessage {
82    fn from(msg: Message) -> Self {
83        PersistedMessage {
84            role: msg.role.into(),
85            content: msg.content,
86            tool_calls: msg.tool_calls,
87            tool_call_id: msg.tool_call_id,
88        }
89    }
90}
91
92pub struct SessionManager {
93    db_path: PathBuf,
94    sessions_dir: PathBuf,
95}
96
97impl SessionManager {
98    pub fn new() -> Result<Self, CliError> {
99        // Centralize all session data in ~/.limit/
100        let home_dir = dirs::home_dir()
101            .ok_or_else(|| CliError::ConfigError("Failed to get home directory".to_string()))?;
102        let limit_dir = home_dir.join(".limit");
103        fs::create_dir_all(&limit_dir).map_err(|e| {
104            CliError::ConfigError(format!("Failed to create .limit directory: {}", e))
105        })?;
106
107        let sessions_dir = limit_dir.join("sessions");
108        fs::create_dir_all(&sessions_dir).map_err(|e| {
109            CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
110        })?;
111
112        let db_path = limit_dir.join("session.db");
113        let session_manager = Self {
114            db_path,
115            sessions_dir,
116        };
117
118        session_manager.init_db()?;
119        Ok(session_manager)
120    }
121
122    pub fn init_db(&self) -> Result<(), CliError> {
123        let conn = Connection::open(&self.db_path)
124            .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))?;
125
126        conn.execute(
127            "CREATE TABLE IF NOT EXISTS sessions (
128                id TEXT PRIMARY KEY,
129                created_at TEXT NOT NULL,
130                last_accessed TEXT NOT NULL,
131                message_count INTEGER NOT NULL,
132                total_input_tokens INTEGER NOT NULL DEFAULT 0,
133                total_output_tokens INTEGER NOT NULL DEFAULT 0
134            )",
135            [],
136        )
137        .map_err(|e| CliError::ConfigError(format!("Failed to create sessions table: {}", e)))?;
138
139        conn.execute(
140            "CREATE INDEX IF NOT EXISTS idx_last_accessed ON sessions(last_accessed DESC)",
141            [],
142        )
143        .map_err(|e| CliError::ConfigError(format!("Failed to create index: {}", e)))?;
144
145        Ok(())
146    }
147
148    fn get_connection(&self) -> Result<Connection, CliError> {
149        Connection::open(&self.db_path)
150            .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))
151    }
152
153    pub fn create_new_session(&self) -> Result<String, CliError> {
154        let session_id = Uuid::new_v4().to_string();
155        let now = Utc::now().to_rfc3339();
156
157        let conn = self.get_connection()?;
158        conn.execute(
159            "INSERT INTO sessions (id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
160            params![&session_id, &now, &now, 0, 0, 0],
161        )
162        .map_err(|e| CliError::ConfigError(format!("Failed to create session: {}", e)))?;
163
164        Ok(session_id)
165    }
166
167    #[instrument(skip(self, messages))]
168    pub fn save_session(
169        &self,
170        session_id: &str,
171        messages: &[Message],
172        total_input_tokens: u64,
173        total_output_tokens: u64,
174    ) -> Result<(), CliError> {
175        let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
176
177        fs::create_dir_all(&self.sessions_dir).map_err(|e| {
178            CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
179        })?;
180
181        let persisted_messages: Vec<PersistedMessage> =
182            messages.iter().cloned().map(|m| m.into()).collect();
183
184        let state = PersistedState {
185            version: CURRENT_VERSION,
186            messages: persisted_messages,
187        };
188
189        let serialized = serialize(&state)
190            .map_err(|e| CliError::ConfigError(format!("Failed to serialize messages: {}", e)))?;
191
192        fs::write(&file_path, serialized)
193            .map_err(|e| CliError::ConfigError(format!("Failed to write session file: {}", e)))?;
194
195        let now = Utc::now().to_rfc3339();
196        let conn = self.get_connection()?;
197        conn.execute(
198            "UPDATE sessions SET last_accessed = ?1, message_count = ?2, total_input_tokens = ?3, total_output_tokens = ?4 WHERE id = ?5",
199            params![&now, messages.len() as i64, total_input_tokens as i64, total_output_tokens as i64, session_id],
200        )
201        .map_err(|e| CliError::ConfigError(format!("Failed to update session metadata: {}", e)))?;
202
203        Ok(())
204    }
205
206    #[instrument(skip(self))]
207    pub fn load_session(&self, session_id: &str) -> Result<Vec<Message>, CliError> {
208        let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
209
210        let data = fs::read(&file_path)
211            .map_err(|e| CliError::ConfigError(format!("Failed to read session file: {}", e)))?;
212
213        let state: PersistedState = deserialize(&data)
214            .map_err(|e| CliError::ConfigError(format!("Failed to deserialize messages: {}", e)))?;
215
216        // Handle version migration if needed
217        if state.version > CURRENT_VERSION {
218            return Err(CliError::ConfigError(format!(
219                "Version mismatch: expected {}, found {}",
220                CURRENT_VERSION, state.version
221            )));
222        }
223
224        let now = Utc::now().to_rfc3339();
225        let conn = self.get_connection()?;
226        conn.execute(
227            "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
228            params![&now, session_id],
229        )
230        .map_err(|e| CliError::ConfigError(format!("Failed to update last_accessed: {}", e)))?;
231
232        // Filter out system messages - z.ai API doesn't support system role
233        let messages: Vec<Message> = state
234            .messages
235            .into_iter()
236            .map(Message::from)
237            .filter(|m| m.role != limit_llm::Role::System)
238            .collect();
239
240        Ok(messages)
241    }
242
243    pub fn list_sessions(&self) -> Result<Vec<SessionInfo>, CliError> {
244        let conn = self.get_connection()?;
245
246        let mut stmt = conn
247            .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC")
248            .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
249
250        let session_iter = stmt
251            .query_map([], |row| {
252                Ok(SessionInfo {
253                    id: row.get(0)?,
254                    created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
255                        .unwrap()
256                        .with_timezone(&Utc),
257                    last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
258                        .unwrap()
259                        .with_timezone(&Utc),
260                    message_count: row.get::<_, i64>(3)? as usize,
261                    total_input_tokens: row.get::<_, i64>(4)? as u64,
262                    total_output_tokens: row.get::<_, i64>(5)? as u64,
263                })
264            })
265            .map_err(|e| CliError::ConfigError(format!("Failed to query sessions: {}", e)))?;
266
267        let mut sessions = Vec::new();
268        for session in session_iter {
269            sessions.push(
270                session.map_err(|e| {
271                    CliError::ConfigError(format!("Failed to parse session: {}", e))
272                })?,
273            );
274        }
275
276        Ok(sessions)
277    }
278
279    pub fn get_last_session(&self) -> Result<Option<SessionInfo>, CliError> {
280        let conn = self.get_connection()?;
281
282        let mut stmt = conn
283            .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC LIMIT 1")
284            .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
285
286        let mut session_iter = stmt
287            .query_map([], |row| {
288                Ok(SessionInfo {
289                    id: row.get(0)?,
290                    created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
291                        .unwrap()
292                        .with_timezone(&Utc),
293                    last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
294                        .unwrap()
295                        .with_timezone(&Utc),
296                    message_count: row.get::<_, i64>(3)? as usize,
297                    total_input_tokens: row.get::<_, i64>(4)? as u64,
298                    total_output_tokens: row.get::<_, i64>(5)? as u64,
299                })
300            })
301            .map_err(|e| CliError::ConfigError(format!("Failed to query last session: {}", e)))?;
302
303        match session_iter.next() {
304            Some(session) => Ok(Some(session.map_err(|e| {
305                CliError::ConfigError(format!("Failed to parse session: {}", e))
306            })?)),
307            None => Ok(None),
308        }
309    }
310
311    /// Update token counts for a session
312    #[allow(dead_code)]
313    pub fn update_session_tokens(
314        &self,
315        session_id: &str,
316        input_tokens: u64,
317        output_tokens: u64,
318    ) -> Result<(), CliError> {
319        let conn = self.get_connection()?;
320        conn.execute(
321            "UPDATE sessions SET total_input_tokens = total_input_tokens + ?1, total_output_tokens = total_output_tokens + ?2 WHERE id = ?3",
322            params![input_tokens as i64, output_tokens as i64, session_id],
323        )
324        .map_err(|e| CliError::ConfigError(format!("Failed to update session tokens: {}", e)))?;
325        Ok(())
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use tempfile::tempdir;
333
334    #[test]
335    fn test_session_manager_new() {
336        let dir = tempdir().unwrap();
337        let sessions_dir = dir.path().join("sessions");
338        let db_path = dir.path().join("session.db");
339
340        fs::create_dir_all(&sessions_dir).unwrap();
341
342        let manager = SessionManager {
343            db_path,
344            sessions_dir,
345        };
346
347        manager.init_db().unwrap();
348
349        assert!(manager.db_path.exists());
350    }
351
352    #[test]
353    fn test_create_new_session() {
354        let dir = tempdir().unwrap();
355        let sessions_dir = dir.path().join("sessions");
356        let db_path = dir.path().join("session.db");
357
358        fs::create_dir_all(&sessions_dir).unwrap();
359
360        let manager = SessionManager {
361            db_path,
362            sessions_dir,
363        };
364
365        manager.init_db().unwrap();
366
367        let session_id = manager.create_new_session().unwrap();
368
369        assert!(!session_id.is_empty());
370
371        let sessions = manager.list_sessions().unwrap();
372        assert_eq!(sessions.len(), 1);
373        assert_eq!(sessions[0].id, session_id);
374        assert_eq!(sessions[0].message_count, 0);
375    }
376
377    #[test]
378    fn test_save_and_load_session() {
379        let dir = tempdir().unwrap();
380        let sessions_dir = dir.path().join("sessions");
381        let db_path = dir.path().join("session.db");
382
383        fs::create_dir_all(&sessions_dir).unwrap();
384
385        let manager = SessionManager {
386            db_path,
387            sessions_dir,
388        };
389
390        manager.init_db().unwrap();
391
392        let session_id = manager.create_new_session().unwrap();
393
394        let messages = vec![
395            Message {
396                role: limit_llm::Role::User,
397                content: Some("Hello".to_string()),
398                tool_calls: None,
399                tool_call_id: None,
400            },
401            Message {
402                role: limit_llm::Role::Assistant,
403                content: Some("Hi there!".to_string()),
404                tool_calls: None,
405                tool_call_id: None,
406            },
407        ];
408
409        manager.save_session(&session_id, &messages, 0, 0).unwrap();
410
411        let loaded = manager.load_session(&session_id).unwrap();
412
413        assert_eq!(loaded.len(), messages.len());
414        assert_eq!(loaded[0].content, messages[0].content);
415        assert_eq!(loaded[1].content, messages[1].content);
416
417        let sessions = manager.list_sessions().unwrap();
418        assert_eq!(sessions[0].message_count, 2);
419    }
420
421    #[test]
422    fn test_list_sessions() {
423        let dir = tempdir().unwrap();
424        let sessions_dir = dir.path().join("sessions");
425        let db_path = dir.path().join("session.db");
426
427        fs::create_dir_all(&sessions_dir).unwrap();
428
429        let manager = SessionManager {
430            db_path,
431            sessions_dir,
432        };
433
434        manager.init_db().unwrap();
435
436        let session_id1 = manager.create_new_session().unwrap();
437        std::thread::sleep(std::time::Duration::from_millis(10));
438        let session_id2 = manager.create_new_session().unwrap();
439
440        let sessions = manager.list_sessions().unwrap();
441        assert_eq!(sessions.len(), 2);
442        assert_eq!(sessions[0].id, session_id2);
443        assert_eq!(sessions[1].id, session_id1);
444    }
445
446    #[test]
447    fn test_get_last_session() {
448        let dir = tempdir().unwrap();
449        let sessions_dir = dir.path().join("sessions");
450        let db_path = dir.path().join("session.db");
451
452        fs::create_dir_all(&sessions_dir).unwrap();
453
454        let manager = SessionManager {
455            db_path,
456            sessions_dir,
457        };
458
459        manager.init_db().unwrap();
460
461        let last = manager.get_last_session().unwrap();
462        assert!(last.is_none());
463
464        let _session_id1 = manager.create_new_session().unwrap();
465        std::thread::sleep(std::time::Duration::from_millis(10));
466        let session_id2 = manager.create_new_session().unwrap();
467
468        let last = manager.get_last_session().unwrap();
469        assert!(last.is_some());
470        assert_eq!(last.unwrap().id, session_id2);
471    }
472
473    #[test]
474    fn test_session_persistence_across_restarts() {
475        let dir = tempdir().unwrap();
476        let sessions_dir = dir.path().join("sessions");
477        let db_path = dir.path().join("session.db");
478
479        fs::create_dir_all(&sessions_dir).unwrap();
480
481        let manager1 = SessionManager {
482            db_path: db_path.clone(),
483            sessions_dir: sessions_dir.clone(),
484        };
485
486        manager1.init_db().unwrap();
487
488        let session_id = manager1.create_new_session().unwrap();
489
490        let messages = vec![Message {
491            role: limit_llm::Role::User,
492            content: Some("Test message".to_string()),
493            tool_calls: None,
494            tool_call_id: None,
495        }];
496
497        manager1.save_session(&session_id, &messages, 0, 0).unwrap();
498
499        drop(manager1);
500
501        let manager2 = SessionManager {
502            db_path,
503            sessions_dir,
504        };
505
506        manager2.init_db().unwrap();
507
508        let loaded = manager2.load_session(&session_id).unwrap();
509        assert_eq!(loaded.len(), 1);
510        assert_eq!(loaded[0].content, Some("Test message".to_string()));
511    }
512}