1use std::collections::HashMap;
7use std::path::Path;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::mpsc::{Receiver, Sender};
11
12use tracing::{debug, error, info, warn};
13
14use crate::indexer::semantic::{
15 EmbeddingInput, SemanticIndexer, message_id_from_db, saturating_u32_from_i64,
16};
17use crate::search::canonicalize::{canonicalize_for_embedding, content_hash};
18use crate::search::fastembed_embedder::FastEmbedder;
19use crate::search::vector_index::{
20 VectorIndex, parse_semantic_doc_id, role_code_from_str, vector_index_path,
21};
22use crate::storage::sqlite::FrankenStorage;
23
24const HASH_EMBEDDER_MODEL: &str = "hash";
25const DEFAULT_SEMANTIC_MODEL: &str = "minilm";
26
27#[derive(Debug, Clone)]
29pub struct EmbeddingJobConfig {
30 pub db_path: String,
31 pub index_path: String,
32 pub two_tier: bool,
33 pub fast_model: Option<String>,
34 pub quality_model: Option<String>,
35}
36
37impl EmbeddingJobConfig {
38 fn fast_pass_model(&self) -> String {
39 self.fast_model
40 .clone()
41 .unwrap_or_else(|| HASH_EMBEDDER_MODEL.to_string())
42 }
43
44 fn quality_pass_model(&self) -> String {
45 self.quality_model
46 .clone()
47 .unwrap_or_else(|| DEFAULT_SEMANTIC_MODEL.to_string())
48 }
49
50 fn single_pass_model(&self) -> String {
51 self.quality_model
52 .clone()
53 .or_else(|| self.fast_model.clone())
54 .unwrap_or_else(|| HASH_EMBEDDER_MODEL.to_string())
55 }
56}
57
58#[derive(Debug)]
60pub enum WorkerMessage {
61 Submit(EmbeddingJobConfig),
63 Cancel {
65 db_path: String,
66 model_id: Option<String>,
67 },
68 Shutdown,
70}
71
72#[derive(Clone)]
74pub struct EmbeddingWorkerHandle {
75 sender: Sender<WorkerMessage>,
76 cancel_flag: Arc<AtomicBool>,
79}
80
81impl EmbeddingWorkerHandle {
82 pub fn submit(&self, config: EmbeddingJobConfig) -> Result<(), String> {
84 self.sender
85 .send(WorkerMessage::Submit(config))
86 .map_err(|e| format!("worker channel closed: {e}"))
87 }
88
89 pub fn cancel(&self, db_path: String, model_id: Option<String>) -> Result<(), String> {
94 self.cancel_flag.store(true, Ordering::SeqCst);
95 self.sender
96 .send(WorkerMessage::Cancel { db_path, model_id })
97 .map_err(|e| format!("worker channel closed: {e}"))
98 }
99
100 pub fn shutdown(&self) -> Result<(), String> {
102 self.sender
103 .send(WorkerMessage::Shutdown)
104 .map_err(|e| format!("worker channel closed: {e}"))
105 }
106}
107
108pub struct EmbeddingWorker {
110 receiver: Receiver<WorkerMessage>,
111 cancel_flag: Arc<AtomicBool>,
112}
113
114#[derive(Debug, Clone, PartialEq, Eq)]
115enum WorkerEmbedderKind {
116 Hash,
117 FastEmbed {
118 model_name: String,
119 embedder_id: String,
120 },
121}
122
123fn resolve_embedder_kind(
124 model_name: &str,
125 use_semantic: bool,
126) -> anyhow::Result<WorkerEmbedderKind> {
127 if !use_semantic
128 || model_name.eq_ignore_ascii_case(HASH_EMBEDDER_MODEL)
129 || model_name.eq_ignore_ascii_case("fnv1a-384")
130 {
131 return Ok(WorkerEmbedderKind::Hash);
132 }
133
134 let normalized_name = match model_name.to_ascii_lowercase().as_str() {
135 "fastembed" | "minilm" | "minilm-384" | "all-minilm-l6-v2" => DEFAULT_SEMANTIC_MODEL,
136 "snowflake-arctic-s" | "snowflake-arctic-s-384" | "snowflake-arctic-embed-s" => {
137 "snowflake-arctic-s"
138 }
139 "nomic-embed" | "nomic-embed-768" | "nomic-embed-text-v1.5" => "nomic-embed",
140 _ => {
141 anyhow::bail!(
142 "unsupported semantic model '{model_name}' for daemon embedding worker; supported: minilm, snowflake-arctic-s, nomic-embed"
143 );
144 }
145 };
146
147 let config = FastEmbedder::config_for(normalized_name).ok_or_else(|| {
148 anyhow::anyhow!("missing FastEmbedder config for registered model '{normalized_name}'")
149 })?;
150 Ok(WorkerEmbedderKind::FastEmbed {
151 model_name: normalized_name.to_string(),
152 embedder_id: config.embedder_id,
153 })
154}
155
156fn saturating_i64_from_usize(raw: usize) -> i64 {
157 i64::try_from(raw).unwrap_or(i64::MAX)
158}
159
160impl EmbeddingWorker {
161 pub fn new() -> (Self, EmbeddingWorkerHandle) {
163 let (sender, receiver) = std::sync::mpsc::channel();
164 let cancel_flag = Arc::new(AtomicBool::new(false));
165 let handle = EmbeddingWorkerHandle {
166 sender,
167 cancel_flag: Arc::clone(&cancel_flag),
168 };
169 let worker = Self {
170 receiver,
171 cancel_flag,
172 };
173 (worker, handle)
174 }
175
176 pub fn run(self) {
178 info!("Embedding worker started");
179 while let Ok(msg) = self.receiver.recv() {
180 match msg {
181 WorkerMessage::Submit(config) => {
182 self.cancel_flag.store(false, Ordering::SeqCst);
183 info!(db_path = %config.db_path, two_tier = config.two_tier, "Processing embedding job");
184 if let Err(e) = self.process_job(&config) {
185 error!(db_path = %config.db_path, error = %e, "Embedding job failed");
186 }
187 }
188 WorkerMessage::Cancel { db_path, model_id } => {
189 info!(%db_path, ?model_id, "Processing cancel — flag already set by handle");
192 if let Err(e) = Self::cancel_in_db(&db_path, model_id.as_deref()) {
194 warn!(%db_path, error = %e, "Failed to cancel jobs in database");
195 }
196 }
197 WorkerMessage::Shutdown => {
198 info!("Embedding worker shutting down");
199 break;
200 }
201 }
202 }
203 info!("Embedding worker stopped");
204 }
205
206 fn cancel_in_db(db_path: &str, model_id: Option<&str>) -> anyhow::Result<()> {
208 let storage = FrankenStorage::open(Path::new(db_path))?;
209 storage.cancel_embedding_jobs(db_path, model_id)?;
210 Ok(())
211 }
212
213 fn process_job(&self, config: &EmbeddingJobConfig) -> anyhow::Result<()> {
215 let db_path = Path::new(&config.db_path);
216 let index_path = Path::new(&config.index_path);
217
218 let storage = FrankenStorage::open(db_path)?;
220 let messages = storage.fetch_messages_for_embedding()?;
221 let total_docs = saturating_i64_from_usize(messages.len());
222
223 if total_docs == 0 {
224 info!(db_path = %config.db_path, "No messages to embed");
225 return Ok(());
226 }
227
228 info!(
229 db_path = %config.db_path,
230 total_docs,
231 two_tier = config.two_tier,
232 "Found messages to embed"
233 );
234
235 let passes = self.build_passes(config);
237
238 for (model_name, use_semantic) in &passes {
239 if self.cancel_flag.load(Ordering::SeqCst) {
240 info!("Embedding job cancelled");
241 return Ok(());
242 }
243
244 let job_id = storage.upsert_embedding_job(&config.db_path, model_name, total_docs)?;
245 storage.start_embedding_job(job_id)?;
246
247 match self.generate_embeddings_and_save(
248 &storage,
249 &messages,
250 model_name,
251 *use_semantic,
252 job_id,
253 index_path,
254 ) {
255 Ok(()) => {
256 storage.complete_embedding_job(job_id)?;
257 info!(model = model_name, "Embedding pass completed");
258 }
259 Err(e) => {
260 let err_msg = format!("{e:#}");
261 storage.fail_embedding_job(job_id, &err_msg)?;
262 warn!(model = model_name, error = %e, "Embedding pass failed");
263 }
264 }
265 }
266
267 Ok(())
268 }
269
270 fn build_passes(&self, config: &EmbeddingJobConfig) -> Vec<(String, bool)> {
272 let mut passes = Vec::new();
273
274 if config.two_tier {
275 let fast = config.fast_pass_model();
277 passes.push((fast, false));
278
279 let quality = config.quality_pass_model();
281 passes.push((quality, true));
282 } else {
283 let model = config.single_pass_model();
285 let is_semantic = model != HASH_EMBEDDER_MODEL;
286 passes.push((model, is_semantic));
287 }
288
289 passes
290 }
291
292 fn generate_embeddings_and_save(
294 &self,
295 storage: &FrankenStorage,
296 messages: &[crate::storage::sqlite::MessageForEmbedding],
297 model_name: &str,
298 use_semantic: bool,
299 job_id: i64,
300 index_path: &Path,
301 ) -> anyhow::Result<()> {
302 let embedder_kind = resolve_embedder_kind(model_name, use_semantic)?;
303
304 let existing_hashes = self.load_existing_hashes(index_path, &embedder_kind);
306
307 let mut inputs: Vec<EmbeddingInput> = Vec::new();
309 let mut skipped_count = 0usize;
310 let mut completed = 0i64;
311
312 for msg in messages {
313 if self.cancel_flag.load(Ordering::SeqCst) {
314 return Err(anyhow::anyhow!("job cancelled"));
315 }
316
317 let canonical = canonicalize_for_embedding(&msg.content);
318 if canonical.is_empty() {
319 completed += 1;
320 continue;
321 }
322
323 let hash = content_hash(&canonical);
324 let role = role_code_from_str(&msg.role).unwrap_or(0);
325
326 let Some(message_id) = message_id_from_db(msg.message_id) else {
328 warn!(
329 raw_message_id = msg.message_id,
330 "Skipping message with out-of-range id during embedding"
331 );
332 completed += 1;
333 continue;
334 };
335
336 if let Some(existing_hash) = existing_hashes.get(&message_id)
338 && *existing_hash == hash
339 {
340 skipped_count += 1;
341 completed += 1;
342 continue;
343 }
344
345 let agent_id = saturating_u32_from_i64(msg.agent_id);
347 let workspace_id = saturating_u32_from_i64(msg.workspace_id.unwrap_or(0));
348
349 inputs.push(EmbeddingInput {
350 message_id,
351 created_at_ms: msg.created_at.unwrap_or(0),
352 agent_id,
353 workspace_id,
354 source_id: msg.source_id_hash,
355 role,
356 chunk_idx: 0,
357 content: canonical,
358 });
359
360 completed += 1;
361 if completed % 100 == 0 {
362 let _ = storage.update_job_progress(job_id, completed);
363 debug!(job_id, completed, "Embedding progress");
364 }
365 }
366
367 if inputs.is_empty() {
368 let final_completed = saturating_i64_from_usize(messages.len());
369 let _ = storage.update_job_progress(job_id, final_completed);
370 info!(
371 model = model_name,
372 skipped = skipped_count,
373 "No documents to embed - all unchanged"
374 );
375 return Ok(());
376 }
377
378 info!(
379 model = model_name,
380 input_count = inputs.len(),
381 skipped = skipped_count,
382 "Embedding documents"
383 );
384
385 let indexer = match embedder_kind {
387 WorkerEmbedderKind::Hash => SemanticIndexer::new(HASH_EMBEDDER_MODEL, None)?,
388 WorkerEmbedderKind::FastEmbed { ref model_name, .. } => {
389 SemanticIndexer::new(model_name, Some(index_path))?
390 }
391 };
392
393 let embedded = indexer.embed_messages(&inputs)?;
395
396 let final_completed = saturating_i64_from_usize(messages.len());
398 let _ = storage.update_job_progress(job_id, final_completed);
399
400 let save_path = vector_index_path(index_path, indexer.embedder_id());
404 if save_path.exists() {
405 let appended = indexer.append_to_index(embedded, index_path)?;
406 info!(appended, "Appended to existing vector index");
407 } else {
408 let _index = indexer.build_and_save_index(embedded, index_path)?;
409 }
410
411 info!(
412 model = model_name,
413 path = %save_path.display(),
414 count = inputs.len(),
415 "Saved vector index"
416 );
417
418 Ok(())
419 }
420
421 fn load_existing_hashes(
423 &self,
424 index_path: &Path,
425 embedder_kind: &WorkerEmbedderKind,
426 ) -> HashMap<u64, [u8; 32]> {
427 let embedder_id = match embedder_kind {
428 WorkerEmbedderKind::Hash => "fnv1a-384",
429 WorkerEmbedderKind::FastEmbed { embedder_id, .. } => embedder_id.as_str(),
430 };
431
432 let fsvi_path = vector_index_path(index_path, embedder_id);
433
434 if !fsvi_path.exists() {
435 return HashMap::new();
436 }
437
438 match VectorIndex::open(&fsvi_path) {
439 Ok(index) => {
440 let mut hashes = HashMap::new();
441 for idx in 0..index.record_count() {
442 let doc_id_str = match index.doc_id_at(idx) {
443 Ok(doc_id) => doc_id,
444 Err(_) => continue,
445 };
446
447 if let Some(parsed) = parse_semantic_doc_id(doc_id_str)
448 && let Some(hash) = parsed.content_hash
449 {
450 hashes.insert(parsed.message_id, hash);
451 }
452 }
453 debug!(
454 path = %fsvi_path.display(),
455 count = hashes.len(),
456 "Loaded existing hashes for dedup"
457 );
458 hashes
459 }
460 Err(e) => {
461 warn!(
462 path = %fsvi_path.display(),
463 error = %e,
464 "Failed to load existing index for dedup"
465 );
466 HashMap::new()
467 }
468 }
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 fn build_pass_config(
477 two_tier: bool,
478 fast_model: Option<&str>,
479 quality_model: Option<&str>,
480 ) -> EmbeddingJobConfig {
481 EmbeddingJobConfig {
482 db_path: String::new(),
483 index_path: String::new(),
484 two_tier,
485 fast_model: fast_model.map(str::to_string),
486 quality_model: quality_model.map(str::to_string),
487 }
488 }
489
490 fn fast_embed_kind(model_name: &str, embedder_id: &str) -> WorkerEmbedderKind {
491 WorkerEmbedderKind::FastEmbed {
492 model_name: model_name.to_string(),
493 embedder_id: embedder_id.to_string(),
494 }
495 }
496
497 #[test]
498 fn test_worker_handle_clone() {
499 let (_worker, handle) = EmbeddingWorker::new();
500 let handle2 = handle.clone();
501 assert!(handle.shutdown().is_ok());
503 drop(handle2);
506 }
507
508 #[test]
509 fn test_job_config() {
510 let config = EmbeddingJobConfig {
511 db_path: "/tmp/test.db".to_string(),
512 index_path: "/tmp/test_index".to_string(),
513 two_tier: true,
514 fast_model: Some("hash".to_string()),
515 quality_model: Some("minilm".to_string()),
516 };
517 assert!(config.two_tier);
518 assert_eq!(config.fast_model.as_deref(), Some("hash"));
519 assert_eq!(config.quality_model.as_deref(), Some("minilm"));
520 }
521
522 #[test]
523 fn test_build_passes_single() {
524 let (_worker, _handle) = EmbeddingWorker::new();
525 let config = build_pass_config(false, None, Some("minilm"));
526 let passes = _worker.build_passes(&config);
527 assert_eq!(passes.len(), 1);
528 assert_eq!(passes[0].0, "minilm");
529 assert!(passes[0].1); }
531
532 #[test]
533 fn test_build_passes_two_tier() {
534 let (_worker, _handle) = EmbeddingWorker::new();
535 let config = build_pass_config(true, Some("hash"), Some("minilm"));
536 let passes = _worker.build_passes(&config);
537 assert_eq!(passes.len(), 2);
538 assert_eq!(passes[0].0, "hash");
539 assert!(!passes[0].1); assert_eq!(passes[1].0, "minilm");
541 assert!(passes[1].1); }
543
544 #[test]
545 fn test_build_passes_defaults() {
546 let (_worker, _handle) = EmbeddingWorker::new();
547 let config = build_pass_config(false, None, None);
548 let passes = _worker.build_passes(&config);
549 assert_eq!(passes.len(), 1);
550 assert_eq!(passes[0].0, "hash");
551 assert!(!passes[0].1); }
553
554 #[test]
555 fn test_message_id_from_db_rejects_negative_ids() {
556 assert_eq!(message_id_from_db(-1), None);
557 assert_eq!(message_id_from_db(0), Some(0));
558 assert_eq!(message_id_from_db(42), Some(42));
559 }
560
561 #[test]
562 fn test_saturating_u32_from_i64_clamps_bounds() {
563 assert_eq!(saturating_u32_from_i64(-7), 0);
564 assert_eq!(saturating_u32_from_i64(0), 0);
565 assert_eq!(saturating_u32_from_i64(7), 7);
566 assert_eq!(saturating_u32_from_i64(i64::from(u32::MAX) + 123), u32::MAX);
567 }
568
569 #[test]
570 fn test_saturating_i64_from_usize_clamps_overflow() {
571 assert_eq!(saturating_i64_from_usize(0), 0);
572 assert_eq!(saturating_i64_from_usize(7), 7);
573 assert_eq!(
574 saturating_i64_from_usize(usize::MAX),
575 i64::try_from(usize::MAX).unwrap_or(i64::MAX)
576 );
577 }
578
579 #[test]
580 fn test_resolve_embedder_kind_hash_aliases() {
581 assert_eq!(
582 resolve_embedder_kind("hash", false).unwrap(),
583 WorkerEmbedderKind::Hash
584 );
585 assert_eq!(
586 resolve_embedder_kind("FNV1A-384", true).unwrap(),
587 WorkerEmbedderKind::Hash
588 );
589 }
590
591 #[test]
603 fn test_resolve_embedder_kind_use_semantic_false_short_circuits_regardless_of_name() {
604 for semantic_name in [
605 "minilm",
606 "minilm-384",
607 "all-minilm-l6-v2",
608 "fastembed",
609 "snowflake-arctic-s",
610 "snowflake-arctic-embed-s",
611 "nomic-embed",
612 "nomic-embed-text-v1.5",
613 "MINILM",
614 ] {
615 assert_eq!(
616 resolve_embedder_kind(semantic_name, false).unwrap(),
617 WorkerEmbedderKind::Hash,
618 "use_semantic=false MUST short-circuit to Hash regardless of model_name; \
619 regression on name {semantic_name:?} indicates the !use_semantic branch \
620 was bypassed"
621 );
622 }
623 }
624
625 #[test]
626 fn test_resolve_embedder_kind_semantic_aliases() {
627 assert_eq!(
628 resolve_embedder_kind("minilm", true).unwrap(),
629 fast_embed_kind("minilm", "minilm-384")
630 );
631 assert_eq!(
632 resolve_embedder_kind("MINILM-384", true).unwrap(),
633 fast_embed_kind("minilm", "minilm-384")
634 );
635 assert_eq!(
636 resolve_embedder_kind("fastembed", true).unwrap(),
637 fast_embed_kind("minilm", "minilm-384")
638 );
639 }
640
641 #[test]
642 fn test_resolve_embedder_kind_registered_fastembed_models() {
643 assert_eq!(
644 resolve_embedder_kind("snowflake-arctic-s", true).unwrap(),
645 fast_embed_kind("snowflake-arctic-s", "snowflake-arctic-s-384")
646 );
647 assert_eq!(
648 resolve_embedder_kind("nomic-embed-text-v1.5", true).unwrap(),
649 fast_embed_kind("nomic-embed", "nomic-embed-768")
650 );
651 }
652
653 #[test]
654 fn test_resolve_embedder_kind_rejects_unknown_semantic_model() {
655 let err = resolve_embedder_kind("e5-large", true).unwrap_err();
656 let msg = format!("{err:#}");
657 assert!(msg.contains("unsupported semantic model"));
658 }
659}