Skip to main content

punch_memory/
substrate.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::{Arc, Mutex as StdMutex};
4
5use rusqlite::Connection;
6use tokio::sync::Mutex;
7use tracing::info;
8
9use punch_types::PunchResult;
10
11use crate::embeddings::{EmbeddingStore, BuiltInEmbedder, Embedder};
12use crate::migrations;
13
14/// The core persistence handle for Punch.
15///
16/// Wraps a SQLite [`Connection`] behind a [`tokio::sync::Mutex`] so it can be
17/// shared across async tasks without blocking the executor. Optionally includes
18/// an [`EmbeddingStore`] for semantic search over stored memories.
19pub struct MemorySubstrate {
20    pub(crate) conn: Mutex<Connection>,
21    /// Optional embedding store for semantic recall.
22    embedding_store: Option<StdMutex<EmbeddingStore>>,
23}
24
25impl MemorySubstrate {
26    /// Open (or create) a SQLite database at `path` and run pending migrations.
27    pub fn new(path: &Path) -> PunchResult<Self> {
28        let conn = Connection::open(path).map_err(|e| {
29            punch_types::PunchError::Memory(format!("failed to open database: {e}"))
30        })?;
31
32        // Enable WAL mode for better concurrent-read performance.
33        conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA foreign_keys = ON;")
34            .map_err(|e| punch_types::PunchError::Memory(format!("failed to set pragmas: {e}")))?;
35
36        migrations::migrate(&conn)?;
37
38        info!(path = %path.display(), "memory substrate initialized");
39
40        Ok(Self {
41            conn: Mutex::new(conn),
42            embedding_store: None,
43        })
44    }
45
46    /// Get a lock on the underlying database connection.
47    ///
48    /// This is intended for advanced queries that don't have a dedicated method.
49    /// Prefer using the higher-level methods on `MemorySubstrate` when possible.
50    pub async fn conn(&self) -> tokio::sync::MutexGuard<'_, Connection> {
51        self.conn.lock().await
52    }
53
54    /// Create an in-memory substrate (useful for testing).
55    pub fn in_memory() -> PunchResult<Self> {
56        let conn = Connection::open_in_memory().map_err(|e| {
57            punch_types::PunchError::Memory(format!("failed to open in-memory database: {e}"))
58        })?;
59
60        conn.execute_batch("PRAGMA foreign_keys = ON;")
61            .map_err(|e| punch_types::PunchError::Memory(format!("failed to set pragmas: {e}")))?;
62
63        migrations::migrate(&conn)?;
64
65        Ok(Self {
66            conn: Mutex::new(conn),
67            embedding_store: None,
68        })
69    }
70
71    /// Attach an embedding store with the given embedder for semantic recall.
72    ///
73    /// The embedding store shares a *separate* SQLite connection (via
74    /// `std::sync::Mutex`) since it operates synchronously.
75    pub fn with_embedding_store(
76        mut self,
77        conn: Arc<StdMutex<Connection>>,
78        embedder: Box<dyn Embedder>,
79    ) -> PunchResult<Self> {
80        let store = EmbeddingStore::new(conn, embedder)?;
81        self.embedding_store = Some(StdMutex::new(store));
82        Ok(self)
83    }
84
85    /// Attach a default built-in (TF-IDF) embedding store using an in-memory
86    /// SQLite connection. Useful for testing and offline operation.
87    pub fn with_builtin_embeddings(mut self) -> PunchResult<Self> {
88        let conn = Connection::open_in_memory().map_err(|e| {
89            punch_types::PunchError::Memory(format!(
90                "failed to open embedding db: {e}"
91            ))
92        })?;
93        let arc = Arc::new(StdMutex::new(conn));
94        let embedder = BuiltInEmbedder::new();
95        let store = EmbeddingStore::new(arc, Box::new(embedder))?;
96        self.embedding_store = Some(StdMutex::new(store));
97        Ok(self)
98    }
99
100    /// Returns whether an embedding store is attached.
101    pub fn has_embedding_store(&self) -> bool {
102        self.embedding_store.is_some()
103    }
104
105    /// Store a text embedding (if the embedding store is attached).
106    pub fn embed_and_store(
107        &self,
108        text: &str,
109        metadata: HashMap<String, String>,
110    ) -> PunchResult<Option<String>> {
111        if let Some(ref store_mutex) = self.embedding_store {
112            let store = store_mutex
113                .lock()
114                .map_err(|e| punch_types::PunchError::Memory(format!("lock failed: {e}")))?;
115            let id = store.store(text, metadata)?;
116            Ok(Some(id))
117        } else {
118            Ok(None)
119        }
120    }
121
122    /// Perform semantic search over stored embeddings. Falls back to `None`
123    /// if no embedding store is attached.
124    pub fn semantic_search(
125        &self,
126        query: &str,
127        k: usize,
128    ) -> PunchResult<Option<Vec<(f32, crate::embeddings::Embedding)>>> {
129        if let Some(ref store_mutex) = self.embedding_store {
130            let store = store_mutex
131                .lock()
132                .map_err(|e| punch_types::PunchError::Memory(format!("lock failed: {e}")))?;
133            let results = store.search(query, k)?;
134            Ok(Some(results))
135        } else {
136            Ok(None)
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_in_memory_creation() {
147        let substrate = MemorySubstrate::in_memory();
148        assert!(substrate.is_ok());
149    }
150
151    #[test]
152    fn test_no_embedding_store_by_default() {
153        let substrate = MemorySubstrate::in_memory().unwrap();
154        assert!(!substrate.has_embedding_store());
155    }
156
157    #[test]
158    fn test_with_builtin_embeddings() {
159        let substrate = MemorySubstrate::in_memory()
160            .unwrap()
161            .with_builtin_embeddings()
162            .unwrap();
163        assert!(substrate.has_embedding_store());
164    }
165
166    #[test]
167    fn test_embed_and_store_without_store() {
168        let substrate = MemorySubstrate::in_memory().unwrap();
169        let result = substrate.embed_and_store("hello", HashMap::new()).unwrap();
170        assert!(result.is_none(), "no embedding store means None");
171    }
172
173    #[test]
174    fn test_semantic_search_without_store() {
175        let substrate = MemorySubstrate::in_memory().unwrap();
176        let result = substrate.semantic_search("hello", 5).unwrap();
177        assert!(result.is_none(), "no embedding store means None");
178    }
179
180    #[test]
181    fn test_embed_and_store_with_builtin() {
182        let substrate = MemorySubstrate::in_memory()
183            .unwrap()
184            .with_builtin_embeddings()
185            .unwrap();
186        let result = substrate.embed_and_store("test text", HashMap::new()).unwrap();
187        assert!(result.is_some());
188    }
189
190    #[test]
191    fn test_semantic_search_with_builtin() {
192        let substrate = MemorySubstrate::in_memory()
193            .unwrap()
194            .with_builtin_embeddings()
195            .unwrap();
196        substrate.embed_and_store("hello world", HashMap::new()).unwrap();
197        let results = substrate.semantic_search("hello", 5).unwrap();
198        assert!(results.is_some());
199    }
200
201    #[tokio::test]
202    async fn test_conn_access() {
203        let substrate = MemorySubstrate::in_memory().unwrap();
204        let conn = substrate.conn().await;
205        let count: i64 = conn
206            .query_row(
207                "SELECT COUNT(*) FROM sqlite_master WHERE type='table'",
208                [],
209                |row| row.get(0),
210            )
211            .unwrap();
212        assert!(count > 0, "should have tables from migrations");
213    }
214}