#![allow(clippy::unnecessary_literal_bound)]
use std::any::Any;
use fsqlite_error::Result;
use fsqlite_types::SqliteValue;
pub trait AggregateFunction: Send + Sync {
type State: Send;
fn initial_state(&self) -> Self::State;
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()>;
fn finalize(&self, state: Self::State) -> Result<SqliteValue>;
fn num_args(&self) -> i32;
fn min_args(&self) -> i32 {
self.num_args().max(0)
}
fn max_args(&self) -> Option<i32> {
(self.num_args() >= 0).then(|| self.num_args())
}
fn accepts_arg_count(&self, num_args: i32) -> bool {
num_args >= self.min_args() && self.max_args().is_none_or(|max| num_args <= max)
}
fn name(&self) -> &str;
}
pub struct AggregateAdapter<F> {
inner: F,
}
impl<F> AggregateAdapter<F> {
pub const fn new(inner: F) -> Self {
Self { inner }
}
}
impl<F> AggregateFunction for AggregateAdapter<F>
where
F: AggregateFunction,
F::State: 'static,
{
type State = Box<dyn Any + Send>;
fn initial_state(&self) -> Self::State {
Box::new(self.inner.initial_state())
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
let concrete = state
.downcast_mut::<F::State>()
.expect("aggregate state type mismatch");
self.inner.step(concrete, args)
}
fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
let concrete = *state
.downcast::<F::State>()
.expect("aggregate state type mismatch");
self.inner.finalize(concrete)
}
fn num_args(&self) -> i32 {
self.inner.num_args()
}
fn min_args(&self) -> i32 {
self.inner.min_args()
}
fn max_args(&self) -> Option<i32> {
self.inner.max_args()
}
fn name(&self) -> &str {
self.inner.name()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
struct SumAgg;
impl AggregateFunction for SumAgg {
type State = i64;
fn initial_state(&self) -> i64 {
0
}
fn step(&self, state: &mut i64, args: &[SqliteValue]) -> Result<()> {
*state += args[0].to_integer();
Ok(())
}
fn finalize(&self, state: i64) -> Result<SqliteValue> {
Ok(SqliteValue::Integer(state))
}
fn num_args(&self) -> i32 {
1
}
fn name(&self) -> &str {
"sum"
}
}
#[test]
fn test_aggregate_initial_state() {
let agg = SumAgg;
assert_eq!(agg.initial_state(), 0);
}
#[test]
fn test_aggregate_step_and_finalize() {
let agg = SumAgg;
let mut state = agg.initial_state();
agg.step(&mut state, &[SqliteValue::Integer(10)]).unwrap();
agg.step(&mut state, &[SqliteValue::Integer(20)]).unwrap();
agg.step(&mut state, &[SqliteValue::Integer(12)]).unwrap();
let result = agg.finalize(state).unwrap();
assert_eq!(result, SqliteValue::Integer(42));
}
#[test]
fn test_aggregate_type_erasure_adapter() {
let adapted: AggregateAdapter<SumAgg> = AggregateAdapter::new(SumAgg);
let erased: Arc<dyn AggregateFunction<State = Box<dyn Any + Send>>> = Arc::new(adapted);
let mut state = erased.initial_state();
erased
.step(&mut state, &[SqliteValue::Integer(10)])
.unwrap();
erased
.step(&mut state, &[SqliteValue::Integer(32)])
.unwrap();
let result = erased.finalize(state).unwrap();
assert_eq!(result, SqliteValue::Integer(42));
let e2 = Arc::clone(&erased);
assert_eq!(e2.name(), "sum");
}
}