ceylon_next/memory/
sqlite.rs

1//! SQLite-backed persistent memory implementation.
2//!
3//! This module provides persistent storage using SQLite database,
4//! suitable for production use with efficient indexed queries.
5
6use super::{Memory, MemoryEntry};
7use crate::llm::types::Message;
8use async_trait::async_trait;
9use rusqlite::{params, Connection, OptionalExtension};
10use std::path::PathBuf;
11use std::sync::{Arc, Mutex};
12
13/// A SQLite-backed implementation of the [`Memory`] trait.
14///
15/// This provides persistent storage of conversation history using SQLite.
16/// All conversations are stored in a local SQLite database file.
17///
18/// # Examples
19///
20/// ```rust,no_run
21/// use ceylon_next::memory::SqliteStore;
22/// use std::sync::Arc;
23///
24/// #[tokio::main]
25/// async fn main() {
26///     let memory = Arc::new(
27///         SqliteStore::new("agent_memory.db").await.unwrap()
28///     );
29/// }
30/// ```
31pub struct SqliteStore {
32    db_path: PathBuf,
33    connection: Arc<Mutex<Connection>>,
34}
35
36impl SqliteStore {
37    /// Creates a new SQLite store with the given database path.
38    ///
39    /// # Arguments
40    ///
41    /// * `db_path` - Path to the SQLite database file (will be created if it doesn't exist)
42    ///
43    /// # Examples
44    ///
45    /// ```rust,no_run
46    /// use ceylon_next::memory::SqliteStore;
47    ///
48    /// #[tokio::main]
49    /// async fn main() {
50    ///     let store = SqliteStore::new("memory.db").await.unwrap();
51    /// }
52    /// ```
53    pub async fn new<P: Into<PathBuf>>(db_path: P) -> Result<Self, String> {
54        let db_path = db_path.into();
55        let connection = Connection::open(&db_path)
56            .map_err(|e| format!("Failed to open database: {}", e))?;
57
58        // Create the memory table if it doesn't exist
59        connection
60            .execute(
61                "CREATE TABLE IF NOT EXISTS memory_entries (
62                    id TEXT PRIMARY KEY,
63                    agent_id TEXT NOT NULL,
64                    task_id TEXT NOT NULL,
65                    messages TEXT NOT NULL,
66                    created_at INTEGER NOT NULL
67                )",
68                [],
69            )
70            .map_err(|e| format!("Failed to create table: {}", e))?;
71
72        // Create index for faster agent lookups
73        connection
74            .execute(
75                "CREATE INDEX IF NOT EXISTS idx_agent_id ON memory_entries(agent_id, created_at DESC)",
76                [],
77            )
78            .map_err(|e| format!("Failed to create index: {}", e))?;
79
80        Ok(Self {
81            db_path,
82            connection: Arc::new(Mutex::new(connection)),
83        })
84    }
85
86    /// Returns the path to the database file.
87    pub fn db_path(&self) -> &PathBuf {
88        &self.db_path
89    }
90}
91
92#[async_trait]
93impl Memory for SqliteStore {
94    async fn store(&self, entry: MemoryEntry) -> Result<String, String> {
95        let entry_id = entry.id.clone();
96        let connection = Arc::clone(&self.connection);
97
98        tokio::task::spawn_blocking(move || {
99            let messages_json = serde_json::to_string(&entry.messages)
100                .map_err(|e| format!("Failed to serialize messages: {}", e))?;
101
102            let conn = connection.lock().unwrap();
103            conn.execute(
104                "INSERT INTO memory_entries (id, agent_id, task_id, messages, created_at)
105                 VALUES (?1, ?2, ?3, ?4, ?5)",
106                params![
107                    entry.id,
108                    entry.agent_id,
109                    entry.task_id,
110                    messages_json,
111                    entry.created_at as i64
112                ],
113            )
114            .map_err(|e| format!("Failed to store entry: {}", e))?;
115
116            Ok(entry_id)
117        })
118        .await
119        .map_err(|e| format!("Task join error: {}", e))?
120    }
121
122    async fn get(&self, id: &str) -> Result<Option<MemoryEntry>, String> {
123        let id = id.to_string();
124        let connection = Arc::clone(&self.connection);
125
126        tokio::task::spawn_blocking(move || {
127            let conn = connection.lock().unwrap();
128            let mut stmt = conn
129                .prepare("SELECT id, agent_id, task_id, messages, created_at FROM memory_entries WHERE id = ?1")
130                .map_err(|e| format!("Failed to prepare statement: {}", e))?;
131
132            let result = stmt
133                .query_row(params![id], |row| {
134                    let messages_json: String = row.get(3)?;
135                    let messages: Vec<Message> = serde_json::from_str(&messages_json)
136                        .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
137
138                    Ok(MemoryEntry {
139                        id: row.get(0)?,
140                        agent_id: row.get(1)?,
141                        task_id: row.get(2)?,
142                        messages,
143                        created_at: row.get::<_, i64>(4)? as u64,
144                    })
145                })
146                .optional()
147                .map_err(|e| format!("Failed to query entry: {}", e))?;
148
149            Ok(result)
150        })
151        .await
152        .map_err(|e| format!("Task join error: {}", e))?
153    }
154
155    async fn get_agent_history(&self, agent_id: &str) -> Result<Vec<MemoryEntry>, String> {
156        let agent_id = agent_id.to_string();
157        let connection = Arc::clone(&self.connection);
158
159        tokio::task::spawn_blocking(move || {
160            let conn = connection.lock().unwrap();
161            let mut stmt = conn
162                .prepare(
163                    "SELECT id, agent_id, task_id, messages, created_at
164                     FROM memory_entries
165                     WHERE agent_id = ?1
166                     ORDER BY created_at DESC",
167                )
168                .map_err(|e| format!("Failed to prepare statement: {}", e))?;
169
170            let entries = stmt
171                .query_map(params![agent_id], |row| {
172                    let messages_json: String = row.get(3)?;
173                    let messages: Vec<Message> = serde_json::from_str(&messages_json)
174                        .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
175
176                    Ok(MemoryEntry {
177                        id: row.get(0)?,
178                        agent_id: row.get(1)?,
179                        task_id: row.get(2)?,
180                        messages,
181                        created_at: row.get::<_, i64>(4)? as u64,
182                    })
183                })
184                .map_err(|e| format!("Failed to query entries: {}", e))?
185                .collect::<Result<Vec<_>, _>>()
186                .map_err(|e| format!("Failed to collect entries: {}", e))?;
187
188            Ok(entries)
189        })
190        .await
191        .map_err(|e| format!("Task join error: {}", e))?
192    }
193
194    async fn get_recent(&self, agent_id: &str, limit: usize) -> Result<Vec<MemoryEntry>, String> {
195        let agent_id = agent_id.to_string();
196        let connection = Arc::clone(&self.connection);
197
198        tokio::task::spawn_blocking(move || {
199            let conn = connection.lock().unwrap();
200            let mut stmt = conn
201                .prepare(
202                    "SELECT id, agent_id, task_id, messages, created_at
203                     FROM memory_entries
204                     WHERE agent_id = ?1
205                     ORDER BY created_at DESC
206                     LIMIT ?2",
207                )
208                .map_err(|e| format!("Failed to prepare statement: {}", e))?;
209
210            let entries = stmt
211                .query_map(params![agent_id, limit as i64], |row| {
212                    let messages_json: String = row.get(3)?;
213                    let messages: Vec<Message> = serde_json::from_str(&messages_json)
214                        .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
215
216                    Ok(MemoryEntry {
217                        id: row.get(0)?,
218                        agent_id: row.get(1)?,
219                        task_id: row.get(2)?,
220                        messages,
221                        created_at: row.get::<_, i64>(4)? as u64,
222                    })
223                })
224                .map_err(|e| format!("Failed to query entries: {}", e))?
225                .collect::<Result<Vec<_>, _>>()
226                .map_err(|e| format!("Failed to collect entries: {}", e))?;
227
228            Ok(entries)
229        })
230        .await
231        .map_err(|e| format!("Task join error: {}", e))?
232    }
233
234    async fn search(&self, agent_id: &str, query: &str) -> Result<Vec<MemoryEntry>, String> {
235        let agent_id = agent_id.to_string();
236        let query = query.to_string();
237        let connection = Arc::clone(&self.connection);
238
239        tokio::task::spawn_blocking(move || {
240            let conn = connection.lock().unwrap();
241            let mut stmt = conn
242                .prepare(
243                    "SELECT id, agent_id, task_id, messages, created_at
244                     FROM memory_entries
245                     WHERE agent_id = ?1 AND messages LIKE ?2
246                     ORDER BY created_at DESC",
247                )
248                .map_err(|e| format!("Failed to prepare statement: {}", e))?;
249
250            let search_pattern = format!("%{}%", query);
251            let entries = stmt
252                .query_map(params![agent_id, search_pattern], |row| {
253                    let messages_json: String = row.get(3)?;
254                    let messages: Vec<Message> = serde_json::from_str(&messages_json)
255                        .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
256
257                    Ok(MemoryEntry {
258                        id: row.get(0)?,
259                        agent_id: row.get(1)?,
260                        task_id: row.get(2)?,
261                        messages,
262                        created_at: row.get::<_, i64>(4)? as u64,
263                    })
264                })
265                .map_err(|e| format!("Failed to query entries: {}", e))?
266                .collect::<Result<Vec<_>, _>>()
267                .map_err(|e| format!("Failed to collect entries: {}", e))?;
268
269            Ok(entries)
270        })
271        .await
272        .map_err(|e| format!("Task join error: {}", e))?
273    }
274
275    async fn clear_agent_memory(&self, agent_id: &str) -> Result<(), String> {
276        let agent_id = agent_id.to_string();
277        let connection = Arc::clone(&self.connection);
278
279        tokio::task::spawn_blocking(move || {
280            let conn = connection.lock().unwrap();
281            conn.execute(
282                "DELETE FROM memory_entries WHERE agent_id = ?1",
283                params![agent_id],
284            )
285            .map_err(|e| format!("Failed to clear agent memory: {}", e))?;
286
287            Ok(())
288        })
289        .await
290        .map_err(|e| format!("Task join error: {}", e))?
291    }
292}