Skip to main content

bones_search/semantic/
embed.rs

1use crate::semantic::model::SemanticModel;
2use anyhow::{Context, Result, bail};
3use bones_core::model::item::WorkItemFields;
4use rusqlite::{Connection, OptionalExtension, params};
5use sha2::{Digest, Sha256};
6use std::collections::{HashMap, HashSet};
7
8const SEMANTIC_META_ID: i64 = 1;
9
10/// Summary of semantic index synchronization work.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub struct SyncStats {
13    pub embedded: usize,
14    pub removed: usize,
15}
16
17/// Manages embedding computation and semantic index storage.
18pub struct EmbeddingPipeline<'a> {
19    model: &'a SemanticModel,
20    db: &'a Connection,
21    embedding_dim: usize,
22    backend_id: &'static str,
23}
24
25impl<'a> EmbeddingPipeline<'a> {
26    /// Construct a pipeline and ensure semantic tables exist.
27    ///
28    /// # Errors
29    ///
30    /// Returns an error if the database schema creation fails.
31    pub fn new(model: &'a SemanticModel, db: &'a Connection) -> Result<Self> {
32        ensure_embedding_schema(db)?;
33        Ok(Self {
34            model,
35            db,
36            embedding_dim: model.dimensions(),
37            backend_id: model.backend_id(),
38        })
39    }
40
41    /// Embed a single item and upsert its vector if content changed.
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if inference or the database upsert fails.
46    pub fn embed_item(&self, item: &WorkItemFields) -> Result<bool> {
47        let content = item_content(item);
48        let content_hash = content_hash_hex(&content, self.backend_id);
49
50        if has_same_hash(self.db, &item.id, &content_hash)? {
51            return Ok(false);
52        }
53
54        let embedding = self
55            .model
56            .embed(&content)
57            .with_context(|| format!("embedding inference failed for item {}", item.id))?;
58
59        upsert_embedding(
60            self.db,
61            &item.id,
62            &content_hash,
63            &embedding,
64            self.embedding_dim,
65        )
66    }
67
68    /// Batch-embed multiple items.
69    ///
70    /// # Errors
71    ///
72    /// Returns an error if batch inference or any database upsert fails.
73    pub fn embed_all(&self, items: &[WorkItemFields]) -> Result<usize> {
74        let mut pending = Vec::new();
75
76        for item in items {
77            let content = item_content(item);
78            let content_hash = content_hash_hex(&content, self.backend_id);
79            if has_same_hash(self.db, &item.id, &content_hash)? {
80                continue;
81            }
82            pending.push((item.id.clone(), content_hash, content));
83        }
84
85        if pending.is_empty() {
86            return Ok(0);
87        }
88
89        let texts: Vec<&str> = pending.iter().map(|(_, _, text)| text.as_str()).collect();
90        let embeddings = self
91            .model
92            .embed_batch(&texts)
93            .context("batch embedding inference failed")?;
94
95        if embeddings.len() != pending.len() {
96            bail!(
97                "embedding batch length mismatch: expected {}, got {}",
98                pending.len(),
99                embeddings.len()
100            );
101        }
102
103        for ((item_id, hash, _), embedding) in pending.iter().zip(embeddings) {
104            upsert_embedding(self.db, item_id, hash, &embedding, self.embedding_dim)?;
105        }
106
107        Ok(pending.len())
108    }
109}
110
111/// Ensure semantic embeddings are synchronized with the current projection.
112///
113/// This is safe to call before every semantic search request: when no new
114/// events were projected, it returns quickly without recomputing embeddings.
115///
116/// # Errors
117///
118/// Returns an error if the database query or embedding inference fails.
119pub fn sync_projection_embeddings(db: &Connection, model: &SemanticModel) -> Result<SyncStats> {
120    ensure_embedding_schema(db)?;
121
122    let projection_cursor = projection_cursor(db)?;
123    let indexed_cursor = semantic_cursor(db)?;
124    let active_items = active_item_count(db)?;
125    let embedded_items = embedding_count(db)?;
126    if should_skip_sync(
127        &indexed_cursor,
128        &projection_cursor,
129        active_items,
130        embedded_items,
131    ) {
132        return Ok(SyncStats::default());
133    }
134
135    let backend_id = model.backend_id();
136    let embedding_dim = model.dimensions();
137
138    let items = load_items_for_embedding(db, backend_id)?;
139    let live_ids: HashSet<String> = items.iter().map(|(id, _, _)| id.clone()).collect();
140    let existing_hashes = load_existing_hashes(db)?;
141
142    let mut pending = Vec::new();
143    for (item_id, content_hash, content) in &items {
144        if existing_hashes.get(item_id) == Some(content_hash) {
145            continue;
146        }
147        pending.push((item_id.clone(), content_hash.clone(), content.clone()));
148    }
149
150    let embedded = if pending.is_empty() {
151        0
152    } else {
153        let texts: Vec<&str> = pending
154            .iter()
155            .map(|(_, _, content)| content.as_str())
156            .collect();
157        let embeddings = model
158            .embed_batch(&texts)
159            .context("semantic index sync failed during embedding inference")?;
160
161        if embeddings.len() != pending.len() {
162            bail!(
163                "semantic index sync embedding count mismatch: expected {}, got {}",
164                pending.len(),
165                embeddings.len()
166            );
167        }
168
169        for ((item_id, content_hash, _), embedding) in pending.iter().zip(embeddings.iter()) {
170            upsert_embedding(db, item_id, content_hash, embedding, embedding_dim)?;
171        }
172        pending.len()
173    };
174
175    let removed = remove_stale_embeddings(db, &live_ids)?;
176    set_semantic_cursor(db, projection_cursor.0, projection_cursor.1.as_deref())?;
177
178    Ok(SyncStats { embedded, removed })
179}
180
181/// Ensure semantic index tables exist without running embedding inference.
182///
183/// This is useful for maintenance flows that want predictable schema state
184/// (for example after a projection rebuild) while deferring embedding work
185/// until a semantic query is actually executed.
186///
187/// # Errors
188///
189/// Returns an error if the database schema creation fails.
190pub fn ensure_semantic_index_schema(db: &Connection) -> Result<()> {
191    ensure_embedding_schema(db)
192}
193
194fn ensure_embedding_schema(db: &Connection) -> Result<()> {
195    db.execute_batch(
196        "
197        CREATE TABLE IF NOT EXISTS item_embeddings (
198            item_id TEXT PRIMARY KEY,
199            content_hash TEXT NOT NULL,
200            embedding_json TEXT NOT NULL
201        );
202
203        CREATE TABLE IF NOT EXISTS semantic_meta (
204            id INTEGER PRIMARY KEY CHECK (id = 1),
205            last_event_offset INTEGER NOT NULL DEFAULT 0,
206            last_event_hash TEXT
207        );
208
209        INSERT OR IGNORE INTO semantic_meta (id, last_event_offset, last_event_hash)
210        VALUES (1, 0, NULL);
211        ",
212    )
213    .context("failed to create semantic index tables")?;
214
215    Ok(())
216}
217
218fn projection_cursor(db: &Connection) -> Result<(i64, Option<String>)> {
219    db.query_row(
220        "SELECT last_event_offset, last_event_hash FROM projection_meta WHERE id = 1",
221        [],
222        |row| Ok((row.get::<_, i64>(0)?, row.get::<_, Option<String>>(1)?)),
223    )
224    .context("failed to read projection cursor for semantic sync")
225}
226
227fn semantic_cursor(db: &Connection) -> Result<(i64, Option<String>)> {
228    db.query_row(
229        "SELECT last_event_offset, last_event_hash FROM semantic_meta WHERE id = ?1",
230        params![SEMANTIC_META_ID],
231        |row| Ok((row.get::<_, i64>(0)?, row.get::<_, Option<String>>(1)?)),
232    )
233    .context("failed to read semantic index cursor")
234}
235
236fn active_item_count(db: &Connection) -> Result<usize> {
237    let count: i64 = db
238        .query_row(
239            "SELECT COUNT(*) FROM items WHERE is_deleted = 0",
240            [],
241            |row| row.get(0),
242        )
243        .context("failed to count active items for semantic sync")?;
244    Ok(usize::try_from(count).unwrap_or(0))
245}
246
247fn embedding_count(db: &Connection) -> Result<usize> {
248    let count: i64 = db
249        .query_row("SELECT COUNT(*) FROM item_embeddings", [], |row| row.get(0))
250        .context("failed to count semantic embeddings")?;
251    Ok(usize::try_from(count).unwrap_or(0))
252}
253
254fn should_skip_sync(
255    indexed_cursor: &(i64, Option<String>),
256    projection_cursor: &(i64, Option<String>),
257    active_items: usize,
258    embedded_items: usize,
259) -> bool {
260    indexed_cursor == projection_cursor && active_items == embedded_items
261}
262
263fn set_semantic_cursor(db: &Connection, offset: i64, hash: Option<&str>) -> Result<()> {
264    db.execute(
265        "UPDATE semantic_meta
266         SET last_event_offset = ?1, last_event_hash = ?2
267         WHERE id = ?3",
268        params![offset, hash, SEMANTIC_META_ID],
269    )
270    .context("failed to update semantic index cursor")?;
271
272    Ok(())
273}
274
275fn load_items_for_embedding(
276    db: &Connection,
277    backend_id: &str,
278) -> Result<Vec<(String, String, String)>> {
279    let mut stmt = db
280        .prepare(
281            "SELECT item_id, title, description
282             FROM items
283             WHERE is_deleted = 0",
284        )
285        .context("failed to prepare item query for semantic sync")?;
286
287    let rows = stmt
288        .query_map([], |row| {
289            let item_id = row.get::<_, String>(0)?;
290            let title = row.get::<_, String>(1)?;
291            let description = row.get::<_, Option<String>>(2)?;
292            Ok((item_id, title, description))
293        })
294        .context("failed to execute item query for semantic sync")?;
295
296    let mut items = Vec::new();
297    for row in rows {
298        let (item_id, title, description) =
299            row.context("failed to read item row for semantic sync")?;
300        let content = content_from_title_description(&title, description.as_deref());
301        let content_hash = content_hash_hex(&content, backend_id);
302        items.push((item_id, content_hash, content));
303    }
304
305    Ok(items)
306}
307
308fn load_existing_hashes(db: &Connection) -> Result<HashMap<String, String>> {
309    let mut stmt = db
310        .prepare("SELECT item_id, content_hash FROM item_embeddings")
311        .context("failed to prepare semantic hash query")?;
312    let rows = stmt
313        .query_map([], |row| {
314            Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
315        })
316        .context("failed to query semantic hash table")?;
317
318    let mut out = HashMap::new();
319    for row in rows {
320        let (item_id, hash) = row.context("failed to read semantic hash row")?;
321        out.insert(item_id, hash);
322    }
323    Ok(out)
324}
325
326fn remove_stale_embeddings(db: &Connection, live_ids: &HashSet<String>) -> Result<usize> {
327    let mut stmt = db
328        .prepare("SELECT item_id FROM item_embeddings")
329        .context("failed to prepare stale semantic row query")?;
330    let rows = stmt
331        .query_map([], |row| row.get::<_, String>(0))
332        .context("failed to query semantic rows for stale cleanup")?;
333
334    let mut stale = Vec::new();
335    for row in rows {
336        let item_id = row.context("failed to read semantic row id")?;
337        if !live_ids.contains(&item_id) {
338            stale.push(item_id);
339        }
340    }
341
342    for item_id in &stale {
343        db.execute(
344            "DELETE FROM item_embeddings WHERE item_id = ?1",
345            params![item_id],
346        )
347        .with_context(|| format!("failed to delete stale semantic row for {item_id}"))?;
348    }
349
350    Ok(stale.len())
351}
352
353fn has_same_hash(db: &Connection, item_id: &str, content_hash: &str) -> Result<bool> {
354    let existing = db
355        .query_row(
356            "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
357            params![item_id],
358            |row| row.get::<_, String>(0),
359        )
360        .optional()
361        .with_context(|| format!("failed to query content hash for item {item_id}"))?;
362
363    Ok(existing.as_deref() == Some(content_hash))
364}
365
366fn upsert_embedding(
367    db: &Connection,
368    item_id: &str,
369    content_hash: &str,
370    embedding: &[f32],
371    expected_dim: usize,
372) -> Result<bool> {
373    if embedding.len() != expected_dim {
374        bail!(
375            "invalid embedding dimension for item {item_id}: expected {expected_dim}, got {}",
376            embedding.len()
377        );
378    }
379
380    let existing_hash = db
381        .query_row(
382            "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
383            params![item_id],
384            |row| row.get::<_, String>(0),
385        )
386        .optional()
387        .with_context(|| format!("failed to lookup semantic row for item {item_id}"))?;
388
389    if existing_hash.as_deref() == Some(content_hash) {
390        return Ok(false);
391    }
392
393    let encoded_vector = encode_embedding_json(embedding);
394    db.execute(
395        "INSERT INTO item_embeddings (item_id, content_hash, embedding_json)
396         VALUES (?1, ?2, ?3)
397         ON CONFLICT(item_id)
398         DO UPDATE SET content_hash = excluded.content_hash,
399                       embedding_json = excluded.embedding_json",
400        params![item_id, content_hash, encoded_vector],
401    )
402    .with_context(|| format!("failed to upsert semantic embedding for item {item_id}"))?;
403
404    Ok(true)
405}
406
407fn item_content(item: &WorkItemFields) -> String {
408    content_from_title_description(&item.title, item.description.as_deref())
409}
410
411fn content_from_title_description(title: &str, description: Option<&str>) -> String {
412    match description {
413        Some(description) if !description.trim().is_empty() => {
414            format!("{} {}", title.trim(), description.trim())
415        }
416        _ => title.trim().to_owned(),
417    }
418}
419
420fn content_hash_hex(content: &str, backend_id: &str) -> String {
421    let mut hasher = Sha256::new();
422    hasher.update(backend_id.as_bytes());
423    hasher.update(b":");
424    hasher.update(content.as_bytes());
425    format!("{:x}", hasher.finalize())
426}
427
428fn encode_embedding_json(embedding: &[f32]) -> String {
429    let mut encoded = String::from("[");
430    for (idx, value) in embedding.iter().enumerate() {
431        if idx != 0 {
432            encoded.push(',');
433        }
434        encoded.push_str(&value.to_string());
435    }
436    encoded.push(']');
437    encoded
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    fn seed_schema_for_unit_tests(db: &Connection) -> Result<()> {
445        db.execute_batch(
446            "
447            CREATE TABLE items (
448                item_id TEXT PRIMARY KEY,
449                title TEXT NOT NULL,
450                description TEXT,
451                is_deleted INTEGER NOT NULL DEFAULT 0,
452                updated_at_us INTEGER NOT NULL DEFAULT 0
453            );
454
455            CREATE TABLE projection_meta (
456                id INTEGER PRIMARY KEY,
457                last_event_offset INTEGER NOT NULL,
458                last_event_hash TEXT
459            );
460
461            INSERT INTO projection_meta (id, last_event_offset, last_event_hash)
462            VALUES (1, 0, NULL);
463            ",
464        )?;
465
466        ensure_embedding_schema(db)?;
467        Ok(())
468    }
469
470    const TEST_DIM: usize = 384;
471
472    fn sample_embedding() -> Vec<f32> {
473        vec![0.25_f32; TEST_DIM]
474    }
475
476    #[test]
477    fn content_hash_changes_with_content() {
478        let left = content_hash_hex("alpha", "test");
479        let right = content_hash_hex("beta", "test");
480        assert_ne!(left, right);
481    }
482
483    #[test]
484    fn content_hash_changes_with_backend() {
485        let left = content_hash_hex("same", "ort");
486        let right = content_hash_hex("same", "model2vec");
487        assert_ne!(left, right);
488    }
489
490    #[test]
491    fn item_content_concatenates_title_and_description() {
492        let item = WorkItemFields {
493            title: "Title".to_string(),
494            description: Some("Description".to_string()),
495            ..WorkItemFields::default()
496        };
497
498        assert_eq!(item_content(&item), "Title Description");
499    }
500
501    #[test]
502    fn upsert_embedding_skips_when_hash_matches() -> Result<()> {
503        let db = Connection::open_in_memory()?;
504        seed_schema_for_unit_tests(&db)?;
505
506        let item_id = "bn-abc";
507        let hash = content_hash_hex("same-content", "test");
508        let embedding = sample_embedding();
509
510        let inserted = upsert_embedding(&db, item_id, &hash, &embedding, TEST_DIM)?;
511        let skipped = upsert_embedding(&db, item_id, &hash, &embedding, TEST_DIM)?;
512
513        assert!(inserted);
514        assert!(!skipped);
515
516        let count: i64 =
517            db.query_row("SELECT COUNT(*) FROM item_embeddings", [], |row| row.get(0))?;
518        assert_eq!(count, 1);
519
520        Ok(())
521    }
522
523    #[test]
524    fn upsert_embedding_updates_hash_when_content_changes() -> Result<()> {
525        let db = Connection::open_in_memory()?;
526        seed_schema_for_unit_tests(&db)?;
527
528        let item_id = "bn-def";
529        let first_hash = content_hash_hex("old", "test");
530        let second_hash = content_hash_hex("new", "test");
531
532        upsert_embedding(&db, item_id, &first_hash, &sample_embedding(), TEST_DIM)?;
533        let written = upsert_embedding(&db, item_id, &second_hash, &sample_embedding(), TEST_DIM)?;
534
535        assert!(written);
536
537        let stored_hash: String = db.query_row(
538            "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
539            params![item_id],
540            |row| row.get(0),
541        )?;
542        assert_eq!(stored_hash, second_hash);
543
544        Ok(())
545    }
546
547    #[test]
548    fn sync_projection_embeddings_short_circuits_when_cursor_matches() -> Result<()> {
549        let db = Connection::open_in_memory()?;
550        seed_schema_for_unit_tests(&db)?;
551
552        db.execute(
553            "UPDATE semantic_meta SET last_event_offset = 7, last_event_hash = 'h7' WHERE id = 1",
554            [],
555        )?;
556        db.execute(
557            "UPDATE projection_meta SET last_event_offset = 7, last_event_hash = 'h7' WHERE id = 1",
558            [],
559        )?;
560
561        let model = SemanticModel::load();
562        if let Ok(model) = model {
563            let stats = sync_projection_embeddings(&db, &model)?;
564            assert_eq!(stats, SyncStats::default());
565        }
566
567        Ok(())
568    }
569
570    #[test]
571    fn should_skip_sync_requires_cardinality_match() {
572        let cursor = (7, Some("h7".to_string()));
573        assert!(should_skip_sync(&cursor, &cursor, 0, 0));
574        assert!(should_skip_sync(&cursor, &cursor, 3, 3));
575        assert!(!should_skip_sync(&cursor, &cursor, 3, 0));
576        assert!(!should_skip_sync(&cursor, &cursor, 0, 2));
577        assert!(!should_skip_sync(
578            &cursor,
579            &(8, Some("h8".to_string())),
580            3,
581            3
582        ));
583    }
584}