1use serde::{Deserialize, Serialize};
2
3use crate::{Database, DbResultExt};
4use roboticus_core::Result;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct EmbeddingEntry {
8 pub id: String,
9 pub source_table: String,
10 pub source_id: String,
11 pub content_preview: String,
12 pub embedding: Vec<f32>,
13 pub created_at: String,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct SearchResult {
18 pub source_table: String,
19 pub source_id: String,
20 pub content_preview: String,
21 pub similarity: f64,
22}
23
24pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
25 if a.len() != b.len() || a.is_empty() {
26 return 0.0;
27 }
28
29 let mut dot = 0.0f64;
30 let mut norm_a = 0.0f64;
31 let mut norm_b = 0.0f64;
32
33 for i in 0..a.len() {
34 let ai = a[i] as f64;
35 let bi = b[i] as f64;
36 dot += ai * bi;
37 norm_a += ai * ai;
38 norm_b += bi * bi;
39 }
40
41 let denom = norm_a.sqrt() * norm_b.sqrt();
42 if denom == 0.0 { 0.0 } else { dot / denom }
43}
44
45pub fn embedding_to_blob(embedding: &[f32]) -> Vec<u8> {
47 let mut bytes = Vec::with_capacity(embedding.len() * 4);
48 for &val in embedding {
49 bytes.extend_from_slice(&val.to_le_bytes());
50 }
51 bytes
52}
53
54pub fn blob_to_embedding(blob: &[u8]) -> Vec<f32> {
56 blob.chunks_exact(4)
57 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
58 .collect()
59}
60
61pub fn store_embedding(
63 db: &Database,
64 id: &str,
65 source_table: &str,
66 source_id: &str,
67 content_preview: &str,
68 embedding: &[f32],
69) -> Result<()> {
70 let blob = embedding_to_blob(embedding);
71 let dimensions = embedding.len() as i64;
72
73 let conn = db.conn();
74 conn.execute(
75 "INSERT OR REPLACE INTO embeddings \
76 (id, source_table, source_id, content_preview, embedding_blob, dimensions) \
77 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
78 rusqlite::params![
79 id,
80 source_table,
81 source_id,
82 content_preview,
83 blob,
84 dimensions
85 ],
86 )
87 .db_err()?;
88
89 Ok(())
90}
91
92fn load_embedding_from_row(blob: Option<Vec<u8>>) -> Option<Vec<f32>> {
94 if let Some(b) = blob
95 && !b.is_empty()
96 {
97 return Some(blob_to_embedding(&b));
98 }
99 None
100}
101
102pub fn search_similar(
111 db: &Database,
112 query_embedding: &[f32],
113 limit: usize,
114 min_similarity: f64,
115) -> Result<Vec<SearchResult>> {
116 let conn = db.conn();
117 let mut stmt = conn
118 .prepare(
119 "SELECT source_table, source_id, content_preview, embedding_blob \
120 FROM embeddings LIMIT 10000",
121 )
122 .db_err()?;
123
124 let rows = stmt
125 .query_map([], |row| {
126 Ok((
127 row.get::<_, String>(0)?,
128 row.get::<_, String>(1)?,
129 row.get::<_, String>(2)?,
130 row.get::<_, Option<Vec<u8>>>(3)?,
131 ))
132 })
133 .db_err()?;
134
135 let mut results: Vec<SearchResult> = Vec::new();
136
137 for row in rows {
138 let (source_table, source_id, content_preview, blob) = row.db_err()?;
139
140 let embedding = match load_embedding_from_row(blob) {
141 Some(e) => e,
142 None => continue,
143 };
144
145 let similarity = cosine_similarity(query_embedding, &embedding);
146
147 if similarity >= min_similarity {
148 results.push(SearchResult {
149 source_table,
150 source_id,
151 content_preview,
152 similarity,
153 });
154 }
155 }
156
157 results.sort_by(|a, b| {
158 b.similarity
159 .partial_cmp(&a.similarity)
160 .unwrap_or(std::cmp::Ordering::Equal)
161 });
162 results.truncate(limit);
163
164 Ok(results)
165}
166
167pub fn hybrid_search(
168 db: &Database,
169 query_text: &str,
170 query_embedding: Option<&[f32]>,
171 limit: usize,
172 hybrid_weight: f64,
173) -> Result<Vec<SearchResult>> {
174 let mut fts_results: Vec<SearchResult> = Vec::new();
175
176 {
177 let conn = db.conn();
178 let safe_query = crate::memory::sanitize_fts_query(query_text);
179 let mut stmt = conn
180 .prepare("SELECT content, category FROM memory_fts WHERE memory_fts MATCH ?1 LIMIT ?2")
181 .db_err()?;
182
183 let rows = stmt
184 .query_map(rusqlite::params![safe_query, limit * 2], |row| {
185 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
186 })
187 .db_err()?;
188
189 for (i, row) in rows.enumerate() {
190 let (content, category) = row.db_err()?;
191 let fts_score = 1.0 - (i as f64 * 0.05).min(0.9);
192 fts_results.push(SearchResult {
193 source_table: category,
194 source_id: String::new(),
195 content_preview: content.chars().take(200).collect(),
196 similarity: fts_score * (1.0 - hybrid_weight),
197 });
198 }
199 }
200
201 if let Some(embedding) = query_embedding {
202 let vec_results = search_similar(db, embedding, limit * 2, 0.0)?;
203 for mut r in vec_results {
204 r.similarity *= hybrid_weight;
205 fts_results.push(r);
206 }
207 }
208
209 fts_results.sort_by(|a, b| {
210 b.similarity
211 .partial_cmp(&a.similarity)
212 .unwrap_or(std::cmp::Ordering::Equal)
213 });
214 fts_results.truncate(limit);
215
216 Ok(fts_results)
217}
218
219pub fn cleanup_orphaned_embeddings(db: &Database) -> Result<usize> {
226 let conn = db.conn();
227 let deleted = conn
228 .execute(
229 "DELETE FROM embeddings WHERE NOT ( \
230 (source_table = 'working_memory' AND source_id IN (SELECT id FROM working_memory)) \
231 OR (source_table = 'episodic_memory' AND source_id IN (SELECT id FROM episodic_memory)) \
232 OR (source_table = 'semantic_memory' AND source_id IN (SELECT id FROM semantic_memory)) \
233 OR (source_table = 'procedural_memory' AND source_id IN (SELECT id FROM procedural_memory)) \
234 OR (source_table = 'relationship_memory' AND source_id IN (SELECT id FROM relationship_memory)) \
235 )",
236 [],
237 )
238 .db_err()?;
239 Ok(deleted)
240}
241
242#[cfg(test)]
243pub(crate) fn embedding_count(db: &Database) -> Result<usize> {
244 let conn = db.conn();
245 let count: usize = conn
246 .query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0))
247 .db_err()?;
248 Ok(count)
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 fn test_db() -> Database {
256 Database::new(":memory:").unwrap()
257 }
258
259 #[test]
260 fn blob_roundtrip() {
261 let original = vec![1.0f32, -0.5, 0.0, 1.23456, f32::MIN, f32::MAX];
262 let blob = embedding_to_blob(&original);
263 let restored = blob_to_embedding(&blob);
264 assert_eq!(original, restored);
265 }
266
267 #[test]
268 fn blob_empty() {
269 let blob = embedding_to_blob(&[]);
270 assert!(blob.is_empty());
271 let restored = blob_to_embedding(&blob);
272 assert!(restored.is_empty());
273 }
274
275 #[test]
276 fn blob_size_is_4x_floats() {
277 let emb = vec![0.0f32; 768];
278 let blob = embedding_to_blob(&emb);
279 assert_eq!(blob.len(), 768 * 4);
280 }
281
282 #[test]
283 fn cosine_identical_vectors() {
284 let v = vec![1.0, 2.0, 3.0];
285 let sim = cosine_similarity(&v, &v);
286 assert!((sim - 1.0).abs() < 1e-6);
287 }
288
289 #[test]
290 fn cosine_orthogonal_vectors() {
291 let a = vec![1.0, 0.0];
292 let b = vec![0.0, 1.0];
293 let sim = cosine_similarity(&a, &b);
294 assert!(sim.abs() < 1e-6);
295 }
296
297 #[test]
298 fn cosine_opposite_vectors() {
299 let a = vec![1.0, 0.0];
300 let b = vec![-1.0, 0.0];
301 let sim = cosine_similarity(&a, &b);
302 assert!((sim - (-1.0)).abs() < 1e-6);
303 }
304
305 #[test]
306 fn cosine_empty_vectors() {
307 let sim = cosine_similarity(&[], &[]);
308 assert_eq!(sim, 0.0);
309 }
310
311 #[test]
312 fn cosine_mismatched_lengths() {
313 let a = vec![1.0, 2.0];
314 let b = vec![1.0];
315 let sim = cosine_similarity(&a, &b);
316 assert_eq!(sim, 0.0);
317 }
318
319 #[test]
320 fn store_and_search() {
321 let db = test_db();
322 let emb1 = vec![1.0, 0.0, 0.0];
323 let emb2 = vec![0.0, 1.0, 0.0];
324 let emb3 = vec![0.9, 0.1, 0.0];
325
326 store_embedding(&db, "e1", "episodic_memory", "ep1", "first entry", &emb1).unwrap();
327 store_embedding(&db, "e2", "episodic_memory", "ep2", "second entry", &emb2).unwrap();
328 store_embedding(&db, "e3", "semantic_memory", "s1", "third entry", &emb3).unwrap();
329
330 let query = vec![1.0, 0.0, 0.0];
331 let results = search_similar(&db, &query, 10, 0.5).unwrap();
332
333 assert_eq!(results.len(), 2);
334 assert_eq!(results[0].source_id, "ep1");
335 assert!((results[0].similarity - 1.0).abs() < 1e-6);
336 assert!(results[1].similarity > 0.5);
337 }
338
339 #[test]
340 fn store_replaces_existing() {
341 let db = test_db();
342 let emb1 = vec![1.0, 0.0];
343 let emb2 = vec![0.0, 1.0];
344 store_embedding(&db, "e1", "episodic_memory", "t1", "v1", &emb1).unwrap();
345 store_embedding(&db, "e1", "episodic_memory", "t1", "v2", &emb2).unwrap();
346 assert_eq!(embedding_count(&db).unwrap(), 1);
347 }
348
349 #[test]
350 fn search_min_similarity_filter() {
351 let db = test_db();
352 store_embedding(&db, "e1", "episodic_memory", "1", "a", &[1.0, 0.0]).unwrap();
353 store_embedding(&db, "e2", "episodic_memory", "2", "b", &[0.0, 1.0]).unwrap();
354
355 let results = search_similar(&db, &[1.0, 0.0], 10, 0.99).unwrap();
356 assert_eq!(results.len(), 1);
357 }
358
359 #[test]
360 fn embedding_count_works() {
361 let db = test_db();
362 assert_eq!(embedding_count(&db).unwrap(), 0);
363 store_embedding(&db, "e1", "episodic_memory", "1", "a", &[1.0]).unwrap();
364 assert_eq!(embedding_count(&db).unwrap(), 1);
365 }
366
367 #[test]
368 fn cosine_zero_vector() {
369 let a = vec![0.0, 0.0];
370 let b = vec![1.0, 0.0];
371 assert_eq!(cosine_similarity(&a, &b), 0.0);
372 }
373
374 #[test]
375 fn hybrid_search_vector_only() {
376 let db = test_db();
377 store_embedding(
378 &db,
379 "e1",
380 "episodic_memory",
381 "t1",
382 "hello world",
383 &[1.0, 0.0, 0.0],
384 )
385 .unwrap();
386 store_embedding(
387 &db,
388 "e2",
389 "episodic_memory",
390 "t2",
391 "goodbye",
392 &[0.0, 1.0, 0.0],
393 )
394 .unwrap();
395
396 let results =
397 hybrid_search(&db, "zzzznonexistent", Some(&[1.0, 0.0, 0.0]), 10, 0.5).unwrap();
398 assert!(!results.is_empty());
399 }
400
401 #[test]
402 fn hybrid_search_empty_db() {
403 let db = test_db();
404 let results = hybrid_search(&db, "anything", Some(&[1.0, 0.0]), 10, 0.5).unwrap();
405 assert!(results.is_empty());
406 }
407
408 #[test]
409 fn hybrid_search_respects_limit() {
410 let db = test_db();
411 for i in 0..20 {
412 store_embedding(
413 &db,
414 &format!("e{i}"),
415 "episodic_memory",
416 &format!("t{i}"),
417 &format!("entry {i}"),
418 &[1.0, 0.0],
419 )
420 .unwrap();
421 }
422 let results = hybrid_search(&db, "entry", Some(&[1.0, 0.0]), 5, 0.5).unwrap();
423 assert!(results.len() <= 5);
424 }
425
426 #[test]
427 fn hybrid_search_no_embedding() {
428 let db = test_db();
429 store_embedding(
430 &db,
431 "e1",
432 "episodic_memory",
433 "t1",
434 "hello world",
435 &[1.0, 0.0],
436 )
437 .unwrap();
438 let results = hybrid_search(&db, "hello", None, 10, 0.5).unwrap();
439 assert!(results.is_empty() || !results.is_empty());
440 }
441
442 #[test]
443 fn hybrid_search_sorted_by_similarity() {
444 let db = test_db();
445 store_embedding(
446 &db,
447 "e1",
448 "episodic_memory",
449 "t1",
450 "first",
451 &[1.0, 0.0, 0.0],
452 )
453 .unwrap();
454 store_embedding(
455 &db,
456 "e2",
457 "episodic_memory",
458 "t2",
459 "second",
460 &[0.5, 0.5, 0.0],
461 )
462 .unwrap();
463 store_embedding(
464 &db,
465 "e3",
466 "episodic_memory",
467 "t3",
468 "third",
469 &[0.0, 0.0, 1.0],
470 )
471 .unwrap();
472
473 let results = hybrid_search(&db, "query", Some(&[1.0, 0.0, 0.0]), 10, 1.0).unwrap();
474 for w in results.windows(2) {
475 assert!(w[0].similarity >= w[1].similarity);
476 }
477 }
478
479 #[test]
480 fn load_embedding_from_blob() {
481 let emb = vec![1.0f32, 2.0, 3.0];
482 let blob = embedding_to_blob(&emb);
483 let loaded = load_embedding_from_row(Some(blob)).unwrap();
484 assert_eq!(loaded, emb);
485 }
486
487 #[test]
488 fn load_embedding_none_returns_none() {
489 let loaded = load_embedding_from_row(None);
490 assert!(loaded.is_none());
491 }
492
493 #[test]
494 fn load_embedding_empty_blob_returns_none() {
495 let loaded = load_embedding_from_row(Some(vec![]));
496 assert!(loaded.is_none());
497 }
498
499 #[test]
500 fn search_similar_skips_row_without_embedding() {
501 let db = test_db();
502 {
504 let conn = db.conn();
505 conn.execute(
506 "INSERT INTO embeddings (id, source_table, source_id, content_preview, embedding_blob, dimensions) \
507 VALUES ('e-no-emb', 'episodic_memory', 't1', 'no embedding here', NULL, 0)",
508 [],
509 ).unwrap();
510 }
511 store_embedding(
513 &db,
514 "e-real",
515 "episodic_memory",
516 "t2",
517 "has embedding",
518 &[1.0, 0.0],
519 )
520 .unwrap();
521
522 let results = search_similar(&db, &[1.0, 0.0], 10, 0.0).unwrap();
523 assert_eq!(results.len(), 1);
525 assert_eq!(results[0].source_id, "t2");
526 }
527
528 #[test]
529 fn hybrid_search_fts_matches() {
530 let db = test_db();
531 crate::memory::store_working(&db, "sess", "note", "quantum computing breakthrough", 5)
533 .unwrap();
534 store_embedding(
535 &db,
536 "e1",
537 "episodic_memory",
538 "t1",
539 "classical computing",
540 &[0.0, 1.0],
541 )
542 .unwrap();
543
544 let results = hybrid_search(&db, "quantum", Some(&[1.0, 0.0]), 10, 0.5).unwrap();
546 assert!(
547 !results.is_empty(),
548 "hybrid search should find FTS match for 'quantum'"
549 );
550 }
551
552 #[test]
553 fn hybrid_search_fts_only_no_embedding() {
554 let db = test_db();
555 crate::memory::store_working(&db, "sess", "note", "unique identifier xyzzy", 5).unwrap();
556
557 let results = hybrid_search(&db, "xyzzy", None, 10, 0.5).unwrap();
559 assert!(
561 !results.is_empty(),
562 "hybrid search without embedding should find FTS results"
563 );
564 }
565
566 #[test]
567 fn hybrid_search_combined_scores() {
568 let db = test_db();
569 crate::memory::store_working(&db, "sess", "note", "machine learning algorithms", 5)
570 .unwrap();
571 store_embedding(
572 &db,
573 "e1",
574 "episodic_memory",
575 "t1",
576 "machine learning",
577 &[1.0, 0.0, 0.0],
578 )
579 .unwrap();
580
581 let results = hybrid_search(&db, "machine", Some(&[1.0, 0.0, 0.0]), 10, 0.5).unwrap();
582 assert!(!results.is_empty());
584 for w in results.windows(2) {
586 assert!(w[0].similarity >= w[1].similarity);
587 }
588 }
589
590 #[test]
593 fn cleanup_orphaned_embeddings_removes_dangling() {
594 let db = test_db();
595 crate::memory::store_working(&db, "s1", "note", "valid", 5).unwrap();
597 let wm_id = {
598 let conn = db.conn();
599 conn.query_row("SELECT id FROM working_memory LIMIT 1", [], |r| {
600 r.get::<_, String>(0)
601 })
602 .unwrap()
603 };
604 store_embedding(
605 &db,
606 "e-valid",
607 "working_memory",
608 &wm_id,
609 "valid",
610 &[1.0, 0.0],
611 )
612 .unwrap();
613
614 store_embedding(
616 &db,
617 "e-orphan",
618 "working_memory",
619 "no-such-id",
620 "orphan",
621 &[0.0, 1.0],
622 )
623 .unwrap();
624
625 assert_eq!(embedding_count(&db).unwrap(), 2);
626 let deleted = cleanup_orphaned_embeddings(&db).unwrap();
627 assert_eq!(deleted, 1);
628 assert_eq!(embedding_count(&db).unwrap(), 1);
629 }
630
631 #[test]
632 fn cleanup_orphaned_embeddings_noop_when_clean() {
633 let db = test_db();
634 crate::memory::store_semantic(&db, "facts", "k1", "v1", 0.9).unwrap();
635 let sem_id = {
636 let conn = db.conn();
637 conn.query_row("SELECT id FROM semantic_memory LIMIT 1", [], |r| {
638 r.get::<_, String>(0)
639 })
640 .unwrap()
641 };
642 store_embedding(&db, "e1", "semantic_memory", &sem_id, "valid", &[1.0, 0.0]).unwrap();
643
644 let deleted = cleanup_orphaned_embeddings(&db).unwrap();
645 assert_eq!(deleted, 0);
646 }
647}