1use rusqlite::params;
2use serde::{Deserialize, Serialize};
3
4use crate::database::Database;
5use crate::error::StorageError;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Memory {
9 pub id: String,
10 pub memory_type: String,
11 pub context: String,
12 pub action: String,
13 pub result: String,
14 pub score: f32,
15 pub embedding_context: Option<Vec<u8>>,
16 pub embedding_action: Option<Vec<u8>>,
17 pub embedding_result: Option<Vec<u8>>,
18 pub indexed: bool,
19 pub tags: Option<String>,
20 pub project: Option<String>,
21 pub parent_id: Option<String>,
22 pub source_ids: Option<String>,
23 pub insight_type: Option<String>,
24 pub created_at: String,
25 pub updated_at: String,
26 pub used_count: i64,
27 pub last_used_at: Option<String>,
28 pub superseded_by: Option<String>,
29}
30
31pub fn row_to_memory(row: &rusqlite::Row) -> rusqlite::Result<Memory> {
32 Ok(Memory {
33 id: row.get("id")?,
34 memory_type: row.get("memory_type")?,
35 context: row.get("context")?,
36 action: row.get("action")?,
37 result: row.get("result")?,
38 score: row.get("score")?,
39 embedding_context: row.get("embedding_context")?,
40 embedding_action: row.get("embedding_action")?,
41 embedding_result: row.get("embedding_result")?,
42 indexed: row.get("indexed")?,
43 tags: row.get("tags")?,
44 project: row.get("project")?,
45 parent_id: row.get("parent_id")?,
46 source_ids: row.get("source_ids")?,
47 insight_type: row.get("insight_type")?,
48 created_at: row.get("created_at")?,
49 updated_at: row.get("updated_at")?,
50 used_count: row.get("used_count")?,
51 last_used_at: row.get("last_used_at")?,
52 superseded_by: row.get("superseded_by")?,
53 })
54}
55
56const INSERT_SQL: &str = r#"
57 INSERT INTO memories (
58 id, memory_type, context, action, result, score,
59 embedding_context, embedding_action, embedding_result,
60 indexed, tags, project, parent_id, source_ids, insight_type,
61 created_at, updated_at, used_count, last_used_at, superseded_by
62 ) VALUES (
63 ?1, ?2, ?3, ?4, ?5, ?6,
64 ?7, ?8, ?9,
65 ?10, ?11, ?12, ?13, ?14, ?15,
66 ?16, ?17, ?18, ?19, ?20
67 )
68"#;
69
70impl Database {
71 pub fn insert_memory(&self, memory: &Memory) -> Result<(), StorageError> {
72 self.connection()
73 .execute(
74 INSERT_SQL,
75 params![
76 memory.id,
77 memory.memory_type,
78 memory.context,
79 memory.action,
80 memory.result,
81 memory.score,
82 memory.embedding_context,
83 memory.embedding_action,
84 memory.embedding_result,
85 memory.indexed,
86 memory.tags,
87 memory.project,
88 memory.parent_id,
89 memory.source_ids,
90 memory.insight_type,
91 memory.created_at,
92 memory.updated_at,
93 memory.used_count,
94 memory.last_used_at,
95 memory.superseded_by,
96 ],
97 )
98 .map_err(|error| match error {
99 rusqlite::Error::SqliteFailure(sql_error, _)
100 if sql_error.extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY
101 || sql_error.extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_UNIQUE =>
102 {
103 StorageError::DuplicateKey(format!("memory id={}", memory.id))
104 }
105 other => StorageError::Sqlite(other),
106 })?;
107 Ok(())
108 }
109
110 pub fn get_memory(&self, id: &str) -> Result<Memory, StorageError> {
111 self.connection()
112 .query_row(
113 "SELECT * FROM memories WHERE id = ?1",
114 params![id],
115 row_to_memory,
116 )
117 .map_err(|error| match error {
118 rusqlite::Error::QueryReturnedNoRows => {
119 StorageError::NotFound(format!("memory id={id}"))
120 }
121 other => StorageError::Sqlite(other),
122 })
123 }
124
125 pub fn set_memory_indexed(&self, id: &str, indexed: bool) -> Result<(), StorageError> {
126 let affected = self.connection().execute(
127 "UPDATE memories SET indexed = ?1 WHERE id = ?2",
128 params![indexed, id],
129 )?;
130 if affected == 0 {
131 return Err(StorageError::NotFound(format!("memory id={id}")));
132 }
133 Ok(())
134 }
135
136 pub fn set_memory_embeddings(
137 &self,
138 id: &str,
139 embedding_context: &[u8],
140 embedding_action: &[u8],
141 embedding_result: &[u8],
142 ) -> Result<(), StorageError> {
143 let affected = self.connection().execute(
144 "UPDATE memories
145 SET embedding_context = ?1,
146 embedding_action = ?2,
147 embedding_result = ?3
148 WHERE id = ?4",
149 params![embedding_context, embedding_action, embedding_result, id],
150 )?;
151 if affected == 0 {
152 return Err(StorageError::NotFound(format!("memory id={id}")));
153 }
154 Ok(())
155 }
156
157 pub fn set_memory_score(&self, id: &str, score: f32) -> Result<(), StorageError> {
158 let affected = self.connection().execute(
159 "UPDATE memories SET score = ?1 WHERE id = ?2",
160 params![score, id],
161 )?;
162 if affected == 0 {
163 return Err(StorageError::NotFound(format!("memory id={id}")));
164 }
165 Ok(())
166 }
167
168 pub fn touch_memory(&self, id: &str, timestamp: &str) -> Result<(), StorageError> {
169 let affected = self.connection().execute(
170 "UPDATE memories SET used_count = used_count + 1, last_used_at = ?1 WHERE id = ?2",
171 params![timestamp, id],
172 )?;
173 if affected == 0 {
174 return Err(StorageError::NotFound(format!("memory id={id}")));
175 }
176 Ok(())
177 }
178
179 pub fn set_superseded_by(&self, id: &str, superseded_by: &str) -> Result<(), StorageError> {
180 let affected = self.connection().execute(
181 "UPDATE memories SET superseded_by = ?1 WHERE id = ?2",
182 params![superseded_by, id],
183 )?;
184 if affected == 0 {
185 return Err(StorageError::NotFound(format!("memory id={id}")));
186 }
187 Ok(())
188 }
189
190 pub fn delete_memory(&self, id: &str) -> Result<(), StorageError> {
191 let affected = self
192 .connection()
193 .execute("DELETE FROM memories WHERE id = ?1", params![id])?;
194 if affected == 0 {
195 return Err(StorageError::NotFound(format!("memory id={id}")));
196 }
197 Ok(())
198 }
199
200 pub fn bulk_insert_memories(&self, memories: &[Memory]) -> Result<usize, StorageError> {
201 let transaction = self.connection().unchecked_transaction()?;
202 let mut statement = transaction.prepare(INSERT_SQL)?;
203 let mut count = 0;
204 for memory in memories {
205 statement.execute(params![
206 memory.id,
207 memory.memory_type,
208 memory.context,
209 memory.action,
210 memory.result,
211 memory.score,
212 memory.embedding_context,
213 memory.embedding_action,
214 memory.embedding_result,
215 memory.indexed,
216 memory.tags,
217 memory.project,
218 memory.parent_id,
219 memory.source_ids,
220 memory.insight_type,
221 memory.created_at,
222 memory.updated_at,
223 memory.used_count,
224 memory.last_used_at,
225 memory.superseded_by,
226 ])?;
227 count += 1;
228 }
229 drop(statement);
230 transaction.commit()?;
231 Ok(count)
232 }
233
234 pub fn list_all_memories(&self) -> Result<Vec<Memory>, StorageError> {
235 let mut statement = self
236 .connection()
237 .prepare("SELECT * FROM memories WHERE superseded_by IS NULL")?;
238 let rows = statement.query_map([], row_to_memory)?;
239 let mut results = Vec::new();
240 for row in rows {
241 results.push(row?);
242 }
243 Ok(results)
244 }
245
246 pub fn get_unindexed_memories(&self, limit: usize) -> Result<Vec<Memory>, StorageError> {
247 let mut statement = self
248 .connection()
249 .prepare("SELECT * FROM memories WHERE indexed = FALSE LIMIT ?1")?;
250 let rows = statement.query_map(params![limit as i64], row_to_memory)?;
251 let mut results = Vec::new();
252 for row in rows {
253 results.push(row?);
254 }
255 Ok(results)
256 }
257
258 pub fn get_indexed_memory_ids(&self) -> Result<Vec<String>, StorageError> {
259 let mut statement = self
260 .connection()
261 .prepare("SELECT id FROM memories WHERE indexed = TRUE")?;
262 let rows = statement.query_map([], |row| row.get::<_, String>(0))?;
263 let mut results = Vec::new();
264 for row in rows {
265 results.push(row?);
266 }
267 Ok(results)
268 }
269}