selene-db-gql 1.3.0

ISO/IEC 39075:2024 GQL parser, planner, optimizer, and executor for selene-db.
Documentation
use rustc_hash::FxHashSet;
use selene_core::Value;

use crate::{
    Aggregate, GqlStatus, SourceSpan,
    runtime::{
        Binding, BindingTableSchema, DataExceptionSubclass, EvalCtx, ExecutorError,
        ExecutorWarning, evaluator, value_compare, value_key::RuntimeEqKey,
    },
};

mod numeric;

use self::numeric::{
    NumericSum, Welford, add_numeric, avg_to_value, count_to_value, data_exception_value,
    percentile_cont_to_value, percentile_disc_to_value, percentile_numeric_to_f64,
    percentile_value, stddev_pop_to_value, stddev_samp_to_value,
};

pub(super) struct AggregateSlot<'plan> {
    aggregate: &'plan Aggregate,
    state: AggregateState,
    seen: FxHashSet<RuntimeEqKey>,
}

impl<'plan> AggregateSlot<'plan> {
    pub(super) fn new(aggregate: &'plan Aggregate) -> Result<Self, ExecutorError> {
        Ok(Self {
            aggregate,
            state: AggregateState::new(classify(aggregate)?),
            seen: FxHashSet::default(),
        })
    }

    pub(super) fn observe(
        &mut self,
        row: &Binding,
        schema: &BindingTableSchema,
        ctx: &EvalCtx<'_, '_, '_, '_>,
    ) -> Result<(), ExecutorError> {
        if matches!(self.state, AggregateState::CountStar { .. }) {
            return self.state.observe(None, self.aggregate.span);
        }
        if self.state.needs_percentile() {
            let arg = self
                .aggregate
                .args
                .get(1)
                .ok_or(ExecutorError::ImplementationDefined {
                    detail: "PERCENTILE independent argument missing",
                })?;
            let value = evaluator::evaluate(&arg.expr, row, schema, ctx)?;
            let percentile = percentile_value(value, self.aggregate.span)?;
            self.state.set_percentile(percentile);
        }
        let arg = self
            .aggregate
            .args
            .first()
            .ok_or(ExecutorError::ImplementationDefined {
                detail: "aggregate argument missing",
            })?;
        let value = evaluator::evaluate(&arg.expr, row, schema, ctx)?;
        if self.state.skips_null() && matches!(value, Value::Null) {
            ctx.tx.emit_warning_once(ExecutorWarning {
                code: GqlStatus::NULL_VALUE_ELIMINATED_IN_SET_FUNCTION,
                message: "null value eliminated in set function".to_owned(),
                span: self.aggregate.span,
            });
            return Ok(());
        }
        if self.aggregate.distinct {
            let key = RuntimeEqKey::from_row(vec![value.clone()]);
            if !self.seen.insert(key) {
                return Ok(());
            }
        }
        self.state.observe(Some(value), self.aggregate.span)
    }

    pub(super) fn finalize_values(self) -> Result<Vec<Value>, ExecutorError> {
        let value = self.state.finalize(self.aggregate.span)?;
        Ok(vec![value])
    }
}

pub(super) fn output_names(aggregate: &Aggregate) -> Vec<selene_core::DbString> {
    vec![aggregate.output_name.clone()]
}

#[derive(Clone, Copy)]
enum AggregateFn {
    Count,
    CountStar,
    Sum,
    Avg,
    StddevPop,
    StddevSamp,
    Min,
    Max,
    Collect,
    PercentileCont,
    PercentileDisc,
}

enum AggregateState {
    Count {
        count: u64,
    },
    CountStar {
        count: u64,
    },
    Sum {
        sum: Option<NumericSum>,
    },
    Avg {
        sum: Option<NumericSum>,
        count: u64,
    },
    StddevPop {
        stats: Welford,
    },
    StddevSamp {
        stats: Welford,
    },
    Min {
        value: Option<Value>,
    },
    Max {
        value: Option<Value>,
    },
    Collect {
        values: Vec<Value>,
    },
    PercentileCont {
        values: Vec<f64>,
        percentile: Option<Option<f64>>,
    },
    PercentileDisc {
        values: Vec<Value>,
        percentile: Option<Option<f64>>,
    },
}

impl AggregateState {
    fn new(function: AggregateFn) -> Self {
        match function {
            AggregateFn::Count => Self::Count { count: 0 },
            AggregateFn::CountStar => Self::CountStar { count: 0 },
            AggregateFn::Sum => Self::Sum { sum: None },
            AggregateFn::Avg => Self::Avg {
                sum: None,
                count: 0,
            },
            AggregateFn::StddevPop => Self::StddevPop {
                stats: Welford::default(),
            },
            AggregateFn::StddevSamp => Self::StddevSamp {
                stats: Welford::default(),
            },
            AggregateFn::Min => Self::Min { value: None },
            AggregateFn::Max => Self::Max { value: None },
            AggregateFn::Collect => Self::Collect { values: Vec::new() },
            AggregateFn::PercentileCont => Self::PercentileCont {
                values: Vec::new(),
                percentile: None,
            },
            AggregateFn::PercentileDisc => Self::PercentileDisc {
                values: Vec::new(),
                percentile: None,
            },
        }
    }

    fn skips_null(&self) -> bool {
        !matches!(self, Self::CountStar { .. } | Self::Collect { .. })
    }

    fn needs_percentile(&self) -> bool {
        matches!(
            self,
            Self::PercentileCont {
                percentile: None,
                ..
            } | Self::PercentileDisc {
                percentile: None,
                ..
            }
        )
    }

