#[cfg(feature = "ipaddr")]
pub mod ipaddr;
#[cfg(feature = "decimal")]
pub mod decimal;
#[cfg(feature = "datetime")]
pub mod datetime;
pub mod partial_evaluation;
use std::collections::{HashMap, HashSet};
use std::sync::LazyLock;
use crate::ast::{CallStyle, Extension, ExtensionFunction, Name, UnreservedId};
use crate::entities::SchemaType;
use crate::extensions::extension_initialization_errors::MultipleConstructorsSameSignatureError;
use crate::fuzzy_match::fuzzy_search_limited;
use crate::parser::Loc;
use miette::Diagnostic;
use smol_str::{SmolStr, ToSmolStr};
use thiserror::Error;
use self::extension_function_lookup_errors::FuncDoesNotExistError;
use self::extension_initialization_errors::FuncMultiplyDefinedError;
static ALL_AVAILABLE_EXTENSION_OBJECTS: LazyLock<Vec<Extension>> = LazyLock::new(|| {
vec![
#[cfg(feature = "ipaddr")]
ipaddr::extension(),
#[cfg(feature = "decimal")]
decimal::extension(),
#[cfg(feature = "datetime")]
datetime::extension(),
#[cfg(feature = "partial-eval")]
partial_evaluation::extension(),
]
});
static ALL_AVAILABLE_EXTENSIONS: LazyLock<Extensions<'static>> =
LazyLock::new(Extensions::build_all_available);
static EXTENSIONS_NONE: LazyLock<Extensions<'static>> = LazyLock::new(|| Extensions {
extensions: &[],
functions: HashMap::new(),
single_arg_constructors: HashMap::new(),
});
static EXTENSION_STYLES: LazyLock<ExtStyles<'static>> = LazyLock::new(ExtStyles::load);
#[derive(Debug)]
pub struct Extensions<'a> {
extensions: &'a [Extension],
functions: HashMap<&'a Name, &'a ExtensionFunction>,
single_arg_constructors: HashMap<&'a SchemaType, &'a ExtensionFunction>,
}
impl Extensions<'static> {
fn build_all_available() -> Extensions<'static> {
#[expect(
clippy::expect_used,
reason = "Builtin extensions define functions/constructors only once. Also tested by many different test cases."
)]
Self::specific_extensions(&ALL_AVAILABLE_EXTENSION_OBJECTS)
.expect("Default extensions should never error on initialization")
}
pub fn all_available() -> &'static Extensions<'static> {
&ALL_AVAILABLE_EXTENSIONS
}
pub fn none() -> &'static Extensions<'static> {
&EXTENSIONS_NONE
}
}
impl<'a> Extensions<'a> {
pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
self.extensions
.iter()
.flat_map(|ext| ext.types_with_operator_overloading())
}
pub fn specific_extensions(
extensions: &'a [Extension],
) -> std::result::Result<Extensions<'a>, ExtensionInitializationError> {
let functions = util::collect_no_duplicates(
extensions
.iter()
.flat_map(|e| e.funcs())
.map(|f| (f.name(), f)),
)
.map_err(|name| FuncMultiplyDefinedError { name: name.clone() })?;
let single_arg_constructors = util::collect_no_duplicates(
extensions
.iter()
.flat_map(|e| e.funcs())
.filter(|f| f.is_single_arg_constructor())
.filter_map(|f| f.return_type().map(|return_type| (return_type, f))),
)
.map_err(|return_type| MultipleConstructorsSameSignatureError {
return_type: Box::new(return_type.clone()),
})?;
Ok(Extensions {
extensions,
functions,
single_arg_constructors,
})
}
pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
self.extensions.iter().map(|ext| ext.name())
}
pub fn ext_types(&self) -> impl Iterator<Item = &Name> {
self.extensions.iter().flat_map(|ext| ext.ext_types())
}
pub fn func(
&self,
name: &Name,
) -> std::result::Result<&ExtensionFunction, ExtensionFunctionLookupError> {
self.functions.get(name).copied().ok_or_else(|| {
FuncDoesNotExistError {
name: name.clone(),
source_loc: name.loc().cloned(),
}
.into()
})
}
pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
self.extensions.iter().flat_map(|ext| ext.funcs())
}
pub(crate) fn lookup_single_arg_constructor(
&self,
return_type: &SchemaType,
) -> Option<&ExtensionFunction> {
self.single_arg_constructors.get(return_type).copied()
}
}
#[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
pub enum ExtensionInitializationError {
#[error(transparent)]
#[diagnostic(transparent)]
FuncMultiplyDefined(#[from] extension_initialization_errors::FuncMultiplyDefinedError),
#[error(transparent)]
#[diagnostic(transparent)]
MultipleConstructorsSameSignature(
#[from] extension_initialization_errors::MultipleConstructorsSameSignatureError,
),
}
mod extension_initialization_errors {
use crate::{ast::Name, entities::SchemaType};
use miette::Diagnostic;
use thiserror::Error;
#[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
#[error("extension function `{name}` is defined multiple times")]
pub struct FuncMultiplyDefinedError {
pub(crate) name: Name,
}
#[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
#[error("multiple extension constructors for the same extension type {return_type}")]
pub struct MultipleConstructorsSameSignatureError {
pub(crate) return_type: Box<SchemaType>,
}
}
#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
pub enum ExtensionFunctionLookupError {
#[error(transparent)]
#[diagnostic(transparent)]
FuncDoesNotExist(#[from] extension_function_lookup_errors::FuncDoesNotExistError),
}
impl ExtensionFunctionLookupError {
pub(crate) fn source_loc(&self) -> Option<&Loc> {
match self {
Self::FuncDoesNotExist(e) => e.source_loc.as_ref(),
}
}
pub(crate) fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
match self {
Self::FuncDoesNotExist(e) => {
Self::FuncDoesNotExist(extension_function_lookup_errors::FuncDoesNotExistError {
source_loc,
..e
})
}
}
}
}
pub mod extension_function_lookup_errors {
use crate::ast::Name;
use crate::parser::Loc;
use miette::Diagnostic;
use thiserror::Error;
#[derive(Debug, PartialEq, Eq, Clone, Error)]
#[error("extension function `{name}` does not exist")]
pub struct FuncDoesNotExistError {
pub(crate) name: Name,
pub(crate) source_loc: Option<Loc>,
}
impl Diagnostic for FuncDoesNotExistError {
impl_diagnostic_from_source_loc_opt_field!(source_loc);
}
}
pub type Result<T> = std::result::Result<T, ExtensionFunctionLookupError>;
#[derive(Debug)]
pub(crate) struct ExtStyles<'a> {
functions: HashSet<&'a Name>,
methods: HashSet<UnreservedId>,
functions_and_methods_as_str: HashSet<SmolStr>,
}
impl ExtStyles<'static> {
fn load() -> ExtStyles<'static> {
let mut functions = HashSet::new();
let mut methods = HashSet::new();
let mut functions_and_methods_as_str = HashSet::new();
for func in crate::extensions::Extensions::all_available().all_funcs() {
functions_and_methods_as_str.insert(func.name().to_smolstr());
match func.style() {
CallStyle::FunctionStyle => {
functions.insert(func.name());
}
CallStyle::MethodStyle => {
debug_assert!(func.name().is_unqualified());
methods.insert(func.name().basename());
}
};
}
ExtStyles {
functions,
methods,
functions_and_methods_as_str,
}
}
pub(crate) fn is_method(id: &UnreservedId) -> bool {
EXTENSION_STYLES.methods.contains(id)
}
pub(crate) fn is_function(id: &Name) -> bool {
EXTENSION_STYLES.functions.contains(id)
}
pub(crate) fn is_known_extension_func_name(name: &Name) -> bool {
Self::is_function(name) || (name.0.path.is_empty() && Self::is_method(&name.basename()))
}
pub(crate) fn is_known_extension_func_str(s: &SmolStr) -> bool {
EXTENSION_STYLES.functions_and_methods_as_str.contains(s)
}
fn suggest<I, T>(key: &str, choices: I) -> Option<String>
where
I: IntoIterator<Item = T>,
T: ToString,
{
const SUGGESTION_EXTENSION_MAX_DISTANCE: usize = 3;
let choice_strings: Vec<String> = choices.into_iter().map(|c| c.to_string()).collect();
let suggestion = fuzzy_search_limited(
key,
choice_strings.as_slice(),
Some(SUGGESTION_EXTENSION_MAX_DISTANCE),
);
suggestion.map(|m| format!("did you mean `{m}`?"))
}
pub(crate) fn suggest_method(name: &UnreservedId) -> Option<String> {
Self::suggest(name.as_ref(), &EXTENSION_STYLES.methods)
}
pub(crate) fn suggest_function(name: &Name) -> Option<String> {
Self::suggest(&name.to_string(), &EXTENSION_STYLES.functions)
}
}
pub mod util {
use std::collections::{hash_map::Entry, HashMap};
pub fn collect_no_duplicates<K, V>(
i: impl Iterator<Item = (K, V)>,
) -> std::result::Result<HashMap<K, V>, K>
where
K: Clone + std::hash::Hash + Eq,
{
let mut map = HashMap::with_capacity(i.size_hint().0);
for (k, v) in i {
match map.entry(k) {
Entry::Occupied(occupied) => {
return Err(occupied.key().clone());
}
Entry::Vacant(vacant) => {
vacant.insert(v);
}
}
}
Ok(map)
}
}
#[cfg(test)]
mod test {
use super::*;
use std::collections::HashSet;
#[test]
fn no_common_extension_function_names() {
let all_names: Vec<_> = Extensions::all_available()
.extensions
.iter()
.flat_map(|e| e.funcs().map(|f| f.name().clone()))
.collect();
let dedup_names: HashSet<_> = all_names.iter().collect();
assert_eq!(all_names.len(), dedup_names.len());
}
}