1use rusqlite::{params, Connection, Result as SqlResult};
29use serde::{Deserialize, Serialize};
30use std::path::Path;
31use std::time::{Duration, SystemTime, UNIX_EPOCH};
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct StoredMessage {
36 pub role: String,
38 pub content: String,
40 pub created_at: u64,
42 pub token_count: Option<u64>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct StoredSession {
49 pub session_id: String,
51 pub system_prompt: String,
53 pub created_at: u64,
55 pub updated_at: u64,
57 pub total_tokens: u64,
59 pub message_count: u64,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
65pub enum RetentionPolicy {
66 TimeBased(Duration),
68 CountBased(usize),
70 TokenBudget(u64),
72 Unlimited,
74}
75
76impl RetentionPolicy {
77 pub fn apply(&self, messages: &mut Vec<StoredMessage>) {
79 match self {
80 RetentionPolicy::TimeBased(duration) => {
81 let cutoff = SystemTime::now()
82 .duration_since(UNIX_EPOCH)
83 .unwrap_or_default()
84 .as_secs()
85 - duration.as_secs();
86 messages.retain(|m| m.created_at >= cutoff);
87 }
88 RetentionPolicy::CountBased(max) => {
89 if messages.len() > *max {
90 let keep = messages.split_off(messages.len() - max);
92 *messages = keep;
93 }
94 }
95 RetentionPolicy::TokenBudget(budget) => {
96 let mut total: u64 = 0;
97 messages.reverse();
99 messages.retain(|m| {
100 let tokens = m.token_count.unwrap_or(0);
101 if total + tokens <= *budget {
102 total += tokens;
103 true
104 } else {
105 false
106 }
107 });
108 messages.reverse();
109 }
110 RetentionPolicy::Unlimited => {
111 }
113 }
114 }
115}
116
117#[derive(Debug)]
119pub struct ConversationStore {
120 conn: Connection,
121}
122
123impl ConversationStore {
124 pub fn open<P: AsRef<Path>>(path: P) -> SqlResult<Self> {
127 let conn = Connection::open(path)?;
128 let store = Self { conn };
129 store.initialize_tables()?;
130 Ok(store)
131 }
132
133 fn initialize_tables(&self) -> SqlResult<()> {
135 self.conn.execute_batch(
136 "
137 CREATE TABLE IF NOT EXISTS sessions (
138 session_id TEXT PRIMARY KEY,
139 system_prompt TEXT NOT NULL DEFAULT '',
140 created_at INTEGER NOT NULL,
141 updated_at INTEGER NOT NULL,
142 total_tokens INTEGER NOT NULL DEFAULT 0,
143 message_count INTEGER NOT NULL DEFAULT 0
144 );
145
146 CREATE TABLE IF NOT EXISTS messages (
147 id INTEGER PRIMARY KEY AUTOINCREMENT,
148 session_id TEXT NOT NULL,
149 role TEXT NOT NULL,
150 content TEXT NOT NULL,
151 created_at INTEGER NOT NULL,
152 token_count INTEGER DEFAULT NULL,
153 FOREIGN KEY (session_id) REFERENCES sessions(session_id) ON DELETE CASCADE
154 );
155
156 CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id);
157 CREATE INDEX IF NOT EXISTS idx_messages_created_at ON messages(created_at);
158 CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at);
159 ",
160 )?;
161 Ok(())
162 }
163
164 pub fn create_session(&self, session_id: &str, system_prompt: &str) -> SqlResult<()> {
166 let now = SystemTime::now()
167 .duration_since(UNIX_EPOCH)
168 .unwrap_or_default()
169 .as_secs();
170 self.conn.execute(
171 "INSERT OR IGNORE INTO sessions (session_id, system_prompt, created_at, updated_at)
172 VALUES (?1, ?2, ?3, ?3)",
173 params![session_id, system_prompt, now],
174 )?;
175 Ok(())
176 }
177
178 pub fn delete_session(&self, session_id: &str) -> SqlResult<()> {
180 self.conn.execute(
181 "DELETE FROM messages WHERE session_id = ?1",
182 params![session_id],
183 )?;
184 self.conn.execute(
185 "DELETE FROM sessions WHERE session_id = ?1",
186 params![session_id],
187 )?;
188 Ok(())
189 }
190
191 pub fn list_sessions(&self) -> SqlResult<Vec<StoredSession>> {
193 let mut stmt = self.conn.prepare(
194 "SELECT session_id, system_prompt, created_at, updated_at, total_tokens, message_count
195 FROM sessions ORDER BY updated_at DESC",
196 )?;
197 let sessions = stmt
198 .query_map([], |row| {
199 Ok(StoredSession {
200 session_id: row.get(0)?,
201 system_prompt: row.get(1)?,
202 created_at: row.get(2)?,
203 updated_at: row.get(3)?,
204 total_tokens: row.get(4)?,
205 message_count: row.get(5)?,
206 })
207 })?
208 .collect::<SqlResult<Vec<_>>>()?;
209 Ok(sessions)
210 }
211
212 pub fn add_message(
214 &self,
215 session_id: &str,
216 role: &str,
217 content: &str,
218 token_count: Option<u64>,
219 ) -> SqlResult<()> {
220 let now = SystemTime::now()
221 .duration_since(UNIX_EPOCH)
222 .unwrap_or_default()
223 .as_secs();
224
225 self.conn.execute(
227 "INSERT INTO messages (session_id, role, content, created_at, token_count)
228 VALUES (?1, ?2, ?3, ?4, ?5)",
229 params![session_id, role, content, now, token_count],
230 )?;
231
232 self.conn.execute(
234 "UPDATE sessions SET
235 updated_at = ?1,
236 total_tokens = total_tokens + ?2,
237 message_count = message_count + 1
238 WHERE session_id = ?3",
239 params![now, token_count.unwrap_or(0), session_id],
240 )?;
241
242 Ok(())
243 }
244
245 pub fn get_history(
247 &self,
248 session_id: &str,
249 policy: Option<RetentionPolicy>,
250 ) -> SqlResult<Vec<StoredMessage>> {
251 let mut stmt = self.conn.prepare(
252 "SELECT role, content, created_at, token_count
253 FROM messages WHERE session_id = ?1
254 ORDER BY created_at ASC",
255 )?;
256
257 let mut messages: Vec<StoredMessage> = stmt
258 .query_map(params![session_id], |row| {
259 Ok(StoredMessage {
260 role: row.get(0)?,
261 content: row.get(1)?,
262 created_at: row.get(2)?,
263 token_count: row.get(3)?,
264 })
265 })?
266 .collect::<SqlResult<Vec<_>>>()?;
267
268 if let Some(policy) = policy {
270 policy.apply(&mut messages);
271 }
272
273 Ok(messages)
274 }
275
276 pub fn message_count(&self, session_id: &str) -> SqlResult<u64> {
278 let count: u64 = self
279 .conn
280 .query_row(
281 "SELECT COUNT(*) FROM messages WHERE session_id = ?1",
282 params![session_id],
283 |row| row.get(0),
284 )
285 .unwrap_or(0);
286 Ok(count)
287 }
288
289 pub fn total_tokens(&self, session_id: &str) -> SqlResult<u64> {
291 let total: u64 = self
292 .conn
293 .query_row(
294 "SELECT COALESCE(SUM(token_count), 0) FROM messages WHERE session_id = ?1",
295 params![session_id],
296 |row| row.get(0),
297 )
298 .unwrap_or(0);
299 Ok(total)
300 }
301
302 pub fn prune_sessions(&self, max_age: Duration) -> SqlResult<u64> {
304 let cutoff = SystemTime::now()
305 .duration_since(UNIX_EPOCH)
306 .unwrap_or_default()
307 .as_secs()
308 - max_age.as_secs();
309
310 let sessions: Vec<String> = self
312 .conn
313 .prepare("SELECT session_id FROM sessions WHERE updated_at < ?1")?
314 .query_map(params![cutoff], |row| row.get(0))?
315 .collect::<SqlResult<Vec<_>>>()?;
316
317 let count = sessions.len() as u64;
318 for session_id in &sessions {
319 self.delete_session(session_id)?;
320 }
321
322 Ok(count)
323 }
324
325 pub fn to_chat_messages(
327 &self,
328 session_id: &str,
329 policy: Option<RetentionPolicy>,
330 ) -> SqlResult<Vec<crate::llm::ChatMessage>> {
331 let stored = self.get_history(session_id, policy)?;
332 Ok(stored
333 .into_iter()
334 .map(|m| crate::llm::ChatMessage {
335 role: m.role,
336 content: m.content,
337 content_parts: None,
338 })
339 .collect())
340 }
341
342 pub fn import_memory(
344 &self,
345 session_id: &str,
346 memory: &crate::agent::ConversationMemory,
347 system_prompt: &str,
348 ) -> SqlResult<()> {
349 self.create_session(session_id, system_prompt)?;
350
351 for msg in memory.history() {
352 self.add_message(session_id, &msg.role, &msg.content, None)?;
353 }
354
355 Ok(())
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use std::time::Duration;
363
364 fn create_test_store() -> ConversationStore {
365 ConversationStore::open(":memory:").expect("Failed to create in-memory store")
366 }
367
368 #[test]
369 fn test_create_and_list_sessions() {
370 let store = create_test_store();
371 store.create_session("test-1", "You are helpful.").unwrap();
372 store.create_session("test-2", "You are a poet.").unwrap();
373
374 let sessions = store.list_sessions().unwrap();
375 assert_eq!(sessions.len(), 2);
376 assert_eq!(sessions[0].session_id, "test-2"); assert_eq!(sessions[1].session_id, "test-1");
378 }
379
380 #[test]
381 fn test_add_and_get_messages() {
382 let store = create_test_store();
383 store
384 .create_session("session-1", "You are helpful.")
385 .unwrap();
386 store
387 .add_message("session-1", "user", "Hello!", Some(5))
388 .unwrap();
389 store
390 .add_message("session-1", "assistant", "Hi there!", Some(10))
391 .unwrap();
392
393 let history = store.get_history("session-1", None).unwrap();
394 assert_eq!(history.len(), 2);
395 assert_eq!(history[0].role, "user");
396 assert_eq!(history[0].content, "Hello!");
397 assert_eq!(history[0].token_count, Some(5));
398 assert_eq!(history[1].role, "assistant");
399 assert_eq!(history[1].content, "Hi there!");
400 assert_eq!(history[1].token_count, Some(10));
401 }
402
403 #[test]
404 fn test_message_count_and_tokens() {
405 let store = create_test_store();
406 store
407 .create_session("session-1", "You are helpful.")
408 .unwrap();
409 store
410 .add_message("session-1", "user", "Hello!", Some(5))
411 .unwrap();
412 store
413 .add_message("session-1", "assistant", "Hi!", Some(3))
414 .unwrap();
415
416 assert_eq!(store.message_count("session-1").unwrap(), 2);
417 assert_eq!(store.total_tokens("session-1").unwrap(), 8);
418 }
419
420 #[test]
421 fn test_delete_session() {
422 let store = create_test_store();
423 store
424 .create_session("session-1", "You are helpful.")
425 .unwrap();
426 store
427 .add_message("session-1", "user", "Hello!", None)
428 .unwrap();
429
430 store.delete_session("session-1").unwrap();
431 let sessions = store.list_sessions().unwrap();
432 assert_eq!(sessions.len(), 0);
433 assert_eq!(store.message_count("session-1").unwrap(), 0);
434 }
435
436 #[test]
437 fn test_retention_policy_time_based() {
438 let mut messages = vec![
439 StoredMessage {
440 role: "user".into(),
441 content: "old".into(),
442 created_at: 1000,
443 token_count: None,
444 },
445 StoredMessage {
446 role: "user".into(),
447 content: "new".into(),
448 created_at: u64::MAX,
449 token_count: None,
450 },
451 ];
452
453 let policy = RetentionPolicy::TimeBased(Duration::from_secs(3600));
455 policy.apply(&mut messages);
456
457 assert_eq!(messages.len(), 1);
459 assert_eq!(messages[0].content, "new");
460 }
461
462 #[test]
463 fn test_retention_policy_count_based() {
464 let mut messages: Vec<StoredMessage> = (0..10)
465 .map(|i| StoredMessage {
466 role: "user".into(),
467 content: format!("msg-{}", i),
468 created_at: i as u64,
469 token_count: None,
470 })
471 .collect();
472
473 let policy = RetentionPolicy::CountBased(3);
474 policy.apply(&mut messages);
475
476 assert_eq!(messages.len(), 3);
477 assert_eq!(messages[0].content, "msg-7");
478 assert_eq!(messages[2].content, "msg-9");
479 }
480
481 #[test]
482 fn test_retention_policy_token_budget() {
483 let mut messages = vec![
484 StoredMessage {
485 role: "user".into(),
486 content: "a".into(),
487 created_at: 1,
488 token_count: Some(100),
489 },
490 StoredMessage {
491 role: "user".into(),
492 content: "b".into(),
493 created_at: 2,
494 token_count: Some(50),
495 },
496 StoredMessage {
497 role: "user".into(),
498 content: "c".into(),
499 created_at: 3,
500 token_count: Some(30),
501 },
502 ];
503
504 let policy = RetentionPolicy::TokenBudget(80);
506 policy.apply(&mut messages);
507
508 assert_eq!(messages.len(), 2);
510 assert_eq!(messages[0].content, "b");
511 assert_eq!(messages[1].content, "c");
512 }
513
514 #[test]
515 fn test_retention_policy_unlimited() {
516 let mut messages = vec![
517 StoredMessage {
518 role: "user".into(),
519 content: "a".into(),
520 created_at: 1,
521 token_count: None,
522 },
523 StoredMessage {
524 role: "user".into(),
525 content: "b".into(),
526 created_at: 2,
527 token_count: None,
528 },
529 ];
530
531 let policy = RetentionPolicy::Unlimited;
532 policy.apply(&mut messages);
533 assert_eq!(messages.len(), 2);
534 }
535
536 #[test]
537 fn test_prune_sessions() {
538 let store = create_test_store();
539 store.create_session("old-session", "Old.").unwrap();
540 store.create_session("new-session", "New.").unwrap();
541
542 let past = 1000; store
545 .conn
546 .execute(
547 "UPDATE sessions SET updated_at = ?1 WHERE session_id = 'old-session'",
548 params![past],
549 )
550 .unwrap();
551
552 let pruned = store.prune_sessions(Duration::from_secs(3600)).unwrap();
553 assert_eq!(pruned, 1);
554
555 let sessions = store.list_sessions().unwrap();
556 assert_eq!(sessions.len(), 1);
557 assert_eq!(sessions[0].session_id, "new-session");
558 }
559
560 #[test]
561 fn test_to_chat_messages() {
562 let store = create_test_store();
563 store.create_session("s1", "System prompt.").unwrap();
564 store
565 .add_message("s1", "system", "System prompt.", None)
566 .unwrap();
567 store.add_message("s1", "user", "Hello!", None).unwrap();
568
569 let chat_msgs = store.to_chat_messages("s1", None).unwrap();
570 assert_eq!(chat_msgs.len(), 2);
571 assert_eq!(chat_msgs[0].role, "system");
572 assert_eq!(chat_msgs[1].content, "Hello!");
573 }
574
575 #[test]
576 fn test_import_memory() {
577 let store = create_test_store();
578 let mut memory = crate::agent::ConversationMemory::new("System prompt.", 0);
579 memory.add_user_message("Hello!");
580 memory.add_assistant_message("Hi there!");
581
582 store
583 .import_memory("imported-session", &memory, "System prompt.")
584 .unwrap();
585
586 let history = store.get_history("imported-session", None).unwrap();
587 assert_eq!(history.len(), 3); assert_eq!(history[0].content, "System prompt.");
589 assert_eq!(history[1].content, "Hello!");
590 assert_eq!(history[2].content, "Hi there!");
591 }
592
593 #[test]
594 fn test_session_metadata_updates() {
595 let store = create_test_store();
596 store.create_session("s1", "Helpful assistant.").unwrap();
597
598 store.add_message("s1", "user", "Hi", Some(3)).unwrap();
599 store
600 .add_message("s1", "assistant", "Hello!", Some(5))
601 .unwrap();
602
603 let sessions = store.list_sessions().unwrap();
604 assert_eq!(sessions.len(), 1);
605 assert_eq!(sessions[0].message_count, 2);
606 assert_eq!(sessions[0].total_tokens, 8);
607 }
608
609 #[test]
610 fn test_nonexistent_session_returns_empty() {
611 let store = create_test_store();
612 let history = store.get_history("nonexistent", None).unwrap();
613 assert!(history.is_empty());
614 assert_eq!(store.message_count("nonexistent").unwrap(), 0);
615 assert_eq!(store.total_tokens("nonexistent").unwrap(), 0);
616 }
617}