use std::cmp::Ordering;
use plexus_serde::{CmpOp, ExpandDir, Expr, Op, Plan, SortDir};
use crate::{ExecutionError, Graph, Node, PlanEngine, QueryResult, Relationship, Row, Value};
type RowSet = Vec<Row>;
struct ExpandSpec<'a> {
src_col: u32,
types: &'a [String],
dir: ExpandDir,
legal_src_labels: &'a [String],
legal_dst_labels: &'a [String],
optional: bool,
}
#[derive(Debug, Clone)]
pub struct IndependentConsumerEngine {
graph: Graph,
}
impl IndependentConsumerEngine {
pub fn new(graph: Graph) -> Self {
Self { graph }
}
}
impl PlanEngine for IndependentConsumerEngine {
type Error = ExecutionError;
fn execute_plan(&mut self, plan: &Plan) -> Result<QueryResult, Self::Error> {
let mut outputs = Vec::<RowSet>::with_capacity(plan.ops.len());
for op in &plan.ops {
let rows = match op {
Op::ScanNodes {
labels,
must_labels,
forbidden_labels,
..
} => self.scan_nodes(labels, must_labels, forbidden_labels),
Op::Expand {
input,
src_col,
types,
dir,
legal_src_labels,
legal_dst_labels,
..
} => self.expand(
get_output(&outputs, *input)?,
ExpandSpec {
src_col: *src_col,
types,
dir: *dir,
legal_src_labels,
legal_dst_labels,
optional: false,
},
)?,
Op::OptionalExpand {
input,
src_col,
types,
dir,
legal_src_labels,
legal_dst_labels,
..
} => self.expand(
get_output(&outputs, *input)?,
ExpandSpec {
src_col: *src_col,
types,
dir: *dir,
legal_src_labels,
legal_dst_labels,
optional: true,
},
)?,
Op::Filter { input, predicate } => {
self.filter(get_output(&outputs, *input)?, predicate)?
}
Op::Project { input, exprs, .. } => {
self.project(get_output(&outputs, *input)?, exprs)?
}
Op::Sort { input, keys, dirs } => {
self.sort(get_output(&outputs, *input)?, keys, dirs)?
}
Op::Return { input } => get_output(&outputs, *input)?.clone(),
_ => {
return Err(ExecutionError::UnsupportedOp(
"independent consumer proof subset",
))
}
};
outputs.push(rows);
}
let Some(rows) = outputs.get(plan.root_op as usize) else {
return Err(ExecutionError::InvalidRootOp(plan.root_op));
};
Ok(QueryResult { rows: rows.clone() })
}
}
pub fn proof_fixture_graph() -> Graph {
let node = |id: u64, labels: &[&str], props: &[(&str, Value)]| Node {
id,
labels: labels.iter().map(|label| (*label).to_string()).collect(),
props: props
.iter()
.map(|(key, value)| ((*key).to_string(), value.clone()))
.collect(),
};
let rel = |id: u64, src: u64, dst: u64, typ: &str| Relationship {
id,
src,
dst,
typ: typ.to_string(),
props: Default::default(),
};
Graph {
nodes: vec![
node(
1,
&["Person"],
&[
("name", Value::String("Alice".to_string())),
("age", Value::Int(30)),
],
),
node(
2,
&["Person"],
&[
("name", Value::String("Bob".to_string())),
("age", Value::Int(40)),
],
),
node(
3,
&["Company"],
&[("name", Value::String("Acme".to_string()))],
),
],
rels: vec![
rel(10, 1, 2, "KNOWS"),
rel(11, 2, 1, "KNOWS"),
rel(12, 2, 3, "WORKS_AT"),
],
}
}
impl IndependentConsumerEngine {
fn scan_nodes(
&self,
labels: &[String],
must_labels: &[String],
forbidden_labels: &[String],
) -> RowSet {
self.graph
.nodes
.iter()
.filter(|node| {
labels.iter().all(|label| node.labels.contains(label))
&& must_labels.iter().all(|label| node.labels.contains(label))
&& forbidden_labels
.iter()
.all(|label| !node.labels.contains(label))
})
.map(|node| vec![Value::NodeRef(node.id)])
.collect()
}
fn expand(&self, input: &[Row], spec: ExpandSpec<'_>) -> Result<RowSet, ExecutionError> {
let mut out = Vec::new();
for row in input {
let Some(value) = row.get(spec.src_col as usize) else {
return Err(ExecutionError::ColumnOutOfBounds {
idx: spec.src_col as usize,
len: row.len(),
});
};
let Value::NodeRef(src_id) = value else {
return Err(ExecutionError::ExpectedNodeRef {
idx: spec.src_col as usize,
});
};
let src_node = self
.graph
.node_by_id(*src_id)
.ok_or(ExecutionError::UnknownNode(*src_id))?;
if !labels_match(src_node, spec.legal_src_labels) {
continue;
}
let mut matched = false;
for rel in &self.graph.rels {
if !spec.types.is_empty() && !spec.types.iter().any(|typ| typ == &rel.typ) {
continue;
}
if let Some(dst_id) = relation_endpoint(rel, *src_id, spec.dir) {
let dst_node = self
.graph
.node_by_id(dst_id)
.ok_or(ExecutionError::UnknownNode(dst_id))?;
if !labels_match(dst_node, spec.legal_dst_labels) {
continue;
}
let mut next = row.clone();
next.push(Value::RelRef(rel.id));
next.push(Value::NodeRef(dst_id));
out.push(next);
matched = true;
}
}
if spec.optional && !matched {
let mut next = row.clone();
next.push(Value::Null);
next.push(Value::Null);
out.push(next);
}
}
Ok(out)
}
fn filter(&self, input: &[Row], predicate: &Expr) -> Result<RowSet, ExecutionError> {
let mut out = Vec::new();
for row in input {
if matches!(self.eval_expr(row, predicate)?, Value::Bool(true)) {
out.push(row.clone());
}
}
Ok(out)
}
fn project(&self, input: &[Row], exprs: &[Expr]) -> Result<RowSet, ExecutionError> {
input
.iter()
.map(|row| {
exprs
.iter()
.map(|expr| self.eval_expr(row, expr))
.collect::<Result<Row, _>>()
})
.collect()
}
fn sort(
&self,
input: &[Row],
keys: &[u32],
dirs: &[SortDir],
) -> Result<RowSet, ExecutionError> {
if keys.len() != dirs.len() {
return Err(ExecutionError::SortArityMismatch {
keys: keys.len(),
dirs: dirs.len(),
});
}
let mut out = input.to_vec();
out.sort_by(|lhs, rhs| compare_rows(lhs, rhs, keys, dirs));
Ok(out)
}
fn eval_expr(&self, row: &Row, expr: &Expr) -> Result<Value, ExecutionError> {
match expr {
Expr::ColRef { idx } => {
row.get(*idx as usize)
.cloned()
.ok_or(ExecutionError::ColumnOutOfBounds {
idx: *idx as usize,
len: row.len(),
})
}
Expr::PropAccess { col, prop } => {
let Some(value) = row.get(*col as usize) else {
return Err(ExecutionError::ColumnOutOfBounds {
idx: *col as usize,
len: row.len(),
});
};
self.property_access(value, prop)
}
Expr::IntLiteral(value) => Ok(Value::Int(*value)),
Expr::FloatLiteral(value) => Ok(Value::Float(*value)),
Expr::BoolLiteral(value) => Ok(Value::Bool(*value)),
Expr::StringLiteral(value) => Ok(Value::String(value.clone())),
Expr::NullLiteral => Ok(Value::Null),
Expr::Cmp { op, lhs, rhs } => {
let lhs = self.eval_expr(row, lhs)?;
let rhs = self.eval_expr(row, rhs)?;
Ok(compare_expr_values(*op, lhs, rhs))
}
_ => Err(ExecutionError::UnsupportedExpr(
"independent consumer proof subset",
)),
}
}
fn property_access(&self, value: &Value, prop: &str) -> Result<Value, ExecutionError> {
match value {
Value::Null => Ok(Value::Null),
Value::NodeRef(id) => Ok(self
.graph
.node_by_id(*id)
.ok_or(ExecutionError::UnknownNode(*id))?
.props
.get(prop)
.cloned()
.unwrap_or(Value::Null)),
Value::RelRef(id) => Ok(self
.graph
.rel_by_id(*id)
.ok_or(ExecutionError::UnknownRel(*id))?
.props
.get(prop)
.cloned()
.unwrap_or(Value::Null)),
Value::Map(entries) => Ok(entries.get(prop).cloned().unwrap_or(Value::Null)),
_ => Ok(Value::Null),
}
}
}
fn get_output(outputs: &[RowSet], idx: u32) -> Result<&RowSet, ExecutionError> {
outputs
.get(idx as usize)
.ok_or(ExecutionError::MissingOpOutput(idx))
}
fn labels_match(node: &Node, required: &[String]) -> bool {
required.is_empty() || required.iter().all(|label| node.labels.contains(label))
}
fn relation_endpoint(rel: &Relationship, src_id: u64, dir: ExpandDir) -> Option<u64> {
match dir {
ExpandDir::Out if rel.src == src_id => Some(rel.dst),
ExpandDir::In if rel.dst == src_id => Some(rel.src),
ExpandDir::Both if rel.src == src_id => Some(rel.dst),
ExpandDir::Both if rel.dst == src_id => Some(rel.src),
_ => None,
}
}
fn compare_rows(lhs: &Row, rhs: &Row, keys: &[u32], dirs: &[SortDir]) -> Ordering {
for (key, dir) in keys.iter().zip(dirs) {
let lhs_value = lhs.get(*key as usize).unwrap_or(&Value::Null);
let rhs_value = rhs.get(*key as usize).unwrap_or(&Value::Null);
let ordering = compare_values(lhs_value, rhs_value);
if ordering != Ordering::Equal {
return match dir {
SortDir::Asc => ordering,
SortDir::Desc => ordering.reverse(),
};
}
}
Ordering::Equal
}
fn compare_expr_values(op: CmpOp, lhs: Value, rhs: Value) -> Value {
if matches!(lhs, Value::Null) || matches!(rhs, Value::Null) {
return Value::Null;
}
let ordering = compare_values(&lhs, &rhs);
let result = match op {
CmpOp::Eq => lhs == rhs,
CmpOp::Ne => lhs != rhs,
CmpOp::Lt => ordering == Ordering::Less,
CmpOp::Gt => ordering == Ordering::Greater,
CmpOp::Le => ordering != Ordering::Greater,
CmpOp::Ge => ordering != Ordering::Less,
};
Value::Bool(result)
}
fn compare_values(lhs: &Value, rhs: &Value) -> Ordering {
match (lhs, rhs) {
(Value::Null, Value::Null) => Ordering::Equal,
(Value::Null, _) => Ordering::Less,
(_, Value::Null) => Ordering::Greater,
(Value::Bool(lhs), Value::Bool(rhs)) => lhs.cmp(rhs),
(Value::Int(lhs), Value::Int(rhs)) => lhs.cmp(rhs),
(Value::Float(lhs), Value::Float(rhs)) => lhs.partial_cmp(rhs).unwrap_or(Ordering::Equal),
(Value::Int(lhs), Value::Float(rhs)) => {
(*lhs as f64).partial_cmp(rhs).unwrap_or(Ordering::Equal)
}
(Value::Float(lhs), Value::Int(rhs)) => {
lhs.partial_cmp(&(*rhs as f64)).unwrap_or(Ordering::Equal)
}
(Value::String(lhs), Value::String(rhs)) => lhs.cmp(rhs),
(Value::NodeRef(lhs), Value::NodeRef(rhs)) => lhs.cmp(rhs),
(Value::RelRef(lhs), Value::RelRef(rhs)) => lhs.cmp(rhs),
_ => value_rank(lhs).cmp(&value_rank(rhs)),
}
}
fn value_rank(value: &Value) -> u8 {
match value {
Value::Null => 0,
Value::Bool(_) => 1,
Value::Int(_) | Value::Float(_) => 2,
Value::String(_) => 3,
Value::NodeRef(_) => 4,
Value::RelRef(_) => 5,
Value::List(_) => 6,
Value::Map(_) => 7,
}
}