1#![allow(clippy::unnecessary_literal_bound)]
8
9use std::any::Any;
10
11use fsqlite_error::Result;
12use fsqlite_types::SqliteValue;
13
14pub trait WindowFunction: Send + Sync {
29 type State: Send;
31
32 fn initial_state(&self) -> Self::State;
34
35 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()>;
37
38 fn inverse(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()>;
44
45 fn value(&self, state: &Self::State) -> Result<SqliteValue>;
50
51 fn finalize(&self, state: Self::State) -> Result<SqliteValue>;
53
54 fn num_args(&self) -> i32;
56
57 fn min_args(&self) -> i32 {
63 self.num_args().max(0)
64 }
65
66 fn max_args(&self) -> Option<i32> {
69 (self.num_args() >= 0).then(|| self.num_args())
70 }
71
72 fn accepts_arg_count(&self, num_args: i32) -> bool {
74 num_args >= self.min_args() && self.max_args().is_none_or(|max| num_args <= max)
75 }
76
77 fn name(&self) -> &str;
79}
80
81pub struct WindowAdapter<F> {
84 inner: F,
85}
86
87impl<F> WindowAdapter<F> {
88 pub const fn new(inner: F) -> Self {
90 Self { inner }
91 }
92}
93
94impl<F> WindowFunction for WindowAdapter<F>
95where
96 F: WindowFunction,
97 F::State: 'static,
98{
99 type State = Box<dyn Any + Send>;
100
101 fn initial_state(&self) -> Self::State {
102 Box::new(self.inner.initial_state())
103 }
104
105 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
106 let concrete = state
107 .downcast_mut::<F::State>()
108 .expect("window state type mismatch");
109 self.inner.step(concrete, args)
110 }
111
112 fn inverse(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
113 let concrete = state
114 .downcast_mut::<F::State>()
115 .expect("window state type mismatch");
116 self.inner.inverse(concrete, args)
117 }
118
119 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
120 let concrete = state
121 .downcast_ref::<F::State>()
122 .expect("window state type mismatch");
123 self.inner.value(concrete)
124 }
125
126 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
127 let concrete = *state
128 .downcast::<F::State>()
129 .expect("window state type mismatch");
130 self.inner.finalize(concrete)
131 }
132
133 fn num_args(&self) -> i32 {
134 self.inner.num_args()
135 }
136
137 fn min_args(&self) -> i32 {
138 self.inner.min_args()
139 }
140
141 fn max_args(&self) -> Option<i32> {
142 self.inner.max_args()
143 }
144
145 fn name(&self) -> &str {
146 self.inner.name()
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 struct WindowSum;
157
158 impl WindowFunction for WindowSum {
159 type State = i64;
160
161 fn initial_state(&self) -> i64 {
162 0
163 }
164
165 fn step(&self, state: &mut i64, args: &[SqliteValue]) -> Result<()> {
166 *state += args[0].to_integer();
167 Ok(())
168 }
169
170 fn inverse(&self, state: &mut i64, args: &[SqliteValue]) -> Result<()> {
171 *state -= args[0].to_integer();
172 Ok(())
173 }
174
175 fn value(&self, state: &i64) -> Result<SqliteValue> {
176 Ok(SqliteValue::Integer(*state))
177 }
178
179 fn finalize(&self, state: i64) -> Result<SqliteValue> {
180 Ok(SqliteValue::Integer(state))
181 }
182
183 fn num_args(&self) -> i32 {
184 1
185 }
186
187 fn name(&self) -> &str {
188 "window_sum"
189 }
190 }
191
192 #[test]
193 fn test_window_function_step_and_inverse() {
194 let f = WindowSum;
195 let mut state = f.initial_state();
196
197 f.step(&mut state, &[SqliteValue::Integer(10)]).unwrap();
199 f.step(&mut state, &[SqliteValue::Integer(20)]).unwrap();
200 f.step(&mut state, &[SqliteValue::Integer(30)]).unwrap();
201 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(60));
202
203 f.inverse(&mut state, &[SqliteValue::Integer(10)]).unwrap();
205 f.step(&mut state, &[SqliteValue::Integer(40)]).unwrap();
206 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(90));
207
208 f.inverse(&mut state, &[SqliteValue::Integer(20)]).unwrap();
210 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(70));
211 }
212
213 #[test]
214 fn test_window_function_value_without_consuming() {
215 let f = WindowSum;
216 let mut state = f.initial_state();
217
218 f.step(&mut state, &[SqliteValue::Integer(42)]).unwrap();
219
220 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(42));
222 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(42));
223 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(42));
224
225 f.step(&mut state, &[SqliteValue::Integer(8)]).unwrap();
227 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(50));
228 }
229
230 #[test]
231 fn test_window_function_finalize_consumes() {
232 let f = WindowSum;
233 let mut state = f.initial_state();
234
235 f.step(&mut state, &[SqliteValue::Integer(10)]).unwrap();
236 f.step(&mut state, &[SqliteValue::Integer(32)]).unwrap();
237
238 let result = f.finalize(state).unwrap();
240 assert_eq!(result, SqliteValue::Integer(42));
241 }
243}