Skip to main content

fsqlite_func/
aggregate.rs

1//! Aggregate function trait with type-erased state adapter.
2//!
3//! Aggregate functions accumulate a result across multiple rows (e.g.
4//! `SUM`, `COUNT`, `AVG`). Each GROUP BY group gets its own state.
5//!
6//! # Type Erasure
7//!
8//! The [`FunctionRegistry`](crate::FunctionRegistry) stores aggregates as
9//! `Arc<dyn AggregateFunction<State = Box<dyn Any + Send>>>`. Concrete
10//! implementations use [`AggregateAdapter`] to wrap their typed state.
11#![allow(clippy::unnecessary_literal_bound)]
12
13use std::any::Any;
14
15use fsqlite_error::Result;
16use fsqlite_types::SqliteValue;
17
18/// An aggregate SQL function (e.g. `SUM`, `COUNT`, `AVG`).
19///
20/// This trait is **open** (user-implementable). Extension authors implement
21/// this trait to register custom aggregate functions.
22///
23/// # State Lifecycle
24///
25/// 1. [`initial_state`](Self::initial_state) creates a fresh accumulator.
26/// 2. [`step`](Self::step) is called once per row.
27/// 3. [`finalize`](Self::finalize) consumes the state and returns the result.
28///
29/// # Send + Sync
30///
31/// The function object itself is shared across threads via `Arc`. The
32/// `State` type must be `Send` so it can be moved between threads.
33pub trait AggregateFunction: Send + Sync {
34    /// The per-group accumulator type.
35    type State: Send;
36
37    /// Create a fresh accumulator (zero/identity state).
38    fn initial_state(&self) -> Self::State;
39
40    /// Process one row, updating the accumulator.
41    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()>;
42
43    /// Consume the accumulator and produce the final result.
44    fn finalize(&self, state: Self::State) -> Result<SqliteValue>;
45
46    /// The number of arguments this function accepts (`-1` = variadic).
47    fn num_args(&self) -> i32;
48
49    /// The function name, used in error messages and EXPLAIN output.
50    fn name(&self) -> &str;
51}
52
53/// Type-erased adapter that wraps a concrete [`AggregateFunction`] so the
54/// registry can store heterogeneous aggregates behind a single trait object.
55///
56/// The adapter implements `AggregateFunction<State = Box<dyn Any + Send>>`,
57/// boxing the concrete state on creation and downcasting on step/finalize.
58pub struct AggregateAdapter<F> {
59    inner: F,
60}
61
62impl<F> AggregateAdapter<F> {
63    /// Wrap a concrete aggregate function for type-erased storage.
64    pub const fn new(inner: F) -> Self {
65        Self { inner }
66    }
67}
68
69impl<F> AggregateFunction for AggregateAdapter<F>
70where
71    F: AggregateFunction,
72    F::State: 'static,
73{
74    type State = Box<dyn Any + Send>;
75
76    fn initial_state(&self) -> Self::State {
77        Box::new(self.inner.initial_state())
78    }
79
80    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
81        let concrete = state
82            .downcast_mut::<F::State>()
83            .expect("aggregate state type mismatch");
84        self.inner.step(concrete, args)
85    }
86
87    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
88        let concrete = *state
89            .downcast::<F::State>()
90            .expect("aggregate state type mismatch");
91        self.inner.finalize(concrete)
92    }
93
94    fn num_args(&self) -> i32 {
95        self.inner.num_args()
96    }
97
98    fn name(&self) -> &str {
99        self.inner.name()
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use std::sync::Arc;
106
107    use super::*;
108
109    // -- Mock: Sum aggregate --
110
111    struct SumAgg;
112
113    impl AggregateFunction for SumAgg {
114        type State = i64;
115
116        fn initial_state(&self) -> i64 {
117            0
118        }
119
120        fn step(&self, state: &mut i64, args: &[SqliteValue]) -> Result<()> {
121            *state += args[0].to_integer();
122            Ok(())
123        }
124
125        fn finalize(&self, state: i64) -> Result<SqliteValue> {
126            Ok(SqliteValue::Integer(state))
127        }
128
129        fn num_args(&self) -> i32 {
130            1
131        }
132
133        fn name(&self) -> &str {
134            "sum"
135        }
136    }
137
138    #[test]
139    fn test_aggregate_initial_state() {
140        let agg = SumAgg;
141        assert_eq!(agg.initial_state(), 0);
142    }
143
144    #[test]
145    fn test_aggregate_step_and_finalize() {
146        let agg = SumAgg;
147        let mut state = agg.initial_state();
148
149        agg.step(&mut state, &[SqliteValue::Integer(10)]).unwrap();
150        agg.step(&mut state, &[SqliteValue::Integer(20)]).unwrap();
151        agg.step(&mut state, &[SqliteValue::Integer(12)]).unwrap();
152
153        let result = agg.finalize(state).unwrap();
154        assert_eq!(result, SqliteValue::Integer(42));
155    }
156
157    #[test]
158    fn test_aggregate_type_erasure_adapter() {
159        let adapted: AggregateAdapter<SumAgg> = AggregateAdapter::new(SumAgg);
160        let erased: Arc<dyn AggregateFunction<State = Box<dyn Any + Send>>> = Arc::new(adapted);
161
162        let mut state = erased.initial_state();
163        erased
164            .step(&mut state, &[SqliteValue::Integer(10)])
165            .unwrap();
166        erased
167            .step(&mut state, &[SqliteValue::Integer(32)])
168            .unwrap();
169
170        let result = erased.finalize(state).unwrap();
171        assert_eq!(result, SqliteValue::Integer(42));
172
173        // Verify we can clone the Arc (shared across threads).
174        let e2 = Arc::clone(&erased);
175        assert_eq!(e2.name(), "sum");
176    }
177}