plexus-engine 0.3.4

Engine integration traits for consuming Plexus plans
Documentation
use crate::*;
use plexus_serde::ArithOp;
use std::collections::BTreeMap;

impl InMemoryEngine {
    pub(crate) fn eval_expr(&self, row: &Row, expr: &Expr) -> Result<Value, ExecutionError> {
        Ok(match expr {
            Expr::ColRef { idx } => {
                row.get(*idx as usize)
                    .cloned()
                    .ok_or(ExecutionError::ColumnOutOfBounds {
                        idx: *idx as usize,
                        len: row.len(),
                    })?
            }
            Expr::PropAccess { col, prop } => {
                let v = row
                    .get(*col as usize)
                    .ok_or(ExecutionError::ColumnOutOfBounds {
                        idx: *col as usize,
                        len: row.len(),
                    })?;
                match v {
                    Value::NodeRef(id) => self
                        .graph
                        .node_by_id(*id)
                        .ok_or(ExecutionError::UnknownNode(*id))?
                        .props
                        .get(prop)
                        .cloned()
                        .unwrap_or(Value::Null),
                    Value::RelRef(id) => self
                        .graph
                        .rel_by_id(*id)
                        .ok_or(ExecutionError::UnknownRel(*id))?
                        .props
                        .get(prop)
                        .cloned()
                        .unwrap_or(Value::Null),
                    _ => Value::Null,
                }
            }
            Expr::IntLiteral(v) => Value::Int(*v),
            Expr::FloatLiteral(v) => Value::Float(*v),
            Expr::BoolLiteral(v) => Value::Bool(*v),
            Expr::StringLiteral(v) => Value::String(v.clone()),
            Expr::NullLiteral => Value::Null,
            Expr::Cmp { op, lhs, rhs } => {
                let l = self.eval_expr(row, lhs)?;
                let r = self.eval_expr(row, rhs)?;
                Value::Bool(cmp_values(*op, &l, &r))
            }
            Expr::And { lhs, rhs } => {
                let l = self.eval_expr(row, lhs)?;
                let r = self.eval_expr(row, rhs)?;
                Value::Bool(as_bool(&l) && as_bool(&r))
            }
            Expr::Or { lhs, rhs } => {
                let l = self.eval_expr(row, lhs)?;
                let r = self.eval_expr(row, rhs)?;
                Value::Bool(as_bool(&l) || as_bool(&r))
            }
            Expr::Not { expr } => {
                let x = self.eval_expr(row, expr)?;
                Value::Bool(!as_bool(&x))
            }
            Expr::IsNull { expr } => {
                let x = self.eval_expr(row, expr)?;
                Value::Bool(matches!(x, Value::Null))
            }
            Expr::IsNotNull { expr } => {
                let x = self.eval_expr(row, expr)?;
                Value::Bool(!matches!(x, Value::Null))
            }
            Expr::StartsWith { expr, pattern } => {
                let x = self.eval_expr(row, expr)?;
                Value::Bool(as_str(&x).is_some_and(|s| s.starts_with(pattern)))
            }
            Expr::EndsWith { expr, pattern } => {
                let x = self.eval_expr(row, expr)?;
                Value::Bool(as_str(&x).is_some_and(|s| s.ends_with(pattern)))
            }
            Expr::Contains { expr, pattern } => {
                let x = self.eval_expr(row, expr)?;
                Value::Bool(as_str(&x).is_some_and(|s| s.contains(pattern)))
            }
            Expr::In { expr, items } => {
                let needle = self.eval_expr(row, expr)?;
                let mut found = false;
                for item in items {
                    let v = self.eval_expr(row, item)?;
                    if v == needle {
                        found = true;
                        break;
                    }
                }
                Value::Bool(found)
            }
            Expr::ListLiteral { items } => {
                let mut out = Vec::with_capacity(items.len());
                for item in items {
                    out.push(self.eval_expr(row, item)?);
                }
                Value::List(out)
            }
            Expr::MapLiteral { entries } => {
                let mut out = BTreeMap::new();
                for (k, v) in entries {
                    out.insert(k.clone(), self.eval_expr(row, v)?);
                }
                Value::Map(out)
            }
            Expr::Exists { expr } => {
                let x = self.eval_expr(row, expr)?;
                Value::Bool(!matches!(x, Value::Null))
            }
            Expr::ListComprehension { .. } => {
                return Err(ExecutionError::UnsupportedExpr("list comprehension"))
            }
            Expr::Agg { .. } => return Err(ExecutionError::ExpectedAggregateExpr),
            Expr::Arith { op, lhs, rhs } => {
                let l = self.eval_expr(row, lhs)?;
                let r = self.eval_expr(row, rhs)?;
                eval_arith(*op, &l, &r)?
            }
            Expr::Param { name, .. } => self
                .params
                .get(name)
                .cloned()
                .ok_or_else(|| ExecutionError::UnboundParam(name.clone()))?,
            Expr::Case { arms, else_expr } => {
                let mut matched = None;
                for (when_expr, then_expr) in arms {
                    let cond = self.eval_expr(row, when_expr)?;
                    if as_bool(&cond) {
                        matched = Some(self.eval_expr(row, then_expr)?);
                        break;
                    }
                }
                match matched {
                    Some(v) => v,
                    None => match else_expr {
                        Some(e) => self.eval_expr(row, e)?,
                        None => Value::Null,
                    },
                }
            }
            Expr::VectorSimilarity { .. } => {
                return Err(ExecutionError::UnsupportedExpr("vector similarity"))
            }
        })
    }

