use std::collections::BTreeSet;
use plexus_serde::{Expr, Op, Plan};
use crate::capabilities::types::{
CapabilityError, EngineCapabilities, ExprKind, OpKind, RequiredCapabilities,
};
use crate::capabilities::wire::check_version_compat;
pub fn required_capabilities(plan: &Plan) -> RequiredCapabilities {
let mut required_ops = BTreeSet::new();
let mut required_exprs = BTreeSet::new();
for op in &plan.ops {
collect_op_features(op, &mut required_ops, &mut required_exprs);
}
RequiredCapabilities {
plan_version: (&plan.version).into(),
required_ops,
required_exprs,
}
}
pub fn validate_plan_against_capabilities(
plan: &Plan,
capabilities: &EngineCapabilities,
) -> Result<(), CapabilityError> {
check_version_compat(&plan.version, capabilities.version_range)?;
validate_graph_ref_support(plan, capabilities)?;
let required = required_capabilities(plan);
let missing_ops: Vec<_> = required
.required_ops
.difference(&capabilities.supported_ops)
.copied()
.collect();
let missing_exprs: Vec<_> = required
.required_exprs
.difference(&capabilities.supported_exprs)
.copied()
.collect();
if missing_ops.is_empty() && missing_exprs.is_empty() {
return Ok(());
}
Err(CapabilityError::MissingFeatureSupport {
missing_ops,
missing_exprs,
})
}
fn validate_graph_ref_support(
plan: &Plan,
capabilities: &EngineCapabilities,
) -> Result<(), CapabilityError> {
let mut refs = BTreeSet::<String>::new();
for op in &plan.ops {
if let Some(graph_ref) = op_graph_ref(op) {
refs.insert(graph_ref.to_string());
}
}
if refs.is_empty() {
return Ok(());
}
if !capabilities.supports_graph_ref {
return Err(CapabilityError::GraphRefUnsupported);
}
if refs.iter().any(|r| r.starts_with('$')) && !capabilities.supports_graph_params {
return Err(CapabilityError::GraphParamUnsupported);
}
if refs.len() > 1 && !capabilities.supports_multi_graph {
return Err(CapabilityError::MultiGraphUnsupported);
}
Ok(())
}
fn op_graph_ref(op: &Op) -> Option<&str> {
let maybe = match op {
Op::ScanNodes { graph_ref, .. }
| Op::Expand { graph_ref, .. }
| Op::OptionalExpand { graph_ref, .. }
| Op::ExpandVarLen { graph_ref, .. } => graph_ref.as_deref(),
_ => None,
};
maybe.map(str::trim).filter(|s| !s.is_empty())
}
fn collect_op_features(
op: &Op,
required_ops: &mut BTreeSet<OpKind>,
required_exprs: &mut BTreeSet<ExprKind>,
) {
match op {
Op::ScanNodes { .. } => {
required_ops.insert(OpKind::ScanNodes);
}
Op::ScanRels { .. } => {
required_ops.insert(OpKind::ScanRels);
}
Op::Expand { .. } => {
required_ops.insert(OpKind::Expand);
}
Op::OptionalExpand { .. } => {
required_ops.insert(OpKind::OptionalExpand);
}
Op::SemiExpand { .. } => {
required_ops.insert(OpKind::SemiExpand);
}
Op::ExpandVarLen { .. } => {
required_ops.insert(OpKind::ExpandVarLen);
}
Op::Filter { predicate, .. } => {
required_ops.insert(OpKind::Filter);
collect_expr_features(predicate, required_exprs);
}
Op::BlockMarker { .. } => {
required_ops.insert(OpKind::BlockMarker);
}
Op::Project { exprs, .. } => {
required_ops.insert(OpKind::Project);
for expr in exprs {
collect_expr_features(expr, required_exprs);
}
}
Op::Aggregate { aggs, .. } => {
required_ops.insert(OpKind::Aggregate);
for agg in aggs {
collect_expr_features(agg, required_exprs);
}
}
Op::Sort { .. } => {
required_ops.insert(OpKind::Sort);
}
Op::Limit { .. } => {
required_ops.insert(OpKind::Limit);
}
Op::Unwind { list_expr, .. } => {
required_ops.insert(OpKind::Unwind);
collect_expr_features(list_expr, required_exprs);
}
Op::PathConstruct { .. } => {
required_ops.insert(OpKind::PathConstruct);
}
Op::Union { .. } => {
required_ops.insert(OpKind::Union);
}
Op::CreateNode { props, .. } => {
required_ops.insert(OpKind::CreateNode);
collect_expr_features(props, required_exprs);
}
Op::CreateRel { props, .. } => {
required_ops.insert(OpKind::CreateRel);
collect_expr_features(props, required_exprs);
}
Op::Merge {
pattern,
on_create_props,
on_match_props,
..
} => {
required_ops.insert(OpKind::Merge);
collect_expr_features(pattern, required_exprs);
collect_expr_features(on_create_props, required_exprs);
collect_expr_features(on_match_props, required_exprs);
}
Op::Delete { .. } => {
required_ops.insert(OpKind::Delete);
}
Op::SetProperty { value_expr, .. } => {
required_ops.insert(OpKind::SetProperty);
collect_expr_features(value_expr, required_exprs);
}
Op::RemoveProperty { .. } => {
required_ops.insert(OpKind::RemoveProperty);
}
Op::VectorScan { query_vector, .. } => {
required_ops.insert(OpKind::VectorScan);
collect_expr_features(query_vector, required_exprs);
}
Op::Rerank { score_expr, .. } => {
required_ops.insert(OpKind::Rerank);
collect_expr_features(score_expr, required_exprs);
}
Op::Return { .. } => {
required_ops.insert(OpKind::Return);
}
Op::ConstRow => {
required_ops.insert(OpKind::ConstRow);
}
}
}
fn collect_expr_features(expr: &Expr, required_exprs: &mut BTreeSet<ExprKind>) {
match expr {
Expr::ColRef { .. } => {
required_exprs.insert(ExprKind::ColRef);
}
Expr::PropAccess { .. } => {
required_exprs.insert(ExprKind::PropAccess);
}
Expr::IntLiteral(_) => {
required_exprs.insert(ExprKind::IntLiteral);
}
Expr::FloatLiteral(_) => {
required_exprs.insert(ExprKind::FloatLiteral);
}
Expr::BoolLiteral(_) => {
required_exprs.insert(ExprKind::BoolLiteral);
}
Expr::StringLiteral(_) => {
required_exprs.insert(ExprKind::StringLiteral);
}
Expr::NullLiteral => {
required_exprs.insert(ExprKind::NullLiteral);
}
Expr::Cmp { lhs, rhs, .. } => {
required_exprs.insert(ExprKind::Cmp);
collect_expr_features(lhs, required_exprs);
collect_expr_features(rhs, required_exprs);
}
Expr::And { lhs, rhs } => {
required_exprs.insert(ExprKind::And);
collect_expr_features(lhs, required_exprs);
collect_expr_features(rhs, required_exprs);
}
Expr::Or { lhs, rhs } => {
required_exprs.insert(ExprKind::Or);
collect_expr_features(lhs, required_exprs);
collect_expr_features(rhs, required_exprs);
}
Expr::Not { expr } => {
required_exprs.insert(ExprKind::Not);
collect_expr_features(expr, required_exprs);
}
Expr::IsNull { expr } => {
required_exprs.insert(ExprKind::IsNull);
collect_expr_features(expr, required_exprs);
}
Expr::IsNotNull { expr } => {
required_exprs.insert(ExprKind::IsNotNull);
collect_expr_features(expr, required_exprs);
}
Expr::StartsWith { expr, .. } => {
required_exprs.insert(ExprKind::StartsWith);
collect_expr_features(expr, required_exprs);
}
Expr::EndsWith { expr, .. } => {
required_exprs.insert(ExprKind::EndsWith);
collect_expr_features(expr, required_exprs);
}
Expr::Contains { expr, .. } => {
required_exprs.insert(ExprKind::Contains);
collect_expr_features(expr, required_exprs);
}
Expr::In { expr, items } => {
required_exprs.insert(ExprKind::In);
collect_expr_features(expr, required_exprs);
for item in items {
collect_expr_features(item, required_exprs);
}
}
Expr::ListLiteral { items } => {
required_exprs.insert(ExprKind::ListLiteral);
for item in items {
collect_expr_features(item, required_exprs);
}
}
Expr::MapLiteral { entries } => {
required_exprs.insert(ExprKind::MapLiteral);
for (_, value) in entries {
collect_expr_features(value, required_exprs);
}
}
Expr::Exists { expr } => {
required_exprs.insert(ExprKind::Exists);
collect_expr_features(expr, required_exprs);
}
Expr::ListComprehension {
list,
predicate,
map,
..
} => {
required_exprs.insert(ExprKind::ListComprehension);
collect_expr_features(list, required_exprs);
if let Some(pred) = predicate {
collect_expr_features(pred, required_exprs);
}
collect_expr_features(map, required_exprs);
}
Expr::Agg { expr, .. } => {
required_exprs.insert(ExprKind::Agg);
if let Some(inner) = expr {
collect_expr_features(inner, required_exprs);
}
}
Expr::Arith { lhs, rhs, .. } => {
required_exprs.insert(ExprKind::Arith);
collect_expr_features(lhs, required_exprs);
collect_expr_features(rhs, required_exprs);
}
Expr::Param { .. } => {
required_exprs.insert(ExprKind::Param);
}
Expr::Case { arms, else_expr } => {
required_exprs.insert(ExprKind::Case);
for (when_expr, then_expr) in arms {
collect_expr_features(when_expr, required_exprs);
collect_expr_features(then_expr, required_exprs);
}
if let Some(e) = else_expr {
collect_expr_features(e, required_exprs);
}
}
Expr::VectorSimilarity { lhs, rhs, .. } => {
required_exprs.insert(ExprKind::VectorSimilarity);
collect_expr_features(lhs, required_exprs);
collect_expr_features(rhs, required_exprs);
}
}
}