use super::eval::{as_bool, as_str, cmp_ordering, cmp_values};
use crate::*;
use plexus_serde::ArithOp;
use plexus_serde::VectorMetric;
use std::cmp::Ordering;
use std::collections::BTreeMap;
impl MockVectorEngine {
fn input_rows<'a>(
&self,
outputs: &'a [Option<RowSet>],
input: u32,
) -> Result<&'a RowSet, ExecutionError> {
outputs
.get(input as usize)
.ok_or(ExecutionError::InvalidOpRef(input))?
.as_ref()
.ok_or(ExecutionError::MissingOpOutput(input))
}
fn eval_expr(&self, row: &Row, expr: &Expr) -> Result<Value, ExecutionError> {
Ok(match expr {
Expr::ColRef { idx } => {
row.get(*idx as usize)
.cloned()
.ok_or(ExecutionError::ColumnOutOfBounds {
idx: *idx as usize,
len: row.len(),
})?
}
Expr::PropAccess { col, prop } => {
let v = row
.get(*col as usize)
.ok_or(ExecutionError::ColumnOutOfBounds {
idx: *col as usize,
len: row.len(),
})?;
match v {
Value::NodeRef(id) => self
.base
.graph
.node_by_id(*id)
.ok_or(ExecutionError::UnknownNode(*id))?
.props
.get(prop)
.cloned()
.unwrap_or(Value::Null),
Value::RelRef(id) => self
.base
.graph
.rel_by_id(*id)
.ok_or(ExecutionError::UnknownRel(*id))?
.props
.get(prop)
.cloned()
.unwrap_or(Value::Null),
_ => Value::Null,
}
}
Expr::IntLiteral(v) => Value::Int(*v),
Expr::FloatLiteral(v) => Value::Float(*v),
Expr::BoolLiteral(v) => Value::Bool(*v),
Expr::StringLiteral(v) => Value::String(v.clone()),
Expr::NullLiteral => Value::Null,
Expr::Cmp { op, lhs, rhs } => {
let l = self.eval_expr(row, lhs)?;
let r = self.eval_expr(row, rhs)?;
Value::Bool(cmp_values(*op, &l, &r))
}
Expr::And { lhs, rhs } => {
let l = self.eval_expr(row, lhs)?;
let r = self.eval_expr(row, rhs)?;
Value::Bool(as_bool(&l) && as_bool(&r))
}
Expr::Or { lhs, rhs } => {
let l = self.eval_expr(row, lhs)?;
let r = self.eval_expr(row, rhs)?;
Value::Bool(as_bool(&l) || as_bool(&r))
}
Expr::Not { expr } => {
let x = self.eval_expr(row, expr)?;
Value::Bool(!as_bool(&x))
}
Expr::IsNull { expr } => {
let x = self.eval_expr(row, expr)?;
Value::Bool(matches!(x, Value::Null))
}
Expr::IsNotNull { expr } => {
let x = self.eval_expr(row, expr)?;
Value::Bool(!matches!(x, Value::Null))
}
Expr::StartsWith { expr, pattern } => {
let x = self.eval_expr(row, expr)?;
Value::Bool(as_str(&x).is_some_and(|s| s.starts_with(pattern)))
}
Expr::EndsWith { expr, pattern } => {
let x = self.eval_expr(row, expr)?;
Value::Bool(as_str(&x).is_some_and(|s| s.ends_with(pattern)))
}
Expr::Contains { expr, pattern } => {
let x = self.eval_expr(row, expr)?;
Value::Bool(as_str(&x).is_some_and(|s| s.contains(pattern)))
}
Expr::In { expr, items } => {
let needle = self.eval_expr(row, expr)?;
let mut found = false;
for item in items {
let v = self.eval_expr(row, item)?;
if v == needle {
found = true;
break;
}
}
Value::Bool(found)
}
Expr::ListLiteral { items } => {
let mut out = Vec::with_capacity(items.len());
for item in items {
out.push(self.eval_expr(row, item)?);
}
Value::List(out)
}
Expr::MapLiteral { entries } => {
let mut out = BTreeMap::new();
for (k, v) in entries {
out.insert(k.clone(), self.eval_expr(row, v)?);
}
Value::Map(out)
}
Expr::Exists { expr } => {
let x = self.eval_expr(row, expr)?;
Value::Bool(!matches!(x, Value::Null))
}
Expr::ListComprehension { .. } => {
return Err(ExecutionError::UnsupportedExpr("list comprehension"))
}
Expr::Agg { .. } => return Err(ExecutionError::ExpectedAggregateExpr),
Expr::Arith { op, lhs, rhs } => {
let l = self.eval_expr(row, lhs)?;
let r = self.eval_expr(row, rhs)?;
eval_arith(*op, &l, &r)?
}
Expr::Param { name, .. } => self
.base
.params
.get(name)
.cloned()
.ok_or_else(|| ExecutionError::UnboundParam(name.clone()))?,
Expr::Case { arms, else_expr } => {
let mut matched = None;
for (when_expr, then_expr) in arms {
let cond = self.eval_expr(row, when_expr)?;
if as_bool(&cond) {
matched = Some(self.eval_expr(row, then_expr)?);
break;
}
}
match matched {
Some(v) => v,
None => match else_expr {
Some(e) => self.eval_expr(row, e)?,
None => Value::Null,
},
}
}
Expr::VectorSimilarity { metric, lhs, rhs } => {
let lhs = self.eval_expr(row, lhs)?;
let rhs = self.eval_expr(row, rhs)?;
Value::Float(vector_similarity(*metric, &lhs, &rhs)?)
}
})
}
fn eval_agg(&self, rows: &[Row], expr: &Expr) -> Result<Value, ExecutionError> {
let Expr::Agg { fn_, expr } = expr else {
return Err(ExecutionError::ExpectedAggregateExpr);
};
match fn_ {
AggFn::CountStar => Ok(Value::Int(rows.len() as i64)),
AggFn::Count => {
let mut cnt = 0i64;
for row in rows {
let Some(e) = expr else {
continue;
};
let v = self.eval_expr(row, e)?;
if !matches!(v, Value::Null) {
cnt += 1;
}
}
Ok(Value::Int(cnt))
}
AggFn::Sum => {
let mut saw_float = false;
let mut sum_i = 0i64;
let mut sum_f = 0.0f64;
for row in rows {
let Some(e) = expr else {
continue;
};
let v = self.eval_expr(row, e)?;
match v {
Value::Int(x) => {
sum_i += x;
sum_f += x as f64;
}
Value::Float(x) => {
saw_float = true;
sum_f += x;
}
Value::Null => {}
_ => return Err(ExecutionError::ExpectedNumeric),
}
}
if saw_float {
Ok(Value::Float(sum_f))
} else {
Ok(Value::Int(sum_i))
}
}
AggFn::Avg => {
let mut sum = 0.0f64;
let mut cnt = 0usize;
for row in rows {
let Some(e) = expr else {
continue;
};
let v = self.eval_expr(row, e)?;
match v {
Value::Int(x) => {
sum += x as f64;
cnt += 1;
}
Value::Float(x) => {
sum += x;
cnt += 1;
}
Value::Null => {}
_ => return Err(ExecutionError::ExpectedNumeric),
}
}
if cnt == 0 {
Ok(Value::Null)
} else {
Ok(Value::Float(sum / cnt as f64))
}
}
AggFn::Min => reduce_min_max_vector(self, rows, expr.as_deref(), true),
AggFn::Max => reduce_min_max_vector(self, rows, expr.as_deref(), false),
AggFn::Collect => {
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let Some(e) = expr else {
continue;
};
out.push(self.eval_expr(row, e)?);
}
Ok(Value::List(out))
}
}
}
fn execute_filter_rows(
&self,
input_rows: &[Row],
predicate: &Expr,
) -> Result<RowSet, ExecutionError> {
let mut out = Vec::new();
for row in input_rows {
if as_bool(&self.eval_expr(row, predicate)?) {
out.push(row.clone());
}
}
Ok(out)
}
fn execute_project_rows(
&self,
input_rows: &[Row],
exprs: &[Expr],
) -> Result<RowSet, ExecutionError> {
let mut out = Vec::with_capacity(input_rows.len());
for row in input_rows {
let mut new_row = Vec::with_capacity(exprs.len());
for e in exprs {
new_row.push(self.eval_expr(row, e)?);
}
out.push(new_row);
}
Ok(out)
}
fn execute_unwind(&self, input: &[Row], list_expr: &Expr) -> Result<RowSet, ExecutionError> {
let mut out = Vec::new();
for row in input {
let value = self.eval_expr(row, list_expr)?;
match value {
Value::List(items) => {
for item in items {
let mut next = row.clone();
next.push(item);
out.push(next);
}
}
Value::Null => {}
scalar => {
let mut next = row.clone();
next.push(scalar);
out.push(next);
}
}
}
Ok(out)
}
fn execute_aggregate_rows(
&self,
input_rows: &[Row],
keys: &[u32],
aggs: &[Expr],
) -> Result<RowSet, ExecutionError> {
let mut groups: Vec<(Vec<Value>, Vec<Row>)> = Vec::new();
for row in input_rows {
let key_vals: Vec<Value> = keys
.iter()
.map(|k| {
row.get(*k as usize)
.cloned()
.ok_or(ExecutionError::ColumnOutOfBounds {
idx: *k as usize,
len: row.len(),
})
})
.collect::<Result<Vec<_>, _>>()?;
if let Some((_, g_rows)) = groups.iter_mut().find(|(k, _)| *k == key_vals) {
g_rows.push(row.clone());
} else {
groups.push((key_vals, vec![row.clone()]));
}
}
let mut out = Vec::new();
for (key_vals, g_rows) in groups {
let mut out_row = key_vals;
for a in aggs {
out_row.push(self.eval_agg(&g_rows, a)?);
}
out.push(out_row);
}
Ok(out)
}
fn execute_vector_scan(
&self,
input_rows: &[Row],
collection: &str,
query_vector: &Expr,
metric: VectorMetric,
top_k: u32,
) -> Result<RowSet, ExecutionError> {
let Some(entries) = self.collections.get(collection) else {
return Ok(Vec::new());
};
let mut out = Vec::new();
for row in input_rows {
let query = self.eval_expr(row, query_vector)?;
let mut scored = entries
.iter()
.enumerate()
.map(|(idx, entry)| {
let score = vector_similarity(
metric,
&query,
&Value::List(to_value_list(&entry.embedding)),
)?;
Ok::<_, ExecutionError>((idx, entry.node_id, score))
})
.collect::<Result<Vec<_>, _>>()?;
scored.sort_by(|(lhs_idx, _, lhs_score), (rhs_idx, _, rhs_score)| {
let ord = match metric {
VectorMetric::L2 => lhs_score.partial_cmp(rhs_score).unwrap_or(Ordering::Equal),
VectorMetric::Cosine | VectorMetric::DotProduct => {
rhs_score.partial_cmp(lhs_score).unwrap_or(Ordering::Equal)
}
};
if ord == Ordering::Equal {
lhs_idx.cmp(rhs_idx)
} else {
ord
}
});
for (_, node_id, score) in scored.into_iter().take(top_k as usize) {
out.push(vec![Value::NodeRef(node_id), Value::Float(score)]);
}
}
Ok(out)
}
fn execute_rerank(
&self,
input_rows: &[Row],
score_expr: &Expr,
top_k: u32,
) -> Result<RowSet, ExecutionError> {
let mut scored = input_rows
.iter()
.enumerate()
.map(|(idx, row)| {
let score = self.eval_expr(row, score_expr)?;
Ok::<_, ExecutionError>((idx, row.clone(), score))
})
.collect::<Result<Vec<_>, _>>()?;
scored.sort_by(|(lhs_idx, _, lhs_score), (rhs_idx, _, rhs_score)| {
let ord = cmp_ordering(lhs_score, rhs_score)
.unwrap_or(Ordering::Equal)
.reverse();
if ord == Ordering::Equal {
lhs_idx.cmp(rhs_idx)
} else {
ord
}
});
Ok(scored
.into_iter()
.take(top_k as usize)
.map(|(_, row, _)| row)
.collect())
}
}
impl PlanEngine for MockVectorEngine {
type Error = ExecutionError;
fn execute_plan(&mut self, plan: &Plan) -> Result<QueryResult, Self::Error> {
let mut seen_ref: Option<&str> = None;
for op in &plan.ops {
let graph_ref = match op {
Op::ScanNodes { graph_ref, .. }
| Op::Expand { graph_ref, .. }
| Op::OptionalExpand { graph_ref, .. }
| Op::ExpandVarLen { graph_ref, .. } => graph_ref.as_deref(),
_ => None,
};
if let Some(r) = graph_ref.map(str::trim).filter(|s| !s.is_empty()) {
match seen_ref {
None => seen_ref = Some(r),
Some(prev) if prev != r => return Err(ExecutionError::MultiGraphUnsupported),
_ => {}
}
}
}
let mut outputs: Vec<Option<RowSet>> = vec![None; plan.ops.len()];
for (idx, op) in plan.ops.iter().enumerate() {
let rows = match op {
Op::ScanNodes {
labels,
must_labels,
forbidden_labels,
..
} => self
.base
.execute_scan_nodes(labels, must_labels, forbidden_labels),
Op::ScanRels {
types,
src_labels,
dst_labels,
..
} => self.base.execute_scan_rels(types, src_labels, dst_labels),
Op::Expand {
input,
src_col,
types,
dir,
legal_src_labels,
legal_dst_labels,
..
} => self.base.execute_expand(
self.input_rows(&outputs, *input)?,
*src_col,
types,
*dir,
legal_src_labels,
legal_dst_labels,
)?,
Op::OptionalExpand {
input,
src_col,
types,
dir,
legal_src_labels,
legal_dst_labels,
..
} => self.base.execute_optional_expand(
self.input_rows(&outputs, *input)?,
*src_col,
types,
*dir,
legal_src_labels,
legal_dst_labels,
)?,
Op::SemiExpand {
input,
src_col,
types,
dir,
legal_src_labels,
legal_dst_labels,
..
} => self.base.execute_semi_expand(
self.input_rows(&outputs, *input)?,
*src_col,
types,
*dir,
legal_src_labels,
legal_dst_labels,
)?,
Op::ExpandVarLen {
input,
src_col,
types,
dir,
min_hops,
max_hops,
..
} => self.base.execute_expand_var_len(
self.input_rows(&outputs, *input)?,
*src_col,
types,
*dir,
*min_hops,
*max_hops,
)?,
Op::Filter { input, predicate } => {
self.execute_filter_rows(self.input_rows(&outputs, *input)?, predicate)?
}
Op::BlockMarker { input, .. } => self.input_rows(&outputs, *input)?.clone(),
Op::Project { input, exprs, .. } => {
self.execute_project_rows(self.input_rows(&outputs, *input)?, exprs)?
}
Op::Aggregate {
input, keys, aggs, ..
} => self.execute_aggregate_rows(self.input_rows(&outputs, *input)?, keys, aggs)?,
Op::Sort { input, keys, dirs } => {
self.base
.execute_sort_rows(self.input_rows(&outputs, *input)?, keys, dirs)?
}
Op::Limit { input, count, skip } => {
self.base
.execute_limit_rows(self.input_rows(&outputs, *input)?, *count, *skip)
}
Op::Return { input } => self.input_rows(&outputs, *input)?.clone(),
Op::Unwind {
input, list_expr, ..
} => self.execute_unwind(self.input_rows(&outputs, *input)?, list_expr)?,
Op::PathConstruct {
input, rel_cols, ..
} => self
.base
.execute_path_construct(self.input_rows(&outputs, *input)?, rel_cols)?,
Op::Union { lhs, rhs, all, .. } => self.base.execute_union_rows(
self.input_rows(&outputs, *lhs)?,
self.input_rows(&outputs, *rhs)?,
*all,
),
Op::VectorScan {
input,
collection,
query_vector,
metric,
top_k,
..
} => self.execute_vector_scan(
self.input_rows(&outputs, *input)?,
collection,
query_vector,
*metric,
*top_k,
)?,
Op::Rerank {
input,
score_expr,
top_k,
..
} => self.execute_rerank(self.input_rows(&outputs, *input)?, score_expr, *top_k)?,
Op::CreateNode { .. }
| Op::CreateRel { .. }
| Op::Merge { .. }
| Op::Delete { .. }
| Op::SetProperty { .. }
| Op::RemoveProperty { .. } => {
return Err(ExecutionError::UnsupportedOp("dml in mock vector engine"));
}
Op::ConstRow => vec![vec![]],
};
outputs[idx] = Some(rows);
}
let root_rows = outputs
.get(plan.root_op as usize)
.ok_or(ExecutionError::InvalidRootOp(plan.root_op))?
.clone()
.ok_or(ExecutionError::InvalidRootOp(plan.root_op))?;
Ok(QueryResult { rows: root_rows })
}
}
fn reduce_min_max_vector(
engine: &MockVectorEngine,
rows: &[Row],
expr: Option<&Expr>,
is_min: bool,
) -> Result<Value, ExecutionError> {
let Some(e) = expr else {
return Ok(Value::Null);
};
let mut best: Option<Value> = None;
for row in rows {
let v = engine.eval_expr(row, e)?;
if matches!(v, Value::Null) {
continue;
}
match &best {
None => best = Some(v),
Some(b) => {
if let Some(ord) = cmp_ordering(&v, b) {
if (is_min && ord == Ordering::Less) || (!is_min && ord == Ordering::Greater) {
best = Some(v);
}
}
}
}
}
Ok(best.unwrap_or(Value::Null))
}
fn to_numeric_vec(v: &Value) -> Result<Vec<f64>, ExecutionError> {
match v {
Value::List(items) => items
.iter()
.map(|item| match item {
Value::Int(x) => Ok(*x as f64),
Value::Float(x) => Ok(*x),
_ => Err(ExecutionError::ExpectedNumeric),
})
.collect(),
_ => Err(ExecutionError::ExpectedNumeric),
}
}
fn to_value_list(values: &[f64]) -> Vec<Value> {
values.iter().copied().map(Value::Float).collect()
}
fn vector_similarity(
metric: VectorMetric,
lhs: &Value,
rhs: &Value,
) -> Result<f64, ExecutionError> {
let lhs = to_numeric_vec(lhs)?;
let rhs = to_numeric_vec(rhs)?;
if lhs.len() != rhs.len() {
return Err(ExecutionError::ExpectedNumeric);
}
Ok(match metric {
VectorMetric::DotProduct => lhs.iter().zip(&rhs).map(|(a, b)| a * b).sum(),
VectorMetric::L2 => lhs
.iter()
.zip(&rhs)
.map(|(a, b)| {
let d = a - b;
d * d
})
.sum::<f64>()
.sqrt(),
VectorMetric::Cosine => {
let dot: f64 = lhs.iter().zip(&rhs).map(|(a, b)| a * b).sum();
let lhs_norm: f64 = lhs.iter().map(|x| x * x).sum::<f64>().sqrt();
let rhs_norm: f64 = rhs.iter().map(|x| x * x).sum::<f64>().sqrt();
if lhs_norm == 0.0 || rhs_norm == 0.0 {
0.0
} else {
dot / (lhs_norm * rhs_norm)
}
}
})
}
fn eval_arith(op: ArithOp, lhs: &Value, rhs: &Value) -> Result<Value, ExecutionError> {
use ArithOp::{Add, Div, Mul, Sub};
match (lhs, rhs) {
(Value::Int(a), Value::Int(b)) => match op {
Add => Ok(Value::Int(a + b)),
Sub => Ok(Value::Int(a - b)),
Mul => Ok(Value::Int(a * b)),
Div => Ok(Value::Float(*a as f64 / *b as f64)),
},
(Value::Int(a), Value::Float(b)) => Ok(Value::Float(eval_arith_f64(op, *a as f64, *b))),
(Value::Float(a), Value::Int(b)) => Ok(Value::Float(eval_arith_f64(op, *a, *b as f64))),
(Value::Float(a), Value::Float(b)) => Ok(Value::Float(eval_arith_f64(op, *a, *b))),
_ => Err(ExecutionError::ExpectedNumeric),
}
}
fn eval_arith_f64(op: ArithOp, lhs: f64, rhs: f64) -> f64 {
use ArithOp::{Add, Div, Mul, Sub};
match op {
Add => lhs + rhs,
Sub => lhs - rhs,
Mul => lhs * rhs,
Div => lhs / rhs,
}
}