use crate::ast::*;
use crate::entities::SchemaType;
use crate::evaluator;
use std::any::Any;
use std::collections::{BTreeSet, HashMap};
use std::fmt::Debug;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::sync::Arc;
pub struct Extension {
name: Name,
functions: HashMap<Name, ExtensionFunction>,
types_with_operator_overloading: BTreeSet<Name>,
}
impl Extension {
pub fn new(
name: Name,
functions: impl IntoIterator<Item = ExtensionFunction>,
types_with_operator_overloading: impl IntoIterator<Item = Name>,
) -> Self {
Self {
name,
functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(),
}
}
pub fn name(&self) -> &Name {
&self.name
}
pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
self.functions.get(name)
}
pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
self.functions.values()
}
pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
self.funcs().flat_map(|func| func.ext_types())
}
pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
self.types_with_operator_overloading.iter()
}
}
impl std::fmt::Debug for Extension {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<extension {}>", self.name())
}
}
#[derive(Debug, Clone)]
pub enum ExtensionOutputValue {
Known(Value),
Unknown(Unknown),
}
impl<T> From<T> for ExtensionOutputValue
where
T: Into<Value>,
{
fn from(v: T) -> Self {
ExtensionOutputValue::Known(v.into())
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum CallStyle {
FunctionStyle,
MethodStyle,
}
macro_rules! extension_function_object {
( $( $tys:ty ), * ) => {
Box<dyn Fn($($tys,)*) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>
}
}
pub type ExtensionFunctionObject = extension_function_object!(&[Value]);
pub type NullaryExtensionFunctionObject = extension_function_object!();
pub type UnaryExtensionFunctionObject = extension_function_object!(&Value);
pub type BinaryExtensionFunctionObject = extension_function_object!(&Value, &Value);
pub type TernaryExtensionFunctionObject = extension_function_object!(&Value, &Value, &Value);
pub type VariadicExtensionFunctionObject = extension_function_object!(&Value, &[Value]);
pub struct ExtensionFunction {
name: Name,
style: CallStyle,
func: ExtensionFunctionObject,
return_type: Option<SchemaType>,
arg_types: Vec<SchemaType>,
is_variadic: bool,
}
impl ExtensionFunction {
fn new(
name: Name,
style: CallStyle,
func: ExtensionFunctionObject,
return_type: Option<SchemaType>,
arg_types: Vec<SchemaType>,
is_variadic: bool,
) -> Self {
Self {
name,
style,
func,
return_type,
arg_types,
is_variadic,
}
}
pub fn nullary(
name: Name,
style: CallStyle,
func: NullaryExtensionFunctionObject,
return_type: SchemaType,
) -> Self {
Self::new(
name.clone(),
style,
Box::new(move |args: &[Value]| {
if args.is_empty() {
func()
} else {
Err(evaluator::EvaluationError::wrong_num_arguments(
name.clone(),
0,
args.len(),
None, ))
}
}),
Some(return_type),
vec![],
false,
)
}
pub fn partial_eval_unknown(
name: Name,
style: CallStyle,
func: UnaryExtensionFunctionObject,
arg_type: SchemaType,
) -> Self {
Self::new(
name.clone(),
style,
Box::new(move |args: &[Value]| match args.first() {
Some(arg) => func(arg),
None => Err(evaluator::EvaluationError::wrong_num_arguments(
name.clone(),
1,
args.len(),
None, )),
}),
None,
vec![arg_type],
false,
)
}
pub fn unary(
name: Name,
style: CallStyle,
func: UnaryExtensionFunctionObject,
return_type: SchemaType,
arg_type: SchemaType,
) -> Self {
Self::new(
name.clone(),
style,
Box::new(move |args: &[Value]| match &args {
&[arg] => func(arg),
_ => Err(evaluator::EvaluationError::wrong_num_arguments(
name.clone(),
1,
args.len(),
None, )),
}),
Some(return_type),
vec![arg_type],
false,
)
}
pub fn binary(
name: Name,
style: CallStyle,
func: BinaryExtensionFunctionObject,
return_type: SchemaType,
arg_types: (SchemaType, SchemaType),
) -> Self {
Self::new(
name.clone(),
style,
Box::new(move |args: &[Value]| match &args {
&[first, second] => func(first, second),
_ => Err(evaluator::EvaluationError::wrong_num_arguments(
name.clone(),
2,
args.len(),
None, )),
}),
Some(return_type),
vec![arg_types.0, arg_types.1],
false,
)
}
pub fn ternary(
name: Name,
style: CallStyle,
func: TernaryExtensionFunctionObject,
return_type: SchemaType,
arg_types: (SchemaType, SchemaType, SchemaType),
) -> Self {
Self::new(
name.clone(),
style,
Box::new(move |args: &[Value]| match &args {
&[first, second, third] => func(first, second, third),
_ => Err(evaluator::EvaluationError::wrong_num_arguments(
name.clone(),
3,
args.len(),
None, )),
}),
Some(return_type),
vec![arg_types.0, arg_types.1, arg_types.2],
false,
)
}
pub fn variadic(
name: Name,
style: CallStyle,
func: VariadicExtensionFunctionObject,
return_type: SchemaType,
arg_types: (SchemaType, SchemaType),
) -> Self {
Self::new(
name.clone(),
style,
Box::new(move |args: &[Value]| match &args {
#[cfg(feature = "variadic-is-in-range")]
&[first, rest @ ..] => func(first, rest),
#[cfg(not(feature = "variadic-is-in-range"))]
&[first, second] => func(first, std::slice::from_ref(second)),
_ => Err(evaluator::EvaluationError::wrong_num_arguments(
name.clone(),
2,
args.len(),
None, )),
}),
Some(return_type),
vec![arg_types.0, arg_types.1],
#[cfg(feature = "variadic-is-in-range")]
true,
#[cfg(not(feature = "variadic-is-in-range"))]
false,
)
}
pub fn name(&self) -> &Name {
&self.name
}
pub fn style(&self) -> CallStyle {
self.style
}
pub fn return_type(&self) -> Option<&SchemaType> {
self.return_type.as_ref()
}
pub fn arg_types(&self) -> &[SchemaType] {
&self.arg_types
}
pub fn is_variadic(&self) -> bool {
self.is_variadic
}
pub fn is_single_arg_constructor(&self) -> bool {
matches!(self.return_type(), Some(SchemaType::Extension { .. }))
&& matches!(self.arg_types(), [SchemaType::String])
}
pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
match (self.func)(args)? {
ExtensionOutputValue::Known(v) => Ok(PartialValue::Value(v)),
ExtensionOutputValue::Unknown(u) => Ok(PartialValue::Residual(Expr::unknown(u))),
}
}
pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
self.return_type
.iter()
.flat_map(|ret_ty| ret_ty.contained_ext_types())
}
}
impl std::fmt::Debug for ExtensionFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<extension function {}>", self.name())
}
}
pub trait ExtensionValue: Debug + Send + Sync + UnwindSafe + RefUnwindSafe {
fn typename(&self) -> Name;
fn supports_operator_overloading(&self) -> bool;
}
impl<V: ExtensionValue> StaticallyTyped for V {
fn type_of(&self) -> Type {
Type::Extension {
name: self.typename(),
}
}
}
#[derive(Debug, Clone)]
pub struct RepresentableExtensionValue {
pub(crate) func: Name,
pub(crate) args: Vec<RestrictedExpr>,
pub(crate) value: Arc<dyn InternalExtensionValue>,
}
impl RepresentableExtensionValue {
pub fn new(
value: Arc<dyn InternalExtensionValue + Send + Sync>,
func: Name,
args: Vec<RestrictedExpr>,
) -> Self {
Self { func, args, value }
}
pub fn value(&self) -> &dyn InternalExtensionValue {
self.value.as_ref()
}
pub fn typename(&self) -> Name {
self.value.typename()
}
pub(crate) fn supports_operator_overloading(&self) -> bool {
self.value.supports_operator_overloading()
}
}
impl From<RepresentableExtensionValue> for RestrictedExpr {
fn from(val: RepresentableExtensionValue) -> Self {
RestrictedExpr::call_extension_fn(val.func, val.args)
}
}
impl StaticallyTyped for RepresentableExtensionValue {
fn type_of(&self) -> Type {
self.value.type_of()
}
}
impl PartialEq for RepresentableExtensionValue {
fn eq(&self, other: &Self) -> bool {
self.value.as_ref() == other.value.as_ref()
}
}
impl Eq for RepresentableExtensionValue {}
impl PartialOrd for RepresentableExtensionValue {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for RepresentableExtensionValue {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.value.cmp(&other.value)
}
}
pub trait InternalExtensionValue: ExtensionValue {
fn as_any(&self) -> &dyn Any;
fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
}
impl<V: 'static + Eq + Ord + ExtensionValue + Send + Sync + Clone> InternalExtensionValue for V {
fn as_any(&self) -> &dyn Any {
self
}
fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
other
.as_any()
.downcast_ref::<V>()
.map(|v| self == v)
.unwrap_or(false) }
fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
other
.as_any()
.downcast_ref::<V>()
.map(|v| self.cmp(v))
.unwrap_or_else(|| {
self.typename().cmp(&other.typename())
})
}
}
impl PartialEq for dyn InternalExtensionValue {
fn eq(&self, other: &Self) -> bool {
self.equals_extvalue(other)
}
}
impl Eq for dyn InternalExtensionValue {}
impl PartialOrd for dyn InternalExtensionValue {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for dyn InternalExtensionValue {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.cmp_extvalue(other)
}
}
impl StaticallyTyped for dyn InternalExtensionValue {
fn type_of(&self) -> Type {
Type::Extension {
name: self.typename(),
}
}
}