1use crate::semantic::model::SemanticModel;
2use anyhow::{Context, Result, bail};
3use bones_core::model::item::WorkItemFields;
4use rusqlite::{Connection, OptionalExtension, params};
5use sha2::{Digest, Sha256};
6use std::collections::{HashMap, HashSet};
7
8const SEMANTIC_META_ID: i64 = 1;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub struct SyncStats {
13 pub embedded: usize,
14 pub removed: usize,
15}
16
17pub struct EmbeddingPipeline<'a> {
19 model: &'a SemanticModel,
20 db: &'a Connection,
21 embedding_dim: usize,
22 backend_id: &'static str,
23}
24
25impl<'a> EmbeddingPipeline<'a> {
26 pub fn new(model: &'a SemanticModel, db: &'a Connection) -> Result<Self> {
32 ensure_embedding_schema(db)?;
33 Ok(Self {
34 model,
35 db,
36 embedding_dim: model.dimensions(),
37 backend_id: model.backend_id(),
38 })
39 }
40
41 pub fn embed_item(&self, item: &WorkItemFields) -> Result<bool> {
47 let content = item_content(item);
48 let content_hash = content_hash_hex(&content, self.backend_id);
49
50 if has_same_hash(self.db, &item.id, &content_hash)? {
51 return Ok(false);
52 }
53
54 let embedding = self
55 .model
56 .embed(&content)
57 .with_context(|| format!("embedding inference failed for item {}", item.id))?;
58
59 upsert_embedding(
60 self.db,
61 &item.id,
62 &content_hash,
63 &embedding,
64 self.embedding_dim,
65 )
66 }
67
68 pub fn embed_all(&self, items: &[WorkItemFields]) -> Result<usize> {
74 let mut pending = Vec::new();
75
76 for item in items {
77 let content = item_content(item);
78 let content_hash = content_hash_hex(&content, self.backend_id);
79 if has_same_hash(self.db, &item.id, &content_hash)? {
80 continue;
81 }
82 pending.push((item.id.clone(), content_hash, content));
83 }
84
85 if pending.is_empty() {
86 return Ok(0);
87 }
88
89 let texts: Vec<&str> = pending.iter().map(|(_, _, text)| text.as_str()).collect();
90 let embeddings = self
91 .model
92 .embed_batch(&texts)
93 .context("batch embedding inference failed")?;
94
95 if embeddings.len() != pending.len() {
96 bail!(
97 "embedding batch length mismatch: expected {}, got {}",
98 pending.len(),
99 embeddings.len()
100 );
101 }
102
103 for ((item_id, hash, _), embedding) in pending.iter().zip(embeddings) {
104 upsert_embedding(self.db, item_id, hash, &embedding, self.embedding_dim)?;
105 }
106
107 Ok(pending.len())
108 }
109}
110
111pub fn sync_projection_embeddings(db: &Connection, model: &SemanticModel) -> Result<SyncStats> {
120 ensure_embedding_schema(db)?;
121
122 let projection_cursor = projection_cursor(db)?;
123 let indexed_cursor = semantic_cursor(db)?;
124 let active_items = active_item_count(db)?;
125 let embedded_items = embedding_count(db)?;
126 if should_skip_sync(
127 &indexed_cursor,
128 &projection_cursor,
129 active_items,
130 embedded_items,
131 ) {
132 return Ok(SyncStats::default());
133 }
134
135 let backend_id = model.backend_id();
136 let embedding_dim = model.dimensions();
137
138 let items = load_items_for_embedding(db, backend_id)?;
139 let live_ids: HashSet<String> = items.iter().map(|(id, _, _)| id.clone()).collect();
140 let existing_hashes = load_existing_hashes(db)?;
141
142 let mut pending = Vec::new();
143 for (item_id, content_hash, content) in &items {
144 if existing_hashes.get(item_id) == Some(content_hash) {
145 continue;
146 }
147 pending.push((item_id.clone(), content_hash.clone(), content.clone()));
148 }
149
150 let embedded = if pending.is_empty() {
151 0
152 } else {
153 let texts: Vec<&str> = pending
154 .iter()
155 .map(|(_, _, content)| content.as_str())
156 .collect();
157 let embeddings = model
158 .embed_batch(&texts)
159 .context("semantic index sync failed during embedding inference")?;
160
161 if embeddings.len() != pending.len() {
162 bail!(
163 "semantic index sync embedding count mismatch: expected {}, got {}",
164 pending.len(),
165 embeddings.len()
166 );
167 }
168
169 for ((item_id, content_hash, _), embedding) in pending.iter().zip(embeddings.iter()) {
170 upsert_embedding(db, item_id, content_hash, embedding, embedding_dim)?;
171 }
172 pending.len()
173 };
174
175 let removed = remove_stale_embeddings(db, &live_ids)?;
176 set_semantic_cursor(db, projection_cursor.0, projection_cursor.1.as_deref())?;
177
178 Ok(SyncStats { embedded, removed })
179}
180
181pub fn ensure_semantic_index_schema(db: &Connection) -> Result<()> {
191 ensure_embedding_schema(db)
192}
193
194fn ensure_embedding_schema(db: &Connection) -> Result<()> {
195 db.execute_batch(
196 "
197 CREATE TABLE IF NOT EXISTS item_embeddings (
198 item_id TEXT PRIMARY KEY,
199 content_hash TEXT NOT NULL,
200 embedding_json TEXT NOT NULL
201 );
202
203 CREATE TABLE IF NOT EXISTS semantic_meta (
204 id INTEGER PRIMARY KEY CHECK (id = 1),
205 last_event_offset INTEGER NOT NULL DEFAULT 0,
206 last_event_hash TEXT
207 );
208
209 INSERT OR IGNORE INTO semantic_meta (id, last_event_offset, last_event_hash)
210 VALUES (1, 0, NULL);
211 ",
212 )
213 .context("failed to create semantic index tables")?;
214
215 Ok(())
216}
217
218fn projection_cursor(db: &Connection) -> Result<(i64, Option<String>)> {
219 db.query_row(
220 "SELECT last_event_offset, last_event_hash FROM projection_meta WHERE id = 1",
221 [],
222 |row| Ok((row.get::<_, i64>(0)?, row.get::<_, Option<String>>(1)?)),
223 )
224 .context("failed to read projection cursor for semantic sync")
225}
226
227fn semantic_cursor(db: &Connection) -> Result<(i64, Option<String>)> {
228 db.query_row(
229 "SELECT last_event_offset, last_event_hash FROM semantic_meta WHERE id = ?1",
230 params![SEMANTIC_META_ID],
231 |row| Ok((row.get::<_, i64>(0)?, row.get::<_, Option<String>>(1)?)),
232 )
233 .context("failed to read semantic index cursor")
234}
235
236fn active_item_count(db: &Connection) -> Result<usize> {
237 let count: i64 = db
238 .query_row(
239 "SELECT COUNT(*) FROM items WHERE is_deleted = 0",
240 [],
241 |row| row.get(0),
242 )
243 .context("failed to count active items for semantic sync")?;
244 Ok(usize::try_from(count).unwrap_or(0))
245}
246
247fn embedding_count(db: &Connection) -> Result<usize> {
248 let count: i64 = db
249 .query_row("SELECT COUNT(*) FROM item_embeddings", [], |row| row.get(0))
250 .context("failed to count semantic embeddings")?;
251 Ok(usize::try_from(count).unwrap_or(0))
252}
253
254fn should_skip_sync(
255 indexed_cursor: &(i64, Option<String>),
256 projection_cursor: &(i64, Option<String>),
257 active_items: usize,
258 embedded_items: usize,
259) -> bool {
260 indexed_cursor == projection_cursor && active_items == embedded_items
261}
262
263fn set_semantic_cursor(db: &Connection, offset: i64, hash: Option<&str>) -> Result<()> {
264 db.execute(
265 "UPDATE semantic_meta
266 SET last_event_offset = ?1, last_event_hash = ?2
267 WHERE id = ?3",
268 params![offset, hash, SEMANTIC_META_ID],
269 )
270 .context("failed to update semantic index cursor")?;
271
272 Ok(())
273}
274
275fn load_items_for_embedding(
276 db: &Connection,
277 backend_id: &str,
278) -> Result<Vec<(String, String, String)>> {
279 let mut stmt = db
280 .prepare(
281 "SELECT item_id, title, description
282 FROM items
283 WHERE is_deleted = 0",
284 )
285 .context("failed to prepare item query for semantic sync")?;
286
287 let rows = stmt
288 .query_map([], |row| {
289 let item_id = row.get::<_, String>(0)?;
290 let title = row.get::<_, String>(1)?;
291 let description = row.get::<_, Option<String>>(2)?;
292 Ok((item_id, title, description))
293 })
294 .context("failed to execute item query for semantic sync")?;
295
296 let mut items = Vec::new();
297 for row in rows {
298 let (item_id, title, description) =
299 row.context("failed to read item row for semantic sync")?;
300 let content = content_from_title_description(&title, description.as_deref());
301 let content_hash = content_hash_hex(&content, backend_id);
302 items.push((item_id, content_hash, content));
303 }
304
305 Ok(items)
306}
307
308fn load_existing_hashes(db: &Connection) -> Result<HashMap<String, String>> {
309 let mut stmt = db
310 .prepare("SELECT item_id, content_hash FROM item_embeddings")
311 .context("failed to prepare semantic hash query")?;
312 let rows = stmt
313 .query_map([], |row| {
314 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
315 })
316 .context("failed to query semantic hash table")?;
317
318 let mut out = HashMap::new();
319 for row in rows {
320 let (item_id, hash) = row.context("failed to read semantic hash row")?;
321 out.insert(item_id, hash);
322 }
323 Ok(out)
324}
325
326fn remove_stale_embeddings(db: &Connection, live_ids: &HashSet<String>) -> Result<usize> {
327 let mut stmt = db
328 .prepare("SELECT item_id FROM item_embeddings")
329 .context("failed to prepare stale semantic row query")?;
330 let rows = stmt
331 .query_map([], |row| row.get::<_, String>(0))
332 .context("failed to query semantic rows for stale cleanup")?;
333
334 let mut stale = Vec::new();
335 for row in rows {
336 let item_id = row.context("failed to read semantic row id")?;
337 if !live_ids.contains(&item_id) {
338 stale.push(item_id);
339 }
340 }
341
342 for item_id in &stale {
343 db.execute(
344 "DELETE FROM item_embeddings WHERE item_id = ?1",
345 params![item_id],
346 )
347 .with_context(|| format!("failed to delete stale semantic row for {item_id}"))?;
348 }
349
350 Ok(stale.len())
351}
352
353fn has_same_hash(db: &Connection, item_id: &str, content_hash: &str) -> Result<bool> {
354 let existing = db
355 .query_row(
356 "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
357 params![item_id],
358 |row| row.get::<_, String>(0),
359 )
360 .optional()
361 .with_context(|| format!("failed to query content hash for item {item_id}"))?;
362
363 Ok(existing.as_deref() == Some(content_hash))
364}
365
366fn upsert_embedding(
367 db: &Connection,
368 item_id: &str,
369 content_hash: &str,
370 embedding: &[f32],
371 expected_dim: usize,
372) -> Result<bool> {
373 if embedding.len() != expected_dim {
374 bail!(
375 "invalid embedding dimension for item {item_id}: expected {expected_dim}, got {}",
376 embedding.len()
377 );
378 }
379
380 let existing_hash = db
381 .query_row(
382 "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
383 params![item_id],
384 |row| row.get::<_, String>(0),
385 )
386 .optional()
387 .with_context(|| format!("failed to lookup semantic row for item {item_id}"))?;
388
389 if existing_hash.as_deref() == Some(content_hash) {
390 return Ok(false);
391 }
392
393 let encoded_vector = encode_embedding_json(embedding);
394 db.execute(
395 "INSERT INTO item_embeddings (item_id, content_hash, embedding_json)
396 VALUES (?1, ?2, ?3)
397 ON CONFLICT(item_id)
398 DO UPDATE SET content_hash = excluded.content_hash,
399 embedding_json = excluded.embedding_json",
400 params![item_id, content_hash, encoded_vector],
401 )
402 .with_context(|| format!("failed to upsert semantic embedding for item {item_id}"))?;
403
404 Ok(true)
405}
406
407fn item_content(item: &WorkItemFields) -> String {
408 content_from_title_description(&item.title, item.description.as_deref())
409}
410
411fn content_from_title_description(title: &str, description: Option<&str>) -> String {
412 match description {
413 Some(description) if !description.trim().is_empty() => {
414 format!("{} {}", title.trim(), description.trim())
415 }
416 _ => title.trim().to_owned(),
417 }
418}
419
420fn content_hash_hex(content: &str, backend_id: &str) -> String {
421 let mut hasher = Sha256::new();
422 hasher.update(backend_id.as_bytes());
423 hasher.update(b":");
424 hasher.update(content.as_bytes());
425 format!("{:x}", hasher.finalize())
426}
427
428fn encode_embedding_json(embedding: &[f32]) -> String {
429 let mut encoded = String::from("[");
430 for (idx, value) in embedding.iter().enumerate() {
431 if idx != 0 {
432 encoded.push(',');
433 }
434 encoded.push_str(&value.to_string());
435 }
436 encoded.push(']');
437 encoded
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 fn seed_schema_for_unit_tests(db: &Connection) -> Result<()> {
445 db.execute_batch(
446 "
447 CREATE TABLE items (
448 item_id TEXT PRIMARY KEY,
449 title TEXT NOT NULL,
450 description TEXT,
451 is_deleted INTEGER NOT NULL DEFAULT 0,
452 updated_at_us INTEGER NOT NULL DEFAULT 0
453 );
454
455 CREATE TABLE projection_meta (
456 id INTEGER PRIMARY KEY,
457 last_event_offset INTEGER NOT NULL,
458 last_event_hash TEXT
459 );
460
461 INSERT INTO projection_meta (id, last_event_offset, last_event_hash)
462 VALUES (1, 0, NULL);
463 ",
464 )?;
465
466 ensure_embedding_schema(db)?;
467 Ok(())
468 }
469
470 const TEST_DIM: usize = 384;
471
472 fn sample_embedding() -> Vec<f32> {
473 vec![0.25_f32; TEST_DIM]
474 }
475
476 #[test]
477 fn content_hash_changes_with_content() {
478 let left = content_hash_hex("alpha", "test");
479 let right = content_hash_hex("beta", "test");
480 assert_ne!(left, right);
481 }
482
483 #[test]
484 fn content_hash_changes_with_backend() {
485 let left = content_hash_hex("same", "ort");
486 let right = content_hash_hex("same", "model2vec");
487 assert_ne!(left, right);
488 }
489
490 #[test]
491 fn item_content_concatenates_title_and_description() {
492 let item = WorkItemFields {
493 title: "Title".to_string(),
494 description: Some("Description".to_string()),
495 ..WorkItemFields::default()
496 };
497
498 assert_eq!(item_content(&item), "Title Description");
499 }
500
501 #[test]
502 fn upsert_embedding_skips_when_hash_matches() -> Result<()> {
503 let db = Connection::open_in_memory()?;
504 seed_schema_for_unit_tests(&db)?;
505
506 let item_id = "bn-abc";
507 let hash = content_hash_hex("same-content", "test");
508 let embedding = sample_embedding();
509
510 let inserted = upsert_embedding(&db, item_id, &hash, &embedding, TEST_DIM)?;
511 let skipped = upsert_embedding(&db, item_id, &hash, &embedding, TEST_DIM)?;
512
513 assert!(inserted);
514 assert!(!skipped);
515
516 let count: i64 =
517 db.query_row("SELECT COUNT(*) FROM item_embeddings", [], |row| row.get(0))?;
518 assert_eq!(count, 1);
519
520 Ok(())
521 }
522
523 #[test]
524 fn upsert_embedding_updates_hash_when_content_changes() -> Result<()> {
525 let db = Connection::open_in_memory()?;
526 seed_schema_for_unit_tests(&db)?;
527
528 let item_id = "bn-def";
529 let first_hash = content_hash_hex("old", "test");
530 let second_hash = content_hash_hex("new", "test");
531
532 upsert_embedding(&db, item_id, &first_hash, &sample_embedding(), TEST_DIM)?;
533 let written = upsert_embedding(&db, item_id, &second_hash, &sample_embedding(), TEST_DIM)?;
534
535 assert!(written);
536
537 let stored_hash: String = db.query_row(
538 "SELECT content_hash FROM item_embeddings WHERE item_id = ?1",
539 params![item_id],
540 |row| row.get(0),
541 )?;
542 assert_eq!(stored_hash, second_hash);
543
544 Ok(())
545 }
546
547 #[test]
548 fn sync_projection_embeddings_short_circuits_when_cursor_matches() -> Result<()> {
549 let db = Connection::open_in_memory()?;
550 seed_schema_for_unit_tests(&db)?;
551
552 db.execute(
553 "UPDATE semantic_meta SET last_event_offset = 7, last_event_hash = 'h7' WHERE id = 1",
554 [],
555 )?;
556 db.execute(
557 "UPDATE projection_meta SET last_event_offset = 7, last_event_hash = 'h7' WHERE id = 1",
558 [],
559 )?;
560
561 let model = SemanticModel::load();
562 if let Ok(model) = model {
563 let stats = sync_projection_embeddings(&db, &model)?;
564 assert_eq!(stats, SyncStats::default());
565 }
566
567 Ok(())
568 }
569
570 #[test]
571 fn should_skip_sync_requires_cardinality_match() {
572 let cursor = (7, Some("h7".to_string()));
573 assert!(should_skip_sync(&cursor, &cursor, 0, 0));
574 assert!(should_skip_sync(&cursor, &cursor, 3, 3));
575 assert!(!should_skip_sync(&cursor, &cursor, 3, 0));
576 assert!(!should_skip_sync(&cursor, &cursor, 0, 2));
577 assert!(!should_skip_sync(
578 &cursor,
579 &(8, Some("h8".to_string())),
580 3,
581 3
582 ));
583 }
584}