Skip to main content

plexus_engine/
independent_consumer.rs

1use std::cmp::Ordering;
2
3use plexus_serde::{CmpOp, ExpandDir, Expr, Op, Plan, SortDir};
4
5use crate::{ExecutionError, Graph, Node, PlanEngine, QueryResult, Relationship, Row, Value};
6
7type RowSet = Vec<Row>;
8
9struct ExpandSpec<'a> {
10    src_col: u32,
11    types: &'a [String],
12    dir: ExpandDir,
13    legal_src_labels: &'a [String],
14    legal_dst_labels: &'a [String],
15    optional: bool,
16}
17
18/// Minimal independent `PlanEngine` implementation used for the Phase 3
19/// interoperability proof example. This is intentionally narrow and only
20/// supports the read operators exercised by the checked-in proof corpus subset.
21#[derive(Debug, Clone)]
22pub struct IndependentConsumerEngine {
23    graph: Graph,
24}
25
26impl IndependentConsumerEngine {
27    pub fn new(graph: Graph) -> Self {
28        Self { graph }
29    }
30}
31
32impl PlanEngine for IndependentConsumerEngine {
33    type Error = ExecutionError;
34
35    fn execute_plan(&mut self, plan: &Plan) -> Result<QueryResult, Self::Error> {
36        let mut outputs = Vec::<RowSet>::with_capacity(plan.ops.len());
37        for op in &plan.ops {
38            let rows = match op {
39                Op::ScanNodes {
40                    labels,
41                    must_labels,
42                    forbidden_labels,
43                    ..
44                } => self.scan_nodes(labels, must_labels, forbidden_labels),
45                Op::Expand {
46                    input,
47                    src_col,
48                    types,
49                    dir,
50                    legal_src_labels,
51                    legal_dst_labels,
52                    ..
53                } => self.expand(
54                    get_output(&outputs, *input)?,
55                    ExpandSpec {
56                        src_col: *src_col,
57                        types,
58                        dir: *dir,
59                        legal_src_labels,
60                        legal_dst_labels,
61                        optional: false,
62                    },
63                )?,
64                Op::OptionalExpand {
65                    input,
66                    src_col,
67                    types,
68                    dir,
69                    legal_src_labels,
70                    legal_dst_labels,
71                    ..
72                } => self.expand(
73                    get_output(&outputs, *input)?,
74                    ExpandSpec {
75                        src_col: *src_col,
76                        types,
77                        dir: *dir,
78                        legal_src_labels,
79                        legal_dst_labels,
80                        optional: true,
81                    },
82                )?,
83                Op::Filter { input, predicate } => {
84                    self.filter(get_output(&outputs, *input)?, predicate)?
85                }
86                Op::Project { input, exprs, .. } => {
87                    self.project(get_output(&outputs, *input)?, exprs)?
88                }
89                Op::Sort { input, keys, dirs } => {
90                    self.sort(get_output(&outputs, *input)?, keys, dirs)?
91                }
92                Op::Return { input } => get_output(&outputs, *input)?.clone(),
93                _ => {
94                    return Err(ExecutionError::UnsupportedOp(
95                        "independent consumer proof subset",
96                    ))
97                }
98            };
99            outputs.push(rows);
100        }
101
102        let Some(rows) = outputs.get(plan.root_op as usize) else {
103            return Err(ExecutionError::InvalidRootOp(plan.root_op));
104        };
105        Ok(QueryResult {
106            rows: rows.clone(),
107            continuation: None,
108        })
109    }
110}
111
112pub fn proof_fixture_graph() -> Graph {
113    let node = |id: u64, labels: &[&str], props: &[(&str, Value)]| Node {
114        id,
115        labels: labels.iter().map(|label| (*label).to_string()).collect(),
116        props: props
117            .iter()
118            .map(|(key, value)| ((*key).to_string(), value.clone()))
119            .collect(),
120    };
121    let rel = |id: u64, src: u64, dst: u64, typ: &str| Relationship {
122        id,
123        src,
124        dst,
125        typ: typ.to_string(),
126        props: Default::default(),
127    };
128
129    Graph {
130        nodes: vec![
131            node(
132                1,
133                &["Person"],
134                &[
135                    ("name", Value::String("Alice".to_string())),
136                    ("age", Value::Int(30)),
137                ],
138            ),
139            node(
140                2,
141                &["Person"],
142                &[
143                    ("name", Value::String("Bob".to_string())),
144                    ("age", Value::Int(40)),
145                ],
146            ),
147            node(
148                3,
149                &["Company"],
150                &[("name", Value::String("Acme".to_string()))],
151            ),
152        ],
153        rels: vec![
154            rel(10, 1, 2, "KNOWS"),
155            rel(11, 2, 1, "KNOWS"),
156            rel(12, 2, 3, "WORKS_AT"),
157        ],
158    }
159}
160
161impl IndependentConsumerEngine {
162    fn scan_nodes(
163        &self,
164        labels: &[String],
165        must_labels: &[String],
166        forbidden_labels: &[String],
167    ) -> RowSet {
168        self.graph
169            .nodes
170            .iter()
171            .filter(|node| {
172                labels.iter().all(|label| node.labels.contains(label))
173                    && must_labels.iter().all(|label| node.labels.contains(label))
174                    && forbidden_labels
175                        .iter()
176                        .all(|label| !node.labels.contains(label))
177            })
178            .map(|node| vec![Value::NodeRef(node.id)])
179            .collect()
180    }
181
182    fn expand(&self, input: &[Row], spec: ExpandSpec<'_>) -> Result<RowSet, ExecutionError> {
183        let mut out = Vec::new();
184        for row in input {
185            let Some(value) = row.get(spec.src_col as usize) else {
186                return Err(ExecutionError::ColumnOutOfBounds {
187                    idx: spec.src_col as usize,
188                    len: row.len(),
189                });
190            };
191            let Value::NodeRef(src_id) = value else {
192                return Err(ExecutionError::ExpectedNodeRef {
193                    idx: spec.src_col as usize,
194                });
195            };
196            let src_node = self
197                .graph
198                .node_by_id(*src_id)
199                .ok_or(ExecutionError::UnknownNode(*src_id))?;
200            if !labels_match(src_node, spec.legal_src_labels) {
201                continue;
202            }
203
204            let mut matched = false;
205            for rel in &self.graph.rels {
206                if !spec.types.is_empty() && !spec.types.iter().any(|typ| typ == &rel.typ) {
207                    continue;
208                }
209                if let Some(dst_id) = relation_endpoint(rel, *src_id, spec.dir) {
210                    let dst_node = self
211                        .graph
212                        .node_by_id(dst_id)
213                        .ok_or(ExecutionError::UnknownNode(dst_id))?;
214                    if !labels_match(dst_node, spec.legal_dst_labels) {
215                        continue;
216                    }
217                    let mut next = row.clone();
218                    next.push(Value::RelRef(rel.id));
219                    next.push(Value::NodeRef(dst_id));
220                    out.push(next);
221                    matched = true;
222                }
223            }
224
225            if spec.optional && !matched {
226                let mut next = row.clone();
227                next.push(Value::Null);
228                next.push(Value::Null);
229                out.push(next);
230            }
231        }
232        Ok(out)
233    }
234
235    fn filter(&self, input: &[Row], predicate: &Expr) -> Result<RowSet, ExecutionError> {
236        let mut out = Vec::new();
237        for row in input {
238            if matches!(self.eval_expr(row, predicate)?, Value::Bool(true)) {
239                out.push(row.clone());
240            }
241        }
242        Ok(out)
243    }
244
245    fn project(&self, input: &[Row], exprs: &[Expr]) -> Result<RowSet, ExecutionError> {
246        input
247            .iter()
248            .map(|row| {
249                exprs
250                    .iter()
251                    .map(|expr| self.eval_expr(row, expr))
252                    .collect::<Result<Row, _>>()
253            })
254            .collect()
255    }
256
257    fn sort(
258        &self,
259        input: &[Row],
260        keys: &[u32],
261        dirs: &[SortDir],
262    ) -> Result<RowSet, ExecutionError> {
263        if keys.len() != dirs.len() {
264            return Err(ExecutionError::SortArityMismatch {
265                keys: keys.len(),
266                dirs: dirs.len(),
267            });
268        }
269        let mut out = input.to_vec();
270        out.sort_by(|lhs, rhs| compare_rows(lhs, rhs, keys, dirs));
271        Ok(out)
272    }
273
274    fn eval_expr(&self, row: &Row, expr: &Expr) -> Result<Value, ExecutionError> {
275        match expr {
276            Expr::ColRef { idx } => {
277                row.get(*idx as usize)
278                    .cloned()
279                    .ok_or(ExecutionError::ColumnOutOfBounds {
280                        idx: *idx as usize,
281                        len: row.len(),
282                    })
283            }
284            Expr::PropAccess { col, prop } => {
285                let Some(value) = row.get(*col as usize) else {
286                    return Err(ExecutionError::ColumnOutOfBounds {
287                        idx: *col as usize,
288                        len: row.len(),
289                    });
290                };
291                self.property_access(value, prop)
292            }
293            Expr::IntLiteral(value) => Ok(Value::Int(*value)),
294            Expr::FloatLiteral(value) => Ok(Value::Float(*value)),
295            Expr::BoolLiteral(value) => Ok(Value::Bool(*value)),
296            Expr::StringLiteral(value) => Ok(Value::String(value.clone())),
297            Expr::NullLiteral => Ok(Value::Null),
298            Expr::Cmp { op, lhs, rhs } => {
299                let lhs = self.eval_expr(row, lhs)?;
300                let rhs = self.eval_expr(row, rhs)?;
301                Ok(compare_expr_values(*op, lhs, rhs))
302            }
303            _ => Err(ExecutionError::UnsupportedExpr(
304                "independent consumer proof subset",
305            )),
306        }
307    }
308
309    fn property_access(&self, value: &Value, prop: &str) -> Result<Value, ExecutionError> {
310        match value {
311            Value::Null => Ok(Value::Null),
312            Value::NodeRef(id) => Ok(self
313                .graph
314                .node_by_id(*id)
315                .ok_or(ExecutionError::UnknownNode(*id))?
316                .props
317                .get(prop)
318                .cloned()
319                .unwrap_or(Value::Null)),
320            Value::RelRef(id) => Ok(self
321                .graph
322                .rel_by_id(*id)
323                .ok_or(ExecutionError::UnknownRel(*id))?
324                .props
325                .get(prop)
326                .cloned()
327                .unwrap_or(Value::Null)),
328            Value::Map(entries) => Ok(entries.get(prop).cloned().unwrap_or(Value::Null)),
329            _ => Ok(Value::Null),
330        }
331    }
332}
333
334fn get_output(outputs: &[RowSet], idx: u32) -> Result<&RowSet, ExecutionError> {
335    outputs
336        .get(idx as usize)
337        .ok_or(ExecutionError::MissingOpOutput(idx))
338}
339
340fn labels_match(node: &Node, required: &[String]) -> bool {
341    required.is_empty() || required.iter().all(|label| node.labels.contains(label))
342}
343
344fn relation_endpoint(rel: &Relationship, src_id: u64, dir: ExpandDir) -> Option<u64> {
345    match dir {
346        ExpandDir::Out if rel.src == src_id => Some(rel.dst),
347        ExpandDir::In if rel.dst == src_id => Some(rel.src),
348        ExpandDir::Both if rel.src == src_id => Some(rel.dst),
349        ExpandDir::Both if rel.dst == src_id => Some(rel.src),
350        _ => None,
351    }
352}
353
354fn compare_rows(lhs: &Row, rhs: &Row, keys: &[u32], dirs: &[SortDir]) -> Ordering {
355    for (key, dir) in keys.iter().zip(dirs) {
356        let lhs_value = lhs.get(*key as usize).unwrap_or(&Value::Null);
357        let rhs_value = rhs.get(*key as usize).unwrap_or(&Value::Null);
358        let ordering = compare_values(lhs_value, rhs_value);
359        if ordering != Ordering::Equal {
360            return match dir {
361                SortDir::Asc => ordering,
362                SortDir::Desc => ordering.reverse(),
363            };
364        }
365    }
366    Ordering::Equal
367}
368
369fn compare_expr_values(op: CmpOp, lhs: Value, rhs: Value) -> Value {
370    if matches!(lhs, Value::Null) || matches!(rhs, Value::Null) {
371        return Value::Null;
372    }
373
374    let ordering = compare_values(&lhs, &rhs);
375    let result = match op {
376        CmpOp::Eq => lhs == rhs,
377        CmpOp::Ne => lhs != rhs,
378        CmpOp::Lt => ordering == Ordering::Less,
379        CmpOp::Gt => ordering == Ordering::Greater,
380        CmpOp::Le => ordering != Ordering::Greater,
381        CmpOp::Ge => ordering != Ordering::Less,
382    };
383    Value::Bool(result)
384}
385
386fn compare_values(lhs: &Value, rhs: &Value) -> Ordering {
387    match (lhs, rhs) {
388        (Value::Null, Value::Null) => Ordering::Equal,
389        (Value::Null, _) => Ordering::Less,
390        (_, Value::Null) => Ordering::Greater,
391        (Value::Bool(lhs), Value::Bool(rhs)) => lhs.cmp(rhs),
392        (Value::Int(lhs), Value::Int(rhs)) => lhs.cmp(rhs),
393        (Value::Float(lhs), Value::Float(rhs)) => lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal),
394        (Value::Int(lhs), Value::Float(rhs)) => {
395            (*lhs as f64).partial_cmp(rhs).unwrap_or(Ordering::Equal)
396        }
397        (Value::Float(lhs), Value::Int(rhs)) => {
398            lhs.partial_cmp(&(*rhs as f64)).unwrap_or(Ordering::Equal)
399        }
400        (Value::String(lhs), Value::String(rhs)) => lhs.cmp(rhs),
401        (Value::NodeRef(lhs), Value::NodeRef(rhs)) => lhs.cmp(rhs),
402        (Value::RelRef(lhs), Value::RelRef(rhs)) => lhs.cmp(rhs),
403        _ => value_rank(lhs).cmp(&value_rank(rhs)),
404    }
405}
406
407fn value_rank(value: &Value) -> u8 {
408    match value {
409        Value::Null => 0,
410        Value::Bool(_) => 1,
411        Value::Int(_) | Value::Float(_) => 2,
412        Value::String(_) => 3,
413        Value::NodeRef(_) => 4,
414        Value::RelRef(_) => 5,
415        Value::List(_) => 6,
416        Value::Map(_) => 7,
417    }
418}