iridium-db 0.2.0

A high-performance vector-graph hybrid storage and indexing engine
use super::vector::{build_query_vector, compare_score, score_embeddings_batch};
use super::ExecutablePredicate;
use super::Row;

#[derive(Debug, Clone)]
pub(super) struct CandidateRow {
    pub row: Row,
    pub vector: Option<crate::features::storage::api::StructuredVector>,
}

pub(super) fn rerank_candidates(
    candidates: Vec<CandidateRow>,
    pred: &ExecutablePredicate,
    morsel_size: usize,
    requested_workers: usize,
) -> Result<(Vec<Row>, u64, usize), String> {
    let chunks: Vec<(usize, Vec<CandidateRow>)> = candidates
        .chunks(morsel_size)
        .enumerate()
        .map(|(idx, chunk)| (idx, chunk.to_vec()))
        .collect();
    if chunks.is_empty() {
        return Ok((Vec::new(), 0, 1));
    }

    let workers = resolve_workers_adaptive(
        requested_workers,
        chunks.len(),
        candidates.len(),
        morsel_size,
    );
    let batches = chunks.len() as u64;
    if workers <= 1 {
        let mut rows = Vec::new();
        for (_, chunk) in chunks {
            rows.extend(process_chunk(chunk, pred)?);
        }
        return Ok((rows, batches, 1));
    }

    let (tx, rx) = std::sync::mpsc::channel::<(usize, Result<Vec<Row>, String>)>();
    std::thread::scope(|scope| {
        for worker_id in 0..workers {
            let tx = tx.clone();
            let assigned: Vec<(usize, Vec<CandidateRow>)> = chunks
                .iter()
                .skip(worker_id)
                .cloned()
                .step_by(workers)
                .collect();
            let pred = pred.clone();
            scope.spawn(move || {
                for (idx, chunk) in assigned {
                    let out = process_chunk(chunk, &pred);
                    let _ = tx.send((idx, out));
                }
            });
        }
    });
    drop(tx);

    let mut collected: Vec<(usize, Vec<Row>)> = Vec::new();
    for (idx, result) in rx {
        collected.push((idx, result?));
    }
    collected.sort_by_key(|(idx, _)| *idx);
    let mut rows = Vec::new();
    for (_, chunk_rows) in collected {
        rows.extend(chunk_rows);
    }
    Ok((rows, batches, workers))
}

pub(super) fn process_chunk(
    chunk: Vec<CandidateRow>,
    pred: &ExecutablePredicate,
) -> Result<Vec<Row>, String> {
    if chunk.is_empty() {
        return Ok(Vec::new());
    }

    rerank_chunk_with_embeddings(chunk, pred)
}

fn rerank_chunk_with_embeddings(
    chunk: Vec<CandidateRow>,
    pred: &ExecutablePredicate,
) -> Result<Vec<Row>, String> {
    let mut query_vectors: std::collections::HashMap<usize, Vec<f32>> =
        std::collections::HashMap::new();
    let mut scores = vec![None; chunk.len()];
    let mut grouped: std::collections::BTreeMap<usize, (Vec<usize>, Vec<f32>)> =
        std::collections::BTreeMap::new();

    for (idx, candidate) in chunk.iter().enumerate() {
        if let Some(vector) = candidate.vector.as_ref() {
            let dim = vector.values.len();
            if dim == 0 {
                continue;
            }
            let (indices, flat) = grouped
                .entry(dim)
                .or_insert_with(|| (Vec::new(), Vec::new()));
            indices.push(idx);
            flat.extend_from_slice(&vector.values);
        }
    }

    for (dim, (indices, flat_embeddings)) in grouped {
        let query = if let Some(existing) = query_vectors.get(&dim) {
            existing.clone()
        } else {
            let built = build_query_vector(pred, dim)?;
            query_vectors.insert(dim, built.clone());
            built
        };
        let batch_scores = score_embeddings_batch(pred.metric, &query, &flat_embeddings, dim);
        for (local_idx, score) in batch_scores.into_iter().enumerate() {
            if let Some(chunk_idx) = indices.get(local_idx).copied() {
                scores[chunk_idx] = Some(score);
            }
        }
    }

    let mut out = Vec::new();
    for (idx, mut candidate) in chunk.into_iter().enumerate() {
        let Some(score) = scores[idx] else {
            continue;
        };
        if compare_score(score, &pred.operator, pred.threshold) {
            candidate.row.score = Some(score);
            out.push(candidate.row);
        }
    }
    Ok(out)
}

fn resolve_workers(requested_workers: usize) -> usize {
    if !internal_parallelism_enabled() {
        return 1;
    }
    if requested_workers > 0 {
        return requested_workers;
    }
    std::thread::available_parallelism()
        .map(|n| n.get())
        .unwrap_or(1)
}

fn resolve_workers_adaptive(
    requested_workers: usize,
    chunk_count: usize,
    candidate_count: usize,
    morsel_size: usize,
) -> usize {
    let base = resolve_workers(requested_workers).min(chunk_count.max(1));
    if requested_workers == 0 && candidate_count <= morsel_size.saturating_mul(4) {
        return 1;
    }
    base.max(1)
}

fn internal_parallelism_enabled() -> bool {
    if let Ok(mode) = std::env::var("IR_THREAD_PER_CORE") {
        let mode = mode.trim().to_ascii_lowercase();
        if mode == "1" || mode == "true" || mode == "on" {
            return false;
        }
    }
    if let Ok(flag) = std::env::var("IR_INTERNAL_PARALLELISM") {
        let flag = flag.trim().to_ascii_lowercase();
        if flag == "0" || flag == "false" || flag == "off" {
            return false;
        }
    }
    true
}