    fn set_percentile(&mut self, next: Option<f64>) {
        match self {
            Self::PercentileCont { percentile, .. } | Self::PercentileDisc { percentile, .. }
                if percentile.is_none() =>
            {
                *percentile = Some(next);
            }
            _ => {}
        }
    }

    fn observe(&mut self, value: Option<Value>, span: SourceSpan) -> Result<(), ExecutorError> {
        match self {
            Self::Count { count } => {
                *count = count.saturating_add(1);
                Ok(())
            }
            Self::CountStar { count } => {
                *count = count.saturating_add(1);
                Ok(())
            }
            Self::Sum { sum } => {
                let value = value.ok_or(ExecutorError::ImplementationDefined {
                    detail: "aggregate value missing",
                })?;
                *sum = Some(add_numeric(sum.take(), value, span)?);
                Ok(())
            }
            Self::Avg { sum, count } => {
                let value = value.ok_or(ExecutorError::ImplementationDefined {
                    detail: "aggregate value missing",
                })?;
                *sum = Some(add_numeric(sum.take(), value, span)?);
                *count = count.saturating_add(1);
                Ok(())
            }
            Self::StddevPop { stats } | Self::StddevSamp { stats } => {
                let value = value.ok_or(ExecutorError::ImplementationDefined {
                    detail: "aggregate value missing",
                })?;
                stats.observe(value, span)
            }
            Self::Min { value: current } => update_min_max(current, value, span, true),
            Self::Max { value: current } => update_min_max(current, value, span, false),
            Self::Collect { values } => {
                values.push(value.ok_or(ExecutorError::ImplementationDefined {
                    detail: "aggregate value missing",
                })?);
                Ok(())
            }
            Self::PercentileCont { values, .. } => {
                let value = value.ok_or(ExecutorError::ImplementationDefined {
                    detail: "aggregate value missing",
                })?;
                values.push(percentile_numeric_to_f64(&value, span)?);
                Ok(())
            }
            Self::PercentileDisc { values, .. } => {
                let value = value.ok_or(ExecutorError::ImplementationDefined {
                    detail: "aggregate value missing",
                })?;
                percentile_numeric_to_f64(&value, span)?;
                values.push(value);
                Ok(())
            }
        }
    }

    fn finalize(self, span: SourceSpan) -> Result<Value, ExecutorError> {
        match self {
            Self::Count { count } | Self::CountStar { count } => count_to_value(count, span),
            Self::Sum { sum } => Ok(sum.map_or(Value::Int(0), NumericSum::into_value)),
            Self::Avg { sum, count } => avg_to_value(sum, count, span),
            Self::StddevPop { stats } => stddev_pop_to_value(stats, span),
            Self::StddevSamp { stats } => stddev_samp_to_value(stats, span),
            Self::Min { value } | Self::Max { value } => Ok(value.unwrap_or(Value::Null)),
            Self::Collect { values } => Ok(Value::List(values)),
            Self::PercentileCont { values, percentile } => {
                percentile_cont_to_value(values, percentile, span)
            }
            Self::PercentileDisc { values, percentile } => {
                percentile_disc_to_value(values, percentile, span)
            }
        }
    }
}

fn classify(aggregate: &Aggregate) -> Result<AggregateFn, ExecutorError> {
    let name = aggregate.function.as_str();
    if aggregate.star {
        return if name == "count" && !aggregate.distinct {
            Ok(AggregateFn::CountStar)
        } else {
            Err(ExecutorError::ImplementationDefined {
                detail: "aggregate star form not implemented",
            })
        };
    }
    if matches!(name, "percentile_cont" | "percentile_disc") {
        if aggregate.args.len() != 2 {
            return Err(ExecutorError::ImplementationDefined {
                detail: "PERCENTILE aggregate arity not implemented",
            });
        }
        return match name {
            "percentile_cont" => Ok(AggregateFn::PercentileCont),
            "percentile_disc" => Ok(AggregateFn::PercentileDisc),
            _ => unreachable!("matches! limited percentile names"),
        };
    }
    if aggregate.args.len() != 1 {
        return Err(ExecutorError::ImplementationDefined {
            detail: "aggregate arity not implemented",
        });
    }
    match name {
        "count" => Ok(AggregateFn::Count),
        "sum" => Ok(AggregateFn::Sum),
        "avg" => Ok(AggregateFn::Avg),
        "stddev_pop" => Ok(AggregateFn::StddevPop),
        "stddev_samp" => Ok(AggregateFn::StddevSamp),
        "min" => Ok(AggregateFn::Min),
        "max" => Ok(AggregateFn::Max),
        "collect_list" => Ok(AggregateFn::Collect),
        _ => Err(ExecutorError::ImplementationDefined {
            detail: "aggregate function not implemented",
        }),
    }
}

fn update_min_max(
    current: &mut Option<Value>,
    next: Option<Value>,
    span: SourceSpan,
    keep_min: bool,
) -> Result<(), ExecutorError> {
    let next = next.ok_or(ExecutorError::ImplementationDefined {
        detail: "aggregate value missing",
    })?;
    let Some(current_value) = current else {
        *current = Some(next);
        return Ok(());
    };
    let ordering = value_compare::compare_non_null(&next, current_value).ok_or_else(|| {
        data_exception_value(
            DataExceptionSubclass::ValuesNotComparable,
            "aggregate value is not order-comparable",
            span,
        )
    })?;
    if (keep_min && ordering.is_lt()) || (!keep_min && ordering.is_gt()) {
        *current_value = next;
    }
    Ok(())
}