pub mod aggregate;
pub mod registry;
pub mod scalar;
pub mod tvf;
pub mod window;
use crate::core::{Error, Result, Value};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FunctionType {
Aggregate,
Scalar,
Window,
TableValued,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FunctionDataType {
Any,
Integer,
Float,
String,
Boolean,
Timestamp,
Date,
Time,
DateTime,
Json,
Unknown,
}
#[derive(Debug, Clone)]
pub struct FunctionSignature {
pub return_type: FunctionDataType,
pub argument_types: Vec<FunctionDataType>,
pub min_args: usize,
pub max_args: usize,
pub is_variadic: bool,
}
impl FunctionSignature {
pub fn new(
return_type: FunctionDataType,
argument_types: Vec<FunctionDataType>,
min_args: usize,
max_args: usize,
) -> Self {
Self {
return_type,
argument_types,
min_args,
max_args,
is_variadic: false,
}
}
pub fn variadic(return_type: FunctionDataType, arg_type: FunctionDataType) -> Self {
Self {
return_type,
argument_types: vec![arg_type],
min_args: 1,
max_args: usize::MAX,
is_variadic: true,
}
}
pub fn validate_arg_count(&self, count: usize) -> Result<()> {
if count < self.min_args {
return Err(Error::invalid_argument(format!(
"expected at least {} arguments, got {}",
self.min_args, count
)));
}
if count > self.max_args {
return Err(Error::invalid_argument(format!(
"expected at most {} arguments, got {}",
self.max_args, count
)));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct FunctionInfo {
pub name: String,
pub function_type: FunctionType,
pub description: String,
pub signature: FunctionSignature,
pub deterministic: bool,
}
impl FunctionInfo {
pub fn new(
name: impl Into<String>,
function_type: FunctionType,
description: impl Into<String>,
signature: FunctionSignature,
) -> Self {
Self {
name: name.into(),
function_type,
description: description.into(),
signature,
deterministic: true,
}
}
pub fn non_deterministic(mut self) -> Self {
self.deterministic = false;
self
}
pub fn name(&self) -> &str {
&self.name
}
pub fn function_type(&self) -> FunctionType {
self.function_type
}
pub fn description(&self) -> &str {
&self.description
}
pub fn signature(&self) -> &FunctionSignature {
&self.signature
}
}
pub trait AggregateFunction: Send + Sync {
fn name(&self) -> &str;
fn info(&self) -> FunctionInfo;
fn configure(&mut self, _options: &[Value]) {
}
fn set_order_by(&mut self, _directions: Vec<bool>) {
}
fn accumulate(&mut self, value: &Value, distinct: bool);
fn accumulate_with_sort_key(&mut self, value: &Value, sort_keys: Vec<Value>, distinct: bool) {
let _ = sort_keys;
self.accumulate(value, distinct);
}
fn supports_order_by(&self) -> bool {
false
}
fn result(&self) -> Value;
fn reset(&mut self);
fn clone_box(&self) -> Box<dyn AggregateFunction>;
}
pub type NativeFn1 = fn(&mut Value);
pub trait ScalarFunction: Send + Sync {
fn name(&self) -> &str;
fn info(&self) -> FunctionInfo;
fn evaluate(&self, args: &[Value]) -> Result<Value>;
fn clone_box(&self) -> Box<dyn ScalarFunction>;
fn native_fn1(&self) -> Option<NativeFn1> {
None
}
}
pub trait WindowFunction: Send + Sync {
fn name(&self) -> &str;
fn info(&self) -> FunctionInfo;
fn process(&self, partition: &[Value], order_by: &[Value], current_row: usize)
-> Result<Value>;
fn clone_box(&self) -> Box<dyn WindowFunction>;
}
pub use aggregate::{
AvgFunction, CountFunction, FirstFunction, LastFunction, MaxFunction, MinFunction, SumFunction,
};
pub use registry::{global_registry, FunctionRegistry};
pub use scalar::{
AbsFunction, CastFunction, CeilingFunction, CoalesceFunction, CollateFunction, ConcatFunction,
DateTruncFunction, FloorFunction, IfNullFunction, LengthFunction, LowerFunction, NowFunction,
NullIfFunction, RoundFunction, SubstringFunction, TimeTruncFunction, UpperFunction,
VersionFunction,
};
pub use window::{
DenseRankFunction, LagFunction, LeadFunction, NtileFunction, RankFunction, RowNumberFunction,
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_function_signature_validation() {
let sig =
FunctionSignature::new(FunctionDataType::Integer, vec![FunctionDataType::Any], 1, 1);
assert!(sig.validate_arg_count(1).is_ok());
assert!(sig.validate_arg_count(0).is_err());
assert!(sig.validate_arg_count(2).is_err());
}
#[test]
fn test_variadic_signature() {
let sig = FunctionSignature::variadic(FunctionDataType::String, FunctionDataType::Any);
assert!(sig.is_variadic);
assert!(sig.validate_arg_count(1).is_ok());
assert!(sig.validate_arg_count(10).is_ok());
assert!(sig.validate_arg_count(0).is_err());
}
#[test]
fn test_function_info() {
let info = FunctionInfo::new(
"TEST",
FunctionType::Scalar,
"Test function",
FunctionSignature::new(FunctionDataType::Integer, vec![], 0, 0),
);
assert_eq!(info.name, "TEST");
assert_eq!(info.function_type, FunctionType::Scalar);
}
}