iridium-db 0.2.0

A high-performance vector-graph hybrid storage and indexing engine
use std::time::Instant;

use crate::features::runtime::api::types::PhysicalOp;
use crate::features::storage::api as storage_api;

use super::rerank::{process_chunk as process_rerank_chunk, rerank_candidates, CandidateRow};
use super::vector::{build_query_vector, compare_score, extract_latest_vector};
use super::{
    merge_rows_deterministic, ExecuteParams, ExplainError, ExplainPlan, Result, Row, RowStream,
};
use crate::features::runtime::api::types::ExecutableScalarPredicate;

pub fn execute(
    plan: &ExplainPlan,
    params: &ExecuteParams,
    handle: &mut storage_api::StorageHandle,
) -> Result<RowStream> {
    execute_with_request(
        plan,
        params,
        handle,
        &storage_api::ThreadCoreRequest::default(),
    )
}

pub fn execute_with_request(
    plan: &ExplainPlan,
    params: &ExecuteParams,
    handle: &mut storage_api::StorageHandle,
    request: &storage_api::ThreadCoreRequest,
) -> Result<RowStream> {
    execute_internal(plan, params, handle, request)
}

pub struct FanoutShardExecution<'a> {
    pub request: storage_api::ThreadCoreRequest,
    pub handle: &'a mut storage_api::StorageHandle,
}

