1use crate::traits::{Memory, MemoryCategory, MemoryEntry};
2use async_trait::async_trait;
3use chrono::Utc;
4use rusqlite::{params, Connection};
5use std::path::{Path, PathBuf};
6
7pub struct SqliteMemory {
8 db_path: PathBuf,
9}
10
11impl SqliteMemory {
12 pub fn new(workspace_dir: &Path) -> anyhow::Result<Self> {
13 let db_path = workspace_dir.join("memory").join("brain.db");
14 if let Some(parent) = db_path.parent() {
15 std::fs::create_dir_all(parent)?;
16 }
17 let this = Self { db_path };
18 this.init_schema()?;
19 Ok(this)
20 }
21
22 fn open(&self) -> anyhow::Result<Connection> {
23 Ok(Connection::open(&self.db_path)?)
24 }
25
26 fn init_schema(&self) -> anyhow::Result<()> {
27 let conn = self.open()?;
28 conn.execute_batch(
29 "
30 CREATE TABLE IF NOT EXISTS memory_entries (
31 id TEXT PRIMARY KEY,
32 key TEXT NOT NULL UNIQUE,
33 content TEXT NOT NULL,
34 category TEXT NOT NULL,
35 timestamp TEXT NOT NULL,
36 session_id TEXT
37 );
38 CREATE INDEX IF NOT EXISTS idx_memory_key ON memory_entries(key);
39 CREATE INDEX IF NOT EXISTS idx_memory_category ON memory_entries(category);
40 CREATE INDEX IF NOT EXISTS idx_memory_session_id ON memory_entries(session_id);
41 ",
42 )?;
43 Ok(())
44 }
45
46 fn to_category(raw: &str) -> MemoryCategory {
47 match raw {
48 "core" => MemoryCategory::Core,
49 "daily" => MemoryCategory::Daily,
50 "conversation" => MemoryCategory::Conversation,
51 other => MemoryCategory::Custom(other.to_string()),
52 }
53 }
54}
55
56#[async_trait]
57impl Memory for SqliteMemory {
58 fn name(&self) -> &str {
59 "sqlite"
60 }
61
62 async fn store(
63 &self,
64 key: &str,
65 content: &str,
66 category: MemoryCategory,
67 session_id: Option<&str>,
68 ) -> anyhow::Result<()> {
69 let db_path = self.db_path.clone();
70 let key = key.to_string();
71 let content = content.to_string();
72 let category = category.to_string();
73 let session_id = session_id.map(|s| s.to_string());
74
75 tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
76 let conn = Connection::open(db_path)?;
77 conn.execute(
78 "
79 INSERT INTO memory_entries (id, key, content, category, timestamp, session_id)
80 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
81 ON CONFLICT(key)
82 DO UPDATE SET
83 content=excluded.content,
84 category=excluded.category,
85 timestamp=excluded.timestamp,
86 session_id=excluded.session_id
87 ",
88 params![
89 uuid::Uuid::new_v4().to_string(),
90 key,
91 content,
92 category,
93 Utc::now().to_rfc3339(),
94 session_id,
95 ],
96 )?;
97 Ok(())
98 })
99 .await??;
100
101 Ok(())
102 }
103
104 async fn recall(
105 &self,
106 query: &str,
107 limit: usize,
108 session_id: Option<&str>,
109 ) -> anyhow::Result<Vec<MemoryEntry>> {
110 let db_path = self.db_path.clone();
111 let query = query.to_string();
112 let session = session_id.map(|s| s.to_string());
113
114 tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
115 let conn = Connection::open(db_path)?;
116 let mut out = Vec::new();
117 let like = format!("%{}%", query);
118
119 if let Some(session_id) = session {
120 let mut stmt = conn.prepare(
121 "
122 SELECT id, key, content, category, timestamp, session_id
123 FROM memory_entries
124 WHERE (key LIKE ?1 OR content LIKE ?1)
125 AND session_id = ?2
126 ORDER BY timestamp DESC
127 LIMIT ?3
128 ",
129 )?;
130 let rows = stmt.query_map(params![like, session_id, limit as i64], |row| {
131 let category_raw: String = row.get(3)?;
132 Ok(MemoryEntry {
133 id: row.get(0)?,
134 key: row.get(1)?,
135 content: row.get(2)?,
136 category: Self::to_category(&category_raw),
137 timestamp: row.get(4)?,
138 session_id: row.get(5)?,
139 score: None,
140 })
141 })?;
142
143 for row in rows {
144 out.push(row?);
145 }
146 } else {
147 let mut stmt = conn.prepare(
148 "
149 SELECT id, key, content, category, timestamp, session_id
150 FROM memory_entries
151 WHERE key LIKE ?1 OR content LIKE ?1
152 ORDER BY timestamp DESC
153 LIMIT ?2
154 ",
155 )?;
156 let rows = stmt.query_map(params![like, limit as i64], |row| {
157 let category_raw: String = row.get(3)?;
158 Ok(MemoryEntry {
159 id: row.get(0)?,
160 key: row.get(1)?,
161 content: row.get(2)?,
162 category: Self::to_category(&category_raw),
163 timestamp: row.get(4)?,
164 session_id: row.get(5)?,
165 score: None,
166 })
167 })?;
168
169 for row in rows {
170 out.push(row?);
171 }
172 }
173
174 Ok(out)
175 })
176 .await?
177 }
178
179 async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
180 let db_path = self.db_path.clone();
181 let key = key.to_string();
182
183 tokio::task::spawn_blocking(move || -> anyhow::Result<Option<MemoryEntry>> {
184 let conn = Connection::open(db_path)?;
185 let mut stmt = conn.prepare(
186 "SELECT id, key, content, category, timestamp, session_id FROM memory_entries WHERE key = ?1",
187 )?;
188 let mut rows = stmt.query(params![key])?;
189 if let Some(row) = rows.next()? {
190 let category_raw: String = row.get(3)?;
191 return Ok(Some(MemoryEntry {
192 id: row.get(0)?,
193 key: row.get(1)?,
194 content: row.get(2)?,
195 category: Self::to_category(&category_raw),
196 timestamp: row.get(4)?,
197 session_id: row.get(5)?,
198 score: None,
199 }));
200 }
201 Ok(None)
202 })
203 .await?
204 }
205
206 async fn list(
207 &self,
208 category: Option<&MemoryCategory>,
209 session_id: Option<&str>,
210 ) -> anyhow::Result<Vec<MemoryEntry>> {
211 let db_path = self.db_path.clone();
212 let category = category.map(|c| c.to_string());
213 let session = session_id.map(|s| s.to_string());
214
215 tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
216 let conn = Connection::open(db_path)?;
217 let mut out = Vec::new();
218
219 let sql = match (category.is_some(), session.is_some()) {
220 (false, false) => "SELECT id, key, content, category, timestamp, session_id FROM memory_entries ORDER BY timestamp DESC",
221 (true, false) => "SELECT id, key, content, category, timestamp, session_id FROM memory_entries WHERE category = ?1 ORDER BY timestamp DESC",
222 (false, true) => "SELECT id, key, content, category, timestamp, session_id FROM memory_entries WHERE session_id = ?1 ORDER BY timestamp DESC",
223 (true, true) => "SELECT id, key, content, category, timestamp, session_id FROM memory_entries WHERE category = ?1 AND session_id = ?2 ORDER BY timestamp DESC",
224 };
225
226 let mut stmt = conn.prepare(sql)?;
227 let mut rows = match (category, session) {
228 (None, None) => stmt.query(params![] )?,
229 (Some(c), None) => stmt.query(params![c])?,
230 (None, Some(s)) => stmt.query(params![s])?,
231 (Some(c), Some(s)) => stmt.query(params![c, s])?,
232 };
233
234 while let Some(row) = rows.next()? {
235 let category_raw: String = row.get(3)?;
236 out.push(MemoryEntry {
237 id: row.get(0)?,
238 key: row.get(1)?,
239 content: row.get(2)?,
240 category: Self::to_category(&category_raw),
241 timestamp: row.get(4)?,
242 session_id: row.get(5)?,
243 score: None,
244 });
245 }
246
247 Ok(out)
248 })
249 .await?
250 }
251
252 async fn forget(&self, key: &str) -> anyhow::Result<bool> {
253 let db_path = self.db_path.clone();
254 let key = key.to_string();
255 tokio::task::spawn_blocking(move || -> anyhow::Result<bool> {
256 let conn = Connection::open(db_path)?;
257 let affected =
258 conn.execute("DELETE FROM memory_entries WHERE key = ?1", params![key])?;
259 Ok(affected > 0)
260 })
261 .await?
262 }
263
264 async fn count(&self) -> anyhow::Result<usize> {
265 let db_path = self.db_path.clone();
266 tokio::task::spawn_blocking(move || -> anyhow::Result<usize> {
267 let conn = Connection::open(db_path)?;
268 let count: i64 =
269 conn.query_row("SELECT COUNT(*) FROM memory_entries", params![], |r| {
270 r.get(0)
271 })?;
272 Ok(count.max(0) as usize)
273 })
274 .await?
275 }
276
277 async fn health_check(&self) -> bool {
278 self.open().is_ok()
279 }
280}