#![allow(
clippy::unnecessary_literal_bound,
clippy::too_many_lines,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_precision_loss,
clippy::match_same_arms,
clippy::items_after_statements,
clippy::float_cmp,
clippy::cast_sign_loss,
clippy::suboptimal_flops
)]
use fsqlite_error::{FrankenError, Result};
use fsqlite_types::SqliteValue;
use crate::{AggregateFunction, FunctionRegistry};
#[inline]
fn kahan_add(sum: &mut f64, compensation: &mut f64, value: f64) {
let s = *sum;
let t = s + value;
if s.abs() > value.abs() {
*compensation += (s - t) + value;
} else {
*compensation += (value - t) + s;
}
*sum = t;
}
pub struct AvgState {
sum: f64,
compensation: f64,
count: i64,
}
pub struct AvgFunc;
impl AggregateFunction for AvgFunc {
type State = AvgState;
fn initial_state(&self) -> Self::State {
AvgState {
sum: 0.0,
compensation: 0.0,
count: 0,
}
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
if !args[0].is_null() {
kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
state.count += 1;
}
Ok(())
}
fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
if state.count == 0 {
Ok(SqliteValue::Null)
} else {
Ok(SqliteValue::Float(
(state.sum + state.compensation) / state.count as f64,
))
}
}
fn num_args(&self) -> i32 {
1
}
fn name(&self) -> &str {
"avg"
}
}
pub struct CountStarFunc;
impl AggregateFunction for CountStarFunc {
type State = i64;
fn initial_state(&self) -> Self::State {
0
}
fn step(&self, state: &mut Self::State, _args: &[SqliteValue]) -> Result<()> {
*state += 1;
Ok(())
}
fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
Ok(SqliteValue::Integer(state))
}
fn num_args(&self) -> i32 {
0 }
fn name(&self) -> &str {
"count"
}
}
pub struct CountFunc;
impl AggregateFunction for CountFunc {
type State = i64;
fn initial_state(&self) -> Self::State {
0
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
if !args[0].is_null() {
*state += 1;
}
Ok(())
}
fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
Ok(SqliteValue::Integer(state))
}
fn num_args(&self) -> i32 {
1
}
fn name(&self) -> &str {
"count"
}
}
pub struct GroupConcatState {
result: String,
has_value: bool,
}
pub struct GroupConcatFunc;
#[inline]
fn push_group_concat_text(result: &mut String, value: &SqliteValue) {
if let Some(text) = value.as_text_str() {
result.push_str(text);
} else {
result.push_str(&value.to_text());
}
}
impl AggregateFunction for GroupConcatFunc {
type State = GroupConcatState;
fn initial_state(&self) -> Self::State {
GroupConcatState {
result: String::new(),
has_value: false,
}
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
if args[0].is_null() {
return Ok(());
}
if state.has_value {
match args.get(1) {
Some(separator) if !separator.is_null() => {
push_group_concat_text(&mut state.result, separator);
}
Some(_) => {}
None => state.result.push(','),
}
}
push_group_concat_text(&mut state.result, &args[0]);
state.has_value = true;
Ok(())
}
fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
if state.has_value {
Ok(SqliteValue::Text(state.result.into()))
} else {
Ok(SqliteValue::Null)
}
}
fn num_args(&self) -> i32 {
-1 }
fn min_args(&self) -> i32 {
1
}
fn max_args(&self) -> Option<i32> {
Some(2)
}
fn name(&self) -> &str {
"group_concat"
}
}
pub struct AggMaxFunc;
impl AggregateFunction for AggMaxFunc {
type State = Option<SqliteValue>;
fn initial_state(&self) -> Self::State {
None
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
if args[0].is_null() {
return Ok(());
}
let candidate = &args[0];
match state {
None => *state = Some(candidate.clone()),
Some(current) => {
if candidate > current {
*state = Some(candidate.clone());
}
}
}
Ok(())
}
fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
Ok(state.unwrap_or(SqliteValue::Null))
}
fn num_args(&self) -> i32 {
1
}
fn name(&self) -> &str {
"max"
}
}
pub struct AggMinFunc;
impl AggregateFunction for AggMinFunc {
type State = Option<SqliteValue>;
fn initial_state(&self) -> Self::State {
None
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
if args[0].is_null() {
return Ok(());
}
let candidate = &args[0];
match state {
None => *state = Some(candidate.clone()),
Some(current) => {
if candidate < current {
*state = Some(candidate.clone());
}
}
}
Ok(())
}
fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
Ok(state.unwrap_or(SqliteValue::Null))
}
fn num_args(&self) -> i32 {
1
}
fn name(&self) -> &str {
"min"
}
}
pub struct SumState {
int_sum: i64,
float_sum: f64,
float_compensation: f64,
all_integer: bool,
has_values: bool,
overflowed: bool,
}
pub struct SumFunc;
impl AggregateFunction for SumFunc {
type State = SumState;
fn initial_state(&self) -> Self::State {
SumState {
int_sum: 0,
float_sum: 0.0,
float_compensation: 0.0,
all_integer: true,
has_values: false,
overflowed: false,
}
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
let value = args[0].to_sum_numeric_value();
if value.is_null() {
return Ok(());
}
state.has_values = true;
match value {
SqliteValue::Integer(i) => {
if state.all_integer && !state.overflowed {
match state.int_sum.checked_add(i) {
Some(s) => state.int_sum = s,
None => state.overflowed = true,
}
}
kahan_add(
&mut state.float_sum,
&mut state.float_compensation,
i as f64,
);
}
SqliteValue::Float(f) => {
state.all_integer = false;
kahan_add(&mut state.float_sum, &mut state.float_compensation, f);
}
SqliteValue::Null | SqliteValue::Text(_) | SqliteValue::Blob(_) => {}
}
Ok(())
}
fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
if !state.has_values {
return Ok(SqliteValue::Null);
}
if state.all_integer && state.overflowed {
return Err(FrankenError::IntegerOverflow);
}
if state.all_integer {
Ok(SqliteValue::Integer(state.int_sum))
} else {
Ok(SqliteValue::Float(
state.float_sum + state.float_compensation,
))
}
}
fn num_args(&self) -> i32 {
1
}
fn name(&self) -> &str {
"sum"
}
}
pub struct TotalFunc;
pub struct TotalState {
sum: f64,
compensation: f64,
}
impl AggregateFunction for TotalFunc {
type State = TotalState;
fn initial_state(&self) -> Self::State {
TotalState {
sum: 0.0,
compensation: 0.0,
}
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
if !args[0].is_null() {
kahan_add(&mut state.sum, &mut state.compensation, args[0].to_float());
}
Ok(())
}
fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
Ok(SqliteValue::Float(state.sum + state.compensation))
}
fn num_args(&self) -> i32 {
1
}
fn name(&self) -> &str {
"total"
}
}
pub struct MedianFunc;
impl AggregateFunction for MedianFunc {
type State = Vec<f64>;
fn initial_state(&self) -> Self::State {
Vec::new()
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
if !args[0].is_null() {
state.push(args[0].to_float());
}
Ok(())
}
fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
if state.is_empty() {
return Ok(SqliteValue::Null);
}
state.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let result = percentile_cont_impl(&state, 0.5);
Ok(SqliteValue::Float(result))
}
fn num_args(&self) -> i32 {
1
}
fn name(&self) -> &str {
"median"
}
}
pub struct PercentileState {
values: Vec<f64>,
p: Option<f64>,
}
pub struct PercentileFunc;
impl AggregateFunction for PercentileFunc {
type State = PercentileState;
fn initial_state(&self) -> Self::State {
PercentileState {
values: Vec::new(),
p: None,
}
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
if !args[0].is_null() {
state.values.push(args[0].to_float());
}
if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
state.p = Some(args[1].to_float());
}
Ok(())
}
fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
if state.values.is_empty() {
return Ok(SqliteValue::Null);
}
let p = state.p.unwrap_or(50.0);
state
.values
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let result = percentile_cont_impl(&state.values, p / 100.0);
Ok(SqliteValue::Float(result))
}
fn num_args(&self) -> i32 {
2
}
fn name(&self) -> &str {
"percentile"
}
}
pub struct PercentileContFunc;
impl AggregateFunction for PercentileContFunc {
type State = PercentileState;
fn initial_state(&self) -> Self::State {
PercentileState {
values: Vec::new(),
p: None,
}
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
if !args[0].is_null() {
state.values.push(args[0].to_float());
}
if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
state.p = Some(args[1].to_float());
}
Ok(())
}
fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
if state.values.is_empty() {
return Ok(SqliteValue::Null);
}
let p = state.p.unwrap_or(0.5);
state
.values
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let result = percentile_cont_impl(&state.values, p);
Ok(SqliteValue::Float(result))
}
fn num_args(&self) -> i32 {
2
}
fn name(&self) -> &str {
"percentile_cont"
}
}
pub struct PercentileDiscFunc;
impl AggregateFunction for PercentileDiscFunc {
type State = PercentileState;
fn initial_state(&self) -> Self::State {
PercentileState {
values: Vec::new(),
p: None,
}
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
if !args[0].is_null() {
state.values.push(args[0].to_float());
}
if state.p.is_none() && args.len() > 1 && !args[1].is_null() {
state.p = Some(args[1].to_float());
}
Ok(())
}
fn finalize(&self, mut state: Self::State) -> Result<SqliteValue> {
if state.values.is_empty() {
return Ok(SqliteValue::Null);
}
let p = state.p.unwrap_or(0.5);
let p = if p.is_nan() { 0.5 } else { p.clamp(0.0, 1.0) };
state
.values
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = state.values.len();
let idx = ((p * n as f64).ceil() as usize)
.saturating_sub(1)
.min(n - 1);
Ok(SqliteValue::Float(state.values[idx]))
}
fn num_args(&self) -> i32 {
2
}
fn name(&self) -> &str {
"percentile_disc"
}
}
fn percentile_cont_impl(sorted: &[f64], p: f64) -> f64 {
let n = sorted.len();
if n == 1 {
return sorted[0];
}
let p = if p.is_nan() { 0.5 } else { p.clamp(0.0, 1.0) };
let rank = p * (n - 1) as f64;
let lower = rank.floor() as usize;
let upper = rank.ceil() as usize;
if lower == upper {
sorted[lower]
} else {
let frac = rank - lower as f64;
sorted[lower] * (1.0 - frac) + sorted[upper] * frac
}
}
pub fn register_aggregate_builtins(registry: &mut FunctionRegistry) {
registry.register_aggregate(AvgFunc);
registry.register_aggregate(CountStarFunc);
registry.register_aggregate(CountFunc);
registry.register_aggregate(GroupConcatFunc);
registry.register_aggregate(AggMaxFunc);
registry.register_aggregate(AggMinFunc);
registry.register_aggregate(SumFunc);
registry.register_aggregate(TotalFunc);
registry.register_aggregate(MedianFunc);
registry.register_aggregate(PercentileFunc);
registry.register_aggregate(PercentileContFunc);
registry.register_aggregate(PercentileDiscFunc);
struct StringAggFunc;
impl AggregateFunction for StringAggFunc {
type State = GroupConcatState;
fn initial_state(&self) -> Self::State {
GroupConcatState {
result: String::new(),
has_value: false,
}
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
GroupConcatFunc.step(state, args)
}
fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
GroupConcatFunc.finalize(state)
}
fn num_args(&self) -> i32 {
2 }
fn name(&self) -> &str {
"string_agg"
}
}
registry.register_aggregate(StringAggFunc);
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-12;
fn int(v: i64) -> SqliteValue {
SqliteValue::Integer(v)
}
fn float(v: f64) -> SqliteValue {
SqliteValue::Float(v)
}
fn null() -> SqliteValue {
SqliteValue::Null
}
fn text(s: &str) -> SqliteValue {
SqliteValue::Text(s.into())
}
fn assert_float_eq(result: &SqliteValue, expected: f64) {
match result {
SqliteValue::Float(v) => {
assert!((v - expected).abs() < EPS, "expected {expected}, got {v}");
}
other => {
assert!(
matches!(other, SqliteValue::Float(_)),
"expected Float({expected}), got {other:?}"
);
}
}
}
fn run_agg<F: AggregateFunction>(func: &F, rows: &[SqliteValue]) -> SqliteValue {
let mut state = func.initial_state();
for row in rows {
func.step(&mut state, std::slice::from_ref(row)).unwrap();
}
func.finalize(state).unwrap()
}
fn run_agg2<F: AggregateFunction>(
func: &F,
rows: &[(SqliteValue, SqliteValue)],
) -> SqliteValue {
let mut state = func.initial_state();
for (a, b) in rows {
func.step(&mut state, &[a.clone(), b.clone()]).unwrap();
}
func.finalize(state).unwrap()
}
#[test]
fn test_avg_basic() {
let r = run_agg(&AvgFunc, &[int(1), int(2), int(3), int(4), int(5)]);
assert_float_eq(&r, 3.0);
}
#[test]
fn test_avg_with_nulls() {
let r = run_agg(&AvgFunc, &[int(1), null(), int(3)]);
assert_float_eq(&r, 2.0);
}
#[test]
fn test_avg_empty() {
let r = run_agg(&AvgFunc, &[]);
assert_eq!(r, SqliteValue::Null);
}
#[test]
fn test_avg_returns_real() {
let r = run_agg(&AvgFunc, &[int(2), int(4)]);
assert!(matches!(r, SqliteValue::Float(_)));
}
#[test]
fn test_count_star() {
let mut state = CountStarFunc.initial_state();
CountStarFunc.step(&mut state, &[]).unwrap(); CountStarFunc.step(&mut state, &[]).unwrap(); CountStarFunc.step(&mut state, &[]).unwrap(); let r = CountStarFunc.finalize(state).unwrap();
assert_eq!(r, int(3));
}
#[test]
fn test_count_column() {
let r = run_agg(&CountFunc, &[int(1), null(), int(3), null(), int(5)]);
assert_eq!(r, int(3));
}
#[test]
fn test_count_empty() {
let r = run_agg(&CountFunc, &[]);
assert_eq!(r, int(0));
}
#[test]
fn test_group_concat_basic() {
let r = run_agg(&GroupConcatFunc, &[text("a"), text("b"), text("c")]);
assert_eq!(r, SqliteValue::Text("a,b,c".into()));
}
#[test]
fn test_group_concat_custom_sep() {
let rows = vec![
(text("a"), text("; ")),
(text("b"), text("; ")),
(text("c"), text("; ")),
];
let r = run_agg2(&GroupConcatFunc, &rows);
assert_eq!(r, SqliteValue::Text("a; b; c".into()));
}
#[test]
fn test_group_concat_null_skipped() {
let r = run_agg(&GroupConcatFunc, &[text("a"), null(), text("c")]);
assert_eq!(r, SqliteValue::Text("a,c".into()));
}
#[test]
fn test_group_concat_empty() {
let r = run_agg(&GroupConcatFunc, &[]);
assert_eq!(r, SqliteValue::Null);
}
#[test]
fn test_group_concat_varying_separator() {
let rows = vec![
(text("a"), text("-")),
(text("b"), text("+")),
(text("c"), text("*")),
];
let r = run_agg2(&GroupConcatFunc, &rows);
assert_eq!(r, SqliteValue::Text("a+b*c".into()));
}
#[test]
fn test_group_concat_single_value() {
let r = run_agg(&GroupConcatFunc, &[text("only")]);
assert_eq!(r, SqliteValue::Text("only".into()));
}
#[test]
fn test_group_concat_integer_values_coerced_to_text() {
let r = run_agg(&GroupConcatFunc, &[int(1), int(2), int(3)]);
assert_eq!(r, SqliteValue::Text("1,2,3".into()));
}
#[test]
#[ignore = "perf-only benchmark"]
fn perf_group_concat_text_rows() {
use std::hint::black_box;
use std::time::Instant;
const ROWS: usize = 200_000;
const REPEATS: usize = 5;
let rows: Vec<SqliteValue> = (0..ROWS).map(|_| text("payload")).collect();
let mut best_ns = u128::MAX;
let mut result_len = 0usize;
for _ in 0..REPEATS {
let started = Instant::now();
let result = black_box(run_agg(&GroupConcatFunc, black_box(rows.as_slice())));
let elapsed_ns = started.elapsed().as_nanos();
if elapsed_ns < best_ns {
best_ns = elapsed_ns;
}
result_len = match result {
SqliteValue::Text(text) => text.len(),
SqliteValue::Null
| SqliteValue::Integer(_)
| SqliteValue::Float(_)
| SqliteValue::Blob(_) => 0,
};
}
println!(
"group_concat_text_rows rows={ROWS} repeats={REPEATS} best_ns={best_ns} result_len={result_len}"
);
}
#[test]
fn test_max_aggregate() {
let r = run_agg(&AggMaxFunc, &[int(3), int(7), int(1), int(5)]);
assert_eq!(r, int(7));
}
#[test]
fn test_max_aggregate_null_skipped() {
let r = run_agg(&AggMaxFunc, &[int(3), null(), int(7), null()]);
assert_eq!(r, int(7));
}
#[test]
fn test_max_aggregate_empty() {
let r = run_agg(&AggMaxFunc, &[]);
assert_eq!(r, SqliteValue::Null);
}
#[test]
fn test_min_aggregate() {
let r = run_agg(&AggMinFunc, &[int(3), int(7), int(1), int(5)]);
assert_eq!(r, int(1));
}
#[test]
fn test_min_aggregate_null_skipped() {
let r = run_agg(&AggMinFunc, &[int(3), null(), int(1), null()]);
assert_eq!(r, int(1));
}
#[test]
fn test_min_aggregate_empty() {
let r = run_agg(&AggMinFunc, &[]);
assert_eq!(r, SqliteValue::Null);
}
#[test]
fn test_sum_integers() {
let r = run_agg(&SumFunc, &[int(1), int(2), int(3)]);
assert_eq!(r, int(6));
}
#[test]
fn test_sum_reals() {
let r = run_agg(&SumFunc, &[float(1.5), float(2.5)]);
assert_float_eq(&r, 4.0);
}
#[test]
fn test_sum_empty_null() {
let r = run_agg(&SumFunc, &[]);
assert_eq!(r, SqliteValue::Null);
}
#[test]
fn test_sum_overflow_error() {
let mut state = SumFunc.initial_state();
SumFunc.step(&mut state, &[int(i64::MAX)]).unwrap();
SumFunc.step(&mut state, &[int(1)]).unwrap();
let err = SumFunc.finalize(state);
assert!(err.is_err(), "sum should raise overflow error");
}
#[test]
fn test_sum_later_real_value_clears_integer_overflow_error() {
let r = run_agg(&SumFunc, &[int(i64::MAX), int(1), float(0.5)]);
assert_float_eq(&r, 9_223_372_036_854_776_000.0);
}
#[test]
fn test_sum_integer_text_preserves_overflow_error() -> Result<()> {
let mut state = SumFunc.initial_state();
SumFunc.step(&mut state, &[text("9223372036854775807")])?;
SumFunc.step(&mut state, &[text("1")])?;
let err = SumFunc.finalize(state);
assert!(err.is_err(), "integer-text sum should raise overflow");
Ok(())
}
#[test]
fn test_sum_integer_text_later_real_clears_overflow_error() {
let r = run_agg(
&SumFunc,
&[text("9223372036854775807"), text("1"), text("0.5")],
);
assert_float_eq(&r, 9_223_372_036_854_776_000.0);
}
#[test]
fn test_sum_prefix_text_uses_real_accumulator() {
let r = run_agg(&SumFunc, &[text("123abc"), int(1)]);
assert_float_eq(&r, 124.0);
}
#[test]
fn test_sum_unicode_whitespace_text_uses_sqlite_ascii_space_rules() {
let leading = run_agg(&SumFunc, &[text("\u{00a0}123"), int(1)]);
assert_float_eq(&leading, 1.0);
let trailing = run_agg(&SumFunc, &[text("123\u{00a0}"), int(1)]);
assert_float_eq(&trailing, 124.0);
}
#[test]
fn test_sum_null_skipped() {
let r = run_agg(&SumFunc, &[int(1), null(), int(3)]);
assert_eq!(r, int(4));
}
#[test]
fn test_total_basic() {
let r = run_agg(&TotalFunc, &[int(1), int(2), int(3)]);
assert_float_eq(&r, 6.0);
}
#[test]
fn test_total_empty_zero() {
let r = run_agg(&TotalFunc, &[]);
assert_float_eq(&r, 0.0);
}
#[test]
fn test_total_no_overflow() {
let r = run_agg(&TotalFunc, &[int(i64::MAX), int(i64::MAX)]);
assert!(matches!(r, SqliteValue::Float(_)));
}
#[test]
fn test_median_basic() {
let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4), int(5)]);
assert_float_eq(&r, 3.0);
}
#[test]
fn test_median_even() {
let r = run_agg(&MedianFunc, &[int(1), int(2), int(3), int(4)]);
assert_float_eq(&r, 2.5);
}
#[test]
fn test_median_null_skipped() {
let r = run_agg(&MedianFunc, &[int(1), null(), int(3)]);
assert_float_eq(&r, 2.0);
}
#[test]
fn test_median_empty() {
let r = run_agg(&MedianFunc, &[]);
assert_eq!(r, SqliteValue::Null);
}
#[test]
fn test_percentile_50() {
let rows: Vec<(SqliteValue, SqliteValue)> = vec![
(int(1), float(50.0)),
(int(2), float(50.0)),
(int(3), float(50.0)),
(int(4), float(50.0)),
(int(5), float(50.0)),
];
let r = run_agg2(&PercentileFunc, &rows);
assert_float_eq(&r, 3.0);
}
#[test]
fn test_percentile_0() {
let rows: Vec<(SqliteValue, SqliteValue)> = vec![
(int(10), float(0.0)),
(int(20), float(0.0)),
(int(30), float(0.0)),
];
let r = run_agg2(&PercentileFunc, &rows);
assert_float_eq(&r, 10.0);
}
#[test]
fn test_percentile_100() {
let rows: Vec<(SqliteValue, SqliteValue)> = vec![
(int(10), float(100.0)),
(int(20), float(100.0)),
(int(30), float(100.0)),
];
let r = run_agg2(&PercentileFunc, &rows);
assert_float_eq(&r, 30.0);
}
#[test]
fn test_percentile_cont_basic() {
let rows: Vec<(SqliteValue, SqliteValue)> = vec![
(int(1), float(0.5)),
(int(2), float(0.5)),
(int(3), float(0.5)),
(int(4), float(0.5)),
(int(5), float(0.5)),
];
let r = run_agg2(&PercentileContFunc, &rows);
assert_float_eq(&r, 3.0);
}
#[test]
fn test_percentile_disc_basic() {
let rows: Vec<(SqliteValue, SqliteValue)> = vec![
(int(1), float(0.5)),
(int(2), float(0.5)),
(int(3), float(0.5)),
(int(4), float(0.5)),
(int(5), float(0.5)),
];
let r = run_agg2(&PercentileDiscFunc, &rows);
match r {
SqliteValue::Float(v) => {
assert!(
[1.0, 2.0, 3.0, 4.0, 5.0].contains(&v),
"expected actual value, got {v}"
);
}
other => {
assert!(
matches!(other, SqliteValue::Float(_)),
"expected Float, got {other:?}"
);
}
}
}
#[test]
fn test_percentile_disc_no_interpolation() {
let rows: Vec<(SqliteValue, SqliteValue)> = vec![
(int(10), float(0.5)),
(int(20), float(0.5)),
(int(30), float(0.5)),
(int(40), float(0.5)),
];
let r = run_agg2(&PercentileDiscFunc, &rows);
match r {
SqliteValue::Float(v) => {
assert!(
[10.0, 20.0, 30.0, 40.0].contains(&v),
"disc must not interpolate: got {v}"
);
}
other => {
assert!(
matches!(other, SqliteValue::Float(_)),
"expected Float, got {other:?}"
);
}
}
}
#[test]
fn test_string_agg_alias() {
let mut reg = FunctionRegistry::new();
register_aggregate_builtins(&mut reg);
let sa = reg
.find_aggregate("string_agg", 2)
.expect("string_agg registered");
let mut state = sa.initial_state();
sa.step(&mut state, &[text("a"), text(",")]).unwrap();
sa.step(&mut state, &[text("b"), text(",")]).unwrap();
let r = sa.finalize(state).unwrap();
assert_eq!(r, SqliteValue::Text("a,b".into()));
}
#[test]
fn test_register_aggregate_builtins_all_present() {
let mut reg = FunctionRegistry::new();
register_aggregate_builtins(&mut reg);
let expected = [
("avg", 1),
("count", 0), ("count", 1), ("max", 1),
("min", 1),
("sum", 1),
("total", 1),
("median", 1),
("percentile", 2),
("percentile_cont", 2),
("percentile_disc", 2),
("string_agg", 2),
];
for (name, arity) in expected {
assert!(
reg.find_aggregate(name, arity).is_some(),
"aggregate '{name}/{arity}' not registered"
);
}
assert!(reg.find_aggregate("group_concat", 1).is_some());
assert!(reg.find_aggregate("group_concat", 2).is_some());
let group_concat_zero = reg.find_aggregate("group_concat", 0).unwrap();
let err = group_concat_zero
.finalize(group_concat_zero.initial_state())
.expect_err("group_concat() should reject zero arguments");
assert!(
matches!(&err, FrankenError::FunctionError(message)
if message == "wrong number of arguments to function group_concat()"),
"unexpected error: {err:?}"
);
}
#[test]
fn test_e2e_registry_invoke_aggregates() {
let mut reg = FunctionRegistry::new();
register_aggregate_builtins(&mut reg);
let avg = reg.find_aggregate("avg", 1).unwrap();
let mut state = avg.initial_state();
avg.step(&mut state, &[int(10)]).unwrap();
avg.step(&mut state, &[int(20)]).unwrap();
avg.step(&mut state, &[int(30)]).unwrap();
let r = avg.finalize(state).unwrap();
assert_float_eq(&r, 20.0);
let sum = reg.find_aggregate("sum", 1).unwrap();
let mut state = sum.initial_state();
sum.step(&mut state, &[int(1)]).unwrap();
sum.step(&mut state, &[int(2)]).unwrap();
sum.step(&mut state, &[int(3)]).unwrap();
let r = sum.finalize(state).unwrap();
assert_eq!(r, int(6));
}
}