Skip to main content

punch_memory/
substrate.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::{Arc, Mutex as StdMutex};
4
5use rusqlite::Connection;
6use tokio::sync::Mutex;
7use tracing::info;
8
9use punch_types::PunchResult;
10
11use crate::embeddings::{BuiltInEmbedder, Embedder, EmbeddingStore};
12use crate::migrations;
13
14/// The core persistence handle for Punch.
15///
16/// Wraps a SQLite [`Connection`] behind a [`tokio::sync::Mutex`] so it can be
17/// shared across async tasks without blocking the executor. Optionally includes
18/// an [`EmbeddingStore`] for semantic search over stored memories.
19pub struct MemorySubstrate {
20    pub(crate) conn: Mutex<Connection>,
21    /// Optional embedding store for semantic recall.
22    embedding_store: Option<StdMutex<EmbeddingStore>>,
23}
24
25impl MemorySubstrate {
26    /// Open (or create) a SQLite database at `path` and run pending migrations.
27    pub fn new(path: &Path) -> PunchResult<Self> {
28        let conn = Connection::open(path).map_err(|e| {
29            punch_types::PunchError::Memory(format!("failed to open database: {e}"))
30        })?;
31
32        // Enable WAL mode for better concurrent-read performance.
33        conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA foreign_keys = ON;")
34            .map_err(|e| punch_types::PunchError::Memory(format!("failed to set pragmas: {e}")))?;
35
36        migrations::migrate(&conn)?;
37
38        info!(path = %path.display(), "memory substrate initialized");
39
40        Ok(Self {
41            conn: Mutex::new(conn),
42            embedding_store: None,
43        })
44    }
45
46    /// Get a lock on the underlying database connection.
47    ///
48    /// This is intended for advanced queries that don't have a dedicated method.
49    /// Prefer using the higher-level methods on `MemorySubstrate` when possible.
50    pub async fn conn(&self) -> tokio::sync::MutexGuard<'_, Connection> {
51        self.conn.lock().await
52    }
53
54    /// Create an in-memory substrate (useful for testing).
55    pub fn in_memory() -> PunchResult<Self> {
56        let conn = Connection::open_in_memory().map_err(|e| {
57            punch_types::PunchError::Memory(format!("failed to open in-memory database: {e}"))
58        })?;
59
60        conn.execute_batch("PRAGMA foreign_keys = ON;")
61            .map_err(|e| punch_types::PunchError::Memory(format!("failed to set pragmas: {e}")))?;
62
63        migrations::migrate(&conn)?;
64
65        Ok(Self {
66            conn: Mutex::new(conn),
67            embedding_store: None,
68        })
69    }
70
71    /// Attach an embedding store with the given embedder for semantic recall.
72    ///
73    /// The embedding store shares a *separate* SQLite connection (via
74    /// `std::sync::Mutex`) since it operates synchronously.
75    pub fn with_embedding_store(
76        mut self,
77        conn: Arc<StdMutex<Connection>>,
78        embedder: Box<dyn Embedder>,
79    ) -> PunchResult<Self> {
80        let store = EmbeddingStore::new(conn, embedder)?;
81        self.embedding_store = Some(StdMutex::new(store));
82        Ok(self)
83    }
84
85    /// Attach a default built-in (TF-IDF) embedding store using an in-memory
86    /// SQLite connection. Useful for testing and offline operation.
87    pub fn with_builtin_embeddings(mut self) -> PunchResult<Self> {
88        let conn = Connection::open_in_memory().map_err(|e| {
89            punch_types::PunchError::Memory(format!("failed to open embedding db: {e}"))
90        })?;
91        let arc = Arc::new(StdMutex::new(conn));
92        let embedder = BuiltInEmbedder::new();
93        let store = EmbeddingStore::new(arc, Box::new(embedder))?;
94        self.embedding_store = Some(StdMutex::new(store));
95        Ok(self)
96    }
97
98    /// Returns whether an embedding store is attached.
99    pub fn has_embedding_store(&self) -> bool {
100        self.embedding_store.is_some()
101    }
102
103    /// Store a text embedding (if the embedding store is attached).
104    pub fn embed_and_store(
105        &self,
106        text: &str,
107        metadata: HashMap<String, String>,
108    ) -> PunchResult<Option<String>> {
109        if let Some(ref store_mutex) = self.embedding_store {
110            let store = store_mutex
111                .lock()
112                .map_err(|e| punch_types::PunchError::Memory(format!("lock failed: {e}")))?;
113            let id = store.store(text, metadata)?;
114            Ok(Some(id))
115        } else {
116            Ok(None)
117        }
118    }
119
120    /// Perform semantic search over stored embeddings. Falls back to `None`
121    /// if no embedding store is attached.
122    pub fn semantic_search(
123        &self,
124        query: &str,
125        k: usize,
126    ) -> PunchResult<Option<Vec<(f32, crate::embeddings::Embedding)>>> {
127        if let Some(ref store_mutex) = self.embedding_store {
128            let store = store_mutex
129                .lock()
130                .map_err(|e| punch_types::PunchError::Memory(format!("lock failed: {e}")))?;
131            let results = store.search(query, k)?;
132            Ok(Some(results))
133        } else {
134            Ok(None)
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn test_in_memory_creation() {
145        let substrate = MemorySubstrate::in_memory();
146        assert!(substrate.is_ok());
147    }
148
149    #[test]
150    fn test_no_embedding_store_by_default() {
151        let substrate = MemorySubstrate::in_memory().unwrap();
152        assert!(!substrate.has_embedding_store());
153    }
154
155    #[test]
156    fn test_with_builtin_embeddings() {
157        let substrate = MemorySubstrate::in_memory()
158            .unwrap()
159            .with_builtin_embeddings()
160            .unwrap();
161        assert!(substrate.has_embedding_store());
162    }
163
164    #[test]
165    fn test_embed_and_store_without_store() {
166        let substrate = MemorySubstrate::in_memory().unwrap();
167        let result = substrate.embed_and_store("hello", HashMap::new()).unwrap();
168        assert!(result.is_none(), "no embedding store means None");
169    }
170
171    #[test]
172    fn test_semantic_search_without_store() {
173        let substrate = MemorySubstrate::in_memory().unwrap();
174        let result = substrate.semantic_search("hello", 5).unwrap();
175        assert!(result.is_none(), "no embedding store means None");
176    }
177
178    #[test]
179    fn test_embed_and_store_with_builtin() {
180        let substrate = MemorySubstrate::in_memory()
181            .unwrap()
182            .with_builtin_embeddings()
183            .unwrap();
184        let result = substrate
185            .embed_and_store("test text", HashMap::new())
186            .unwrap();
187        assert!(result.is_some());
188    }
189
190    #[test]
191    fn test_semantic_search_with_builtin() {
192        let substrate = MemorySubstrate::in_memory()
193            .unwrap()
194            .with_builtin_embeddings()
195            .unwrap();
196        substrate
197            .embed_and_store("hello world", HashMap::new())
198            .unwrap();
199        let results = substrate.semantic_search("hello", 5).unwrap();
200        assert!(results.is_some());
201    }
202
203    #[tokio::test]
204    async fn test_conn_access() {
205        let substrate = MemorySubstrate::in_memory().unwrap();
206        let conn = substrate.conn().await;
207        let count: i64 = conn
208            .query_row(
209                "SELECT COUNT(*) FROM sqlite_master WHERE type='table'",
210                [],
211                |row| row.get(0),
212            )
213            .unwrap();
214        assert!(count > 0, "should have tables from migrations");
215    }
216}