Skip to main content

enact_memory/
sqlite.rs

1use crate::traits::{Memory, MemoryCategory, MemoryEntry};
2use async_trait::async_trait;
3use chrono::Utc;
4use rusqlite::{params, Connection};
5use std::path::{Path, PathBuf};
6
7pub struct SqliteMemory {
8    db_path: PathBuf,
9}
10
11impl SqliteMemory {
12    pub fn new(workspace_dir: &Path) -> anyhow::Result<Self> {
13        let db_path = workspace_dir.join("memory").join("brain.db");
14        if let Some(parent) = db_path.parent() {
15            std::fs::create_dir_all(parent)?;
16        }
17        let this = Self { db_path };
18        this.init_schema()?;
19        Ok(this)
20    }
21
22    fn open(&self) -> anyhow::Result<Connection> {
23        Ok(Connection::open(&self.db_path)?)
24    }
25
26    fn init_schema(&self) -> anyhow::Result<()> {
27        let conn = self.open()?;
28        conn.execute_batch(
29            "
30            CREATE TABLE IF NOT EXISTS memory_entries (
31                id TEXT PRIMARY KEY,
32                key TEXT NOT NULL UNIQUE,
33                content TEXT NOT NULL,
34                category TEXT NOT NULL,
35                timestamp TEXT NOT NULL,
36                session_id TEXT
37            );
38            CREATE INDEX IF NOT EXISTS idx_memory_key ON memory_entries(key);
39            CREATE INDEX IF NOT EXISTS idx_memory_category ON memory_entries(category);
40            CREATE INDEX IF NOT EXISTS idx_memory_session_id ON memory_entries(session_id);
41            ",
42        )?;
43        Ok(())
44    }
45
46    fn to_category(raw: &str) -> MemoryCategory {
47        match raw {
48            "core" => MemoryCategory::Core,
49            "daily" => MemoryCategory::Daily,
50            "conversation" => MemoryCategory::Conversation,
51            other => MemoryCategory::Custom(other.to_string()),
52        }
53    }
54}
55
56#[async_trait]
57impl Memory for SqliteMemory {
58    fn name(&self) -> &str {
59        "sqlite"
60    }
61
62    async fn store(
63        &self,
64        key: &str,
65        content: &str,
66        category: MemoryCategory,
67        session_id: Option<&str>,
68    ) -> anyhow::Result<()> {
69        let db_path = self.db_path.clone();
70        let key = key.to_string();
71        let content = content.to_string();
72        let category = category.to_string();
73        let session_id = session_id.map(|s| s.to_string());
74
75        tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
76            let conn = Connection::open(db_path)?;
77            conn.execute(
78                "
79                INSERT INTO memory_entries (id, key, content, category, timestamp, session_id)
80                VALUES (?1, ?2, ?3, ?4, ?5, ?6)
81                ON CONFLICT(key)
82                DO UPDATE SET
83                    content=excluded.content,
84                    category=excluded.category,
85                    timestamp=excluded.timestamp,
86                    session_id=excluded.session_id
87                ",
88                params![
89                    uuid::Uuid::new_v4().to_string(),
90                    key,
91                    content,
92                    category,
93                    Utc::now().to_rfc3339(),
94                    session_id,
95                ],
96            )?;
97            Ok(())
98        })
99        .await??;
100
101        Ok(())
102    }
103
104    async fn recall(
105        &self,
106        query: &str,
107        limit: usize,
108        session_id: Option<&str>,
109    ) -> anyhow::Result<Vec<MemoryEntry>> {
110        let db_path = self.db_path.clone();
111        let query = query.to_string();
112        let session = session_id.map(|s| s.to_string());
113
114        tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
115            let conn = Connection::open(db_path)?;
116            let mut out = Vec::new();
117            let like = format!("%{}%", query);
118
119            if let Some(session_id) = session {
120                let mut stmt = conn.prepare(
121                    "
122                    SELECT id, key, content, category, timestamp, session_id
123                    FROM memory_entries
124                    WHERE (key LIKE ?1 OR content LIKE ?1)
125                      AND session_id = ?2
126                    ORDER BY timestamp DESC
127                    LIMIT ?3
128                    ",
129                )?;
130                let rows = stmt.query_map(params![like, session_id, limit as i64], |row| {
131                    let category_raw: String = row.get(3)?;
132                    Ok(MemoryEntry {
133                        id: row.get(0)?,
134                        key: row.get(1)?,
135                        content: row.get(2)?,
136                        category: Self::to_category(&category_raw),
137                        timestamp: row.get(4)?,
138                        session_id: row.get(5)?,
139                        score: None,
140                    })
141                })?;
142
143                for row in rows {
144                    out.push(row?);
145                }
146            } else {
147                let mut stmt = conn.prepare(
148                    "
149                    SELECT id, key, content, category, timestamp, session_id
150                    FROM memory_entries
151                    WHERE key LIKE ?1 OR content LIKE ?1
152                    ORDER BY timestamp DESC
153                    LIMIT ?2
154                    ",
155                )?;
156                let rows = stmt.query_map(params![like, limit as i64], |row| {
157                    let category_raw: String = row.get(3)?;
158                    Ok(MemoryEntry {
159                        id: row.get(0)?,
160                        key: row.get(1)?,
161                        content: row.get(2)?,
162                        category: Self::to_category(&category_raw),
163                        timestamp: row.get(4)?,
164                        session_id: row.get(5)?,
165                        score: None,
166                    })
167                })?;
168
169                for row in rows {
170                    out.push(row?);
171                }
172            }
173
174            Ok(out)
175        })
176        .await?
177    }
178
179    async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
180        let db_path = self.db_path.clone();
181        let key = key.to_string();
182
183        tokio::task::spawn_blocking(move || -> anyhow::Result<Option<MemoryEntry>> {
184            let conn = Connection::open(db_path)?;
185            let mut stmt = conn.prepare(
186                "SELECT id, key, content, category, timestamp, session_id FROM memory_entries WHERE key = ?1",
187            )?;
188            let mut rows = stmt.query(params![key])?;
189            if let Some(row) = rows.next()? {
190                let category_raw: String = row.get(3)?;
191                return Ok(Some(MemoryEntry {
192                    id: row.get(0)?,
193                    key: row.get(1)?,
194                    content: row.get(2)?,
195                    category: Self::to_category(&category_raw),
196                    timestamp: row.get(4)?,
197                    session_id: row.get(5)?,
198                    score: None,
199                }));
200            }
201            Ok(None)
202        })
203        .await?
204    }
205
206    async fn list(
207        &self,
208        category: Option<&MemoryCategory>,
209        session_id: Option<&str>,
210    ) -> anyhow::Result<Vec<MemoryEntry>> {
211        let db_path = self.db_path.clone();
212        let category = category.map(|c| c.to_string());
213        let session = session_id.map(|s| s.to_string());
214
215        tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
216            let conn = Connection::open(db_path)?;
217            let mut out = Vec::new();
218
219            let sql = match (category.is_some(), session.is_some()) {
220                (false, false) => "SELECT id, key, content, category, timestamp, session_id FROM memory_entries ORDER BY timestamp DESC",
221                (true, false) => "SELECT id, key, content, category, timestamp, session_id FROM memory_entries WHERE category = ?1 ORDER BY timestamp DESC",
222                (false, true) => "SELECT id, key, content, category, timestamp, session_id FROM memory_entries WHERE session_id = ?1 ORDER BY timestamp DESC",
223                (true, true) => "SELECT id, key, content, category, timestamp, session_id FROM memory_entries WHERE category = ?1 AND session_id = ?2 ORDER BY timestamp DESC",
224            };
225
226            let mut stmt = conn.prepare(sql)?;
227            let mut rows = match (category, session) {
228                (None, None) => stmt.query(params![] )?,
229                (Some(c), None) => stmt.query(params![c])?,
230                (None, Some(s)) => stmt.query(params![s])?,
231                (Some(c), Some(s)) => stmt.query(params![c, s])?,
232            };
233
234            while let Some(row) = rows.next()? {
235                let category_raw: String = row.get(3)?;
236                out.push(MemoryEntry {
237                    id: row.get(0)?,
238                    key: row.get(1)?,
239                    content: row.get(2)?,
240                    category: Self::to_category(&category_raw),
241                    timestamp: row.get(4)?,
242                    session_id: row.get(5)?,
243                    score: None,
244                });
245            }
246
247            Ok(out)
248        })
249        .await?
250    }
251
252    async fn forget(&self, key: &str) -> anyhow::Result<bool> {
253        let db_path = self.db_path.clone();
254        let key = key.to_string();
255        tokio::task::spawn_blocking(move || -> anyhow::Result<bool> {
256            let conn = Connection::open(db_path)?;
257            let affected =
258                conn.execute("DELETE FROM memory_entries WHERE key = ?1", params![key])?;
259            Ok(affected > 0)
260        })
261        .await?
262    }
263
264    async fn count(&self) -> anyhow::Result<usize> {
265        let db_path = self.db_path.clone();
266        tokio::task::spawn_blocking(move || -> anyhow::Result<usize> {
267            let conn = Connection::open(db_path)?;
268            let count: i64 =
269                conn.query_row("SELECT COUNT(*) FROM memory_entries", params![], |r| {
270                    r.get(0)
271                })?;
272            Ok(count.max(0) as usize)
273        })
274        .await?
275    }
276
277    async fn health_check(&self) -> bool {
278        self.open().is_ok()
279    }
280}