Skip to main content

roboticus_db/
embeddings.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{Database, DbResultExt};
4use roboticus_core::Result;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct EmbeddingEntry {
8    pub id: String,
9    pub source_table: String,
10    pub source_id: String,
11    pub content_preview: String,
12    pub embedding: Vec<f32>,
13    pub created_at: String,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct SearchResult {
18    pub source_table: String,
19    pub source_id: String,
20    pub content_preview: String,
21    pub similarity: f64,
22}
23
24pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
25    if a.len() != b.len() || a.is_empty() {
26        return 0.0;
27    }
28
29    let mut dot = 0.0f64;
30    let mut norm_a = 0.0f64;
31    let mut norm_b = 0.0f64;
32
33    for i in 0..a.len() {
34        let ai = a[i] as f64;
35        let bi = b[i] as f64;
36        dot += ai * bi;
37        norm_a += ai * ai;
38        norm_b += bi * bi;
39    }
40
41    let denom = norm_a.sqrt() * norm_b.sqrt();
42    if denom == 0.0 { 0.0 } else { dot / denom }
43}
44
45/// Serialize `Vec<f32>` to a compact little-endian byte representation.
46pub fn embedding_to_blob(embedding: &[f32]) -> Vec<u8> {
47    let mut bytes = Vec::with_capacity(embedding.len() * 4);
48    for &val in embedding {
49        bytes.extend_from_slice(&val.to_le_bytes());
50    }
51    bytes
52}
53
54/// Deserialize a BLOB back to `Vec<f32>`.
55pub fn blob_to_embedding(blob: &[u8]) -> Vec<f32> {
56    blob.chunks_exact(4)
57        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
58        .collect()
59}
60
61/// Store an embedding using binary BLOB format.
62pub fn store_embedding(
63    db: &Database,
64    id: &str,
65    source_table: &str,
66    source_id: &str,
67    content_preview: &str,
68    embedding: &[f32],
69) -> Result<()> {
70    let blob = embedding_to_blob(embedding);
71    let dimensions = embedding.len() as i64;
72
73    let conn = db.conn();
74    conn.execute(
75        "INSERT OR REPLACE INTO embeddings \
76         (id, source_table, source_id, content_preview, embedding_blob, dimensions) \
77         VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
78        rusqlite::params![
79            id,
80            source_table,
81            source_id,
82            content_preview,
83            blob,
84            dimensions
85        ],
86    )
87    .db_err()?;
88
89    Ok(())
90}
91
92/// Load an embedding from a BLOB column value.
93fn load_embedding_from_row(blob: Option<Vec<u8>>) -> Option<Vec<f32>> {
94    if let Some(b) = blob
95        && !b.is_empty()
96    {
97        return Some(blob_to_embedding(&b));
98    }
99    None
100}
101
102/// Brute-force cosine similarity search over all stored embeddings.
103///
104/// **Complexity**: O(N) where N is the number of stored embeddings. Every row is loaded
105/// into memory and compared. For production workloads with large embedding tables,
106/// use `AnnIndex` (approximate nearest neighbor) instead.
107///
108/// A `LIMIT 10000` cap is applied to the SQL query to prevent unbounded memory usage
109/// while the AnnIndex integration is pending.
110pub fn search_similar(
111    db: &Database,
112    query_embedding: &[f32],
113    limit: usize,
114    min_similarity: f64,
115) -> Result<Vec<SearchResult>> {
116    let conn = db.conn();
117    let mut stmt = conn
118        .prepare(
119            "SELECT source_table, source_id, content_preview, embedding_blob \
120             FROM embeddings LIMIT 10000",
121        )
122        .db_err()?;
123
124    let rows = stmt
125        .query_map([], |row| {
126            Ok((
127                row.get::<_, String>(0)?,
128                row.get::<_, String>(1)?,
129                row.get::<_, String>(2)?,
130                row.get::<_, Option<Vec<u8>>>(3)?,
131            ))
132        })
133        .db_err()?;
134
135    let mut results: Vec<SearchResult> = Vec::new();
136
137    for row in rows {
138        let (source_table, source_id, content_preview, blob) = row.db_err()?;
139
140        let embedding = match load_embedding_from_row(blob) {
141            Some(e) => e,
142            None => continue,
143        };
144
145        let similarity = cosine_similarity(query_embedding, &embedding);
146
147        if similarity >= min_similarity {
148            results.push(SearchResult {
149                source_table,
150                source_id,
151                content_preview,
152                similarity,
153            });
154        }
155    }
156
157    results.sort_by(|a, b| {
158        b.similarity
159            .partial_cmp(&a.similarity)
160            .unwrap_or(std::cmp::Ordering::Equal)
161    });
162    results.truncate(limit);
163
164    Ok(results)
165}
166
167pub fn hybrid_search(
168    db: &Database,
169    query_text: &str,
170    query_embedding: Option<&[f32]>,
171    limit: usize,
172    hybrid_weight: f64,
173) -> Result<Vec<SearchResult>> {
174    let mut fts_results: Vec<SearchResult> = Vec::new();
175
176    {
177        let conn = db.conn();
178        let safe_query = crate::memory::sanitize_fts_query(query_text);
179        let mut stmt = conn
180            .prepare("SELECT content, category FROM memory_fts WHERE memory_fts MATCH ?1 LIMIT ?2")
181            .db_err()?;
182
183        let rows = stmt
184            .query_map(rusqlite::params![safe_query, limit * 2], |row| {
185                Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
186            })
187            .db_err()?;
188
189        for (i, row) in rows.enumerate() {
190            let (content, category) = row.db_err()?;
191            let fts_score = 1.0 - (i as f64 * 0.05).min(0.9);
192            fts_results.push(SearchResult {
193                source_table: category,
194                source_id: String::new(),
195                content_preview: content.chars().take(200).collect(),
196                similarity: fts_score * (1.0 - hybrid_weight),
197            });
198        }
199    }
200
201    if let Some(embedding) = query_embedding {
202        let vec_results = search_similar(db, embedding, limit * 2, 0.0)?;
203        for mut r in vec_results {
204            r.similarity *= hybrid_weight;
205            fts_results.push(r);
206        }
207    }
208
209    fts_results.sort_by(|a, b| {
210        b.similarity
211            .partial_cmp(&a.similarity)
212            .unwrap_or(std::cmp::Ordering::Equal)
213    });
214    fts_results.truncate(limit);
215
216    Ok(fts_results)
217}
218
219/// Delete embeddings whose `source_table + source_id` no longer reference an
220/// existing row in the parent memory table.
221///
222/// Checks `working_memory`, `episodic_memory`, `semantic_memory`,
223/// `procedural_memory`, and `relationship_memory`.  Returns the number of
224/// orphaned rows removed.
225pub fn cleanup_orphaned_embeddings(db: &Database) -> Result<usize> {
226    let conn = db.conn();
227    let deleted = conn
228        .execute(
229            "DELETE FROM embeddings WHERE NOT ( \
230               (source_table = 'working_memory'      AND source_id IN (SELECT id FROM working_memory)) \
231            OR (source_table = 'episodic_memory'      AND source_id IN (SELECT id FROM episodic_memory)) \
232            OR (source_table = 'semantic_memory'      AND source_id IN (SELECT id FROM semantic_memory)) \
233            OR (source_table = 'procedural_memory'    AND source_id IN (SELECT id FROM procedural_memory)) \
234            OR (source_table = 'relationship_memory'  AND source_id IN (SELECT id FROM relationship_memory)) \
235            )",
236            [],
237        )
238        .db_err()?;
239    Ok(deleted)
240}
241
242#[cfg(test)]
243pub(crate) fn embedding_count(db: &Database) -> Result<usize> {
244    let conn = db.conn();
245    let count: usize = conn
246        .query_row("SELECT COUNT(*) FROM embeddings", [], |row| row.get(0))
247        .db_err()?;
248    Ok(count)
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    fn test_db() -> Database {
256        Database::new(":memory:").unwrap()
257    }
258
259    #[test]
260    fn blob_roundtrip() {
261        let original = vec![1.0f32, -0.5, 0.0, 1.23456, f32::MIN, f32::MAX];
262        let blob = embedding_to_blob(&original);
263        let restored = blob_to_embedding(&blob);
264        assert_eq!(original, restored);
265    }
266
267    #[test]
268    fn blob_empty() {
269        let blob = embedding_to_blob(&[]);
270        assert!(blob.is_empty());
271        let restored = blob_to_embedding(&blob);
272        assert!(restored.is_empty());
273    }
274
275    #[test]
276    fn blob_size_is_4x_floats() {
277        let emb = vec![0.0f32; 768];
278        let blob = embedding_to_blob(&emb);
279        assert_eq!(blob.len(), 768 * 4);
280    }
281
282    #[test]
283    fn cosine_identical_vectors() {
284        let v = vec![1.0, 2.0, 3.0];
285        let sim = cosine_similarity(&v, &v);
286        assert!((sim - 1.0).abs() < 1e-6);
287    }
288
289    #[test]
290    fn cosine_orthogonal_vectors() {
291        let a = vec![1.0, 0.0];
292        let b = vec![0.0, 1.0];
293        let sim = cosine_similarity(&a, &b);
294        assert!(sim.abs() < 1e-6);
295    }
296
297    #[test]
298    fn cosine_opposite_vectors() {
299        let a = vec![1.0, 0.0];
300        let b = vec![-1.0, 0.0];
301        let sim = cosine_similarity(&a, &b);
302        assert!((sim - (-1.0)).abs() < 1e-6);
303    }
304
305    #[test]
306    fn cosine_empty_vectors() {
307        let sim = cosine_similarity(&[], &[]);
308        assert_eq!(sim, 0.0);
309    }
310
311    #[test]
312    fn cosine_mismatched_lengths() {
313        let a = vec![1.0, 2.0];
314        let b = vec![1.0];
315        let sim = cosine_similarity(&a, &b);
316        assert_eq!(sim, 0.0);
317    }
318
319    #[test]
320    fn store_and_search() {
321        let db = test_db();
322        let emb1 = vec![1.0, 0.0, 0.0];
323        let emb2 = vec![0.0, 1.0, 0.0];
324        let emb3 = vec![0.9, 0.1, 0.0];
325
326        store_embedding(&db, "e1", "episodic_memory", "ep1", "first entry", &emb1).unwrap();
327        store_embedding(&db, "e2", "episodic_memory", "ep2", "second entry", &emb2).unwrap();
328        store_embedding(&db, "e3", "semantic_memory", "s1", "third entry", &emb3).unwrap();
329
330        let query = vec![1.0, 0.0, 0.0];
331        let results = search_similar(&db, &query, 10, 0.5).unwrap();
332
333        assert_eq!(results.len(), 2);
334        assert_eq!(results[0].source_id, "ep1");
335        assert!((results[0].similarity - 1.0).abs() < 1e-6);
336        assert!(results[1].similarity > 0.5);
337    }
338
339    #[test]
340    fn store_replaces_existing() {
341        let db = test_db();
342        let emb1 = vec![1.0, 0.0];
343        let emb2 = vec![0.0, 1.0];
344        store_embedding(&db, "e1", "episodic_memory", "t1", "v1", &emb1).unwrap();
345        store_embedding(&db, "e1", "episodic_memory", "t1", "v2", &emb2).unwrap();
346        assert_eq!(embedding_count(&db).unwrap(), 1);
347    }
348
349    #[test]
350    fn search_min_similarity_filter() {
351        let db = test_db();
352        store_embedding(&db, "e1", "episodic_memory", "1", "a", &[1.0, 0.0]).unwrap();
353        store_embedding(&db, "e2", "episodic_memory", "2", "b", &[0.0, 1.0]).unwrap();
354
355        let results = search_similar(&db, &[1.0, 0.0], 10, 0.99).unwrap();
356        assert_eq!(results.len(), 1);
357    }
358
359    #[test]
360    fn embedding_count_works() {
361        let db = test_db();
362        assert_eq!(embedding_count(&db).unwrap(), 0);
363        store_embedding(&db, "e1", "episodic_memory", "1", "a", &[1.0]).unwrap();
364        assert_eq!(embedding_count(&db).unwrap(), 1);
365    }
366
367    #[test]
368    fn cosine_zero_vector() {
369        let a = vec![0.0, 0.0];
370        let b = vec![1.0, 0.0];
371        assert_eq!(cosine_similarity(&a, &b), 0.0);
372    }
373
374    #[test]
375    fn hybrid_search_vector_only() {
376        let db = test_db();
377        store_embedding(
378            &db,
379            "e1",
380            "episodic_memory",
381            "t1",
382            "hello world",
383            &[1.0, 0.0, 0.0],
384        )
385        .unwrap();
386        store_embedding(
387            &db,
388            "e2",
389            "episodic_memory",
390            "t2",
391            "goodbye",
392            &[0.0, 1.0, 0.0],
393        )
394        .unwrap();
395
396        let results =
397            hybrid_search(&db, "zzzznonexistent", Some(&[1.0, 0.0, 0.0]), 10, 0.5).unwrap();
398        assert!(!results.is_empty());
399    }
400
401    #[test]
402    fn hybrid_search_empty_db() {
403        let db = test_db();
404        let results = hybrid_search(&db, "anything", Some(&[1.0, 0.0]), 10, 0.5).unwrap();
405        assert!(results.is_empty());
406    }
407
408    #[test]
409    fn hybrid_search_respects_limit() {
410        let db = test_db();
411        for i in 0..20 {
412            store_embedding(
413                &db,
414                &format!("e{i}"),
415                "episodic_memory",
416                &format!("t{i}"),
417                &format!("entry {i}"),
418                &[1.0, 0.0],
419            )
420            .unwrap();
421        }
422        let results = hybrid_search(&db, "entry", Some(&[1.0, 0.0]), 5, 0.5).unwrap();
423        assert!(results.len() <= 5);
424    }
425
426    #[test]
427    fn hybrid_search_no_embedding() {
428        let db = test_db();
429        store_embedding(
430            &db,
431            "e1",
432            "episodic_memory",
433            "t1",
434            "hello world",
435            &[1.0, 0.0],
436        )
437        .unwrap();
438        let results = hybrid_search(&db, "hello", None, 10, 0.5).unwrap();
439        assert!(results.is_empty() || !results.is_empty());
440    }
441
442    #[test]
443    fn hybrid_search_sorted_by_similarity() {
444        let db = test_db();
445        store_embedding(
446            &db,
447            "e1",
448            "episodic_memory",
449            "t1",
450            "first",
451            &[1.0, 0.0, 0.0],
452        )
453        .unwrap();
454        store_embedding(
455            &db,
456            "e2",
457            "episodic_memory",
458            "t2",
459            "second",
460            &[0.5, 0.5, 0.0],
461        )
462        .unwrap();
463        store_embedding(
464            &db,
465            "e3",
466            "episodic_memory",
467            "t3",
468            "third",
469            &[0.0, 0.0, 1.0],
470        )
471        .unwrap();
472
473        let results = hybrid_search(&db, "query", Some(&[1.0, 0.0, 0.0]), 10, 1.0).unwrap();
474        for w in results.windows(2) {
475            assert!(w[0].similarity >= w[1].similarity);
476        }
477    }
478
479    #[test]
480    fn load_embedding_from_blob() {
481        let emb = vec![1.0f32, 2.0, 3.0];
482        let blob = embedding_to_blob(&emb);
483        let loaded = load_embedding_from_row(Some(blob)).unwrap();
484        assert_eq!(loaded, emb);
485    }
486
487    #[test]
488    fn load_embedding_none_returns_none() {
489        let loaded = load_embedding_from_row(None);
490        assert!(loaded.is_none());
491    }
492
493    #[test]
494    fn load_embedding_empty_blob_returns_none() {
495        let loaded = load_embedding_from_row(Some(vec![]));
496        assert!(loaded.is_none());
497    }
498
499    #[test]
500    fn search_similar_skips_row_without_embedding() {
501        let db = test_db();
502        // Insert a row with no embedding data (NULL blob)
503        {
504            let conn = db.conn();
505            conn.execute(
506                "INSERT INTO embeddings (id, source_table, source_id, content_preview, embedding_blob, dimensions) \
507                 VALUES ('e-no-emb', 'episodic_memory', 't1', 'no embedding here', NULL, 0)",
508                [],
509            ).unwrap();
510        }
511        // Also insert one with a real embedding
512        store_embedding(
513            &db,
514            "e-real",
515            "episodic_memory",
516            "t2",
517            "has embedding",
518            &[1.0, 0.0],
519        )
520        .unwrap();
521
522        let results = search_similar(&db, &[1.0, 0.0], 10, 0.0).unwrap();
523        // Should only find the one with a real embedding
524        assert_eq!(results.len(), 1);
525        assert_eq!(results[0].source_id, "t2");
526    }
527
528    #[test]
529    fn hybrid_search_fts_matches() {
530        let db = test_db();
531        // Store data in FTS-indexed tables (working_memory populates memory_fts)
532        crate::memory::store_working(&db, "sess", "note", "quantum computing breakthrough", 5)
533            .unwrap();
534        store_embedding(
535            &db,
536            "e1",
537            "episodic_memory",
538            "t1",
539            "classical computing",
540            &[0.0, 1.0],
541        )
542        .unwrap();
543
544        // Search with FTS query that should match the working memory entry
545        let results = hybrid_search(&db, "quantum", Some(&[1.0, 0.0]), 10, 0.5).unwrap();
546        assert!(
547            !results.is_empty(),
548            "hybrid search should find FTS match for 'quantum'"
549        );
550    }
551
552    #[test]
553    fn hybrid_search_fts_only_no_embedding() {
554        let db = test_db();
555        crate::memory::store_working(&db, "sess", "note", "unique identifier xyzzy", 5).unwrap();
556
557        // Search with only FTS (no embedding provided), weight doesn't matter much
558        let results = hybrid_search(&db, "xyzzy", None, 10, 0.5).unwrap();
559        // FTS results get weighted by (1 - hybrid_weight), so they should appear
560        assert!(
561            !results.is_empty(),
562            "hybrid search without embedding should find FTS results"
563        );
564    }
565
566    #[test]
567    fn hybrid_search_combined_scores() {
568        let db = test_db();
569        crate::memory::store_working(&db, "sess", "note", "machine learning algorithms", 5)
570            .unwrap();
571        store_embedding(
572            &db,
573            "e1",
574            "episodic_memory",
575            "t1",
576            "machine learning",
577            &[1.0, 0.0, 0.0],
578        )
579        .unwrap();
580
581        let results = hybrid_search(&db, "machine", Some(&[1.0, 0.0, 0.0]), 10, 0.5).unwrap();
582        // Should have results from both FTS and vector search
583        assert!(!results.is_empty());
584        // Results should be sorted by similarity desc
585        for w in results.windows(2) {
586            assert!(w[0].similarity >= w[1].similarity);
587        }
588    }
589
590    // ── Orphan cleanup tests ─────────────────────────────────
591
592    #[test]
593    fn cleanup_orphaned_embeddings_removes_dangling() {
594        let db = test_db();
595        // Create a working_memory entry and its embedding (should survive).
596        crate::memory::store_working(&db, "s1", "note", "valid", 5).unwrap();
597        let wm_id = {
598            let conn = db.conn();
599            conn.query_row("SELECT id FROM working_memory LIMIT 1", [], |r| {
600                r.get::<_, String>(0)
601            })
602            .unwrap()
603        };
604        store_embedding(
605            &db,
606            "e-valid",
607            "working_memory",
608            &wm_id,
609            "valid",
610            &[1.0, 0.0],
611        )
612        .unwrap();
613
614        // Create an orphaned embedding pointing at a non-existent source.
615        store_embedding(
616            &db,
617            "e-orphan",
618            "working_memory",
619            "no-such-id",
620            "orphan",
621            &[0.0, 1.0],
622        )
623        .unwrap();
624
625        assert_eq!(embedding_count(&db).unwrap(), 2);
626        let deleted = cleanup_orphaned_embeddings(&db).unwrap();
627        assert_eq!(deleted, 1);
628        assert_eq!(embedding_count(&db).unwrap(), 1);
629    }
630
631    #[test]
632    fn cleanup_orphaned_embeddings_noop_when_clean() {
633        let db = test_db();
634        crate::memory::store_semantic(&db, "facts", "k1", "v1", 0.9).unwrap();
635        let sem_id = {
636            let conn = db.conn();
637            conn.query_row("SELECT id FROM semantic_memory LIMIT 1", [], |r| {
638                r.get::<_, String>(0)
639            })
640            .unwrap()
641        };
642        store_embedding(&db, "e1", "semantic_memory", &sem_id, "valid", &[1.0, 0.0]).unwrap();
643
644        let deleted = cleanup_orphaned_embeddings(&db).unwrap();
645        assert_eq!(deleted, 0);
646    }
647}