iridium-db 0.2.0

A high-performance vector-graph hybrid storage and indexing engine
use std::collections::HashSet;

use plexus_serde::{AggFn, Expr};

use super::super::super::{ExplainError, Result, Row};
use super::types::{AggregateField, AggregateMetric, AggregateSpec, AggregateState};

pub(super) fn validate_supported_aggregate(keys: &[u32], aggs: &[Expr]) -> Result<AggregateSpec> {
    if keys != [0] && !keys.is_empty() {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Aggregate currently supports only keys=[] or keys=[0]".to_string(),
        ));
    }
    if keys == [0] && aggs.is_empty() {
        return Ok(AggregateSpec::GroupByNodeIdDistinct);
    }
    if aggs.len() != 1 {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Aggregate currently supports exactly one aggregate expression".to_string(),
        ));
    }
    let metric = parse_supported_aggregate_metric(&aggs[0])?;
    if keys == [0] {
        Ok(AggregateSpec::GroupByNodeIdMetric(metric))
    } else {
        Ok(AggregateSpec::GlobalMetric(metric))
    }
}

fn parse_supported_aggregate_metric(expr: &Expr) -> Result<AggregateMetric> {
    let Expr::Agg { fn_, expr } = expr else {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Aggregate expression must be Agg".to_string(),
        ));
    };

    match fn_ {
        AggFn::CountStar => {
            if expr.is_some() {
                return Err(ExplainError::UnsupportedSerializedOperator(
                    "COUNT(*) must not have an argument expression".to_string(),
                ));
            }
            Ok(AggregateMetric::CountStar)
        }
        AggFn::Count => Ok(AggregateMetric::CountField(parse_aggregate_field_expr(
            expr.as_deref(),
        )?)),
        AggFn::Sum => Ok(AggregateMetric::SumField(parse_aggregate_field_expr(
            expr.as_deref(),
        )?)),
        AggFn::Avg => Ok(AggregateMetric::AvgField(parse_aggregate_field_expr(
            expr.as_deref(),
        )?)),
        AggFn::Min => Ok(AggregateMetric::MinField(parse_aggregate_field_expr(
            expr.as_deref(),
        )?)),
        AggFn::Max => Ok(AggregateMetric::MaxField(parse_aggregate_field_expr(
            expr.as_deref(),
        )?)),
        AggFn::Collect => Err(ExplainError::UnsupportedSerializedOperator(
            "Aggregate Collect is not supported".to_string(),
        )),
    }
}

fn parse_aggregate_field_expr(expr: Option<&Expr>) -> Result<AggregateField> {
    let Some(Expr::PropAccess { col, prop }) = expr else {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Aggregate metric currently requires PropAccess(col=0, field) argument".to_string(),
        ));
    };
    if *col != 0 {
        return Err(ExplainError::UnsupportedSerializedOperator(
            "Aggregate metric currently supports only PropAccess col=0".to_string(),
        ));
    }
    parse_supported_row_field(prop)
}

pub(super) fn parse_supported_row_field(prop: &str) -> Result<AggregateField> {
    match prop {
        "node_id" => Ok(AggregateField::NodeId),
        "adjacency_degree" => Ok(AggregateField::AdjacencyDegree),
        "delta_count" => Ok(AggregateField::DeltaCount),
        "has_full" => Ok(AggregateField::HasFull),
        "score" => Ok(AggregateField::Score),
        "aggregate_value" => Ok(AggregateField::AggregateValue),
        _ => Err(ExplainError::UnsupportedSerializedOperator(format!(
            "unsupported row field '{}'",
            prop
        ))),
    }
}

pub(super) fn apply_aggregate(rows: Vec<Row>, aggregate: AggregateSpec) -> Vec<Row> {
    match aggregate {
        AggregateSpec::GroupByNodeIdDistinct => {
            let mut out = Vec::new();
            let mut seen = HashSet::new();
            for row in rows {
                if seen.insert(row.node_id) {
                    out.push(row);
                }
            }
            out
        }
        AggregateSpec::GroupByNodeIdMetric(metric) => {
            let mut states = std::collections::BTreeMap::<u64, AggregateState>::new();
            let mut exemplar = std::collections::HashMap::<u64, Row>::new();
            for row in rows {
                let value = aggregate_metric_input(&row, metric);
                states.entry(row.node_id).or_default().apply(metric, value);
                exemplar.entry(row.node_id).or_insert(row);
            }
            let mut out = Vec::with_capacity(states.len());
            for (node_id, state) in states {
                if let Some(mut row) = exemplar.remove(&node_id) {
                    row.aggregate_value = finalize_aggregate(metric, &state);
                    out.push(row);
                }
            }
            out
        }
        AggregateSpec::GlobalMetric(metric) => {
            let mut state = AggregateState::default();
            for row in rows {
                let value = aggregate_metric_input(&row, metric);
                state.apply(metric, value);
            }
            vec![Row {
                node_id: 0,
                has_full: false,
                delta_count: 0,
                adjacency_degree: 0,
                score: None,
                aggregate_value: finalize_aggregate(metric, &state),
            }]
        }
    }
}

impl AggregateState {
    pub(super) fn apply(&mut self, metric: AggregateMetric, value: Option<f64>) {
        self.rows_seen += 1;
        if let Some(v) = value {
            self.values_seen += 1;
            self.sum += v;
            self.min = Some(self.min.map_or(v, |cur| cur.min(v)));
            self.max = Some(self.max.map_or(v, |cur| cur.max(v)));
        }
        if matches!(metric, AggregateMetric::CountStar) {
            self.values_seen = self.rows_seen;
        }
    }
}

fn aggregate_metric_input(row: &Row, metric: AggregateMetric) -> Option<f64> {
    match metric {
        AggregateMetric::CountStar => Some(1.0),
        AggregateMetric::CountField(field)
        | AggregateMetric::SumField(field)
        | AggregateMetric::AvgField(field)
        | AggregateMetric::MinField(field)
        | AggregateMetric::MaxField(field) => aggregate_field_value(row, field),
    }
}

fn aggregate_field_value(row: &Row, field: AggregateField) -> Option<f64> {
    match field {
        AggregateField::NodeId => Some(row.node_id as f64),
        AggregateField::AdjacencyDegree => Some(row.adjacency_degree as f64),
        AggregateField::DeltaCount => Some(row.delta_count as f64),
        AggregateField::HasFull => Some(if row.has_full { 1.0 } else { 0.0 }),
        AggregateField::Score => row.score,
        AggregateField::AggregateValue => row.aggregate_value,
    }
}

fn finalize_aggregate(metric: AggregateMetric, state: &AggregateState) -> Option<f64> {
    match metric {
        AggregateMetric::CountStar => Some(state.rows_seen as f64),
        AggregateMetric::CountField(_) => Some(state.values_seen as f64),
        AggregateMetric::SumField(_) => {
            if state.values_seen == 0 {
                None
            } else {
                Some(state.sum)
            }
        }
        AggregateMetric::AvgField(_) => {
            if state.values_seen == 0 {
                None
            } else {
                Some(state.sum / state.values_seen as f64)
            }
        }
        AggregateMetric::MinField(_) => state.min,
        AggregateMetric::MaxField(_) => state.max,
    }
}