iridium-db 0.2.0

A high-performance vector-graph hybrid storage and indexing engine
use std::collections::HashSet;

use plexus_serde::{CmpOp, ExpandDir, Expr, Op, Plan, SortDir};

use crate::features::runtime::api::types::ExecutableScalarPredicate;

use super::super::super::{ExplainError, Result};
use super::super::validate::op_name;
use super::aggregate::{parse_supported_row_field, validate_supported_aggregate};
use super::types::{
    make_explain_plan, AdapterExecutionPlan, ExpandKind, ExpandSpec, ExplainPlanRequest, PostSort,
};

pub(super) fn translate_readonly_plan(plan: &Plan) -> Result<AdapterExecutionPlan> {
    let mut idx = plan.root_op as usize;
    let mut visited = HashSet::new();
    let mut op_chain = Vec::new();
    let mut limit: Option<u64> = None;
    let mut has_project = false;
    let mut scalar_predicate: Option<ExecutableScalarPredicate> = None;
    let mut has_filter = false;
    let mut expand: Option<ExpandSpec> = None;
    let mut aggregate = None;
    let mut unwind_values: Option<Vec<u64>> = None;
    let mut path_construct = false;
    let mut post_sort: Option<PostSort> = None;
    let mut has_post_expand_filter = false;
    let mut has_post_expand_unwind = false;

    match plan.ops.get(idx) {
        Some(Op::Return { .. }) => {}
        Some(other) => {
            return Err(ExplainError::SerializedPlanMalformed(format!(
                "root op must be Return, found {}",
                op_name(other)
            )));
        }
        None => {
            return Err(ExplainError::SerializedPlanMalformed(format!(
                "invalid root op index {}",
                plan.root_op
            )));
        }
    }

    loop {
        if !visited.insert(idx) {
            return Err(ExplainError::SerializedPlanMalformed(
                "cycle detected while tracing root pipeline".to_string(),
            ));
        }
        let op = &plan.ops[idx];
        op_chain.push(op_name(op).to_string());
        match op {
            Op::Return { input } => idx = *input as usize,
            Op::Limit { input, count, skip } => {
                if *skip != 0 {
                    return Err(ExplainError::UnsupportedSerializedOperator(
                        "Limit with skip != 0".to_string(),
                    ));
                }
                if *count < 0 {
                    return Err(ExplainError::SerializedPlanMalformed(
                        "Limit count must be non-negative".to_string(),
                    ));
                }
                limit = Some(*count as u64);
                idx = *input as usize;
            }
            Op::Project { input, exprs, .. } => {
                validate_supported_project_exprs(exprs)?;
                has_project = true;
                idx = *input as usize;
            }
            Op::Sort { input, keys, dirs } => {
                post_sort = Some(validate_sort(keys, dirs)?);
                idx = *input as usize;
            }
            Op::Expand {
                input,
                src_col,
                types,
                dir,
                ..
            } => {
                validate_supported_expand_shape(*src_col, types, *dir, "Expand")?;
                expand = Some(ExpandSpec {
                    kind: ExpandKind::Expand,
                    dir: *dir,
                });
                idx = *input as usize;
            }
            Op::OptionalExpand {
                input,
                src_col,
                types,
                dir,
                ..
            } => {
                validate_supported_expand_shape(*src_col, types, *dir, "OptionalExpand")?;
                expand = Some(ExpandSpec {
                    kind: ExpandKind::OptionalExpand,
                    dir: *dir,
                });
                idx = *input as usize;
            }
            Op::SemiExpand {
                input,
                src_col,
                types,
                dir,
                ..
            } => {
                validate_supported_expand_shape(*src_col, types, *dir, "SemiExpand")?;
                expand = Some(ExpandSpec {
                    kind: ExpandKind::SemiExpand,
                    dir: *dir,
                });
                idx = *input as usize;
            }
            Op::ExpandVarLen {
                input,
                src_col,
                types,
                dir,
                min_hops,
                max_hops,
                ..
            } => {
                validate_supported_expand_shape(*src_col, types, *dir, "ExpandVarLen")?;
                if *min_hops < 1 || *max_hops < *min_hops {
                    return Err(ExplainError::UnsupportedSerializedOperator(
                        "ExpandVarLen requires 1 <= min_hops <= max_hops".to_string(),
                    ));
                }
                if *max_hops > 8 {
                    return Err(ExplainError::UnsupportedSerializedOperator(
                        "ExpandVarLen currently supports max_hops <= 8".to_string(),
                    ));
                }
                expand = Some(ExpandSpec {
                    kind: ExpandKind::ExpandVarLen {
                        min_hops: *min_hops,
                        max_hops: *max_hops,
                    },
                    dir: *dir,
                });
                idx = *input as usize;
            }
            Op::Filter { input, predicate } => {
                if expand.is_none() {
                    has_post_expand_filter = true;
                }
                has_filter = true;
                scalar_predicate = Some(extract_scalar_predicate(predicate)?);
                idx = *input as usize;
            }
            Op::Unwind {
                input, list_expr, ..
            } => {
                if expand.is_none() {
                    has_post_expand_unwind = true;
                }
                unwind_values = Some(extract_unwind_values(list_expr)?);
                idx = *input as usize;
            }
            Op::PathConstruct { input, .. } => {
                path_construct = true;
                idx = *input as usize;
            }
            Op::Aggregate {
                input, keys, aggs, ..
            } => {
                aggregate = Some(validate_supported_aggregate(keys, aggs)?);
                idx = *input as usize;
            }
            Op::ScanNodes { .. } => break,
            other => {
                return Err(ExplainError::UnsupportedSerializedOperator(
                    op_name(other).to_string(),
                ));
            }
        }
        if idx >= plan.ops.len() {
            return Err(ExplainError::SerializedPlanMalformed(format!(
                "invalid op reference index {}",
                idx
            )));
        }
    }

    op_chain.reverse();
    if expand.is_some() && has_post_expand_filter {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Filter above Expand (post-expand filtering) is not supported yet".to_string(),
        ));
    }
    if expand.is_some() && has_post_expand_unwind {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Unwind above Expand (post-expand unwind) is not supported yet".to_string(),
        ));
    }

    Ok(AdapterExecutionPlan {
        explain_plan: make_explain_plan(ExplainPlanRequest {
            logical_ops: op_chain,
            has_project,
            has_filter,
            limit,
            post_sort,
            expand,
            unwind_values: unwind_values.clone(),
            aggregate,
            scalar_predicate,
        }),
        expand,
        aggregate,
        unwind_values,
        path_construct,
        post_sort,
        post_limit: limit.map(|value| value as usize),
    })
}

