Skip to main content

ravenclaws/
persistence.rs

1//! # Conversation Persistence (SQLite backend)
2//!
3//! Provides SQLite-backed storage for conversation history so agents survive
4//! pod restarts without losing context. Supports configurable retention policies
5//! (time-based, count-based, token-budget-based).
6//!
7//! ## Architecture
8//!
9//! - `ConversationStore` — manages a SQLite database with sessions and messages tables
10//! - Sessions are identified by a session ID (UUID or user-provided)
11//! - Messages are stored with role, content, timestamp, and token count
12//! - Retention policies are applied on read (not on write) for simplicity
13//!
14//! ## Usage
15//!
16//! ```rust,no_run
17//! use ravenclaws::persistence::ConversationStore;
18//!
19//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
20//! let store = ConversationStore::open(":memory:")?;
21//! store.create_session("session-1", "You are a helpful assistant.")?;
22//! store.add_message("session-1", "user", "Hello!", None)?;
23//! let history = store.get_history("session-1", None)?;
24//! # Ok(())
25//! # }
26//! ```
27
28use rusqlite::{params, Connection, Result as SqlResult};
29use serde::{Deserialize, Serialize};
30use std::path::Path;
31use std::time::{Duration, SystemTime, UNIX_EPOCH};
32
33/// A stored conversation message
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct StoredMessage {
36    /// Message role (system, user, assistant, tool)
37    pub role: String,
38    /// Message content
39    pub content: String,
40    /// Unix timestamp when the message was created
41    pub created_at: u64,
42    /// Optional token count for budget tracking
43    pub token_count: Option<u64>,
44}
45
46/// A stored conversation session
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct StoredSession {
49    /// Unique session identifier
50    pub session_id: String,
51    /// System prompt used for this session
52    pub system_prompt: String,
53    /// Unix timestamp when the session was created
54    pub created_at: u64,
55    /// Unix timestamp of the last activity
56    pub updated_at: u64,
57    /// Total token count across all messages
58    pub total_tokens: u64,
59    /// Number of messages in the session
60    pub message_count: u64,
61}
62
63/// Retention policy for pruning old conversations
64#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
65pub enum RetentionPolicy {
66    /// Keep messages newer than this duration
67    TimeBased(Duration),
68    /// Keep at most this many messages (oldest removed first)
69    CountBased(usize),
70    /// Keep messages until total tokens exceed this budget
71    TokenBudget(u64),
72    /// No retention limit
73    Unlimited,
74}
75
76impl RetentionPolicy {
77    /// Apply this policy to a list of messages, returning the pruned list
78    pub fn apply(&self, messages: &mut Vec<StoredMessage>) {
79        match self {
80            RetentionPolicy::TimeBased(duration) => {
81                let cutoff = SystemTime::now()
82                    .duration_since(UNIX_EPOCH)
83                    .unwrap_or_default()
84                    .as_secs()
85                    - duration.as_secs();
86                messages.retain(|m| m.created_at >= cutoff);
87            }
88            RetentionPolicy::CountBased(max) => {
89                if messages.len() > *max {
90                    // Keep the most recent `max` messages
91                    let keep = messages.split_off(messages.len() - max);
92                    *messages = keep;
93                }
94            }
95            RetentionPolicy::TokenBudget(budget) => {
96                let mut total: u64 = 0;
97                // Keep messages from newest to oldest until budget is exceeded
98                messages.reverse();
99                messages.retain(|m| {
100                    let tokens = m.token_count.unwrap_or(0);
101                    if total + tokens <= *budget {
102                        total += tokens;
103                        true
104                    } else {
105                        false
106                    }
107                });
108                messages.reverse();
109            }
110            RetentionPolicy::Unlimited => {
111                // No pruning
112            }
113        }
114    }
115}
116
117/// SQLite-backed conversation store
118#[derive(Debug)]
119pub struct ConversationStore {
120    conn: Connection,
121}
122
123impl ConversationStore {
124    /// Open or create a SQLite database at the given path.
125    /// Use `:memory:` for an in-memory database (useful for testing).
126    pub fn open<P: AsRef<Path>>(path: P) -> SqlResult<Self> {
127        let conn = Connection::open(path)?;
128        let store = Self { conn };
129        store.initialize_tables()?;
130        Ok(store)
131    }
132
133    /// Initialize the database schema
134    fn initialize_tables(&self) -> SqlResult<()> {
135        self.conn.execute_batch(
136            "
137            CREATE TABLE IF NOT EXISTS sessions (
138                session_id   TEXT PRIMARY KEY,
139                system_prompt TEXT NOT NULL DEFAULT '',
140                created_at   INTEGER NOT NULL,
141                updated_at   INTEGER NOT NULL,
142                total_tokens INTEGER NOT NULL DEFAULT 0,
143                message_count INTEGER NOT NULL DEFAULT 0
144            );
145
146            CREATE TABLE IF NOT EXISTS messages (
147                id          INTEGER PRIMARY KEY AUTOINCREMENT,
148                session_id  TEXT NOT NULL,
149                role        TEXT NOT NULL,
150                content     TEXT NOT NULL,
151                created_at  INTEGER NOT NULL,
152                token_count INTEGER DEFAULT NULL,
153                FOREIGN KEY (session_id) REFERENCES sessions(session_id) ON DELETE CASCADE
154            );
155
156            CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id);
157            CREATE INDEX IF NOT EXISTS idx_messages_created_at ON messages(created_at);
158            CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at);
159            ",
160        )?;
161        Ok(())
162    }
163
164    /// Create a new conversation session
165    pub fn create_session(&self, session_id: &str, system_prompt: &str) -> SqlResult<()> {
166        let now = SystemTime::now()
167            .duration_since(UNIX_EPOCH)
168            .unwrap_or_default()
169            .as_secs();
170        self.conn.execute(
171            "INSERT OR IGNORE INTO sessions (session_id, system_prompt, created_at, updated_at)
172             VALUES (?1, ?2, ?3, ?3)",
173            params![session_id, system_prompt, now],
174        )?;
175        Ok(())
176    }
177
178    /// Delete a session and all its messages
179    pub fn delete_session(&self, session_id: &str) -> SqlResult<()> {
180        self.conn.execute(
181            "DELETE FROM messages WHERE session_id = ?1",
182            params![session_id],
183        )?;
184        self.conn.execute(
185            "DELETE FROM sessions WHERE session_id = ?1",
186            params![session_id],
187        )?;
188        Ok(())
189    }
190
191    /// List all sessions, ordered by most recently updated first
192    pub fn list_sessions(&self) -> SqlResult<Vec<StoredSession>> {
193        let mut stmt = self.conn.prepare(
194            "SELECT session_id, system_prompt, created_at, updated_at, total_tokens, message_count
195             FROM sessions ORDER BY updated_at DESC",
196        )?;
197        let sessions = stmt
198            .query_map([], |row| {
199                Ok(StoredSession {
200                    session_id: row.get(0)?,
201                    system_prompt: row.get(1)?,
202                    created_at: row.get(2)?,
203                    updated_at: row.get(3)?,
204                    total_tokens: row.get(4)?,
205                    message_count: row.get(5)?,
206                })
207            })?
208            .collect::<SqlResult<Vec<_>>>()?;
209        Ok(sessions)
210    }
211
212    /// Add a message to a session
213    pub fn add_message(
214        &self,
215        session_id: &str,
216        role: &str,
217        content: &str,
218        token_count: Option<u64>,
219    ) -> SqlResult<()> {
220        let now = SystemTime::now()
221            .duration_since(UNIX_EPOCH)
222            .unwrap_or_default()
223            .as_secs();
224
225        // Insert the message
226        self.conn.execute(
227            "INSERT INTO messages (session_id, role, content, created_at, token_count)
228             VALUES (?1, ?2, ?3, ?4, ?5)",
229            params![session_id, role, content, now, token_count],
230        )?;
231
232        // Update session metadata
233        self.conn.execute(
234            "UPDATE sessions SET
235                updated_at = ?1,
236                total_tokens = total_tokens + ?2,
237                message_count = message_count + 1
238             WHERE session_id = ?3",
239            params![now, token_count.unwrap_or(0), session_id],
240        )?;
241
242        Ok(())
243    }
244
245    /// Get message history for a session, optionally applying a retention policy
246    pub fn get_history(
247        &self,
248        session_id: &str,
249        policy: Option<RetentionPolicy>,
250    ) -> SqlResult<Vec<StoredMessage>> {
251        let mut stmt = self.conn.prepare(
252            "SELECT role, content, created_at, token_count
253             FROM messages WHERE session_id = ?1
254             ORDER BY created_at ASC",
255        )?;
256
257        let mut messages: Vec<StoredMessage> = stmt
258            .query_map(params![session_id], |row| {
259                Ok(StoredMessage {
260                    role: row.get(0)?,
261                    content: row.get(1)?,
262                    created_at: row.get(2)?,
263                    token_count: row.get(3)?,
264                })
265            })?
266            .collect::<SqlResult<Vec<_>>>()?;
267
268        // Apply retention policy if specified
269        if let Some(policy) = policy {
270            policy.apply(&mut messages);
271        }
272
273        Ok(messages)
274    }
275
276    /// Get the number of messages in a session
277    pub fn message_count(&self, session_id: &str) -> SqlResult<u64> {
278        let count: u64 = self
279            .conn
280            .query_row(
281                "SELECT COUNT(*) FROM messages WHERE session_id = ?1",
282                params![session_id],
283                |row| row.get(0),
284            )
285            .unwrap_or(0);
286        Ok(count)
287    }
288
289    /// Get the total token count for a session
290    pub fn total_tokens(&self, session_id: &str) -> SqlResult<u64> {
291        let total: u64 = self
292            .conn
293            .query_row(
294                "SELECT COALESCE(SUM(token_count), 0) FROM messages WHERE session_id = ?1",
295                params![session_id],
296                |row| row.get(0),
297            )
298            .unwrap_or(0);
299        Ok(total)
300    }
301
302    /// Prune old sessions based on a retention policy applied to session age
303    pub fn prune_sessions(&self, max_age: Duration) -> SqlResult<u64> {
304        let cutoff = SystemTime::now()
305            .duration_since(UNIX_EPOCH)
306            .unwrap_or_default()
307            .as_secs()
308            - max_age.as_secs();
309
310        // Find sessions to delete
311        let sessions: Vec<String> = self
312            .conn
313            .prepare("SELECT session_id FROM sessions WHERE updated_at < ?1")?
314            .query_map(params![cutoff], |row| row.get(0))?
315            .collect::<SqlResult<Vec<_>>>()?;
316
317        let count = sessions.len() as u64;
318        for session_id in &sessions {
319            self.delete_session(session_id)?;
320        }
321
322        Ok(count)
323    }
324
325    /// Convert stored messages to `ChatMessage` format for the LLM
326    pub fn to_chat_messages(
327        &self,
328        session_id: &str,
329        policy: Option<RetentionPolicy>,
330    ) -> SqlResult<Vec<crate::llm::ChatMessage>> {
331        let stored = self.get_history(session_id, policy)?;
332        Ok(stored
333            .into_iter()
334            .map(|m| crate::llm::ChatMessage {
335                role: m.role,
336                content: m.content,
337                content_parts: None,
338            })
339            .collect())
340    }
341
342    /// Import messages from a `ConversationMemory` into a session
343    pub fn import_memory(
344        &self,
345        session_id: &str,
346        memory: &crate::agent::ConversationMemory,
347        system_prompt: &str,
348    ) -> SqlResult<()> {
349        self.create_session(session_id, system_prompt)?;
350
351        for msg in memory.history() {
352            self.add_message(session_id, &msg.role, &msg.content, None)?;
353        }
354
355        Ok(())
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use std::time::Duration;
363
364    fn create_test_store() -> ConversationStore {
365        ConversationStore::open(":memory:").expect("Failed to create in-memory store")
366    }
367
368    #[test]
369    fn test_create_and_list_sessions() {
370        let store = create_test_store();
371        store.create_session("test-1", "You are helpful.").unwrap();
372        store.create_session("test-2", "You are a poet.").unwrap();
373
374        let sessions = store.list_sessions().unwrap();
375        assert_eq!(sessions.len(), 2);
376        assert_eq!(sessions[0].session_id, "test-2"); // most recent first
377        assert_eq!(sessions[1].session_id, "test-1");
378    }
379
380    #[test]
381    fn test_add_and_get_messages() {
382        let store = create_test_store();
383        store
384            .create_session("session-1", "You are helpful.")
385            .unwrap();
386        store
387            .add_message("session-1", "user", "Hello!", Some(5))
388            .unwrap();
389        store
390            .add_message("session-1", "assistant", "Hi there!", Some(10))
391            .unwrap();
392
393        let history = store.get_history("session-1", None).unwrap();
394        assert_eq!(history.len(), 2);
395        assert_eq!(history[0].role, "user");
396        assert_eq!(history[0].content, "Hello!");
397        assert_eq!(history[0].token_count, Some(5));
398        assert_eq!(history[1].role, "assistant");
399        assert_eq!(history[1].content, "Hi there!");
400        assert_eq!(history[1].token_count, Some(10));
401    }
402
403    #[test]
404    fn test_message_count_and_tokens() {
405        let store = create_test_store();
406        store
407            .create_session("session-1", "You are helpful.")
408            .unwrap();
409        store
410            .add_message("session-1", "user", "Hello!", Some(5))
411            .unwrap();
412        store
413            .add_message("session-1", "assistant", "Hi!", Some(3))
414            .unwrap();
415
416        assert_eq!(store.message_count("session-1").unwrap(), 2);
417        assert_eq!(store.total_tokens("session-1").unwrap(), 8);
418    }
419
420    #[test]
421    fn test_delete_session() {
422        let store = create_test_store();
423        store
424            .create_session("session-1", "You are helpful.")
425            .unwrap();
426        store
427            .add_message("session-1", "user", "Hello!", None)
428            .unwrap();
429
430        store.delete_session("session-1").unwrap();
431        let sessions = store.list_sessions().unwrap();
432        assert_eq!(sessions.len(), 0);
433        assert_eq!(store.message_count("session-1").unwrap(), 0);
434    }
435
436    #[test]
437    fn test_retention_policy_time_based() {
438        let mut messages = vec![
439            StoredMessage {
440                role: "user".into(),
441                content: "old".into(),
442                created_at: 1000,
443                token_count: None,
444            },
445            StoredMessage {
446                role: "user".into(),
447                content: "new".into(),
448                created_at: u64::MAX,
449                token_count: None,
450            },
451        ];
452
453        // Keep messages newer than 1 hour
454        let policy = RetentionPolicy::TimeBased(Duration::from_secs(3600));
455        policy.apply(&mut messages);
456
457        // Only the "new" message (with far-future timestamp) should remain
458        assert_eq!(messages.len(), 1);
459        assert_eq!(messages[0].content, "new");
460    }
461
462    #[test]
463    fn test_retention_policy_count_based() {
464        let mut messages: Vec<StoredMessage> = (0..10)
465            .map(|i| StoredMessage {
466                role: "user".into(),
467                content: format!("msg-{}", i),
468                created_at: i as u64,
469                token_count: None,
470            })
471            .collect();
472
473        let policy = RetentionPolicy::CountBased(3);
474        policy.apply(&mut messages);
475
476        assert_eq!(messages.len(), 3);
477        assert_eq!(messages[0].content, "msg-7");
478        assert_eq!(messages[2].content, "msg-9");
479    }
480
481    #[test]
482    fn test_retention_policy_token_budget() {
483        let mut messages = vec![
484            StoredMessage {
485                role: "user".into(),
486                content: "a".into(),
487                created_at: 1,
488                token_count: Some(100),
489            },
490            StoredMessage {
491                role: "user".into(),
492                content: "b".into(),
493                created_at: 2,
494                token_count: Some(50),
495            },
496            StoredMessage {
497                role: "user".into(),
498                content: "c".into(),
499                created_at: 3,
500                token_count: Some(30),
501            },
502        ];
503
504        // Budget of 80 tokens — should keep newest messages up to 80 tokens
505        let policy = RetentionPolicy::TokenBudget(80);
506        policy.apply(&mut messages);
507
508        // From newest: c(30) + b(50) = 80, a(100) exceeds budget
509        assert_eq!(messages.len(), 2);
510        assert_eq!(messages[0].content, "b");
511        assert_eq!(messages[1].content, "c");
512    }
513
514    #[test]
515    fn test_retention_policy_unlimited() {
516        let mut messages = vec![
517            StoredMessage {
518                role: "user".into(),
519                content: "a".into(),
520                created_at: 1,
521                token_count: None,
522            },
523            StoredMessage {
524                role: "user".into(),
525                content: "b".into(),
526                created_at: 2,
527                token_count: None,
528            },
529        ];
530
531        let policy = RetentionPolicy::Unlimited;
532        policy.apply(&mut messages);
533        assert_eq!(messages.len(), 2);
534    }
535
536    #[test]
537    fn test_prune_sessions() {
538        let store = create_test_store();
539        store.create_session("old-session", "Old.").unwrap();
540        store.create_session("new-session", "New.").unwrap();
541
542        // Manually set old session's updated_at to the past
543        let past = 1000; // year 1970
544        store
545            .conn
546            .execute(
547                "UPDATE sessions SET updated_at = ?1 WHERE session_id = 'old-session'",
548                params![past],
549            )
550            .unwrap();
551
552        let pruned = store.prune_sessions(Duration::from_secs(3600)).unwrap();
553        assert_eq!(pruned, 1);
554
555        let sessions = store.list_sessions().unwrap();
556        assert_eq!(sessions.len(), 1);
557        assert_eq!(sessions[0].session_id, "new-session");
558    }
559
560    #[test]
561    fn test_to_chat_messages() {
562        let store = create_test_store();
563        store.create_session("s1", "System prompt.").unwrap();
564        store
565            .add_message("s1", "system", "System prompt.", None)
566            .unwrap();
567        store.add_message("s1", "user", "Hello!", None).unwrap();
568
569        let chat_msgs = store.to_chat_messages("s1", None).unwrap();
570        assert_eq!(chat_msgs.len(), 2);
571        assert_eq!(chat_msgs[0].role, "system");
572        assert_eq!(chat_msgs[1].content, "Hello!");
573    }
574
575    #[test]
576    fn test_import_memory() {
577        let store = create_test_store();
578        let mut memory = crate::agent::ConversationMemory::new("System prompt.", 0);
579        memory.add_user_message("Hello!");
580        memory.add_assistant_message("Hi there!");
581
582        store
583            .import_memory("imported-session", &memory, "System prompt.")
584            .unwrap();
585
586        let history = store.get_history("imported-session", None).unwrap();
587        assert_eq!(history.len(), 3); // system + user + assistant
588        assert_eq!(history[0].content, "System prompt.");
589        assert_eq!(history[1].content, "Hello!");
590        assert_eq!(history[2].content, "Hi there!");
591    }
592
593    #[test]
594    fn test_session_metadata_updates() {
595        let store = create_test_store();
596        store.create_session("s1", "Helpful assistant.").unwrap();
597
598        store.add_message("s1", "user", "Hi", Some(3)).unwrap();
599        store
600            .add_message("s1", "assistant", "Hello!", Some(5))
601            .unwrap();
602
603        let sessions = store.list_sessions().unwrap();
604        assert_eq!(sessions.len(), 1);
605        assert_eq!(sessions[0].message_count, 2);
606        assert_eq!(sessions[0].total_tokens, 8);
607    }
608
609    #[test]
610    fn test_nonexistent_session_returns_empty() {
611        let store = create_test_store();
612        let history = store.get_history("nonexistent", None).unwrap();
613        assert!(history.is_empty());
614        assert_eq!(store.message_count("nonexistent").unwrap(), 0);
615        assert_eq!(store.total_tokens("nonexistent").unwrap(), 0);
616    }
617}