fsqlite_func/
aggregate.rs1#![allow(clippy::unnecessary_literal_bound)]
12
13use std::any::Any;
14
15use fsqlite_error::Result;
16use fsqlite_types::SqliteValue;
17
18pub trait AggregateFunction: Send + Sync {
34 type State: Send;
36
37 fn initial_state(&self) -> Self::State;
39
40 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()>;
42
43 fn finalize(&self, state: Self::State) -> Result<SqliteValue>;
45
46 fn num_args(&self) -> i32;
48
49 fn name(&self) -> &str;
51}
52
53pub struct AggregateAdapter<F> {
59 inner: F,
60}
61
62impl<F> AggregateAdapter<F> {
63 pub const fn new(inner: F) -> Self {
65 Self { inner }
66 }
67}
68
69impl<F> AggregateFunction for AggregateAdapter<F>
70where
71 F: AggregateFunction,
72 F::State: 'static,
73{
74 type State = Box<dyn Any + Send>;
75
76 fn initial_state(&self) -> Self::State {
77 Box::new(self.inner.initial_state())
78 }
79
80 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
81 let concrete = state
82 .downcast_mut::<F::State>()
83 .expect("aggregate state type mismatch");
84 self.inner.step(concrete, args)
85 }
86
87 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
88 let concrete = *state
89 .downcast::<F::State>()
90 .expect("aggregate state type mismatch");
91 self.inner.finalize(concrete)
92 }
93
94 fn num_args(&self) -> i32 {
95 self.inner.num_args()
96 }
97
98 fn name(&self) -> &str {
99 self.inner.name()
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use std::sync::Arc;
106
107 use super::*;
108
109 struct SumAgg;
112
113 impl AggregateFunction for SumAgg {
114 type State = i64;
115
116 fn initial_state(&self) -> i64 {
117 0
118 }
119
120 fn step(&self, state: &mut i64, args: &[SqliteValue]) -> Result<()> {
121 *state += args[0].to_integer();
122 Ok(())
123 }
124
125 fn finalize(&self, state: i64) -> Result<SqliteValue> {
126 Ok(SqliteValue::Integer(state))
127 }
128
129 fn num_args(&self) -> i32 {
130 1
131 }
132
133 fn name(&self) -> &str {
134 "sum"
135 }
136 }
137
138 #[test]
139 fn test_aggregate_initial_state() {
140 let agg = SumAgg;
141 assert_eq!(agg.initial_state(), 0);
142 }
143
144 #[test]
145 fn test_aggregate_step_and_finalize() {
146 let agg = SumAgg;
147 let mut state = agg.initial_state();
148
149 agg.step(&mut state, &[SqliteValue::Integer(10)]).unwrap();
150 agg.step(&mut state, &[SqliteValue::Integer(20)]).unwrap();
151 agg.step(&mut state, &[SqliteValue::Integer(12)]).unwrap();
152
153 let result = agg.finalize(state).unwrap();
154 assert_eq!(result, SqliteValue::Integer(42));
155 }
156
157 #[test]
158 fn test_aggregate_type_erasure_adapter() {
159 let adapted: AggregateAdapter<SumAgg> = AggregateAdapter::new(SumAgg);
160 let erased: Arc<dyn AggregateFunction<State = Box<dyn Any + Send>>> = Arc::new(adapted);
161
162 let mut state = erased.initial_state();
163 erased
164 .step(&mut state, &[SqliteValue::Integer(10)])
165 .unwrap();
166 erased
167 .step(&mut state, &[SqliteValue::Integer(32)])
168 .unwrap();
169
170 let result = erased.finalize(state).unwrap();
171 assert_eq!(result, SqliteValue::Integer(42));
172
173 let e2 = Arc::clone(&erased);
175 assert_eq!(e2.name(), "sum");
176 }
177}