#[cfg(feature = "ipaddr")]
pub mod ipaddr;
#[cfg(feature = "decimal")]
pub mod decimal;
pub mod partial_evaluation;
use crate::ast::{Extension, ExtensionFunction, Name};
use crate::entities::SchemaType;
use miette::Diagnostic;
use thiserror::Error;
lazy_static::lazy_static! {
static ref ALL_AVAILABLE_EXTENSIONS: Vec<Extension> = vec![
#[cfg(feature = "ipaddr")]
ipaddr::extension(),
#[cfg(feature = "decimal")]
decimal::extension(),
#[cfg(feature = "partial-eval")]
partial_evaluation::extension(),
];
}
#[derive(Debug, Clone, Copy)]
pub struct Extensions<'a> {
extensions: &'a [Extension],
}
impl Extensions<'static> {
pub fn all_available() -> Extensions<'static> {
Extensions {
extensions: &ALL_AVAILABLE_EXTENSIONS,
}
}
pub fn none() -> Extensions<'static> {
Extensions { extensions: &[] }
}
}
impl<'a> Extensions<'a> {
pub fn specific_extensions(extensions: &'a [Extension]) -> Extensions<'a> {
Extensions { extensions }
}
pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
self.extensions.iter().map(|ext| ext.name())
}
pub fn func(&self, name: &Name) -> Result<&ExtensionFunction> {
let extension_funcs: Vec<&ExtensionFunction> = self
.extensions
.iter()
.filter_map(|ext| ext.get_func(name))
.collect();
match extension_funcs.first() {
None => Err(ExtensionFunctionLookupError::FuncDoesNotExist { name: name.clone() }),
Some(first) if extension_funcs.len() == 1 => Ok(first),
_ => Err(ExtensionFunctionLookupError::FuncMultiplyDefined {
name: name.clone(),
num_defs: extension_funcs.len(),
}),
}
}
pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
self.extensions.iter().flat_map(|ext| ext.funcs())
}
pub(crate) fn lookup_single_arg_constructor(
&self,
return_type: &SchemaType,
arg_type: &SchemaType,
) -> Result<Option<&ExtensionFunction>> {
let matches = self
.all_funcs()
.filter(|f| {
f.is_constructor()
&& f.return_type() == Some(return_type)
&& f.arg_types().first().map(Option::as_ref) == Some(Some(arg_type))
})
.collect::<Vec<_>>();
match matches.first() {
None => Ok(None),
Some(first) if matches.len() == 1 => Ok(Some(first)),
_ => Err(
ExtensionFunctionLookupError::MultipleConstructorsSameSignature {
return_type: Box::new(return_type.clone()),
arg_type: Box::new(arg_type.clone()),
},
),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
pub enum ExtensionFunctionLookupError {
#[error("extension function `{name}` does not exist")]
FuncDoesNotExist {
name: Name,
},
#[error("extension function `{name}` has no type")]
HasNoType {
name: Name,
},
#[error("extension function `{name}` is defined {num_defs} times")]
FuncMultiplyDefined {
name: Name,
num_defs: usize,
},
#[error(
"multiple extension constructors have the same type signature {arg_type} -> {return_type}"
)]
MultipleConstructorsSameSignature {
return_type: Box<SchemaType>,
arg_type: Box<SchemaType>,
},
}
pub type Result<T> = std::result::Result<T, ExtensionFunctionLookupError>;
#[cfg(test)]
pub(crate) mod test {
use super::*;
use std::collections::HashSet;
#[test]
fn no_common_extension_function_names() {
let all_names: Vec<_> = Extensions::all_available()
.extensions
.iter()
.flat_map(|e| e.funcs().map(|f| f.name().clone()))
.collect();
let dedup_names: HashSet<_> = all_names.iter().collect();
assert_eq!(all_names.len(), dedup_names.len());
}
}