1use std::cmp::Ordering;
2
3use plexus_serde::{CmpOp, ExpandDir, Expr, Op, Plan, SortDir};
4
5use crate::{ExecutionError, Graph, Node, PlanEngine, QueryResult, Relationship, Row, Value};
6
7type RowSet = Vec<Row>;
8
9struct ExpandSpec<'a> {
10 src_col: u32,
11 types: &'a [String],
12 dir: ExpandDir,
13 legal_src_labels: &'a [String],
14 legal_dst_labels: &'a [String],
15 optional: bool,
16}
17
18#[derive(Debug, Clone)]
22pub struct IndependentConsumerEngine {
23 graph: Graph,
24}
25
26impl IndependentConsumerEngine {
27 pub fn new(graph: Graph) -> Self {
28 Self { graph }
29 }
30}
31
32impl PlanEngine for IndependentConsumerEngine {
33 type Error = ExecutionError;
34
35 fn execute_plan(&mut self, plan: &Plan) -> Result<QueryResult, Self::Error> {
36 let mut outputs = Vec::<RowSet>::with_capacity(plan.ops.len());
37 for op in &plan.ops {
38 let rows = match op {
39 Op::ScanNodes {
40 labels,
41 must_labels,
42 forbidden_labels,
43 ..
44 } => self.scan_nodes(labels, must_labels, forbidden_labels),
45 Op::Expand {
46 input,
47 src_col,
48 types,
49 dir,
50 legal_src_labels,
51 legal_dst_labels,
52 ..
53 } => self.expand(
54 get_output(&outputs, *input)?,
55 ExpandSpec {
56 src_col: *src_col,
57 types,
58 dir: *dir,
59 legal_src_labels,
60 legal_dst_labels,
61 optional: false,
62 },
63 )?,
64 Op::OptionalExpand {
65 input,
66 src_col,
67 types,
68 dir,
69 legal_src_labels,
70 legal_dst_labels,
71 ..
72 } => self.expand(
73 get_output(&outputs, *input)?,
74 ExpandSpec {
75 src_col: *src_col,
76 types,
77 dir: *dir,
78 legal_src_labels,
79 legal_dst_labels,
80 optional: true,
81 },
82 )?,
83 Op::Filter { input, predicate } => {
84 self.filter(get_output(&outputs, *input)?, predicate)?
85 }
86 Op::Project { input, exprs, .. } => {
87 self.project(get_output(&outputs, *input)?, exprs)?
88 }
89 Op::Sort { input, keys, dirs } => {
90 self.sort(get_output(&outputs, *input)?, keys, dirs)?
91 }
92 Op::Return { input } => get_output(&outputs, *input)?.clone(),
93 _ => {
94 return Err(ExecutionError::UnsupportedOp(
95 "independent consumer proof subset",
96 ))
97 }
98 };
99 outputs.push(rows);
100 }
101
102 let Some(rows) = outputs.get(plan.root_op as usize) else {
103 return Err(ExecutionError::InvalidRootOp(plan.root_op));
104 };
105 Ok(QueryResult { rows: rows.clone() })
106 }
107}
108
109pub fn proof_fixture_graph() -> Graph {
110 let node = |id: u64, labels: &[&str], props: &[(&str, Value)]| Node {
111 id,
112 labels: labels.iter().map(|label| (*label).to_string()).collect(),
113 props: props
114 .iter()
115 .map(|(key, value)| ((*key).to_string(), value.clone()))
116 .collect(),
117 };
118 let rel = |id: u64, src: u64, dst: u64, typ: &str| Relationship {
119 id,
120 src,
121 dst,
122 typ: typ.to_string(),
123 props: Default::default(),
124 };
125
126 Graph {
127 nodes: vec![
128 node(
129 1,
130 &["Person"],
131 &[
132 ("name", Value::String("Alice".to_string())),
133 ("age", Value::Int(30)),
134 ],
135 ),
136 node(
137 2,
138 &["Person"],
139 &[
140 ("name", Value::String("Bob".to_string())),
141 ("age", Value::Int(40)),
142 ],
143 ),
144 node(
145 3,
146 &["Company"],
147 &[("name", Value::String("Acme".to_string()))],
148 ),
149 ],
150 rels: vec![
151 rel(10, 1, 2, "KNOWS"),
152 rel(11, 2, 1, "KNOWS"),
153 rel(12, 2, 3, "WORKS_AT"),
154 ],
155 }
156}
157
158impl IndependentConsumerEngine {
159 fn scan_nodes(
160 &self,
161 labels: &[String],
162 must_labels: &[String],
163 forbidden_labels: &[String],
164 ) -> RowSet {
165 self.graph
166 .nodes
167 .iter()
168 .filter(|node| {
169 labels.iter().all(|label| node.labels.contains(label))
170 && must_labels.iter().all(|label| node.labels.contains(label))
171 && forbidden_labels
172 .iter()
173 .all(|label| !node.labels.contains(label))
174 })
175 .map(|node| vec![Value::NodeRef(node.id)])
176 .collect()
177 }
178
179 fn expand(&self, input: &[Row], spec: ExpandSpec<'_>) -> Result<RowSet, ExecutionError> {
180 let mut out = Vec::new();
181 for row in input {
182 let Some(value) = row.get(spec.src_col as usize) else {
183 return Err(ExecutionError::ColumnOutOfBounds {
184 idx: spec.src_col as usize,
185 len: row.len(),
186 });
187 };
188 let Value::NodeRef(src_id) = value else {
189 return Err(ExecutionError::ExpectedNodeRef {
190 idx: spec.src_col as usize,
191 });
192 };
193 let src_node = self
194 .graph
195 .node_by_id(*src_id)
196 .ok_or(ExecutionError::UnknownNode(*src_id))?;
197 if !labels_match(src_node, spec.legal_src_labels) {
198 continue;
199 }
200
201 let mut matched = false;
202 for rel in &self.graph.rels {
203 if !spec.types.is_empty() && !spec.types.iter().any(|typ| typ == &rel.typ) {
204 continue;
205 }
206 if let Some(dst_id) = relation_endpoint(rel, *src_id, spec.dir) {
207 let dst_node = self
208 .graph
209 .node_by_id(dst_id)
210 .ok_or(ExecutionError::UnknownNode(dst_id))?;
211 if !labels_match(dst_node, spec.legal_dst_labels) {
212 continue;
213 }
214 let mut next = row.clone();
215 next.push(Value::RelRef(rel.id));
216 next.push(Value::NodeRef(dst_id));
217 out.push(next);
218 matched = true;
219 }
220 }
221
222 if spec.optional && !matched {
223 let mut next = row.clone();
224 next.push(Value::Null);
225 next.push(Value::Null);
226 out.push(next);
227 }
228 }
229 Ok(out)
230 }
231
232 fn filter(&self, input: &[Row], predicate: &Expr) -> Result<RowSet, ExecutionError> {
233 let mut out = Vec::new();
234 for row in input {
235 if matches!(self.eval_expr(row, predicate)?, Value::Bool(true)) {
236 out.push(row.clone());
237 }
238 }
239 Ok(out)
240 }
241
242 fn project(&self, input: &[Row], exprs: &[Expr]) -> Result<RowSet, ExecutionError> {
243 input
244 .iter()
245 .map(|row| {
246 exprs
247 .iter()
248 .map(|expr| self.eval_expr(row, expr))
249 .collect::<Result<Row, _>>()
250 })
251 .collect()
252 }
253
254 fn sort(
255 &self,
256 input: &[Row],
257 keys: &[u32],
258 dirs: &[SortDir],
259 ) -> Result<RowSet, ExecutionError> {
260 if keys.len() != dirs.len() {
261 return Err(ExecutionError::SortArityMismatch {
262 keys: keys.len(),
263 dirs: dirs.len(),
264 });
265 }
266 let mut out = input.to_vec();
267 out.sort_by(|lhs, rhs| compare_rows(lhs, rhs, keys, dirs));
268 Ok(out)
269 }
270
271 fn eval_expr(&self, row: &Row, expr: &Expr) -> Result<Value, ExecutionError> {
272 match expr {
273 Expr::ColRef { idx } => {
274 row.get(*idx as usize)
275 .cloned()
276 .ok_or(ExecutionError::ColumnOutOfBounds {
277 idx: *idx as usize,
278 len: row.len(),
279 })
280 }
281 Expr::PropAccess { col, prop } => {
282 let Some(value) = row.get(*col as usize) else {
283 return Err(ExecutionError::ColumnOutOfBounds {
284 idx: *col as usize,
285 len: row.len(),
286 });
287 };
288 self.property_access(value, prop)
289 }
290 Expr::IntLiteral(value) => Ok(Value::Int(*value)),
291 Expr::FloatLiteral(value) => Ok(Value::Float(*value)),
292 Expr::BoolLiteral(value) => Ok(Value::Bool(*value)),
293 Expr::StringLiteral(value) => Ok(Value::String(value.clone())),
294 Expr::NullLiteral => Ok(Value::Null),
295 Expr::Cmp { op, lhs, rhs } => {
296 let lhs = self.eval_expr(row, lhs)?;
297 let rhs = self.eval_expr(row, rhs)?;
298 Ok(compare_expr_values(*op, lhs, rhs))
299 }
300 _ => Err(ExecutionError::UnsupportedExpr(
301 "independent consumer proof subset",
302 )),
303 }
304 }
305
306 fn property_access(&self, value: &Value, prop: &str) -> Result<Value, ExecutionError> {
307 match value {
308 Value::Null => Ok(Value::Null),
309 Value::NodeRef(id) => Ok(self
310 .graph
311 .node_by_id(*id)
312 .ok_or(ExecutionError::UnknownNode(*id))?
313 .props
314 .get(prop)
315 .cloned()
316 .unwrap_or(Value::Null)),
317 Value::RelRef(id) => Ok(self
318 .graph
319 .rel_by_id(*id)
320 .ok_or(ExecutionError::UnknownRel(*id))?
321 .props
322 .get(prop)
323 .cloned()
324 .unwrap_or(Value::Null)),
325 Value::Map(entries) => Ok(entries.get(prop).cloned().unwrap_or(Value::Null)),
326 _ => Ok(Value::Null),
327 }
328 }
329}
330
331fn get_output(outputs: &[RowSet], idx: u32) -> Result<&RowSet, ExecutionError> {
332 outputs
333 .get(idx as usize)
334 .ok_or(ExecutionError::MissingOpOutput(idx))
335}
336
337fn labels_match(node: &Node, required: &[String]) -> bool {
338 required.is_empty() || required.iter().all(|label| node.labels.contains(label))
339}
340
341fn relation_endpoint(rel: &Relationship, src_id: u64, dir: ExpandDir) -> Option<u64> {
342 match dir {
343 ExpandDir::Out if rel.src == src_id => Some(rel.dst),
344 ExpandDir::In if rel.dst == src_id => Some(rel.src),
345 ExpandDir::Both if rel.src == src_id => Some(rel.dst),
346 ExpandDir::Both if rel.dst == src_id => Some(rel.src),
347 _ => None,
348 }
349}
350
351fn compare_rows(lhs: &Row, rhs: &Row, keys: &[u32], dirs: &[SortDir]) -> Ordering {
352 for (key, dir) in keys.iter().zip(dirs) {
353 let lhs_value = lhs.get(*key as usize).unwrap_or(&Value::Null);
354 let rhs_value = rhs.get(*key as usize).unwrap_or(&Value::Null);
355 let ordering = compare_values(lhs_value, rhs_value);
356 if ordering != Ordering::Equal {
357 return match dir {
358 SortDir::Asc => ordering,
359 SortDir::Desc => ordering.reverse(),
360 };
361 }
362 }
363 Ordering::Equal
364}
365
366fn compare_expr_values(op: CmpOp, lhs: Value, rhs: Value) -> Value {
367 if matches!(lhs, Value::Null) || matches!(rhs, Value::Null) {
368 return Value::Null;
369 }
370
371 let ordering = compare_values(&lhs, &rhs);
372 let result = match op {
373 CmpOp::Eq => lhs == rhs,
374 CmpOp::Ne => lhs != rhs,
375 CmpOp::Lt => ordering == Ordering::Less,
376 CmpOp::Gt => ordering == Ordering::Greater,
377 CmpOp::Le => ordering != Ordering::Greater,
378 CmpOp::Ge => ordering != Ordering::Less,
379 };
380 Value::Bool(result)
381}
382
383fn compare_values(lhs: &Value, rhs: &Value) -> Ordering {
384 match (lhs, rhs) {
385 (Value::Null, Value::Null) => Ordering::Equal,
386 (Value::Null, _) => Ordering::Less,
387 (_, Value::Null) => Ordering::Greater,
388 (Value::Bool(lhs), Value::Bool(rhs)) => lhs.cmp(rhs),
389 (Value::Int(lhs), Value::Int(rhs)) => lhs.cmp(rhs),
390 (Value::Float(lhs), Value::Float(rhs)) => lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal),
391 (Value::Int(lhs), Value::Float(rhs)) => {
392 (*lhs as f64).partial_cmp(rhs).unwrap_or(Ordering::Equal)
393 }
394 (Value::Float(lhs), Value::Int(rhs)) => {
395 lhs.partial_cmp(&(*rhs as f64)).unwrap_or(Ordering::Equal)
396 }
397 (Value::String(lhs), Value::String(rhs)) => lhs.cmp(rhs),
398 (Value::NodeRef(lhs), Value::NodeRef(rhs)) => lhs.cmp(rhs),
399 (Value::RelRef(lhs), Value::RelRef(rhs)) => lhs.cmp(rhs),
400 _ => value_rank(lhs).cmp(&value_rank(rhs)),
401 }
402}
403
404fn value_rank(value: &Value) -> u8 {
405 match value {
406 Value::Null => 0,
407 Value::Bool(_) => 1,
408 Value::Int(_) | Value::Float(_) => 2,
409 Value::String(_) => 3,
410 Value::NodeRef(_) => 4,
411 Value::RelRef(_) => 5,
412 Value::List(_) => 6,
413 Value::Map(_) => 7,
414 }
415}