use super::{ScalarFunction, TableFunction};
use arrow_schema::DataType;
use std::collections::HashMap;
pub struct FunctionSignature {
pub name: String,
pub arg_types: Vec<SigDataType>,
pub variadic: bool,
pub return_type: SigDataType,
pub function: FunctionKind,
}
pub enum FunctionKind {
Scalar(ScalarFunction),
Table(TableFunction),
}
impl FunctionKind {
pub fn is_scalar(&self) -> bool {
matches!(self, Self::Scalar(_))
}
pub fn is_table(&self) -> bool {
matches!(self, Self::Table(_))
}
pub fn as_scalar(&self) -> Option<ScalarFunction> {
match self {
Self::Scalar(f) => Some(*f),
_ => None,
}
}
pub fn as_table(&self) -> Option<TableFunction> {
match self {
Self::Table(f) => Some(*f),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum SigDataType {
Exact(DataType),
Any,
}
impl FunctionSignature {
fn matches(&self, arg_types: &[DataType], return_type: &DataType) -> bool {
if !self.return_type.matches(return_type) {
return false;
}
if arg_types.len() < self.arg_types.len() {
return false;
}
for (target, ty) in self.arg_types.iter().zip(arg_types) {
if !target.matches(ty) {
return false;
}
}
if self.variadic {
true
} else {
arg_types.len() == self.arg_types.len()
}
}
}
impl SigDataType {
fn matches(&self, data_type: &DataType) -> bool {
match self {
Self::Exact(ty) => ty == data_type,
Self::Any => true,
}
}
}
impl From<DataType> for SigDataType {
fn from(dt: DataType) -> Self {
Self::Exact(dt)
}
}
#[doc(hidden)]
#[linkme::distributed_slice]
pub static SIGNATURES: [fn() -> FunctionSignature];
lazy_static::lazy_static! {
pub static ref REGISTRY: FunctionRegistry = {
let mut signatures = HashMap::<String, Vec<FunctionSignature>>::new();
for sig in SIGNATURES {
let sig = sig();
signatures.entry(sig.name.clone()).or_default().push(sig);
}
FunctionRegistry { signatures }
};
}
#[derive(Default)]
pub struct FunctionRegistry {
signatures: HashMap<String, Vec<FunctionSignature>>,
}
impl FunctionRegistry {
pub fn get(
&self,
name: &str,
arg_types: &[DataType],
return_type: &DataType,
) -> Option<&FunctionSignature> {
let sigs = self.signatures.get(name)?;
sigs.iter().find(|sig| sig.matches(arg_types, return_type))
}
pub fn iter(&self) -> impl Iterator<Item = &FunctionSignature> {
self.signatures.values().flatten()
}
}