pub(super) fn validate_sort(keys: &[u32], dirs: &[SortDir]) -> Result<PostSort> {
    if keys.len() != dirs.len() {
        return Err(ExplainError::SerializedPlanMalformed(format!(
            "Sort keys/dirs length mismatch ({} vs {})",
            keys.len(),
            dirs.len()
        )));
    }
    if keys.len() != 1 {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Sort currently supports exactly one key".to_string(),
        ));
    }
    if keys[0] != 0 {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Sort currently supports only key index 0 (node id)".to_string(),
        ));
    }
    match dirs[0] {
        SortDir::Asc => Ok(PostSort::NodeIdAsc),
        SortDir::Desc => Ok(PostSort::NodeIdDesc),
    }
}

fn validate_supported_project_exprs(exprs: &[Expr]) -> Result<()> {
    if exprs.is_empty() {
        return Ok(());
    }
    match &exprs[0] {
        Expr::ColRef { idx } if *idx == 0 => {}
        _ => {
            return Err(ExplainError::UnsupportedSerializedOperator(
                "Project currently requires first expression ColRef(idx=0)".to_string(),
            ))
        }
    }
    for expr in exprs.iter().skip(1) {
        match expr {
            Expr::ColRef { idx } if *idx == 0 => {}
            Expr::PropAccess { col, prop } if *col == 0 => {
                parse_supported_row_field(prop)?;
            }
            Expr::IntLiteral(_)
            | Expr::FloatLiteral(_)
            | Expr::BoolLiteral(_)
            | Expr::StringLiteral(_)
            | Expr::NullLiteral => {}
            _ => {
                return Err(ExplainError::UnsupportedSerializedOperator(
                    "Project currently supports only ColRef(idx=0), PropAccess(col=0,<supported-field>), and literal trailing expressions".to_string(),
                ))
            }
        }
    }
    Ok(())
}

fn validate_supported_expand_shape(
    src_col: u32,
    types: &[String],
    dir: ExpandDir,
    op_name: &str,
) -> Result<()> {
    if src_col != 0 {
        return Err(ExplainError::UnsupportedSerializedOperator(format!(
            "{} currently supports only src_col=0",
            op_name
        )));
    }
    if !types.is_empty() {
        return Err(ExplainError::UnsupportedSerializedOperator(format!(
            "{} currently supports only empty relationship type filter",
            op_name
        )));
    }
    if dir != ExpandDir::Out {
        return Err(ExplainError::UnsupportedSerializedOperator(format!(
            "{} currently supports only Out direction",
            op_name
        )));
    }
    Ok(())
}

pub(super) fn extract_unwind_values(expr: &Expr) -> Result<Vec<u64>> {
    let Expr::ListLiteral { items } = expr else {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Unwind currently supports only ListLiteral expressions".to_string(),
        ));
    };
    let mut out = Vec::with_capacity(items.len());
    for item in items {
        match item {
            Expr::IntLiteral(value) if *value >= 0 => out.push(*value as u64),
            _ => {
                return Err(ExplainError::UnsupportedSerializedOperator(
                    "Unwind list currently supports only non-negative IntLiteral items".to_string(),
                ))
            }
        }
    }
    Ok(out)
}

fn extract_scalar_predicate(predicate: &Expr) -> Result<ExecutableScalarPredicate> {
    let Expr::Cmp { op, lhs, rhs } = predicate else {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Filter currently requires Cmp predicate".to_string(),
        ));
    };
    let Expr::PropAccess { prop, .. } = lhs.as_ref() else {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Filter currently requires PropAccess on lhs".to_string(),
        ));
    };
    let value = literal_as_f64(rhs)?;
    Ok(ExecutableScalarPredicate {
        field: prop.clone(),
        operator: cmp_op_as_str(*op).to_string(),
        value,
    })
}

fn literal_as_f64(expr: &Expr) -> Result<f64> {
    match expr {
        Expr::IntLiteral(value) => Ok(*value as f64),
        Expr::FloatLiteral(value) => Ok(*value),
        Expr::BoolLiteral(value) => Ok(if *value { 1.0 } else { 0.0 }),
        _ => Err(ExplainError::UnsupportedSerializedOperator(
            "Filter rhs must be IntLiteral, FloatLiteral, or BoolLiteral".to_string(),
        )),
    }
}

fn cmp_op_as_str(op: CmpOp) -> &'static str {
    match op {
        CmpOp::Eq => "=",
        CmpOp::Ne => "!=",
        CmpOp::Lt => "<",
        CmpOp::Gt => ">",
        CmpOp::Le => "<=",
        CmpOp::Ge => ">=",
    }
}