use std::collections::HashSet;
use plexus_serde::{CmpOp, ExpandDir, Expr, Op, Plan, SortDir};
use crate::features::runtime::api::types::ExecutableScalarPredicate;
use super::super::super::{ExplainError, Result};
use super::super::validate::op_name;
use super::aggregate::{parse_supported_row_field, validate_supported_aggregate};
use super::types::{
make_explain_plan, AdapterExecutionPlan, ExpandKind, ExpandSpec, ExplainPlanRequest, PostSort,
};
pub(super) fn translate_readonly_plan(plan: &Plan) -> Result<AdapterExecutionPlan> {
let mut idx = plan.root_op as usize;
let mut visited = HashSet::new();
let mut op_chain = Vec::new();
let mut limit: Option<u64> = None;
let mut has_project = false;
let mut scalar_predicate: Option<ExecutableScalarPredicate> = None;
let mut has_filter = false;
let mut expand: Option<ExpandSpec> = None;
let mut aggregate = None;
let mut unwind_values: Option<Vec<u64>> = None;
let mut path_construct = false;
let mut post_sort: Option<PostSort> = None;
let mut has_post_expand_filter = false;
let mut has_post_expand_unwind = false;
match plan.ops.get(idx) {
Some(Op::Return { .. }) => {}
Some(other) => {
return Err(ExplainError::SerializedPlanMalformed(format!(
"root op must be Return, found {}",
op_name(other)
)));
}
None => {
return Err(ExplainError::SerializedPlanMalformed(format!(
"invalid root op index {}",
plan.root_op
)));
}
}
loop {
if !visited.insert(idx) {
return Err(ExplainError::SerializedPlanMalformed(
"cycle detected while tracing root pipeline".to_string(),
));
}
let op = &plan.ops[idx];
op_chain.push(op_name(op).to_string());
match op {
Op::Return { input } => idx = *input as usize,
Op::Limit { input, count, skip } => {
if *skip != 0 {
return Err(ExplainError::UnsupportedSerializedOperator(
"Limit with skip != 0".to_string(),
));
}
if *count < 0 {
return Err(ExplainError::SerializedPlanMalformed(
"Limit count must be non-negative".to_string(),
));
}
limit = Some(*count as u64);
idx = *input as usize;
}
Op::Project { input, exprs, .. } => {
validate_supported_project_exprs(exprs)?;
has_project = true;
idx = *input as usize;
}
Op::Sort { input, keys, dirs } => {
post_sort = Some(validate_sort(keys, dirs)?);
idx = *input as usize;
}
Op::Expand {
input,
src_col,
types,
dir,
..
} => {
validate_supported_expand_shape(*src_col, types, *dir, "Expand")?;
expand = Some(ExpandSpec {
kind: ExpandKind::Expand,
dir: *dir,
});
idx = *input as usize;
}
Op::OptionalExpand {
input,
src_col,
types,
dir,
..
} => {
validate_supported_expand_shape(*src_col, types, *dir, "OptionalExpand")?;
expand = Some(ExpandSpec {
kind: ExpandKind::OptionalExpand,
dir: *dir,
});
idx = *input as usize;
}
Op::SemiExpand {
input,
src_col,
types,
dir,
..
} => {
validate_supported_expand_shape(*src_col, types, *dir, "SemiExpand")?;
expand = Some(ExpandSpec {
kind: ExpandKind::SemiExpand,
dir: *dir,
});
idx = *input as usize;
}
Op::ExpandVarLen {
input,
src_col,
types,
dir,
min_hops,
max_hops,
..
} => {
validate_supported_expand_shape(*src_col, types, *dir, "ExpandVarLen")?;
if *min_hops < 1 || *max_hops < *min_hops {
return Err(ExplainError::UnsupportedSerializedOperator(
"ExpandVarLen requires 1 <= min_hops <= max_hops".to_string(),
));
}
if *max_hops > 8 {
return Err(ExplainError::UnsupportedSerializedOperator(
"ExpandVarLen currently supports max_hops <= 8".to_string(),
));
}
expand = Some(ExpandSpec {
kind: ExpandKind::ExpandVarLen {
min_hops: *min_hops,
max_hops: *max_hops,
},
dir: *dir,
});
idx = *input as usize;
}
Op::Filter { input, predicate } => {
if expand.is_none() {
has_post_expand_filter = true;
}
has_filter = true;
scalar_predicate = Some(extract_scalar_predicate(predicate)?);
idx = *input as usize;
}
Op::Unwind {
input, list_expr, ..
} => {
if expand.is_none() {
has_post_expand_unwind = true;
}
unwind_values = Some(extract_unwind_values(list_expr)?);
idx = *input as usize;
}
Op::PathConstruct { input, .. } => {
path_construct = true;
idx = *input as usize;
}
Op::Aggregate {
input, keys, aggs, ..
} => {
aggregate = Some(validate_supported_aggregate(keys, aggs)?);
idx = *input as usize;
}
Op::ScanNodes { .. } => break,
other => {
return Err(ExplainError::UnsupportedSerializedOperator(
op_name(other).to_string(),
));
}
}
if idx >= plan.ops.len() {
return Err(ExplainError::SerializedPlanMalformed(format!(
"invalid op reference index {}",
idx
)));
}
}
op_chain.reverse();
if expand.is_some() && has_post_expand_filter {
return Err(ExplainError::UnsupportedSerializedOperator(
"Filter above Expand (post-expand filtering) is not supported yet".to_string(),
));
}
if expand.is_some() && has_post_expand_unwind {
return Err(ExplainError::UnsupportedSerializedOperator(
"Unwind above Expand (post-expand unwind) is not supported yet".to_string(),
));
}
Ok(AdapterExecutionPlan {
explain_plan: make_explain_plan(ExplainPlanRequest {
logical_ops: op_chain,
has_project,
has_filter,
limit,
post_sort,
expand,
unwind_values: unwind_values.clone(),
aggregate,
scalar_predicate,
}),
expand,
aggregate,
unwind_values,
path_construct,
post_sort,
post_limit: limit.map(|value| value as usize),
})
}
pub(super) fn validate_sort(keys: &[u32], dirs: &[SortDir]) -> Result<PostSort> {
if keys.len() != dirs.len() {
return Err(ExplainError::SerializedPlanMalformed(format!(
"Sort keys/dirs length mismatch ({} vs {})",
keys.len(),
dirs.len()
)));
}
if keys.len() != 1 {
return Err(ExplainError::UnsupportedSerializedOperator(
"Sort currently supports exactly one key".to_string(),
));
}
if keys[0] != 0 {
return Err(ExplainError::UnsupportedSerializedOperator(
"Sort currently supports only key index 0 (node id)".to_string(),
));
}
match dirs[0] {
SortDir::Asc => Ok(PostSort::NodeIdAsc),
SortDir::Desc => Ok(PostSort::NodeIdDesc),
}
}
fn validate_supported_project_exprs(exprs: &[Expr]) -> Result<()> {
if exprs.is_empty() {
return Ok(());
}
match &exprs[0] {
Expr::ColRef { idx } if *idx == 0 => {}
_ => {
return Err(ExplainError::UnsupportedSerializedOperator(
"Project currently requires first expression ColRef(idx=0)".to_string(),
))
}
}
for expr in exprs.iter().skip(1) {
match expr {
Expr::ColRef { idx } if *idx == 0 => {}
Expr::PropAccess { col, prop } if *col == 0 => {
parse_supported_row_field(prop)?;
}
Expr::IntLiteral(_)
| Expr::FloatLiteral(_)
| Expr::BoolLiteral(_)
| Expr::StringLiteral(_)
| Expr::NullLiteral => {}
_ => {
return Err(ExplainError::UnsupportedSerializedOperator(
"Project currently supports only ColRef(idx=0), PropAccess(col=0,<supported-field>), and literal trailing expressions".to_string(),
))
}
}
}
Ok(())
}
fn validate_supported_expand_shape(
src_col: u32,
types: &[String],
dir: ExpandDir,
op_name: &str,
) -> Result<()> {
if src_col != 0 {
return Err(ExplainError::UnsupportedSerializedOperator(format!(
"{} currently supports only src_col=0",
op_name
)));
}
if !types.is_empty() {
return Err(ExplainError::UnsupportedSerializedOperator(format!(
"{} currently supports only empty relationship type filter",
op_name
)));
}
if dir != ExpandDir::Out {
return Err(ExplainError::UnsupportedSerializedOperator(format!(
"{} currently supports only Out direction",
op_name
)));
}
Ok(())
}
pub(super) fn extract_unwind_values(expr: &Expr) -> Result<Vec<u64>> {
let Expr::ListLiteral { items } = expr else {
return Err(ExplainError::UnsupportedSerializedOperator(
"Unwind currently supports only ListLiteral expressions".to_string(),
));
};
let mut out = Vec::with_capacity(items.len());
for item in items {
match item {
Expr::IntLiteral(value) if *value >= 0 => out.push(*value as u64),
_ => {
return Err(ExplainError::UnsupportedSerializedOperator(
"Unwind list currently supports only non-negative IntLiteral items".to_string(),
))
}
}
}
Ok(out)
}
fn extract_scalar_predicate(predicate: &Expr) -> Result<ExecutableScalarPredicate> {
let Expr::Cmp { op, lhs, rhs } = predicate else {
return Err(ExplainError::UnsupportedSerializedOperator(
"Filter currently requires Cmp predicate".to_string(),
));
};
let Expr::PropAccess { prop, .. } = lhs.as_ref() else {
return Err(ExplainError::UnsupportedSerializedOperator(
"Filter currently requires PropAccess on lhs".to_string(),
));
};
let value = literal_as_f64(rhs)?;
Ok(ExecutableScalarPredicate {
field: prop.clone(),
operator: cmp_op_as_str(*op).to_string(),
value,
})
}
fn literal_as_f64(expr: &Expr) -> Result<f64> {
match expr {
Expr::IntLiteral(value) => Ok(*value as f64),
Expr::FloatLiteral(value) => Ok(*value),
Expr::BoolLiteral(value) => Ok(if *value { 1.0 } else { 0.0 }),
_ => Err(ExplainError::UnsupportedSerializedOperator(
"Filter rhs must be IntLiteral, FloatLiteral, or BoolLiteral".to_string(),
)),
}
}
fn cmp_op_as_str(op: CmpOp) -> &'static str {
match op {
CmpOp::Eq => "=",
CmpOp::Ne => "!=",
CmpOp::Lt => "<",
CmpOp::Gt => ">",
CmpOp::Le => "<=",
CmpOp::Ge => ">=",
}
}