Skip to main content

plexus_engine/capabilities/
validation.rs

1use std::collections::BTreeSet;
2
3use plexus_serde::{Expr, Op, Plan};
4
5use crate::capabilities::types::{
6    CapabilityError, EngineCapabilities, ExprKind, OpKind, RequiredCapabilities,
7};
8use crate::capabilities::wire::check_version_compat;
9
10pub fn required_capabilities(plan: &Plan) -> RequiredCapabilities {
11    let mut required_ops = BTreeSet::new();
12    let mut required_exprs = BTreeSet::new();
13
14    for op in &plan.ops {
15        collect_op_features(op, &mut required_ops, &mut required_exprs);
16    }
17
18    RequiredCapabilities {
19        plan_version: (&plan.version).into(),
20        required_ops,
21        required_exprs,
22    }
23}
24
25pub fn validate_plan_against_capabilities(
26    plan: &Plan,
27    capabilities: &EngineCapabilities,
28) -> Result<(), CapabilityError> {
29    check_version_compat(&plan.version, capabilities.version_range)?;
30    validate_graph_ref_support(plan, capabilities)?;
31
32    let required = required_capabilities(plan);
33    let missing_ops: Vec<_> = required
34        .required_ops
35        .difference(&capabilities.supported_ops)
36        .copied()
37        .collect();
38    let missing_exprs: Vec<_> = required
39        .required_exprs
40        .difference(&capabilities.supported_exprs)
41        .copied()
42        .collect();
43
44    if missing_ops.is_empty() && missing_exprs.is_empty() {
45        return Ok(());
46    }
47
48    Err(CapabilityError::MissingFeatureSupport {
49        missing_ops,
50        missing_exprs,
51    })
52}
53
54fn validate_graph_ref_support(
55    plan: &Plan,
56    capabilities: &EngineCapabilities,
57) -> Result<(), CapabilityError> {
58    let mut refs = BTreeSet::<String>::new();
59    for op in &plan.ops {
60        if let Some(graph_ref) = op_graph_ref(op) {
61            refs.insert(graph_ref.to_string());
62        }
63    }
64
65    if refs.is_empty() {
66        return Ok(());
67    }
68
69    if !capabilities.supports_graph_ref {
70        return Err(CapabilityError::GraphRefUnsupported);
71    }
72
73    // Graph parameter variables ($g) require explicit engine support.
74    if refs.iter().any(|r| r.starts_with('$')) && !capabilities.supports_graph_params {
75        return Err(CapabilityError::GraphParamUnsupported);
76    }
77
78    if refs.len() > 1 && !capabilities.supports_multi_graph {
79        return Err(CapabilityError::MultiGraphUnsupported);
80    }
81
82    Ok(())
83}
84
85fn op_graph_ref(op: &Op) -> Option<&str> {
86    let maybe = match op {
87        Op::ScanNodes { graph_ref, .. }
88        | Op::Expand { graph_ref, .. }
89        | Op::OptionalExpand { graph_ref, .. }
90        | Op::ExpandVarLen { graph_ref, .. } => graph_ref.as_deref(),
91        _ => None,
92    };
93    maybe.map(str::trim).filter(|s| !s.is_empty())
94}
95
96fn collect_op_features(
97    op: &Op,
98    required_ops: &mut BTreeSet<OpKind>,
99    required_exprs: &mut BTreeSet<ExprKind>,
100) {
101    match op {
102        Op::ScanNodes { .. } => {
103            required_ops.insert(OpKind::ScanNodes);
104        }
105        Op::ScanRels { .. } => {
106            required_ops.insert(OpKind::ScanRels);
107        }
108        Op::Expand { .. } => {
109            required_ops.insert(OpKind::Expand);
110        }
111        Op::OptionalExpand { .. } => {
112            required_ops.insert(OpKind::OptionalExpand);
113        }
114        Op::SemiExpand { .. } => {
115            required_ops.insert(OpKind::SemiExpand);
116        }
117        Op::ExpandVarLen { .. } => {
118            required_ops.insert(OpKind::ExpandVarLen);
119        }
120        Op::Filter { predicate, .. } => {
121            required_ops.insert(OpKind::Filter);
122            collect_expr_features(predicate, required_exprs);
123        }
124        Op::BlockMarker { .. } => {
125            required_ops.insert(OpKind::BlockMarker);
126        }
127        Op::Project { exprs, .. } => {
128            required_ops.insert(OpKind::Project);
129            for expr in exprs {
130                collect_expr_features(expr, required_exprs);
131            }
132        }
133        Op::Aggregate { aggs, .. } => {
134            required_ops.insert(OpKind::Aggregate);
135            for agg in aggs {
136                collect_expr_features(agg, required_exprs);
137            }
138        }
139        Op::Sort { .. } => {
140            required_ops.insert(OpKind::Sort);
141        }
142        Op::Limit { .. } => {
143            required_ops.insert(OpKind::Limit);
144        }
145        Op::Unwind { list_expr, .. } => {
146            required_ops.insert(OpKind::Unwind);
147            collect_expr_features(list_expr, required_exprs);
148        }
149        Op::PathConstruct { .. } => {
150            required_ops.insert(OpKind::PathConstruct);
151        }
152        Op::Union { .. } => {
153            required_ops.insert(OpKind::Union);
154        }
155        Op::CreateNode { props, .. } => {
156            required_ops.insert(OpKind::CreateNode);
157            collect_expr_features(props, required_exprs);
158        }
159        Op::CreateRel { props, .. } => {
160            required_ops.insert(OpKind::CreateRel);
161            collect_expr_features(props, required_exprs);
162        }
163        Op::Merge {
164            pattern,
165            on_create_props,
166            on_match_props,
167            ..
168        } => {
169            required_ops.insert(OpKind::Merge);
170            collect_expr_features(pattern, required_exprs);
171            collect_expr_features(on_create_props, required_exprs);
172            collect_expr_features(on_match_props, required_exprs);
173        }
174        Op::Delete { .. } => {
175            required_ops.insert(OpKind::Delete);
176        }
177        Op::SetProperty { value_expr, .. } => {
178            required_ops.insert(OpKind::SetProperty);
179            collect_expr_features(value_expr, required_exprs);
180        }
181        Op::RemoveProperty { .. } => {
182            required_ops.insert(OpKind::RemoveProperty);
183        }
184        Op::VectorScan { query_vector, .. } => {
185            required_ops.insert(OpKind::VectorScan);
186            collect_expr_features(query_vector, required_exprs);
187        }
188        Op::Rerank { score_expr, .. } => {
189            required_ops.insert(OpKind::Rerank);
190            collect_expr_features(score_expr, required_exprs);
191        }
192        Op::Return { .. } => {
193            required_ops.insert(OpKind::Return);
194        }
195        Op::ConstRow => {
196            required_ops.insert(OpKind::ConstRow);
197        }
198    }
199}
200
201fn collect_expr_features(expr: &Expr, required_exprs: &mut BTreeSet<ExprKind>) {
202    match expr {
203        Expr::ColRef { .. } => {
204            required_exprs.insert(ExprKind::ColRef);
205        }
206        Expr::PropAccess { .. } => {
207            required_exprs.insert(ExprKind::PropAccess);
208        }
209        Expr::IntLiteral(_) => {
210            required_exprs.insert(ExprKind::IntLiteral);
211        }
212        Expr::FloatLiteral(_) => {
213            required_exprs.insert(ExprKind::FloatLiteral);
214        }
215        Expr::BoolLiteral(_) => {
216            required_exprs.insert(ExprKind::BoolLiteral);
217        }
218        Expr::StringLiteral(_) => {
219            required_exprs.insert(ExprKind::StringLiteral);
220        }
221        Expr::NullLiteral => {
222            required_exprs.insert(ExprKind::NullLiteral);
223        }
224        Expr::Cmp { lhs, rhs, .. } => {
225            required_exprs.insert(ExprKind::Cmp);
226            collect_expr_features(lhs, required_exprs);
227            collect_expr_features(rhs, required_exprs);
228        }
229        Expr::And { lhs, rhs } => {
230            required_exprs.insert(ExprKind::And);
231            collect_expr_features(lhs, required_exprs);
232            collect_expr_features(rhs, required_exprs);
233        }
234        Expr::Or { lhs, rhs } => {
235            required_exprs.insert(ExprKind::Or);
236            collect_expr_features(lhs, required_exprs);
237            collect_expr_features(rhs, required_exprs);
238        }
239        Expr::Not { expr } => {
240            required_exprs.insert(ExprKind::Not);
241            collect_expr_features(expr, required_exprs);
242        }
243        Expr::IsNull { expr } => {
244            required_exprs.insert(ExprKind::IsNull);
245            collect_expr_features(expr, required_exprs);
246        }
247        Expr::IsNotNull { expr } => {
248            required_exprs.insert(ExprKind::IsNotNull);
249            collect_expr_features(expr, required_exprs);
250        }
251        Expr::StartsWith { expr, .. } => {
252            required_exprs.insert(ExprKind::StartsWith);
253            collect_expr_features(expr, required_exprs);
254        }
255        Expr::EndsWith { expr, .. } => {
256            required_exprs.insert(ExprKind::EndsWith);
257            collect_expr_features(expr, required_exprs);
258        }
259        Expr::Contains { expr, .. } => {
260            required_exprs.insert(ExprKind::Contains);
261            collect_expr_features(expr, required_exprs);
262        }
263        Expr::In { expr, items } => {
264            required_exprs.insert(ExprKind::In);
265            collect_expr_features(expr, required_exprs);
266            for item in items {
267                collect_expr_features(item, required_exprs);
268            }
269        }
270        Expr::ListLiteral { items } => {
271            required_exprs.insert(ExprKind::ListLiteral);
272            for item in items {
273                collect_expr_features(item, required_exprs);
274            }
275        }
276        Expr::MapLiteral { entries } => {
277            required_exprs.insert(ExprKind::MapLiteral);
278            for (_, value) in entries {
279                collect_expr_features(value, required_exprs);
280            }
281        }
282        Expr::Exists { expr } => {
283            required_exprs.insert(ExprKind::Exists);
284            collect_expr_features(expr, required_exprs);
285        }
286        Expr::ListComprehension {
287            list,
288            predicate,
289            map,
290            ..
291        } => {
292            required_exprs.insert(ExprKind::ListComprehension);
293            collect_expr_features(list, required_exprs);
294            if let Some(pred) = predicate {
295                collect_expr_features(pred, required_exprs);
296            }
297            collect_expr_features(map, required_exprs);
298        }
299        Expr::Agg { expr, .. } => {
300            required_exprs.insert(ExprKind::Agg);
301            if let Some(inner) = expr {
302                collect_expr_features(inner, required_exprs);
303            }
304        }
305        Expr::Arith { lhs, rhs, .. } => {
306            required_exprs.insert(ExprKind::Arith);
307            collect_expr_features(lhs, required_exprs);
308            collect_expr_features(rhs, required_exprs);
309        }
310        Expr::Param { .. } => {
311            required_exprs.insert(ExprKind::Param);
312        }
313        Expr::Case { arms, else_expr } => {
314            required_exprs.insert(ExprKind::Case);
315            for (when_expr, then_expr) in arms {
316                collect_expr_features(when_expr, required_exprs);
317                collect_expr_features(then_expr, required_exprs);
318            }
319            if let Some(e) = else_expr {
320                collect_expr_features(e, required_exprs);
321            }
322        }
323        Expr::VectorSimilarity { lhs, rhs, .. } => {
324            required_exprs.insert(ExprKind::VectorSimilarity);
325            collect_expr_features(lhs, required_exprs);
326            collect_expr_features(rhs, required_exprs);
327        }
328    }
329}