Skip to main content

memvid_core/
enrichment_worker.rs

1//! Background enrichment worker for progressive ingestion.
2//!
3//! Processes frames in the enrichment queue asynchronously:
4//! - Re-extracts full text for skim extractions
5//! - Generates embeddings with batching and checkpointing
6//! - Updates Tantivy index with enriched content
7//! - Marks frames as Enriched when complete
8
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
11use std::time::{Duration, Instant};
12
13use crate::error::Result;
14use crate::types::{EnrichmentTask, FrameId, VecEmbedder};
15
16/// Configuration for the enrichment worker.
17#[derive(Debug, Clone)]
18pub struct EnrichmentWorkerConfig {
19    /// Batch size for embedding generation.
20    pub embedding_batch_size: usize,
21    /// Checkpoint interval (persist progress every N embeddings).
22    pub checkpoint_interval: usize,
23    /// Delay between processing tasks (to avoid blocking writers).
24    pub task_delay_ms: u64,
25    /// Maximum time to spend on a single task before yielding.
26    pub max_task_time_ms: u64,
27}
28
29impl Default for EnrichmentWorkerConfig {
30    fn default() -> Self {
31        Self {
32            embedding_batch_size: 32,
33            checkpoint_interval: 100,
34            task_delay_ms: 50,
35            max_task_time_ms: 5000,
36        }
37    }
38}
39
40/// Statistics for the enrichment worker.
41#[derive(Debug, Clone, Default)]
42pub struct EnrichmentWorkerStats {
43    /// Total frames processed.
44    pub frames_processed: u64,
45    /// Total embeddings generated.
46    pub embeddings_generated: u64,
47    /// Total re-extractions performed.
48    pub re_extractions: u64,
49    /// Total errors encountered.
50    pub errors: u64,
51    /// Current queue depth.
52    pub queue_depth: usize,
53    /// Whether worker is currently running.
54    pub is_running: bool,
55}
56
57/// Handle for controlling the background enrichment worker.
58pub struct EnrichmentWorkerHandle {
59    /// Signal to stop the worker.
60    stop_signal: Arc<AtomicBool>,
61    /// Counter for frames processed.
62    frames_processed: Arc<AtomicU64>,
63    /// Counter for embeddings generated.
64    embeddings_generated: Arc<AtomicU64>,
65    /// Counter for re-extractions.
66    re_extractions: Arc<AtomicU64>,
67    /// Counter for errors.
68    errors: Arc<AtomicU64>,
69    /// Running state.
70    is_running: Arc<AtomicBool>,
71}
72
73impl EnrichmentWorkerHandle {
74    /// Create a new worker handle.
75    #[must_use]
76    pub fn new() -> Self {
77        Self {
78            stop_signal: Arc::new(AtomicBool::new(false)),
79            frames_processed: Arc::new(AtomicU64::new(0)),
80            embeddings_generated: Arc::new(AtomicU64::new(0)),
81            re_extractions: Arc::new(AtomicU64::new(0)),
82            errors: Arc::new(AtomicU64::new(0)),
83            is_running: Arc::new(AtomicBool::new(false)),
84        }
85    }
86
87    /// Signal the worker to stop.
88    pub fn stop(&self) {
89        self.stop_signal.store(true, Ordering::SeqCst);
90    }
91
92    /// Check if stop was requested.
93    #[must_use]
94    pub fn should_stop(&self) -> bool {
95        self.stop_signal.load(Ordering::SeqCst)
96    }
97
98    /// Check if worker is currently running.
99    #[must_use]
100    pub fn is_running(&self) -> bool {
101        self.is_running.load(Ordering::SeqCst)
102    }
103
104    /// Get current statistics.
105    #[must_use]
106    pub fn stats(&self) -> EnrichmentWorkerStats {
107        EnrichmentWorkerStats {
108            frames_processed: self.frames_processed.load(Ordering::Relaxed),
109            embeddings_generated: self.embeddings_generated.load(Ordering::Relaxed),
110            re_extractions: self.re_extractions.load(Ordering::Relaxed),
111            errors: self.errors.load(Ordering::Relaxed),
112            queue_depth: 0, // Will be updated by caller
113            is_running: self.is_running.load(Ordering::Relaxed),
114        }
115    }
116
117    /// Increment frames processed counter.
118    pub(crate) fn inc_frames_processed(&self) {
119        self.frames_processed.fetch_add(1, Ordering::Relaxed);
120    }
121
122    /// Increment embeddings generated counter.
123    pub(crate) fn inc_embeddings(&self, count: u64) {
124        self.embeddings_generated
125            .fetch_add(count, Ordering::Relaxed);
126    }
127
128    /// Increment re-extractions counter.
129    pub(crate) fn inc_re_extractions(&self) {
130        self.re_extractions.fetch_add(1, Ordering::Relaxed);
131    }
132
133    /// Increment errors counter.
134    pub(crate) fn inc_errors(&self) {
135        self.errors.fetch_add(1, Ordering::Relaxed);
136    }
137
138    /// Set running state.
139    pub(crate) fn set_running(&self, running: bool) {
140        self.is_running.store(running, Ordering::SeqCst);
141    }
142
143    /// Clone the handle for sharing with the worker thread.
144    #[must_use]
145    pub fn clone_handle(&self) -> Self {
146        Self {
147            stop_signal: Arc::clone(&self.stop_signal),
148            frames_processed: Arc::clone(&self.frames_processed),
149            embeddings_generated: Arc::clone(&self.embeddings_generated),
150            re_extractions: Arc::clone(&self.re_extractions),
151            errors: Arc::clone(&self.errors),
152            is_running: Arc::clone(&self.is_running),
153        }
154    }
155}
156
157impl Default for EnrichmentWorkerHandle {
158    fn default() -> Self {
159        Self::new()
160    }
161}
162
163/// Result of processing a single enrichment task.
164#[derive(Debug)]
165pub struct TaskResult {
166    /// Frame ID that was processed.
167    pub frame_id: FrameId,
168    /// Whether full re-extraction was performed.
169    pub re_extracted: bool,
170    /// Number of embeddings generated.
171    pub embeddings_generated: usize,
172    /// Time spent processing.
173    pub elapsed_ms: u64,
174    /// Error if processing failed.
175    pub error: Option<String>,
176}
177
178/// Batched embedding generator for efficient embedding creation.
179///
180/// Collects text chunks and generates embeddings in batches to minimize
181/// API calls and improve throughput.
182pub struct EmbeddingBatcher<E: VecEmbedder> {
183    /// The embedder to use for generating embeddings.
184    embedder: E,
185    /// Batch size for embedding generation.
186    batch_size: usize,
187    /// Pending texts to embed.
188    pending_texts: Vec<(FrameId, String)>,
189    /// Generated embeddings ready to store.
190    ready_embeddings: Vec<(FrameId, Vec<f32>)>,
191}
192
193impl<E: VecEmbedder> EmbeddingBatcher<E> {
194    /// Create a new embedding batcher.
195    pub fn new(embedder: E, batch_size: usize) -> Self {
196        Self {
197            embedder,
198            batch_size: batch_size.max(1),
199            pending_texts: Vec::new(),
200            ready_embeddings: Vec::new(),
201        }
202    }
203
204    /// Add a frame's text for embedding.
205    pub fn add(&mut self, frame_id: FrameId, text: String) {
206        self.pending_texts.push((frame_id, text));
207    }
208
209    /// Get the number of pending texts.
210    pub fn pending_count(&self) -> usize {
211        self.pending_texts.len()
212    }
213
214    /// Get the number of ready embeddings.
215    pub fn ready_count(&self) -> usize {
216        self.ready_embeddings.len()
217    }
218
219    /// Check if a batch is ready to process.
220    pub fn should_flush(&self) -> bool {
221        self.pending_texts.len() >= self.batch_size
222    }
223
224    /// Process pending texts and generate embeddings.
225    ///
226    /// Returns the number of embeddings generated.
227    pub fn flush(&mut self) -> Result<usize> {
228        if self.pending_texts.is_empty() {
229            return Ok(0);
230        }
231
232        // Take all pending texts
233        let pending: Vec<_> = std::mem::take(&mut self.pending_texts);
234        let count = pending.len();
235
236        // Extract texts for batch embedding
237        let texts: Vec<&str> = pending.iter().map(|(_, text)| text.as_str()).collect();
238
239        // Generate embeddings in batch
240        let embeddings = self.embedder.embed_chunks(&texts)?;
241
242        // Store results
243        for ((frame_id, _), embedding) in pending.into_iter().zip(embeddings.into_iter()) {
244            self.ready_embeddings.push((frame_id, embedding));
245        }
246
247        Ok(count)
248    }
249
250    /// Take all ready embeddings.
251    pub fn take_embeddings(&mut self) -> Vec<(FrameId, Vec<f32>)> {
252        std::mem::take(&mut self.ready_embeddings)
253    }
254
255    /// Get embedding dimension from the embedder.
256    pub fn dimension(&self) -> usize {
257        self.embedder.embedding_dimension()
258    }
259}
260
261/// Enrichment task processor (stateless, operates on Memvid instance).
262pub struct EnrichmentProcessor {
263    /// Worker configuration.
264    pub config: EnrichmentWorkerConfig,
265}
266
267impl EnrichmentProcessor {
268    /// Create a new enrichment processor.
269    #[must_use]
270    pub fn new(config: EnrichmentWorkerConfig) -> Self {
271        Self { config }
272    }
273
274    /// Process a single enrichment task.
275    ///
276    /// This method:
277    /// 1. Reads the frame from the memory
278    /// 2. If frame needs re-extraction (skim), performs full extraction
279    /// 3. If frame needs embeddings, generates them with batching
280    /// 4. Updates the Tantivy index
281    /// 5. Returns the result
282    ///
283    /// The caller is responsible for:
284    /// - Acquiring write lock on the memory
285    /// - Updating the enrichment queue
286    /// - Persisting changes
287    pub fn process_task<F, E, R>(
288        &self,
289        task: &EnrichmentTask,
290        read_frame: F,
291        extract_full: E,
292        update_index: R,
293    ) -> TaskResult
294    where
295        F: FnOnce(FrameId) -> Option<(String, bool, bool)>, // (text, is_skim, needs_embedding)
296        E: FnOnce(FrameId) -> Result<String>,               // Full extraction
297        R: FnOnce(FrameId, &str) -> Result<()>,             // Update index
298    {
299        let start = Instant::now();
300        let mut result = TaskResult {
301            frame_id: task.frame_id,
302            re_extracted: false,
303            embeddings_generated: 0,
304            elapsed_ms: 0,
305            error: None,
306        };
307
308        // Read current frame state
309        let (text, is_skim, _needs_embedding) = if let Some(data) = read_frame(task.frame_id) {
310            data
311        } else {
312            result.error = Some("Frame not found".to_string());
313            result.elapsed_ms = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
314            return result;
315        };
316
317        // Re-extract if this was a skim
318        let final_text = if is_skim {
319            match extract_full(task.frame_id) {
320                Ok(full_text) => {
321                    result.re_extracted = true;
322                    full_text
323                }
324                Err(err) => {
325                    tracing::warn!(
326                        frame_id = task.frame_id,
327                        ?err,
328                        "re-extraction failed, using skim text"
329                    );
330                    text
331                }
332            }
333        } else {
334            text
335        };
336
337        // Update index with enriched content
338        if let Err(err) = update_index(task.frame_id, &final_text) {
339            result.error = Some(format!("Index update failed: {err}"));
340        }
341
342        result.elapsed_ms = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
343        result
344    }
345}
346
347/// Run the enrichment worker loop.
348///
349/// This function should be called from a background thread.
350/// It processes tasks from the enrichment queue until stopped.
351///
352/// # Arguments
353/// * `handle` - Worker handle for control and statistics
354/// * `config` - Worker configuration
355/// * `get_next_task` - Closure to get the next task from the queue
356/// * `process_task` - Closure to process a single task
357/// * `mark_complete` - Closure to mark a task as complete
358/// * `checkpoint` - Closure to save progress
359pub fn run_worker_loop<G, P, M, C>(
360    handle: &EnrichmentWorkerHandle,
361    config: &EnrichmentWorkerConfig,
362    mut get_next_task: G,
363    mut process_task: P,
364    mut mark_complete: M,
365    mut checkpoint: C,
366) where
367    G: FnMut() -> Option<EnrichmentTask>,
368    P: FnMut(&EnrichmentTask) -> TaskResult,
369    M: FnMut(FrameId),
370    C: FnMut(),
371{
372    handle.set_running(true);
373    tracing::info!("enrichment worker started");
374
375    let mut tasks_since_checkpoint = 0;
376
377    while !handle.should_stop() {
378        // Get next task
379        let task = if let Some(task) = get_next_task() {
380            task
381        } else {
382            // Queue is empty, wait and check again
383            std::thread::sleep(Duration::from_millis(config.task_delay_ms * 10));
384            continue;
385        };
386
387        // Process the task
388        let result = process_task(&task);
389
390        // Update statistics
391        handle.inc_frames_processed();
392        if result.re_extracted {
393            handle.inc_re_extractions();
394        }
395        if result.embeddings_generated > 0 {
396            handle.inc_embeddings(result.embeddings_generated as u64);
397        }
398        if result.error.is_some() {
399            handle.inc_errors();
400            tracing::warn!(
401                frame_id = task.frame_id,
402                error = ?result.error,
403                "enrichment task failed"
404            );
405        } else {
406            tracing::debug!(
407                frame_id = task.frame_id,
408                re_extracted = result.re_extracted,
409                embeddings = result.embeddings_generated,
410                elapsed_ms = result.elapsed_ms,
411                "enrichment task complete"
412            );
413        }
414
415        // Mark task complete (remove from queue)
416        mark_complete(task.frame_id);
417        tasks_since_checkpoint += 1;
418
419        // Checkpoint periodically
420        if tasks_since_checkpoint >= config.checkpoint_interval {
421            checkpoint();
422            tasks_since_checkpoint = 0;
423        }
424
425        // Yield to other threads
426        std::thread::sleep(Duration::from_millis(config.task_delay_ms));
427    }
428
429    // Final checkpoint
430    if tasks_since_checkpoint > 0 {
431        checkpoint();
432    }
433
434    handle.set_running(false);
435    tracing::info!(
436        frames_processed = handle.frames_processed.load(Ordering::Relaxed),
437        "enrichment worker stopped"
438    );
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    /// Mock embedder for testing
446    struct MockEmbedder {
447        dimension: usize,
448    }
449
450    impl MockEmbedder {
451        fn new(dimension: usize) -> Self {
452            Self { dimension }
453        }
454    }
455
456    impl crate::types::VecEmbedder for MockEmbedder {
457        fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
458            // Generate deterministic embedding based on text length
459            let seed = text.len() as f32;
460            Ok((0..self.dimension)
461                .map(|i| (seed + i as f32) * 0.1)
462                .collect())
463        }
464
465        fn embedding_dimension(&self) -> usize {
466            self.dimension
467        }
468    }
469
470    #[test]
471    fn test_embedding_batcher_basic() {
472        let embedder = MockEmbedder::new(4);
473        let mut batcher = EmbeddingBatcher::new(embedder, 2);
474
475        assert_eq!(batcher.pending_count(), 0);
476        assert_eq!(batcher.ready_count(), 0);
477        assert!(!batcher.should_flush());
478
479        // Add one item - shouldn't trigger flush yet
480        batcher.add(1, "hello".to_string());
481        assert_eq!(batcher.pending_count(), 1);
482        assert!(!batcher.should_flush());
483
484        // Add second item - should trigger flush
485        batcher.add(2, "world".to_string());
486        assert_eq!(batcher.pending_count(), 2);
487        assert!(batcher.should_flush());
488
489        // Flush the batch
490        let count = batcher.flush().expect("flush should succeed");
491        assert_eq!(count, 2);
492        assert_eq!(batcher.pending_count(), 0);
493        assert_eq!(batcher.ready_count(), 2);
494
495        // Take embeddings
496        let embeddings = batcher.take_embeddings();
497        assert_eq!(embeddings.len(), 2);
498        assert_eq!(embeddings[0].0, 1); // frame_id
499        assert_eq!(embeddings[0].1.len(), 4); // dimension
500        assert_eq!(embeddings[1].0, 2);
501        assert_eq!(embeddings[1].1.len(), 4);
502
503        // After take, ready should be empty
504        assert_eq!(batcher.ready_count(), 0);
505    }
506
507    #[test]
508    fn test_embedding_batcher_dimension() {
509        let embedder = MockEmbedder::new(128);
510        let batcher = EmbeddingBatcher::new(embedder, 32);
511        assert_eq!(batcher.dimension(), 128);
512    }
513
514    #[test]
515    fn test_embedding_batcher_flush_empty() {
516        let embedder = MockEmbedder::new(4);
517        let mut batcher = EmbeddingBatcher::new(embedder, 2);
518
519        // Flushing empty batcher should return 0
520        let count = batcher.flush().expect("flush should succeed");
521        assert_eq!(count, 0);
522    }
523
524    #[test]
525    fn test_worker_handle() {
526        let handle = EnrichmentWorkerHandle::new();
527        assert!(!handle.is_running());
528        assert!(!handle.should_stop());
529
530        handle.set_running(true);
531        assert!(handle.is_running());
532
533        handle.stop();
534        assert!(handle.should_stop());
535
536        handle.inc_frames_processed();
537        handle.inc_embeddings(10);
538        handle.inc_re_extractions();
539        handle.inc_errors();
540
541        let stats = handle.stats();
542        assert_eq!(stats.frames_processed, 1);
543        assert_eq!(stats.embeddings_generated, 10);
544        assert_eq!(stats.re_extractions, 1);
545        assert_eq!(stats.errors, 1);
546    }
547
548    #[test]
549    fn test_processor() {
550        let processor = EnrichmentProcessor::new(EnrichmentWorkerConfig::default());
551        let task = EnrichmentTask {
552            frame_id: 1,
553            created_at: 0,
554            chunks_done: 0,
555            chunks_total: 0,
556        };
557
558        let result = processor.process_task(
559            &task,
560            |_| Some(("test content".to_string(), false, false)),
561            |_| Ok("full content".to_string()),
562            |_, _| Ok(()),
563        );
564
565        assert_eq!(result.frame_id, 1);
566        assert!(!result.re_extracted); // Not a skim
567        assert!(result.error.is_none());
568    }
569
570    #[test]
571    fn test_processor_with_skim() {
572        let processor = EnrichmentProcessor::new(EnrichmentWorkerConfig::default());
573        let task = EnrichmentTask {
574            frame_id: 2,
575            created_at: 0,
576            chunks_done: 0,
577            chunks_total: 0,
578        };
579
580        let result = processor.process_task(
581            &task,
582            |_| Some(("skim content".to_string(), true, false)), // is_skim = true
583            |_| Ok("full extracted content".to_string()),
584            |_, text| {
585                assert_eq!(text, "full extracted content");
586                Ok(())
587            },
588        );
589
590        assert_eq!(result.frame_id, 2);
591        assert!(result.re_extracted); // Re-extraction happened
592        assert!(result.error.is_none());
593    }
594}