1use async_trait::async_trait;
2use rusqlite::{Connection, params};
3use std::sync::Mutex;
4use tracing::debug;
5use crate::{
6 error::MemoryError,
7 store::{MemoryEntry, MemoryStore},
8};
9
10pub struct SqliteMemoryStore {
12 conn: Mutex<Connection>,
13}
14
15impl SqliteMemoryStore {
16 pub fn open(path: &str) -> Result<Self, MemoryError> {
17 let conn = Connection::open(path)?;
18 let store = Self { conn: Mutex::new(conn) };
19 store.init_schema()?;
20 Ok(store)
21 }
22
23 pub fn in_memory() -> Result<Self, MemoryError> {
24 let conn = Connection::open_in_memory()?;
25 let store = Self { conn: Mutex::new(conn) };
26 store.init_schema()?;
27 Ok(store)
28 }
29
30 fn init_schema(&self) -> Result<(), MemoryError> {
31 let conn = self.conn.lock().unwrap();
32 conn.execute_batch(
33 "CREATE TABLE IF NOT EXISTS memory (
34 id TEXT PRIMARY KEY,
35 key TEXT NOT NULL UNIQUE,
36 value TEXT NOT NULL,
37 tags TEXT NOT NULL DEFAULT '[]',
38 created_at TEXT NOT NULL,
39 updated_at TEXT NOT NULL
40 );
41 CREATE INDEX IF NOT EXISTS idx_memory_key ON memory(key);",
42 )?;
43 Ok(())
44 }
45
46 fn row_to_entry(
47 id: String,
48 key: String,
49 value_str: String,
50 tags_str: String,
51 created_at_str: String,
52 updated_at_str: String,
53 ) -> Result<MemoryEntry, MemoryError> {
54 Ok(MemoryEntry {
55 id,
56 key,
57 value: serde_json::from_str(&value_str)?,
58 tags: serde_json::from_str(&tags_str)?,
59 created_at: created_at_str.parse().unwrap_or_else(|_| chrono::Utc::now()),
60 updated_at: updated_at_str.parse().unwrap_or_else(|_| chrono::Utc::now()),
61 })
62 }
63}
64
65#[async_trait]
66impl MemoryStore for SqliteMemoryStore {
67 async fn set(&self, key: &str, value: serde_json::Value) -> Result<(), MemoryError> {
68 let conn = self.conn.lock().unwrap();
69 let now = chrono::Utc::now().to_rfc3339();
70 let id = uuid::Uuid::new_v4().to_string();
71 let value_str = serde_json::to_string(&value)?;
72 conn.execute(
73 "INSERT INTO memory (id, key, value, tags, created_at, updated_at)
74 VALUES (?1, ?2, ?3, '[]', ?4, ?4)
75 ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at",
76 params![id, key, value_str, now],
77 )?;
78 debug!(key, "SqliteMemoryStore::set");
79 Ok(())
80 }
81
82 async fn get(&self, key: &str) -> Result<Option<MemoryEntry>, MemoryError> {
83 let conn = self.conn.lock().unwrap();
84 let mut stmt = conn.prepare(
85 "SELECT id, key, value, tags, created_at, updated_at FROM memory WHERE key = ?1",
86 )?;
87 let mut rows = stmt.query(params![key])?;
88 if let Some(row) = rows.next()? {
89 let entry = Self::row_to_entry(
90 row.get(0)?, row.get(1)?, row.get(2)?,
91 row.get(3)?, row.get(4)?, row.get(5)?,
92 )?;
93 Ok(Some(entry))
94 } else {
95 Ok(None)
96 }
97 }
98
99 async fn delete(&self, key: &str) -> Result<(), MemoryError> {
100 let conn = self.conn.lock().unwrap();
101 conn.execute("DELETE FROM memory WHERE key = ?1", params![key])?;
102 Ok(())
103 }
104
105 async fn list_keys(&self, prefix: Option<&str>) -> Result<Vec<String>, MemoryError> {
106 let conn = self.conn.lock().unwrap();
107 let (sql, pattern): (&str, String) = match prefix {
108 Some(p) => ("SELECT key FROM memory WHERE key LIKE ?1 ORDER BY key", format!("{p}%")),
109 None => ("SELECT key FROM memory ORDER BY key", String::new()),
110 };
111 let mut stmt = conn.prepare(sql)?;
112 let keys = if prefix.is_some() {
113 stmt.query_map(params![pattern], |row| row.get(0))?
114 .collect::<Result<Vec<String>, _>>()?
115 } else {
116 stmt.query_map([], |row| row.get(0))?
117 .collect::<Result<Vec<String>, _>>()?
118 };
119 Ok(keys)
120 }
121
122 async fn search_by_tag(&self, tag: &str) -> Result<Vec<MemoryEntry>, MemoryError> {
123 let conn = self.conn.lock().unwrap();
124 let mut stmt = conn.prepare(
125 "SELECT id, key, value, tags, created_at, updated_at FROM memory WHERE tags LIKE ?1",
126 )?;
127 let pattern = format!("%\"{tag}\"%" );
128 let entries = stmt
129 .query_map(params![pattern], |row| {
130 Ok((
131 row.get::<_, String>(0)?,
132 row.get::<_, String>(1)?,
133 row.get::<_, String>(2)?,
134 row.get::<_, String>(3)?,
135 row.get::<_, String>(4)?,
136 row.get::<_, String>(5)?,
137 ))
138 })?
139 .filter_map(|r| r.ok())
140 .filter_map(|(id, key, val, tags, ca, ua)| {
141 Self::row_to_entry(id, key, val, tags, ca, ua).ok()
142 })
143 .collect();
144 Ok(entries)
145 }
146
147 async fn set_tagged(
148 &self,
149 key: &str,
150 value: serde_json::Value,
151 tags: Vec<String>,
152 ) -> Result<(), MemoryError> {
153 let conn = self.conn.lock().unwrap();
154 let now = chrono::Utc::now().to_rfc3339();
155 let id = uuid::Uuid::new_v4().to_string();
156 let value_str = serde_json::to_string(&value)?;
157 let tags_str = serde_json::to_string(&tags)?;
158 conn.execute(
159 "INSERT INTO memory (id, key, value, tags, created_at, updated_at)
160 VALUES (?1, ?2, ?3, ?4, ?5, ?5)
161 ON CONFLICT(key) DO UPDATE SET value = excluded.value, tags = excluded.tags, updated_at = excluded.updated_at",
162 params![id, key, value_str, tags_str, now],
163 )?;
164 Ok(())
165 }
166}