Skip to main content

engram_storage/
memory.rs

1use rusqlite::params;
2use serde::{Deserialize, Serialize};
3
4use crate::database::Database;
5use crate::error::StorageError;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Memory {
9    pub id: String,
10    pub memory_type: String,
11    pub context: String,
12    pub action: String,
13    pub result: String,
14    pub score: f32,
15    pub embedding_context: Option<Vec<u8>>,
16    pub embedding_action: Option<Vec<u8>>,
17    pub embedding_result: Option<Vec<u8>>,
18    pub indexed: bool,
19    pub tags: Option<String>,
20    pub project: Option<String>,
21    pub parent_id: Option<String>,
22    pub source_ids: Option<String>,
23    pub insight_type: Option<String>,
24    pub created_at: String,
25    pub updated_at: String,
26    pub used_count: i64,
27    pub last_used_at: Option<String>,
28    pub superseded_by: Option<String>,
29}
30
31pub fn row_to_memory(row: &rusqlite::Row) -> rusqlite::Result<Memory> {
32    Ok(Memory {
33        id: row.get("id")?,
34        memory_type: row.get("memory_type")?,
35        context: row.get("context")?,
36        action: row.get("action")?,
37        result: row.get("result")?,
38        score: row.get("score")?,
39        embedding_context: row.get("embedding_context")?,
40        embedding_action: row.get("embedding_action")?,
41        embedding_result: row.get("embedding_result")?,
42        indexed: row.get("indexed")?,
43        tags: row.get("tags")?,
44        project: row.get("project")?,
45        parent_id: row.get("parent_id")?,
46        source_ids: row.get("source_ids")?,
47        insight_type: row.get("insight_type")?,
48        created_at: row.get("created_at")?,
49        updated_at: row.get("updated_at")?,
50        used_count: row.get("used_count")?,
51        last_used_at: row.get("last_used_at")?,
52        superseded_by: row.get("superseded_by")?,
53    })
54}
55
56const INSERT_SQL: &str = r#"
57    INSERT INTO memories (
58        id, memory_type, context, action, result, score,
59        embedding_context, embedding_action, embedding_result,
60        indexed, tags, project, parent_id, source_ids, insight_type,
61        created_at, updated_at, used_count, last_used_at, superseded_by
62    ) VALUES (
63        ?1, ?2, ?3, ?4, ?5, ?6,
64        ?7, ?8, ?9,
65        ?10, ?11, ?12, ?13, ?14, ?15,
66        ?16, ?17, ?18, ?19, ?20
67    )
68"#;
69
70impl Database {
71    pub fn insert_memory(&self, memory: &Memory) -> Result<(), StorageError> {
72        self.connection()
73            .execute(
74                INSERT_SQL,
75                params![
76                    memory.id,
77                    memory.memory_type,
78                    memory.context,
79                    memory.action,
80                    memory.result,
81                    memory.score,
82                    memory.embedding_context,
83                    memory.embedding_action,
84                    memory.embedding_result,
85                    memory.indexed,
86                    memory.tags,
87                    memory.project,
88                    memory.parent_id,
89                    memory.source_ids,
90                    memory.insight_type,
91                    memory.created_at,
92                    memory.updated_at,
93                    memory.used_count,
94                    memory.last_used_at,
95                    memory.superseded_by,
96                ],
97            )
98            .map_err(|error| match error {
99                rusqlite::Error::SqliteFailure(sql_error, _)
100                    if sql_error.extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY
101                        || sql_error.extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_UNIQUE =>
102                {
103                    StorageError::DuplicateKey(format!("memory id={}", memory.id))
104                }
105                other => StorageError::Sqlite(other),
106            })?;
107        Ok(())
108    }
109
110    pub fn get_memory(&self, id: &str) -> Result<Memory, StorageError> {
111        self.connection()
112            .query_row(
113                "SELECT * FROM memories WHERE id = ?1",
114                params![id],
115                row_to_memory,
116            )
117            .map_err(|error| match error {
118                rusqlite::Error::QueryReturnedNoRows => {
119                    StorageError::NotFound(format!("memory id={id}"))
120                }
121                other => StorageError::Sqlite(other),
122            })
123    }
124
125    pub fn set_memory_indexed(&self, id: &str, indexed: bool) -> Result<(), StorageError> {
126        let affected = self.connection().execute(
127            "UPDATE memories SET indexed = ?1 WHERE id = ?2",
128            params![indexed, id],
129        )?;
130        if affected == 0 {
131            return Err(StorageError::NotFound(format!("memory id={id}")));
132        }
133        Ok(())
134    }
135
136    pub fn set_memory_score(&self, id: &str, score: f32) -> Result<(), StorageError> {
137        let affected = self.connection().execute(
138            "UPDATE memories SET score = ?1 WHERE id = ?2",
139            params![score, id],
140        )?;
141        if affected == 0 {
142            return Err(StorageError::NotFound(format!("memory id={id}")));
143        }
144        Ok(())
145    }
146
147    pub fn touch_memory(&self, id: &str, timestamp: &str) -> Result<(), StorageError> {
148        let affected = self.connection().execute(
149            "UPDATE memories SET used_count = used_count + 1, last_used_at = ?1 WHERE id = ?2",
150            params![timestamp, id],
151        )?;
152        if affected == 0 {
153            return Err(StorageError::NotFound(format!("memory id={id}")));
154        }
155        Ok(())
156    }
157
158    pub fn set_superseded_by(&self, id: &str, superseded_by: &str) -> Result<(), StorageError> {
159        let affected = self.connection().execute(
160            "UPDATE memories SET superseded_by = ?1 WHERE id = ?2",
161            params![superseded_by, id],
162        )?;
163        if affected == 0 {
164            return Err(StorageError::NotFound(format!("memory id={id}")));
165        }
166        Ok(())
167    }
168
169    pub fn delete_memory(&self, id: &str) -> Result<(), StorageError> {
170        let affected = self
171            .connection()
172            .execute("DELETE FROM memories WHERE id = ?1", params![id])?;
173        if affected == 0 {
174            return Err(StorageError::NotFound(format!("memory id={id}")));
175        }
176        Ok(())
177    }
178
179    pub fn bulk_insert_memories(&self, memories: &[Memory]) -> Result<usize, StorageError> {
180        let transaction = self.connection().unchecked_transaction()?;
181        let mut statement = transaction.prepare(INSERT_SQL)?;
182        let mut count = 0;
183        for memory in memories {
184            statement.execute(params![
185                memory.id,
186                memory.memory_type,
187                memory.context,
188                memory.action,
189                memory.result,
190                memory.score,
191                memory.embedding_context,
192                memory.embedding_action,
193                memory.embedding_result,
194                memory.indexed,
195                memory.tags,
196                memory.project,
197                memory.parent_id,
198                memory.source_ids,
199                memory.insight_type,
200                memory.created_at,
201                memory.updated_at,
202                memory.used_count,
203                memory.last_used_at,
204                memory.superseded_by,
205            ])?;
206            count += 1;
207        }
208        drop(statement);
209        transaction.commit()?;
210        Ok(count)
211    }
212
213    pub fn list_all_memories(&self) -> Result<Vec<Memory>, StorageError> {
214        let mut statement = self
215            .connection()
216            .prepare("SELECT * FROM memories WHERE superseded_by IS NULL")?;
217        let rows = statement.query_map([], row_to_memory)?;
218        let mut results = Vec::new();
219        for row in rows {
220            results.push(row?);
221        }
222        Ok(results)
223    }
224
225    pub fn get_unindexed_memories(&self, limit: usize) -> Result<Vec<Memory>, StorageError> {
226        let mut statement = self
227            .connection()
228            .prepare("SELECT * FROM memories WHERE indexed = FALSE LIMIT ?1")?;
229        let rows = statement.query_map(params![limit as i64], row_to_memory)?;
230        let mut results = Vec::new();
231        for row in rows {
232            results.push(row?);
233        }
234        Ok(results)
235    }
236
237    pub fn get_indexed_memory_ids(&self) -> Result<Vec<String>, StorageError> {
238        let mut statement = self
239            .connection()
240            .prepare("SELECT id FROM memories WHERE indexed = TRUE")?;
241        let rows = statement.query_map([], |row| row.get::<_, String>(0))?;
242        let mut results = Vec::new();
243        for row in rows {
244            results.push(row?);
245        }
246        Ok(results)
247    }
248}