punch_memory/
substrate.rs1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::{Arc, Mutex as StdMutex};
4
5use rusqlite::Connection;
6use tokio::sync::Mutex;
7use tracing::info;
8
9use punch_types::PunchResult;
10
11use crate::embeddings::{BuiltInEmbedder, Embedder, EmbeddingStore};
12use crate::migrations;
13
14pub struct MemorySubstrate {
20 pub(crate) conn: Mutex<Connection>,
21 embedding_store: Option<StdMutex<EmbeddingStore>>,
23}
24
25impl MemorySubstrate {
26 pub fn new(path: &Path) -> PunchResult<Self> {
28 let conn = Connection::open(path).map_err(|e| {
29 punch_types::PunchError::Memory(format!("failed to open database: {e}"))
30 })?;
31
32 conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA foreign_keys = ON;")
34 .map_err(|e| punch_types::PunchError::Memory(format!("failed to set pragmas: {e}")))?;
35
36 migrations::migrate(&conn)?;
37
38 info!(path = %path.display(), "memory substrate initialized");
39
40 Ok(Self {
41 conn: Mutex::new(conn),
42 embedding_store: None,
43 })
44 }
45
46 pub async fn conn(&self) -> tokio::sync::MutexGuard<'_, Connection> {
51 self.conn.lock().await
52 }
53
54 pub fn in_memory() -> PunchResult<Self> {
56 let conn = Connection::open_in_memory().map_err(|e| {
57 punch_types::PunchError::Memory(format!("failed to open in-memory database: {e}"))
58 })?;
59
60 conn.execute_batch("PRAGMA foreign_keys = ON;")
61 .map_err(|e| punch_types::PunchError::Memory(format!("failed to set pragmas: {e}")))?;
62
63 migrations::migrate(&conn)?;
64
65 Ok(Self {
66 conn: Mutex::new(conn),
67 embedding_store: None,
68 })
69 }
70
71 pub fn with_embedding_store(
76 mut self,
77 conn: Arc<StdMutex<Connection>>,
78 embedder: Box<dyn Embedder>,
79 ) -> PunchResult<Self> {
80 let store = EmbeddingStore::new(conn, embedder)?;
81 self.embedding_store = Some(StdMutex::new(store));
82 Ok(self)
83 }
84
85 pub fn with_builtin_embeddings(mut self) -> PunchResult<Self> {
88 let conn = Connection::open_in_memory().map_err(|e| {
89 punch_types::PunchError::Memory(format!("failed to open embedding db: {e}"))
90 })?;
91 let arc = Arc::new(StdMutex::new(conn));
92 let embedder = BuiltInEmbedder::new();
93 let store = EmbeddingStore::new(arc, Box::new(embedder))?;
94 self.embedding_store = Some(StdMutex::new(store));
95 Ok(self)
96 }
97
98 pub fn has_embedding_store(&self) -> bool {
100 self.embedding_store.is_some()
101 }
102
103 pub fn embed_and_store(
105 &self,
106 text: &str,
107 metadata: HashMap<String, String>,
108 ) -> PunchResult<Option<String>> {
109 if let Some(ref store_mutex) = self.embedding_store {
110 let store = store_mutex
111 .lock()
112 .map_err(|e| punch_types::PunchError::Memory(format!("lock failed: {e}")))?;
113 let id = store.store(text, metadata)?;
114 Ok(Some(id))
115 } else {
116 Ok(None)
117 }
118 }
119
120 pub fn semantic_search(
123 &self,
124 query: &str,
125 k: usize,
126 ) -> PunchResult<Option<Vec<(f32, crate::embeddings::Embedding)>>> {
127 if let Some(ref store_mutex) = self.embedding_store {
128 let store = store_mutex
129 .lock()
130 .map_err(|e| punch_types::PunchError::Memory(format!("lock failed: {e}")))?;
131 let results = store.search(query, k)?;
132 Ok(Some(results))
133 } else {
134 Ok(None)
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn test_in_memory_creation() {
145 let substrate = MemorySubstrate::in_memory();
146 assert!(substrate.is_ok());
147 }
148
149 #[test]
150 fn test_no_embedding_store_by_default() {
151 let substrate = MemorySubstrate::in_memory().unwrap();
152 assert!(!substrate.has_embedding_store());
153 }
154
155 #[test]
156 fn test_with_builtin_embeddings() {
157 let substrate = MemorySubstrate::in_memory()
158 .unwrap()
159 .with_builtin_embeddings()
160 .unwrap();
161 assert!(substrate.has_embedding_store());
162 }
163
164 #[test]
165 fn test_embed_and_store_without_store() {
166 let substrate = MemorySubstrate::in_memory().unwrap();
167 let result = substrate.embed_and_store("hello", HashMap::new()).unwrap();
168 assert!(result.is_none(), "no embedding store means None");
169 }
170
171 #[test]
172 fn test_semantic_search_without_store() {
173 let substrate = MemorySubstrate::in_memory().unwrap();
174 let result = substrate.semantic_search("hello", 5).unwrap();
175 assert!(result.is_none(), "no embedding store means None");
176 }
177
178 #[test]
179 fn test_embed_and_store_with_builtin() {
180 let substrate = MemorySubstrate::in_memory()
181 .unwrap()
182 .with_builtin_embeddings()
183 .unwrap();
184 let result = substrate
185 .embed_and_store("test text", HashMap::new())
186 .unwrap();
187 assert!(result.is_some());
188 }
189
190 #[test]
191 fn test_semantic_search_with_builtin() {
192 let substrate = MemorySubstrate::in_memory()
193 .unwrap()
194 .with_builtin_embeddings()
195 .unwrap();
196 substrate
197 .embed_and_store("hello world", HashMap::new())
198 .unwrap();
199 let results = substrate.semantic_search("hello", 5).unwrap();
200 assert!(results.is_some());
201 }
202
203 #[tokio::test]
204 async fn test_conn_access() {
205 let substrate = MemorySubstrate::in_memory().unwrap();
206 let conn = substrate.conn().await;
207 let count: i64 = conn
208 .query_row(
209 "SELECT COUNT(*) FROM sqlite_master WHERE type='table'",
210 [],
211 |row| row.get(0),
212 )
213 .unwrap();
214 assert!(count > 0, "should have tables from migrations");
215 }
216}