fsqlite_func/
aggregate.rs1#![allow(clippy::unnecessary_literal_bound)]
12
13use std::any::Any;
14
15use fsqlite_error::Result;
16use fsqlite_types::SqliteValue;
17
18pub trait AggregateFunction: Send + Sync {
34 type State: Send;
36
37 fn initial_state(&self) -> Self::State;
39
40 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()>;
42
43 fn finalize(&self, state: Self::State) -> Result<SqliteValue>;
45
46 fn num_args(&self) -> i32;
48
49 fn min_args(&self) -> i32 {
55 self.num_args().max(0)
56 }
57
58 fn max_args(&self) -> Option<i32> {
61 (self.num_args() >= 0).then(|| self.num_args())
62 }
63
64 fn accepts_arg_count(&self, num_args: i32) -> bool {
66 num_args >= self.min_args() && self.max_args().is_none_or(|max| num_args <= max)
67 }
68
69 fn name(&self) -> &str;
71}
72
73pub struct AggregateAdapter<F> {
79 inner: F,
80}
81
82impl<F> AggregateAdapter<F> {
83 pub const fn new(inner: F) -> Self {
85 Self { inner }
86 }
87}
88
89impl<F> AggregateFunction for AggregateAdapter<F>
90where
91 F: AggregateFunction,
92 F::State: 'static,
93{
94 type State = Box<dyn Any + Send>;
95
96 fn initial_state(&self) -> Self::State {
97 Box::new(self.inner.initial_state())
98 }
99
100 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
101 let concrete = state
102 .downcast_mut::<F::State>()
103 .expect("aggregate state type mismatch");
104 self.inner.step(concrete, args)
105 }
106
107 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
108 let concrete = *state
109 .downcast::<F::State>()
110 .expect("aggregate state type mismatch");
111 self.inner.finalize(concrete)
112 }
113
114 fn num_args(&self) -> i32 {
115 self.inner.num_args()
116 }
117
118 fn min_args(&self) -> i32 {
119 self.inner.min_args()
120 }
121
122 fn max_args(&self) -> Option<i32> {
123 self.inner.max_args()
124 }
125
126 fn name(&self) -> &str {
127 self.inner.name()
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use std::sync::Arc;
134
135 use super::*;
136
137 struct SumAgg;
140
141 impl AggregateFunction for SumAgg {
142 type State = i64;
143
144 fn initial_state(&self) -> i64 {
145 0
146 }
147
148 fn step(&self, state: &mut i64, args: &[SqliteValue]) -> Result<()> {
149 *state += args[0].to_integer();
150 Ok(())
151 }
152
153 fn finalize(&self, state: i64) -> Result<SqliteValue> {
154 Ok(SqliteValue::Integer(state))
155 }
156
157 fn num_args(&self) -> i32 {
158 1
159 }
160
161 fn name(&self) -> &str {
162 "sum"
163 }
164 }
165
166 #[test]
167 fn test_aggregate_initial_state() {
168 let agg = SumAgg;
169 assert_eq!(agg.initial_state(), 0);
170 }
171
172 #[test]
173 fn test_aggregate_step_and_finalize() {
174 let agg = SumAgg;
175 let mut state = agg.initial_state();
176
177 agg.step(&mut state, &[SqliteValue::Integer(10)]).unwrap();
178 agg.step(&mut state, &[SqliteValue::Integer(20)]).unwrap();
179 agg.step(&mut state, &[SqliteValue::Integer(12)]).unwrap();
180
181 let result = agg.finalize(state).unwrap();
182 assert_eq!(result, SqliteValue::Integer(42));
183 }
184
185 #[test]
186 fn test_aggregate_type_erasure_adapter() {
187 let adapted: AggregateAdapter<SumAgg> = AggregateAdapter::new(SumAgg);
188 let erased: Arc<dyn AggregateFunction<State = Box<dyn Any + Send>>> = Arc::new(adapted);
189
190 let mut state = erased.initial_state();
191 erased
192 .step(&mut state, &[SqliteValue::Integer(10)])
193 .unwrap();
194 erased
195 .step(&mut state, &[SqliteValue::Integer(32)])
196 .unwrap();
197
198 let result = erased.finalize(state).unwrap();
199 assert_eq!(result, SqliteValue::Integer(42));
200
201 let e2 = Arc::clone(&erased);
203 assert_eq!(e2.name(), "sum");
204 }
205}