1use std::cmp::Ordering;
2
3use plexus_serde::{CmpOp, ExpandDir, Expr, Op, Plan, SortDir};
4
5use crate::{ExecutionError, Graph, Node, PlanEngine, QueryResult, Relationship, Row, Value};
6
7type RowSet = Vec<Row>;
8
9struct ExpandSpec<'a> {
10 src_col: u32,
11 types: &'a [String],
12 dir: ExpandDir,
13 legal_src_labels: &'a [String],
14 legal_dst_labels: &'a [String],
15 optional: bool,
16}
17
18#[derive(Debug, Clone)]
22pub struct IndependentConsumerEngine {
23 graph: Graph,
24}
25
26impl IndependentConsumerEngine {
27 pub fn new(graph: Graph) -> Self {
28 Self { graph }
29 }
30}
31
32impl PlanEngine for IndependentConsumerEngine {
33 type Error = ExecutionError;
34
35 fn execute_plan(&mut self, plan: &Plan) -> Result<QueryResult, Self::Error> {
36 let mut outputs = Vec::<RowSet>::with_capacity(plan.ops.len());
37 for op in &plan.ops {
38 let rows = match op {
39 Op::ScanNodes {
40 labels,
41 must_labels,
42 forbidden_labels,
43 ..
44 } => self.scan_nodes(labels, must_labels, forbidden_labels),
45 Op::Expand {
46 input,
47 src_col,
48 types,
49 dir,
50 legal_src_labels,
51 legal_dst_labels,
52 ..
53 } => self.expand(
54 get_output(&outputs, *input)?,
55 ExpandSpec {
56 src_col: *src_col,
57 types,
58 dir: *dir,
59 legal_src_labels,
60 legal_dst_labels,
61 optional: false,
62 },
63 )?,
64 Op::OptionalExpand {
65 input,
66 src_col,
67 types,
68 dir,
69 legal_src_labels,
70 legal_dst_labels,
71 ..
72 } => self.expand(
73 get_output(&outputs, *input)?,
74 ExpandSpec {
75 src_col: *src_col,
76 types,
77 dir: *dir,
78 legal_src_labels,
79 legal_dst_labels,
80 optional: true,
81 },
82 )?,
83 Op::Filter { input, predicate } => {
84 self.filter(get_output(&outputs, *input)?, predicate)?
85 }
86 Op::Project { input, exprs, .. } => {
87 self.project(get_output(&outputs, *input)?, exprs)?
88 }
89 Op::Sort { input, keys, dirs } => {
90 self.sort(get_output(&outputs, *input)?, keys, dirs)?
91 }
92 Op::Return { input } => get_output(&outputs, *input)?.clone(),
93 _ => {
94 return Err(ExecutionError::UnsupportedOp(
95 "independent consumer proof subset",
96 ))
97 }
98 };
99 outputs.push(rows);
100 }
101
102 let Some(rows) = outputs.get(plan.root_op as usize) else {
103 return Err(ExecutionError::InvalidRootOp(plan.root_op));
104 };
105 Ok(QueryResult {
106 rows: rows.clone(),
107 continuation: None,
108 })
109 }
110}
111
112pub fn proof_fixture_graph() -> Graph {
113 let node = |id: u64, labels: &[&str], props: &[(&str, Value)]| Node {
114 id,
115 labels: labels.iter().map(|label| (*label).to_string()).collect(),
116 props: props
117 .iter()
118 .map(|(key, value)| ((*key).to_string(), value.clone()))
119 .collect(),
120 };
121 let rel = |id: u64, src: u64, dst: u64, typ: &str| Relationship {
122 id,
123 src,
124 dst,
125 typ: typ.to_string(),
126 props: Default::default(),
127 };
128
129 Graph {
130 nodes: vec![
131 node(
132 1,
133 &["Person"],
134 &[
135 ("name", Value::String("Alice".to_string())),
136 ("age", Value::Int(30)),
137 ],
138 ),
139 node(
140 2,
141 &["Person"],
142 &[
143 ("name", Value::String("Bob".to_string())),
144 ("age", Value::Int(40)),
145 ],
146 ),
147 node(
148 3,
149 &["Company"],
150 &[("name", Value::String("Acme".to_string()))],
151 ),
152 ],
153 rels: vec![
154 rel(10, 1, 2, "KNOWS"),
155 rel(11, 2, 1, "KNOWS"),
156 rel(12, 2, 3, "WORKS_AT"),
157 ],
158 }
159}
160
161impl IndependentConsumerEngine {
162 fn scan_nodes(
163 &self,
164 labels: &[String],
165 must_labels: &[String],
166 forbidden_labels: &[String],
167 ) -> RowSet {
168 self.graph
169 .nodes
170 .iter()
171 .filter(|node| {
172 labels.iter().all(|label| node.labels.contains(label))
173 && must_labels.iter().all(|label| node.labels.contains(label))
174 && forbidden_labels
175 .iter()
176 .all(|label| !node.labels.contains(label))
177 })
178 .map(|node| vec![Value::NodeRef(node.id)])
179 .collect()
180 }
181
182 fn expand(&self, input: &[Row], spec: ExpandSpec<'_>) -> Result<RowSet, ExecutionError> {
183 let mut out = Vec::new();
184 for row in input {
185 let Some(value) = row.get(spec.src_col as usize) else {
186 return Err(ExecutionError::ColumnOutOfBounds {
187 idx: spec.src_col as usize,
188 len: row.len(),
189 });
190 };
191 let Value::NodeRef(src_id) = value else {
192 return Err(ExecutionError::ExpectedNodeRef {
193 idx: spec.src_col as usize,
194 });
195 };
196 let src_node = self
197 .graph
198 .node_by_id(*src_id)
199 .ok_or(ExecutionError::UnknownNode(*src_id))?;
200 if !labels_match(src_node, spec.legal_src_labels) {
201 continue;
202 }
203
204 let mut matched = false;
205 for rel in &self.graph.rels {
206 if !spec.types.is_empty() && !spec.types.iter().any(|typ| typ == &rel.typ) {
207 continue;
208 }
209 if let Some(dst_id) = relation_endpoint(rel, *src_id, spec.dir) {
210 let dst_node = self
211 .graph
212 .node_by_id(dst_id)
213 .ok_or(ExecutionError::UnknownNode(dst_id))?;
214 if !labels_match(dst_node, spec.legal_dst_labels) {
215 continue;
216 }
217 let mut next = row.clone();
218 next.push(Value::RelRef(rel.id));
219 next.push(Value::NodeRef(dst_id));
220 out.push(next);
221 matched = true;
222 }
223 }
224
225 if spec.optional && !matched {
226 let mut next = row.clone();
227 next.push(Value::Null);
228 next.push(Value::Null);
229 out.push(next);
230 }
231 }
232 Ok(out)
233 }
234
235 fn filter(&self, input: &[Row], predicate: &Expr) -> Result<RowSet, ExecutionError> {
236 let mut out = Vec::new();
237 for row in input {
238 if matches!(self.eval_expr(row, predicate)?, Value::Bool(true)) {
239 out.push(row.clone());
240 }
241 }
242 Ok(out)
243 }
244
245 fn project(&self, input: &[Row], exprs: &[Expr]) -> Result<RowSet, ExecutionError> {
246 input
247 .iter()
248 .map(|row| {
249 exprs
250 .iter()
251 .map(|expr| self.eval_expr(row, expr))
252 .collect::<Result<Row, _>>()
253 })
254 .collect()
255 }
256
257 fn sort(
258 &self,
259 input: &[Row],
260 keys: &[u32],
261 dirs: &[SortDir],
262 ) -> Result<RowSet, ExecutionError> {
263 if keys.len() != dirs.len() {
264 return Err(ExecutionError::SortArityMismatch {
265 keys: keys.len(),
266 dirs: dirs.len(),
267 });
268 }
269 let mut out = input.to_vec();
270 out.sort_by(|lhs, rhs| compare_rows(lhs, rhs, keys, dirs));
271 Ok(out)
272 }
273
274 fn eval_expr(&self, row: &Row, expr: &Expr) -> Result<Value, ExecutionError> {
275 match expr {
276 Expr::ColRef { idx } => {
277 row.get(*idx as usize)
278 .cloned()
279 .ok_or(ExecutionError::ColumnOutOfBounds {
280 idx: *idx as usize,
281 len: row.len(),
282 })
283 }
284 Expr::PropAccess { col, prop } => {
285 let Some(value) = row.get(*col as usize) else {
286 return Err(ExecutionError::ColumnOutOfBounds {
287 idx: *col as usize,
288 len: row.len(),
289 });
290 };
291 self.property_access(value, prop)
292 }
293 Expr::IntLiteral(value) => Ok(Value::Int(*value)),
294 Expr::FloatLiteral(value) => Ok(Value::Float(*value)),
295 Expr::BoolLiteral(value) => Ok(Value::Bool(*value)),
296 Expr::StringLiteral(value) => Ok(Value::String(value.clone())),
297 Expr::NullLiteral => Ok(Value::Null),
298 Expr::Cmp { op, lhs, rhs } => {
299 let lhs = self.eval_expr(row, lhs)?;
300 let rhs = self.eval_expr(row, rhs)?;
301 Ok(compare_expr_values(*op, lhs, rhs))
302 }
303 _ => Err(ExecutionError::UnsupportedExpr(
304 "independent consumer proof subset",
305 )),
306 }
307 }
308
309 fn property_access(&self, value: &Value, prop: &str) -> Result<Value, ExecutionError> {
310 match value {
311 Value::Null => Ok(Value::Null),
312 Value::NodeRef(id) => Ok(self
313 .graph
314 .node_by_id(*id)
315 .ok_or(ExecutionError::UnknownNode(*id))?
316 .props
317 .get(prop)
318 .cloned()
319 .unwrap_or(Value::Null)),
320 Value::RelRef(id) => Ok(self
321 .graph
322 .rel_by_id(*id)
323 .ok_or(ExecutionError::UnknownRel(*id))?
324 .props
325 .get(prop)
326 .cloned()
327 .unwrap_or(Value::Null)),
328 Value::Map(entries) => Ok(entries.get(prop).cloned().unwrap_or(Value::Null)),
329 _ => Ok(Value::Null),
330 }
331 }
332}
333
334fn get_output(outputs: &[RowSet], idx: u32) -> Result<&RowSet, ExecutionError> {
335 outputs
336 .get(idx as usize)
337 .ok_or(ExecutionError::MissingOpOutput(idx))
338}
339
340fn labels_match(node: &Node, required: &[String]) -> bool {
341 required.is_empty() || required.iter().all(|label| node.labels.contains(label))
342}
343
344fn relation_endpoint(rel: &Relationship, src_id: u64, dir: ExpandDir) -> Option<u64> {
345 match dir {
346 ExpandDir::Out if rel.src == src_id => Some(rel.dst),
347 ExpandDir::In if rel.dst == src_id => Some(rel.src),
348 ExpandDir::Both if rel.src == src_id => Some(rel.dst),
349 ExpandDir::Both if rel.dst == src_id => Some(rel.src),
350 _ => None,
351 }
352}
353
354fn compare_rows(lhs: &Row, rhs: &Row, keys: &[u32], dirs: &[SortDir]) -> Ordering {
355 for (key, dir) in keys.iter().zip(dirs) {
356 let lhs_value = lhs.get(*key as usize).unwrap_or(&Value::Null);
357 let rhs_value = rhs.get(*key as usize).unwrap_or(&Value::Null);
358 let ordering = compare_values(lhs_value, rhs_value);
359 if ordering != Ordering::Equal {
360 return match dir {
361 SortDir::Asc => ordering,
362 SortDir::Desc => ordering.reverse(),
363 };
364 }
365 }
366 Ordering::Equal
367}
368
369fn compare_expr_values(op: CmpOp, lhs: Value, rhs: Value) -> Value {
370 if matches!(lhs, Value::Null) || matches!(rhs, Value::Null) {
371 return Value::Null;
372 }
373
374 let ordering = compare_values(&lhs, &rhs);
375 let result = match op {
376 CmpOp::Eq => lhs == rhs,
377 CmpOp::Ne => lhs != rhs,
378 CmpOp::Lt => ordering == Ordering::Less,
379 CmpOp::Gt => ordering == Ordering::Greater,
380 CmpOp::Le => ordering != Ordering::Greater,
381 CmpOp::Ge => ordering != Ordering::Less,
382 };
383 Value::Bool(result)
384}
385
386fn compare_values(lhs: &Value, rhs: &Value) -> Ordering {
387 match (lhs, rhs) {
388 (Value::Null, Value::Null) => Ordering::Equal,
389 (Value::Null, _) => Ordering::Less,
390 (_, Value::Null) => Ordering::Greater,
391 (Value::Bool(lhs), Value::Bool(rhs)) => lhs.cmp(rhs),
392 (Value::Int(lhs), Value::Int(rhs)) => lhs.cmp(rhs),
393 (Value::Float(lhs), Value::Float(rhs)) => lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal),
394 (Value::Int(lhs), Value::Float(rhs)) => {
395 (*lhs as f64).partial_cmp(rhs).unwrap_or(Ordering::Equal)
396 }
397 (Value::Float(lhs), Value::Int(rhs)) => {
398 lhs.partial_cmp(&(*rhs as f64)).unwrap_or(Ordering::Equal)
399 }
400 (Value::String(lhs), Value::String(rhs)) => lhs.cmp(rhs),
401 (Value::NodeRef(lhs), Value::NodeRef(rhs)) => lhs.cmp(rhs),
402 (Value::RelRef(lhs), Value::RelRef(rhs)) => lhs.cmp(rhs),
403 _ => value_rank(lhs).cmp(&value_rank(rhs)),
404 }
405}
406
407fn value_rank(value: &Value) -> u8 {
408 match value {
409 Value::Null => 0,
410 Value::Bool(_) => 1,
411 Value::Int(_) | Value::Float(_) => 2,
412 Value::String(_) => 3,
413 Value::NodeRef(_) => 4,
414 Value::RelRef(_) => 5,
415 Value::List(_) => 6,
416 Value::Map(_) => 7,
417 }
418}