1#![allow(clippy::unnecessary_literal_bound)]
8
9use std::any::Any;
10
11use fsqlite_error::Result;
12use fsqlite_types::SqliteValue;
13
14pub trait WindowFunction: Send + Sync {
29 type State: Send;
31
32 fn initial_state(&self) -> Self::State;
34
35 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()>;
37
38 fn inverse(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()>;
44
45 fn value(&self, state: &Self::State) -> Result<SqliteValue>;
50
51 fn finalize(&self, state: Self::State) -> Result<SqliteValue>;
53
54 fn num_args(&self) -> i32;
56
57 fn name(&self) -> &str;
59}
60
61pub struct WindowAdapter<F> {
64 inner: F,
65}
66
67impl<F> WindowAdapter<F> {
68 pub const fn new(inner: F) -> Self {
70 Self { inner }
71 }
72}
73
74impl<F> WindowFunction for WindowAdapter<F>
75where
76 F: WindowFunction,
77 F::State: 'static,
78{
79 type State = Box<dyn Any + Send>;
80
81 fn initial_state(&self) -> Self::State {
82 Box::new(self.inner.initial_state())
83 }
84
85 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
86 let concrete = state
87 .downcast_mut::<F::State>()
88 .expect("window state type mismatch");
89 self.inner.step(concrete, args)
90 }
91
92 fn inverse(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
93 let concrete = state
94 .downcast_mut::<F::State>()
95 .expect("window state type mismatch");
96 self.inner.inverse(concrete, args)
97 }
98
99 fn value(&self, state: &Self::State) -> Result<SqliteValue> {
100 let concrete = state
101 .downcast_ref::<F::State>()
102 .expect("window state type mismatch");
103 self.inner.value(concrete)
104 }
105
106 fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
107 let concrete = *state
108 .downcast::<F::State>()
109 .expect("window state type mismatch");
110 self.inner.finalize(concrete)
111 }
112
113 fn num_args(&self) -> i32 {
114 self.inner.num_args()
115 }
116
117 fn name(&self) -> &str {
118 self.inner.name()
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 struct WindowSum;
129
130 impl WindowFunction for WindowSum {
131 type State = i64;
132
133 fn initial_state(&self) -> i64 {
134 0
135 }
136
137 fn step(&self, state: &mut i64, args: &[SqliteValue]) -> Result<()> {
138 *state += args[0].to_integer();
139 Ok(())
140 }
141
142 fn inverse(&self, state: &mut i64, args: &[SqliteValue]) -> Result<()> {
143 *state -= args[0].to_integer();
144 Ok(())
145 }
146
147 fn value(&self, state: &i64) -> Result<SqliteValue> {
148 Ok(SqliteValue::Integer(*state))
149 }
150
151 fn finalize(&self, state: i64) -> Result<SqliteValue> {
152 Ok(SqliteValue::Integer(state))
153 }
154
155 fn num_args(&self) -> i32 {
156 1
157 }
158
159 fn name(&self) -> &str {
160 "window_sum"
161 }
162 }
163
164 #[test]
165 fn test_window_function_step_and_inverse() {
166 let f = WindowSum;
167 let mut state = f.initial_state();
168
169 f.step(&mut state, &[SqliteValue::Integer(10)]).unwrap();
171 f.step(&mut state, &[SqliteValue::Integer(20)]).unwrap();
172 f.step(&mut state, &[SqliteValue::Integer(30)]).unwrap();
173 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(60));
174
175 f.inverse(&mut state, &[SqliteValue::Integer(10)]).unwrap();
177 f.step(&mut state, &[SqliteValue::Integer(40)]).unwrap();
178 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(90));
179
180 f.inverse(&mut state, &[SqliteValue::Integer(20)]).unwrap();
182 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(70));
183 }
184
185 #[test]
186 fn test_window_function_value_without_consuming() {
187 let f = WindowSum;
188 let mut state = f.initial_state();
189
190 f.step(&mut state, &[SqliteValue::Integer(42)]).unwrap();
191
192 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(42));
194 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(42));
195 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(42));
196
197 f.step(&mut state, &[SqliteValue::Integer(8)]).unwrap();
199 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(50));
200 }
201
202 #[test]
203 fn test_window_function_finalize_consumes() {
204 let f = WindowSum;
205 let mut state = f.initial_state();
206
207 f.step(&mut state, &[SqliteValue::Integer(10)]).unwrap();
208 f.step(&mut state, &[SqliteValue::Integer(32)]).unwrap();
209
210 let result = f.finalize(state).unwrap();
212 assert_eq!(result, SqliteValue::Integer(42));
213 }
215}