use std::sync::Arc;
use sha2::{Digest, Sha256};
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use crate::context::semantic_index::SemanticIndex;
use crate::llm::embedding::{EmbeddingError, EmbeddingProvider};
use crate::storage::watchtower::embeddings;
use crate::storage::DbPool;
use super::super::scheduler::LoopScheduler;
pub const EMBEDDING_WORKER_INTERVAL_SECS: u64 = 5;
pub struct EmbeddingWorker {
pool: DbPool,
provider: Arc<dyn EmbeddingProvider>,
index: Arc<RwLock<SemanticIndex>>,
batch_size: u32,
account_id: String,
}
impl EmbeddingWorker {
pub fn new(
pool: DbPool,
provider: Arc<dyn EmbeddingProvider>,
index: Arc<RwLock<SemanticIndex>>,
batch_size: u32,
account_id: String,
) -> Self {
Self {
pool,
provider,
index,
batch_size,
account_id,
}
}
pub async fn run(&self, cancel: CancellationToken, scheduler: LoopScheduler) {
tracing::info!("Embedding worker started");
loop {
tokio::select! {
() = cancel.cancelled() => {
tracing::debug!("Embedding worker cancelled");
return;
}
() = scheduler.tick() => {}
}
if cancel.is_cancelled() {
return;
}
match self.process_dirty_batch().await {
Ok(count) => {
if count > 0 {
tracing::info!(
embedded = count,
account = %self.account_id,
"Embedding worker indexed chunks"
);
}
}
Err(EmbeddingError::RateLimited { retry_after_secs }) => {
tracing::warn!(
retry_after = retry_after_secs,
"Embedding provider rate limited, will retry"
);
}
Err(EmbeddingError::DimensionMismatch { expected, actual }) => {
tracing::error!(
expected,
actual,
"Embedding dimension mismatch — check provider config"
);
return;
}
Err(e) => {
tracing::warn!(error = %e, "Embedding worker batch failed, will retry");
}
}
}
}
async fn process_dirty_batch(&self) -> Result<u32, EmbeddingError> {
let dirty = embeddings::get_dirty_chunks_for(&self.pool, &self.account_id, self.batch_size)
.await
.map_err(|e| EmbeddingError::Internal(e.to_string()))?;
if dirty.is_empty() {
return Ok(0);
}
let texts: Vec<String> = dirty.iter().map(|c| c.chunk_text.clone()).collect();
let response = self.provider.embed(texts).await?;
if response.embeddings.len() != dirty.len() {
return Err(EmbeddingError::Internal(format!(
"provider returned {} embeddings for {} inputs",
response.embeddings.len(),
dirty.len()
)));
}
let mut count = 0u32;
for (chunk, embedding) in dirty.iter().zip(response.embeddings.iter()) {
let hash = {
let mut hasher = Sha256::new();
hasher.update(chunk.chunk_text.as_bytes());
hex::encode(hasher.finalize())
};
let bytes = embeddings::vec_to_bytes(embedding);
embeddings::upsert_chunk_embedding(
&self.pool,
chunk.chunk_id,
&self.account_id,
&bytes,
response.model.as_str(),
response.dimension as i64,
&hash,
1,
)
.await
.map_err(|e| EmbeddingError::Internal(e.to_string()))?;
if let Err(e) = self
.index
.write()
.await
.insert(chunk.chunk_id, embedding.clone())
{
tracing::warn!(
chunk_id = chunk.chunk_id,
error = %e,
"Failed to insert into in-memory index"
);
}
count += 1;
}
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::embedding::{EmbeddingResponse, EmbeddingUsage};
use crate::storage::{init_test_db, watchtower};
use std::time::Duration;
struct MockProvider {
dimension: usize,
}
#[async_trait::async_trait]
impl EmbeddingProvider for MockProvider {
fn name(&self) -> &str {
"mock"
}
fn dimension(&self) -> usize {
self.dimension
}
fn model_id(&self) -> &str {
"mock-embed"
}
async fn embed(&self, inputs: Vec<String>) -> Result<EmbeddingResponse, EmbeddingError> {
let embeddings = inputs
.iter()
.map(|_| vec![0.1_f32; self.dimension])
.collect();
Ok(EmbeddingResponse {
embeddings,
model: "mock-embed".to_string(),
dimension: self.dimension,
usage: EmbeddingUsage { total_tokens: 5 },
})
}
async fn health_check(&self) -> Result<(), EmbeddingError> {
Ok(())
}
}
struct FailingProvider;
#[async_trait::async_trait]
impl EmbeddingProvider for FailingProvider {
fn name(&self) -> &str {
"failing"
}
fn dimension(&self) -> usize {
3
}
fn model_id(&self) -> &str {
"failing"
}
async fn embed(&self, _inputs: Vec<String>) -> Result<EmbeddingResponse, EmbeddingError> {
Err(EmbeddingError::Network("connection refused".to_string()))
}
async fn health_check(&self) -> Result<(), EmbeddingError> {
Err(EmbeddingError::Network("connection refused".to_string()))
}
}
async fn setup_with_chunk(pool: &DbPool) -> i64 {
let source_id = watchtower::insert_source_context(pool, "local_fs", "{}")
.await
.expect("insert source");
watchtower::upsert_content_node(
pool,
source_id,
"test.md",
"hash1",
Some("Test"),
"Content",
None,
None,
)
.await
.expect("upsert node");
watchtower::insert_chunk(pool, "default", 1, "# Test", "chunk text", "chash1", 0)
.await
.expect("insert chunk")
}
#[tokio::test]
async fn worker_processes_dirty_chunks() {
let pool = init_test_db().await.expect("init db");
let _chunk_id = setup_with_chunk(&pool).await;
let provider = Arc::new(MockProvider { dimension: 3 });
let index = Arc::new(RwLock::new(SemanticIndex::new(
3,
"mock-embed".to_string(),
100,
)));
let worker = EmbeddingWorker::new(
pool.clone(),
provider,
index.clone(),
10,
"default".to_string(),
);
let count = worker.process_dirty_batch().await.expect("should succeed");
assert_eq!(count, 1);
let rows = embeddings::get_all_embeddings_for(&pool, "default")
.await
.expect("get all");
assert_eq!(rows.len(), 1);
assert_eq!(index.read().await.len(), 1);
}
#[tokio::test]
async fn worker_handles_empty_dirty_set() {
let pool = init_test_db().await.expect("init db");
let provider = Arc::new(MockProvider { dimension: 3 });
let index = Arc::new(RwLock::new(SemanticIndex::new(
3,
"mock-embed".to_string(),
100,
)));
let worker = EmbeddingWorker::new(pool.clone(), provider, index, 10, "default".to_string());
let count = worker.process_dirty_batch().await.expect("should succeed");
assert_eq!(count, 0);
}
#[tokio::test]
async fn worker_handles_provider_error() {
let pool = init_test_db().await.expect("init db");
let _chunk_id = setup_with_chunk(&pool).await;
let provider = Arc::new(FailingProvider);
let index = Arc::new(RwLock::new(SemanticIndex::new(
3,
"failing".to_string(),
100,
)));
let worker = EmbeddingWorker::new(pool.clone(), provider, index, 10, "default".to_string());
let err = worker.process_dirty_batch().await.unwrap_err();
matches!(err, EmbeddingError::Network(_));
}
#[tokio::test]
async fn worker_idempotent_reprocessing() {
let pool = init_test_db().await.expect("init db");
let _chunk_id = setup_with_chunk(&pool).await;
let provider = Arc::new(MockProvider { dimension: 3 });
let index = Arc::new(RwLock::new(SemanticIndex::new(
3,
"mock-embed".to_string(),
100,
)));
let worker = EmbeddingWorker::new(
pool.clone(),
provider,
index.clone(),
10,
"default".to_string(),
);
worker.process_dirty_batch().await.expect("first");
let _count = worker.process_dirty_batch().await.expect("second");
let rows = embeddings::get_all_embeddings_for(&pool, "default")
.await
.expect("get all");
assert_eq!(rows.len(), 1);
assert_eq!(index.read().await.len(), 1);
}
struct MismatchCountProvider;
#[async_trait::async_trait]
impl EmbeddingProvider for MismatchCountProvider {
fn name(&self) -> &str {
"mismatch"
}
fn dimension(&self) -> usize {
3
}
fn model_id(&self) -> &str {
"mismatch"
}
async fn embed(&self, _inputs: Vec<String>) -> Result<EmbeddingResponse, EmbeddingError> {
Ok(EmbeddingResponse {
embeddings: vec![],
model: "mismatch".to_string(),
dimension: 3,
usage: EmbeddingUsage { total_tokens: 0 },
})
}
async fn health_check(&self) -> Result<(), EmbeddingError> {
Ok(())
}
}
#[tokio::test]
async fn worker_detects_embedding_count_mismatch() {
let pool = init_test_db().await.expect("init db");
let _chunk_id = setup_with_chunk(&pool).await;
let provider = Arc::new(MismatchCountProvider);
let index = Arc::new(RwLock::new(SemanticIndex::new(
3,
"mismatch".to_string(),
100,
)));
let worker = EmbeddingWorker::new(pool.clone(), provider, index, 10, "default".to_string());
let err = worker.process_dirty_batch().await.unwrap_err();
assert!(
matches!(err, EmbeddingError::Internal(ref msg) if msg.contains("0 embeddings for 1 inputs")),
"expected Internal mismatch error, got: {err}"
);
}
#[tokio::test]
async fn worker_respects_cancellation() {
let pool = init_test_db().await.expect("init db");
let provider = Arc::new(MockProvider { dimension: 3 });
let index = Arc::new(RwLock::new(SemanticIndex::new(
3,
"mock-embed".to_string(),
100,
)));
let worker = EmbeddingWorker::new(pool, provider, index, 10, "default".to_string());
let cancel = CancellationToken::new();
let scheduler = LoopScheduler::new(Duration::from_secs(60), Duration::ZERO, Duration::ZERO);
let cancel_clone = cancel.clone();
let handle = tokio::spawn(async move {
worker.run(cancel_clone, scheduler).await;
});
cancel.cancel();
tokio::time::timeout(Duration::from_secs(2), handle)
.await
.expect("should complete within timeout")
.expect("should not panic");
}
}