Skip to main content

bones_search/semantic/
embed.rs

1use crate::semantic::model::SemanticModel;
2use anyhow::{Context, Result, bail};
3use bones_core::model::item::WorkItemFields;
4use rusqlite::{Connection, OptionalExtension, params};
5use sha2::{Digest, Sha256};
6use std::collections::{HashMap, HashSet};
7
8const EMBEDDING_DIM: usize = 384;
9const SEMANTIC_META_ID: i64 = 1;
10
11/// Summary of semantic index synchronization work.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub struct SyncStats {
14    pub embedded: usize,
15    pub removed: usize,
16}
17
18/// Manages embedding computation and semantic index storage.
19pub struct EmbeddingPipeline<'a> {
20    model: &'a SemanticModel,
21    db: &'a Connection,
22}
23
24impl<'a> EmbeddingPipeline<'a> {
25    /// Construct a pipeline and ensure semantic tables exist.
26    ///
27    /// # Errors
28    ///
29    /// Returns an error if the database schema creation fails.
30    pub fn new(model: &'a SemanticModel, db: &'a Connection) -> Result<Self> {
31        ensure_embedding_schema(db)?;
32        Ok(Self { model, db })
33    }
34
35    /// Embed a single item and upsert its vector if content changed.
36    ///
37    /// # Errors
38    ///
39    /// Returns an error if inference or the database upsert fails.
40    pub fn embed_item(&self, item: &WorkItemFields) -> Result<bool> {
41        let content = item_content(item);
42        let content_hash = content_hash_hex(&content);
43
44        if has_same_hash(self.db, &item.id, &content_hash)? {
45            return Ok(false);
46        }
47
48        let embedding = self
49            .model
50            .embed(&content)
51            .with_context(|| format!("embedding inference failed for item {}", item.id))?;
52
53        upsert_embedding(self.db, &item.id, &content_hash, &embedding)
54    }
55
56    /// Batch-embed multiple items.
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if batch inference or any database upsert fails.
61    pub fn embed_all(&self, items: &[WorkItemFields]) -> Result<usize> {
62        let mut pending = Vec::new();
63
64        for item in items {
65            let content = item_content(item);
66            let content_hash = content_hash_hex(&content);
67            if has_same_hash(self.db, &item.id, &content_hash)? {
68                continue;
69            }
70            pending.push((item.id.clone(), content_hash, content));
71        }
72
73        if pending.is_empty() {
74            return Ok(0);
75        }
76
77        let texts: Vec<&str> = pending.iter().map(|(_, _, text)| text.as_str()).collect();
78        let embeddings = self
79            .model
80            .embed_batch(&texts)
81            .context("batch embedding inference failed")?;
82
83        if embeddings.len() != pending.len() {
84            bail!(
85                "embedding batch length mismatch: expected {}, got {}",
86                pending.len(),
87                embeddings.len()
88            );
89        }
90
91        for ((item_id, hash, _), embedding) in pending.iter().zip(embeddings) {
92            upsert_embedding(self.db, item_id, hash, &embedding)?;
93        }
94
95        Ok(pending.len())
96    }
97}
98
99/// Ensure semantic embeddings are synchronized with the current projection.
100///
101/// This is safe to call before every semantic search request: when no new
102/// events were projected, it returns quickly without recomputing embeddings.
103///
104/// # Errors
105///
106/// Returns an error if the database query or embedding inference fails.
107pub fn sync_projection_embeddings(db: &Connection, model: &SemanticModel) -> Result<SyncStats> {
108    ensure_embedding_schema(db)?;
109
110    let projection_cursor = projection_cursor(db)?;
111    let indexed_cursor = semantic_cursor(db)?;
112    let active_items = active_item_count(db)?;
113    let embedded_items = embedding_count(db)?;
114    if should_skip_sync(
115        &indexed_cursor,
116        &projection_cursor,
117        active_items,
118        embedded_items,
119    ) {
120        return Ok(SyncStats::default());
121    }
122
123    let items = load_items_for_embedding(db)?;
124    let live_ids: HashSet<String> = items.iter().map(|(id, _, _)| id.clone()).collect();
125    let existing_hashes = load_existing_hashes(db)?;
126
127    let mut pending = Vec::new();
128    for (item_id, content_hash, content) in &items {
129        if existing_hashes.get(item_id) == Some(content_hash) {
130            continue;
131        }
132        pending.push((item_id.clone(), content_hash.clone(), content.clone()));
133    }
134
135    let embedded = if pending.is_empty() {
136        0
137    } else {
138        let texts: Vec<&str> = pending
139            .iter()
140            .map(|(_, _, content)| content.as_str())
141            .collect();
142        let embeddings = model
143            .embed_batch(&texts)
144            .context("semantic index sync failed during embedding inference")?;
145
146        if embeddings.len() != pending.len() {
147            bail!(
148                "semantic index sync embedding count mismatch: expected {}, got {}",
149                pending.len(),
150                embeddings.len()
151            );
152        }
153
154        for ((item_id, content_hash, _), embedding) in pending.iter().zip(embeddings.iter()) {
155            upsert_embedding(db, item_id, content_hash, embedding)?;
156        }
157        pending.len()
158    };
159
160    let removed = remove_stale_embeddings(db, &live_ids)?;
161    set_semantic_cursor(db, projection_cursor.0, projection_cursor.1.as_deref())?;
162
163    Ok(SyncStats { embedded, removed })
164}
165
166/// Ensure semantic index tables exist without running embedding inference.
167///
168/// This is useful for maintenance flows that want predictable schema state
169/// (for example after a projection rebuild) while deferring embedding work
170/// until a semantic query is actually executed.
171///
172/// # Errors
173///
174/// Returns an error if the database schema creation fails.
175pub fn ensure_semantic_index_schema(db: &Connection) -> Result<()> {
176    ensure_embedding_schema(db)
177}
178
179fn ensure_embedding_schema(db: &Connection) -> Result<()> {
180    db.execute_batch(
181        "
182        CREATE TABLE IF NOT EXISTS item_embeddings (
183            item_id TEXT PRIMARY KEY,
184            content_hash TEXT NOT NULL,
185            embedding_json TEXT NOT NULL
186        );
187
188        CREATE TABLE IF NOT EXISTS semantic_meta (
189            id INTEGER PRIMARY KEY CHECK (id = 1),
190            last_event_offset INTEGER NOT NULL DEFAULT 0,
191            last_event_hash TEXT
192        );
193
194        INSERT OR IGNORE INTO semantic_meta (id, last_event_offset, last_event_hash)
195        VALUES (1, 0, NULL);
196        ",
197    )
198    .context("failed to create semantic index tables")?;
199
200    Ok(())
201}
202
203fn projection_cursor(db: &Connection) -> Result<(i64, Option<String>)> {
204    db.query_row(
205        "SELECT last_event_offset, last_event_hash FROM projection_meta WHERE id = 1",
206        [],
207        |row| Ok((row.get::<_, i64>(0)?, row.get::<_, Option<String>>(1)?)),
208    )
209    .context("failed to read projection cursor for semantic sync")
210}
211
212fn semantic_cursor(db: &Connection) -> Result<(i64, Option<String>)> {
213    db.query_row(
214        "SELECT last_event_offset, last_event_hash FROM semantic_meta WHERE id = ?1",
215        params![SEMANTIC_META_ID],
216        |row| Ok((row.get::<_, i64>(0)?, row.get::<_, Option<String>>(1)?)),
217    )
218    .context("failed to read semantic index cursor")
219}
220
221fn active_item_count(db: &Connection) -> Result<usize> {
222    let count: i64 = db
223        .query_row(
224            "SELECT COUNT(*) FROM items WHERE is_deleted = 0",
225            [],
226            |row| row.get(0),
227        )
228        .context("failed to count active items for semantic sync")?;
229    Ok(usize::try_from(count).unwrap_or(0))
230}
231
232fn embedding_count(db: &Connection) -> Result<usize> {
233    let count: i64 = db
234        .query_row("SELECT COUNT(*) FROM item_embeddings", [], |row| row.get(0))
235        .context("failed to count semantic embeddings")?;
236    Ok(usize::try_from(count).unwrap_or(0))
237}
238
239fn should_skip_sync(
240    indexed_cursor: &(i64, Option<String>),
241    projection_cursor: &(i64, Option<String>),
242    active_items: usize,
243    embedded_items: usize,
244) -> bool {
245    indexed_cursor == projection_cursor && active_items == embedded_items
246}
247
248fn set_semantic_cursor(db: &Connection, offset: i64, hash: Option<&str>) -> Result<()> {
249    db.execute(
250        "UPDATE semantic_meta
251         SET last_event_offset = ?1, last_event_hash = ?2
252         WHERE id = ?3",
253        params![offset, hash, SEMANTIC_META_ID],
254    )
255    .context("failed to update semantic index cursor")?;
256
257    Ok(())
258}
259
260fn load_items_for_embedding(db: &Connection) -> Result<Vec<(String, String, String)>> {
261    let mut stmt = db
262        .prepare(
263            "SELECT item_id, title, description
264             FROM items
265             WHERE is_deleted = 0",
266        )
267        .context("failed to prepare item query for semantic sync")?;
268
269    let rows = stmt
270        .query_map([], |row| {
271            let item_id = row.get::<_, String>(0)?;
272            let title = row.get::<_, String>(1)?;
273            let description = row.get::<_, Option<String>>(2)?;
274            Ok((item_id, title, description))
275        })
276        .context("failed to execute item query for semantic sync")?;
277
278    let mut items = Vec::new();
279    for row in rows {
280        let (item_id, title, description) =
281            row.context("failed to read item row for semantic sync")?;
282        let content = content_from_title_description(&title, description.as_deref());
283        let content_hash = content_hash_hex(&content);
284        items.push((item_id, content_hash, content));
285    }
286
287    Ok(items)
288}
289
290fn load_existing_hashes(db: &Connection) -> Result<HashMap<String, String>> {
291    let mut stmt = db
292        .prepare("SELECT item_id, content_hash FROM item_embeddings")
293        .context("failed to prepare semantic hash query")?;
294    let rows = stmt
295        .query_map([], |row| {
296            Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
297        })
298        .context("failed to query semantic hash table")?;
299
300    let mut out = HashMap::new();
301    for row in rows {
302        let (item_id, hash) = row.context("failed to read semantic hash row")?;
303        out.insert(item_id, hash);
304    }
305    Ok(out)
306}
307
308fn remove_stale_embeddings(db: &Connection, live_ids: &HashSet<String>) -> Result<usize> {
309    let mut stmt = db
310        .prepare("SELECT item_id FROM item_embeddings")
311        .context("failed to prepare stale semantic row query")?;
312    let rows = stmt
313        .query_map([], |row| row.get::<_, String>(0))
314        .context("failed to query semantic rows for stale cleanup")?;
315
316    let mut stale = Vec::new();
317    for row in rows {
318        let item_id = row.context("failed to read semantic row id")?;
319        if !live_ids.contains(&item_id) {
320            stale.push(item_id);
321        }
322    }
323
324    for item_id in &stale {
325        db.execute(
326            "DELETE FROM item_embeddings WHERE item_id = ?1",
327            params![item_id],
328        )
329        .with_context(|| format!("failed to delete stale semantic row for {item_id}"))?;
330    }
331
332    Ok(stale.len())
333}
334
335fn has_same_hash(db: &Connection, item_id: &str, content_hash: &str) -> Result<bool> {
336    let existing = db
337        .query_row(
338            "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
339            params![item_id],
340            |row| row.get::<_, String>(0),
341        )
342        .optional()
343        .with_context(|| format!("failed to query content hash for item {item_id}"))?;
344
345    Ok(existing.as_deref() == Some(content_hash))
346}
347
348fn upsert_embedding(
349    db: &Connection,
350    item_id: &str,
351    content_hash: &str,
352    embedding: &[f32],
353) -> Result<bool> {
354    if embedding.len() != EMBEDDING_DIM {
355        bail!(
356            "invalid embedding dimension for item {item_id}: expected {EMBEDDING_DIM}, got {}",
357            embedding.len()
358        );
359    }
360
361    let existing_hash = db
362        .query_row(
363            "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
364            params![item_id],
365            |row| row.get::<_, String>(0),
366        )
367        .optional()
368        .with_context(|| format!("failed to lookup semantic row for item {item_id}"))?;
369
370    if existing_hash.as_deref() == Some(content_hash) {
371        return Ok(false);
372    }
373
374    let encoded_vector = encode_embedding_json(embedding);
375    db.execute(
376        "INSERT INTO item_embeddings (item_id, content_hash, embedding_json)
377         VALUES (?1, ?2, ?3)
378         ON CONFLICT(item_id)
379         DO UPDATE SET content_hash = excluded.content_hash,
380                       embedding_json = excluded.embedding_json",
381        params![item_id, content_hash, encoded_vector],
382    )
383    .with_context(|| format!("failed to upsert semantic embedding for item {item_id}"))?;
384
385    Ok(true)
386}
387
388fn item_content(item: &WorkItemFields) -> String {
389    content_from_title_description(&item.title, item.description.as_deref())
390}
391
392fn content_from_title_description(title: &str, description: Option<&str>) -> String {
393    match description {
394        Some(description) if !description.trim().is_empty() => {
395            format!("{} {}", title.trim(), description.trim())
396        }
397        _ => title.trim().to_owned(),
398    }
399}
400
401fn content_hash_hex(content: &str) -> String {
402    let mut hasher = Sha256::new();
403    hasher.update(content.as_bytes());
404    format!("{:x}", hasher.finalize())
405}
406
407fn encode_embedding_json(embedding: &[f32]) -> String {
408    let mut encoded = String::from("[");
409    for (idx, value) in embedding.iter().enumerate() {
410        if idx != 0 {
411            encoded.push(',');
412        }
413        encoded.push_str(&value.to_string());
414    }
415    encoded.push(']');
416    encoded
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    fn seed_schema_for_unit_tests(db: &Connection) -> Result<()> {
424        db.execute_batch(
425            "
426            CREATE TABLE items (
427                item_id TEXT PRIMARY KEY,
428                title TEXT NOT NULL,
429                description TEXT,
430                is_deleted INTEGER NOT NULL DEFAULT 0,
431                updated_at_us INTEGER NOT NULL DEFAULT 0
432            );
433
434            CREATE TABLE projection_meta (
435                id INTEGER PRIMARY KEY,
436                last_event_offset INTEGER NOT NULL,
437                last_event_hash TEXT
438            );
439
440            INSERT INTO projection_meta (id, last_event_offset, last_event_hash)
441            VALUES (1, 0, NULL);
442            ",
443        )?;
444
445        ensure_embedding_schema(db)?;
446        Ok(())
447    }
448
449    fn sample_embedding() -> Vec<f32> {
450        vec![0.25_f32; EMBEDDING_DIM]
451    }
452
453    #[test]
454    fn content_hash_changes_with_content() {
455        let left = content_hash_hex("alpha");
456        let right = content_hash_hex("beta");
457        assert_ne!(left, right);
458    }
459
460    #[test]
461    fn item_content_concatenates_title_and_description() {
462        let item = WorkItemFields {
463            title: "Title".to_string(),
464            description: Some("Description".to_string()),
465            ..WorkItemFields::default()
466        };
467
468        assert_eq!(item_content(&item), "Title Description");
469    }
470
471    #[test]
472    fn upsert_embedding_skips_when_hash_matches() -> Result<()> {
473        let db = Connection::open_in_memory()?;
474        seed_schema_for_unit_tests(&db)?;
475
476        let item_id = "bn-abc";
477        let hash = content_hash_hex("same-content");
478        let embedding = sample_embedding();
479
480        let inserted = upsert_embedding(&db, item_id, &hash, &embedding)?;
481        let skipped = upsert_embedding(&db, item_id, &hash, &embedding)?;
482
483        assert!(inserted);
484        assert!(!skipped);
485
486        let count: i64 =
487            db.query_row("SELECT COUNT(*) FROM item_embeddings", [], |row| row.get(0))?;
488        assert_eq!(count, 1);
489
490        Ok(())
491    }
492
493    #[test]
494    fn upsert_embedding_updates_hash_when_content_changes() -> Result<()> {
495        let db = Connection::open_in_memory()?;
496        seed_schema_for_unit_tests(&db)?;
497
498        let item_id = "bn-def";
499        let first_hash = content_hash_hex("old");
500        let second_hash = content_hash_hex("new");
501
502        upsert_embedding(&db, item_id, &first_hash, &sample_embedding())?;
503        let written = upsert_embedding(&db, item_id, &second_hash, &sample_embedding())?;
504
505        assert!(written);
506
507        let stored_hash: String = db.query_row(
508            "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
509            params![item_id],
510            |row| row.get(0),
511        )?;
512        assert_eq!(stored_hash, second_hash);
513
514        Ok(())
515    }
516
517    #[test]
518    fn sync_projection_embeddings_short_circuits_when_cursor_matches() -> Result<()> {
519        let db = Connection::open_in_memory()?;
520        seed_schema_for_unit_tests(&db)?;
521
522        db.execute(
523            "UPDATE semantic_meta SET last_event_offset = 7, last_event_hash = 'h7' WHERE id = 1",
524            [],
525        )?;
526        db.execute(
527            "UPDATE projection_meta SET last_event_offset = 7, last_event_hash = 'h7' WHERE id = 1",
528            [],
529        )?;
530
531        let model = SemanticModel::load();
532        if let Ok(model) = model {
533            let stats = sync_projection_embeddings(&db, &model)?;
534            assert_eq!(stats, SyncStats::default());
535        }
536
537        Ok(())
538    }
539
540    #[test]
541    fn should_skip_sync_requires_cardinality_match() {
542        let cursor = (7, Some("h7".to_string()));
543        assert!(should_skip_sync(&cursor, &cursor, 0, 0));
544        assert!(should_skip_sync(&cursor, &cursor, 3, 3));
545        assert!(!should_skip_sync(&cursor, &cursor, 3, 0));
546        assert!(!should_skip_sync(&cursor, &cursor, 0, 2));
547        assert!(!should_skip_sync(
548            &cursor,
549            &(8, Some("h8".to_string())),
550            3,
551            3
552        ));
553    }
554}