Skip to main content

plexus_engine/
independent_consumer.rs

1use std::cmp::Ordering;
2
3use plexus_serde::{CmpOp, ExpandDir, Expr, Op, Plan, SortDir};
4
5use crate::{ExecutionError, Graph, Node, PlanEngine, QueryResult, Relationship, Row, Value};
6
7type RowSet = Vec<Row>;
8
9struct ExpandSpec<'a> {
10    src_col: u32,
11    types: &'a [String],
12    dir: ExpandDir,
13    legal_src_labels: &'a [String],
14    legal_dst_labels: &'a [String],
15    optional: bool,
16}
17
18/// Minimal independent `PlanEngine` implementation used for the Phase 3
19/// interoperability proof example. This is intentionally narrow and only
20/// supports the read operators exercised by the checked-in proof corpus subset.
21#[derive(Debug, Clone)]
22pub struct IndependentConsumerEngine {
23    graph: Graph,
24}
25
26impl IndependentConsumerEngine {
27    pub fn new(graph: Graph) -> Self {
28        Self { graph }
29    }
30}
31
32impl PlanEngine for IndependentConsumerEngine {
33    type Error = ExecutionError;
34
35    fn execute_plan(&mut self, plan: &Plan) -> Result<QueryResult, Self::Error> {
36        let mut outputs = Vec::<RowSet>::with_capacity(plan.ops.len());
37        for op in &plan.ops {
38            let rows = match op {
39                Op::ScanNodes {
40                    labels,
41                    must_labels,
42                    forbidden_labels,
43                    ..
44                } => self.scan_nodes(labels, must_labels, forbidden_labels),
45                Op::Expand {
46                    input,
47                    src_col,
48                    types,
49                    dir,
50                    legal_src_labels,
51                    legal_dst_labels,
52                    ..
53                } => self.expand(
54                    get_output(&outputs, *input)?,
55                    ExpandSpec {
56                        src_col: *src_col,
57                        types,
58                        dir: *dir,
59                        legal_src_labels,
60                        legal_dst_labels,
61                        optional: false,
62                    },
63                )?,
64                Op::OptionalExpand {
65                    input,
66                    src_col,
67                    types,
68                    dir,
69                    legal_src_labels,
70                    legal_dst_labels,
71                    ..
72                } => self.expand(
73                    get_output(&outputs, *input)?,
74                    ExpandSpec {
75                        src_col: *src_col,
76                        types,
77                        dir: *dir,
78                        legal_src_labels,
79                        legal_dst_labels,
80                        optional: true,
81                    },
82                )?,
83                Op::Filter { input, predicate } => {
84                    self.filter(get_output(&outputs, *input)?, predicate)?
85                }
86                Op::Project { input, exprs, .. } => {
87                    self.project(get_output(&outputs, *input)?, exprs)?
88                }
89                Op::Sort { input, keys, dirs } => {
90                    self.sort(get_output(&outputs, *input)?, keys, dirs)?
91                }
92                Op::Return { input } => get_output(&outputs, *input)?.clone(),
93                _ => {
94                    return Err(ExecutionError::UnsupportedOp(
95                        "independent consumer proof subset",
96                    ))
97                }
98            };
99            outputs.push(rows);
100        }
101
102        let Some(rows) = outputs.get(plan.root_op as usize) else {
103            return Err(ExecutionError::InvalidRootOp(plan.root_op));
104        };
105        Ok(QueryResult { rows: rows.clone() })
106    }
107}
108
109pub fn proof_fixture_graph() -> Graph {
110    let node = |id: u64, labels: &[&str], props: &[(&str, Value)]| Node {
111        id,
112        labels: labels.iter().map(|label| (*label).to_string()).collect(),
113        props: props
114            .iter()
115            .map(|(key, value)| ((*key).to_string(), value.clone()))
116            .collect(),
117    };
118    let rel = |id: u64, src: u64, dst: u64, typ: &str| Relationship {
119        id,
120        src,
121        dst,
122        typ: typ.to_string(),
123        props: Default::default(),
124    };
125
126    Graph {
127        nodes: vec![
128            node(
129                1,
130                &["Person"],
131                &[
132                    ("name", Value::String("Alice".to_string())),
133                    ("age", Value::Int(30)),
134                ],
135            ),
136            node(
137                2,
138                &["Person"],
139                &[
140                    ("name", Value::String("Bob".to_string())),
141                    ("age", Value::Int(40)),
142                ],
143            ),
144            node(
145                3,
146                &["Company"],
147                &[("name", Value::String("Acme".to_string()))],
148            ),
149        ],
150        rels: vec![
151            rel(10, 1, 2, "KNOWS"),
152            rel(11, 2, 1, "KNOWS"),
153            rel(12, 2, 3, "WORKS_AT"),
154        ],
155    }
156}
157
158impl IndependentConsumerEngine {
159    fn scan_nodes(
160        &self,
161        labels: &[String],
162        must_labels: &[String],
163        forbidden_labels: &[String],
164    ) -> RowSet {
165        self.graph
166            .nodes
167            .iter()
168            .filter(|node| {
169                labels.iter().all(|label| node.labels.contains(label))
170                    && must_labels.iter().all(|label| node.labels.contains(label))
171                    && forbidden_labels
172                        .iter()
173                        .all(|label| !node.labels.contains(label))
174            })
175            .map(|node| vec![Value::NodeRef(node.id)])
176            .collect()
177    }
178
179    fn expand(&self, input: &[Row], spec: ExpandSpec<'_>) -> Result<RowSet, ExecutionError> {
180        let mut out = Vec::new();
181        for row in input {
182            let Some(value) = row.get(spec.src_col as usize) else {
183                return Err(ExecutionError::ColumnOutOfBounds {
184                    idx: spec.src_col as usize,
185                    len: row.len(),
186                });
187            };
188            let Value::NodeRef(src_id) = value else {
189                return Err(ExecutionError::ExpectedNodeRef {
190                    idx: spec.src_col as usize,
191                });
192            };
193            let src_node = self
194                .graph
195                .node_by_id(*src_id)
196                .ok_or(ExecutionError::UnknownNode(*src_id))?;
197            if !labels_match(src_node, spec.legal_src_labels) {
198                continue;
199            }
200
201            let mut matched = false;
202            for rel in &self.graph.rels {
203                if !spec.types.is_empty() && !spec.types.iter().any(|typ| typ == &rel.typ) {
204                    continue;
205                }
206                if let Some(dst_id) = relation_endpoint(rel, *src_id, spec.dir) {
207                    let dst_node = self
208                        .graph
209                        .node_by_id(dst_id)
210                        .ok_or(ExecutionError::UnknownNode(dst_id))?;
211                    if !labels_match(dst_node, spec.legal_dst_labels) {
212                        continue;
213                    }
214                    let mut next = row.clone();
215                    next.push(Value::RelRef(rel.id));
216                    next.push(Value::NodeRef(dst_id));
217                    out.push(next);
218                    matched = true;
219                }
220            }
221
222            if spec.optional && !matched {
223                let mut next = row.clone();
224                next.push(Value::Null);
225                next.push(Value::Null);
226                out.push(next);
227            }
228        }
229        Ok(out)
230    }
231
232    fn filter(&self, input: &[Row], predicate: &Expr) -> Result<RowSet, ExecutionError> {
233        let mut out = Vec::new();
234        for row in input {
235            if matches!(self.eval_expr(row, predicate)?, Value::Bool(true)) {
236                out.push(row.clone());
237            }
238        }
239        Ok(out)
240    }
241
242    fn project(&self, input: &[Row], exprs: &[Expr]) -> Result<RowSet, ExecutionError> {
243        input
244            .iter()
245            .map(|row| {
246                exprs
247                    .iter()
248                    .map(|expr| self.eval_expr(row, expr))
249                    .collect::<Result<Row, _>>()
250            })
251            .collect()
252    }
253
254    fn sort(
255        &self,
256        input: &[Row],
257        keys: &[u32],
258        dirs: &[SortDir],
259    ) -> Result<RowSet, ExecutionError> {
260        if keys.len() != dirs.len() {
261            return Err(ExecutionError::SortArityMismatch {
262                keys: keys.len(),
263                dirs: dirs.len(),
264            });
265        }
266        let mut out = input.to_vec();
267        out.sort_by(|lhs, rhs| compare_rows(lhs, rhs, keys, dirs));
268        Ok(out)
269    }
270
271    fn eval_expr(&self, row: &Row, expr: &Expr) -> Result<Value, ExecutionError> {
272        match expr {
273            Expr::ColRef { idx } => {
274                row.get(*idx as usize)
275                    .cloned()
276                    .ok_or(ExecutionError::ColumnOutOfBounds {
277                        idx: *idx as usize,
278                        len: row.len(),
279                    })
280            }
281            Expr::PropAccess { col, prop } => {
282                let Some(value) = row.get(*col as usize) else {
283                    return Err(ExecutionError::ColumnOutOfBounds {
284                        idx: *col as usize,
285                        len: row.len(),
286                    });
287                };
288                self.property_access(value, prop)
289            }
290            Expr::IntLiteral(value) => Ok(Value::Int(*value)),
291            Expr::FloatLiteral(value) => Ok(Value::Float(*value)),
292            Expr::BoolLiteral(value) => Ok(Value::Bool(*value)),
293            Expr::StringLiteral(value) => Ok(Value::String(value.clone())),
294            Expr::NullLiteral => Ok(Value::Null),
295            Expr::Cmp { op, lhs, rhs } => {
296                let lhs = self.eval_expr(row, lhs)?;
297                let rhs = self.eval_expr(row, rhs)?;
298                Ok(compare_expr_values(*op, lhs, rhs))
299            }
300            _ => Err(ExecutionError::UnsupportedExpr(
301                "independent consumer proof subset",
302            )),
303        }
304    }
305
306    fn property_access(&self, value: &Value, prop: &str) -> Result<Value, ExecutionError> {
307        match value {
308            Value::Null => Ok(Value::Null),
309            Value::NodeRef(id) => Ok(self
310                .graph
311                .node_by_id(*id)
312                .ok_or(ExecutionError::UnknownNode(*id))?
313                .props
314                .get(prop)
315                .cloned()
316                .unwrap_or(Value::Null)),
317            Value::RelRef(id) => Ok(self
318                .graph
319                .rel_by_id(*id)
320                .ok_or(ExecutionError::UnknownRel(*id))?
321                .props
322                .get(prop)
323                .cloned()
324                .unwrap_or(Value::Null)),
325            Value::Map(entries) => Ok(entries.get(prop).cloned().unwrap_or(Value::Null)),
326            _ => Ok(Value::Null),
327        }
328    }
329}
330
331fn get_output(outputs: &[RowSet], idx: u32) -> Result<&RowSet, ExecutionError> {
332    outputs
333        .get(idx as usize)
334        .ok_or(ExecutionError::MissingOpOutput(idx))
335}
336
337fn labels_match(node: &Node, required: &[String]) -> bool {
338    required.is_empty() || required.iter().all(|label| node.labels.contains(label))
339}
340
341fn relation_endpoint(rel: &Relationship, src_id: u64, dir: ExpandDir) -> Option<u64> {
342    match dir {
343        ExpandDir::Out if rel.src == src_id => Some(rel.dst),
344        ExpandDir::In if rel.dst == src_id => Some(rel.src),
345        ExpandDir::Both if rel.src == src_id => Some(rel.dst),
346        ExpandDir::Both if rel.dst == src_id => Some(rel.src),
347        _ => None,
348    }
349}
350
351fn compare_rows(lhs: &Row, rhs: &Row, keys: &[u32], dirs: &[SortDir]) -> Ordering {
352    for (key, dir) in keys.iter().zip(dirs) {
353        let lhs_value = lhs.get(*key as usize).unwrap_or(&Value::Null);
354        let rhs_value = rhs.get(*key as usize).unwrap_or(&Value::Null);
355        let ordering = compare_values(lhs_value, rhs_value);
356        if ordering != Ordering::Equal {
357            return match dir {
358                SortDir::Asc => ordering,
359                SortDir::Desc => ordering.reverse(),
360            };
361        }
362    }
363    Ordering::Equal
364}
365
366fn compare_expr_values(op: CmpOp, lhs: Value, rhs: Value) -> Value {
367    if matches!(lhs, Value::Null) || matches!(rhs, Value::Null) {
368        return Value::Null;
369    }
370
371    let ordering = compare_values(&lhs, &rhs);
372    let result = match op {
373        CmpOp::Eq => lhs == rhs,
374        CmpOp::Ne => lhs != rhs,
375        CmpOp::Lt => ordering == Ordering::Less,
376        CmpOp::Gt => ordering == Ordering::Greater,
377        CmpOp::Le => ordering != Ordering::Greater,
378        CmpOp::Ge => ordering != Ordering::Less,
379    };
380    Value::Bool(result)
381}
382
383fn compare_values(lhs: &Value, rhs: &Value) -> Ordering {
384    match (lhs, rhs) {
385        (Value::Null, Value::Null) => Ordering::Equal,
386        (Value::Null, _) => Ordering::Less,
387        (_, Value::Null) => Ordering::Greater,
388        (Value::Bool(lhs), Value::Bool(rhs)) => lhs.cmp(rhs),
389        (Value::Int(lhs), Value::Int(rhs)) => lhs.cmp(rhs),
390        (Value::Float(lhs), Value::Float(rhs)) => lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal),
391        (Value::Int(lhs), Value::Float(rhs)) => {
392            (*lhs as f64).partial_cmp(rhs).unwrap_or(Ordering::Equal)
393        }
394        (Value::Float(lhs), Value::Int(rhs)) => {
395            lhs.partial_cmp(&(*rhs as f64)).unwrap_or(Ordering::Equal)
396        }
397        (Value::String(lhs), Value::String(rhs)) => lhs.cmp(rhs),
398        (Value::NodeRef(lhs), Value::NodeRef(rhs)) => lhs.cmp(rhs),
399        (Value::RelRef(lhs), Value::RelRef(rhs)) => lhs.cmp(rhs),
400        _ => value_rank(lhs).cmp(&value_rank(rhs)),
401    }
402}
403
404fn value_rank(value: &Value) -> u8 {
405    match value {
406        Value::Null => 0,
407        Value::Bool(_) => 1,
408        Value::Int(_) | Value::Float(_) => 2,
409        Value::String(_) => 3,
410        Value::NodeRef(_) => 4,
411        Value::RelRef(_) => 5,
412        Value::List(_) => 6,
413        Value::Map(_) => 7,
414    }
415}