iridium-db 0.2.0

A high-performance vector-graph hybrid storage and indexing engine
use std::collections::HashSet;
use std::time::Instant;

use crate::features::storage::api as storage_api;
use plexus_serde::{Op, Plan};

use super::super::super::{ExecuteParams, ExplainError, Result, Row, RowStream};
use super::aggregate::{apply_aggregate, validate_supported_aggregate};
use super::execute_deserialized_plan;
use super::plan::validate_sort;
use super::types::{sort_rows, DirectUnion, UnionChain, UnionWrapper};

pub(super) fn union_chain_from_root(plan: &Plan) -> Result<Option<UnionChain>> {
    let root = plan.root_op as usize;
    let Some(Op::Return { .. }) = plan.ops.get(root) else {
        return Err(ExplainError::SerializedPlanMalformed(
            "root op must be Return".to_string(),
        ));
    };
    let mut idx = root;
    let mut wrappers = Vec::new();
    loop {
        let Some(op) = plan.ops.get(idx) else {
            return Err(ExplainError::SerializedPlanMalformed(format!(
                "invalid root chain op index {}",
                idx
            )));
        };
        match op {
            Op::Return { input } => idx = *input as usize,
            Op::Sort { input, keys, dirs } => {
                wrappers.push(UnionWrapper::Sort(validate_sort(keys, dirs)?));
                idx = *input as usize;
            }
            Op::Limit { input, count, skip } => {
                if *skip != 0 {
                    return Err(ExplainError::UnsupportedSerializedOperator(
                        "Limit with skip != 0".to_string(),
                    ));
                }
                if *count < 0 {
                    return Err(ExplainError::SerializedPlanMalformed(
                        "Limit count must be non-negative".to_string(),
                    ));
                }
                wrappers.push(UnionWrapper::Limit(*count as usize));
                idx = *input as usize;
            }
            Op::Aggregate {
                input, keys, aggs, ..
            } => {
                wrappers.push(UnionWrapper::Aggregate(validate_supported_aggregate(
                    keys, aggs,
                )?));
                idx = *input as usize;
            }
            Op::PathConstruct { input, .. } => {
                wrappers.push(UnionWrapper::PathConstruct);
                idx = *input as usize;
            }
            Op::Union { lhs, rhs, all, .. } => {
                return Ok(Some(UnionChain {
                    union: DirectUnion {
                        lhs: *lhs,
                        rhs: *rhs,
                        all: *all,
                    },
                    wrappers,
                }));
            }
            _ => return Ok(None),
        }
        if idx >= plan.ops.len() {
            return Err(ExplainError::SerializedPlanMalformed(format!(
                "invalid root chain reference index {}",
                idx
            )));
        }
    }
}

pub(super) fn execute_union_chain(
    plan: &Plan,
    chain: &UnionChain,
    params: &ExecuteParams,
    handle: &mut storage_api::StorageHandle,
) -> Result<RowStream> {
    let start = Instant::now();
    let lhs_plan = branch_plan_with_return(plan, chain.union.lhs)?;
    let rhs_plan = branch_plan_with_return(plan, chain.union.rhs)?;
    let lhs_stream = execute_deserialized_plan(&lhs_plan, params, handle)?;
    let rhs_stream = execute_deserialized_plan(&rhs_plan, params, handle)?;

    let mut rows = lhs_stream.rows;
    rows.extend(rhs_stream.rows);
    if !chain.union.all {
        let mut seen = HashSet::new();
        rows.retain(|row| seen.insert(row.node_id));
    }
    apply_union_wrappers(&mut rows, &chain.wrappers);

    Ok(RowStream {
        rows,
        scanned_nodes: lhs_stream
            .scanned_nodes
            .saturating_add(rhs_stream.scanned_nodes),
        latency_micros: start.elapsed().as_micros(),
        morsels_processed: lhs_stream
            .morsels_processed
            .saturating_add(rhs_stream.morsels_processed),
        rerank_batches: lhs_stream
            .rerank_batches
            .saturating_add(rhs_stream.rerank_batches),
        parallel_workers: lhs_stream.parallel_workers.max(rhs_stream.parallel_workers),
    })
}

fn apply_union_wrappers(rows: &mut Vec<Row>, wrappers: &[UnionWrapper]) {
    for wrapper in wrappers.iter().rev() {
        match wrapper {
            UnionWrapper::Sort(sort) => sort_rows(rows, *sort),
            UnionWrapper::Limit(limit) => {
                if rows.len() > *limit {
                    rows.truncate(*limit);
                }
            }
            UnionWrapper::Aggregate(aggregate) => {
                let owned = std::mem::take(rows);
                *rows = apply_aggregate(owned, *aggregate);
            }
            UnionWrapper::PathConstruct => {}
        }
    }
}

fn branch_plan_with_return(plan: &Plan, root_input: u32) -> Result<Plan> {
    let idx = root_input as usize;
    if idx >= plan.ops.len() {
        return Err(ExplainError::SerializedPlanMalformed(format!(
            "invalid union branch index {}",
            root_input
        )));
    }
    let mut ops = plan.ops.clone();
    let return_idx = ops.len() as u32;
    ops.push(Op::Return { input: root_input });
    Ok(Plan {
        version: plan.version.clone(),
        ops,
        root_op: return_idx,
    })
}