1use bytepunch::{Decompressor, Dictionary};
6use crate::spool::SpoolReader;
7use rusqlite::{params, Connection, OptionalExtension};
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10
11#[derive(Debug, Clone)]
13pub struct SearchResult {
14 pub id: String,
15 pub score: f32,
16}
17
18#[derive(Debug, Clone)]
20pub struct DocumentRef {
21 pub id: String,
22 pub file_path: String,
23 pub source: String,
24 pub metadata: Option<String>,
25 pub spool_offset: Option<u64>,
27 pub spool_length: Option<u32>,
29}
30
31#[derive(Debug, Clone)]
33pub struct FileSearchResult {
34 pub doc_ref: DocumentRef,
35 pub score: f32,
36}
37
38pub struct PersistentVectorStore {
40 conn: Arc<Mutex<Connection>>,
41 dimension: usize,
42}
43
44impl PersistentVectorStore {
45 pub fn new<P: AsRef<Path>>(path: P, dimension: usize) -> rusqlite::Result<Self> {
47 let conn = Connection::open(path)?;
48
49 let store = Self {
50 conn: Arc::new(Mutex::new(conn)),
51 dimension,
52 };
53
54 store.initialize_schema()?;
55 Ok(store)
56 }
57
58 fn initialize_schema(&self) -> rusqlite::Result<()> {
60 let conn = self.conn.lock().unwrap();
61
62 conn.execute(
64 "CREATE TABLE IF NOT EXISTS documents (
65 id TEXT PRIMARY KEY,
66 file_path TEXT NOT NULL,
67 source TEXT NOT NULL,
68 metadata TEXT,
69 spool_offset INTEGER,
70 spool_length INTEGER
71 )",
72 [],
73 )?;
74
75 conn.execute(
77 "CREATE TABLE IF NOT EXISTS embeddings (
78 doc_id TEXT PRIMARY KEY,
79 vector BLOB NOT NULL,
80 FOREIGN KEY (doc_id) REFERENCES documents(id) ON DELETE CASCADE
81 )",
82 [],
83 )?;
84
85 Ok(())
89 }
90
91 pub fn add_document(
93 &self,
94 id: &str,
95 file_path: &str,
96 source: &str,
97 metadata: Option<&str>,
98 embedding: &[f32],
99 ) -> rusqlite::Result<()> {
100 self.add_document_with_spool(id, file_path, source, metadata, None, None, embedding)
101 }
102
103 pub fn add_document_with_spool(
105 &self,
106 id: &str,
107 file_path: &str,
108 source: &str,
109 metadata: Option<&str>,
110 spool_offset: Option<u64>,
111 spool_length: Option<u32>,
112 embedding: &[f32],
113 ) -> rusqlite::Result<()> {
114 if embedding.len() != self.dimension {
115 return Err(rusqlite::Error::InvalidParameterName(
116 format!("Expected {} dims, got {}", self.dimension, embedding.len()).into(),
117 ));
118 }
119
120 let conn = self.conn.lock().unwrap();
121
122 conn.execute(
124 "INSERT OR REPLACE INTO documents (id, file_path, source, metadata, spool_offset, spool_length) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
125 params![
126 id,
127 file_path,
128 source,
129 metadata,
130 spool_offset.map(|o| o as i64),
131 spool_length.map(|l| l as i64),
132 ],
133 )?;
134
135 let vector_bytes = embedding
137 .iter()
138 .flat_map(|f| f.to_le_bytes())
139 .collect::<Vec<u8>>();
140
141 conn.execute(
142 "INSERT OR REPLACE INTO embeddings (doc_id, vector) VALUES (?1, ?2)",
143 params![id, &vector_bytes],
144 )?;
145
146 Ok(())
147 }
148
149 pub fn search(
152 &self,
153 query_embedding: &[f32],
154 top_k: usize,
155 ) -> rusqlite::Result<Vec<FileSearchResult>> {
156 if query_embedding.len() != self.dimension {
157 return Err(rusqlite::Error::InvalidParameterName(
158 format!(
159 "Expected {} dims, got {}",
160 self.dimension,
161 query_embedding.len()
162 )
163 .into(),
164 ));
165 }
166
167 let conn = self.conn.lock().unwrap();
168
169 let mut stmt = conn.prepare(
171 "SELECT d.id, d.file_path, d.source, d.metadata, d.spool_offset, d.spool_length, e.vector
172 FROM documents d
173 JOIN embeddings e ON d.id = e.doc_id"
174 )?;
175
176 let mut results: Vec<(DocumentRef, f32)> = Vec::new();
177
178 let rows = stmt.query_map([], |row| {
179 let id: String = row.get(0)?;
180 let file_path: String = row.get(1)?;
181 let source: String = row.get(2)?;
182 let metadata: Option<String> = row.get(3)?;
183 let spool_offset: Option<i64> = row.get(4)?;
184 let spool_length: Option<i64> = row.get(5)?;
185 let vector_bytes: Vec<u8> = row.get(6)?;
186
187 let embedding: Vec<f32> = vector_bytes
189 .chunks_exact(4)
190 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
191 .collect();
192
193 let similarity = cosine_similarity(query_embedding, &embedding);
195
196 Ok((
197 DocumentRef {
198 id,
199 file_path,
200 source,
201 metadata,
202 spool_offset: spool_offset.map(|o| o as u64),
203 spool_length: spool_length.map(|l| l as u32),
204 },
205 similarity,
206 ))
207 })?;
208
209 for row in rows {
210 let (doc_ref, score) = row?;
211 results.push((doc_ref, score));
212 }
213
214 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
216
217 Ok(results
219 .into_iter()
220 .take(top_k)
221 .map(|(doc_ref, score)| FileSearchResult { doc_ref, score })
222 .collect())
223 }
224
225 pub fn count(&self) -> rusqlite::Result<usize> {
227 let conn = self.conn.lock().unwrap();
228 let count: i64 = conn.query_row("SELECT COUNT(*) FROM documents", [], |row| row.get(0))?;
229 Ok(count as usize)
230 }
231}
232
233fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
235 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
236 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
237 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
238
239 if norm_a == 0.0 || norm_b == 0.0 {
240 0.0
241 } else {
242 dot / (norm_a * norm_b)
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_persistent_store() {
252 let store = PersistentVectorStore::new(":memory:", 384).unwrap();
253
254 let embedding = vec![0.1f32; 384];
256 store
257 .add_document(
258 "test-1",
259 "docs/test-1.cml.bp",
260 "test",
261 Some("{\"type\": \"test\"}"),
262 &embedding,
263 )
264 .unwrap();
265
266 assert_eq!(store.count().unwrap(), 1);
267
268 let results = store.search(&embedding, 5).unwrap();
270 assert_eq!(results.len(), 1);
271 assert_eq!(results[0].doc_ref.id, "test-1");
272 assert_eq!(results[0].doc_ref.file_path, "docs/test-1.cml.bp");
273 assert!(results[0].score > 0.99); }
275
276 #[test]
277 fn test_persistent_store_with_spool() {
278 let store = PersistentVectorStore::new(":memory:", 384).unwrap();
279
280 let embedding = vec![0.2f32; 384];
282 store
283 .add_document_with_spool(
284 "test-2",
285 "docs/test.spool",
286 "test",
287 Some("{\"type\": \"spool\"}"),
288 Some(1024),
289 Some(512),
290 &embedding,
291 )
292 .unwrap();
293
294 assert_eq!(store.count().unwrap(), 1);
295
296 let results = store.search(&embedding, 5).unwrap();
298 assert_eq!(results.len(), 1);
299 assert_eq!(results[0].doc_ref.id, "test-2");
300 assert_eq!(results[0].doc_ref.spool_offset, Some(1024));
301 assert_eq!(results[0].doc_ref.spool_length, Some(512));
302 }
303}