1use crate::semantic::model::SemanticModel;
2use anyhow::{Context, Result, bail};
3use bones_core::model::item::WorkItemFields;
4use rusqlite::{Connection, OptionalExtension, params};
5use sha2::{Digest, Sha256};
6use std::collections::{HashMap, HashSet};
7
8const EMBEDDING_DIM: usize = 384;
9const SEMANTIC_META_ID: i64 = 1;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub struct SyncStats {
14 pub embedded: usize,
15 pub removed: usize,
16}
17
18pub struct EmbeddingPipeline<'a> {
20 model: &'a SemanticModel,
21 db: &'a Connection,
22}
23
24impl<'a> EmbeddingPipeline<'a> {
25 pub fn new(model: &'a SemanticModel, db: &'a Connection) -> Result<Self> {
31 ensure_embedding_schema(db)?;
32 Ok(Self { model, db })
33 }
34
35 pub fn embed_item(&self, item: &WorkItemFields) -> Result<bool> {
41 let content = item_content(item);
42 let content_hash = content_hash_hex(&content);
43
44 if has_same_hash(self.db, &item.id, &content_hash)? {
45 return Ok(false);
46 }
47
48 let embedding = self
49 .model
50 .embed(&content)
51 .with_context(|| format!("embedding inference failed for item {}", item.id))?;
52
53 upsert_embedding(self.db, &item.id, &content_hash, &embedding)
54 }
55
56 pub fn embed_all(&self, items: &[WorkItemFields]) -> Result<usize> {
62 let mut pending = Vec::new();
63
64 for item in items {
65 let content = item_content(item);
66 let content_hash = content_hash_hex(&content);
67 if has_same_hash(self.db, &item.id, &content_hash)? {
68 continue;
69 }
70 pending.push((item.id.clone(), content_hash, content));
71 }
72
73 if pending.is_empty() {
74 return Ok(0);
75 }
76
77 let texts: Vec<&str> = pending.iter().map(|(_, _, text)| text.as_str()).collect();
78 let embeddings = self
79 .model
80 .embed_batch(&texts)
81 .context("batch embedding inference failed")?;
82
83 if embeddings.len() != pending.len() {
84 bail!(
85 "embedding batch length mismatch: expected {}, got {}",
86 pending.len(),
87 embeddings.len()
88 );
89 }
90
91 for ((item_id, hash, _), embedding) in pending.iter().zip(embeddings) {
92 upsert_embedding(self.db, item_id, hash, &embedding)?;
93 }
94
95 Ok(pending.len())
96 }
97}
98
99pub fn sync_projection_embeddings(db: &Connection, model: &SemanticModel) -> Result<SyncStats> {
108 ensure_embedding_schema(db)?;
109
110 let projection_cursor = projection_cursor(db)?;
111 let indexed_cursor = semantic_cursor(db)?;
112 let active_items = active_item_count(db)?;
113 let embedded_items = embedding_count(db)?;
114 if should_skip_sync(
115 &indexed_cursor,
116 &projection_cursor,
117 active_items,
118 embedded_items,
119 ) {
120 return Ok(SyncStats::default());
121 }
122
123 let items = load_items_for_embedding(db)?;
124 let live_ids: HashSet<String> = items.iter().map(|(id, _, _)| id.clone()).collect();
125 let existing_hashes = load_existing_hashes(db)?;
126
127 let mut pending = Vec::new();
128 for (item_id, content_hash, content) in &items {
129 if existing_hashes.get(item_id) == Some(content_hash) {
130 continue;
131 }
132 pending.push((item_id.clone(), content_hash.clone(), content.clone()));
133 }
134
135 let embedded = if pending.is_empty() {
136 0
137 } else {
138 let texts: Vec<&str> = pending
139 .iter()
140 .map(|(_, _, content)| content.as_str())
141 .collect();
142 let embeddings = model
143 .embed_batch(&texts)
144 .context("semantic index sync failed during embedding inference")?;
145
146 if embeddings.len() != pending.len() {
147 bail!(
148 "semantic index sync embedding count mismatch: expected {}, got {}",
149 pending.len(),
150 embeddings.len()
151 );
152 }
153
154 for ((item_id, content_hash, _), embedding) in pending.iter().zip(embeddings.iter()) {
155 upsert_embedding(db, item_id, content_hash, embedding)?;
156 }
157 pending.len()
158 };
159
160 let removed = remove_stale_embeddings(db, &live_ids)?;
161 set_semantic_cursor(db, projection_cursor.0, projection_cursor.1.as_deref())?;
162
163 Ok(SyncStats { embedded, removed })
164}
165
166pub fn ensure_semantic_index_schema(db: &Connection) -> Result<()> {
176 ensure_embedding_schema(db)
177}
178
179fn ensure_embedding_schema(db: &Connection) -> Result<()> {
180 db.execute_batch(
181 "
182 CREATE TABLE IF NOT EXISTS item_embeddings (
183 item_id TEXT PRIMARY KEY,
184 content_hash TEXT NOT NULL,
185 embedding_json TEXT NOT NULL
186 );
187
188 CREATE TABLE IF NOT EXISTS semantic_meta (
189 id INTEGER PRIMARY KEY CHECK (id = 1),
190 last_event_offset INTEGER NOT NULL DEFAULT 0,
191 last_event_hash TEXT
192 );
193
194 INSERT OR IGNORE INTO semantic_meta (id, last_event_offset, last_event_hash)
195 VALUES (1, 0, NULL);
196 ",
197 )
198 .context("failed to create semantic index tables")?;
199
200 Ok(())
201}
202
203fn projection_cursor(db: &Connection) -> Result<(i64, Option<String>)> {
204 db.query_row(
205 "SELECT last_event_offset, last_event_hash FROM projection_meta WHERE id = 1",
206 [],
207 |row| Ok((row.get::<_, i64>(0)?, row.get::<_, Option<String>>(1)?)),
208 )
209 .context("failed to read projection cursor for semantic sync")
210}
211
212fn semantic_cursor(db: &Connection) -> Result<(i64, Option<String>)> {
213 db.query_row(
214 "SELECT last_event_offset, last_event_hash FROM semantic_meta WHERE id = ?1",
215 params![SEMANTIC_META_ID],
216 |row| Ok((row.get::<_, i64>(0)?, row.get::<_, Option<String>>(1)?)),
217 )
218 .context("failed to read semantic index cursor")
219}
220
221fn active_item_count(db: &Connection) -> Result<usize> {
222 let count: i64 = db
223 .query_row(
224 "SELECT COUNT(*) FROM items WHERE is_deleted = 0",
225 [],
226 |row| row.get(0),
227 )
228 .context("failed to count active items for semantic sync")?;
229 Ok(usize::try_from(count).unwrap_or(0))
230}
231
232fn embedding_count(db: &Connection) -> Result<usize> {
233 let count: i64 = db
234 .query_row("SELECT COUNT(*) FROM item_embeddings", [], |row| row.get(0))
235 .context("failed to count semantic embeddings")?;
236 Ok(usize::try_from(count).unwrap_or(0))
237}
238
239fn should_skip_sync(
240 indexed_cursor: &(i64, Option<String>),
241 projection_cursor: &(i64, Option<String>),
242 active_items: usize,
243 embedded_items: usize,
244) -> bool {
245 indexed_cursor == projection_cursor && active_items == embedded_items
246}
247
248fn set_semantic_cursor(db: &Connection, offset: i64, hash: Option<&str>) -> Result<()> {
249 db.execute(
250 "UPDATE semantic_meta
251 SET last_event_offset = ?1, last_event_hash = ?2
252 WHERE id = ?3",
253 params![offset, hash, SEMANTIC_META_ID],
254 )
255 .context("failed to update semantic index cursor")?;
256
257 Ok(())
258}
259
260fn load_items_for_embedding(db: &Connection) -> Result<Vec<(String, String, String)>> {
261 let mut stmt = db
262 .prepare(
263 "SELECT item_id, title, description
264 FROM items
265 WHERE is_deleted = 0",
266 )
267 .context("failed to prepare item query for semantic sync")?;
268
269 let rows = stmt
270 .query_map([], |row| {
271 let item_id = row.get::<_, String>(0)?;
272 let title = row.get::<_, String>(1)?;
273 let description = row.get::<_, Option<String>>(2)?;
274 Ok((item_id, title, description))
275 })
276 .context("failed to execute item query for semantic sync")?;
277
278 let mut items = Vec::new();
279 for row in rows {
280 let (item_id, title, description) =
281 row.context("failed to read item row for semantic sync")?;
282 let content = content_from_title_description(&title, description.as_deref());
283 let content_hash = content_hash_hex(&content);
284 items.push((item_id, content_hash, content));
285 }
286
287 Ok(items)
288}
289
290fn load_existing_hashes(db: &Connection) -> Result<HashMap<String, String>> {
291 let mut stmt = db
292 .prepare("SELECT item_id, content_hash FROM item_embeddings")
293 .context("failed to prepare semantic hash query")?;
294 let rows = stmt
295 .query_map([], |row| {
296 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
297 })
298 .context("failed to query semantic hash table")?;
299
300 let mut out = HashMap::new();
301 for row in rows {
302 let (item_id, hash) = row.context("failed to read semantic hash row")?;
303 out.insert(item_id, hash);
304 }
305 Ok(out)
306}
307
308fn remove_stale_embeddings(db: &Connection, live_ids: &HashSet<String>) -> Result<usize> {
309 let mut stmt = db
310 .prepare("SELECT item_id FROM item_embeddings")
311 .context("failed to prepare stale semantic row query")?;
312 let rows = stmt
313 .query_map([], |row| row.get::<_, String>(0))
314 .context("failed to query semantic rows for stale cleanup")?;
315
316 let mut stale = Vec::new();
317 for row in rows {
318 let item_id = row.context("failed to read semantic row id")?;
319 if !live_ids.contains(&item_id) {
320 stale.push(item_id);
321 }
322 }
323
324 for item_id in &stale {
325 db.execute(
326 "DELETE FROM item_embeddings WHERE item_id = ?1",
327 params![item_id],
328 )
329 .with_context(|| format!("failed to delete stale semantic row for {item_id}"))?;
330 }
331
332 Ok(stale.len())
333}
334
335fn has_same_hash(db: &Connection, item_id: &str, content_hash: &str) -> Result<bool> {
336 let existing = db
337 .query_row(
338 "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
339 params![item_id],
340 |row| row.get::<_, String>(0),
341 )
342 .optional()
343 .with_context(|| format!("failed to query content hash for item {item_id}"))?;
344
345 Ok(existing.as_deref() == Some(content_hash))
346}
347
348fn upsert_embedding(
349 db: &Connection,
350 item_id: &str,
351 content_hash: &str,
352 embedding: &[f32],
353) -> Result<bool> {
354 if embedding.len() != EMBEDDING_DIM {
355 bail!(
356 "invalid embedding dimension for item {item_id}: expected {EMBEDDING_DIM}, got {}",
357 embedding.len()
358 );
359 }
360
361 let existing_hash = db
362 .query_row(
363 "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
364 params![item_id],
365 |row| row.get::<_, String>(0),
366 )
367 .optional()
368 .with_context(|| format!("failed to lookup semantic row for item {item_id}"))?;
369
370 if existing_hash.as_deref() == Some(content_hash) {
371 return Ok(false);
372 }
373
374 let encoded_vector = encode_embedding_json(embedding);
375 db.execute(
376 "INSERT INTO item_embeddings (item_id, content_hash, embedding_json)
377 VALUES (?1, ?2, ?3)
378 ON CONFLICT(item_id)
379 DO UPDATE SET content_hash = excluded.content_hash,
380 embedding_json = excluded.embedding_json",
381 params![item_id, content_hash, encoded_vector],
382 )
383 .with_context(|| format!("failed to upsert semantic embedding for item {item_id}"))?;
384
385 Ok(true)
386}
387
388fn item_content(item: &WorkItemFields) -> String {
389 content_from_title_description(&item.title, item.description.as_deref())
390}
391
392fn content_from_title_description(title: &str, description: Option<&str>) -> String {
393 match description {
394 Some(description) if !description.trim().is_empty() => {
395 format!("{} {}", title.trim(), description.trim())
396 }
397 _ => title.trim().to_owned(),
398 }
399}
400
401fn content_hash_hex(content: &str) -> String {
402 let mut hasher = Sha256::new();
403 hasher.update(content.as_bytes());
404 format!("{:x}", hasher.finalize())
405}
406
407fn encode_embedding_json(embedding: &[f32]) -> String {
408 let mut encoded = String::from("[");
409 for (idx, value) in embedding.iter().enumerate() {
410 if idx != 0 {
411 encoded.push(',');
412 }
413 encoded.push_str(&value.to_string());
414 }
415 encoded.push(']');
416 encoded
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 fn seed_schema_for_unit_tests(db: &Connection) -> Result<()> {
424 db.execute_batch(
425 "
426 CREATE TABLE items (
427 item_id TEXT PRIMARY KEY,
428 title TEXT NOT NULL,
429 description TEXT,
430 is_deleted INTEGER NOT NULL DEFAULT 0,
431 updated_at_us INTEGER NOT NULL DEFAULT 0
432 );
433
434 CREATE TABLE projection_meta (
435 id INTEGER PRIMARY KEY,
436 last_event_offset INTEGER NOT NULL,
437 last_event_hash TEXT
438 );
439
440 INSERT INTO projection_meta (id, last_event_offset, last_event_hash)
441 VALUES (1, 0, NULL);
442 ",
443 )?;
444
445 ensure_embedding_schema(db)?;
446 Ok(())
447 }
448
449 fn sample_embedding() -> Vec<f32> {
450 vec![0.25_f32; EMBEDDING_DIM]
451 }
452
453 #[test]
454 fn content_hash_changes_with_content() {
455 let left = content_hash_hex("alpha");
456 let right = content_hash_hex("beta");
457 assert_ne!(left, right);
458 }
459
460 #[test]
461 fn item_content_concatenates_title_and_description() {
462 let item = WorkItemFields {
463 title: "Title".to_string(),
464 description: Some("Description".to_string()),
465 ..WorkItemFields::default()
466 };
467
468 assert_eq!(item_content(&item), "Title Description");
469 }
470
471 #[test]
472 fn upsert_embedding_skips_when_hash_matches() -> Result<()> {
473 let db = Connection::open_in_memory()?;
474 seed_schema_for_unit_tests(&db)?;
475
476 let item_id = "bn-abc";
477 let hash = content_hash_hex("same-content");
478 let embedding = sample_embedding();
479
480 let inserted = upsert_embedding(&db, item_id, &hash, &embedding)?;
481 let skipped = upsert_embedding(&db, item_id, &hash, &embedding)?;
482
483 assert!(inserted);
484 assert!(!skipped);
485
486 let count: i64 =
487 db.query_row("SELECT COUNT(*) FROM item_embeddings", [], |row| row.get(0))?;
488 assert_eq!(count, 1);
489
490 Ok(())
491 }
492
493 #[test]
494 fn upsert_embedding_updates_hash_when_content_changes() -> Result<()> {
495 let db = Connection::open_in_memory()?;
496 seed_schema_for_unit_tests(&db)?;
497
498 let item_id = "bn-def";
499 let first_hash = content_hash_hex("old");
500 let second_hash = content_hash_hex("new");
501
502 upsert_embedding(&db, item_id, &first_hash, &sample_embedding())?;
503 let written = upsert_embedding(&db, item_id, &second_hash, &sample_embedding())?;
504
505 assert!(written);
506
507 let stored_hash: String = db.query_row(
508 "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
509 params![item_id],
510 |row| row.get(0),
511 )?;
512 assert_eq!(stored_hash, second_hash);
513
514 Ok(())
515 }
516
517 #[test]
518 fn sync_projection_embeddings_short_circuits_when_cursor_matches() -> Result<()> {
519 let db = Connection::open_in_memory()?;
520 seed_schema_for_unit_tests(&db)?;
521
522 db.execute(
523 "UPDATE semantic_meta SET last_event_offset = 7, last_event_hash = 'h7' WHERE id = 1",
524 [],
525 )?;
526 db.execute(
527 "UPDATE projection_meta SET last_event_offset = 7, last_event_hash = 'h7' WHERE id = 1",
528 [],
529 )?;
530
531 let model = SemanticModel::load();
532 if let Ok(model) = model {
533 let stats = sync_projection_embeddings(&db, &model)?;
534 assert_eq!(stats, SyncStats::default());
535 }
536
537 Ok(())
538 }
539
540 #[test]
541 fn should_skip_sync_requires_cardinality_match() {
542 let cursor = (7, Some("h7".to_string()));
543 assert!(should_skip_sync(&cursor, &cursor, 0, 0));
544 assert!(should_skip_sync(&cursor, &cursor, 3, 3));
545 assert!(!should_skip_sync(&cursor, &cursor, 3, 0));
546 assert!(!should_skip_sync(&cursor, &cursor, 0, 2));
547 assert!(!should_skip_sync(
548 &cursor,
549 &(8, Some("h8".to_string())),
550 3,
551 3
552 ));
553 }
554}