#![allow(clippy::unnecessary_literal_bound)]
use std::any::Any;
use fsqlite_error::Result;
use fsqlite_types::SqliteValue;
pub trait WindowFunction: Send + Sync {
type State: Send;
fn initial_state(&self) -> Self::State;
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()>;
fn inverse(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()>;
fn value(&self, state: &Self::State) -> Result<SqliteValue>;
fn finalize(&self, state: Self::State) -> Result<SqliteValue>;
fn num_args(&self) -> i32;
fn min_args(&self) -> i32 {
self.num_args().max(0)
}
fn max_args(&self) -> Option<i32> {
(self.num_args() >= 0).then(|| self.num_args())
}
fn accepts_arg_count(&self, num_args: i32) -> bool {
num_args >= self.min_args() && self.max_args().is_none_or(|max| num_args <= max)
}
fn name(&self) -> &str;
}
pub struct WindowAdapter<F> {
inner: F,
}
impl<F> WindowAdapter<F> {
pub const fn new(inner: F) -> Self {
Self { inner }
}
}
impl<F> WindowFunction for WindowAdapter<F>
where
F: WindowFunction,
F::State: 'static,
{
type State = Box<dyn Any + Send>;
fn initial_state(&self) -> Self::State {
Box::new(self.inner.initial_state())
}
fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
let concrete = state
.downcast_mut::<F::State>()
.expect("window state type mismatch");
self.inner.step(concrete, args)
}
fn inverse(&self, state: &mut Self::State, args: &[SqliteValue]) -> Result<()> {
let concrete = state
.downcast_mut::<F::State>()
.expect("window state type mismatch");
self.inner.inverse(concrete, args)
}
fn value(&self, state: &Self::State) -> Result<SqliteValue> {
let concrete = state
.downcast_ref::<F::State>()
.expect("window state type mismatch");
self.inner.value(concrete)
}
fn finalize(&self, state: Self::State) -> Result<SqliteValue> {
let concrete = *state
.downcast::<F::State>()
.expect("window state type mismatch");
self.inner.finalize(concrete)
}
fn num_args(&self) -> i32 {
self.inner.num_args()
}
fn min_args(&self) -> i32 {
self.inner.min_args()
}
fn max_args(&self) -> Option<i32> {
self.inner.max_args()
}
fn name(&self) -> &str {
self.inner.name()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct WindowSum;
impl WindowFunction for WindowSum {
type State = i64;
fn initial_state(&self) -> i64 {
0
}
fn step(&self, state: &mut i64, args: &[SqliteValue]) -> Result<()> {
*state += args[0].to_integer();
Ok(())
}
fn inverse(&self, state: &mut i64, args: &[SqliteValue]) -> Result<()> {
*state -= args[0].to_integer();
Ok(())
}
fn value(&self, state: &i64) -> Result<SqliteValue> {
Ok(SqliteValue::Integer(*state))
}
fn finalize(&self, state: i64) -> Result<SqliteValue> {
Ok(SqliteValue::Integer(state))
}
fn num_args(&self) -> i32 {
1
}
fn name(&self) -> &str {
"window_sum"
}
}
#[test]
fn test_window_function_step_and_inverse() {
let f = WindowSum;
let mut state = f.initial_state();
f.step(&mut state, &[SqliteValue::Integer(10)]).unwrap();
f.step(&mut state, &[SqliteValue::Integer(20)]).unwrap();
f.step(&mut state, &[SqliteValue::Integer(30)]).unwrap();
assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(60));
f.inverse(&mut state, &[SqliteValue::Integer(10)]).unwrap();
f.step(&mut state, &[SqliteValue::Integer(40)]).unwrap();
assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(90));
f.inverse(&mut state, &[SqliteValue::Integer(20)]).unwrap();
assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(70));
}
#[test]
fn test_window_function_value_without_consuming() {
let f = WindowSum;
let mut state = f.initial_state();
f.step(&mut state, &[SqliteValue::Integer(42)]).unwrap();
assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(42));
assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(42));
assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(42));
f.step(&mut state, &[SqliteValue::Integer(8)]).unwrap();
assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(50));
}
#[test]
fn test_window_function_finalize_consumes() {
let f = WindowSum;
let mut state = f.initial_state();
f.step(&mut state, &[SqliteValue::Integer(10)]).unwrap();
f.step(&mut state, &[SqliteValue::Integer(32)]).unwrap();
let result = f.finalize(state).unwrap();
assert_eq!(result, SqliteValue::Integer(42));
}
}