punch_memory/
substrate.rs1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::{Arc, Mutex as StdMutex};
4
5use rusqlite::Connection;
6use tokio::sync::Mutex;
7use tracing::info;
8
9use punch_types::PunchResult;
10
11use crate::embeddings::{EmbeddingStore, BuiltInEmbedder, Embedder};
12use crate::migrations;
13
14pub struct MemorySubstrate {
20 pub(crate) conn: Mutex<Connection>,
21 embedding_store: Option<StdMutex<EmbeddingStore>>,
23}
24
25impl MemorySubstrate {
26 pub fn new(path: &Path) -> PunchResult<Self> {
28 let conn = Connection::open(path).map_err(|e| {
29 punch_types::PunchError::Memory(format!("failed to open database: {e}"))
30 })?;
31
32 conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA foreign_keys = ON;")
34 .map_err(|e| punch_types::PunchError::Memory(format!("failed to set pragmas: {e}")))?;
35
36 migrations::migrate(&conn)?;
37
38 info!(path = %path.display(), "memory substrate initialized");
39
40 Ok(Self {
41 conn: Mutex::new(conn),
42 embedding_store: None,
43 })
44 }
45
46 pub async fn conn(&self) -> tokio::sync::MutexGuard<'_, Connection> {
51 self.conn.lock().await
52 }
53
54 pub fn in_memory() -> PunchResult<Self> {
56 let conn = Connection::open_in_memory().map_err(|e| {
57 punch_types::PunchError::Memory(format!("failed to open in-memory database: {e}"))
58 })?;
59
60 conn.execute_batch("PRAGMA foreign_keys = ON;")
61 .map_err(|e| punch_types::PunchError::Memory(format!("failed to set pragmas: {e}")))?;
62
63 migrations::migrate(&conn)?;
64
65 Ok(Self {
66 conn: Mutex::new(conn),
67 embedding_store: None,
68 })
69 }
70
71 pub fn with_embedding_store(
76 mut self,
77 conn: Arc<StdMutex<Connection>>,
78 embedder: Box<dyn Embedder>,
79 ) -> PunchResult<Self> {
80 let store = EmbeddingStore::new(conn, embedder)?;
81 self.embedding_store = Some(StdMutex::new(store));
82 Ok(self)
83 }
84
85 pub fn with_builtin_embeddings(mut self) -> PunchResult<Self> {
88 let conn = Connection::open_in_memory().map_err(|e| {
89 punch_types::PunchError::Memory(format!(
90 "failed to open embedding db: {e}"
91 ))
92 })?;
93 let arc = Arc::new(StdMutex::new(conn));
94 let embedder = BuiltInEmbedder::new();
95 let store = EmbeddingStore::new(arc, Box::new(embedder))?;
96 self.embedding_store = Some(StdMutex::new(store));
97 Ok(self)
98 }
99
100 pub fn has_embedding_store(&self) -> bool {
102 self.embedding_store.is_some()
103 }
104
105 pub fn embed_and_store(
107 &self,
108 text: &str,
109 metadata: HashMap<String, String>,
110 ) -> PunchResult<Option<String>> {
111 if let Some(ref store_mutex) = self.embedding_store {
112 let store = store_mutex
113 .lock()
114 .map_err(|e| punch_types::PunchError::Memory(format!("lock failed: {e}")))?;
115 let id = store.store(text, metadata)?;
116 Ok(Some(id))
117 } else {
118 Ok(None)
119 }
120 }
121
122 pub fn semantic_search(
125 &self,
126 query: &str,
127 k: usize,
128 ) -> PunchResult<Option<Vec<(f32, crate::embeddings::Embedding)>>> {
129 if let Some(ref store_mutex) = self.embedding_store {
130 let store = store_mutex
131 .lock()
132 .map_err(|e| punch_types::PunchError::Memory(format!("lock failed: {e}")))?;
133 let results = store.search(query, k)?;
134 Ok(Some(results))
135 } else {
136 Ok(None)
137 }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn test_in_memory_creation() {
147 let substrate = MemorySubstrate::in_memory();
148 assert!(substrate.is_ok());
149 }
150
151 #[test]
152 fn test_no_embedding_store_by_default() {
153 let substrate = MemorySubstrate::in_memory().unwrap();
154 assert!(!substrate.has_embedding_store());
155 }
156
157 #[test]
158 fn test_with_builtin_embeddings() {
159 let substrate = MemorySubstrate::in_memory()
160 .unwrap()
161 .with_builtin_embeddings()
162 .unwrap();
163 assert!(substrate.has_embedding_store());
164 }
165
166 #[test]
167 fn test_embed_and_store_without_store() {
168 let substrate = MemorySubstrate::in_memory().unwrap();
169 let result = substrate.embed_and_store("hello", HashMap::new()).unwrap();
170 assert!(result.is_none(), "no embedding store means None");
171 }
172
173 #[test]
174 fn test_semantic_search_without_store() {
175 let substrate = MemorySubstrate::in_memory().unwrap();
176 let result = substrate.semantic_search("hello", 5).unwrap();
177 assert!(result.is_none(), "no embedding store means None");
178 }
179
180 #[test]
181 fn test_embed_and_store_with_builtin() {
182 let substrate = MemorySubstrate::in_memory()
183 .unwrap()
184 .with_builtin_embeddings()
185 .unwrap();
186 let result = substrate.embed_and_store("test text", HashMap::new()).unwrap();
187 assert!(result.is_some());
188 }
189
190 #[test]
191 fn test_semantic_search_with_builtin() {
192 let substrate = MemorySubstrate::in_memory()
193 .unwrap()
194 .with_builtin_embeddings()
195 .unwrap();
196 substrate.embed_and_store("hello world", HashMap::new()).unwrap();
197 let results = substrate.semantic_search("hello", 5).unwrap();
198 assert!(results.is_some());
199 }
200
201 #[tokio::test]
202 async fn test_conn_access() {
203 let substrate = MemorySubstrate::in_memory().unwrap();
204 let conn = substrate.conn().await;
205 let count: i64 = conn
206 .query_row(
207 "SELECT COUNT(*) FROM sqlite_master WHERE type='table'",
208 [],
209 |row| row.get(0),
210 )
211 .unwrap();
212 assert!(count > 0, "should have tables from migrations");
213 }
214}