pub fn execute_fanout(
    plan: &ExplainPlan,
    params: &ExecuteParams,
    shards: &mut [FanoutShardExecution<'_>],
) -> Result<RowStream> {
    if shards.is_empty() {
        return Err(ExplainError::InvalidPlan(
            "fanout requires at least one shard execution target".to_string(),
        ));
    }
    let start = Instant::now();
    let mut combined_rows = Vec::new();
    let mut scanned_nodes = 0_u64;
    let mut morsels_processed = 0_u64;
    let mut rerank_batches = 0_u64;
    let mut parallel_workers = 1_usize;

    for shard in shards.iter_mut() {
        let out = execute_with_request(plan, params, shard.handle, &shard.request)?;
        scanned_nodes = scanned_nodes.saturating_add(out.scanned_nodes);
        morsels_processed = morsels_processed.saturating_add(out.morsels_processed);
        rerank_batches = rerank_batches.saturating_add(out.rerank_batches);
        parallel_workers = parallel_workers.max(out.parallel_workers);
        combined_rows.extend(out.rows);
    }

    let limit = plan.limit.unwrap_or(u64::MAX) as usize;
    let rows = merge_rows_deterministic(combined_rows, limit);

    Ok(RowStream {
        rows,
        scanned_nodes,
        latency_micros: start.elapsed().as_micros(),
        morsels_processed,
        rerank_batches,
        parallel_workers,
    })
}

fn execute_internal(
    plan: &ExplainPlan,
    params: &ExecuteParams,
    handle: &mut storage_api::StorageHandle,
    request: &storage_api::ThreadCoreRequest,
) -> Result<RowStream> {
    if params.scan_end_exclusive <= params.scan_start {
        return Err(ExplainError::InvalidPlan(
            "scan_end_exclusive must be greater than scan_start".to_string(),
        ));
    }
    if params.morsel_size == 0 {
        return Err(ExplainError::InvalidPlan(
            "morsel_size must be greater than 0".to_string(),
        ));
    }

    let start = Instant::now();
    let mut scanned = 0_u64;
    let limit = plan.limit.unwrap_or(u64::MAX) as usize;
    let scan_ids = build_scan_node_ids(plan, params, handle, request);
    let mut rows;
    let mut rerank_batches = 0_u64;
    let mut parallel_workers = 1_usize;
    let morsels_processed;

    if let Some(pred) = &plan.predicate {
        // HNSW short-circuit: when the graph is populated and this is a VectorScan,
        // skip the full node scan and delegate to the HNSW index.
        let requested_query_dim = pred.inline_vector.as_ref().map(|v| v.len());
        if plan.physical_ops.contains(&PhysicalOp::VectorScan)
            && pred.metric == storage_api::VectorMetric::Cosine
            && !handle.hnsw_graphs.is_empty()
            && ann_gate_allows_hnsw(handle, limit, params.morsel_size)
        {
            if let Some(space_id) =
                storage_api::ann_space_for_query(handle, pred.metric, requested_query_dim)
            {
                let dim = if space_id == 0 {
                    handle
                        .hnsw_graphs
                        .get(&space_id)
                        .and_then(|graph| graph.infer_dim())
                        .unwrap_or(128)
                } else {
                    handle
                        .manifest
                        .vector_space(space_id)
                        .map(|space| space.dimension as usize)
                        .unwrap_or(128)
                };
                let query = build_query_vector(pred, dim).map_err(ExplainError::InvalidPlan)?;
                let k = if limit == usize::MAX {
                    handle
                        .hnsw_graphs
                        .get(&space_id)
                        .map(|graph| graph.len())
                        .unwrap_or(0)
                } else {
                    limit
                };
                let hnsw_results = if space_id == 0 {
                    storage_api::hnsw_search(handle, &query, k)
                } else {
                    storage_api::hnsw_search_in_space(handle, space_id, &query, k)
                };
                let rows: Vec<Row> = hnsw_results
                    .into_iter()
                    .filter(|(_, sim)| compare_score(*sim, &pred.operator, pred.threshold))
                    .map(|(node_id, sim)| Row {
                        node_id,
                        has_full: false,
                        delta_count: 0,
                        adjacency_degree: 0,
                        score: Some(sim),
                        aggregate_value: None,
                    })
                    .collect();
                let scanned = rows.len() as u64;
                return Ok(RowStream {
                    rows,
                    scanned_nodes: scanned,
                    latency_micros: start.elapsed().as_micros(),
                    morsels_processed: 1,
                    rerank_batches: 0,
                    parallel_workers: 1,
                });
            }
        }

        let early_exit_enabled = limit != usize::MAX;
        if !early_exit_enabled {
            let mut candidates = Vec::new();
            for node_id in scan_ids {
                scanned += 1;
                let logical = storage_api::get_logical_node_for_request(handle, node_id, request)
                    .map_err(|e| {
                    ExplainError::InvalidPlan(format!("storage read failed: {:?}", e))
                })?;
                if logical.full.is_none() && logical.deltas.is_empty() {
                    continue;
                }
                candidates.push(CandidateRow {
                    row: Row {
                        node_id,
                        has_full: logical.full.is_some(),
                        delta_count: logical.deltas.len(),
                        adjacency_degree: logical.adjacency().len(),
                        score: None,
                        aggregate_value: None,
                    },
                    vector: extract_latest_vector(handle, &logical, pred.metric)
                        .map_err(ExplainError::InvalidPlan)?,
                });
            }
            (rows, rerank_batches, parallel_workers) = rerank_candidates(
                candidates,
                pred,
                params.morsel_size,
                params.parallel_workers,
            )
            .map_err(ExplainError::InvalidPlan)?;
        } else {
            rows = Vec::new();
            let mut chunk = Vec::with_capacity(params.morsel_size);
            let mut processed_chunks = 0_u64;
            for node_id in scan_ids {
                if rows.len() >= limit {
                    break;
                }
                scanned += 1;
                let logical = storage_api::get_logical_node_for_request(handle, node_id, request)
                    .map_err(|e| {
                    ExplainError::InvalidPlan(format!("storage read failed: {:?}", e))
                })?;
                if logical.full.is_none() && logical.deltas.is_empty() {
                    continue;
                }
                chunk.push(CandidateRow {
                    row: Row {
                        node_id,
                        has_full: logical.full.is_some(),
                        delta_count: logical.deltas.len(),
                        adjacency_degree: logical.adjacency().len(),
                        score: None,
                        aggregate_value: None,
                    },
                    vector: extract_latest_vector(handle, &logical, pred.metric)
                        .map_err(ExplainError::InvalidPlan)?,
                });
                if chunk.len() >= params.morsel_size {
                    processed_chunks += 1;
                    rows.extend(
                        process_rerank_chunk(std::mem::take(&mut chunk), pred)
                            .map_err(ExplainError::InvalidPlan)?,
                    );
                }
            }
            if !chunk.is_empty() && rows.len() < limit {
                processed_chunks += 1;
                rows.extend(process_rerank_chunk(chunk, pred).map_err(ExplainError::InvalidPlan)?);
            }
            rerank_batches = processed_chunks;
            parallel_workers = 1;
        }
        if rows.len() > limit {
            rows.truncate(limit);
        }
        morsels_processed = if scanned == 0 {
            0
        } else {
            scanned.div_ceil(params.morsel_size as u64)
        };
    } else {
        rows = Vec::new();
        let has_scalar = plan.scalar_predicate.is_some();
        for node_id in scan_ids {
            if !has_scalar && rows.len() >= limit {
                break;
            }
            scanned += 1;
            let summary = storage_api::get_node_row_summary_for_request(handle, node_id, request)
                .map_err(|e| {
                ExplainError::InvalidPlan(format!("storage read failed: {:?}", e))
            })?;
            let Some(summary) = summary else {
                continue;
            };
            rows.push(Row {
                node_id,
                has_full: summary.has_full,
                delta_count: summary.delta_count,
                adjacency_degree: summary.adjacency_degree,
                score: None,
                aggregate_value: None,
            });
        }
        if let Some(scalar_pred) = &plan.scalar_predicate {
            rows = apply_scalar_filter(std::mem::take(&mut rows), scalar_pred);
            if rows.len() > limit {
                rows.truncate(limit);
            }
        }
        morsels_processed = if scanned == 0 {
            0
        } else {
            scanned.div_ceil(params.morsel_size as u64)
        };
    }

    Ok(RowStream {
        rows,
        scanned_nodes: scanned,
        latency_micros: start.elapsed().as_micros(),
        morsels_processed,
        rerank_batches,
        parallel_workers,
    })
}

fn ann_gate_allows_hnsw(
    handle: &storage_api::StorageHandle,
    limit: usize,
    morsel_size: usize,
) -> bool {
    if limit == usize::MAX {
        return true;
    }
    if handle.hnsw_graphs.len() != 1 {
        return true;
    }
    let small_limit = limit <= 128;
    let small_scan_upper_bound = limit <= morsel_size.saturating_mul(2);
    !(small_limit && small_scan_upper_bound)
}

fn apply_scalar_filter(rows: Vec<Row>, pred: &ExecutableScalarPredicate) -> Vec<Row> {
    rows.into_iter()
        .filter(|row| {
            let lhs = match pred.field.as_str() {
                "adjacency_degree" => row.adjacency_degree as f64,
                "delta_count" => row.delta_count as f64,
                "has_full" => {
                    if row.has_full {
                        1.0
                    } else {
                        0.0
                    }
                }
                "score" => row.score.unwrap_or(0.0),
                "aggregate_value" => row.aggregate_value.unwrap_or(0.0),
                _ => return false,
            };
            compare_score(lhs, &pred.operator, pred.value)
        })
        .collect()
}

fn build_scan_node_ids(
    plan: &ExplainPlan,
    params: &ExecuteParams,
    handle: &storage_api::StorageHandle,
    request: &storage_api::ThreadCoreRequest,
) -> Vec<u64> {
    if let Some(bitmap) = &plan.bitmap_predicate {
        return storage_api::bitmap_postings_in_range_limit_for_request(
            handle,
            &bitmap.index_name,
            &bitmap.value_key,
            params.scan_start,
            params.scan_end_exclusive,
            plan.limit.map(|value| value as usize),
            request,
        );
    }
    (params.scan_start..params.scan_end_exclusive)
        .filter(|node_id| storage_api::request_owns_node(request, *node_id))
        .collect::<Vec<u64>>()
}