Skip to main content

coding_agent_search/daemon/
worker.rs

1//! Background embedding worker for the daemon.
2//!
3//! Processes embedding jobs on a dedicated thread using sync primitives.
4//! Adapted from xf's async worker to cass's sync daemon architecture.
5
6use std::collections::HashMap;
7use std::path::Path;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::mpsc::{Receiver, Sender};
11
12use tracing::{debug, error, info, warn};
13
14use crate::indexer::semantic::{
15    EmbeddingInput, SemanticIndexer, message_id_from_db, saturating_u32_from_i64,
16};
17use crate::search::canonicalize::{canonicalize_for_embedding, content_hash};
18use crate::search::fastembed_embedder::FastEmbedder;
19use crate::search::vector_index::{
20    VectorIndex, parse_semantic_doc_id, role_code_from_str, vector_index_path,
21};
22use crate::storage::sqlite::FrankenStorage;
23
24const HASH_EMBEDDER_MODEL: &str = "hash";
25const DEFAULT_SEMANTIC_MODEL: &str = "minilm";
26
27/// Configuration for a single embedding job.
28#[derive(Debug, Clone)]
29pub struct EmbeddingJobConfig {
30    pub db_path: String,
31    pub index_path: String,
32    pub two_tier: bool,
33    pub fast_model: Option<String>,
34    pub quality_model: Option<String>,
35}
36
37impl EmbeddingJobConfig {
38    fn fast_pass_model(&self) -> String {
39        self.fast_model
40            .clone()
41            .unwrap_or_else(|| HASH_EMBEDDER_MODEL.to_string())
42    }
43
44    fn quality_pass_model(&self) -> String {
45        self.quality_model
46            .clone()
47            .unwrap_or_else(|| DEFAULT_SEMANTIC_MODEL.to_string())
48    }
49
50    fn single_pass_model(&self) -> String {
51        self.quality_model
52            .clone()
53            .or_else(|| self.fast_model.clone())
54            .unwrap_or_else(|| HASH_EMBEDDER_MODEL.to_string())
55    }
56}
57
58/// Messages sent to the background worker.
59#[derive(Debug)]
60pub enum WorkerMessage {
61    /// Submit a new embedding job.
62    Submit(EmbeddingJobConfig),
63    /// Cancel jobs for a db_path, optionally filtered by model_id.
64    Cancel {
65        db_path: String,
66        model_id: Option<String>,
67    },
68    /// Shut down the worker thread.
69    Shutdown,
70}
71
72/// Handle for sending messages to the background worker.
73#[derive(Clone)]
74pub struct EmbeddingWorkerHandle {
75    sender: Sender<WorkerMessage>,
76    /// Shared cancel flag — set directly from the handle so cancellation
77    /// takes effect even while `process_job` is running on the worker thread.
78    cancel_flag: Arc<AtomicBool>,
79}
80
81impl EmbeddingWorkerHandle {
82    /// Submit an embedding job to the worker.
83    pub fn submit(&self, config: EmbeddingJobConfig) -> Result<(), String> {
84        self.sender
85            .send(WorkerMessage::Submit(config))
86            .map_err(|e| format!("worker channel closed: {e}"))
87    }
88
89    /// Cancel embedding jobs for a db_path.
90    ///
91    /// Sets the cancel flag directly (so the running job sees it immediately)
92    /// AND sends a Cancel message for database-level cleanup.
93    pub fn cancel(&self, db_path: String, model_id: Option<String>) -> Result<(), String> {
94        self.cancel_flag.store(true, Ordering::SeqCst);
95        self.sender
96            .send(WorkerMessage::Cancel { db_path, model_id })
97            .map_err(|e| format!("worker channel closed: {e}"))
98    }
99
100    /// Request the worker to shut down.
101    pub fn shutdown(&self) -> Result<(), String> {
102        self.sender
103            .send(WorkerMessage::Shutdown)
104            .map_err(|e| format!("worker channel closed: {e}"))
105    }
106}
107
108/// Background embedding worker that processes jobs on a dedicated thread.
109pub struct EmbeddingWorker {
110    receiver: Receiver<WorkerMessage>,
111    cancel_flag: Arc<AtomicBool>,
112}
113
114#[derive(Debug, Clone, PartialEq, Eq)]
115enum WorkerEmbedderKind {
116    Hash,
117    FastEmbed {
118        model_name: String,
119        embedder_id: String,
120    },
121}
122
123fn resolve_embedder_kind(
124    model_name: &str,
125    use_semantic: bool,
126) -> anyhow::Result<WorkerEmbedderKind> {
127    if !use_semantic
128        || model_name.eq_ignore_ascii_case(HASH_EMBEDDER_MODEL)
129        || model_name.eq_ignore_ascii_case("fnv1a-384")
130    {
131        return Ok(WorkerEmbedderKind::Hash);
132    }
133
134    let normalized_name = match model_name.to_ascii_lowercase().as_str() {
135        "fastembed" | "minilm" | "minilm-384" | "all-minilm-l6-v2" => DEFAULT_SEMANTIC_MODEL,
136        "snowflake-arctic-s" | "snowflake-arctic-s-384" | "snowflake-arctic-embed-s" => {
137            "snowflake-arctic-s"
138        }
139        "nomic-embed" | "nomic-embed-768" | "nomic-embed-text-v1.5" => "nomic-embed",
140        _ => {
141            anyhow::bail!(
142                "unsupported semantic model '{model_name}' for daemon embedding worker; supported: minilm, snowflake-arctic-s, nomic-embed"
143            );
144        }
145    };
146
147    let config = FastEmbedder::config_for(normalized_name).ok_or_else(|| {
148        anyhow::anyhow!("missing FastEmbedder config for registered model '{normalized_name}'")
149    })?;
150    Ok(WorkerEmbedderKind::FastEmbed {
151        model_name: normalized_name.to_string(),
152        embedder_id: config.embedder_id,
153    })
154}
155
156fn saturating_i64_from_usize(raw: usize) -> i64 {
157    i64::try_from(raw).unwrap_or(i64::MAX)
158}
159
160impl EmbeddingWorker {
161    /// Create a new worker and its handle.
162    pub fn new() -> (Self, EmbeddingWorkerHandle) {
163        let (sender, receiver) = std::sync::mpsc::channel();
164        let cancel_flag = Arc::new(AtomicBool::new(false));
165        let handle = EmbeddingWorkerHandle {
166            sender,
167            cancel_flag: Arc::clone(&cancel_flag),
168        };
169        let worker = Self {
170            receiver,
171            cancel_flag,
172        };
173        (worker, handle)
174    }
175
176    /// Run the worker loop (blocking). Call from a spawned thread.
177    pub fn run(self) {
178        info!("Embedding worker started");
179        while let Ok(msg) = self.receiver.recv() {
180            match msg {
181                WorkerMessage::Submit(config) => {
182                    self.cancel_flag.store(false, Ordering::SeqCst);
183                    info!(db_path = %config.db_path, two_tier = config.two_tier, "Processing embedding job");
184                    if let Err(e) = self.process_job(&config) {
185                        error!(db_path = %config.db_path, error = %e, "Embedding job failed");
186                    }
187                }
188                WorkerMessage::Cancel { db_path, model_id } => {
189                    // The cancel_flag is already set by the handle (so the running
190                    // job sees it immediately). This handler performs DB cleanup.
191                    info!(%db_path, ?model_id, "Processing cancel — flag already set by handle");
192                    // Cancel in the database
193                    if let Err(e) = Self::cancel_in_db(&db_path, model_id.as_deref()) {
194                        warn!(%db_path, error = %e, "Failed to cancel jobs in database");
195                    }
196                }
197                WorkerMessage::Shutdown => {
198                    info!("Embedding worker shutting down");
199                    break;
200                }
201            }
202        }
203        info!("Embedding worker stopped");
204    }
205
206    /// Cancel jobs in the database.
207    fn cancel_in_db(db_path: &str, model_id: Option<&str>) -> anyhow::Result<()> {
208        let storage = FrankenStorage::open(Path::new(db_path))?;
209        storage.cancel_embedding_jobs(db_path, model_id)?;
210        Ok(())
211    }
212
213    /// Process a single embedding job.
214    fn process_job(&self, config: &EmbeddingJobConfig) -> anyhow::Result<()> {
215        let db_path = Path::new(&config.db_path);
216        let index_path = Path::new(&config.index_path);
217
218        // Open storage and fetch messages
219        let storage = FrankenStorage::open(db_path)?;
220        let messages = storage.fetch_messages_for_embedding()?;
221        let total_docs = saturating_i64_from_usize(messages.len());
222
223        if total_docs == 0 {
224            info!(db_path = %config.db_path, "No messages to embed");
225            return Ok(());
226        }
227
228        info!(
229            db_path = %config.db_path,
230            total_docs,
231            two_tier = config.two_tier,
232            "Found messages to embed"
233        );
234
235        // Determine which passes to run
236        let passes = self.build_passes(config);
237
238        for (model_name, use_semantic) in &passes {
239            if self.cancel_flag.load(Ordering::SeqCst) {
240                info!("Embedding job cancelled");
241                return Ok(());
242            }
243
244            let job_id = storage.upsert_embedding_job(&config.db_path, model_name, total_docs)?;
245            storage.start_embedding_job(job_id)?;
246
247            match self.generate_embeddings_and_save(
248                &storage,
249                &messages,
250                model_name,
251                *use_semantic,
252                job_id,
253                index_path,
254            ) {
255                Ok(()) => {
256                    storage.complete_embedding_job(job_id)?;
257                    info!(model = model_name, "Embedding pass completed");
258                }
259                Err(e) => {
260                    let err_msg = format!("{e:#}");
261                    storage.fail_embedding_job(job_id, &err_msg)?;
262                    warn!(model = model_name, error = %e, "Embedding pass failed");
263                }
264            }
265        }
266
267        Ok(())
268    }
269
270    /// Determine the embedding passes to run based on config.
271    fn build_passes(&self, config: &EmbeddingJobConfig) -> Vec<(String, bool)> {
272        let mut passes = Vec::new();
273
274        if config.two_tier {
275            // Fast hash pass
276            let fast = config.fast_pass_model();
277            passes.push((fast, false));
278
279            // Quality semantic pass
280            let quality = config.quality_pass_model();
281            passes.push((quality, true));
282        } else {
283            // Single pass with best available
284            let model = config.single_pass_model();
285            let is_semantic = model != HASH_EMBEDDER_MODEL;
286            passes.push((model, is_semantic));
287        }
288
289        passes
290    }
291
292    /// Generate embeddings for messages and save the vector index.
293    fn generate_embeddings_and_save(
294        &self,
295        storage: &FrankenStorage,
296        messages: &[crate::storage::sqlite::MessageForEmbedding],
297        model_name: &str,
298        use_semantic: bool,
299        job_id: i64,
300        index_path: &Path,
301    ) -> anyhow::Result<()> {
302        let embedder_kind = resolve_embedder_kind(model_name, use_semantic)?;
303
304        // Load existing index to check for unchanged documents
305        let existing_hashes = self.load_existing_hashes(index_path, &embedder_kind);
306
307        // Prepare inputs, skipping unchanged documents
308        let mut inputs: Vec<EmbeddingInput> = Vec::new();
309        let mut skipped_count = 0usize;
310        let mut completed = 0i64;
311
312        for msg in messages {
313            if self.cancel_flag.load(Ordering::SeqCst) {
314                return Err(anyhow::anyhow!("job cancelled"));
315            }
316
317            let canonical = canonicalize_for_embedding(&msg.content);
318            if canonical.is_empty() {
319                completed += 1;
320                continue;
321            }
322
323            let hash = content_hash(&canonical);
324            let role = role_code_from_str(&msg.role).unwrap_or(0);
325
326            // Invalid/negative IDs indicate corrupted data; skip rather than collapsing to 0.
327            let Some(message_id) = message_id_from_db(msg.message_id) else {
328                warn!(
329                    raw_message_id = msg.message_id,
330                    "Skipping message with out-of-range id during embedding"
331                );
332                completed += 1;
333                continue;
334            };
335
336            // Check if this document is unchanged - skip re-embedding if hash matches
337            if let Some(existing_hash) = existing_hashes.get(&message_id)
338                && *existing_hash == hash
339            {
340                skipped_count += 1;
341                completed += 1;
342                continue;
343            }
344
345            // Clamp to a stable range instead of silently wrapping/failing.
346            let agent_id = saturating_u32_from_i64(msg.agent_id);
347            let workspace_id = saturating_u32_from_i64(msg.workspace_id.unwrap_or(0));
348
349            inputs.push(EmbeddingInput {
350                message_id,
351                created_at_ms: msg.created_at.unwrap_or(0),
352                agent_id,
353                workspace_id,
354                source_id: msg.source_id_hash,
355                role,
356                chunk_idx: 0,
357                content: canonical,
358            });
359
360            completed += 1;
361            if completed % 100 == 0 {
362                let _ = storage.update_job_progress(job_id, completed);
363                debug!(job_id, completed, "Embedding progress");
364            }
365        }
366
367        if inputs.is_empty() {
368            let final_completed = saturating_i64_from_usize(messages.len());
369            let _ = storage.update_job_progress(job_id, final_completed);
370            info!(
371                model = model_name,
372                skipped = skipped_count,
373                "No documents to embed - all unchanged"
374            );
375            return Ok(());
376        }
377
378        info!(
379            model = model_name,
380            input_count = inputs.len(),
381            skipped = skipped_count,
382            "Embedding documents"
383        );
384
385        // Create the appropriate embedder/indexer
386        let indexer = match embedder_kind {
387            WorkerEmbedderKind::Hash => SemanticIndexer::new(HASH_EMBEDDER_MODEL, None)?,
388            WorkerEmbedderKind::FastEmbed { ref model_name, .. } => {
389                SemanticIndexer::new(model_name, Some(index_path))?
390            }
391        };
392
393        // Embed messages
394        let embedded = indexer.embed_messages(&inputs)?;
395
396        // Update final progress
397        let final_completed = saturating_i64_from_usize(messages.len());
398        let _ = storage.update_job_progress(job_id, final_completed);
399
400        // Append to existing vector index, or create a new one if none exists.
401        // Using append_to_index preserves previously-indexed unchanged documents
402        // that were skipped by the dedup check above.
403        let save_path = vector_index_path(index_path, indexer.embedder_id());
404        if save_path.exists() {
405            let appended = indexer.append_to_index(embedded, index_path)?;
406            info!(appended, "Appended to existing vector index");
407        } else {
408            let _index = indexer.build_and_save_index(embedded, index_path)?;
409        }
410
411        info!(
412            model = model_name,
413            path = %save_path.display(),
414            count = inputs.len(),
415            "Saved vector index"
416        );
417
418        Ok(())
419    }
420
421    /// Load content hashes from an existing vector index for dedup.
422    fn load_existing_hashes(
423        &self,
424        index_path: &Path,
425        embedder_kind: &WorkerEmbedderKind,
426    ) -> HashMap<u64, [u8; 32]> {
427        let embedder_id = match embedder_kind {
428            WorkerEmbedderKind::Hash => "fnv1a-384",
429            WorkerEmbedderKind::FastEmbed { embedder_id, .. } => embedder_id.as_str(),
430        };
431
432        let fsvi_path = vector_index_path(index_path, embedder_id);
433
434        if !fsvi_path.exists() {
435            return HashMap::new();
436        }
437
438        match VectorIndex::open(&fsvi_path) {
439            Ok(index) => {
440                let mut hashes = HashMap::new();
441                for idx in 0..index.record_count() {
442                    let doc_id_str = match index.doc_id_at(idx) {
443                        Ok(doc_id) => doc_id,
444                        Err(_) => continue,
445                    };
446
447                    if let Some(parsed) = parse_semantic_doc_id(doc_id_str)
448                        && let Some(hash) = parsed.content_hash
449                    {
450                        hashes.insert(parsed.message_id, hash);
451                    }
452                }
453                debug!(
454                    path = %fsvi_path.display(),
455                    count = hashes.len(),
456                    "Loaded existing hashes for dedup"
457                );
458                hashes
459            }
460            Err(e) => {
461                warn!(
462                    path = %fsvi_path.display(),
463                    error = %e,
464                    "Failed to load existing index for dedup"
465                );
466                HashMap::new()
467            }
468        }
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    fn build_pass_config(
477        two_tier: bool,
478        fast_model: Option<&str>,
479        quality_model: Option<&str>,
480    ) -> EmbeddingJobConfig {
481        EmbeddingJobConfig {
482            db_path: String::new(),
483            index_path: String::new(),
484            two_tier,
485            fast_model: fast_model.map(str::to_string),
486            quality_model: quality_model.map(str::to_string),
487        }
488    }
489
490    fn fast_embed_kind(model_name: &str, embedder_id: &str) -> WorkerEmbedderKind {
491        WorkerEmbedderKind::FastEmbed {
492            model_name: model_name.to_string(),
493            embedder_id: embedder_id.to_string(),
494        }
495    }
496
497    #[test]
498    fn test_worker_handle_clone() {
499        let (_worker, handle) = EmbeddingWorker::new();
500        let handle2 = handle.clone();
501        // Both handles should be able to send
502        assert!(handle.shutdown().is_ok());
503        // Second handle will fail since receiver got Shutdown and loop ended
504        // But the channel itself is still open until worker drops
505        drop(handle2);
506    }
507
508    #[test]
509    fn test_job_config() {
510        let config = EmbeddingJobConfig {
511            db_path: "/tmp/test.db".to_string(),
512            index_path: "/tmp/test_index".to_string(),
513            two_tier: true,
514            fast_model: Some("hash".to_string()),
515            quality_model: Some("minilm".to_string()),
516        };
517        assert!(config.two_tier);
518        assert_eq!(config.fast_model.as_deref(), Some("hash"));
519        assert_eq!(config.quality_model.as_deref(), Some("minilm"));
520    }
521
522    #[test]
523    fn test_build_passes_single() {
524        let (_worker, _handle) = EmbeddingWorker::new();
525        let config = build_pass_config(false, None, Some("minilm"));
526        let passes = _worker.build_passes(&config);
527        assert_eq!(passes.len(), 1);
528        assert_eq!(passes[0].0, "minilm");
529        assert!(passes[0].1); // semantic
530    }
531
532    #[test]
533    fn test_build_passes_two_tier() {
534        let (_worker, _handle) = EmbeddingWorker::new();
535        let config = build_pass_config(true, Some("hash"), Some("minilm"));
536        let passes = _worker.build_passes(&config);
537        assert_eq!(passes.len(), 2);
538        assert_eq!(passes[0].0, "hash");
539        assert!(!passes[0].1); // not semantic
540        assert_eq!(passes[1].0, "minilm");
541        assert!(passes[1].1); // semantic
542    }
543
544    #[test]
545    fn test_build_passes_defaults() {
546        let (_worker, _handle) = EmbeddingWorker::new();
547        let config = build_pass_config(false, None, None);
548        let passes = _worker.build_passes(&config);
549        assert_eq!(passes.len(), 1);
550        assert_eq!(passes[0].0, "hash");
551        assert!(!passes[0].1); // hash is not semantic
552    }
553
554    #[test]
555    fn test_message_id_from_db_rejects_negative_ids() {
556        assert_eq!(message_id_from_db(-1), None);
557        assert_eq!(message_id_from_db(0), Some(0));
558        assert_eq!(message_id_from_db(42), Some(42));
559    }
560
561    #[test]
562    fn test_saturating_u32_from_i64_clamps_bounds() {
563        assert_eq!(saturating_u32_from_i64(-7), 0);
564        assert_eq!(saturating_u32_from_i64(0), 0);
565        assert_eq!(saturating_u32_from_i64(7), 7);
566        assert_eq!(saturating_u32_from_i64(i64::from(u32::MAX) + 123), u32::MAX);
567    }
568
569    #[test]
570    fn test_saturating_i64_from_usize_clamps_overflow() {
571        assert_eq!(saturating_i64_from_usize(0), 0);
572        assert_eq!(saturating_i64_from_usize(7), 7);
573        assert_eq!(
574            saturating_i64_from_usize(usize::MAX),
575            i64::try_from(usize::MAX).unwrap_or(i64::MAX)
576        );
577    }
578
579    #[test]
580    fn test_resolve_embedder_kind_hash_aliases() {
581        assert_eq!(
582            resolve_embedder_kind("hash", false).unwrap(),
583            WorkerEmbedderKind::Hash
584        );
585        assert_eq!(
586            resolve_embedder_kind("FNV1A-384", true).unwrap(),
587            WorkerEmbedderKind::Hash
588        );
589    }
590
591    /// `coding_agent_session_search-am69y`: pin the override-by-flag
592    /// short-circuit at the top of `resolve_embedder_kind`. The
593    /// `test_resolve_embedder_kind_hash_aliases` companion above
594    /// exercises ("hash", false), but "hash" matches BOTH the
595    /// `!use_semantic` branch AND the `eq_ignore_ascii_case("hash")`
596    /// branch — so a regression that broke only the `!use_semantic`
597    /// short-circuit would still be rescued by the name match and
598    /// silently pass. This test pins the flag-only contract by
599    /// passing semantic model names with `use_semantic=false`: every
600    /// registered FastEmbedder name MUST resolve to `Hash` purely
601    /// because the flag is false, regardless of name.
602    #[test]
603    fn test_resolve_embedder_kind_use_semantic_false_short_circuits_regardless_of_name() {
604        for semantic_name in [
605            "minilm",
606            "minilm-384",
607            "all-minilm-l6-v2",
608            "fastembed",
609            "snowflake-arctic-s",
610            "snowflake-arctic-embed-s",
611            "nomic-embed",
612            "nomic-embed-text-v1.5",
613            "MINILM",
614        ] {
615            assert_eq!(
616                resolve_embedder_kind(semantic_name, false).unwrap(),
617                WorkerEmbedderKind::Hash,
618                "use_semantic=false MUST short-circuit to Hash regardless of model_name; \
619                 regression on name {semantic_name:?} indicates the !use_semantic branch \
620                 was bypassed"
621            );
622        }
623    }
624
625    #[test]
626    fn test_resolve_embedder_kind_semantic_aliases() {
627        assert_eq!(
628            resolve_embedder_kind("minilm", true).unwrap(),
629            fast_embed_kind("minilm", "minilm-384")
630        );
631        assert_eq!(
632            resolve_embedder_kind("MINILM-384", true).unwrap(),
633            fast_embed_kind("minilm", "minilm-384")
634        );
635        assert_eq!(
636            resolve_embedder_kind("fastembed", true).unwrap(),
637            fast_embed_kind("minilm", "minilm-384")
638        );
639    }
640
641    #[test]
642    fn test_resolve_embedder_kind_registered_fastembed_models() {
643        assert_eq!(
644            resolve_embedder_kind("snowflake-arctic-s", true).unwrap(),
645            fast_embed_kind("snowflake-arctic-s", "snowflake-arctic-s-384")
646        );
647        assert_eq!(
648            resolve_embedder_kind("nomic-embed-text-v1.5", true).unwrap(),
649            fast_embed_kind("nomic-embed", "nomic-embed-768")
650        );
651    }
652
653    #[test]
654    fn test_resolve_embedder_kind_rejects_unknown_semantic_model() {
655        let err = resolve_embedder_kind("e5-large", true).unwrap_err();
656        let msg = format!("{err:#}");
657        assert!(msg.contains("unsupported semantic model"));
658    }
659}