#![allow(clippy::unnecessary_literal_bound)]
use fsqlite_error::Result;
use fsqlite_types::SqliteValue;
pub trait ScalarFunction: Send + Sync {
fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue>;
fn is_deterministic(&self) -> bool {
true
}
fn num_args(&self) -> i32;
fn min_args(&self) -> i32 {
self.num_args().max(0)
}
fn max_args(&self) -> Option<i32> {
(self.num_args() >= 0).then(|| self.num_args())
}
fn accepts_arg_count(&self, num_args: i32) -> bool {
num_args >= self.min_args() && self.max_args().is_none_or(|max| num_args <= max)
}
fn name(&self) -> &str;
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use fsqlite_error::FrankenError;
use super::*;
struct AddOne;
impl ScalarFunction for AddOne {
fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
match &args[0] {
SqliteValue::Integer(i) => Ok(SqliteValue::Integer(i + 1)),
SqliteValue::Float(f) => Ok(SqliteValue::Float(f + 1.0)),
SqliteValue::Null => Ok(SqliteValue::Null),
SqliteValue::Text(s) => {
let n: i64 = s.parse().unwrap_or(0);
Ok(SqliteValue::Integer(n + 1))
}
SqliteValue::Blob(_) => Ok(SqliteValue::Integer(1)),
}
}
fn num_args(&self) -> i32 {
1
}
fn name(&self) -> &str {
"add_one"
}
}
struct NonDeterministic;
impl ScalarFunction for NonDeterministic {
fn invoke(&self, _args: &[SqliteValue]) -> Result<SqliteValue> {
Ok(SqliteValue::Integer(42))
}
fn is_deterministic(&self) -> bool {
false
}
fn num_args(&self) -> i32 {
0
}
fn name(&self) -> &str {
"random_ish"
}
}
struct Concat;
impl ScalarFunction for Concat {
fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
let mut result = String::new();
for arg in args {
result.push_str(&arg.to_text());
}
Ok(SqliteValue::Text(result.into()))
}
fn num_args(&self) -> i32 {
-1
}
fn name(&self) -> &str {
"concat"
}
}
struct SafeAbs;
impl ScalarFunction for SafeAbs {
fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
match &args[0] {
SqliteValue::Integer(i) => {
if *i == i64::MIN {
return Err(FrankenError::function_error("abs(i64::MIN) would overflow"));
}
Ok(SqliteValue::Integer(i.abs()))
}
_ => Ok(args[0].clone()),
}
}
fn num_args(&self) -> i32 {
1
}
fn name(&self) -> &str {
"safe_abs"
}
}
struct BigResult;
impl ScalarFunction for BigResult {
fn invoke(&self, _args: &[SqliteValue]) -> Result<SqliteValue> {
Err(FrankenError::TooBig)
}
fn num_args(&self) -> i32 {
0
}
fn name(&self) -> &str {
"big_result"
}
}
#[test]
fn test_scalar_function_invoke_basic() {
let f = AddOne;
assert_eq!(
f.invoke(&[SqliteValue::Integer(41)]).unwrap(),
SqliteValue::Integer(42)
);
assert_eq!(
f.invoke(&[SqliteValue::Float(1.5)]).unwrap(),
SqliteValue::Float(2.5)
);
assert!(f.invoke(&[SqliteValue::Null]).unwrap().is_null());
assert_eq!(
f.invoke(&[SqliteValue::Text("99".into())]).unwrap(),
SqliteValue::Integer(100)
);
}
#[test]
fn test_scalar_function_deterministic_flag() {
let det = AddOne;
assert!(det.is_deterministic());
let non_det = NonDeterministic;
assert!(!non_det.is_deterministic());
}
#[test]
fn test_scalar_function_variadic() {
let f = Concat;
assert_eq!(f.num_args(), -1);
assert_eq!(f.min_args(), 0);
assert_eq!(f.max_args(), None);
assert!(f.accepts_arg_count(0));
assert!(f.accepts_arg_count(3));
assert_eq!(f.invoke(&[]).unwrap(), SqliteValue::Text("".into()));
assert_eq!(
f.invoke(&[SqliteValue::Text("hello".into())]).unwrap(),
SqliteValue::Text("hello".into())
);
assert_eq!(
f.invoke(&[
SqliteValue::Text("a".into()),
SqliteValue::Text("b".into()),
SqliteValue::Text("c".into()),
])
.unwrap(),
SqliteValue::Text("abc".into())
);
}
#[test]
fn test_scalar_function_error_domain() {
let f = SafeAbs;
let err = f.invoke(&[SqliteValue::Integer(i64::MIN)]).unwrap_err();
assert!(
matches!(err, FrankenError::FunctionError(ref msg) if msg.contains("overflow")),
"expected FunctionError, got {err:?}"
);
}
#[test]
fn test_scalar_function_too_big_error() {
let f = BigResult;
let err = f.invoke(&[]).unwrap_err();
assert!(matches!(err, FrankenError::TooBig));
}
#[test]
fn test_scalar_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<AddOne>();
let f: Arc<dyn ScalarFunction> = Arc::new(AddOne);
let f2 = Arc::clone(&f);
let handle = std::thread::spawn(move || f2.invoke(&[SqliteValue::Integer(0)]));
let _ = f.invoke(&[SqliteValue::Integer(1)]);
handle.join().unwrap().unwrap();
}
}