use super::vector::{build_query_vector, compare_score, score_embeddings_batch};
use super::ExecutablePredicate;
use super::Row;
#[derive(Debug, Clone)]
pub(super) struct CandidateRow {
pub row: Row,
pub vector: Option<crate::features::storage::api::StructuredVector>,
}
pub(super) fn rerank_candidates(
candidates: Vec<CandidateRow>,
pred: &ExecutablePredicate,
morsel_size: usize,
requested_workers: usize,
) -> Result<(Vec<Row>, u64, usize), String> {
let chunks: Vec<(usize, Vec<CandidateRow>)> = candidates
.chunks(morsel_size)
.enumerate()
.map(|(idx, chunk)| (idx, chunk.to_vec()))
.collect();
if chunks.is_empty() {
return Ok((Vec::new(), 0, 1));
}
let workers = resolve_workers_adaptive(
requested_workers,
chunks.len(),
candidates.len(),
morsel_size,
);
let batches = chunks.len() as u64;
if workers <= 1 {
let mut rows = Vec::new();
for (_, chunk) in chunks {
rows.extend(process_chunk(chunk, pred)?);
}
return Ok((rows, batches, 1));
}
let (tx, rx) = std::sync::mpsc::channel::<(usize, Result<Vec<Row>, String>)>();
std::thread::scope(|scope| {
for worker_id in 0..workers {
let tx = tx.clone();
let assigned: Vec<(usize, Vec<CandidateRow>)> = chunks
.iter()
.skip(worker_id)
.cloned()
.step_by(workers)
.collect();
let pred = pred.clone();
scope.spawn(move || {
for (idx, chunk) in assigned {
let out = process_chunk(chunk, &pred);
let _ = tx.send((idx, out));
}
});
}
});
drop(tx);
let mut collected: Vec<(usize, Vec<Row>)> = Vec::new();
for (idx, result) in rx {
collected.push((idx, result?));
}
collected.sort_by_key(|(idx, _)| *idx);
let mut rows = Vec::new();
for (_, chunk_rows) in collected {
rows.extend(chunk_rows);
}
Ok((rows, batches, workers))
}
pub(super) fn process_chunk(
chunk: Vec<CandidateRow>,
pred: &ExecutablePredicate,
) -> Result<Vec<Row>, String> {
if chunk.is_empty() {
return Ok(Vec::new());
}
rerank_chunk_with_embeddings(chunk, pred)
}
fn rerank_chunk_with_embeddings(
chunk: Vec<CandidateRow>,
pred: &ExecutablePredicate,
) -> Result<Vec<Row>, String> {
let mut query_vectors: std::collections::HashMap<usize, Vec<f32>> =
std::collections::HashMap::new();
let mut scores = vec![None; chunk.len()];
let mut grouped: std::collections::BTreeMap<usize, (Vec<usize>, Vec<f32>)> =
std::collections::BTreeMap::new();
for (idx, candidate) in chunk.iter().enumerate() {
if let Some(vector) = candidate.vector.as_ref() {
let dim = vector.values.len();
if dim == 0 {
continue;
}
let (indices, flat) = grouped
.entry(dim)
.or_insert_with(|| (Vec::new(), Vec::new()));
indices.push(idx);
flat.extend_from_slice(&vector.values);
}
}
for (dim, (indices, flat_embeddings)) in grouped {
let query = if let Some(existing) = query_vectors.get(&dim) {
existing.clone()
} else {
let built = build_query_vector(pred, dim)?;
query_vectors.insert(dim, built.clone());
built
};
let batch_scores = score_embeddings_batch(pred.metric, &query, &flat_embeddings, dim);
for (local_idx, score) in batch_scores.into_iter().enumerate() {
if let Some(chunk_idx) = indices.get(local_idx).copied() {
scores[chunk_idx] = Some(score);
}
}
}
let mut out = Vec::new();
for (idx, mut candidate) in chunk.into_iter().enumerate() {
let Some(score) = scores[idx] else {
continue;
};
if compare_score(score, &pred.operator, pred.threshold) {
candidate.row.score = Some(score);
out.push(candidate.row);
}
}
Ok(out)
}
fn resolve_workers(requested_workers: usize) -> usize {
if !internal_parallelism_enabled() {
return 1;
}
if requested_workers > 0 {
return requested_workers;
}
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
}
fn resolve_workers_adaptive(
requested_workers: usize,
chunk_count: usize,
candidate_count: usize,
morsel_size: usize,
) -> usize {
let base = resolve_workers(requested_workers).min(chunk_count.max(1));
if requested_workers == 0 && candidate_count <= morsel_size.saturating_mul(4) {
return 1;
}
base.max(1)
}
fn internal_parallelism_enabled() -> bool {
if let Ok(mode) = std::env::var("IR_THREAD_PER_CORE") {
let mode = mode.trim().to_ascii_lowercase();
if mode == "1" || mode == "true" || mode == "on" {
return false;
}
}
if let Ok(flag) = std::env::var("IR_INTERNAL_PARALLELISM") {
let flag = flag.trim().to_ascii_lowercase();
if flag == "0" || flag == "false" || flag == "off" {
return false;
}
}
true
}