    pub(crate) fn eval_agg(&self, rows: &[Row], expr: &Expr) -> Result<Value, ExecutionError> {
        let Expr::Agg { fn_, expr } = expr else {
            return Err(ExecutionError::ExpectedAggregateExpr);
        };

        match fn_ {
            AggFn::CountStar => Ok(Value::Int(rows.len() as i64)),
            AggFn::Count => {
                let mut cnt = 0i64;
                for row in rows {
                    let Some(e) = expr else {
                        continue;
                    };
                    let v = self.eval_expr(row, e)?;
                    if !matches!(v, Value::Null) {
                        cnt += 1;
                    }
                }
                Ok(Value::Int(cnt))
            }
            AggFn::Sum => {
                let mut saw_float = false;
                let mut sum_i = 0i64;
                let mut sum_f = 0.0f64;
                for row in rows {
                    let Some(e) = expr else {
                        continue;
                    };
                    let v = self.eval_expr(row, e)?;
                    match v {
                        Value::Int(x) => {
                            sum_i += x;
                            sum_f += x as f64;
                        }
                        Value::Float(x) => {
                            saw_float = true;
                            sum_f += x;
                        }
                        Value::Null => {}
                        _ => return Err(ExecutionError::ExpectedNumeric),
                    }
                }
                if saw_float {
                    Ok(Value::Float(sum_f))
                } else {
                    Ok(Value::Int(sum_i))
                }
            }
            AggFn::Avg => {
                let mut sum = 0.0f64;
                let mut cnt = 0usize;
                for row in rows {
                    let Some(e) = expr else {
                        continue;
                    };
                    let v = self.eval_expr(row, e)?;
                    match v {
                        Value::Int(x) => {
                            sum += x as f64;
                            cnt += 1;
                        }
                        Value::Float(x) => {
                            sum += x;
                            cnt += 1;
                        }
                        Value::Null => {}
                        _ => return Err(ExecutionError::ExpectedNumeric),
                    }
                }
                if cnt == 0 {
                    Ok(Value::Null)
                } else {
                    Ok(Value::Float(sum / cnt as f64))
                }
            }
            AggFn::Min => reduce_min_max(self, rows, expr.as_deref(), true),
            AggFn::Max => reduce_min_max(self, rows, expr.as_deref(), false),
            AggFn::Collect => {
                let mut out = Vec::with_capacity(rows.len());
                for row in rows {
                    let Some(e) = expr else {
                        continue;
                    };
                    out.push(self.eval_expr(row, e)?);
                }
                Ok(Value::List(out))
            }
        }
    }
}

