1use super::{Memory, MemoryEntry};
7use crate::llm::types::Message;
8use async_trait::async_trait;
9use rusqlite::{params, Connection, OptionalExtension};
10use std::path::PathBuf;
11use std::sync::{Arc, Mutex};
12
13pub struct SqliteStore {
32 db_path: PathBuf,
33 connection: Arc<Mutex<Connection>>,
34}
35
36impl SqliteStore {
37 pub async fn new<P: Into<PathBuf>>(db_path: P) -> Result<Self, String> {
54 let db_path = db_path.into();
55 let connection = Connection::open(&db_path)
56 .map_err(|e| format!("Failed to open database: {}", e))?;
57
58 connection
60 .execute(
61 "CREATE TABLE IF NOT EXISTS memory_entries (
62 id TEXT PRIMARY KEY,
63 agent_id TEXT NOT NULL,
64 task_id TEXT NOT NULL,
65 messages TEXT NOT NULL,
66 created_at INTEGER NOT NULL
67 )",
68 [],
69 )
70 .map_err(|e| format!("Failed to create table: {}", e))?;
71
72 connection
74 .execute(
75 "CREATE INDEX IF NOT EXISTS idx_agent_id ON memory_entries(agent_id, created_at DESC)",
76 [],
77 )
78 .map_err(|e| format!("Failed to create index: {}", e))?;
79
80 Ok(Self {
81 db_path,
82 connection: Arc::new(Mutex::new(connection)),
83 })
84 }
85
86 pub fn db_path(&self) -> &PathBuf {
88 &self.db_path
89 }
90}
91
92#[async_trait]
93impl Memory for SqliteStore {
94 async fn store(&self, entry: MemoryEntry) -> Result<String, String> {
95 let entry_id = entry.id.clone();
96 let connection = Arc::clone(&self.connection);
97
98 tokio::task::spawn_blocking(move || {
99 let messages_json = serde_json::to_string(&entry.messages)
100 .map_err(|e| format!("Failed to serialize messages: {}", e))?;
101
102 let conn = connection.lock().unwrap();
103 conn.execute(
104 "INSERT INTO memory_entries (id, agent_id, task_id, messages, created_at)
105 VALUES (?1, ?2, ?3, ?4, ?5)",
106 params![
107 entry.id,
108 entry.agent_id,
109 entry.task_id,
110 messages_json,
111 entry.created_at as i64
112 ],
113 )
114 .map_err(|e| format!("Failed to store entry: {}", e))?;
115
116 Ok(entry_id)
117 })
118 .await
119 .map_err(|e| format!("Task join error: {}", e))?
120 }
121
122 async fn get(&self, id: &str) -> Result<Option<MemoryEntry>, String> {
123 let id = id.to_string();
124 let connection = Arc::clone(&self.connection);
125
126 tokio::task::spawn_blocking(move || {
127 let conn = connection.lock().unwrap();
128 let mut stmt = conn
129 .prepare("SELECT id, agent_id, task_id, messages, created_at FROM memory_entries WHERE id = ?1")
130 .map_err(|e| format!("Failed to prepare statement: {}", e))?;
131
132 let result = stmt
133 .query_row(params![id], |row| {
134 let messages_json: String = row.get(3)?;
135 let messages: Vec<Message> = serde_json::from_str(&messages_json)
136 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
137
138 Ok(MemoryEntry {
139 id: row.get(0)?,
140 agent_id: row.get(1)?,
141 task_id: row.get(2)?,
142 messages,
143 created_at: row.get::<_, i64>(4)? as u64,
144 })
145 })
146 .optional()
147 .map_err(|e| format!("Failed to query entry: {}", e))?;
148
149 Ok(result)
150 })
151 .await
152 .map_err(|e| format!("Task join error: {}", e))?
153 }
154
155 async fn get_agent_history(&self, agent_id: &str) -> Result<Vec<MemoryEntry>, String> {
156 let agent_id = agent_id.to_string();
157 let connection = Arc::clone(&self.connection);
158
159 tokio::task::spawn_blocking(move || {
160 let conn = connection.lock().unwrap();
161 let mut stmt = conn
162 .prepare(
163 "SELECT id, agent_id, task_id, messages, created_at
164 FROM memory_entries
165 WHERE agent_id = ?1
166 ORDER BY created_at DESC",
167 )
168 .map_err(|e| format!("Failed to prepare statement: {}", e))?;
169
170 let entries = stmt
171 .query_map(params![agent_id], |row| {
172 let messages_json: String = row.get(3)?;
173 let messages: Vec<Message> = serde_json::from_str(&messages_json)
174 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
175
176 Ok(MemoryEntry {
177 id: row.get(0)?,
178 agent_id: row.get(1)?,
179 task_id: row.get(2)?,
180 messages,
181 created_at: row.get::<_, i64>(4)? as u64,
182 })
183 })
184 .map_err(|e| format!("Failed to query entries: {}", e))?
185 .collect::<Result<Vec<_>, _>>()
186 .map_err(|e| format!("Failed to collect entries: {}", e))?;
187
188 Ok(entries)
189 })
190 .await
191 .map_err(|e| format!("Task join error: {}", e))?
192 }
193
194 async fn get_recent(&self, agent_id: &str, limit: usize) -> Result<Vec<MemoryEntry>, String> {
195 let agent_id = agent_id.to_string();
196 let connection = Arc::clone(&self.connection);
197
198 tokio::task::spawn_blocking(move || {
199 let conn = connection.lock().unwrap();
200 let mut stmt = conn
201 .prepare(
202 "SELECT id, agent_id, task_id, messages, created_at
203 FROM memory_entries
204 WHERE agent_id = ?1
205 ORDER BY created_at DESC
206 LIMIT ?2",
207 )
208 .map_err(|e| format!("Failed to prepare statement: {}", e))?;
209
210 let entries = stmt
211 .query_map(params![agent_id, limit as i64], |row| {
212 let messages_json: String = row.get(3)?;
213 let messages: Vec<Message> = serde_json::from_str(&messages_json)
214 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
215
216 Ok(MemoryEntry {
217 id: row.get(0)?,
218 agent_id: row.get(1)?,
219 task_id: row.get(2)?,
220 messages,
221 created_at: row.get::<_, i64>(4)? as u64,
222 })
223 })
224 .map_err(|e| format!("Failed to query entries: {}", e))?
225 .collect::<Result<Vec<_>, _>>()
226 .map_err(|e| format!("Failed to collect entries: {}", e))?;
227
228 Ok(entries)
229 })
230 .await
231 .map_err(|e| format!("Task join error: {}", e))?
232 }
233
234 async fn search(&self, agent_id: &str, query: &str) -> Result<Vec<MemoryEntry>, String> {
235 let agent_id = agent_id.to_string();
236 let query = query.to_string();
237 let connection = Arc::clone(&self.connection);
238
239 tokio::task::spawn_blocking(move || {
240 let conn = connection.lock().unwrap();
241 let mut stmt = conn
242 .prepare(
243 "SELECT id, agent_id, task_id, messages, created_at
244 FROM memory_entries
245 WHERE agent_id = ?1 AND messages LIKE ?2
246 ORDER BY created_at DESC",
247 )
248 .map_err(|e| format!("Failed to prepare statement: {}", e))?;
249
250 let search_pattern = format!("%{}%", query);
251 let entries = stmt
252 .query_map(params![agent_id, search_pattern], |row| {
253 let messages_json: String = row.get(3)?;
254 let messages: Vec<Message> = serde_json::from_str(&messages_json)
255 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
256
257 Ok(MemoryEntry {
258 id: row.get(0)?,
259 agent_id: row.get(1)?,
260 task_id: row.get(2)?,
261 messages,
262 created_at: row.get::<_, i64>(4)? as u64,
263 })
264 })
265 .map_err(|e| format!("Failed to query entries: {}", e))?
266 .collect::<Result<Vec<_>, _>>()
267 .map_err(|e| format!("Failed to collect entries: {}", e))?;
268
269 Ok(entries)
270 })
271 .await
272 .map_err(|e| format!("Task join error: {}", e))?
273 }
274
275 async fn clear_agent_memory(&self, agent_id: &str) -> Result<(), String> {
276 let agent_id = agent_id.to_string();
277 let connection = Arc::clone(&self.connection);
278
279 tokio::task::spawn_blocking(move || {
280 let conn = connection.lock().unwrap();
281 conn.execute(
282 "DELETE FROM memory_entries WHERE agent_id = ?1",
283 params![agent_id],
284 )
285 .map_err(|e| format!("Failed to clear agent memory: {}", e))?;
286
287 Ok(())
288 })
289 .await
290 .map_err(|e| format!("Task join error: {}", e))?
291 }
292}