use std::collections::BTreeSet;
use crate::ast::{Expr, Name};
use crate::validator::types::Type;
pub struct ExtensionSchema {
name: Name,
function_types: Vec<ExtensionFunctionType>,
types_with_operator_overloading: BTreeSet<Name>,
}
impl std::fmt::Debug for ExtensionSchema {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<extension schema {}>", self.name())
}
}
impl ExtensionSchema {
pub fn new(
name: Name,
function_types: impl IntoIterator<Item = ExtensionFunctionType>,
types_with_operator_overloading: impl IntoIterator<Item = Name>,
) -> Self {
Self {
name,
function_types: function_types.into_iter().collect(),
types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(),
}
}
pub fn name(&self) -> &Name {
&self.name
}
pub fn function_types(&self) -> impl Iterator<Item = &ExtensionFunctionType> {
self.function_types.iter()
}
pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> {
self.types_with_operator_overloading.iter()
}
}
pub(crate) type ArgumentCheckFn =
Box<dyn Fn(&[Expr]) -> Result<(), String> + Sync + Send + 'static>;
pub struct ExtensionFunctionType {
name: Name,
argument_types: Vec<Type>,
return_type: Type,
check_arguments: Option<ArgumentCheckFn>,
is_variadic: bool,
}
impl ExtensionFunctionType {
pub fn new(
name: Name,
argument_types: Vec<Type>,
return_type: Type,
check_arguments: Option<ArgumentCheckFn>,
is_variadic: bool,
) -> Self {
Self {
name,
argument_types,
return_type,
check_arguments,
is_variadic,
}
}
pub fn name(&self) -> &Name {
&self.name
}
pub fn argument_types(&self) -> &Vec<Type> {
&self.argument_types
}
pub fn return_type(&self) -> &Type {
&self.return_type
}
pub fn check_arguments(&self, args: &[Expr]) -> Result<(), String> {
if let Some(f) = &self.check_arguments {
return (f)(args);
}
Ok(())
}
pub fn has_argument_check(&self) -> bool {
self.check_arguments.is_some()
}
pub fn is_variadic(&self) -> bool {
self.is_variadic
}
}
impl std::fmt::Debug for ExtensionFunctionType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<extension function type {}>", self.name())
}
}