pub(crate) fn as_bool(v: &Value) -> bool {
    matches!(v, Value::Bool(true))
}

pub(crate) fn as_str(v: &Value) -> Option<&str> {
    match v {
        Value::String(s) => Some(s.as_str()),
        _ => None,
    }
}

pub(crate) fn cmp_values(op: CmpOp, lhs: &Value, rhs: &Value) -> bool {
    let ord = cmp_ordering(lhs, rhs);
    match op {
        CmpOp::Eq => lhs == rhs,
        CmpOp::Ne => lhs != rhs,
        CmpOp::Lt => ord == Some(Ordering::Less),
        CmpOp::Gt => ord == Some(Ordering::Greater),
        CmpOp::Le => matches!(ord, Some(Ordering::Less | Ordering::Equal)),
        CmpOp::Ge => matches!(ord, Some(Ordering::Greater | Ordering::Equal)),
    }
}

pub(crate) fn cmp_ordering(lhs: &Value, rhs: &Value) -> Option<Ordering> {
    match (lhs, rhs) {
        (Value::Null, Value::Null) => Some(Ordering::Equal),
        (Value::Null, _) => Some(Ordering::Less),
        (_, Value::Null) => Some(Ordering::Greater),
        (Value::Int(a), Value::Int(b)) => Some(a.cmp(b)),
        (Value::Float(a), Value::Float(b)) => a.partial_cmp(b),
        (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
        (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)),
        (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
        (Value::String(a), Value::String(b)) => Some(a.cmp(b)),
        (Value::NodeRef(a), Value::NodeRef(b)) => Some(a.cmp(b)),
        (Value::RelRef(a), Value::RelRef(b)) => Some(a.cmp(b)),
        _ => None,
    }
}

pub(super) fn reduce_min_max(
    engine: &InMemoryEngine,
    rows: &[Row],
    expr: Option<&Expr>,
    is_min: bool,
) -> Result<Value, ExecutionError> {
    let Some(e) = expr else {
        return Ok(Value::Null);
    };
    let mut best: Option<Value> = None;
    for row in rows {
        let v = engine.eval_expr(row, e)?;
        if matches!(v, Value::Null) {
            continue;
        }
        match &best {
            None => best = Some(v),
            Some(b) => {
                if let Some(ord) = cmp_ordering(&v, b) {
                    if (is_min && ord == Ordering::Less) || (!is_min && ord == Ordering::Greater) {
                        best = Some(v);
                    }
                }
            }
        }
    }
    Ok(best.unwrap_or(Value::Null))
}

fn eval_arith(op: ArithOp, lhs: &Value, rhs: &Value) -> Result<Value, ExecutionError> {
    use ArithOp::{Add, Div, Mul, Sub};
    match (lhs, rhs) {
        (Value::Int(a), Value::Int(b)) => match op {
            Add => Ok(Value::Int(a + b)),
            Sub => Ok(Value::Int(a - b)),
            Mul => Ok(Value::Int(a * b)),
            Div => Ok(Value::Float(*a as f64 / *b as f64)),
        },
        (Value::Int(a), Value::Float(b)) => Ok(Value::Float(eval_arith_f64(op, *a as f64, *b))),
        (Value::Float(a), Value::Int(b)) => Ok(Value::Float(eval_arith_f64(op, *a, *b as f64))),
        (Value::Float(a), Value::Float(b)) => Ok(Value::Float(eval_arith_f64(op, *a, *b))),
        _ => Err(ExecutionError::ExpectedNumeric),
    }
}

fn eval_arith_f64(op: ArithOp, lhs: f64, rhs: f64) -> f64 {
    use ArithOp::{Add, Div, Mul, Sub};
    match op {
        Add => lhs + rhs,
        Sub => lhs - rhs,
        Mul => lhs * rhs,
        Div => lhs / rhs,
    }
}