Skip to main content

plexus_engine/engine/
vector.rs

1use super::eval::{as_bool, as_str, cmp_ordering, cmp_values};
2use crate::*;
3use plexus_serde::ArithOp;
4use plexus_serde::VectorMetric;
5use std::cmp::Ordering;
6use std::collections::BTreeMap;
7
8impl MockVectorEngine {
9    fn input_rows<'a>(
10        &self,
11        outputs: &'a [Option<RowSet>],
12        input: u32,
13    ) -> Result<&'a RowSet, ExecutionError> {
14        outputs
15            .get(input as usize)
16            .ok_or(ExecutionError::InvalidOpRef(input))?
17            .as_ref()
18            .ok_or(ExecutionError::MissingOpOutput(input))
19    }
20
21    fn eval_expr(&self, row: &Row, expr: &Expr) -> Result<Value, ExecutionError> {
22        Ok(match expr {
23            Expr::ColRef { idx } => {
24                row.get(*idx as usize)
25                    .cloned()
26                    .ok_or(ExecutionError::ColumnOutOfBounds {
27                        idx: *idx as usize,
28                        len: row.len(),
29                    })?
30            }
31            Expr::PropAccess { col, prop } => {
32                let v = row
33                    .get(*col as usize)
34                    .ok_or(ExecutionError::ColumnOutOfBounds {
35                        idx: *col as usize,
36                        len: row.len(),
37                    })?;
38                match v {
39                    Value::NodeRef(id) => self
40                        .base
41                        .graph
42                        .node_by_id(*id)
43                        .ok_or(ExecutionError::UnknownNode(*id))?
44                        .props
45                        .get(prop)
46                        .cloned()
47                        .unwrap_or(Value::Null),
48                    Value::RelRef(id) => self
49                        .base
50                        .graph
51                        .rel_by_id(*id)
52                        .ok_or(ExecutionError::UnknownRel(*id))?
53                        .props
54                        .get(prop)
55                        .cloned()
56                        .unwrap_or(Value::Null),
57                    _ => Value::Null,
58                }
59            }
60            Expr::IntLiteral(v) => Value::Int(*v),
61            Expr::FloatLiteral(v) => Value::Float(*v),
62            Expr::BoolLiteral(v) => Value::Bool(*v),
63            Expr::StringLiteral(v) => Value::String(v.clone()),
64            Expr::NullLiteral => Value::Null,
65            Expr::Cmp { op, lhs, rhs } => {
66                let l = self.eval_expr(row, lhs)?;
67                let r = self.eval_expr(row, rhs)?;
68                Value::Bool(cmp_values(*op, &l, &r))
69            }
70            Expr::And { lhs, rhs } => {
71                let l = self.eval_expr(row, lhs)?;
72                let r = self.eval_expr(row, rhs)?;
73                Value::Bool(as_bool(&l) && as_bool(&r))
74            }
75            Expr::Or { lhs, rhs } => {
76                let l = self.eval_expr(row, lhs)?;
77                let r = self.eval_expr(row, rhs)?;
78                Value::Bool(as_bool(&l) || as_bool(&r))
79            }
80            Expr::Not { expr } => {
81                let x = self.eval_expr(row, expr)?;
82                Value::Bool(!as_bool(&x))
83            }
84            Expr::IsNull { expr } => {
85                let x = self.eval_expr(row, expr)?;
86                Value::Bool(matches!(x, Value::Null))
87            }
88            Expr::IsNotNull { expr } => {
89                let x = self.eval_expr(row, expr)?;
90                Value::Bool(!matches!(x, Value::Null))
91            }
92            Expr::StartsWith { expr, pattern } => {
93                let x = self.eval_expr(row, expr)?;
94                Value::Bool(as_str(&x).is_some_and(|s| s.starts_with(pattern)))
95            }
96            Expr::EndsWith { expr, pattern } => {
97                let x = self.eval_expr(row, expr)?;
98                Value::Bool(as_str(&x).is_some_and(|s| s.ends_with(pattern)))
99            }
100            Expr::Contains { expr, pattern } => {
101                let x = self.eval_expr(row, expr)?;
102                Value::Bool(as_str(&x).is_some_and(|s| s.contains(pattern)))
103            }
104            Expr::In { expr, items } => {
105                let needle = self.eval_expr(row, expr)?;
106                let mut found = false;
107                for item in items {
108                    let v = self.eval_expr(row, item)?;
109                    if v == needle {
110                        found = true;
111                        break;
112                    }
113                }
114                Value::Bool(found)
115            }
116            Expr::ListLiteral { items } => {
117                let mut out = Vec::with_capacity(items.len());
118                for item in items {
119                    out.push(self.eval_expr(row, item)?);
120                }
121                Value::List(out)
122            }
123            Expr::MapLiteral { entries } => {
124                let mut out = BTreeMap::new();
125                for (k, v) in entries {
126                    out.insert(k.clone(), self.eval_expr(row, v)?);
127                }
128                Value::Map(out)
129            }
130            Expr::Exists { expr } => {
131                let x = self.eval_expr(row, expr)?;
132                Value::Bool(!matches!(x, Value::Null))
133            }
134            Expr::ListComprehension { .. } => {
135                return Err(ExecutionError::UnsupportedExpr("list comprehension"))
136            }
137            Expr::Agg { .. } => return Err(ExecutionError::ExpectedAggregateExpr),
138            Expr::Arith { op, lhs, rhs } => {
139                let l = self.eval_expr(row, lhs)?;
140                let r = self.eval_expr(row, rhs)?;
141                eval_arith(*op, &l, &r)?
142            }
143            Expr::Param { name, .. } => self
144                .base
145                .params
146                .get(name)
147                .cloned()
148                .ok_or_else(|| ExecutionError::UnboundParam(name.clone()))?,
149            Expr::Case { arms, else_expr } => {
150                let mut matched = None;
151                for (when_expr, then_expr) in arms {
152                    let cond = self.eval_expr(row, when_expr)?;
153                    if as_bool(&cond) {
154                        matched = Some(self.eval_expr(row, then_expr)?);
155                        break;
156                    }
157                }
158                match matched {
159                    Some(v) => v,
160                    None => match else_expr {
161                        Some(e) => self.eval_expr(row, e)?,
162                        None => Value::Null,
163                    },
164                }
165            }
166            Expr::VectorSimilarity { metric, lhs, rhs } => {
167                let lhs = self.eval_expr(row, lhs)?;
168                let rhs = self.eval_expr(row, rhs)?;
169                Value::Float(vector_similarity(*metric, &lhs, &rhs)?)
170            }
171        })
172    }
173
174    fn eval_agg(&self, rows: &[Row], expr: &Expr) -> Result<Value, ExecutionError> {
175        let Expr::Agg { fn_, expr } = expr else {
176            return Err(ExecutionError::ExpectedAggregateExpr);
177        };
178
179        match fn_ {
180            AggFn::CountStar => Ok(Value::Int(rows.len() as i64)),
181            AggFn::Count => {
182                let mut cnt = 0i64;
183                for row in rows {
184                    let Some(e) = expr else {
185                        continue;
186                    };
187                    let v = self.eval_expr(row, e)?;
188                    if !matches!(v, Value::Null) {
189                        cnt += 1;
190                    }
191                }
192                Ok(Value::Int(cnt))
193            }
194            AggFn::Sum => {
195                let mut saw_float = false;
196                let mut sum_i = 0i64;
197                let mut sum_f = 0.0f64;
198                for row in rows {
199                    let Some(e) = expr else {
200                        continue;
201                    };
202                    let v = self.eval_expr(row, e)?;
203                    match v {
204                        Value::Int(x) => {
205                            sum_i += x;
206                            sum_f += x as f64;
207                        }
208                        Value::Float(x) => {
209                            saw_float = true;
210                            sum_f += x;
211                        }
212                        Value::Null => {}
213                        _ => return Err(ExecutionError::ExpectedNumeric),
214                    }
215                }
216                if saw_float {
217                    Ok(Value::Float(sum_f))
218                } else {
219                    Ok(Value::Int(sum_i))
220                }
221            }
222            AggFn::Avg => {
223                let mut sum = 0.0f64;
224                let mut cnt = 0usize;
225                for row in rows {
226                    let Some(e) = expr else {
227                        continue;
228                    };
229                    let v = self.eval_expr(row, e)?;
230                    match v {
231                        Value::Int(x) => {
232                            sum += x as f64;
233                            cnt += 1;
234                        }
235                        Value::Float(x) => {
236                            sum += x;
237                            cnt += 1;
238                        }
239                        Value::Null => {}
240                        _ => return Err(ExecutionError::ExpectedNumeric),
241                    }
242                }
243                if cnt == 0 {
244                    Ok(Value::Null)
245                } else {
246                    Ok(Value::Float(sum / cnt as f64))
247                }
248            }
249            AggFn::Min => reduce_min_max_vector(self, rows, expr.as_deref(), true),
250            AggFn::Max => reduce_min_max_vector(self, rows, expr.as_deref(), false),
251            AggFn::Collect => {
252                let mut out = Vec::with_capacity(rows.len());
253                for row in rows {
254                    let Some(e) = expr else {
255                        continue;
256                    };
257                    out.push(self.eval_expr(row, e)?);
258                }
259                Ok(Value::List(out))
260            }
261        }
262    }
263
264    fn execute_filter_rows(
265        &self,
266        input_rows: &[Row],
267        predicate: &Expr,
268    ) -> Result<RowSet, ExecutionError> {
269        let mut out = Vec::new();
270        for row in input_rows {
271            if as_bool(&self.eval_expr(row, predicate)?) {
272                out.push(row.clone());
273            }
274        }
275        Ok(out)
276    }
277
278    fn execute_project_rows(
279        &self,
280        input_rows: &[Row],
281        exprs: &[Expr],
282    ) -> Result<RowSet, ExecutionError> {
283        let mut out = Vec::with_capacity(input_rows.len());
284        for row in input_rows {
285            let mut new_row = Vec::with_capacity(exprs.len());
286            for e in exprs {
287                new_row.push(self.eval_expr(row, e)?);
288            }
289            out.push(new_row);
290        }
291        Ok(out)
292    }
293
294    fn execute_unwind(&self, input: &[Row], list_expr: &Expr) -> Result<RowSet, ExecutionError> {
295        let mut out = Vec::new();
296        for row in input {
297            let value = self.eval_expr(row, list_expr)?;
298            match value {
299                Value::List(items) => {
300                    for item in items {
301                        let mut next = row.clone();
302                        next.push(item);
303                        out.push(next);
304                    }
305                }
306                Value::Null => {}
307                scalar => {
308                    let mut next = row.clone();
309                    next.push(scalar);
310                    out.push(next);
311                }
312            }
313        }
314        Ok(out)
315    }
316
317    fn execute_aggregate_rows(
318        &self,
319        input_rows: &[Row],
320        keys: &[u32],
321        aggs: &[Expr],
322    ) -> Result<RowSet, ExecutionError> {
323        let mut groups: Vec<(Vec<Value>, Vec<Row>)> = Vec::new();
324        for row in input_rows {
325            let key_vals: Vec<Value> = keys
326                .iter()
327                .map(|k| {
328                    row.get(*k as usize)
329                        .cloned()
330                        .ok_or(ExecutionError::ColumnOutOfBounds {
331                            idx: *k as usize,
332                            len: row.len(),
333                        })
334                })
335                .collect::<Result<Vec<_>, _>>()?;
336            if let Some((_, g_rows)) = groups.iter_mut().find(|(k, _)| *k == key_vals) {
337                g_rows.push(row.clone());
338            } else {
339                groups.push((key_vals, vec![row.clone()]));
340            }
341        }
342
343        let mut out = Vec::new();
344        for (key_vals, g_rows) in groups {
345            let mut out_row = key_vals;
346            for a in aggs {
347                out_row.push(self.eval_agg(&g_rows, a)?);
348            }
349            out.push(out_row);
350        }
351        Ok(out)
352    }
353
354    fn execute_vector_scan(
355        &self,
356        input_rows: &[Row],
357        collection: &str,
358        query_vector: &Expr,
359        metric: VectorMetric,
360        top_k: u32,
361    ) -> Result<RowSet, ExecutionError> {
362        let Some(entries) = self.collections.get(collection) else {
363            return Ok(Vec::new());
364        };
365
366        let mut out = Vec::new();
367        for row in input_rows {
368            let query = self.eval_expr(row, query_vector)?;
369            let mut scored = entries
370                .iter()
371                .enumerate()
372                .map(|(idx, entry)| {
373                    let score = vector_similarity(
374                        metric,
375                        &query,
376                        &Value::List(to_value_list(&entry.embedding)),
377                    )?;
378                    Ok::<_, ExecutionError>((idx, entry.node_id, score))
379                })
380                .collect::<Result<Vec<_>, _>>()?;
381
382            scored.sort_by(|(lhs_idx, _, lhs_score), (rhs_idx, _, rhs_score)| {
383                let ord = match metric {
384                    VectorMetric::L2 => lhs_score.partial_cmp(rhs_score).unwrap_or(Ordering::Equal),
385                    VectorMetric::Cosine | VectorMetric::DotProduct => {
386                        rhs_score.partial_cmp(lhs_score).unwrap_or(Ordering::Equal)
387                    }
388                };
389                if ord == Ordering::Equal {
390                    lhs_idx.cmp(rhs_idx)
391                } else {
392                    ord
393                }
394            });
395
396            for (_, node_id, score) in scored.into_iter().take(top_k as usize) {
397                out.push(vec![Value::NodeRef(node_id), Value::Float(score)]);
398            }
399        }
400        Ok(out)
401    }
402
403    fn execute_rerank(
404        &self,
405        input_rows: &[Row],
406        score_expr: &Expr,
407        top_k: u32,
408    ) -> Result<RowSet, ExecutionError> {
409        let mut scored = input_rows
410            .iter()
411            .enumerate()
412            .map(|(idx, row)| {
413                let score = self.eval_expr(row, score_expr)?;
414                Ok::<_, ExecutionError>((idx, row.clone(), score))
415            })
416            .collect::<Result<Vec<_>, _>>()?;
417
418        scored.sort_by(|(lhs_idx, _, lhs_score), (rhs_idx, _, rhs_score)| {
419            let ord = cmp_ordering(lhs_score, rhs_score)
420                .unwrap_or(Ordering::Equal)
421                .reverse();
422            if ord == Ordering::Equal {
423                lhs_idx.cmp(rhs_idx)
424            } else {
425                ord
426            }
427        });
428
429        Ok(scored
430            .into_iter()
431            .take(top_k as usize)
432            .map(|(_, row, _)| row)
433            .collect())
434    }
435}
436
437impl PlanEngine for MockVectorEngine {
438    type Error = ExecutionError;
439
440    fn execute_plan(&mut self, plan: &Plan) -> Result<QueryResult, Self::Error> {
441        let mut seen_ref: Option<&str> = None;
442        for op in &plan.ops {
443            let graph_ref = match op {
444                Op::ScanNodes { graph_ref, .. }
445                | Op::Expand { graph_ref, .. }
446                | Op::OptionalExpand { graph_ref, .. }
447                | Op::ExpandVarLen { graph_ref, .. } => graph_ref.as_deref(),
448                _ => None,
449            };
450            if let Some(r) = graph_ref.map(str::trim).filter(|s| !s.is_empty()) {
451                match seen_ref {
452                    None => seen_ref = Some(r),
453                    Some(prev) if prev != r => return Err(ExecutionError::MultiGraphUnsupported),
454                    _ => {}
455                }
456            }
457        }
458
459        let mut outputs: Vec<Option<RowSet>> = vec![None; plan.ops.len()];
460        for (idx, op) in plan.ops.iter().enumerate() {
461            let rows = match op {
462                Op::ScanNodes {
463                    labels,
464                    must_labels,
465                    forbidden_labels,
466                    ..
467                } => self
468                    .base
469                    .execute_scan_nodes(labels, must_labels, forbidden_labels),
470                Op::ScanRels {
471                    types,
472                    src_labels,
473                    dst_labels,
474                    ..
475                } => self.base.execute_scan_rels(types, src_labels, dst_labels),
476                Op::Expand {
477                    input,
478                    src_col,
479                    types,
480                    dir,
481                    legal_src_labels,
482                    legal_dst_labels,
483                    ..
484                } => self.base.execute_expand(
485                    self.input_rows(&outputs, *input)?,
486                    *src_col,
487                    types,
488                    *dir,
489                    legal_src_labels,
490                    legal_dst_labels,
491                )?,
492                Op::OptionalExpand {
493                    input,
494                    src_col,
495                    types,
496                    dir,
497                    legal_src_labels,
498                    legal_dst_labels,
499                    ..
500                } => self.base.execute_optional_expand(
501                    self.input_rows(&outputs, *input)?,
502                    *src_col,
503                    types,
504                    *dir,
505                    legal_src_labels,
506                    legal_dst_labels,
507                )?,
508                Op::SemiExpand {
509                    input,
510                    src_col,
511                    types,
512                    dir,
513                    legal_src_labels,
514                    legal_dst_labels,
515                    ..
516                } => self.base.execute_semi_expand(
517                    self.input_rows(&outputs, *input)?,
518                    *src_col,
519                    types,
520                    *dir,
521                    legal_src_labels,
522                    legal_dst_labels,
523                )?,
524                Op::ExpandVarLen {
525                    input,
526                    src_col,
527                    types,
528                    dir,
529                    min_hops,
530                    max_hops,
531                    ..
532                } => self.base.execute_expand_var_len(
533                    self.input_rows(&outputs, *input)?,
534                    *src_col,
535                    types,
536                    *dir,
537                    *min_hops,
538                    *max_hops,
539                )?,
540                Op::Filter { input, predicate } => {
541                    self.execute_filter_rows(self.input_rows(&outputs, *input)?, predicate)?
542                }
543                Op::BlockMarker { input, .. } => self.input_rows(&outputs, *input)?.clone(),
544                Op::Project { input, exprs, .. } => {
545                    self.execute_project_rows(self.input_rows(&outputs, *input)?, exprs)?
546                }
547                Op::Aggregate {
548                    input, keys, aggs, ..
549                } => self.execute_aggregate_rows(self.input_rows(&outputs, *input)?, keys, aggs)?,
550                Op::Sort { input, keys, dirs } => {
551                    self.base
552                        .execute_sort_rows(self.input_rows(&outputs, *input)?, keys, dirs)?
553                }
554                Op::Limit { input, count, skip, .. } => {
555                    self.base
556                        .execute_limit_rows(self.input_rows(&outputs, *input)?, *count, *skip)
557                }
558                Op::Return { input } => self.input_rows(&outputs, *input)?.clone(),
559                Op::Unwind {
560                    input, list_expr, ..
561                } => self.execute_unwind(self.input_rows(&outputs, *input)?, list_expr)?,
562                Op::PathConstruct {
563                    input, rel_cols, ..
564                } => self
565                    .base
566                    .execute_path_construct(self.input_rows(&outputs, *input)?, rel_cols)?,
567                Op::Union { lhs, rhs, all, .. } => self.base.execute_union_rows(
568                    self.input_rows(&outputs, *lhs)?,
569                    self.input_rows(&outputs, *rhs)?,
570                    *all,
571                ),
572                Op::VectorScan {
573                    input,
574                    collection,
575                    query_vector,
576                    metric,
577                    top_k,
578                    ..
579                } => self.execute_vector_scan(
580                    self.input_rows(&outputs, *input)?,
581                    collection,
582                    query_vector,
583                    *metric,
584                    *top_k,
585                )?,
586                Op::Rerank {
587                    input,
588                    score_expr,
589                    top_k,
590                    ..
591                } => self.execute_rerank(self.input_rows(&outputs, *input)?, score_expr, *top_k)?,
592                Op::CreateNode { .. }
593                | Op::CreateRel { .. }
594                | Op::Merge { .. }
595                | Op::Delete { .. }
596                | Op::SetProperty { .. }
597                | Op::RemoveProperty { .. } => {
598                    return Err(ExecutionError::UnsupportedOp("dml in mock vector engine"));
599                }
600                Op::ConstRow => vec![vec![]],
601            };
602            outputs[idx] = Some(rows);
603        }
604
605        let root_rows = outputs
606            .get(plan.root_op as usize)
607            .ok_or(ExecutionError::InvalidRootOp(plan.root_op))?
608            .clone()
609            .ok_or(ExecutionError::InvalidRootOp(plan.root_op))?;
610        Ok(QueryResult {
611            rows: root_rows,
612            continuation: None,
613        })
614    }
615}
616
617fn reduce_min_max_vector(
618    engine: &MockVectorEngine,
619    rows: &[Row],
620    expr: Option<&Expr>,
621    is_min: bool,
622) -> Result<Value, ExecutionError> {
623    let Some(e) = expr else {
624        return Ok(Value::Null);
625    };
626    let mut best: Option<Value> = None;
627    for row in rows {
628        let v = engine.eval_expr(row, e)?;
629        if matches!(v, Value::Null) {
630            continue;
631        }
632        match &best {
633            None => best = Some(v),
634            Some(b) => {
635                if let Some(ord) = cmp_ordering(&v, b) {
636                    if (is_min && ord == Ordering::Less) || (!is_min && ord == Ordering::Greater) {
637                        best = Some(v);
638                    }
639                }
640            }
641        }
642    }
643    Ok(best.unwrap_or(Value::Null))
644}
645
646fn to_numeric_vec(v: &Value) -> Result<Vec<f64>, ExecutionError> {
647    match v {
648        Value::List(items) => items
649            .iter()
650            .map(|item| match item {
651                Value::Int(x) => Ok(*x as f64),
652                Value::Float(x) => Ok(*x),
653                _ => Err(ExecutionError::ExpectedNumeric),
654            })
655            .collect(),
656        _ => Err(ExecutionError::ExpectedNumeric),
657    }
658}
659
660fn to_value_list(values: &[f64]) -> Vec<Value> {
661    values.iter().copied().map(Value::Float).collect()
662}
663
664fn vector_similarity(
665    metric: VectorMetric,
666    lhs: &Value,
667    rhs: &Value,
668) -> Result<f64, ExecutionError> {
669    let lhs = to_numeric_vec(lhs)?;
670    let rhs = to_numeric_vec(rhs)?;
671    if lhs.len() != rhs.len() {
672        return Err(ExecutionError::ExpectedNumeric);
673    }
674    Ok(match metric {
675        VectorMetric::DotProduct => lhs.iter().zip(&rhs).map(|(a, b)| a * b).sum(),
676        VectorMetric::L2 => lhs
677            .iter()
678            .zip(&rhs)
679            .map(|(a, b)| {
680                let d = a - b;
681                d * d
682            })
683            .sum::<f64>()
684            .sqrt(),
685        VectorMetric::Cosine => {
686            let dot: f64 = lhs.iter().zip(&rhs).map(|(a, b)| a * b).sum();
687            let lhs_norm: f64 = lhs.iter().map(|x| x * x).sum::<f64>().sqrt();
688            let rhs_norm: f64 = rhs.iter().map(|x| x * x).sum::<f64>().sqrt();
689            if lhs_norm == 0.0 || rhs_norm == 0.0 {
690                0.0
691            } else {
692                dot / (lhs_norm * rhs_norm)
693            }
694        }
695    })
696}
697
698fn eval_arith(op: ArithOp, lhs: &Value, rhs: &Value) -> Result<Value, ExecutionError> {
699    use ArithOp::{Add, Div, Mul, Sub};
700    match (lhs, rhs) {
701        (Value::Int(a), Value::Int(b)) => match op {
702            Add => Ok(Value::Int(a + b)),
703            Sub => Ok(Value::Int(a - b)),
704            Mul => Ok(Value::Int(a * b)),
705            Div => Ok(Value::Float(*a as f64 / *b as f64)),
706        },
707        (Value::Int(a), Value::Float(b)) => Ok(Value::Float(eval_arith_f64(op, *a as f64, *b))),
708        (Value::Float(a), Value::Int(b)) => Ok(Value::Float(eval_arith_f64(op, *a, *b as f64))),
709        (Value::Float(a), Value::Float(b)) => Ok(Value::Float(eval_arith_f64(op, *a, *b))),
710        _ => Err(ExecutionError::ExpectedNumeric),
711    }
712}
713
714fn eval_arith_f64(op: ArithOp, lhs: f64, rhs: f64) -> f64 {
715    use ArithOp::{Add, Div, Mul, Sub};
716    match op {
717        Add => lhs + rhs,
718        Sub => lhs - rhs,
719        Mul => lhs * rhs,
720        Div => lhs / rhs,
721    }
722}