Skip to main content

engram_storage/
memory.rs

1use rusqlite::params;
2use serde::{Deserialize, Serialize};
3
4use crate::database::Database;
5use crate::error::StorageError;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Memory {
9    pub id: String,
10    pub memory_type: String,
11    pub context: String,
12    pub action: String,
13    pub result: String,
14    pub score: f32,
15    pub embedding_context: Option<Vec<u8>>,
16    pub embedding_action: Option<Vec<u8>>,
17    pub embedding_result: Option<Vec<u8>>,
18    pub indexed: bool,
19    pub tags: Option<String>,
20    pub project: Option<String>,
21    pub parent_id: Option<String>,
22    pub source_ids: Option<String>,
23    pub insight_type: Option<String>,
24    pub created_at: String,
25    pub updated_at: String,
26    pub used_count: i64,
27    pub last_used_at: Option<String>,
28    pub superseded_by: Option<String>,
29}
30
31pub fn row_to_memory(row: &rusqlite::Row) -> rusqlite::Result<Memory> {
32    Ok(Memory {
33        id: row.get("id")?,
34        memory_type: row.get("memory_type")?,
35        context: row.get("context")?,
36        action: row.get("action")?,
37        result: row.get("result")?,
38        score: row.get("score")?,
39        embedding_context: row.get("embedding_context")?,
40        embedding_action: row.get("embedding_action")?,
41        embedding_result: row.get("embedding_result")?,
42        indexed: row.get("indexed")?,
43        tags: row.get("tags")?,
44        project: row.get("project")?,
45        parent_id: row.get("parent_id")?,
46        source_ids: row.get("source_ids")?,
47        insight_type: row.get("insight_type")?,
48        created_at: row.get("created_at")?,
49        updated_at: row.get("updated_at")?,
50        used_count: row.get("used_count")?,
51        last_used_at: row.get("last_used_at")?,
52        superseded_by: row.get("superseded_by")?,
53    })
54}
55
56const INSERT_SQL: &str = r#"
57    INSERT INTO memories (
58        id, memory_type, context, action, result, score,
59        embedding_context, embedding_action, embedding_result,
60        indexed, tags, project, parent_id, source_ids, insight_type,
61        created_at, updated_at, used_count, last_used_at, superseded_by
62    ) VALUES (
63        ?1, ?2, ?3, ?4, ?5, ?6,
64        ?7, ?8, ?9,
65        ?10, ?11, ?12, ?13, ?14, ?15,
66        ?16, ?17, ?18, ?19, ?20
67    )
68"#;
69
70impl Database {
71    pub fn insert_memory(&self, memory: &Memory) -> Result<(), StorageError> {
72        self.connection()
73            .execute(
74                INSERT_SQL,
75                params![
76                    memory.id,
77                    memory.memory_type,
78                    memory.context,
79                    memory.action,
80                    memory.result,
81                    memory.score,
82                    memory.embedding_context,
83                    memory.embedding_action,
84                    memory.embedding_result,
85                    memory.indexed,
86                    memory.tags,
87                    memory.project,
88                    memory.parent_id,
89                    memory.source_ids,
90                    memory.insight_type,
91                    memory.created_at,
92                    memory.updated_at,
93                    memory.used_count,
94                    memory.last_used_at,
95                    memory.superseded_by,
96                ],
97            )
98            .map_err(|error| match error {
99                rusqlite::Error::SqliteFailure(sql_error, _)
100                    if sql_error.extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY
101                        || sql_error.extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_UNIQUE =>
102                {
103                    StorageError::DuplicateKey(format!("memory id={}", memory.id))
104                }
105                other => StorageError::Sqlite(other),
106            })?;
107        Ok(())
108    }
109
110    pub fn get_memory(&self, id: &str) -> Result<Memory, StorageError> {
111        self.connection()
112            .query_row(
113                "SELECT * FROM memories WHERE id = ?1",
114                params![id],
115                row_to_memory,
116            )
117            .map_err(|error| match error {
118                rusqlite::Error::QueryReturnedNoRows => {
119                    StorageError::NotFound(format!("memory id={id}"))
120                }
121                other => StorageError::Sqlite(other),
122            })
123    }
124
125    pub fn set_memory_indexed(&self, id: &str, indexed: bool) -> Result<(), StorageError> {
126        let affected = self.connection().execute(
127            "UPDATE memories SET indexed = ?1 WHERE id = ?2",
128            params![indexed, id],
129        )?;
130        if affected == 0 {
131            return Err(StorageError::NotFound(format!("memory id={id}")));
132        }
133        Ok(())
134    }
135
136    pub fn set_memory_embeddings(
137        &self,
138        id: &str,
139        embedding_context: &[u8],
140        embedding_action: &[u8],
141        embedding_result: &[u8],
142    ) -> Result<(), StorageError> {
143        let affected = self.connection().execute(
144            "UPDATE memories
145             SET embedding_context = ?1,
146                 embedding_action = ?2,
147                 embedding_result = ?3
148             WHERE id = ?4",
149            params![embedding_context, embedding_action, embedding_result, id],
150        )?;
151        if affected == 0 {
152            return Err(StorageError::NotFound(format!("memory id={id}")));
153        }
154        Ok(())
155    }
156
157    pub fn set_memory_score(&self, id: &str, score: f32) -> Result<(), StorageError> {
158        let affected = self.connection().execute(
159            "UPDATE memories SET score = ?1 WHERE id = ?2",
160            params![score, id],
161        )?;
162        if affected == 0 {
163            return Err(StorageError::NotFound(format!("memory id={id}")));
164        }
165        Ok(())
166    }
167
168    pub fn touch_memory(&self, id: &str, timestamp: &str) -> Result<(), StorageError> {
169        let affected = self.connection().execute(
170            "UPDATE memories SET used_count = used_count + 1, last_used_at = ?1 WHERE id = ?2",
171            params![timestamp, id],
172        )?;
173        if affected == 0 {
174            return Err(StorageError::NotFound(format!("memory id={id}")));
175        }
176        Ok(())
177    }
178
179    pub fn set_superseded_by(&self, id: &str, superseded_by: &str) -> Result<(), StorageError> {
180        let affected = self.connection().execute(
181            "UPDATE memories SET superseded_by = ?1 WHERE id = ?2",
182            params![superseded_by, id],
183        )?;
184        if affected == 0 {
185            return Err(StorageError::NotFound(format!("memory id={id}")));
186        }
187        Ok(())
188    }
189
190    pub fn delete_memory(&self, id: &str) -> Result<(), StorageError> {
191        let affected = self
192            .connection()
193            .execute("DELETE FROM memories WHERE id = ?1", params![id])?;
194        if affected == 0 {
195            return Err(StorageError::NotFound(format!("memory id={id}")));
196        }
197        Ok(())
198    }
199
200    pub fn bulk_insert_memories(&self, memories: &[Memory]) -> Result<usize, StorageError> {
201        let transaction = self.connection().unchecked_transaction()?;
202        let mut statement = transaction.prepare(INSERT_SQL)?;
203        let mut count = 0;
204        for memory in memories {
205            statement.execute(params![
206                memory.id,
207                memory.memory_type,
208                memory.context,
209                memory.action,
210                memory.result,
211                memory.score,
212                memory.embedding_context,
213                memory.embedding_action,
214                memory.embedding_result,
215                memory.indexed,
216                memory.tags,
217                memory.project,
218                memory.parent_id,
219                memory.source_ids,
220                memory.insight_type,
221                memory.created_at,
222                memory.updated_at,
223                memory.used_count,
224                memory.last_used_at,
225                memory.superseded_by,
226            ])?;
227            count += 1;
228        }
229        drop(statement);
230        transaction.commit()?;
231        Ok(count)
232    }
233
234    pub fn list_all_memories(&self) -> Result<Vec<Memory>, StorageError> {
235        let mut statement = self
236            .connection()
237            .prepare("SELECT * FROM memories WHERE superseded_by IS NULL")?;
238        let rows = statement.query_map([], row_to_memory)?;
239        let mut results = Vec::new();
240        for row in rows {
241            results.push(row?);
242        }
243        Ok(results)
244    }
245
246    pub fn get_unindexed_memories(&self, limit: usize) -> Result<Vec<Memory>, StorageError> {
247        let mut statement = self
248            .connection()
249            .prepare("SELECT * FROM memories WHERE indexed = FALSE LIMIT ?1")?;
250        let rows = statement.query_map(params![limit as i64], row_to_memory)?;
251        let mut results = Vec::new();
252        for row in rows {
253            results.push(row?);
254        }
255        Ok(results)
256    }
257
258    pub fn get_indexed_memory_ids(&self) -> Result<Vec<String>, StorageError> {
259        let mut statement = self
260            .connection()
261            .prepare("SELECT id FROM memories WHERE indexed = TRUE")?;
262        let rows = statement.query_map([], |row| row.get::<_, String>(0))?;
263        let mut results = Vec::new();
264        for row in rows {
265            results.push(row?);
266        }
267        Ok(results)
268    }
269}