Skip to main content

dot/memory/
mod.rs

1pub mod extract;
2pub mod tools;
3
4use anyhow::{Context, Result, bail};
5use chrono::Utc;
6use rusqlite::{Connection, params};
7use std::fmt;
8use std::sync::{Arc, Mutex};
9use uuid::Uuid;
10
11use crate::db::schema;
12
13#[derive(Debug, Clone, PartialEq)]
14pub enum MemoryKind {
15    Fact,
16    Preference,
17    Decision,
18    Project,
19    Entity,
20    Belief,
21}
22
23impl fmt::Display for MemoryKind {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        f.write_str(self.as_str())
26    }
27}
28
29impl MemoryKind {
30    pub fn as_str(&self) -> &str {
31        match self {
32            Self::Fact => "fact",
33            Self::Preference => "preference",
34            Self::Decision => "decision",
35            Self::Project => "project",
36            Self::Entity => "entity",
37            Self::Belief => "belief",
38        }
39    }
40
41    pub fn parse(s: &str) -> Self {
42        match s {
43            "preference" => Self::Preference,
44            "decision" => Self::Decision,
45            "project" => Self::Project,
46            "entity" => Self::Entity,
47            "belief" => Self::Belief,
48            _ => Self::Fact,
49        }
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct Memory {
55    pub id: String,
56    pub content: String,
57    pub kind: MemoryKind,
58    pub importance: f32,
59    pub access_count: u32,
60    pub source_conversation_id: Option<String>,
61    pub superseded_by: Option<String>,
62    pub created_at: String,
63    pub updated_at: String,
64}
65
66#[derive(Debug, Clone)]
67pub struct ScoredMemory {
68    pub memory: Memory,
69    pub score: f64,
70}
71
72#[derive(Debug, Clone)]
73pub struct MemoryBlock {
74    pub id: String,
75    pub name: String,
76    pub content: String,
77    pub updated_at: String,
78}
79
80pub struct MemoryStore {
81    conn: Arc<Mutex<Connection>>,
82}
83
84impl MemoryStore {
85    pub fn open() -> Result<Self> {
86        let path = crate::config::Config::db_path();
87        let conn = Connection::open(&path)
88            .with_context(|| format!("opening memory db at {}", path.display()))?;
89        conn.execute_batch("PRAGMA journal_mode=WAL;")
90            .context("enabling WAL mode")?;
91        let store = Self {
92            conn: Arc::new(Mutex::new(conn)),
93        };
94        store.init()?;
95        Ok(store)
96    }
97
98    fn init(&self) -> Result<()> {
99        let conn = self.conn.lock().unwrap();
100        conn.execute_batch(schema::CREATE_MEMORY_BLOCKS)
101            .context("creating memory_blocks table")?;
102        conn.execute_batch(schema::CREATE_MEMORIES)
103            .context("creating memories table")?;
104        conn.execute_batch(schema::CREATE_MEMORIES_FTS)
105            .context("creating memories_fts table")?;
106        conn.execute_batch(schema::CREATE_MEMORIES_TRIGGERS)
107            .context("creating memories triggers")?;
108
109        let count: i64 = conn
110            .query_row(
111                "SELECT COUNT(*) FROM memory_blocks WHERE name = 'human'",
112                [],
113                |r| r.get(0),
114            )
115            .unwrap_or(0);
116        if count == 0 {
117            let now = Utc::now().to_rfc3339();
118            conn.execute(
119                "INSERT INTO memory_blocks (id, name, content, updated_at) VALUES (?1, 'human', '', ?2)",
120                params![Uuid::new_v4().to_string(), now],
121            ).context("creating default human block")?;
122            conn.execute(
123                "INSERT INTO memory_blocks (id, name, content, updated_at) VALUES (?1, 'agent', '', ?2)",
124                params![Uuid::new_v4().to_string(), now],
125            ).context("creating default agent block")?;
126        }
127        Ok(())
128    }
129
130    pub fn get_block(&self, name: &str) -> Result<MemoryBlock> {
131        let conn = self.conn.lock().unwrap();
132        conn.query_row(
133            "SELECT id, name, content, updated_at FROM memory_blocks WHERE name = ?1",
134            params![name],
135            |row| {
136                Ok(MemoryBlock {
137                    id: row.get(0)?,
138                    name: row.get(1)?,
139                    content: row.get(2)?,
140                    updated_at: row.get(3)?,
141                })
142            },
143        )
144        .with_context(|| format!("getting memory block '{name}'"))
145    }
146
147    pub fn blocks(&self) -> Result<Vec<MemoryBlock>> {
148        let conn = self.conn.lock().unwrap();
149        let mut stmt = conn
150            .prepare("SELECT id, name, content, updated_at FROM memory_blocks ORDER BY name")
151            .context("preparing blocks query")?;
152        let rows = stmt
153            .query_map([], |row| {
154                Ok(MemoryBlock {
155                    id: row.get(0)?,
156                    name: row.get(1)?,
157                    content: row.get(2)?,
158                    updated_at: row.get(3)?,
159                })
160            })
161            .context("querying blocks")?;
162        let mut blocks = Vec::new();
163        for row in rows {
164            blocks.push(row.context("reading block row")?);
165        }
166        Ok(blocks)
167    }
168
169    pub fn update_block(&self, name: &str, content: &str) -> Result<()> {
170        let conn = self.conn.lock().unwrap();
171        let now = Utc::now().to_rfc3339();
172        let affected = conn
173            .execute(
174                "UPDATE memory_blocks SET content = ?1, updated_at = ?2 WHERE name = ?3",
175                params![content, now, name],
176            )
177            .context("updating memory block")?;
178        if affected == 0 {
179            bail!("no memory block named '{name}'");
180        }
181        Ok(())
182    }
183
184    pub fn add(
185        &self,
186        content: &str,
187        kind: &MemoryKind,
188        importance: f32,
189        source: Option<&str>,
190    ) -> Result<String> {
191        let id = Uuid::new_v4().to_string();
192        let now = Utc::now().to_rfc3339();
193        let conn = self.conn.lock().unwrap();
194        conn.execute(
195            "INSERT INTO memories (id, content, kind, importance, access_count, source_conversation_id, created_at, updated_at) \
196             VALUES (?1, ?2, ?3, ?4, 0, ?5, ?6, ?7)",
197            params![id, content, kind.as_str(), importance, source, now, now],
198        )
199        .context("adding memory")?;
200        tracing::debug!("added memory {id}: {}", &content[..content.len().min(60)]);
201        Ok(id)
202    }
203
204    pub fn update(&self, id: &str, content: &str, importance: f32) -> Result<()> {
205        let conn = self.conn.lock().unwrap();
206        let now = Utc::now().to_rfc3339();
207        let affected = conn
208            .execute(
209                "UPDATE memories SET content = ?1, importance = ?2, updated_at = ?3 WHERE id = ?4 AND superseded_by IS NULL",
210                params![content, importance, now, id],
211            )
212            .context("updating memory")?;
213        if affected == 0 {
214            bail!("memory '{id}' not found or superseded");
215        }
216        Ok(())
217    }
218
219    pub fn delete(&self, id: &str) -> Result<()> {
220        let conn = self.conn.lock().unwrap();
221        conn.execute("DELETE FROM memories WHERE id = ?1", params![id])
222            .context("deleting memory")?;
223        Ok(())
224    }
225
226    pub fn supersede(&self, old_id: &str, new_id: &str) -> Result<()> {
227        let conn = self.conn.lock().unwrap();
228        conn.execute(
229            "UPDATE memories SET superseded_by = ?1 WHERE id = ?2",
230            params![new_id, old_id],
231        )
232        .context("superseding memory")?;
233        Ok(())
234    }
235
236    pub fn search(&self, query: &str, limit: usize) -> Result<Vec<ScoredMemory>> {
237        let conn = self.conn.lock().unwrap();
238        let fts_query = query
239            .split_whitespace()
240            .map(|w| format!("\"{}\"", w.replace('"', "")))
241            .collect::<Vec<_>>()
242            .join(" OR ");
243        if fts_query.is_empty() {
244            return Ok(Vec::new());
245        }
246        let mut stmt = conn
247            .prepare(
248                "SELECT m.id, m.content, m.kind, m.importance, m.access_count,
249                        m.source_conversation_id, m.superseded_by, m.created_at, m.updated_at,
250                        (-bm25(memories_fts)) * 0.5
251                        + m.importance * (0.95 / (1.0 + (julianday('now') - julianday(m.updated_at)) / 7.0)) * 0.35
252                        + MIN(1.0, CAST(m.access_count AS REAL) / 10.0) * 0.15 AS score
253                 FROM memories m
254                 JOIN memories_fts ON memories_fts.rowid = m.rowid
255                 WHERE memories_fts MATCH ?1
256                   AND m.superseded_by IS NULL
257                 ORDER BY score DESC LIMIT ?2",
258            )
259            .context("preparing memory search")?;
260        let rows = stmt
261            .query_map(params![fts_query, limit as i64], |row| {
262                Ok(ScoredMemory {
263                    memory: Memory {
264                        id: row.get(0)?,
265                        content: row.get(1)?,
266                        kind: MemoryKind::parse(row.get::<_, String>(2)?.as_str()),
267                        importance: row.get(3)?,
268                        access_count: row.get::<_, i64>(4)? as u32,
269                        source_conversation_id: row.get(5)?,
270                        superseded_by: row.get(6)?,
271                        created_at: row.get(7)?,
272                        updated_at: row.get(8)?,
273                    },
274                    score: row.get(9)?,
275                })
276            })
277            .context("executing memory search")?;
278        let mut results = Vec::new();
279        for row in rows {
280            results.push(row.context("reading memory search row")?);
281        }
282        // bump access counts
283        for r in &results {
284            let _ = conn.execute(
285                "UPDATE memories SET access_count = access_count + 1 WHERE id = ?1",
286                params![r.memory.id],
287            );
288        }
289        Ok(results)
290    }
291
292    pub fn list(&self, kind: Option<&MemoryKind>, limit: usize) -> Result<Vec<Memory>> {
293        let conn = self.conn.lock().unwrap();
294        let (sql, kind_val);
295        if let Some(k) = kind {
296            kind_val = k.as_str().to_string();
297            sql = "SELECT id, content, kind, importance, access_count, source_conversation_id, superseded_by, created_at, updated_at \
298                   FROM memories WHERE kind = ?1 AND superseded_by IS NULL ORDER BY updated_at DESC LIMIT ?2";
299        } else {
300            kind_val = String::new();
301            sql = "SELECT id, content, kind, importance, access_count, source_conversation_id, superseded_by, created_at, updated_at \
302                   FROM memories WHERE superseded_by IS NULL ORDER BY updated_at DESC LIMIT ?2";
303        }
304        let mut stmt = conn.prepare(sql).context("preparing memory list")?;
305        let rows = if kind.is_some() {
306            stmt.query_map(params![kind_val, limit as i64], map_memory_row)
307                .context("listing memories")?
308        } else {
309            stmt.query_map(params![kind_val, limit as i64], map_memory_row)
310                .context("listing memories")?
311        };
312        let mut memories = Vec::new();
313        for row in rows {
314            memories.push(row.context("reading memory list row")?);
315        }
316        Ok(memories)
317    }
318
319    pub fn snapshot(&self, limit: usize) -> Result<Vec<Memory>> {
320        let conn = self.conn.lock().unwrap();
321        let mut stmt = conn
322            .prepare(
323                "SELECT id, content, kind, importance, access_count, source_conversation_id, superseded_by, created_at, updated_at \
324                 FROM memories WHERE superseded_by IS NULL ORDER BY importance DESC, updated_at DESC LIMIT ?1",
325            )
326            .context("preparing memory snapshot")?;
327        let rows = stmt
328            .query_map(params![limit as i64], map_memory_row)
329            .context("querying memory snapshot")?;
330        let mut memories = Vec::new();
331        for row in rows {
332            memories.push(row.context("reading snapshot row")?);
333        }
334        Ok(memories)
335    }
336
337    pub fn count(&self) -> Result<usize> {
338        let conn = self.conn.lock().unwrap();
339        let count: i64 = conn
340            .query_row(
341                "SELECT COUNT(*) FROM memories WHERE superseded_by IS NULL",
342                [],
343                |r| r.get(0),
344            )
345            .context("counting memories")?;
346        Ok(count as usize)
347    }
348
349    pub fn inject_context(&self, query: &str, count: usize) -> Result<String> {
350        let mut out = String::new();
351
352        // Core blocks
353        let blocks = self.blocks()?;
354        let has_core = blocks.iter().any(|b| !b.content.is_empty());
355        if has_core {
356            out.push_str("<memory>\n## Core\n");
357            for block in &blocks {
358                if !block.content.is_empty() {
359                    out.push_str(&format!("[{}]\n{}\n\n", block.name, block.content));
360                }
361            }
362        }
363
364        // Archival search
365        let results = self.search(query, count)?;
366        if !results.is_empty() {
367            if !has_core {
368                out.push_str("<memory>\n");
369            }
370            out.push_str("## Relevant Context\n");
371            for r in &results {
372                out.push_str(&format!(
373                    "[{}] {} ({:.2})\n",
374                    r.memory.kind, r.memory.content, r.memory.importance
375                ));
376            }
377        }
378
379        if !out.is_empty() {
380            out.push_str("</memory>");
381        }
382        Ok(out)
383    }
384}
385
386fn map_memory_row(row: &rusqlite::Row) -> rusqlite::Result<Memory> {
387    Ok(Memory {
388        id: row.get(0)?,
389        content: row.get(1)?,
390        kind: MemoryKind::parse(row.get::<_, String>(2)?.as_str()),
391        importance: row.get(3)?,
392        access_count: row.get::<_, i64>(4)? as u32,
393        source_conversation_id: row.get(5)?,
394        superseded_by: row.get(6)?,
395        created_at: row.get(7)?,
396        updated_at: row.get(8)?,
397    })
398}
399
400unsafe impl Send for MemoryStore {}
401unsafe impl Sync for MemoryStore {}