1use super::Database;
6use crate::error::Result;
7use chrono::Utc;
8use rusqlite::params;
9
10#[derive(Debug, Clone)]
12pub enum CacheLookupResult {
13 Hit(Vec<f32>),
15 Miss,
17 ModelMismatch,
19}
20
21impl Database {
22 pub fn ensure_vec_table(&self, _dimensions: usize) -> Result<()> {
24 self.conn.execute(
25 "CREATE TABLE IF NOT EXISTS embeddings (
26 hash_seq TEXT PRIMARY KEY,
27 embedding BLOB NOT NULL
28 )",
29 [],
30 )?;
31 Ok(())
33 }
34
35 pub fn insert_embedding(
37 &self,
38 hash: &str,
39 seq: u32,
40 pos: usize,
41 model: &str,
42 embedding: &[f32],
43 ) -> Result<()> {
44 let now = Utc::now().to_rfc3339();
45 let hash_seq = format!("{}_{}", hash, seq);
46 let embedding_bytes = embedding_to_bytes(embedding);
47
48 self.conn.execute("BEGIN IMMEDIATE", [])?;
49 let result = (|| {
50 self.conn.execute(
51 "INSERT OR REPLACE INTO content_vectors (hash, seq, pos, model, created_at)
52 VALUES (?1, ?2, ?3, ?4, ?5)",
53 params![hash, seq, pos, model, now],
54 )?;
55 self.conn.execute(
56 "INSERT OR REPLACE INTO embeddings (hash_seq, embedding) VALUES (?1, ?2)",
57 params![hash_seq, embedding_bytes],
58 )?;
59 Ok(())
60 })();
61
62 if result.is_ok() {
63 self.conn.execute("COMMIT", [])?;
64 } else {
65 let _ = self.conn.execute("ROLLBACK", []);
66 }
67 result
68 }
69
70 pub fn has_vector_index(&self) -> bool {
72 self.conn
73 .query_row("SELECT COUNT(*) FROM content_vectors", [], |row| {
74 row.get::<_, i64>(0)
75 })
76 .map(|count| count > 0)
77 .unwrap_or(false)
78 }
79
80 pub fn get_all_embeddings(&self) -> Result<Vec<(String, Vec<f32>)>> {
82 let mut stmt = self
83 .conn
84 .prepare("SELECT hash_seq, embedding FROM embeddings")?;
85
86 let results = stmt
87 .query_map([], |row| {
88 let hash_seq: String = row.get(0)?;
89 let embedding_bytes: Vec<u8> = row.get(1)?;
90 let embedding = bytes_to_embedding(&embedding_bytes);
91 Ok((hash_seq, embedding))
92 })?
93 .collect::<std::result::Result<Vec<_>, _>>()?;
94
95 Ok(results)
96 }
97
98 pub fn get_embeddings_for_collection(
100 &self,
101 collection: &str,
102 ) -> Result<Vec<(String, Vec<f32>)>> {
103 let mut stmt = self.conn.prepare(
104 "SELECT e.hash_seq, e.embedding
105 FROM embeddings e
106 JOIN content_vectors cv ON e.hash_seq = cv.hash || '_' || cv.seq
107 JOIN documents d ON d.hash = cv.hash AND d.active = 1
108 WHERE d.collection = ?1",
109 )?;
110
111 let results = stmt
112 .query_map(params![collection], |row| {
113 let hash_seq: String = row.get(0)?;
114 let embedding_bytes: Vec<u8> = row.get(1)?;
115 let embedding = bytes_to_embedding(&embedding_bytes);
116 Ok((hash_seq, embedding))
117 })?
118 .collect::<std::result::Result<Vec<_>, _>>()?;
119
120 Ok(results)
121 }
122
123 pub fn get_hashes_needing_embedding(&self) -> Result<Vec<(String, String)>> {
125 let mut stmt = self.conn.prepare(
126 "SELECT c.hash, c.doc FROM content c
127 JOIN documents d ON d.hash = c.hash AND d.active = 1
128 WHERE c.hash NOT IN (SELECT DISTINCT hash FROM content_vectors)",
129 )?;
130
131 let results = stmt
132 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
133 .collect::<std::result::Result<Vec<_>, _>>()?;
134
135 Ok(results)
136 }
137
138 pub fn count_hashes_needing_embedding(&self) -> Result<usize> {
140 let count: i64 = self.conn.query_row(
141 "SELECT COUNT(DISTINCT c.hash) FROM content c
142 JOIN documents d ON d.hash = c.hash AND d.active = 1
143 WHERE c.hash NOT IN (SELECT DISTINCT hash FROM content_vectors)",
144 [],
145 |row| row.get(0),
146 )?;
147 Ok(count as usize)
148 }
149
150 pub fn delete_embeddings(&self, hash: &str) -> Result<usize> {
152 let pattern = format!("{}_*", hash);
153
154 self.conn.execute("BEGIN IMMEDIATE", [])?;
155 let result = (|| {
156 self.conn
157 .execute("DELETE FROM content_vectors WHERE hash = ?1", params![hash])?;
158 let rows = self.conn.execute(
161 "DELETE FROM embeddings WHERE hash_seq GLOB ?1",
162 params![pattern],
163 )?;
164 Ok(rows)
165 })();
166
167 if result.is_ok() {
168 self.conn.execute("COMMIT", [])?;
169 } else {
170 let _ = self.conn.execute("ROLLBACK", []);
171 }
172 result
173 }
174
175 pub fn get_all_hashes_for_embedding(&self) -> Result<Vec<(String, String)>> {
177 let mut stmt = self.conn.prepare(
178 "SELECT c.hash, c.doc FROM content c
179 JOIN documents d ON d.hash = c.hash AND d.active = 1",
180 )?;
181
182 let results = stmt
183 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
184 .collect::<std::result::Result<Vec<_>, _>>()?;
185
186 Ok(results)
187 }
188
189 pub fn check_model_compatibility(&self, model: &str, expected_dims: usize) -> Result<bool> {
191 match self.get_model_dimensions(model)? {
192 Some(stored_dims) => Ok(stored_dims == expected_dims),
193 None => Ok(true), }
195 }
196
197 pub fn get_cached_embedding(
199 &self,
200 chunk_hash: &str,
201 model: &str,
202 expected_dims: usize,
203 ) -> Result<CacheLookupResult> {
204 if !self.check_model_compatibility(model, expected_dims)? {
205 return Ok(CacheLookupResult::ModelMismatch);
206 }
207 self.get_cached_embedding_fast(chunk_hash, model)
208 }
209
210 pub fn get_cached_embedding_fast(
212 &self,
213 chunk_hash: &str,
214 model: &str,
215 ) -> Result<CacheLookupResult> {
216 let result = self.conn.query_row(
217 "SELECT embedding FROM chunk_embeddings WHERE chunk_hash = ?1 AND model = ?2",
218 params![chunk_hash, model],
219 |row| {
220 let bytes: Vec<u8> = row.get(0)?;
221 Ok(bytes_to_embedding(&bytes))
222 },
223 );
224
225 match result {
226 Ok(embedding) => Ok(CacheLookupResult::Hit(embedding)),
227 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(CacheLookupResult::Miss),
228 Err(e) => Err(e.into()),
229 }
230 }
231
232 pub fn insert_chunk_embedding(
234 &self,
235 doc_hash: &str,
236 seq: u32,
237 pos: usize,
238 chunk_hash: &str,
239 model: &str,
240 embedding: &[f32],
241 ) -> Result<()> {
242 let now = Utc::now().to_rfc3339();
243 let hash_seq = format!("{}_{}", doc_hash, seq);
244 let embedding_bytes = embedding_to_bytes(embedding);
245
246 self.conn.execute("BEGIN IMMEDIATE", [])?;
247 let result = (|| {
248 self.conn.execute(
249 "INSERT OR REPLACE INTO content_vectors (hash, seq, pos, model, chunk_hash, created_at)
250 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
251 params![doc_hash, seq, pos, model, chunk_hash, now],
252 )?;
253 self.conn.execute(
254 "INSERT OR REPLACE INTO embeddings (hash_seq, embedding) VALUES (?1, ?2)",
255 params![hash_seq, embedding_bytes],
256 )?;
257 self.conn.execute(
258 "INSERT OR REPLACE INTO chunk_embeddings (chunk_hash, model, embedding, created_at)
259 VALUES (?1, ?2, ?3, ?4)",
260 params![chunk_hash, model, &embedding_bytes, now],
261 )?;
262 Ok(())
263 })();
264
265 if result.is_ok() {
266 self.conn.execute("COMMIT", [])?;
267 } else {
268 let _ = self.conn.execute("ROLLBACK", []);
269 }
270 result
271 }
272
273 pub fn get_chunk_hashes_for_doc(&self, doc_hash: &str) -> Result<Vec<(u32, String)>> {
275 let mut stmt = self.conn.prepare(
276 "SELECT seq, chunk_hash FROM content_vectors WHERE hash = ?1 AND chunk_hash IS NOT NULL"
277 )?;
278
279 let results = stmt
280 .query_map(params![doc_hash], |row| Ok((row.get(0)?, row.get(1)?)))?
281 .collect::<std::result::Result<Vec<_>, _>>()?;
282
283 Ok(results)
284 }
285
286 pub fn cleanup_orphaned_chunk_embeddings(&self) -> Result<usize> {
288 let count = self.conn.execute(
289 "DELETE FROM chunk_embeddings WHERE chunk_hash NOT IN (
290 SELECT DISTINCT chunk_hash FROM content_vectors WHERE chunk_hash IS NOT NULL
291 )",
292 [],
293 )?;
294 Ok(count)
295 }
296
297 pub fn register_model(&self, model: &str, dimensions: usize) -> Result<()> {
299 let now = Utc::now().to_rfc3339();
300
301 self.conn.execute(
302 "INSERT INTO model_metadata (model, dimensions, created_at, last_used_at)
303 VALUES (?1, ?2, ?3, ?3)
304 ON CONFLICT(model) DO UPDATE SET last_used_at = ?3",
305 params![model, dimensions as i64, now],
306 )?;
307
308 Ok(())
309 }
310
311 pub fn get_model_dimensions(&self, model: &str) -> Result<Option<usize>> {
313 let result = self.conn.query_row(
314 "SELECT dimensions FROM model_metadata WHERE model = ?1",
315 params![model],
316 |row| row.get::<_, i64>(0),
317 );
318
319 match result {
320 Ok(dims) => Ok(Some(dims as usize)),
321 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
322 Err(e) => Err(e.into()),
323 }
324 }
325
326 pub fn count_cached_embeddings(&self, model: &str) -> Result<usize> {
328 let count: i64 = self.conn.query_row(
329 "SELECT COUNT(*) FROM chunk_embeddings WHERE model = ?1",
330 params![model],
331 |row| row.get(0),
332 )?;
333 Ok(count as usize)
334 }
335}
336
337pub fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
339 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
340}
341
342pub fn bytes_to_embedding(bytes: &[u8]) -> Vec<f32> {
344 bytes
345 .chunks_exact(4)
346 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
347 .collect()
348}
349
350pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
352 if a.len() != b.len() || a.is_empty() {
353 return 0.0;
354 }
355
356 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
357 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
358 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
359
360 if norm_a == 0.0 || norm_b == 0.0 {
361 return 0.0;
362 }
363
364 dot / (norm_a * norm_b)
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_embedding_roundtrip() {
373 let original = vec![1.0f32, 2.0, 3.0, -1.5];
374 let bytes = embedding_to_bytes(&original);
375 let restored = bytes_to_embedding(&bytes);
376 assert_eq!(original, restored);
377 }
378
379 #[test]
380 fn test_cosine_similarity_identical() {
381 let a = vec![1.0, 0.0, 0.0];
382 let b = vec![1.0, 0.0, 0.0];
383 let sim = cosine_similarity(&a, &b);
384 assert!((sim - 1.0).abs() < 0.0001);
385 }
386
387 #[test]
388 fn test_cosine_similarity_orthogonal() {
389 let a = vec![1.0, 0.0, 0.0];
390 let b = vec![0.0, 1.0, 0.0];
391 let sim = cosine_similarity(&a, &b);
392 assert!(sim.abs() < 0.0001);
393 }
394}