use super::{ScalarFunction, TableFunction};
use arrow_schema::{Field, Fields};
use std::collections::HashMap;
pub struct FunctionSignature {
pub name: String,
pub arg_types: Fields,
pub variadic: bool,
pub return_type: Field,
pub function: FunctionKind,
}
pub enum FunctionKind {
Scalar(ScalarFunction),
Table(TableFunction),
}
impl FunctionKind {
pub fn is_scalar(&self) -> bool {
matches!(self, Self::Scalar(_))
}
pub fn is_table(&self) -> bool {
matches!(self, Self::Table(_))
}
pub fn as_scalar(&self) -> Option<ScalarFunction> {
match self {
Self::Scalar(f) => Some(*f),
_ => None,
}
}
pub fn as_table(&self) -> Option<TableFunction> {
match self {
Self::Table(f) => Some(*f),
_ => None,
}
}
}
impl FunctionSignature {
fn matches(&self, arg_types: &[Field], return_type: &Field) -> bool {
if !(self.return_type.data_type() == return_type.data_type()
&& self.return_type.metadata() == return_type.metadata())
{
return false;
}
if arg_types.len() < self.arg_types.len() {
return false;
}
for (target, ty) in self.arg_types.iter().zip(arg_types) {
if !(target.data_type() == ty.data_type() && target.metadata() == ty.metadata()) {
return false;
}
}
if self.variadic {
true
} else {
arg_types.len() == self.arg_types.len()
}
}
}
#[doc(hidden)]
#[linkme::distributed_slice]
pub static SIGNATURES: [fn() -> FunctionSignature];
pub static REGISTRY: std::sync::LazyLock<FunctionRegistry> = std::sync::LazyLock::new(|| {
let mut signatures = HashMap::<String, Vec<FunctionSignature>>::new();
for sig in SIGNATURES {
let sig = sig();
signatures.entry(sig.name.clone()).or_default().push(sig);
}
FunctionRegistry { signatures }
});
#[derive(Default)]
pub struct FunctionRegistry {
signatures: HashMap<String, Vec<FunctionSignature>>,
}
impl FunctionRegistry {
pub fn get(
&self,
name: &str,
arg_types: &[Field],
return_type: &Field,
) -> Option<&FunctionSignature> {
let sigs = self.signatures.get(name)?;
sigs.iter().find(|sig| sig.matches(arg_types, return_type))
}
pub fn iter(&self) -> impl Iterator<Item = &FunctionSignature> {
self.signatures.values().flatten()
}
}