use super::super::super::ast::{is_aggregate_expression, Expression, ReturnClause};
use super::super::super::result::{Bindings, ResultRow};
use super::super::helpers::return_item_column_name;
use super::super::CypherExecutor;
use super::RowStream;
use crate::datatypes::values::Value;
use petgraph::graph::NodeIndex;
use rustc_hash::FxHashMap;
use std::collections::{HashMap, HashSet};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum GroupKeyPart {
NodeProp(NodeIndex),
Resolved(Value),
}
enum GroupExprStrategy {
NodeProp { variable: String },
Eval,
}
impl GroupExprStrategy {
fn for_expr(expr: &Expression) -> Self {
if let Expression::PropertyAccess { variable, .. } = expr {
Self::NodeProp {
variable: variable.clone(),
}
} else {
Self::Eval
}
}
}
#[derive(Clone, Copy)]
enum AggKind {
CountStar,
Count,
Sum,
Avg,
Min,
Max,
}
pub(crate) struct AggSpec {
kind: AggKind,
arg: Option<Expression>,
distinct: bool,
arg_is_node_var: Option<String>,
arg_is_edge_var: Option<String>,
}
struct AggState {
count: i64,
sum: f64,
sum_was_int: bool,
sum_seen_value: bool,
min: Option<Value>,
max: Option<Value>,
distinct_nodes: Option<HashSet<usize>>,
distinct_edges: Option<HashSet<usize>>,
distinct_values: Option<HashSet<Value>>,
}
impl AggState {
fn new(spec: &AggSpec) -> Self {
let (distinct_nodes, distinct_edges, distinct_values) = if spec.distinct {
(
Some(HashSet::new()),
Some(HashSet::new()),
Some(HashSet::new()),
)
} else {
(None, None, None)
};
AggState {
count: 0,
sum: 0.0,
sum_was_int: true,
sum_seen_value: false,
min: None,
max: None,
distinct_nodes,
distinct_edges,
distinct_values,
}
}
fn record(&mut self, value: Option<Value>, spec: &AggSpec) {
let val = match value {
Some(v) if !matches!(v, Value::Null) => v,
None if matches!(spec.kind, AggKind::CountStar) => {
self.count += 1;
return;
}
_ => return,
};
match spec.kind {
AggKind::CountStar | AggKind::Count => {
self.count += 1;
}
AggKind::Sum | AggKind::Avg => {
if let Some(f) = value_to_f64(&val) {
self.sum += f;
self.count += 1;
self.sum_seen_value = true;
if !matches!(val, Value::Int64(_)) {
self.sum_was_int = false;
}
}
}
AggKind::Min => {
self.min = Some(match self.min.take() {
None => val,
Some(current) => {
if cmp_lt(&val, ¤t) {
val
} else {
current
}
}
});
}
AggKind::Max => {
self.max = Some(match self.max.take() {
None => val,
Some(current) => {
if cmp_gt(&val, ¤t) {
val
} else {
current
}
}
});
}
}
}
fn merge(&mut self, other: AggState) {
self.count += other.count;
self.sum += other.sum;
if other.sum_seen_value {
self.sum_seen_value = true;
if !other.sum_was_int {
self.sum_was_int = false;
}
}
self.min = combine(self.min.take(), other.min, false);
self.max = combine(self.max.take(), other.max, true);
if let (Some(a), Some(b)) = (self.distinct_nodes.as_mut(), other.distinct_nodes) {
a.extend(b);
}
if let (Some(a), Some(b)) = (self.distinct_edges.as_mut(), other.distinct_edges) {
a.extend(b);
}
if let (Some(a), Some(b)) = (self.distinct_values.as_mut(), other.distinct_values) {
a.extend(b);
}
}
fn finalize(&self, spec: &AggSpec) -> Value {
match spec.kind {
AggKind::CountStar => Value::Int64(self.count),
AggKind::Count => {
if spec.distinct {
let n = self.distinct_nodes.as_ref().map(|s| s.len()).unwrap_or(0)
+ self.distinct_edges.as_ref().map(|s| s.len()).unwrap_or(0)
+ self.distinct_values.as_ref().map(|s| s.len()).unwrap_or(0);
Value::Int64(n as i64)
} else {
Value::Int64(self.count)
}
}
AggKind::Sum => {
if !self.sum_seen_value {
Value::Int64(0)
} else if self.sum_was_int && self.sum.fract() == 0.0 {
Value::Int64(self.sum as i64)
} else {
Value::Float64(self.sum)
}
}
AggKind::Avg => {
if self.count == 0 {
Value::Null
} else {
Value::Float64(self.sum / self.count as f64)
}
}
AggKind::Min => self.min.clone().unwrap_or(Value::Null),
AggKind::Max => self.max.clone().unwrap_or(Value::Null),
}
}
}
struct GroupAcc {
states: Vec<AggState>,
first_node_bindings: Bindings<NodeIndex>,
}
impl GroupAcc {
fn new(specs: &[AggSpec]) -> Self {
GroupAcc {
states: specs.iter().map(AggState::new).collect(),
first_node_bindings: Bindings::new(),
}
}
}
#[derive(Debug)]
pub enum AggregateBail {
UnsupportedAggregate,
UnsupportedItem,
}
pub(crate) type CompiledSpecs = (Vec<usize>, Vec<usize>, Vec<AggSpec>);
pub fn try_compile_specs(return_clause: &ReturnClause) -> Result<CompiledSpecs, AggregateBail> {
let mut group_indices = Vec::new();
let mut agg_indices = Vec::new();
let mut specs = Vec::new();
for (i, item) in return_clause.items.iter().enumerate() {
if is_aggregate_expression(&item.expression) {
let spec = compile_agg(&item.expression)?;
agg_indices.push(i);
specs.push(spec);
} else {
match &item.expression {
Expression::Variable(_) | Expression::PropertyAccess { .. } => {
group_indices.push(i);
}
_ => return Err(AggregateBail::UnsupportedItem),
}
}
}
Ok((group_indices, agg_indices, specs))
}
fn compile_agg(expr: &Expression) -> Result<AggSpec, AggregateBail> {
let (name, args, distinct) = match expr {
Expression::FunctionCall {
name,
args,
distinct,
} => (name.as_str(), args, *distinct),
_ => return Err(AggregateBail::UnsupportedAggregate),
};
let kind = match name {
"count" => {
if args.len() == 1 && matches!(args[0], Expression::Star) {
AggKind::CountStar
} else {
AggKind::Count
}
}
"sum" => AggKind::Sum,
"avg" | "mean" | "average" => AggKind::Avg,
"min" => AggKind::Min,
"max" => AggKind::Max,
_ => return Err(AggregateBail::UnsupportedAggregate),
};
let arg = if matches!(kind, AggKind::CountStar) {
None
} else if args.len() == 1 {
Some(args[0].clone())
} else {
return Err(AggregateBail::UnsupportedAggregate);
};
let arg_is_node_var = arg.as_ref().and_then(|a| match a {
Expression::Variable(v) => Some(v.clone()),
_ => None,
});
Ok(AggSpec {
kind,
arg,
distinct,
arg_is_node_var,
arg_is_edge_var: None,
})
}
pub fn apply<'q>(
executor: &'q CypherExecutor<'q>,
upstream: RowStream<'q>,
return_clause: &ReturnClause,
group_indices: &[usize],
agg_indices: &[usize],
specs: &[AggSpec],
) -> Result<RowStream<'q>, String> {
let folded_group_exprs: Vec<Expression> = group_indices
.iter()
.map(|&i| executor.fold_constants_expr(&return_clause.items[i].expression))
.collect();
let strategies: Vec<GroupExprStrategy> = folded_group_exprs
.iter()
.map(GroupExprStrategy::for_expr)
.collect();
let folded_args: Vec<Option<Expression>> = specs
.iter()
.map(|s| s.arg.as_ref().map(|e| executor.fold_constants_expr(e)))
.collect();
let mut surrogate_groups: Vec<(Vec<GroupKeyPart>, GroupAcc)> = Vec::new();
let mut surrogate_index: FxHashMap<Vec<GroupKeyPart>, usize> = FxHashMap::default();
let group_var_names: Vec<Option<String>> = group_indices
.iter()
.map(|&i| match &return_clause.items[i].expression {
Expression::Variable(v) => Some(v.clone()),
Expression::PropertyAccess { variable, .. } => Some(variable.clone()),
_ => None,
})
.collect();
let mut row_count = 0u64;
for row_result in upstream {
let row = row_result?;
row_count += 1;
if row_count.is_multiple_of(4096) {
executor.check_deadline()?;
}
let key_parts: Vec<GroupKeyPart> = strategies
.iter()
.zip(folded_group_exprs.iter())
.map(|(strategy, expr)| match strategy {
GroupExprStrategy::NodeProp { variable } => {
if let Some(&idx) = row.node_bindings.get(variable) {
GroupKeyPart::NodeProp(idx)
} else {
GroupKeyPart::Resolved(
executor
.evaluate_expression(expr, &row)
.unwrap_or(Value::Null),
)
}
}
GroupExprStrategy::Eval => GroupKeyPart::Resolved(
executor
.evaluate_expression(expr, &row)
.unwrap_or(Value::Null),
),
})
.collect();
let group_idx = match surrogate_index.get(&key_parts) {
Some(&idx) => idx,
None => {
let idx = surrogate_groups.len();
surrogate_index.insert(key_parts.clone(), idx);
let mut acc = GroupAcc::new(specs);
for var_opt in group_var_names.iter().flatten() {
if let Some(&node_idx) = row.node_bindings.get(var_opt) {
acc.first_node_bindings.insert(var_opt.clone(), node_idx);
}
}
surrogate_groups.push((key_parts, acc));
idx
}
};
let acc = &mut surrogate_groups[group_idx].1;
for (ai, spec) in specs.iter().enumerate() {
update_agg_state(
&mut acc.states[ai],
spec,
folded_args[ai].as_ref(),
&row,
executor,
);
}
}
let mut resolved_node_props: HashMap<(NodeIndex, usize), Value> = HashMap::new();
for (key_parts, _) in &surrogate_groups {
for (slot, part) in key_parts.iter().enumerate() {
if let GroupKeyPart::NodeProp(idx) = part {
resolved_node_props.entry((*idx, slot)).or_insert_with(|| {
executor.resolve_node_prop_for_group(*idx, &folded_group_exprs[slot])
});
}
}
}
let mut groups: Vec<(Vec<Value>, GroupAcc)> = Vec::new();
let mut group_index_map: FxHashMap<Vec<Value>, usize> = FxHashMap::default();
for (key_parts, acc) in surrogate_groups {
let resolved: Vec<Value> = key_parts
.iter()
.enumerate()
.map(|(slot, part)| match part {
GroupKeyPart::NodeProp(idx) => resolved_node_props
.get(&(*idx, slot))
.cloned()
.unwrap_or(Value::Null),
GroupKeyPart::Resolved(v) => v.clone(),
})
.collect();
match group_index_map.get(&resolved) {
Some(&idx) => {
let existing = std::mem::replace(&mut groups[idx].1, GroupAcc::new(specs));
let merged = merge_group_accs(existing, acc);
groups[idx].1 = merged;
}
None => {
let idx = groups.len();
group_index_map.insert(resolved.clone(), idx);
groups.push((resolved, acc));
}
}
}
let columns: Vec<String> = return_clause
.items
.iter()
.map(return_item_column_name)
.collect();
let mut output_rows: Vec<ResultRow> = Vec::with_capacity(groups.len());
for (resolved_keys, acc) in &groups {
let mut projected = Bindings::with_capacity(return_clause.items.len());
for (ki, &item_idx) in group_indices.iter().enumerate() {
let key = return_item_column_name(&return_clause.items[item_idx]);
projected.insert(key, resolved_keys[ki].clone());
}
for (ai, spec) in specs.iter().enumerate() {
let key = return_item_column_name(&return_clause.items[agg_indices[ai]]);
projected.insert(key, acc.states[ai].finalize(spec));
}
let mut row = ResultRow::from_projected(projected);
for (k, v) in acc.first_node_bindings.iter() {
row.node_bindings.insert(k.clone(), *v);
}
output_rows.push(row);
}
if output_rows.is_empty() && group_indices.is_empty() {
let mut projected = Bindings::with_capacity(return_clause.items.len());
for (ai, spec) in specs.iter().enumerate() {
let key = return_item_column_name(&return_clause.items[agg_indices[ai]]);
let empty_state = AggState::new(spec);
projected.insert(key, empty_state.finalize(spec));
}
output_rows.push(ResultRow::from_projected(projected));
}
if return_clause.distinct {
let mut seen = HashSet::new();
output_rows.retain(|row| {
let key: Vec<Value> = columns
.iter()
.map(|c| row.projected.get(c).cloned().unwrap_or(Value::Null))
.collect();
seen.insert(key)
});
}
Ok(RowStream::from_vec(output_rows, columns))
}
fn update_agg_state(
state: &mut AggState,
spec: &AggSpec,
folded_arg: Option<&Expression>,
row: &ResultRow,
executor: &CypherExecutor<'_>,
) {
if matches!(spec.kind, AggKind::CountStar) {
state.record(None, spec);
return;
}
let expr = match folded_arg {
Some(e) => e,
None => return,
};
if spec.distinct {
if let Some(var_name) = &spec.arg_is_node_var {
if let Some(&idx) = row.node_bindings.get(var_name) {
let key = idx.index();
let dn = state.distinct_nodes.get_or_insert_with(HashSet::new);
if !dn.insert(key) {
return;
}
let val = match spec.kind {
AggKind::Count => Value::Boolean(true), _ => return, };
state.record(Some(val), spec);
return;
}
}
if let Some(var_name) = &spec.arg_is_edge_var {
if let Some(eb) = row.edge_bindings.get(var_name) {
let key = eb.edge_index.index();
let de = state.distinct_edges.get_or_insert_with(HashSet::new);
if !de.insert(key) {
return;
}
state.record(Some(Value::Boolean(true)), spec);
return;
}
}
let val = executor
.evaluate_expression(expr, row)
.unwrap_or(Value::Null);
if matches!(val, Value::Null) {
return;
}
let dv = state.distinct_values.get_or_insert_with(HashSet::new);
if !dv.insert(val.clone()) {
return;
}
state.record(Some(val), spec);
} else {
let val = executor
.evaluate_expression(expr, row)
.unwrap_or(Value::Null);
state.record(Some(val), spec);
}
}
fn merge_group_accs(mut a: GroupAcc, b: GroupAcc) -> GroupAcc {
debug_assert_eq!(a.states.len(), b.states.len());
let mut merged_states = Vec::with_capacity(a.states.len());
for (sa, sb) in a.states.drain(..).zip(b.states) {
let mut sa = sa;
sa.merge(sb);
merged_states.push(sa);
}
a.states = merged_states;
GroupAcc {
states: a.states,
first_node_bindings: a.first_node_bindings,
}
}
fn value_to_f64(v: &Value) -> Option<f64> {
match v {
Value::Int64(i) => Some(*i as f64),
Value::Float64(f) => Some(*f),
Value::UniqueId(u) => Some(*u as f64),
Value::Boolean(b) => Some(if *b { 1.0 } else { 0.0 }),
_ => None,
}
}
fn cmp_lt(a: &Value, b: &Value) -> bool {
matches!(
crate::graph::core::filtering::compare_values(a, b),
Some(std::cmp::Ordering::Less)
)
}
fn cmp_gt(a: &Value, b: &Value) -> bool {
matches!(
crate::graph::core::filtering::compare_values(a, b),
Some(std::cmp::Ordering::Greater)
)
}
fn combine(a: Option<Value>, b: Option<Value>, want_max: bool) -> Option<Value> {
match (a, b) {
(None, x) | (x, None) => x,
(Some(a), Some(b)) => Some(match (want_max, cmp_lt(&a, &b)) {
(true, true) => b,
(true, false) => a,
(false, true) => a,
(false, false) => b,
}),
}
}