Skip to main content

fsqlite_func/
aggregate.rs

1//! Aggregate function trait with type-erased state adapter.
2//!
3//! Aggregate functions accumulate a result across multiple rows (e.g.
4//! `SUM`, `COUNT`, `AVG`). Each GROUP BY group gets its own state.
5//!
6//! # Type Erasure
7//!
8//! The [`FunctionRegistry`](crate::FunctionRegistry) stores aggregates as
9//! `Arc<dyn AggregateFunction<State = Box<dyn Any + Send>>>`. Concrete
10//! implementations use [`AggregateAdapter`] to wrap their typed state.
11#![allow(clippy::unnecessary_literal_bound)]
12
13use std::any::Any;
14
15use fsqlite_error::Result;
16use fsqlite_types::SqliteValue;
17
18/// An aggregate SQL function (e.g. `SUM`, `COUNT`, `AVG`).
19///
20/// This trait is **open** (user-implementable). Extension authors implement
21/// this trait to register custom aggregate functions.
22///
23/// # State Lifecycle
24///
25/// 1. [`initial_state`](Self::initial_state) creates a fresh accumulator.
26/// 2. [`step`](Self::step) is called once per row.
27/// 3. [`finalize`](Self::finalize) consumes the state and returns the result.
28///
29/// # Send + Sync
30///
31/// The function object itself is shared across threads via `Arc`. The
32/// `State` type must be `Send` so it can be moved between threads.
33pub trait AggregateFunction: Send + Sync {
34    /// The per-group accumulator type.
35    type State: Send;
36
37    /// Create a fresh accumulator (zero/identity state).
38    fn initial_state(&self) -> Self::State;
39
40    /// Process one row, updating the accumulator.
41    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()>;
42
43    /// Consume the accumulator and produce the final result.
44    fn finalize(&self, state: Self::State) -> Result<SqliteValue>;
45
46    /// The number of arguments this function accepts (`-1` = variadic).
47    fn num_args(&self) -> i32;
48
49    /// Minimum accepted SQL argument count for variadic functions.
50    ///
51    /// Fixed-arity functions default to their exact arity. Variadic functions
52    /// default to accepting zero arguments unless an implementation tightens
53    /// the bound to match SQLite's function surface.
54    fn min_args(&self) -> i32 {
55        self.num_args().max(0)
56    }
57
58    /// Maximum accepted SQL argument count, or `None` for unbounded variadic
59    /// functions.
60    fn max_args(&self) -> Option<i32> {
61        (self.num_args() >= 0).then(|| self.num_args())
62    }
63
64    /// Return whether this function accepts `num_args` SQL arguments.
65    fn accepts_arg_count(&self, num_args: i32) -> bool {
66        num_args >= self.min_args() && self.max_args().is_none_or(|max| num_args <= max)
67    }
68
69    /// The function name, used in error messages and EXPLAIN output.
70    fn name(&self) -> &str;
71}
72
73/// Type-erased adapter that wraps a concrete [`AggregateFunction`] so the
74/// registry can store heterogeneous aggregates behind a single trait object.
75///
76/// The adapter implements `AggregateFunction<State = Box<dyn Any + Send>>`,
77/// boxing the concrete state on creation and downcasting on step/finalize.
78pub struct AggregateAdapter<F> {
79    inner: F,
80}
81
82impl<F> AggregateAdapter<F> {
83    /// Wrap a concrete aggregate function for type-erased storage.
84    pub const fn new(inner: F) -> Self {
85        Self { inner }
86    }
87}
88
89impl<F> AggregateFunction for AggregateAdapter<F>
90where
91    F: AggregateFunction,
92    F::State: 'static,
93{
94    type State = Box<dyn Any + Send>;
95
96    fn initial_state(&self) -> Self::State {
97        Box::new(self.inner.initial_state())
98    }
99
100    fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
101        let concrete = state
102            .downcast_mut::<F::State>()
103            .expect("aggregate state type mismatch");
104        self.inner.step(concrete, args)
105    }
106
107    fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
108        let concrete = *state
109            .downcast::<F::State>()
110            .expect("aggregate state type mismatch");
111        self.inner.finalize(concrete)
112    }
113
114    fn num_args(&self) -> i32 {
115        self.inner.num_args()
116    }
117
118    fn min_args(&self) -> i32 {
119        self.inner.min_args()
120    }
121
122    fn max_args(&self) -> Option<i32> {
123        self.inner.max_args()
124    }
125
126    fn name(&self) -> &str {
127        self.inner.name()
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use std::sync::Arc;
134
135    use super::*;
136
137    // -- Mock: Sum aggregate --
138
139    struct SumAgg;
140
141    impl AggregateFunction for SumAgg {
142        type State = i64;
143
144        fn initial_state(&self) -> i64 {
145            0
146        }
147
148        fn step(&self, state: &mut i64, args: &[SqliteValue]) -> Result<()> {
149            *state += args[0].to_integer();
150            Ok(())
151        }
152
153        fn finalize(&self, state: i64) -> Result<SqliteValue> {
154            Ok(SqliteValue::Integer(state))
155        }
156
157        fn num_args(&self) -> i32 {
158            1
159        }
160
161        fn name(&self) -> &str {
162            "sum"
163        }
164    }
165
166    #[test]
167    fn test_aggregate_initial_state() {
168        let agg = SumAgg;
169        assert_eq!(agg.initial_state(), 0);
170    }
171
172    #[test]
173    fn test_aggregate_step_and_finalize() {
174        let agg = SumAgg;
175        let mut state = agg.initial_state();
176
177        agg.step(&mut state, &[SqliteValue::Integer(10)]).unwrap();
178        agg.step(&mut state, &[SqliteValue::Integer(20)]).unwrap();
179        agg.step(&mut state, &[SqliteValue::Integer(12)]).unwrap();
180
181        let result = agg.finalize(state).unwrap();
182        assert_eq!(result, SqliteValue::Integer(42));
183    }
184
185    #[test]
186    fn test_aggregate_type_erasure_adapter() {
187        let adapted: AggregateAdapter<SumAgg> = AggregateAdapter::new(SumAgg);
188        let erased: Arc<dyn AggregateFunction<State = Box<dyn Any + Send>>> = Arc::new(adapted);
189
190        let mut state = erased.initial_state();
191        erased
192            .step(&mut state, &[SqliteValue::Integer(10)])
193            .unwrap();
194        erased
195            .step(&mut state, &[SqliteValue::Integer(32)])
196            .unwrap();
197
198        let result = erased.finalize(state).unwrap();
199        assert_eq!(result, SqliteValue::Integer(42));
200
201        // Verify we can clone the Arc (shared across threads).
202        let e2 = Arc::clone(&erased);
203        assert_eq!(e2.name(), "sum");
204    }
205}