Skip to main content

limit_cli/
session.rs

1use crate::error::CliError;
2use bincode::{deserialize, serialize};
3use chrono::{DateTime, Utc};
4use limit_llm::Message;
5use rusqlite::{params, Connection};
6use serde::{Deserialize, Serialize};
7use std::fs;
8use std::path::PathBuf;
9use tracing::instrument;
10use uuid::Uuid;
11
12const CURRENT_VERSION: u32 = 2;
13
14#[derive(Debug, Clone)]
15pub struct SessionInfo {
16    pub id: String,
17    #[allow(dead_code)]
18    pub created_at: DateTime<Utc>,
19    #[allow(dead_code)]
20    pub last_accessed: DateTime<Utc>,
21    pub message_count: usize,
22    pub total_input_tokens: u64,
23    pub total_output_tokens: u64,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27struct PersistedState {
28    version: u32,
29    messages: Vec<PersistedMessage>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33struct PersistedMessage {
34    role: PersistedRole,
35    content: Option<String>,
36    tool_calls: Option<Vec<limit_llm::ToolCall>>,
37    tool_call_id: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41enum PersistedRole {
42    User,
43    Assistant,
44    System,
45    Tool,
46}
47
48impl From<PersistedRole> for limit_llm::Role {
49    fn from(role: PersistedRole) -> Self {
50        match role {
51            PersistedRole::User => limit_llm::Role::User,
52            PersistedRole::Assistant => limit_llm::Role::Assistant,
53            PersistedRole::System => limit_llm::Role::System,
54            PersistedRole::Tool => limit_llm::Role::Tool,
55        }
56    }
57}
58
59impl From<limit_llm::Role> for PersistedRole {
60    fn from(role: limit_llm::Role) -> Self {
61        match role {
62            limit_llm::Role::User => PersistedRole::User,
63            limit_llm::Role::Assistant => PersistedRole::Assistant,
64            limit_llm::Role::System => PersistedRole::System,
65            limit_llm::Role::Tool => PersistedRole::Tool,
66        }
67    }
68}
69
70impl From<PersistedMessage> for Message {
71    fn from(msg: PersistedMessage) -> Self {
72        Message {
73            role: msg.role.into(),
74            content: msg.content,
75            tool_calls: msg.tool_calls,
76            tool_call_id: msg.tool_call_id,
77        }
78    }
79}
80
81impl From<Message> for PersistedMessage {
82    fn from(msg: Message) -> Self {
83        PersistedMessage {
84            role: msg.role.into(),
85            content: msg.content,
86            tool_calls: msg.tool_calls,
87            tool_call_id: msg.tool_call_id,
88        }
89    }
90}
91
92pub struct SessionManager {
93    db_path: PathBuf,
94    sessions_dir: PathBuf,
95}
96
97impl SessionManager {
98    pub fn new() -> Result<Self, CliError> {
99        let limit_dir = PathBuf::from(".limit");
100        fs::create_dir_all(&limit_dir).map_err(|e| {
101            CliError::ConfigError(format!("Failed to create .limit directory: {}", e))
102        })?;
103
104        let sessions_dir = limit_dir.join("sessions");
105        fs::create_dir_all(&sessions_dir).map_err(|e| {
106            CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
107        })?;
108
109        let db_path = limit_dir.join("session.db");
110        let session_manager = Self {
111            db_path,
112            sessions_dir,
113        };
114
115        session_manager.init_db()?;
116        Ok(session_manager)
117    }
118
119    pub fn init_db(&self) -> Result<(), CliError> {
120        let conn = Connection::open(&self.db_path)
121            .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))?;
122
123        conn.execute(
124            "CREATE TABLE IF NOT EXISTS sessions (
125                id TEXT PRIMARY KEY,
126                created_at TEXT NOT NULL,
127                last_accessed TEXT NOT NULL,
128                message_count INTEGER NOT NULL,
129                total_input_tokens INTEGER NOT NULL DEFAULT 0,
130                total_output_tokens INTEGER NOT NULL DEFAULT 0
131            )",
132            [],
133        )
134        .map_err(|e| CliError::ConfigError(format!("Failed to create sessions table: {}", e)))?;
135
136        conn.execute(
137            "CREATE INDEX IF NOT EXISTS idx_last_accessed ON sessions(last_accessed DESC)",
138            [],
139        )
140        .map_err(|e| CliError::ConfigError(format!("Failed to create index: {}", e)))?;
141
142        Ok(())
143    }
144
145    fn get_connection(&self) -> Result<Connection, CliError> {
146        Connection::open(&self.db_path)
147            .map_err(|e| CliError::ConfigError(format!("Failed to open database: {}", e)))
148    }
149
150    pub fn create_new_session(&self) -> Result<String, CliError> {
151        let session_id = Uuid::new_v4().to_string();
152        let now = Utc::now().to_rfc3339();
153
154        let conn = self.get_connection()?;
155        conn.execute(
156            "INSERT INTO sessions (id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
157            params![&session_id, &now, &now, 0, 0, 0],
158        )
159        .map_err(|e| CliError::ConfigError(format!("Failed to create session: {}", e)))?;
160
161        Ok(session_id)
162    }
163
164    #[instrument(skip(self, messages))]
165    pub fn save_session(
166        &self,
167        session_id: &str,
168        messages: &[Message],
169        total_input_tokens: u64,
170        total_output_tokens: u64,
171    ) -> Result<(), CliError> {
172        let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
173
174        fs::create_dir_all(&self.sessions_dir).map_err(|e| {
175            CliError::ConfigError(format!("Failed to create sessions directory: {}", e))
176        })?;
177
178        let persisted_messages: Vec<PersistedMessage> =
179            messages.iter().cloned().map(|m| m.into()).collect();
180
181        let state = PersistedState {
182            version: CURRENT_VERSION,
183            messages: persisted_messages,
184        };
185
186        let serialized = serialize(&state)
187            .map_err(|e| CliError::ConfigError(format!("Failed to serialize messages: {}", e)))?;
188
189        fs::write(&file_path, serialized)
190            .map_err(|e| CliError::ConfigError(format!("Failed to write session file: {}", e)))?;
191
192        let now = Utc::now().to_rfc3339();
193        let conn = self.get_connection()?;
194        conn.execute(
195            "UPDATE sessions SET last_accessed = ?1, message_count = ?2, total_input_tokens = ?3, total_output_tokens = ?4 WHERE id = ?5",
196            params![&now, messages.len() as i64, total_input_tokens as i64, total_output_tokens as i64, session_id],
197        )
198        .map_err(|e| CliError::ConfigError(format!("Failed to update session metadata: {}", e)))?;
199
200        Ok(())
201    }
202
203    #[instrument(skip(self))]
204    pub fn load_session(&self, session_id: &str) -> Result<Vec<Message>, CliError> {
205        let file_path = self.sessions_dir.join(format!("{}.bin", session_id));
206
207        let data = fs::read(&file_path)
208            .map_err(|e| CliError::ConfigError(format!("Failed to read session file: {}", e)))?;
209
210        let state: PersistedState = deserialize(&data)
211            .map_err(|e| CliError::ConfigError(format!("Failed to deserialize messages: {}", e)))?;
212
213        // Handle version migration if needed
214        if state.version > CURRENT_VERSION {
215            return Err(CliError::ConfigError(format!(
216                "Version mismatch: expected {}, found {}",
217                CURRENT_VERSION, state.version
218            )));
219        }
220
221        let now = Utc::now().to_rfc3339();
222        let conn = self.get_connection()?;
223        conn.execute(
224            "UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
225            params![&now, session_id],
226        )
227        .map_err(|e| CliError::ConfigError(format!("Failed to update last_accessed: {}", e)))?;
228
229        // Filter out system messages - z.ai API doesn't support system role
230        let messages: Vec<Message> = state
231            .messages
232            .into_iter()
233            .map(Message::from)
234            .filter(|m| m.role != limit_llm::Role::System)
235            .collect();
236
237        Ok(messages)
238    }
239
240    pub fn list_sessions(&self) -> Result<Vec<SessionInfo>, CliError> {
241        let conn = self.get_connection()?;
242
243        let mut stmt = conn
244            .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC")
245            .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
246
247        let session_iter = stmt
248            .query_map([], |row| {
249                Ok(SessionInfo {
250                    id: row.get(0)?,
251                    created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
252                        .unwrap()
253                        .with_timezone(&Utc),
254                    last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
255                        .unwrap()
256                        .with_timezone(&Utc),
257                    message_count: row.get::<_, i64>(3)? as usize,
258                    total_input_tokens: row.get::<_, i64>(4)? as u64,
259                    total_output_tokens: row.get::<_, i64>(5)? as u64,
260                })
261            })
262            .map_err(|e| CliError::ConfigError(format!("Failed to query sessions: {}", e)))?;
263
264        let mut sessions = Vec::new();
265        for session in session_iter {
266            sessions.push(
267                session.map_err(|e| {
268                    CliError::ConfigError(format!("Failed to parse session: {}", e))
269                })?,
270            );
271        }
272
273        Ok(sessions)
274    }
275
276    pub fn get_last_session(&self) -> Result<Option<SessionInfo>, CliError> {
277        let conn = self.get_connection()?;
278
279        let mut stmt = conn
280            .prepare("SELECT id, created_at, last_accessed, message_count, total_input_tokens, total_output_tokens FROM sessions ORDER BY last_accessed DESC LIMIT 1")
281            .map_err(|e| CliError::ConfigError(format!("Failed to prepare query: {}", e)))?;
282
283        let mut session_iter = stmt
284            .query_map([], |row| {
285                Ok(SessionInfo {
286                    id: row.get(0)?,
287                    created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(1)?)
288                        .unwrap()
289                        .with_timezone(&Utc),
290                    last_accessed: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?)
291                        .unwrap()
292                        .with_timezone(&Utc),
293                    message_count: row.get::<_, i64>(3)? as usize,
294                    total_input_tokens: row.get::<_, i64>(4)? as u64,
295                    total_output_tokens: row.get::<_, i64>(5)? as u64,
296                })
297            })
298            .map_err(|e| CliError::ConfigError(format!("Failed to query last session: {}", e)))?;
299
300        match session_iter.next() {
301            Some(session) => Ok(Some(session.map_err(|e| {
302                CliError::ConfigError(format!("Failed to parse session: {}", e))
303            })?)),
304            None => Ok(None),
305        }
306    }
307
308    /// Update token counts for a session
309    #[allow(dead_code)]
310    pub fn update_session_tokens(
311        &self,
312        session_id: &str,
313        input_tokens: u64,
314        output_tokens: u64,
315    ) -> Result<(), CliError> {
316        let conn = self.get_connection()?;
317        conn.execute(
318            "UPDATE sessions SET total_input_tokens = total_input_tokens + ?1, total_output_tokens = total_output_tokens + ?2 WHERE id = ?3",
319            params![input_tokens as i64, output_tokens as i64, session_id],
320        )
321        .map_err(|e| CliError::ConfigError(format!("Failed to update session tokens: {}", e)))?;
322        Ok(())
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use tempfile::tempdir;
330
331    #[test]
332    fn test_session_manager_new() {
333        let dir = tempdir().unwrap();
334        let sessions_dir = dir.path().join("sessions");
335        let db_path = dir.path().join("session.db");
336
337        fs::create_dir_all(&sessions_dir).unwrap();
338
339        let manager = SessionManager {
340            db_path,
341            sessions_dir,
342        };
343
344        manager.init_db().unwrap();
345
346        assert!(manager.db_path.exists());
347    }
348
349    #[test]
350    fn test_create_new_session() {
351        let dir = tempdir().unwrap();
352        let sessions_dir = dir.path().join("sessions");
353        let db_path = dir.path().join("session.db");
354
355        fs::create_dir_all(&sessions_dir).unwrap();
356
357        let manager = SessionManager {
358            db_path,
359            sessions_dir,
360        };
361
362        manager.init_db().unwrap();
363
364        let session_id = manager.create_new_session().unwrap();
365
366        assert!(!session_id.is_empty());
367
368        let sessions = manager.list_sessions().unwrap();
369        assert_eq!(sessions.len(), 1);
370        assert_eq!(sessions[0].id, session_id);
371        assert_eq!(sessions[0].message_count, 0);
372    }
373
374    #[test]
375    fn test_save_and_load_session() {
376        let dir = tempdir().unwrap();
377        let sessions_dir = dir.path().join("sessions");
378        let db_path = dir.path().join("session.db");
379
380        fs::create_dir_all(&sessions_dir).unwrap();
381
382        let manager = SessionManager {
383            db_path,
384            sessions_dir,
385        };
386
387        manager.init_db().unwrap();
388
389        let session_id = manager.create_new_session().unwrap();
390
391        let messages = vec![
392            Message {
393                role: limit_llm::Role::User,
394                content: Some("Hello".to_string()),
395                tool_calls: None,
396                tool_call_id: None,
397            },
398            Message {
399                role: limit_llm::Role::Assistant,
400                content: Some("Hi there!".to_string()),
401                tool_calls: None,
402                tool_call_id: None,
403            },
404        ];
405
406        manager.save_session(&session_id, &messages, 0, 0).unwrap();
407
408        let loaded = manager.load_session(&session_id).unwrap();
409
410        assert_eq!(loaded.len(), messages.len());
411        assert_eq!(loaded[0].content, messages[0].content);
412        assert_eq!(loaded[1].content, messages[1].content);
413
414        let sessions = manager.list_sessions().unwrap();
415        assert_eq!(sessions[0].message_count, 2);
416    }
417
418    #[test]
419    fn test_list_sessions() {
420        let dir = tempdir().unwrap();
421        let sessions_dir = dir.path().join("sessions");
422        let db_path = dir.path().join("session.db");
423
424        fs::create_dir_all(&sessions_dir).unwrap();
425
426        let manager = SessionManager {
427            db_path,
428            sessions_dir,
429        };
430
431        manager.init_db().unwrap();
432
433        let session_id1 = manager.create_new_session().unwrap();
434        std::thread::sleep(std::time::Duration::from_millis(10));
435        let session_id2 = manager.create_new_session().unwrap();
436
437        let sessions = manager.list_sessions().unwrap();
438        assert_eq!(sessions.len(), 2);
439        assert_eq!(sessions[0].id, session_id2);
440        assert_eq!(sessions[1].id, session_id1);
441    }
442
443    #[test]
444    fn test_get_last_session() {
445        let dir = tempdir().unwrap();
446        let sessions_dir = dir.path().join("sessions");
447        let db_path = dir.path().join("session.db");
448
449        fs::create_dir_all(&sessions_dir).unwrap();
450
451        let manager = SessionManager {
452            db_path,
453            sessions_dir,
454        };
455
456        manager.init_db().unwrap();
457
458        let last = manager.get_last_session().unwrap();
459        assert!(last.is_none());
460
461        let _session_id1 = manager.create_new_session().unwrap();
462        std::thread::sleep(std::time::Duration::from_millis(10));
463        let session_id2 = manager.create_new_session().unwrap();
464
465        let last = manager.get_last_session().unwrap();
466        assert!(last.is_some());
467        assert_eq!(last.unwrap().id, session_id2);
468    }
469
470    #[test]
471    fn test_session_persistence_across_restarts() {
472        let dir = tempdir().unwrap();
473        let sessions_dir = dir.path().join("sessions");
474        let db_path = dir.path().join("session.db");
475
476        fs::create_dir_all(&sessions_dir).unwrap();
477
478        let manager1 = SessionManager {
479            db_path: db_path.clone(),
480            sessions_dir: sessions_dir.clone(),
481        };
482
483        manager1.init_db().unwrap();
484
485        let session_id = manager1.create_new_session().unwrap();
486
487        let messages = vec![Message {
488            role: limit_llm::Role::User,
489            content: Some("Test message".to_string()),
490            tool_calls: None,
491            tool_call_id: None,
492        }];
493
494        manager1.save_session(&session_id, &messages, 0, 0).unwrap();
495
496        drop(manager1);
497
498        let manager2 = SessionManager {
499            db_path,
500            sessions_dir,
501        };
502
503        manager2.init_db().unwrap();
504
505        let loaded = manager2.load_session(&session_id).unwrap();
506        assert_eq!(loaded.len(), 1);
507        assert_eq!(loaded[0].content, Some("Test message".to_string()));
508    }
509}