1use super::eval::{as_bool, as_str, cmp_ordering, cmp_values};
2use crate::*;
3use plexus_serde::ArithOp;
4use plexus_serde::VectorMetric;
5use std::cmp::Ordering;
6use std::collections::BTreeMap;
7
8impl MockVectorEngine {
9 fn input_rows<'a>(
10 &self,
11 outputs: &'a [Option<RowSet>],
12 input: u32,
13 ) -> Result<&'a RowSet, ExecutionError> {
14 outputs
15 .get(input as usize)
16 .ok_or(ExecutionError::InvalidOpRef(input))?
17 .as_ref()
18 .ok_or(ExecutionError::MissingOpOutput(input))
19 }
20
21 fn eval_expr(&self, row: &Row, expr: &Expr) -> Result<Value, ExecutionError> {
22 Ok(match expr {
23 Expr::ColRef { idx } => {
24 row.get(*idx as usize)
25 .cloned()
26 .ok_or(ExecutionError::ColumnOutOfBounds {
27 idx: *idx as usize,
28 len: row.len(),
29 })?
30 }
31 Expr::PropAccess { col, prop } => {
32 let v = row
33 .get(*col as usize)
34 .ok_or(ExecutionError::ColumnOutOfBounds {
35 idx: *col as usize,
36 len: row.len(),
37 })?;
38 match v {
39 Value::NodeRef(id) => self
40 .base
41 .graph
42 .node_by_id(*id)
43 .ok_or(ExecutionError::UnknownNode(*id))?
44 .props
45 .get(prop)
46 .cloned()
47 .unwrap_or(Value::Null),
48 Value::RelRef(id) => self
49 .base
50 .graph
51 .rel_by_id(*id)
52 .ok_or(ExecutionError::UnknownRel(*id))?
53 .props
54 .get(prop)
55 .cloned()
56 .unwrap_or(Value::Null),
57 _ => Value::Null,
58 }
59 }
60 Expr::IntLiteral(v) => Value::Int(*v),
61 Expr::FloatLiteral(v) => Value::Float(*v),
62 Expr::BoolLiteral(v) => Value::Bool(*v),
63 Expr::StringLiteral(v) => Value::String(v.clone()),
64 Expr::NullLiteral => Value::Null,
65 Expr::Cmp { op, lhs, rhs } => {
66 let l = self.eval_expr(row, lhs)?;
67 let r = self.eval_expr(row, rhs)?;
68 Value::Bool(cmp_values(*op, &l, &r))
69 }
70 Expr::And { lhs, rhs } => {
71 let l = self.eval_expr(row, lhs)?;
72 let r = self.eval_expr(row, rhs)?;
73 Value::Bool(as_bool(&l) && as_bool(&r))
74 }
75 Expr::Or { lhs, rhs } => {
76 let l = self.eval_expr(row, lhs)?;
77 let r = self.eval_expr(row, rhs)?;
78 Value::Bool(as_bool(&l) || as_bool(&r))
79 }
80 Expr::Not { expr } => {
81 let x = self.eval_expr(row, expr)?;
82 Value::Bool(!as_bool(&x))
83 }
84 Expr::IsNull { expr } => {
85 let x = self.eval_expr(row, expr)?;
86 Value::Bool(matches!(x, Value::Null))
87 }
88 Expr::IsNotNull { expr } => {
89 let x = self.eval_expr(row, expr)?;
90 Value::Bool(!matches!(x, Value::Null))
91 }
92 Expr::StartsWith { expr, pattern } => {
93 let x = self.eval_expr(row, expr)?;
94 Value::Bool(as_str(&x).is_some_and(|s| s.starts_with(pattern)))
95 }
96 Expr::EndsWith { expr, pattern } => {
97 let x = self.eval_expr(row, expr)?;
98 Value::Bool(as_str(&x).is_some_and(|s| s.ends_with(pattern)))
99 }
100 Expr::Contains { expr, pattern } => {
101 let x = self.eval_expr(row, expr)?;
102 Value::Bool(as_str(&x).is_some_and(|s| s.contains(pattern)))
103 }
104 Expr::In { expr, items } => {
105 let needle = self.eval_expr(row, expr)?;
106 let mut found = false;
107 for item in items {
108 let v = self.eval_expr(row, item)?;
109 if v == needle {
110 found = true;
111 break;
112 }
113 }
114 Value::Bool(found)
115 }
116 Expr::ListLiteral { items } => {
117 let mut out = Vec::with_capacity(items.len());
118 for item in items {
119 out.push(self.eval_expr(row, item)?);
120 }
121 Value::List(out)
122 }
123 Expr::MapLiteral { entries } => {
124 let mut out = BTreeMap::new();
125 for (k, v) in entries {
126 out.insert(k.clone(), self.eval_expr(row, v)?);
127 }
128 Value::Map(out)
129 }
130 Expr::Exists { expr } => {
131 let x = self.eval_expr(row, expr)?;
132 Value::Bool(!matches!(x, Value::Null))
133 }
134 Expr::ListComprehension { .. } => {
135 return Err(ExecutionError::UnsupportedExpr("list comprehension"))
136 }
137 Expr::Agg { .. } => return Err(ExecutionError::ExpectedAggregateExpr),
138 Expr::Arith { op, lhs, rhs } => {
139 let l = self.eval_expr(row, lhs)?;
140 let r = self.eval_expr(row, rhs)?;
141 eval_arith(*op, &l, &r)?
142 }
143 Expr::Param { name, .. } => self
144 .base
145 .params
146 .get(name)
147 .cloned()
148 .ok_or_else(|| ExecutionError::UnboundParam(name.clone()))?,
149 Expr::Case { arms, else_expr } => {
150 let mut matched = None;
151 for (when_expr, then_expr) in arms {
152 let cond = self.eval_expr(row, when_expr)?;
153 if as_bool(&cond) {
154 matched = Some(self.eval_expr(row, then_expr)?);
155 break;
156 }
157 }
158 match matched {
159 Some(v) => v,
160 None => match else_expr {
161 Some(e) => self.eval_expr(row, e)?,
162 None => Value::Null,
163 },
164 }
165 }
166 Expr::VectorSimilarity { metric, lhs, rhs } => {
167 let lhs = self.eval_expr(row, lhs)?;
168 let rhs = self.eval_expr(row, rhs)?;
169 Value::Float(vector_similarity(*metric, &lhs, &rhs)?)
170 }
171 })
172 }
173
174 fn eval_agg(&self, rows: &[Row], expr: &Expr) -> Result<Value, ExecutionError> {
175 let Expr::Agg { fn_, expr } = expr else {
176 return Err(ExecutionError::ExpectedAggregateExpr);
177 };
178
179 match fn_ {
180 AggFn::CountStar => Ok(Value::Int(rows.len() as i64)),
181 AggFn::Count => {
182 let mut cnt = 0i64;
183 for row in rows {
184 let Some(e) = expr else {
185 continue;
186 };
187 let v = self.eval_expr(row, e)?;
188 if !matches!(v, Value::Null) {
189 cnt += 1;
190 }
191 }
192 Ok(Value::Int(cnt))
193 }
194 AggFn::Sum => {
195 let mut saw_float = false;
196 let mut sum_i = 0i64;
197 let mut sum_f = 0.0f64;
198 for row in rows {
199 let Some(e) = expr else {
200 continue;
201 };
202 let v = self.eval_expr(row, e)?;
203 match v {
204 Value::Int(x) => {
205 sum_i += x;
206 sum_f += x as f64;
207 }
208 Value::Float(x) => {
209 saw_float = true;
210 sum_f += x;
211 }
212 Value::Null => {}
213 _ => return Err(ExecutionError::ExpectedNumeric),
214 }
215 }
216 if saw_float {
217 Ok(Value::Float(sum_f))
218 } else {
219 Ok(Value::Int(sum_i))
220 }
221 }
222 AggFn::Avg => {
223 let mut sum = 0.0f64;
224 let mut cnt = 0usize;
225 for row in rows {
226 let Some(e) = expr else {
227 continue;
228 };
229 let v = self.eval_expr(row, e)?;
230 match v {
231 Value::Int(x) => {
232 sum += x as f64;
233 cnt += 1;
234 }
235 Value::Float(x) => {
236 sum += x;
237 cnt += 1;
238 }
239 Value::Null => {}
240 _ => return Err(ExecutionError::ExpectedNumeric),
241 }
242 }
243 if cnt == 0 {
244 Ok(Value::Null)
245 } else {
246 Ok(Value::Float(sum / cnt as f64))
247 }
248 }
249 AggFn::Min => reduce_min_max_vector(self, rows, expr.as_deref(), true),
250 AggFn::Max => reduce_min_max_vector(self, rows, expr.as_deref(), false),
251 AggFn::Collect => {
252 let mut out = Vec::with_capacity(rows.len());
253 for row in rows {
254 let Some(e) = expr else {
255 continue;
256 };
257 out.push(self.eval_expr(row, e)?);
258 }
259 Ok(Value::List(out))
260 }
261 }
262 }
263
264 fn execute_filter_rows(
265 &self,
266 input_rows: &[Row],
267 predicate: &Expr,
268 ) -> Result<RowSet, ExecutionError> {
269 let mut out = Vec::new();
270 for row in input_rows {
271 if as_bool(&self.eval_expr(row, predicate)?) {
272 out.push(row.clone());
273 }
274 }
275 Ok(out)
276 }
277
278 fn execute_project_rows(
279 &self,
280 input_rows: &[Row],
281 exprs: &[Expr],
282 ) -> Result<RowSet, ExecutionError> {
283 let mut out = Vec::with_capacity(input_rows.len());
284 for row in input_rows {
285 let mut new_row = Vec::with_capacity(exprs.len());
286 for e in exprs {
287 new_row.push(self.eval_expr(row, e)?);
288 }
289 out.push(new_row);
290 }
291 Ok(out)
292 }
293
294 fn execute_unwind(&self, input: &[Row], list_expr: &Expr) -> Result<RowSet, ExecutionError> {
295 let mut out = Vec::new();
296 for row in input {
297 let value = self.eval_expr(row, list_expr)?;
298 match value {
299 Value::List(items) => {
300 for item in items {
301 let mut next = row.clone();
302 next.push(item);
303 out.push(next);
304 }
305 }
306 Value::Null => {}
307 scalar => {
308 let mut next = row.clone();
309 next.push(scalar);
310 out.push(next);
311 }
312 }
313 }
314 Ok(out)
315 }
316
317 fn execute_aggregate_rows(
318 &self,
319 input_rows: &[Row],
320 keys: &[u32],
321 aggs: &[Expr],
322 ) -> Result<RowSet, ExecutionError> {
323 let mut groups: Vec<(Vec<Value>, Vec<Row>)> = Vec::new();
324 for row in input_rows {
325 let key_vals: Vec<Value> = keys
326 .iter()
327 .map(|k| {
328 row.get(*k as usize)
329 .cloned()
330 .ok_or(ExecutionError::ColumnOutOfBounds {
331 idx: *k as usize,
332 len: row.len(),
333 })
334 })
335 .collect::<Result<Vec<_>, _>>()?;
336 if let Some((_, g_rows)) = groups.iter_mut().find(|(k, _)| *k == key_vals) {
337 g_rows.push(row.clone());
338 } else {
339 groups.push((key_vals, vec![row.clone()]));
340 }
341 }
342
343 let mut out = Vec::new();
344 for (key_vals, g_rows) in groups {
345 let mut out_row = key_vals;
346 for a in aggs {
347 out_row.push(self.eval_agg(&g_rows, a)?);
348 }
349 out.push(out_row);
350 }
351 Ok(out)
352 }
353
354 fn execute_vector_scan(
355 &self,
356 input_rows: &[Row],
357 collection: &str,
358 query_vector: &Expr,
359 metric: VectorMetric,
360 top_k: u32,
361 ) -> Result<RowSet, ExecutionError> {
362 let Some(entries) = self.collections.get(collection) else {
363 return Ok(Vec::new());
364 };
365
366 let mut out = Vec::new();
367 for row in input_rows {
368 let query = self.eval_expr(row, query_vector)?;
369 let mut scored = entries
370 .iter()
371 .enumerate()
372 .map(|(idx, entry)| {
373 let score = vector_similarity(
374 metric,
375 &query,
376 &Value::List(to_value_list(&entry.embedding)),
377 )?;
378 Ok::<_, ExecutionError>((idx, entry.node_id, score))
379 })
380 .collect::<Result<Vec<_>, _>>()?;
381
382 scored.sort_by(|(lhs_idx, _, lhs_score), (rhs_idx, _, rhs_score)| {
383 let ord = match metric {
384 VectorMetric::L2 => lhs_score.partial_cmp(rhs_score).unwrap_or(Ordering::Equal),
385 VectorMetric::Cosine | VectorMetric::DotProduct => {
386 rhs_score.partial_cmp(lhs_score).unwrap_or(Ordering::Equal)
387 }
388 };
389 if ord == Ordering::Equal {
390 lhs_idx.cmp(rhs_idx)
391 } else {
392 ord
393 }
394 });
395
396 for (_, node_id, score) in scored.into_iter().take(top_k as usize) {
397 out.push(vec![Value::NodeRef(node_id), Value::Float(score)]);
398 }
399 }
400 Ok(out)
401 }
402
403 fn execute_rerank(
404 &self,
405 input_rows: &[Row],
406 score_expr: &Expr,
407 top_k: u32,
408 ) -> Result<RowSet, ExecutionError> {
409 let mut scored = input_rows
410 .iter()
411 .enumerate()
412 .map(|(idx, row)| {
413 let score = self.eval_expr(row, score_expr)?;
414 Ok::<_, ExecutionError>((idx, row.clone(), score))
415 })
416 .collect::<Result<Vec<_>, _>>()?;
417
418 scored.sort_by(|(lhs_idx, _, lhs_score), (rhs_idx, _, rhs_score)| {
419 let ord = cmp_ordering(lhs_score, rhs_score)
420 .unwrap_or(Ordering::Equal)
421 .reverse();
422 if ord == Ordering::Equal {
423 lhs_idx.cmp(rhs_idx)
424 } else {
425 ord
426 }
427 });
428
429 Ok(scored
430 .into_iter()
431 .take(top_k as usize)
432 .map(|(_, row, _)| row)
433 .collect())
434 }
435}
436
437impl PlanEngine for MockVectorEngine {
438 type Error = ExecutionError;
439
440 fn execute_plan(&mut self, plan: &Plan) -> Result<QueryResult, Self::Error> {
441 let mut seen_ref: Option<&str> = None;
442 for op in &plan.ops {
443 let graph_ref = match op {
444 Op::ScanNodes { graph_ref, .. }
445 | Op::Expand { graph_ref, .. }
446 | Op::OptionalExpand { graph_ref, .. }
447 | Op::ExpandVarLen { graph_ref, .. } => graph_ref.as_deref(),
448 _ => None,
449 };
450 if let Some(r) = graph_ref.map(str::trim).filter(|s| !s.is_empty()) {
451 match seen_ref {
452 None => seen_ref = Some(r),
453 Some(prev) if prev != r => return Err(ExecutionError::MultiGraphUnsupported),
454 _ => {}
455 }
456 }
457 }
458
459 let mut outputs: Vec<Option<RowSet>> = vec![None; plan.ops.len()];
460 for (idx, op) in plan.ops.iter().enumerate() {
461 let rows = match op {
462 Op::ScanNodes {
463 labels,
464 must_labels,
465 forbidden_labels,
466 ..
467 } => self
468 .base
469 .execute_scan_nodes(labels, must_labels, forbidden_labels),
470 Op::ScanRels {
471 types,
472 src_labels,
473 dst_labels,
474 ..
475 } => self.base.execute_scan_rels(types, src_labels, dst_labels),
476 Op::Expand {
477 input,
478 src_col,
479 types,
480 dir,
481 legal_src_labels,
482 legal_dst_labels,
483 ..
484 } => self.base.execute_expand(
485 self.input_rows(&outputs, *input)?,
486 *src_col,
487 types,
488 *dir,
489 legal_src_labels,
490 legal_dst_labels,
491 )?,
492 Op::OptionalExpand {
493 input,
494 src_col,
495 types,
496 dir,
497 legal_src_labels,
498 legal_dst_labels,
499 ..
500 } => self.base.execute_optional_expand(
501 self.input_rows(&outputs, *input)?,
502 *src_col,
503 types,
504 *dir,
505 legal_src_labels,
506 legal_dst_labels,
507 )?,
508 Op::SemiExpand {
509 input,
510 src_col,
511 types,
512 dir,
513 legal_src_labels,
514 legal_dst_labels,
515 ..
516 } => self.base.execute_semi_expand(
517 self.input_rows(&outputs, *input)?,
518 *src_col,
519 types,
520 *dir,
521 legal_src_labels,
522 legal_dst_labels,
523 )?,
524 Op::ExpandVarLen {
525 input,
526 src_col,
527 types,
528 dir,
529 min_hops,
530 max_hops,
531 ..
532 } => self.base.execute_expand_var_len(
533 self.input_rows(&outputs, *input)?,
534 *src_col,
535 types,
536 *dir,
537 *min_hops,
538 *max_hops,
539 )?,
540 Op::Filter { input, predicate } => {
541 self.execute_filter_rows(self.input_rows(&outputs, *input)?, predicate)?
542 }
543 Op::BlockMarker { input, .. } => self.input_rows(&outputs, *input)?.clone(),
544 Op::Project { input, exprs, .. } => {
545 self.execute_project_rows(self.input_rows(&outputs, *input)?, exprs)?
546 }
547 Op::Aggregate {
548 input, keys, aggs, ..
549 } => self.execute_aggregate_rows(self.input_rows(&outputs, *input)?, keys, aggs)?,
550 Op::Sort { input, keys, dirs } => {
551 self.base
552 .execute_sort_rows(self.input_rows(&outputs, *input)?, keys, dirs)?
553 }
554 Op::Limit { input, count, skip, .. } => {
555 self.base
556 .execute_limit_rows(self.input_rows(&outputs, *input)?, *count, *skip)
557 }
558 Op::Return { input } => self.input_rows(&outputs, *input)?.clone(),
559 Op::Unwind {
560 input, list_expr, ..
561 } => self.execute_unwind(self.input_rows(&outputs, *input)?, list_expr)?,
562 Op::PathConstruct {
563 input, rel_cols, ..
564 } => self
565 .base
566 .execute_path_construct(self.input_rows(&outputs, *input)?, rel_cols)?,
567 Op::Union { lhs, rhs, all, .. } => self.base.execute_union_rows(
568 self.input_rows(&outputs, *lhs)?,
569 self.input_rows(&outputs, *rhs)?,
570 *all,
571 ),
572 Op::VectorScan {
573 input,
574 collection,
575 query_vector,
576 metric,
577 top_k,
578 ..
579 } => self.execute_vector_scan(
580 self.input_rows(&outputs, *input)?,
581 collection,
582 query_vector,
583 *metric,
584 *top_k,
585 )?,
586 Op::Rerank {
587 input,
588 score_expr,
589 top_k,
590 ..
591 } => self.execute_rerank(self.input_rows(&outputs, *input)?, score_expr, *top_k)?,
592 Op::CreateNode { .. }
593 | Op::CreateRel { .. }
594 | Op::Merge { .. }
595 | Op::Delete { .. }
596 | Op::SetProperty { .. }
597 | Op::RemoveProperty { .. } => {
598 return Err(ExecutionError::UnsupportedOp("dml in mock vector engine"));
599 }
600 Op::ConstRow => vec![vec![]],
601 };
602 outputs[idx] = Some(rows);
603 }
604
605 let root_rows = outputs
606 .get(plan.root_op as usize)
607 .ok_or(ExecutionError::InvalidRootOp(plan.root_op))?
608 .clone()
609 .ok_or(ExecutionError::InvalidRootOp(plan.root_op))?;
610 Ok(QueryResult {
611 rows: root_rows,
612 continuation: None,
613 })
614 }
615}
616
617fn reduce_min_max_vector(
618 engine: &MockVectorEngine,
619 rows: &[Row],
620 expr: Option<&Expr>,
621 is_min: bool,
622) -> Result<Value, ExecutionError> {
623 let Some(e) = expr else {
624 return Ok(Value::Null);
625 };
626 let mut best: Option<Value> = None;
627 for row in rows {
628 let v = engine.eval_expr(row, e)?;
629 if matches!(v, Value::Null) {
630 continue;
631 }
632 match &best {
633 None => best = Some(v),
634 Some(b) => {
635 if let Some(ord) = cmp_ordering(&v, b) {
636 if (is_min && ord == Ordering::Less) || (!is_min && ord == Ordering::Greater) {
637 best = Some(v);
638 }
639 }
640 }
641 }
642 }
643 Ok(best.unwrap_or(Value::Null))
644}
645
646fn to_numeric_vec(v: &Value) -> Result<Vec<f64>, ExecutionError> {
647 match v {
648 Value::List(items) => items
649 .iter()
650 .map(|item| match item {
651 Value::Int(x) => Ok(*x as f64),
652 Value::Float(x) => Ok(*x),
653 _ => Err(ExecutionError::ExpectedNumeric),
654 })
655 .collect(),
656 _ => Err(ExecutionError::ExpectedNumeric),
657 }
658}
659
660fn to_value_list(values: &[f64]) -> Vec<Value> {
661 values.iter().copied().map(Value::Float).collect()
662}
663
664fn vector_similarity(
665 metric: VectorMetric,
666 lhs: &Value,
667 rhs: &Value,
668) -> Result<f64, ExecutionError> {
669 let lhs = to_numeric_vec(lhs)?;
670 let rhs = to_numeric_vec(rhs)?;
671 if lhs.len() != rhs.len() {
672 return Err(ExecutionError::ExpectedNumeric);
673 }
674 Ok(match metric {
675 VectorMetric::DotProduct => lhs.iter().zip(&rhs).map(|(a, b)| a * b).sum(),
676 VectorMetric::L2 => lhs
677 .iter()
678 .zip(&rhs)
679 .map(|(a, b)| {
680 let d = a - b;
681 d * d
682 })
683 .sum::<f64>()
684 .sqrt(),
685 VectorMetric::Cosine => {
686 let dot: f64 = lhs.iter().zip(&rhs).map(|(a, b)| a * b).sum();
687 let lhs_norm: f64 = lhs.iter().map(|x| x * x).sum::<f64>().sqrt();
688 let rhs_norm: f64 = rhs.iter().map(|x| x * x).sum::<f64>().sqrt();
689 if lhs_norm == 0.0 || rhs_norm == 0.0 {
690 0.0
691 } else {
692 dot / (lhs_norm * rhs_norm)
693 }
694 }
695 })
696}
697
698fn eval_arith(op: ArithOp, lhs: &Value, rhs: &Value) -> Result<Value, ExecutionError> {
699 use ArithOp::{Add, Div, Mul, Sub};
700 match (lhs, rhs) {
701 (Value::Int(a), Value::Int(b)) => match op {
702 Add => Ok(Value::Int(a + b)),
703 Sub => Ok(Value::Int(a - b)),
704 Mul => Ok(Value::Int(a * b)),
705 Div => Ok(Value::Float(*a as f64 / *b as f64)),
706 },
707 (Value::Int(a), Value::Float(b)) => Ok(Value::Float(eval_arith_f64(op, *a as f64, *b))),
708 (Value::Float(a), Value::Int(b)) => Ok(Value::Float(eval_arith_f64(op, *a, *b as f64))),
709 (Value::Float(a), Value::Float(b)) => Ok(Value::Float(eval_arith_f64(op, *a, *b))),
710 _ => Err(ExecutionError::ExpectedNumeric),
711 }
712}
713
714fn eval_arith_f64(op: ArithOp, lhs: f64, rhs: f64) -> f64 {
715 use ArithOp::{Add, Div, Mul, Sub};
716 match op {
717 Add => lhs + rhs,
718 Sub => lhs - rhs,
719 Mul => lhs * rhs,
720 Div => lhs / rhs,
721 }
722}