use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use crate::error::Result;
use crate::types::{EnrichmentTask, FrameId, VecEmbedder};
#[derive(Debug, Clone)]
pub struct EnrichmentWorkerConfig {
pub embedding_batch_size: usize,
pub checkpoint_interval: usize,
pub task_delay_ms: u64,
pub max_task_time_ms: u64,
}
impl Default for EnrichmentWorkerConfig {
fn default() -> Self {
Self {
embedding_batch_size: 32,
checkpoint_interval: 100,
task_delay_ms: 50,
max_task_time_ms: 5000,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct EnrichmentWorkerStats {
pub frames_processed: u64,
pub embeddings_generated: u64,
pub re_extractions: u64,
pub errors: u64,
pub queue_depth: usize,
pub is_running: bool,
}
pub struct EnrichmentWorkerHandle {
stop_signal: Arc<AtomicBool>,
frames_processed: Arc<AtomicU64>,
embeddings_generated: Arc<AtomicU64>,
re_extractions: Arc<AtomicU64>,
errors: Arc<AtomicU64>,
is_running: Arc<AtomicBool>,
}
impl EnrichmentWorkerHandle {
#[must_use]
pub fn new() -> Self {
Self {
stop_signal: Arc::new(AtomicBool::new(false)),
frames_processed: Arc::new(AtomicU64::new(0)),
embeddings_generated: Arc::new(AtomicU64::new(0)),
re_extractions: Arc::new(AtomicU64::new(0)),
errors: Arc::new(AtomicU64::new(0)),
is_running: Arc::new(AtomicBool::new(false)),
}
}
pub fn stop(&self) {
self.stop_signal.store(true, Ordering::SeqCst);
}
#[must_use]
pub fn should_stop(&self) -> bool {
self.stop_signal.load(Ordering::SeqCst)
}
#[must_use]
pub fn is_running(&self) -> bool {
self.is_running.load(Ordering::SeqCst)
}
#[must_use]
pub fn stats(&self) -> EnrichmentWorkerStats {
EnrichmentWorkerStats {
frames_processed: self.frames_processed.load(Ordering::Relaxed),
embeddings_generated: self.embeddings_generated.load(Ordering::Relaxed),
re_extractions: self.re_extractions.load(Ordering::Relaxed),
errors: self.errors.load(Ordering::Relaxed),
queue_depth: 0, is_running: self.is_running.load(Ordering::Relaxed),
}
}
pub(crate) fn inc_frames_processed(&self) {
self.frames_processed.fetch_add(1, Ordering::Relaxed);
}
pub(crate) fn inc_embeddings(&self, count: u64) {
self.embeddings_generated
.fetch_add(count, Ordering::Relaxed);
}
pub(crate) fn inc_re_extractions(&self) {
self.re_extractions.fetch_add(1, Ordering::Relaxed);
}
pub(crate) fn inc_errors(&self) {
self.errors.fetch_add(1, Ordering::Relaxed);
}
pub(crate) fn set_running(&self, running: bool) {
self.is_running.store(running, Ordering::SeqCst);
}
#[must_use]
pub fn clone_handle(&self) -> Self {
Self {
stop_signal: Arc::clone(&self.stop_signal),
frames_processed: Arc::clone(&self.frames_processed),
embeddings_generated: Arc::clone(&self.embeddings_generated),
re_extractions: Arc::clone(&self.re_extractions),
errors: Arc::clone(&self.errors),
is_running: Arc::clone(&self.is_running),
}
}
}
impl Default for EnrichmentWorkerHandle {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct TaskResult {
pub frame_id: FrameId,
pub re_extracted: bool,
pub embeddings_generated: usize,
pub elapsed_ms: u64,
pub error: Option<String>,
}
pub struct EmbeddingBatcher<E: VecEmbedder> {
embedder: E,
batch_size: usize,
pending_texts: Vec<(FrameId, String)>,
ready_embeddings: Vec<(FrameId, Vec<f32>)>,
}
impl<E: VecEmbedder> EmbeddingBatcher<E> {
pub fn new(embedder: E, batch_size: usize) -> Self {
Self {
embedder,
batch_size: batch_size.max(1),
pending_texts: Vec::new(),
ready_embeddings: Vec::new(),
}
}
pub fn add(&mut self, frame_id: FrameId, text: String) {
self.pending_texts.push((frame_id, text));
}
pub fn pending_count(&self) -> usize {
self.pending_texts.len()
}
pub fn ready_count(&self) -> usize {
self.ready_embeddings.len()
}
pub fn should_flush(&self) -> bool {
self.pending_texts.len() >= self.batch_size
}
pub fn flush(&mut self) -> Result<usize> {
if self.pending_texts.is_empty() {
return Ok(0);
}
let pending: Vec<_> = std::mem::take(&mut self.pending_texts);
let count = pending.len();
let texts: Vec<&str> = pending.iter().map(|(_, text)| text.as_str()).collect();
let embeddings = self.embedder.embed_chunks(&texts)?;
for ((frame_id, _), embedding) in pending.into_iter().zip(embeddings.into_iter()) {
self.ready_embeddings.push((frame_id, embedding));
}
Ok(count)
}
pub fn take_embeddings(&mut self) -> Vec<(FrameId, Vec<f32>)> {
std::mem::take(&mut self.ready_embeddings)
}
pub fn dimension(&self) -> usize {
self.embedder.embedding_dimension()
}
}
pub struct EnrichmentProcessor {
pub config: EnrichmentWorkerConfig,
}
impl EnrichmentProcessor {
#[must_use]
pub fn new(config: EnrichmentWorkerConfig) -> Self {
Self { config }
}
pub fn process_task<F, E, R>(
&self,
task: &EnrichmentTask,
read_frame: F,
extract_full: E,
update_index: R,
) -> TaskResult
where
F: FnOnce(FrameId) -> Option<(String, bool, bool)>, E: FnOnce(FrameId) -> Result<String>, R: FnOnce(FrameId, &str) -> Result<()>, {
let start = Instant::now();
let mut result = TaskResult {
frame_id: task.frame_id,
re_extracted: false,
embeddings_generated: 0,
elapsed_ms: 0,
error: None,
};
let (text, is_skim, _needs_embedding) = if let Some(data) = read_frame(task.frame_id) {
data
} else {
result.error = Some("Frame not found".to_string());
result.elapsed_ms = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
return result;
};
let final_text = if is_skim {
match extract_full(task.frame_id) {
Ok(full_text) => {
result.re_extracted = true;
full_text
}
Err(err) => {
tracing::warn!(
frame_id = task.frame_id,
?err,
"re-extraction failed, using skim text"
);
text
}
}
} else {
text
};
if let Err(err) = update_index(task.frame_id, &final_text) {
result.error = Some(format!("Index update failed: {err}"));
}
result.elapsed_ms = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
result
}
}
pub fn run_worker_loop<G, P, M, C>(
handle: &EnrichmentWorkerHandle,
config: &EnrichmentWorkerConfig,
mut get_next_task: G,
mut process_task: P,
mut mark_complete: M,
mut checkpoint: C,
) where
G: FnMut() -> Option<EnrichmentTask>,
P: FnMut(&EnrichmentTask) -> TaskResult,
M: FnMut(FrameId),
C: FnMut(),
{
handle.set_running(true);
tracing::info!("enrichment worker started");
let mut tasks_since_checkpoint = 0;
while !handle.should_stop() {
let task = if let Some(task) = get_next_task() {
task
} else {
std::thread::sleep(Duration::from_millis(config.task_delay_ms * 10));
continue;
};
let result = process_task(&task);
handle.inc_frames_processed();
if result.re_extracted {
handle.inc_re_extractions();
}
if result.embeddings_generated > 0 {
handle.inc_embeddings(result.embeddings_generated as u64);
}
if result.error.is_some() {
handle.inc_errors();
tracing::warn!(
frame_id = task.frame_id,
error = ?result.error,
"enrichment task failed"
);
} else {
tracing::debug!(
frame_id = task.frame_id,
re_extracted = result.re_extracted,
embeddings = result.embeddings_generated,
elapsed_ms = result.elapsed_ms,
"enrichment task complete"
);
}
mark_complete(task.frame_id);
tasks_since_checkpoint += 1;
if tasks_since_checkpoint >= config.checkpoint_interval {
checkpoint();
tasks_since_checkpoint = 0;
}
std::thread::sleep(Duration::from_millis(config.task_delay_ms));
}
if tasks_since_checkpoint > 0 {
checkpoint();
}
handle.set_running(false);
tracing::info!(
frames_processed = handle.frames_processed.load(Ordering::Relaxed),
"enrichment worker stopped"
);
}
#[cfg(test)]
mod tests {
use super::*;
struct MockEmbedder {
dimension: usize,
}
impl MockEmbedder {
fn new(dimension: usize) -> Self {
Self { dimension }
}
}
impl crate::types::VecEmbedder for MockEmbedder {
fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
let seed = text.len() as f32;
Ok((0..self.dimension)
.map(|i| (seed + i as f32) * 0.1)
.collect())
}
fn embedding_dimension(&self) -> usize {
self.dimension
}
}
#[test]
fn test_embedding_batcher_basic() {
let embedder = MockEmbedder::new(4);
let mut batcher = EmbeddingBatcher::new(embedder, 2);
assert_eq!(batcher.pending_count(), 0);
assert_eq!(batcher.ready_count(), 0);
assert!(!batcher.should_flush());
batcher.add(1, "hello".to_string());
assert_eq!(batcher.pending_count(), 1);
assert!(!batcher.should_flush());
batcher.add(2, "world".to_string());
assert_eq!(batcher.pending_count(), 2);
assert!(batcher.should_flush());
let count = batcher.flush().expect("flush should succeed");
assert_eq!(count, 2);
assert_eq!(batcher.pending_count(), 0);
assert_eq!(batcher.ready_count(), 2);
let embeddings = batcher.take_embeddings();
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].0, 1); assert_eq!(embeddings[0].1.len(), 4); assert_eq!(embeddings[1].0, 2);
assert_eq!(embeddings[1].1.len(), 4);
assert_eq!(batcher.ready_count(), 0);
}
#[test]
fn test_embedding_batcher_dimension() {
let embedder = MockEmbedder::new(128);
let batcher = EmbeddingBatcher::new(embedder, 32);
assert_eq!(batcher.dimension(), 128);
}
#[test]
fn test_embedding_batcher_flush_empty() {
let embedder = MockEmbedder::new(4);
let mut batcher = EmbeddingBatcher::new(embedder, 2);
let count = batcher.flush().expect("flush should succeed");
assert_eq!(count, 0);
}
#[test]
fn test_worker_handle() {
let handle = EnrichmentWorkerHandle::new();
assert!(!handle.is_running());
assert!(!handle.should_stop());
handle.set_running(true);
assert!(handle.is_running());
handle.stop();
assert!(handle.should_stop());
handle.inc_frames_processed();
handle.inc_embeddings(10);
handle.inc_re_extractions();
handle.inc_errors();
let stats = handle.stats();
assert_eq!(stats.frames_processed, 1);
assert_eq!(stats.embeddings_generated, 10);
assert_eq!(stats.re_extractions, 1);
assert_eq!(stats.errors, 1);
}
#[test]
fn test_processor() {
let processor = EnrichmentProcessor::new(EnrichmentWorkerConfig::default());
let task = EnrichmentTask {
frame_id: 1,
created_at: 0,
chunks_done: 0,
chunks_total: 0,
};
let result = processor.process_task(
&task,
|_| Some(("test content".to_string(), false, false)),
|_| Ok("full content".to_string()),
|_, _| Ok(()),
);
assert_eq!(result.frame_id, 1);
assert!(!result.re_extracted); assert!(result.error.is_none());
}
#[test]
fn test_processor_with_skim() {
let processor = EnrichmentProcessor::new(EnrichmentWorkerConfig::default());
let task = EnrichmentTask {
frame_id: 2,
created_at: 0,
chunks_done: 0,
chunks_total: 0,
};
let result = processor.process_task(
&task,
|_| Some(("skim content".to_string(), true, false)), |_| Ok("full extracted content".to_string()),
|_, text| {
assert_eq!(text, "full extracted content");
Ok(())
},
);
assert_eq!(result.frame_id, 2);
assert!(result.re_extracted); assert!(result.error.is_none());
}
}