Skip to main content

edgechain_memory/
sqlite.rs

1use async_trait::async_trait;
2use rusqlite::{Connection, params};
3use std::sync::Mutex;
4use tracing::debug;
5use crate::{
6    error::MemoryError,
7    store::{MemoryEntry, MemoryStore},
8};
9
10/// Persistent long-term memory backed by SQLite.
11pub struct SqliteMemoryStore {
12    conn: Mutex<Connection>,
13}
14
15impl SqliteMemoryStore {
16    pub fn open(path: &str) -> Result<Self, MemoryError> {
17        let conn = Connection::open(path)?;
18        let store = Self { conn: Mutex::new(conn) };
19        store.init_schema()?;
20        Ok(store)
21    }
22
23    pub fn in_memory() -> Result<Self, MemoryError> {
24        let conn = Connection::open_in_memory()?;
25        let store = Self { conn: Mutex::new(conn) };
26        store.init_schema()?;
27        Ok(store)
28    }
29
30    fn init_schema(&self) -> Result<(), MemoryError> {
31        let conn = self.conn.lock().unwrap();
32        conn.execute_batch(
33            "CREATE TABLE IF NOT EXISTS memory (
34                id          TEXT PRIMARY KEY,
35                key         TEXT NOT NULL UNIQUE,
36                value       TEXT NOT NULL,
37                tags        TEXT NOT NULL DEFAULT '[]',
38                created_at  TEXT NOT NULL,
39                updated_at  TEXT NOT NULL
40            );
41            CREATE INDEX IF NOT EXISTS idx_memory_key ON memory(key);",
42        )?;
43        Ok(())
44    }
45
46    fn row_to_entry(
47        id: String,
48        key: String,
49        value_str: String,
50        tags_str: String,
51        created_at_str: String,
52        updated_at_str: String,
53    ) -> Result<MemoryEntry, MemoryError> {
54        Ok(MemoryEntry {
55            id,
56            key,
57            value: serde_json::from_str(&value_str)?,
58            tags: serde_json::from_str(&tags_str)?,
59            created_at: created_at_str.parse().unwrap_or_else(|_| chrono::Utc::now()),
60            updated_at: updated_at_str.parse().unwrap_or_else(|_| chrono::Utc::now()),
61        })
62    }
63}
64
65#[async_trait]
66impl MemoryStore for SqliteMemoryStore {
67    async fn set(&self, key: &str, value: serde_json::Value) -> Result<(), MemoryError> {
68        let conn = self.conn.lock().unwrap();
69        let now = chrono::Utc::now().to_rfc3339();
70        let id = uuid::Uuid::new_v4().to_string();
71        let value_str = serde_json::to_string(&value)?;
72        conn.execute(
73            "INSERT INTO memory (id, key, value, tags, created_at, updated_at)
74             VALUES (?1, ?2, ?3, '[]', ?4, ?4)
75             ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at",
76            params![id, key, value_str, now],
77        )?;
78        debug!(key, "SqliteMemoryStore::set");
79        Ok(())
80    }
81
82    async fn get(&self, key: &str) -> Result<Option<MemoryEntry>, MemoryError> {
83        let conn = self.conn.lock().unwrap();
84        let mut stmt = conn.prepare(
85            "SELECT id, key, value, tags, created_at, updated_at FROM memory WHERE key = ?1",
86        )?;
87        let mut rows = stmt.query(params![key])?;
88        if let Some(row) = rows.next()? {
89            let entry = Self::row_to_entry(
90                row.get(0)?, row.get(1)?, row.get(2)?,
91                row.get(3)?, row.get(4)?, row.get(5)?,
92            )?;
93            Ok(Some(entry))
94        } else {
95            Ok(None)
96        }
97    }
98
99    async fn delete(&self, key: &str) -> Result<(), MemoryError> {
100        let conn = self.conn.lock().unwrap();
101        conn.execute("DELETE FROM memory WHERE key = ?1", params![key])?;
102        Ok(())
103    }
104
105    async fn list_keys(&self, prefix: Option<&str>) -> Result<Vec<String>, MemoryError> {
106        let conn = self.conn.lock().unwrap();
107        let (sql, pattern): (&str, String) = match prefix {
108            Some(p) => ("SELECT key FROM memory WHERE key LIKE ?1 ORDER BY key", format!("{p}%")),
109            None => ("SELECT key FROM memory ORDER BY key", String::new()),
110        };
111        let mut stmt = conn.prepare(sql)?;
112        let keys = if prefix.is_some() {
113            stmt.query_map(params![pattern], |row| row.get(0))?
114                .collect::<Result<Vec<String>, _>>()?
115        } else {
116            stmt.query_map([], |row| row.get(0))?
117                .collect::<Result<Vec<String>, _>>()?
118        };
119        Ok(keys)
120    }
121
122    async fn search_by_tag(&self, tag: &str) -> Result<Vec<MemoryEntry>, MemoryError> {
123        let conn = self.conn.lock().unwrap();
124        let mut stmt = conn.prepare(
125            "SELECT id, key, value, tags, created_at, updated_at FROM memory WHERE tags LIKE ?1",
126        )?;
127        let pattern = format!("%\"{tag}\"%" );
128        let entries = stmt
129            .query_map(params![pattern], |row| {
130                Ok((
131                    row.get::<_, String>(0)?,
132                    row.get::<_, String>(1)?,
133                    row.get::<_, String>(2)?,
134                    row.get::<_, String>(3)?,
135                    row.get::<_, String>(4)?,
136                    row.get::<_, String>(5)?,
137                ))
138            })?
139            .filter_map(|r| r.ok())
140            .filter_map(|(id, key, val, tags, ca, ua)| {
141                Self::row_to_entry(id, key, val, tags, ca, ua).ok()
142            })
143            .collect();
144        Ok(entries)
145    }
146
147    async fn set_tagged(
148        &self,
149        key: &str,
150        value: serde_json::Value,
151        tags: Vec<String>,
152    ) -> Result<(), MemoryError> {
153        let conn = self.conn.lock().unwrap();
154        let now = chrono::Utc::now().to_rfc3339();
155        let id = uuid::Uuid::new_v4().to_string();
156        let value_str = serde_json::to_string(&value)?;
157        let tags_str = serde_json::to_string(&tags)?;
158        conn.execute(
159            "INSERT INTO memory (id, key, value, tags, created_at, updated_at)
160             VALUES (?1, ?2, ?3, ?4, ?5, ?5)
161             ON CONFLICT(key) DO UPDATE SET value = excluded.value, tags = excluded.tags, updated_at = excluded.updated_at",
162            params![id, key, value_str, tags_str, now],
163        )?;
164        Ok(())
165    }
166}