1use std::path::Path;
10use std::sync::{Arc, Mutex};
11
12use rusqlite::Connection;
13use thiserror::Error;
14use tracing::info;
15use uuid::Uuid;
16
17#[cfg(feature = "encryption")]
18use crate::encryption::Encryptor;
19
20mod migrations;
21
22#[cfg(test)]
23mod tests;
24
25#[derive(Debug, Error)]
27pub enum SqliteError {
28 #[error("SQLite error: {0}")]
29 Rusqlite(#[from] rusqlite::Error),
30
31 #[error("Lock poisoned")]
32 LockPoisoned,
33
34 #[error("Migration failed: {0}")]
35 Migration(String),
36}
37
38#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
40pub struct ExportedFact {
41 pub id: String,
42 pub namespace: String,
43 pub category: String,
44 pub subject: String,
45 pub predicate: String,
46 pub object: String,
47 pub confidence: f64,
48 pub source_episode_id: Option<String>,
49}
50
51#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53pub struct ExportedEpisode {
54 pub id: String,
55 pub session_id: String,
56 pub session_channel: String,
57 #[serde(default = "default_namespace")]
58 pub namespace: String,
59 pub role: String,
60 pub content: String,
61 pub timestamp: String,
62 pub importance: f64,
63 pub reinforcement_count: i32,
64}
65
66fn default_namespace() -> String {
67 "personal".to_string()
68}
69
70#[derive(Debug, Clone)]
72pub struct Notification {
73 pub id: String,
74 pub content: String,
75 pub priority: i32,
76 pub triggered_by: String,
77 pub created_at: String,
78 pub delivered_at: Option<String>,
79 pub channel: Option<String>,
80}
81
82#[derive(Clone)]
91pub struct SqlitePool {
92 conn: Arc<Mutex<Connection>>,
93 #[cfg(feature = "encryption")]
94 encryptor: Option<Arc<Encryptor>>,
95}
96
97#[derive(Debug, Clone)]
99pub struct ScheduledIntent {
100 pub id: String,
101 pub description: String,
102 pub cron: Option<String>,
103 pub namespace: String,
104 pub created_at: String,
105 pub status: String,
106 pub metadata: Option<String>,
107}
108
109impl SqlitePool {
110 pub fn open(path: &Path) -> Result<Self, SqliteError> {
117 if let Some(parent) = path.parent() {
119 std::fs::create_dir_all(parent).map_err(|e| {
120 SqliteError::Migration(format!("Cannot create directory {}: {e}", parent.display()))
121 })?;
122 }
123
124 let conn = Connection::open(path)?;
125
126 conn.execute_batch(
128 "
129 PRAGMA journal_mode = WAL;
130 PRAGMA synchronous = NORMAL;
131 PRAGMA foreign_keys = ON;
132 PRAGMA busy_timeout = 5000;
133 PRAGMA cache_size = -8000;
134 ",
135 )?;
136
137 let pool = Self {
138 conn: Arc::new(Mutex::new(conn)),
139 #[cfg(feature = "encryption")]
140 encryptor: None,
141 };
142
143 pool.migrate()?;
145
146 info!("SQLite database opened at {}", path.display());
147 Ok(pool)
148 }
149
150 pub fn open_memory() -> Result<Self, SqliteError> {
152 let conn = Connection::open_in_memory()?;
153 conn.execute_batch(
154 "
155 PRAGMA journal_mode = WAL;
156 PRAGMA foreign_keys = ON;
157 ",
158 )?;
159
160 let pool = Self {
161 conn: Arc::new(Mutex::new(conn)),
162 #[cfg(feature = "encryption")]
163 encryptor: None,
164 };
165
166 pool.migrate()?;
167 Ok(pool)
168 }
169
170 pub fn with_conn<F, T>(&self, f: F) -> Result<T, SqliteError>
172 where
173 F: FnOnce(&Connection) -> Result<T, SqliteError>,
174 {
175 let conn = self.conn.lock().map_err(|_| SqliteError::LockPoisoned)?;
176 f(&conn)
177 }
178
179 #[cfg(feature = "encryption")]
184 pub fn with_encryptor(mut self, enc: Encryptor) -> Self {
185 self.encryptor = Some(Arc::new(enc));
186 self
187 }
188
189 pub fn is_encrypted(&self) -> bool {
191 #[cfg(feature = "encryption")]
192 {
193 self.encryptor.is_some()
194 }
195 #[cfg(not(feature = "encryption"))]
196 {
197 false
198 }
199 }
200
201 pub fn encrypt_content(&self, plaintext: &str) -> String {
203 #[cfg(feature = "encryption")]
204 {
205 if let Some(enc) = &self.encryptor {
206 return enc
207 .encrypt_string(plaintext)
208 .unwrap_or_else(|_| plaintext.to_string());
209 }
210 }
211 plaintext.to_string()
212 }
213
214 pub fn decrypt_content(&self, maybe_ciphertext: &str) -> String {
219 #[cfg(feature = "encryption")]
220 {
221 if let Some(enc) = &self.encryptor {
222 return enc
223 .decrypt_string(maybe_ciphertext)
224 .unwrap_or_else(|_| maybe_ciphertext.to_string());
225 }
226 }
227 maybe_ciphertext.to_string()
228 }
229
230 pub fn try_decrypt_content(&self, maybe_ciphertext: &str) -> Option<String> {
236 #[cfg(feature = "encryption")]
237 {
238 if let Some(enc) = &self.encryptor {
239 return enc.decrypt_string(maybe_ciphertext).ok();
240 }
241 }
242 Some(maybe_ciphertext.to_string())
243 }
244
245 pub fn wal_checkpoint(&self) -> Result<(), SqliteError> {
251 self.with_conn(|conn| {
252 conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
253 Ok(())
254 })
255 }
256
257 pub fn insert_scheduled_intent(
259 &self,
260 description: &str,
261 cron: Option<&str>,
262 namespace: &str,
263 metadata: Option<&str>,
264 ) -> Result<String, SqliteError> {
265 let id = Uuid::new_v4().to_string();
266 self.with_conn(|conn| {
267 conn.execute(
268 "INSERT INTO scheduled_intents (id, description, cron, namespace, metadata)
269 VALUES (?1, ?2, ?3, ?4, ?5)",
270 rusqlite::params![id, description, cron, namespace, metadata],
271 )?;
272 Ok(())
273 })?;
274 Ok(id)
275 }
276
277 pub fn list_scheduled_intents(
279 &self,
280 namespace: Option<&str>,
281 ) -> Result<Vec<ScheduledIntent>, SqliteError> {
282 self.with_conn(|conn| {
283 let mut intents = Vec::new();
284 if let Some(ns) = namespace {
285 let mut stmt = conn.prepare(
286 "SELECT id, description, cron, namespace, created_at, status, metadata
287 FROM scheduled_intents
288 WHERE namespace = ?1 OR namespace LIKE ?2
289 ORDER BY created_at DESC",
290 )?;
291 let prefix = format!("{}/%", ns);
292 let rows = stmt.query_map([ns, &prefix], |row| {
293 Ok(ScheduledIntent {
294 id: row.get(0)?,
295 description: row.get(1)?,
296 cron: row.get(2)?,
297 namespace: row.get(3)?,
298 created_at: row.get(4)?,
299 status: row.get(5)?,
300 metadata: row.get(6)?,
301 })
302 })?;
303 for row in rows {
304 intents.push(row?);
305 }
306 } else {
307 let mut stmt = conn.prepare(
308 "SELECT id, description, cron, namespace, created_at, status, metadata
309 FROM scheduled_intents
310 ORDER BY created_at DESC",
311 )?;
312 let rows = stmt.query_map([], |row| {
313 Ok(ScheduledIntent {
314 id: row.get(0)?,
315 description: row.get(1)?,
316 cron: row.get(2)?,
317 namespace: row.get(3)?,
318 created_at: row.get(4)?,
319 status: row.get(5)?,
320 metadata: row.get(6)?,
321 })
322 })?;
323 for row in rows {
324 intents.push(row?);
325 }
326 }
327 Ok(intents)
328 })
329 }
330
331 pub fn update_scheduled_intent_status(
333 &self,
334 id: &str,
335 status: &str,
336 ) -> Result<bool, SqliteError> {
337 self.with_conn(|conn| {
338 let affected = conn.execute(
339 "UPDATE scheduled_intents SET status = ?2 WHERE id = ?1",
340 rusqlite::params![id, status],
341 )?;
342 Ok(affected > 0)
343 })
344 }
345
346 pub fn cancel_scheduled_intent(&self, id: &str) -> Result<bool, SqliteError> {
348 self.update_scheduled_intent_status(id, "cancelled")
349 }
350
351 pub fn due_scheduled_intents(&self) -> Result<Vec<ScheduledIntent>, SqliteError> {
353 self.with_conn(|conn| {
354 let mut stmt = conn.prepare(
355 "SELECT id, description, cron, namespace, created_at, status, metadata
356 FROM scheduled_intents
357 WHERE status = 'scheduled'
358 ORDER BY created_at ASC",
359 )?;
360 let rows = stmt.query_map([], |row| {
361 Ok(ScheduledIntent {
362 id: row.get(0)?,
363 description: row.get(1)?,
364 cron: row.get(2)?,
365 namespace: row.get(3)?,
366 created_at: row.get(4)?,
367 status: row.get(5)?,
368 metadata: row.get(6)?,
369 })
370 })?;
371 Ok(rows.filter_map(|r| r.ok()).collect())
372 })
373 }
374
375 pub fn insert_notification(
377 &self,
378 content: &str,
379 priority: i32,
380 triggered_by: &str,
381 channel: Option<&str>,
382 ) -> Result<String, SqliteError> {
383 let id = Uuid::new_v4().to_string();
384 self.with_conn(|conn| {
385 conn.execute(
386 "INSERT INTO notification_outbox (id, content, priority, triggered_by, channel)
387 VALUES (?1, ?2, ?3, ?4, ?5)",
388 rusqlite::params![id, content, priority, triggered_by, channel],
389 )?;
390 Ok(())
391 })?;
392 Ok(id)
393 }
394
395 pub fn pending_notifications(&self, limit: usize) -> Result<Vec<Notification>, SqliteError> {
397 self.with_conn(|conn| {
398 let mut stmt = conn.prepare(
399 "SELECT id, content, priority, triggered_by, created_at, delivered_at, channel
400 FROM notification_outbox
401 WHERE delivered_at IS NULL
402 ORDER BY priority DESC, created_at ASC
403 LIMIT ?1",
404 )?;
405 let rows = stmt
406 .query_map([limit as i64], |row| {
407 Ok(Notification {
408 id: row.get(0)?,
409 content: row.get(1)?,
410 priority: row.get(2)?,
411 triggered_by: row.get(3)?,
412 created_at: row.get(4)?,
413 delivered_at: row.get(5)?,
414 channel: row.get(6)?,
415 })
416 })?
417 .collect::<Result<Vec<_>, _>>()?;
418 Ok(rows)
419 })
420 }
421
422 pub fn mark_notification_delivered(&self, id: &str) -> Result<bool, SqliteError> {
424 self.with_conn(|conn| {
425 let affected = conn.execute(
426 "UPDATE notification_outbox SET delivered_at = datetime('now') WHERE id = ?1 AND delivered_at IS NULL",
427 [id],
428 )?;
429 Ok(affected > 0)
430 })
431 }
432
433 pub fn mark_notifications_delivered(&self, ids: &[String]) -> Result<usize, SqliteError> {
436 if ids.is_empty() {
437 return Ok(0);
438 }
439 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
440 let sql = format!(
441 "UPDATE notification_outbox SET delivered_at = datetime('now') \
442 WHERE delivered_at IS NULL AND id IN ({placeholders})"
443 );
444 self.with_conn(|conn| {
445 let params: Vec<&dyn rusqlite::types::ToSql> = ids
446 .iter()
447 .map(|id| id as &dyn rusqlite::types::ToSql)
448 .collect();
449 let affected = conn.execute(&sql, params.as_slice())?;
450 Ok(affected)
451 })
452 }
453
454 pub fn prune_notifications(&self, max_age_days: u32) -> Result<usize, SqliteError> {
456 self.with_conn(|conn| {
457 let deleted = conn.execute(
458 "DELETE FROM notification_outbox
459 WHERE (delivered_at IS NOT NULL AND created_at < datetime('now', ?1))
460 OR created_at < datetime('now', ?1)",
461 [format!("-{max_age_days} days")],
462 )?;
463 Ok(deleted)
464 })
465 }
466
467 pub fn export_all_facts(&self) -> Result<Vec<ExportedFact>, SqliteError> {
471 self.with_conn(|conn| {
472 let mut stmt = conn.prepare(
473 "SELECT id, namespace, category, subject, predicate, object,
474 confidence, source_episode_id
475 FROM semantic_facts
476 ORDER BY id ASC",
477 )?;
478 let rows = stmt
479 .query_map([], |row| {
480 Ok(ExportedFact {
481 id: row.get(0)?,
482 namespace: row.get(1)?,
483 category: row.get(2)?,
484 subject: row.get(3)?,
485 predicate: row.get(4)?,
486 object: row.get(5)?,
487 confidence: row.get(6)?,
488 source_episode_id: row.get(7)?,
489 })
490 })?
491 .collect::<Result<Vec<_>, _>>()?;
492 Ok(rows)
493 })
494 }
495
496 pub fn export_all_episodes(&self) -> Result<Vec<ExportedEpisode>, SqliteError> {
498 self.with_conn(|conn| {
499 let mut stmt = conn.prepare(
500 "SELECT e.id, e.session_id, COALESCE(s.channel, 'cli'),
501 e.namespace, e.role, e.content, e.timestamp,
502 e.importance, e.reinforcement_count
503 FROM episodes e
504 LEFT JOIN sessions s ON s.id = e.session_id
505 ORDER BY e.timestamp ASC",
506 )?;
507 let rows = stmt
508 .query_map([], |row| {
509 Ok(ExportedEpisode {
510 id: row.get(0)?,
511 session_id: row.get(1)?,
512 session_channel: row.get(2)?,
513 namespace: row.get(3)?,
514 role: row.get(4)?,
515 content: row.get(5)?,
516 timestamp: row.get(6)?,
517 importance: row.get(7)?,
518 reinforcement_count: row.get(8)?,
519 })
520 })?
521 .collect::<Result<Vec<_>, _>>()?;
522 Ok(rows)
523 })
524 }
525
526 pub fn import_facts(&self, facts: &[ExportedFact]) -> Result<(usize, Vec<usize>), SqliteError> {
528 self.with_conn(|conn| {
529 let mut imported = 0usize;
530 let mut new_indices = Vec::new();
531 for (idx, f) in facts.iter().enumerate() {
532 let n = conn.execute(
533 "INSERT INTO semantic_facts
534 (id, namespace, category, subject, predicate, object,
535 confidence, source_episode_id)
536 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
537 ON CONFLICT(id) DO NOTHING",
538 rusqlite::params![
539 f.id,
540 f.namespace,
541 f.category,
542 f.subject,
543 f.predicate,
544 f.object,
545 f.confidence,
546 f.source_episode_id
547 ],
548 )?;
549 if n > 0 {
550 new_indices.push(idx);
551 }
552 imported += n;
553 }
554 Ok((imported, new_indices))
555 })
556 }
557
558 pub fn import_episodes(&self, episodes: &[ExportedEpisode]) -> Result<usize, SqliteError> {
560 self.with_conn(|conn| {
561 let mut sessions: std::collections::HashMap<String, String> =
563 std::collections::HashMap::new();
564 for ep in episodes {
565 sessions
566 .entry(ep.session_id.clone())
567 .or_insert_with(|| ep.session_channel.clone());
568 }
569 for (sid, channel) in &sessions {
570 conn.execute(
571 "INSERT INTO sessions (id, channel) VALUES (?1, ?2)
572 ON CONFLICT(id) DO NOTHING",
573 rusqlite::params![sid, channel],
574 )?;
575 }
576
577 let mut imported = 0usize;
578 for e in episodes {
579 let n = conn.execute(
580 "INSERT INTO episodes
581 (id, session_id, namespace, role, content, timestamp,
582 importance, reinforcement_count)
583 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
584 ON CONFLICT(id) DO NOTHING",
585 rusqlite::params![
586 e.id,
587 e.session_id,
588 e.namespace,
589 e.role,
590 e.content,
591 e.timestamp,
592 e.importance,
593 e.reinforcement_count
594 ],
595 )?;
596 imported += n;
597 }
598 Ok(imported)
599 })
600 }
601
602 pub fn table_stats(&self) -> Result<Vec<(String, i64)>, SqliteError> {
604 self.with_conn(|conn| {
605 let mut stats = Vec::new();
606 for table in &[
609 "sessions",
610 "episodes",
611 "semantic_facts",
612 "episode_promotions",
613 "scheduled_intents",
614 "notification_outbox",
615 "user_profile",
616 "procedures",
617 "audit_log",
618 ] {
619 let sql = match *table {
620 "sessions" => "SELECT COUNT(*) FROM sessions",
621 "episodes" => "SELECT COUNT(*) FROM episodes",
622 "semantic_facts" => "SELECT COUNT(*) FROM semantic_facts",
623 "episode_promotions" => "SELECT COUNT(*) FROM episode_promotions",
624 "scheduled_intents" => "SELECT COUNT(*) FROM scheduled_intents",
625 "notification_outbox" => "SELECT COUNT(*) FROM notification_outbox",
626 "user_profile" => "SELECT COUNT(*) FROM user_profile",
627 "procedures" => "SELECT COUNT(*) FROM procedures",
628 "audit_log" => "SELECT COUNT(*) FROM audit_log",
629 _ => continue,
630 };
631 let count: i64 = conn.query_row(sql, [], |row| row.get(0)).unwrap_or(0);
632 stats.push((table.to_string(), count));
633 }
634
635 Ok(stats)
636 })
637 }
638}