use std::collections::HashSet;
use plexus_serde::{AggFn, Expr};
use super::super::super::{ExplainError, Result, Row};
use super::types::{AggregateField, AggregateMetric, AggregateSpec, AggregateState};
pub(super) fn validate_supported_aggregate(keys: &[u32], aggs: &[Expr]) -> Result<AggregateSpec> {
if keys != [0] && !keys.is_empty() {
return Err(ExplainError::UnsupportedSerializedOperator(
"Aggregate currently supports only keys=[] or keys=[0]".to_string(),
));
}
if keys == [0] && aggs.is_empty() {
return Ok(AggregateSpec::GroupByNodeIdDistinct);
}
if aggs.len() != 1 {
return Err(ExplainError::UnsupportedSerializedOperator(
"Aggregate currently supports exactly one aggregate expression".to_string(),
));
}
let metric = parse_supported_aggregate_metric(&aggs[0])?;
if keys == [0] {
Ok(AggregateSpec::GroupByNodeIdMetric(metric))
} else {
Ok(AggregateSpec::GlobalMetric(metric))
}
}
fn parse_supported_aggregate_metric(expr: &Expr) -> Result<AggregateMetric> {
let Expr::Agg { fn_, expr } = expr else {
return Err(ExplainError::UnsupportedSerializedOperator(
"Aggregate expression must be Agg".to_string(),
));
};
match fn_ {
AggFn::CountStar => {
if expr.is_some() {
return Err(ExplainError::UnsupportedSerializedOperator(
"COUNT(*) must not have an argument expression".to_string(),
));
}
Ok(AggregateMetric::CountStar)
}
AggFn::Count => Ok(AggregateMetric::CountField(parse_aggregate_field_expr(
expr.as_deref(),
)?)),
AggFn::Sum => Ok(AggregateMetric::SumField(parse_aggregate_field_expr(
expr.as_deref(),
)?)),
AggFn::Avg => Ok(AggregateMetric::AvgField(parse_aggregate_field_expr(
expr.as_deref(),
)?)),
AggFn::Min => Ok(AggregateMetric::MinField(parse_aggregate_field_expr(
expr.as_deref(),
)?)),
AggFn::Max => Ok(AggregateMetric::MaxField(parse_aggregate_field_expr(
expr.as_deref(),
)?)),
AggFn::Collect => Err(ExplainError::UnsupportedSerializedOperator(
"Aggregate Collect is not supported".to_string(),
)),
}
}
fn parse_aggregate_field_expr(expr: Option<&Expr>) -> Result<AggregateField> {
let Some(Expr::PropAccess { col, prop }) = expr else {
return Err(ExplainError::UnsupportedSerializedOperator(
"Aggregate metric currently requires PropAccess(col=0, field) argument".to_string(),
));
};
if *col != 0 {
return Err(ExplainError::UnsupportedSerializedOperator(
"Aggregate metric currently supports only PropAccess col=0".to_string(),
));
}
parse_supported_row_field(prop)
}
pub(super) fn parse_supported_row_field(prop: &str) -> Result<AggregateField> {
match prop {
"node_id" => Ok(AggregateField::NodeId),
"adjacency_degree" => Ok(AggregateField::AdjacencyDegree),
"delta_count" => Ok(AggregateField::DeltaCount),
"has_full" => Ok(AggregateField::HasFull),
"score" => Ok(AggregateField::Score),
"aggregate_value" => Ok(AggregateField::AggregateValue),
_ => Err(ExplainError::UnsupportedSerializedOperator(format!(
"unsupported row field '{}'",
prop
))),
}
}
pub(super) fn apply_aggregate(rows: Vec<Row>, aggregate: AggregateSpec) -> Vec<Row> {
match aggregate {
AggregateSpec::GroupByNodeIdDistinct => {
let mut out = Vec::new();
let mut seen = HashSet::new();
for row in rows {
if seen.insert(row.node_id) {
out.push(row);
}
}
out
}
AggregateSpec::GroupByNodeIdMetric(metric) => {
let mut states = std::collections::BTreeMap::<u64, AggregateState>::new();
let mut exemplar = std::collections::HashMap::<u64, Row>::new();
for row in rows {
let value = aggregate_metric_input(&row, metric);
states.entry(row.node_id).or_default().apply(metric, value);
exemplar.entry(row.node_id).or_insert(row);
}
let mut out = Vec::with_capacity(states.len());
for (node_id, state) in states {
if let Some(mut row) = exemplar.remove(&node_id) {
row.aggregate_value = finalize_aggregate(metric, &state);
out.push(row);
}
}
out
}
AggregateSpec::GlobalMetric(metric) => {
let mut state = AggregateState::default();
for row in rows {
let value = aggregate_metric_input(&row, metric);
state.apply(metric, value);
}
vec![Row {
node_id: 0,
has_full: false,
delta_count: 0,
adjacency_degree: 0,
score: None,
aggregate_value: finalize_aggregate(metric, &state),
}]
}
}
}
impl AggregateState {
pub(super) fn apply(&mut self, metric: AggregateMetric, value: Option<f64>) {
self.rows_seen += 1;
if let Some(v) = value {
self.values_seen += 1;
self.sum += v;
self.min = Some(self.min.map_or(v, |cur| cur.min(v)));
self.max = Some(self.max.map_or(v, |cur| cur.max(v)));
}
if matches!(metric, AggregateMetric::CountStar) {
self.values_seen = self.rows_seen;
}
}
}
fn aggregate_metric_input(row: &Row, metric: AggregateMetric) -> Option<f64> {
match metric {
AggregateMetric::CountStar => Some(1.0),
AggregateMetric::CountField(field)
| AggregateMetric::SumField(field)
| AggregateMetric::AvgField(field)
| AggregateMetric::MinField(field)
| AggregateMetric::MaxField(field) => aggregate_field_value(row, field),
}
}
fn aggregate_field_value(row: &Row, field: AggregateField) -> Option<f64> {
match field {
AggregateField::NodeId => Some(row.node_id as f64),
AggregateField::AdjacencyDegree => Some(row.adjacency_degree as f64),
AggregateField::DeltaCount => Some(row.delta_count as f64),
AggregateField::HasFull => Some(if row.has_full { 1.0 } else { 0.0 }),
AggregateField::Score => row.score,
AggregateField::AggregateValue => row.aggregate_value,
}
}
fn finalize_aggregate(metric: AggregateMetric, state: &AggregateState) -> Option<f64> {
match metric {
AggregateMetric::CountStar => Some(state.rows_seen as f64),
AggregateMetric::CountField(_) => Some(state.values_seen as f64),
AggregateMetric::SumField(_) => {
if state.values_seen == 0 {
None
} else {
Some(state.sum)
}
}
AggregateMetric::AvgField(_) => {
if state.values_seen == 0 {
None
} else {
Some(state.sum / state.values_seen as f64)
}
}
AggregateMetric::MinField(_) => state.min,
AggregateMetric::MaxField(_) => state.max,
}
}