use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;
use crate::enrichment_worker::{
EmbeddingBatcher, EnrichmentWorkerConfig, EnrichmentWorkerHandle, EnrichmentWorkerStats,
TaskResult,
};
use crate::error::Result;
use crate::extract_budgeted::ExtractionBudget;
use crate::types::{EnrichmentState, EnrichmentTask, FrameId, FrameStatus, VecEmbedder};
use crate::vec::VecIndexBuilder;
use super::Memvid;
pub struct EnrichmentHandle {
pub handle: EnrichmentWorkerHandle,
thread: Option<JoinHandle<()>>,
}
impl EnrichmentHandle {
#[must_use]
pub fn stop_and_wait(mut self) -> EnrichmentWorkerStats {
self.handle.stop();
if let Some(thread) = self.thread.take() {
let _ = thread.join();
}
self.handle.stats()
}
#[must_use]
pub fn is_running(&self) -> bool {
self.handle.is_running()
}
#[must_use]
pub fn stats(&self) -> EnrichmentWorkerStats {
self.handle.stats()
}
pub fn stop(&self) {
self.handle.stop();
}
}
pub fn start_enrichment_worker(
memvid: Arc<Mutex<Memvid>>,
config: Option<EnrichmentWorkerConfig>,
) -> EnrichmentHandle {
let config = config.unwrap_or_default();
let handle = EnrichmentWorkerHandle::new();
let worker_handle = handle.clone_handle();
let memvid_clone = Arc::clone(&memvid);
let config_clone = config.clone();
let thread = std::thread::spawn(move || {
crate::enrichment_worker::run_worker_loop(
&worker_handle,
&config_clone,
|| {
let mv = memvid_clone.lock().ok()?;
mv.next_enrichment_task()
},
|task| {
let mut mv = match memvid_clone.lock() {
Ok(mv) => mv,
Err(_) => {
return TaskResult {
frame_id: task.frame_id,
re_extracted: false,
embeddings_generated: 0,
elapsed_ms: 0,
error: Some("Failed to acquire lock".to_string()),
};
}
};
mv.process_enrichment_task(task)
},
|frame_id| {
if let Ok(mut mv) = memvid_clone.lock() {
mv.complete_enrichment_task(frame_id);
}
},
|| {
if let Ok(mut mv) = memvid_clone.lock() {
if let Err(err) = mv.commit() {
tracing::warn!(?err, "enrichment checkpoint commit failed");
}
}
},
);
});
EnrichmentHandle {
handle,
thread: Some(thread),
}
}
pub fn start_enrichment_worker_with_embeddings<E>(
memvid: Arc<Mutex<Memvid>>,
embedder: E,
config: Option<EnrichmentWorkerConfig>,
) -> EnrichmentHandle
where
E: VecEmbedder + Send + 'static,
{
let config = config.unwrap_or_default();
let handle = EnrichmentWorkerHandle::new();
let worker_handle = handle.clone_handle();
let batch_size = config.embedding_batch_size;
let thread = std::thread::spawn(move || {
worker_handle.set_running(true);
tracing::info!("enrichment worker with embeddings started");
match memvid.lock() {
Ok(mut mv) => {
match mv.process_enrichment_with_embeddings(embedder, batch_size) {
Ok((frames, embeddings)) => {
for _ in 0..frames {
worker_handle.inc_frames_processed();
}
worker_handle.inc_embeddings(embeddings as u64);
if let Err(err) = mv.commit() {
tracing::warn!(?err, "final commit failed");
worker_handle.inc_errors();
}
}
Err(err) => {
tracing::error!(?err, "enrichment with embeddings failed");
worker_handle.inc_errors();
}
}
}
Err(err) => {
tracing::error!(?err, "failed to acquire lock for enrichment");
worker_handle.inc_errors();
}
}
worker_handle.set_running(false);
tracing::info!(
frames_processed = worker_handle.stats().frames_processed,
embeddings_generated = worker_handle.stats().embeddings_generated,
"enrichment worker with embeddings stopped"
);
});
EnrichmentHandle {
handle,
thread: Some(thread),
}
}
impl Memvid {
#[must_use]
pub fn enrichment_queue_len(&self) -> usize {
self.toc.enrichment_queue.len()
}
#[must_use]
pub fn has_pending_enrichment(&self) -> bool {
!self.toc.enrichment_queue.is_empty()
}
#[must_use]
pub fn next_enrichment_task(&self) -> Option<EnrichmentTask> {
self.toc.enrichment_queue.tasks.first().cloned()
}
pub fn complete_enrichment_task(&mut self, frame_id: FrameId) {
self.toc.enrichment_queue.remove(frame_id);
self.dirty = true;
}
#[must_use]
pub fn read_frame_for_enrichment(&self, frame_id: FrameId) -> Option<(String, bool, bool)> {
let frame = self
.toc
.frames
.iter()
.find(|f| f.id == frame_id && f.status == FrameStatus::Active)?;
let search_text = frame.search_text.clone().unwrap_or_default();
let is_skim = frame
.extra_metadata
.get("skim")
.is_some_and(|v| v == "true");
let needs_embedding = frame.enrichment_state == EnrichmentState::Searchable;
Some((search_text, is_skim, needs_embedding))
}
pub fn extract_full_text(&mut self, frame_id: FrameId) -> Result<String> {
let frame = self
.toc
.frames
.iter()
.find(|f| f.id == frame_id && f.status == FrameStatus::Active)
.cloned()
.ok_or(crate::MemvidError::FrameNotFound { frame_id })?;
let payload = self.read_frame_payload_bytes(&frame)?;
let mime_hint = frame.metadata.as_ref().and_then(|m| m.mime.as_deref());
let uri_hint = frame.uri.as_deref();
let budget = ExtractionBudget::unlimited();
match crate::extract_budgeted::extract_with_budget(&payload, mime_hint, uri_hint, budget) {
Ok(result) => Ok(result.text),
Err(_) => {
Ok(frame.search_text.clone().unwrap_or_default())
}
}
}
#[cfg(feature = "lex")]
pub fn update_tantivy_for_enrichment(&mut self, frame_id: FrameId, text: &str) -> Result<()> {
let tantivy = match self.tantivy.as_mut() {
Some(t) => t,
None => return Ok(()), };
let frame = self
.toc
.frames
.iter()
.find(|f| f.id == frame_id && f.status == FrameStatus::Active)
.ok_or(crate::MemvidError::FrameNotFound { frame_id })?
.clone();
tantivy.delete_frame(frame_id)?;
tantivy.add_frame(&frame, text)?;
tantivy.soft_commit()?;
self.tantivy_dirty = true;
Ok(())
}
#[cfg(not(feature = "lex"))]
pub fn update_tantivy_for_enrichment(&mut self, _frame_id: FrameId, _text: &str) -> Result<()> {
Ok(())
}
pub fn mark_frame_enriched(&mut self, frame_id: FrameId) {
if let Some(frame) = self
.toc
.frames
.iter_mut()
.find(|f| f.id == frame_id && f.status == FrameStatus::Active)
{
frame.enrichment_state = EnrichmentState::Enriched;
self.dirty = true;
}
}
pub fn process_enrichment_task(&mut self, task: &EnrichmentTask) -> TaskResult {
let frame_data = self.read_frame_for_enrichment(task.frame_id);
let (search_text, is_skim, _needs_embedding) = match frame_data {
Some(data) => data,
None => {
return TaskResult {
frame_id: task.frame_id,
re_extracted: false,
embeddings_generated: 0,
elapsed_ms: 0,
error: Some("Frame not found".to_string()),
};
}
};
let start = std::time::Instant::now();
let mut result = TaskResult {
frame_id: task.frame_id,
re_extracted: false,
embeddings_generated: 0,
elapsed_ms: 0,
error: None,
};
let final_text = if is_skim {
match self.extract_full_text(task.frame_id) {
Ok(full_text) => {
result.re_extracted = true;
full_text
}
Err(err) => {
tracing::warn!(
frame_id = task.frame_id,
?err,
"re-extraction failed, using skim text"
);
search_text
}
}
} else {
search_text
};
if let Err(err) = self.update_tantivy_for_enrichment(task.frame_id, &final_text) {
result.error = Some(format!("Index update failed: {err}"));
}
self.mark_frame_enriched(task.frame_id);
result.elapsed_ms = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
result
}
pub fn process_all_enrichment(&mut self) -> usize {
let mut processed = 0;
while let Some(task) = self.next_enrichment_task() {
let result = self.process_enrichment_task(&task);
self.complete_enrichment_task(task.frame_id);
if result.error.is_some() {
tracing::warn!(
frame_id = task.frame_id,
error = ?result.error,
"enrichment task failed"
);
} else {
tracing::debug!(
frame_id = task.frame_id,
re_extracted = result.re_extracted,
elapsed_ms = result.elapsed_ms,
"enrichment task complete"
);
}
processed += 1;
}
processed
}
#[must_use]
pub fn enrichment_stats(&self) -> EnrichmentStats {
let total_frames = self
.toc
.frames
.iter()
.filter(|f| f.status == FrameStatus::Active)
.count();
let enriched_frames = self
.toc
.frames
.iter()
.filter(|f| {
f.status == FrameStatus::Active && f.enrichment_state == EnrichmentState::Enriched
})
.count();
let pending_frames = self.enrichment_queue_len();
EnrichmentStats {
total_frames,
enriched_frames,
pending_frames,
searchable_only: total_frames.saturating_sub(enriched_frames),
}
}
pub fn add_embeddings(&mut self, embeddings: Vec<(FrameId, Vec<f32>)>) -> Result<usize> {
if embeddings.is_empty() {
return Ok(0);
}
let count = embeddings.len();
let mut builder = VecIndexBuilder::new();
if let Some(ref vec_index) = self.vec_index {
for (frame_id, embedding) in vec_index.entries() {
if !embeddings.iter().any(|(id, _)| *id == frame_id) {
builder.add_document(frame_id, embedding.to_vec());
}
}
}
for (frame_id, embedding) in embeddings {
builder.add_document(frame_id, embedding);
}
let artifact = builder.finish()?;
if artifact.vector_count == 0 {
return Ok(0);
}
let new_index = crate::vec::VecIndex::decode(&artifact.bytes)?;
self.vec_index = Some(new_index);
self.toc.indexes.vec = Some(crate::types::VecIndexManifest {
vector_count: artifact.vector_count,
dimension: artifact.dimension,
bytes_offset: 0, bytes_length: artifact.bytes.len() as u64,
checksum: artifact.checksum,
compression_mode: crate::types::VectorCompression::None,
model: self.vec_model.clone(),
});
self.dirty = true;
self.vec_enabled = true;
tracing::debug!(
count,
total_vectors = artifact.vector_count,
dimension = artifact.dimension,
"added embeddings to vector index"
);
Ok(count)
}
pub fn process_enrichment_with_embeddings<E: VecEmbedder>(
&mut self,
embedder: E,
batch_size: usize,
) -> Result<(usize, usize)> {
let mut batcher = EmbeddingBatcher::new(embedder, batch_size);
let mut frames_processed = 0;
let mut embeddings_generated = 0;
let tasks: Vec<_> = self.toc.enrichment_queue.tasks.clone();
for task in tasks {
let frame_data = self.read_frame_for_enrichment(task.frame_id);
let (search_text, is_skim, needs_embedding) = match frame_data {
Some(data) => data,
None => continue, };
let final_text = if is_skim {
match self.extract_full_text(task.frame_id) {
Ok(full_text) => full_text,
Err(err) => {
tracing::warn!(frame_id = task.frame_id, ?err, "re-extraction failed");
search_text
}
}
} else {
search_text
};
if let Err(err) = self.update_tantivy_for_enrichment(task.frame_id, &final_text) {
tracing::warn!(frame_id = task.frame_id, ?err, "tantivy update failed");
}
if needs_embedding && !final_text.trim().is_empty() {
batcher.add(task.frame_id, final_text);
if batcher.should_flush() {
match batcher.flush() {
Ok(count) => {
embeddings_generated += count;
let ready = batcher.take_embeddings();
if !ready.is_empty() {
if let Err(err) = self.add_embeddings(ready) {
tracing::warn!(?err, "failed to add embeddings");
}
}
}
Err(err) => {
tracing::warn!(?err, "batch embedding failed");
}
}
}
}
self.mark_frame_enriched(task.frame_id);
self.complete_enrichment_task(task.frame_id);
frames_processed += 1;
let chunks_done = u32::try_from(embeddings_generated).unwrap_or(u32::MAX);
self.toc.enrichment_queue.update_checkpoint(
task.frame_id,
chunks_done,
task.chunks_total.max(chunks_done),
);
}
if batcher.pending_count() > 0 {
match batcher.flush() {
Ok(count) => {
embeddings_generated += count;
let ready = batcher.take_embeddings();
if !ready.is_empty() {
if let Err(err) = self.add_embeddings(ready) {
tracing::warn!(?err, "failed to add final embeddings");
}
}
}
Err(err) => {
tracing::warn!(?err, "final batch embedding failed");
}
}
}
tracing::info!(
frames_processed,
embeddings_generated,
"enrichment with embeddings complete"
);
Ok((frames_processed, embeddings_generated))
}
#[must_use]
pub fn has_embeddings(&self) -> bool {
self.vec_enabled && self.vec_index.is_some()
}
#[must_use]
pub fn vector_count(&self) -> usize {
self.toc.indexes.vec.as_ref().map_or(0, |m| {
#[allow(clippy::cast_possible_truncation)]
let count = m.vector_count as usize;
count
})
}
}
#[derive(Debug, Clone)]
pub struct EnrichmentStats {
pub total_frames: usize,
pub enriched_frames: usize,
pub pending_frames: usize,
pub searchable_only: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enrichment_stats_default() {
let stats = EnrichmentStats {
total_frames: 100,
enriched_frames: 50,
pending_frames: 10,
searchable_only: 50,
};
assert_eq!(
stats.enriched_frames + stats.searchable_only,
stats.total_frames
);
}
}