use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{Receiver, Sender};
use tracing::{debug, error, info, warn};
use crate::indexer::semantic::{
EmbeddingInput, SemanticIndexer, message_id_from_db, saturating_u32_from_i64,
};
use crate::search::canonicalize::{canonicalize_for_embedding, content_hash};
use crate::search::fastembed_embedder::FastEmbedder;
use crate::search::vector_index::{
VectorIndex, parse_semantic_doc_id, role_code_from_str, vector_index_path,
};
use crate::storage::sqlite::FrankenStorage;
const HASH_EMBEDDER_MODEL: &str = "hash";
const DEFAULT_SEMANTIC_MODEL: &str = "minilm";
#[derive(Debug, Clone)]
pub struct EmbeddingJobConfig {
pub db_path: String,
pub index_path: String,
pub two_tier: bool,
pub fast_model: Option<String>,
pub quality_model: Option<String>,
}
impl EmbeddingJobConfig {
fn fast_pass_model(&self) -> String {
self.fast_model
.clone()
.unwrap_or_else(|| HASH_EMBEDDER_MODEL.to_string())
}
fn quality_pass_model(&self) -> String {
self.quality_model
.clone()
.unwrap_or_else(|| DEFAULT_SEMANTIC_MODEL.to_string())
}
fn single_pass_model(&self) -> String {
self.quality_model
.clone()
.or_else(|| self.fast_model.clone())
.unwrap_or_else(|| HASH_EMBEDDER_MODEL.to_string())
}
}
#[derive(Debug)]
pub enum WorkerMessage {
Submit(EmbeddingJobConfig),
Cancel {
db_path: String,
model_id: Option<String>,
},
Shutdown,
}
#[derive(Clone)]
pub struct EmbeddingWorkerHandle {
sender: Sender<WorkerMessage>,
cancel_flag: Arc<AtomicBool>,
}
impl EmbeddingWorkerHandle {
pub fn submit(&self, config: EmbeddingJobConfig) -> Result<(), String> {
self.sender
.send(WorkerMessage::Submit(config))
.map_err(|e| format!("worker channel closed: {e}"))
}
pub fn cancel(&self, db_path: String, model_id: Option<String>) -> Result<(), String> {
self.cancel_flag.store(true, Ordering::SeqCst);
self.sender
.send(WorkerMessage::Cancel { db_path, model_id })
.map_err(|e| format!("worker channel closed: {e}"))
}
pub fn shutdown(&self) -> Result<(), String> {
self.sender
.send(WorkerMessage::Shutdown)
.map_err(|e| format!("worker channel closed: {e}"))
}
}
pub struct EmbeddingWorker {
receiver: Receiver<WorkerMessage>,
cancel_flag: Arc<AtomicBool>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum WorkerEmbedderKind {
Hash,
FastEmbed {
model_name: String,
embedder_id: String,
},
}
fn resolve_embedder_kind(
model_name: &str,
use_semantic: bool,
) -> anyhow::Result<WorkerEmbedderKind> {
if !use_semantic
|| model_name.eq_ignore_ascii_case(HASH_EMBEDDER_MODEL)
|| model_name.eq_ignore_ascii_case("fnv1a-384")
{
return Ok(WorkerEmbedderKind::Hash);
}
let normalized_name = match model_name.to_ascii_lowercase().as_str() {
"fastembed" | "minilm" | "minilm-384" | "all-minilm-l6-v2" => DEFAULT_SEMANTIC_MODEL,
"snowflake-arctic-s" | "snowflake-arctic-s-384" | "snowflake-arctic-embed-s" => {
"snowflake-arctic-s"
}
"nomic-embed" | "nomic-embed-768" | "nomic-embed-text-v1.5" => "nomic-embed",
_ => {
anyhow::bail!(
"unsupported semantic model '{model_name}' for daemon embedding worker; supported: minilm, snowflake-arctic-s, nomic-embed"
);
}
};
let config = FastEmbedder::config_for(normalized_name).ok_or_else(|| {
anyhow::anyhow!("missing FastEmbedder config for registered model '{normalized_name}'")
})?;
Ok(WorkerEmbedderKind::FastEmbed {
model_name: normalized_name.to_string(),
embedder_id: config.embedder_id,
})
}
fn saturating_i64_from_usize(raw: usize) -> i64 {
i64::try_from(raw).unwrap_or(i64::MAX)
}
impl EmbeddingWorker {
pub fn new() -> (Self, EmbeddingWorkerHandle) {
let (sender, receiver) = std::sync::mpsc::channel();
let cancel_flag = Arc::new(AtomicBool::new(false));
let handle = EmbeddingWorkerHandle {
sender,
cancel_flag: Arc::clone(&cancel_flag),
};
let worker = Self {
receiver,
cancel_flag,
};
(worker, handle)
}
pub fn run(self) {
info!("Embedding worker started");
while let Ok(msg) = self.receiver.recv() {
match msg {
WorkerMessage::Submit(config) => {
self.cancel_flag.store(false, Ordering::SeqCst);
info!(db_path = %config.db_path, two_tier = config.two_tier, "Processing embedding job");
if let Err(e) = self.process_job(&config) {
error!(db_path = %config.db_path, error = %e, "Embedding job failed");
}
}
WorkerMessage::Cancel { db_path, model_id } => {
info!(%db_path, ?model_id, "Processing cancel — flag already set by handle");
if let Err(e) = Self::cancel_in_db(&db_path, model_id.as_deref()) {
warn!(%db_path, error = %e, "Failed to cancel jobs in database");
}
}
WorkerMessage::Shutdown => {
info!("Embedding worker shutting down");
break;
}
}
}
info!("Embedding worker stopped");
}
fn cancel_in_db(db_path: &str, model_id: Option<&str>) -> anyhow::Result<()> {
let storage = FrankenStorage::open(Path::new(db_path))?;
storage.cancel_embedding_jobs(db_path, model_id)?;
Ok(())
}
fn process_job(&self, config: &EmbeddingJobConfig) -> anyhow::Result<()> {
let db_path = Path::new(&config.db_path);
let index_path = Path::new(&config.index_path);
let storage = FrankenStorage::open(db_path)?;
let messages = storage.fetch_messages_for_embedding()?;
let total_docs = saturating_i64_from_usize(messages.len());
if total_docs == 0 {
info!(db_path = %config.db_path, "No messages to embed");
return Ok(());
}
info!(
db_path = %config.db_path,
total_docs,
two_tier = config.two_tier,
"Found messages to embed"
);
let passes = self.build_passes(config);
for (model_name, use_semantic) in &passes {
if self.cancel_flag.load(Ordering::SeqCst) {
info!("Embedding job cancelled");
return Ok(());
}
let job_id = storage.upsert_embedding_job(&config.db_path, model_name, total_docs)?;
storage.start_embedding_job(job_id)?;
match self.generate_embeddings_and_save(
&storage,
&messages,
model_name,
*use_semantic,
job_id,
index_path,
) {
Ok(()) => {
storage.complete_embedding_job(job_id)?;
info!(model = model_name, "Embedding pass completed");
}
Err(e) => {
let err_msg = format!("{e:#}");
storage.fail_embedding_job(job_id, &err_msg)?;
warn!(model = model_name, error = %e, "Embedding pass failed");
}
}
}
Ok(())
}
fn build_passes(&self, config: &EmbeddingJobConfig) -> Vec<(String, bool)> {
let mut passes = Vec::new();
if config.two_tier {
let fast = config.fast_pass_model();
passes.push((fast, false));
let quality = config.quality_pass_model();
passes.push((quality, true));
} else {
let model = config.single_pass_model();
let is_semantic = model != HASH_EMBEDDER_MODEL;
passes.push((model, is_semantic));
}
passes
}
fn generate_embeddings_and_save(
&self,
storage: &FrankenStorage,
messages: &[crate::storage::sqlite::MessageForEmbedding],
model_name: &str,
use_semantic: bool,
job_id: i64,
index_path: &Path,
) -> anyhow::Result<()> {
let embedder_kind = resolve_embedder_kind(model_name, use_semantic)?;
let existing_hashes = self.load_existing_hashes(index_path, &embedder_kind);
let mut inputs: Vec<EmbeddingInput> = Vec::new();
let mut skipped_count = 0usize;
let mut completed = 0i64;
for msg in messages {
if self.cancel_flag.load(Ordering::SeqCst) {
return Err(anyhow::anyhow!("job cancelled"));
}
let canonical = canonicalize_for_embedding(&msg.content);
if canonical.is_empty() {
completed += 1;
continue;
}
let hash = content_hash(&canonical);
let role = role_code_from_str(&msg.role).unwrap_or(0);
let Some(message_id) = message_id_from_db(msg.message_id) else {
warn!(
raw_message_id = msg.message_id,
"Skipping message with out-of-range id during embedding"
);
completed += 1;
continue;
};
if let Some(existing_hash) = existing_hashes.get(&message_id)
&& *existing_hash == hash
{
skipped_count += 1;
completed += 1;
continue;
}
let agent_id = saturating_u32_from_i64(msg.agent_id);
let workspace_id = saturating_u32_from_i64(msg.workspace_id.unwrap_or(0));
inputs.push(EmbeddingInput {
message_id,
created_at_ms: msg.created_at.unwrap_or(0),
agent_id,
workspace_id,
source_id: msg.source_id_hash,
role,
chunk_idx: 0,
content: canonical,
});
completed += 1;
if completed % 100 == 0 {
let _ = storage.update_job_progress(job_id, completed);
debug!(job_id, completed, "Embedding progress");
}
}
if inputs.is_empty() {
let final_completed = saturating_i64_from_usize(messages.len());
let _ = storage.update_job_progress(job_id, final_completed);
info!(
model = model_name,
skipped = skipped_count,
"No documents to embed - all unchanged"
);
return Ok(());
}
info!(
model = model_name,
input_count = inputs.len(),
skipped = skipped_count,
"Embedding documents"
);
let indexer = match embedder_kind {
WorkerEmbedderKind::Hash => SemanticIndexer::new(HASH_EMBEDDER_MODEL, None)?,
WorkerEmbedderKind::FastEmbed { ref model_name, .. } => {
SemanticIndexer::new(model_name, Some(index_path))?
}
};
let embedded = indexer.embed_messages(&inputs)?;
let final_completed = saturating_i64_from_usize(messages.len());
let _ = storage.update_job_progress(job_id, final_completed);
let save_path = vector_index_path(index_path, indexer.embedder_id());
if save_path.exists() {
let appended = indexer.append_to_index(embedded, index_path)?;
info!(appended, "Appended to existing vector index");
} else {
let _index = indexer.build_and_save_index(embedded, index_path)?;
}
info!(
model = model_name,
path = %save_path.display(),
count = inputs.len(),
"Saved vector index"
);
Ok(())
}
fn load_existing_hashes(
&self,
index_path: &Path,
embedder_kind: &WorkerEmbedderKind,
) -> HashMap<u64, [u8; 32]> {
let embedder_id = match embedder_kind {
WorkerEmbedderKind::Hash => "fnv1a-384",
WorkerEmbedderKind::FastEmbed { embedder_id, .. } => embedder_id.as_str(),
};
let fsvi_path = vector_index_path(index_path, embedder_id);
if !fsvi_path.exists() {
return HashMap::new();
}
match VectorIndex::open(&fsvi_path) {
Ok(index) => {
let mut hashes = HashMap::new();
for idx in 0..index.record_count() {
let doc_id_str = match index.doc_id_at(idx) {
Ok(doc_id) => doc_id,
Err(_) => continue,
};
if let Some(parsed) = parse_semantic_doc_id(doc_id_str)
&& let Some(hash) = parsed.content_hash
{
hashes.insert(parsed.message_id, hash);
}
}
debug!(
path = %fsvi_path.display(),
count = hashes.len(),
"Loaded existing hashes for dedup"
);
hashes
}
Err(e) => {
warn!(
path = %fsvi_path.display(),
error = %e,
"Failed to load existing index for dedup"
);
HashMap::new()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build_pass_config(
two_tier: bool,
fast_model: Option<&str>,
quality_model: Option<&str>,
) -> EmbeddingJobConfig {
EmbeddingJobConfig {
db_path: String::new(),
index_path: String::new(),
two_tier,
fast_model: fast_model.map(str::to_string),
quality_model: quality_model.map(str::to_string),
}
}
fn fast_embed_kind(model_name: &str, embedder_id: &str) -> WorkerEmbedderKind {
WorkerEmbedderKind::FastEmbed {
model_name: model_name.to_string(),
embedder_id: embedder_id.to_string(),
}
}
#[test]
fn test_worker_handle_clone() {
let (_worker, handle) = EmbeddingWorker::new();
let handle2 = handle.clone();
assert!(handle.shutdown().is_ok());
drop(handle2);
}
#[test]
fn test_job_config() {
let config = EmbeddingJobConfig {
db_path: "/tmp/test.db".to_string(),
index_path: "/tmp/test_index".to_string(),
two_tier: true,
fast_model: Some("hash".to_string()),
quality_model: Some("minilm".to_string()),
};
assert!(config.two_tier);
assert_eq!(config.fast_model.as_deref(), Some("hash"));
assert_eq!(config.quality_model.as_deref(), Some("minilm"));
}
#[test]
fn test_build_passes_single() {
let (_worker, _handle) = EmbeddingWorker::new();
let config = build_pass_config(false, None, Some("minilm"));
let passes = _worker.build_passes(&config);
assert_eq!(passes.len(), 1);
assert_eq!(passes[0].0, "minilm");
assert!(passes[0].1); }
#[test]
fn test_build_passes_two_tier() {
let (_worker, _handle) = EmbeddingWorker::new();
let config = build_pass_config(true, Some("hash"), Some("minilm"));
let passes = _worker.build_passes(&config);
assert_eq!(passes.len(), 2);
assert_eq!(passes[0].0, "hash");
assert!(!passes[0].1); assert_eq!(passes[1].0, "minilm");
assert!(passes[1].1); }
#[test]
fn test_build_passes_defaults() {
let (_worker, _handle) = EmbeddingWorker::new();
let config = build_pass_config(false, None, None);
let passes = _worker.build_passes(&config);
assert_eq!(passes.len(), 1);
assert_eq!(passes[0].0, "hash");
assert!(!passes[0].1); }
#[test]
fn test_message_id_from_db_rejects_negative_ids() {
assert_eq!(message_id_from_db(-1), None);
assert_eq!(message_id_from_db(0), Some(0));
assert_eq!(message_id_from_db(42), Some(42));
}
#[test]
fn test_saturating_u32_from_i64_clamps_bounds() {
assert_eq!(saturating_u32_from_i64(-7), 0);
assert_eq!(saturating_u32_from_i64(0), 0);
assert_eq!(saturating_u32_from_i64(7), 7);
assert_eq!(saturating_u32_from_i64(i64::from(u32::MAX) + 123), u32::MAX);
}
#[test]
fn test_saturating_i64_from_usize_clamps_overflow() {
assert_eq!(saturating_i64_from_usize(0), 0);
assert_eq!(saturating_i64_from_usize(7), 7);
assert_eq!(
saturating_i64_from_usize(usize::MAX),
i64::try_from(usize::MAX).unwrap_or(i64::MAX)
);
}
#[test]
fn test_resolve_embedder_kind_hash_aliases() {
assert_eq!(
resolve_embedder_kind("hash", false).unwrap(),
WorkerEmbedderKind::Hash
);
assert_eq!(
resolve_embedder_kind("FNV1A-384", true).unwrap(),
WorkerEmbedderKind::Hash
);
}
#[test]
fn test_resolve_embedder_kind_use_semantic_false_short_circuits_regardless_of_name() {
for semantic_name in [
"minilm",
"minilm-384",
"all-minilm-l6-v2",
"fastembed",
"snowflake-arctic-s",
"snowflake-arctic-embed-s",
"nomic-embed",
"nomic-embed-text-v1.5",
"MINILM",
] {
assert_eq!(
resolve_embedder_kind(semantic_name, false).unwrap(),
WorkerEmbedderKind::Hash,
"use_semantic=false MUST short-circuit to Hash regardless of model_name; \
regression on name {semantic_name:?} indicates the !use_semantic branch \
was bypassed"
);
}
}
#[test]
fn test_resolve_embedder_kind_semantic_aliases() {
assert_eq!(
resolve_embedder_kind("minilm", true).unwrap(),
fast_embed_kind("minilm", "minilm-384")
);
assert_eq!(
resolve_embedder_kind("MINILM-384", true).unwrap(),
fast_embed_kind("minilm", "minilm-384")
);
assert_eq!(
resolve_embedder_kind("fastembed", true).unwrap(),
fast_embed_kind("minilm", "minilm-384")
);
}
#[test]
fn test_resolve_embedder_kind_registered_fastembed_models() {
assert_eq!(
resolve_embedder_kind("snowflake-arctic-s", true).unwrap(),
fast_embed_kind("snowflake-arctic-s", "snowflake-arctic-s-384")
);
assert_eq!(
resolve_embedder_kind("nomic-embed-text-v1.5", true).unwrap(),
fast_embed_kind("nomic-embed", "nomic-embed-768")
);
}
#[test]
fn test_resolve_embedder_kind_rejects_unknown_semantic_model() {
let err = resolve_embedder_kind("e5-large", true).unwrap_err();
let msg = format!("{err:#}");
assert!(msg.contains("unsupported semantic model"));
}
}