1use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
11use std::time::{Duration, Instant};
12
13use crate::error::Result;
14use crate::types::{EnrichmentTask, FrameId, VecEmbedder};
15
16#[derive(Debug, Clone)]
18pub struct EnrichmentWorkerConfig {
19 pub embedding_batch_size: usize,
21 pub checkpoint_interval: usize,
23 pub task_delay_ms: u64,
25 pub max_task_time_ms: u64,
27}
28
29impl Default for EnrichmentWorkerConfig {
30 fn default() -> Self {
31 Self {
32 embedding_batch_size: 32,
33 checkpoint_interval: 100,
34 task_delay_ms: 50,
35 max_task_time_ms: 5000,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Default)]
42pub struct EnrichmentWorkerStats {
43 pub frames_processed: u64,
45 pub embeddings_generated: u64,
47 pub re_extractions: u64,
49 pub errors: u64,
51 pub queue_depth: usize,
53 pub is_running: bool,
55}
56
57pub struct EnrichmentWorkerHandle {
59 stop_signal: Arc<AtomicBool>,
61 frames_processed: Arc<AtomicU64>,
63 embeddings_generated: Arc<AtomicU64>,
65 re_extractions: Arc<AtomicU64>,
67 errors: Arc<AtomicU64>,
69 is_running: Arc<AtomicBool>,
71}
72
73impl EnrichmentWorkerHandle {
74 #[must_use]
76 pub fn new() -> Self {
77 Self {
78 stop_signal: Arc::new(AtomicBool::new(false)),
79 frames_processed: Arc::new(AtomicU64::new(0)),
80 embeddings_generated: Arc::new(AtomicU64::new(0)),
81 re_extractions: Arc::new(AtomicU64::new(0)),
82 errors: Arc::new(AtomicU64::new(0)),
83 is_running: Arc::new(AtomicBool::new(false)),
84 }
85 }
86
87 pub fn stop(&self) {
89 self.stop_signal.store(true, Ordering::SeqCst);
90 }
91
92 #[must_use]
94 pub fn should_stop(&self) -> bool {
95 self.stop_signal.load(Ordering::SeqCst)
96 }
97
98 #[must_use]
100 pub fn is_running(&self) -> bool {
101 self.is_running.load(Ordering::SeqCst)
102 }
103
104 #[must_use]
106 pub fn stats(&self) -> EnrichmentWorkerStats {
107 EnrichmentWorkerStats {
108 frames_processed: self.frames_processed.load(Ordering::Relaxed),
109 embeddings_generated: self.embeddings_generated.load(Ordering::Relaxed),
110 re_extractions: self.re_extractions.load(Ordering::Relaxed),
111 errors: self.errors.load(Ordering::Relaxed),
112 queue_depth: 0, is_running: self.is_running.load(Ordering::Relaxed),
114 }
115 }
116
117 pub(crate) fn inc_frames_processed(&self) {
119 self.frames_processed.fetch_add(1, Ordering::Relaxed);
120 }
121
122 pub(crate) fn inc_embeddings(&self, count: u64) {
124 self.embeddings_generated
125 .fetch_add(count, Ordering::Relaxed);
126 }
127
128 pub(crate) fn inc_re_extractions(&self) {
130 self.re_extractions.fetch_add(1, Ordering::Relaxed);
131 }
132
133 pub(crate) fn inc_errors(&self) {
135 self.errors.fetch_add(1, Ordering::Relaxed);
136 }
137
138 pub(crate) fn set_running(&self, running: bool) {
140 self.is_running.store(running, Ordering::SeqCst);
141 }
142
143 #[must_use]
145 pub fn clone_handle(&self) -> Self {
146 Self {
147 stop_signal: Arc::clone(&self.stop_signal),
148 frames_processed: Arc::clone(&self.frames_processed),
149 embeddings_generated: Arc::clone(&self.embeddings_generated),
150 re_extractions: Arc::clone(&self.re_extractions),
151 errors: Arc::clone(&self.errors),
152 is_running: Arc::clone(&self.is_running),
153 }
154 }
155}
156
157impl Default for EnrichmentWorkerHandle {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163#[derive(Debug)]
165pub struct TaskResult {
166 pub frame_id: FrameId,
168 pub re_extracted: bool,
170 pub embeddings_generated: usize,
172 pub elapsed_ms: u64,
174 pub error: Option<String>,
176}
177
178pub struct EmbeddingBatcher<E: VecEmbedder> {
183 embedder: E,
185 batch_size: usize,
187 pending_texts: Vec<(FrameId, String)>,
189 ready_embeddings: Vec<(FrameId, Vec<f32>)>,
191}
192
193impl<E: VecEmbedder> EmbeddingBatcher<E> {
194 pub fn new(embedder: E, batch_size: usize) -> Self {
196 Self {
197 embedder,
198 batch_size: batch_size.max(1),
199 pending_texts: Vec::new(),
200 ready_embeddings: Vec::new(),
201 }
202 }
203
204 pub fn add(&mut self, frame_id: FrameId, text: String) {
206 self.pending_texts.push((frame_id, text));
207 }
208
209 pub fn pending_count(&self) -> usize {
211 self.pending_texts.len()
212 }
213
214 pub fn ready_count(&self) -> usize {
216 self.ready_embeddings.len()
217 }
218
219 pub fn should_flush(&self) -> bool {
221 self.pending_texts.len() >= self.batch_size
222 }
223
224 pub fn flush(&mut self) -> Result<usize> {
228 if self.pending_texts.is_empty() {
229 return Ok(0);
230 }
231
232 let pending: Vec<_> = std::mem::take(&mut self.pending_texts);
234 let count = pending.len();
235
236 let texts: Vec<&str> = pending.iter().map(|(_, text)| text.as_str()).collect();
238
239 let embeddings = self.embedder.embed_chunks(&texts)?;
241
242 for ((frame_id, _), embedding) in pending.into_iter().zip(embeddings.into_iter()) {
244 self.ready_embeddings.push((frame_id, embedding));
245 }
246
247 Ok(count)
248 }
249
250 pub fn take_embeddings(&mut self) -> Vec<(FrameId, Vec<f32>)> {
252 std::mem::take(&mut self.ready_embeddings)
253 }
254
255 pub fn dimension(&self) -> usize {
257 self.embedder.embedding_dimension()
258 }
259}
260
261pub struct EnrichmentProcessor {
263 pub config: EnrichmentWorkerConfig,
265}
266
267impl EnrichmentProcessor {
268 #[must_use]
270 pub fn new(config: EnrichmentWorkerConfig) -> Self {
271 Self { config }
272 }
273
274 pub fn process_task<F, E, R>(
288 &self,
289 task: &EnrichmentTask,
290 read_frame: F,
291 extract_full: E,
292 update_index: R,
293 ) -> TaskResult
294 where
295 F: FnOnce(FrameId) -> Option<(String, bool, bool)>, E: FnOnce(FrameId) -> Result<String>, R: FnOnce(FrameId, &str) -> Result<()>, {
299 let start = Instant::now();
300 let mut result = TaskResult {
301 frame_id: task.frame_id,
302 re_extracted: false,
303 embeddings_generated: 0,
304 elapsed_ms: 0,
305 error: None,
306 };
307
308 let (text, is_skim, _needs_embedding) = if let Some(data) = read_frame(task.frame_id) {
310 data
311 } else {
312 result.error = Some("Frame not found".to_string());
313 result.elapsed_ms = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
314 return result;
315 };
316
317 let final_text = if is_skim {
319 match extract_full(task.frame_id) {
320 Ok(full_text) => {
321 result.re_extracted = true;
322 full_text
323 }
324 Err(err) => {
325 tracing::warn!(
326 frame_id = task.frame_id,
327 ?err,
328 "re-extraction failed, using skim text"
329 );
330 text
331 }
332 }
333 } else {
334 text
335 };
336
337 if let Err(err) = update_index(task.frame_id, &final_text) {
339 result.error = Some(format!("Index update failed: {err}"));
340 }
341
342 result.elapsed_ms = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
343 result
344 }
345}
346
347pub fn run_worker_loop<G, P, M, C>(
360 handle: &EnrichmentWorkerHandle,
361 config: &EnrichmentWorkerConfig,
362 mut get_next_task: G,
363 mut process_task: P,
364 mut mark_complete: M,
365 mut checkpoint: C,
366) where
367 G: FnMut() -> Option<EnrichmentTask>,
368 P: FnMut(&EnrichmentTask) -> TaskResult,
369 M: FnMut(FrameId),
370 C: FnMut(),
371{
372 handle.set_running(true);
373 tracing::info!("enrichment worker started");
374
375 let mut tasks_since_checkpoint = 0;
376
377 while !handle.should_stop() {
378 let task = if let Some(task) = get_next_task() {
380 task
381 } else {
382 std::thread::sleep(Duration::from_millis(config.task_delay_ms * 10));
384 continue;
385 };
386
387 let result = process_task(&task);
389
390 handle.inc_frames_processed();
392 if result.re_extracted {
393 handle.inc_re_extractions();
394 }
395 if result.embeddings_generated > 0 {
396 handle.inc_embeddings(result.embeddings_generated as u64);
397 }
398 if result.error.is_some() {
399 handle.inc_errors();
400 tracing::warn!(
401 frame_id = task.frame_id,
402 error = ?result.error,
403 "enrichment task failed"
404 );
405 } else {
406 tracing::debug!(
407 frame_id = task.frame_id,
408 re_extracted = result.re_extracted,
409 embeddings = result.embeddings_generated,
410 elapsed_ms = result.elapsed_ms,
411 "enrichment task complete"
412 );
413 }
414
415 mark_complete(task.frame_id);
417 tasks_since_checkpoint += 1;
418
419 if tasks_since_checkpoint >= config.checkpoint_interval {
421 checkpoint();
422 tasks_since_checkpoint = 0;
423 }
424
425 std::thread::sleep(Duration::from_millis(config.task_delay_ms));
427 }
428
429 if tasks_since_checkpoint > 0 {
431 checkpoint();
432 }
433
434 handle.set_running(false);
435 tracing::info!(
436 frames_processed = handle.frames_processed.load(Ordering::Relaxed),
437 "enrichment worker stopped"
438 );
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 struct MockEmbedder {
447 dimension: usize,
448 }
449
450 impl MockEmbedder {
451 fn new(dimension: usize) -> Self {
452 Self { dimension }
453 }
454 }
455
456 impl crate::types::VecEmbedder for MockEmbedder {
457 fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
458 let seed = text.len() as f32;
460 Ok((0..self.dimension)
461 .map(|i| (seed + i as f32) * 0.1)
462 .collect())
463 }
464
465 fn embedding_dimension(&self) -> usize {
466 self.dimension
467 }
468 }
469
470 #[test]
471 fn test_embedding_batcher_basic() {
472 let embedder = MockEmbedder::new(4);
473 let mut batcher = EmbeddingBatcher::new(embedder, 2);
474
475 assert_eq!(batcher.pending_count(), 0);
476 assert_eq!(batcher.ready_count(), 0);
477 assert!(!batcher.should_flush());
478
479 batcher.add(1, "hello".to_string());
481 assert_eq!(batcher.pending_count(), 1);
482 assert!(!batcher.should_flush());
483
484 batcher.add(2, "world".to_string());
486 assert_eq!(batcher.pending_count(), 2);
487 assert!(batcher.should_flush());
488
489 let count = batcher.flush().expect("flush should succeed");
491 assert_eq!(count, 2);
492 assert_eq!(batcher.pending_count(), 0);
493 assert_eq!(batcher.ready_count(), 2);
494
495 let embeddings = batcher.take_embeddings();
497 assert_eq!(embeddings.len(), 2);
498 assert_eq!(embeddings[0].0, 1); assert_eq!(embeddings[0].1.len(), 4); assert_eq!(embeddings[1].0, 2);
501 assert_eq!(embeddings[1].1.len(), 4);
502
503 assert_eq!(batcher.ready_count(), 0);
505 }
506
507 #[test]
508 fn test_embedding_batcher_dimension() {
509 let embedder = MockEmbedder::new(128);
510 let batcher = EmbeddingBatcher::new(embedder, 32);
511 assert_eq!(batcher.dimension(), 128);
512 }
513
514 #[test]
515 fn test_embedding_batcher_flush_empty() {
516 let embedder = MockEmbedder::new(4);
517 let mut batcher = EmbeddingBatcher::new(embedder, 2);
518
519 let count = batcher.flush().expect("flush should succeed");
521 assert_eq!(count, 0);
522 }
523
524 #[test]
525 fn test_worker_handle() {
526 let handle = EnrichmentWorkerHandle::new();
527 assert!(!handle.is_running());
528 assert!(!handle.should_stop());
529
530 handle.set_running(true);
531 assert!(handle.is_running());
532
533 handle.stop();
534 assert!(handle.should_stop());
535
536 handle.inc_frames_processed();
537 handle.inc_embeddings(10);
538 handle.inc_re_extractions();
539 handle.inc_errors();
540
541 let stats = handle.stats();
542 assert_eq!(stats.frames_processed, 1);
543 assert_eq!(stats.embeddings_generated, 10);
544 assert_eq!(stats.re_extractions, 1);
545 assert_eq!(stats.errors, 1);
546 }
547
548 #[test]
549 fn test_processor() {
550 let processor = EnrichmentProcessor::new(EnrichmentWorkerConfig::default());
551 let task = EnrichmentTask {
552 frame_id: 1,
553 created_at: 0,
554 chunks_done: 0,
555 chunks_total: 0,
556 };
557
558 let result = processor.process_task(
559 &task,
560 |_| Some(("test content".to_string(), false, false)),
561 |_| Ok("full content".to_string()),
562 |_, _| Ok(()),
563 );
564
565 assert_eq!(result.frame_id, 1);
566 assert!(!result.re_extracted); assert!(result.error.is_none());
568 }
569
570 #[test]
571 fn test_processor_with_skim() {
572 let processor = EnrichmentProcessor::new(EnrichmentWorkerConfig::default());
573 let task = EnrichmentTask {
574 frame_id: 2,
575 created_at: 0,
576 chunks_done: 0,
577 chunks_total: 0,
578 };
579
580 let result = processor.process_task(
581 &task,
582 |_| Some(("skim content".to_string(), true, false)), |_| Ok("full extracted content".to_string()),
584 |_, text| {
585 assert_eq!(text, "full extracted content");
586 Ok(())
587 },
588 );
589
590 assert_eq!(result.frame_id, 2);
591 assert!(result.re_extracted); assert!(result.error.is_none());
593 }
594}