1use crate::ast::*;
18use crate::entities::SchemaType;
19use crate::evaluator;
20use std::any::Any;
21use std::collections::{BTreeSet, HashMap};
22use std::fmt::Debug;
23use std::panic::{RefUnwindSafe, UnwindSafe};
24use std::sync::Arc;
25
26pub struct Extension {
33 name: Name,
35 functions: HashMap<Name, ExtensionFunction>,
37 types_with_operator_overloading: BTreeSet<Name>,
39}
40
41impl Extension {
42 pub fn new(
44 name: Name,
45 functions: impl IntoIterator<Item = ExtensionFunction>,
46 types_with_operator_overloading: impl IntoIterator<Item = Name>,
47 ) -> Self {
48 Self {
49 name,
50 functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
51 types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(),
52 }
53 }
54
55 pub fn name(&self) -> &Name {
57 &self.name
58 }
59
60 pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
63 self.functions.get(name)
64 }
65
66 pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
68 self.functions.values()
69 }
70
71 pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
74 self.funcs().flat_map(|func| func.ext_types())
75 }
76
77 pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
79 self.types_with_operator_overloading.iter()
80 }
81}
82
83impl std::fmt::Debug for Extension {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "<extension {}>", self.name())
86 }
87}
88
89#[derive(Debug, Clone)]
91pub enum ExtensionOutputValue {
92 Known(Value),
94 Unknown(Unknown),
96}
97
98impl<T> From<T> for ExtensionOutputValue
99where
100 T: Into<Value>,
101{
102 fn from(v: T) -> Self {
103 ExtensionOutputValue::Known(v.into())
104 }
105}
106
107#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
109#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
110pub enum CallStyle {
111 FunctionStyle,
113 MethodStyle,
115}
116
117macro_rules! extension_function_object {
120 ( $( $tys:ty ), * ) => {
121 Box<dyn Fn($($tys,)*) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>
122 }
123}
124
125pub type ExtensionFunctionObject = extension_function_object!(&[Value]);
127pub type NullaryExtensionFunctionObject = extension_function_object!();
129pub type UnaryExtensionFunctionObject = extension_function_object!(&Value);
131pub type BinaryExtensionFunctionObject = extension_function_object!(&Value, &Value);
133pub type TernaryExtensionFunctionObject = extension_function_object!(&Value, &Value, &Value);
135pub type VariadicExtensionFunctionObject = extension_function_object!(&Value, &[Value]);
137
138pub struct ExtensionFunction {
141 name: Name,
143 style: CallStyle,
145 func: ExtensionFunctionObject,
148 return_type: Option<SchemaType>,
156 arg_types: Vec<SchemaType>,
158 is_variadic: bool,
161}
162
163impl ExtensionFunction {
164 fn new(
166 name: Name,
167 style: CallStyle,
168 func: ExtensionFunctionObject,
169 return_type: Option<SchemaType>,
170 arg_types: Vec<SchemaType>,
171 is_variadic: bool,
172 ) -> Self {
173 Self {
174 name,
175 style,
176 func,
177 return_type,
178 arg_types,
179 is_variadic,
180 }
181 }
182
183 pub fn nullary(
185 name: Name,
186 style: CallStyle,
187 func: NullaryExtensionFunctionObject,
188 return_type: SchemaType,
189 ) -> Self {
190 Self::new(
191 name.clone(),
192 style,
193 Box::new(move |args: &[Value]| {
194 if args.is_empty() {
195 func()
196 } else {
197 Err(evaluator::EvaluationError::wrong_num_arguments(
198 name.clone(),
199 0,
200 args.len(),
201 None, ))
203 }
204 }),
205 Some(return_type),
206 vec![],
207 false,
208 )
209 }
210
211 pub fn partial_eval_unknown(
214 name: Name,
215 style: CallStyle,
216 func: UnaryExtensionFunctionObject,
217 arg_type: SchemaType,
218 ) -> Self {
219 Self::new(
220 name.clone(),
221 style,
222 Box::new(move |args: &[Value]| match args.first() {
223 Some(arg) => func(arg),
224 None => Err(evaluator::EvaluationError::wrong_num_arguments(
225 name.clone(),
226 1,
227 args.len(),
228 None, )),
230 }),
231 None,
232 vec![arg_type],
233 false,
234 )
235 }
236
237 pub fn unary(
239 name: Name,
240 style: CallStyle,
241 func: UnaryExtensionFunctionObject,
242 return_type: SchemaType,
243 arg_type: SchemaType,
244 ) -> Self {
245 Self::new(
246 name.clone(),
247 style,
248 Box::new(move |args: &[Value]| match &args {
249 &[arg] => func(arg),
250 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
251 name.clone(),
252 1,
253 args.len(),
254 None, )),
256 }),
257 Some(return_type),
258 vec![arg_type],
259 false,
260 )
261 }
262
263 pub fn binary(
265 name: Name,
266 style: CallStyle,
267 func: BinaryExtensionFunctionObject,
268 return_type: SchemaType,
269 arg_types: (SchemaType, SchemaType),
270 ) -> Self {
271 Self::new(
272 name.clone(),
273 style,
274 Box::new(move |args: &[Value]| match &args {
275 &[first, second] => func(first, second),
276 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
277 name.clone(),
278 2,
279 args.len(),
280 None, )),
282 }),
283 Some(return_type),
284 vec![arg_types.0, arg_types.1],
285 false,
286 )
287 }
288
289 pub fn ternary(
291 name: Name,
292 style: CallStyle,
293 func: TernaryExtensionFunctionObject,
294 return_type: SchemaType,
295 arg_types: (SchemaType, SchemaType, SchemaType),
296 ) -> Self {
297 Self::new(
298 name.clone(),
299 style,
300 Box::new(move |args: &[Value]| match &args {
301 &[first, second, third] => func(first, second, third),
302 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
303 name.clone(),
304 3,
305 args.len(),
306 None, )),
308 }),
309 Some(return_type),
310 vec![arg_types.0, arg_types.1, arg_types.2],
311 false,
312 )
313 }
314
315 pub fn variadic(
317 name: Name,
318 style: CallStyle,
319 func: VariadicExtensionFunctionObject,
320 return_type: SchemaType,
321 arg_types: (SchemaType, SchemaType),
322 ) -> Self {
323 Self::new(
324 name.clone(),
325 style,
326 Box::new(move |args: &[Value]| match &args {
327 #[cfg(feature = "variadic-is-in-range")]
328 &[first, rest @ ..] => func(first, rest),
329 #[cfg(not(feature = "variadic-is-in-range"))]
330 &[first, second] => func(first, std::slice::from_ref(second)),
331 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
332 name.clone(),
333 2,
334 args.len(),
335 None, )),
337 }),
338 Some(return_type),
339 vec![arg_types.0, arg_types.1],
340 #[cfg(feature = "variadic-is-in-range")]
341 true,
342 #[cfg(not(feature = "variadic-is-in-range"))]
343 false,
344 )
345 }
346
347 pub fn name(&self) -> &Name {
349 &self.name
350 }
351
352 pub fn style(&self) -> CallStyle {
354 self.style
355 }
356
357 pub fn return_type(&self) -> Option<&SchemaType> {
361 self.return_type.as_ref()
362 }
363
364 pub fn arg_types(&self) -> &[SchemaType] {
366 &self.arg_types
367 }
368
369 pub fn is_variadic(&self) -> bool {
371 self.is_variadic
372 }
373
374 pub fn is_single_arg_constructor(&self) -> bool {
380 matches!(self.return_type(), Some(SchemaType::Extension { .. }))
382 && matches!(self.arg_types(), [SchemaType::String])
384 }
385
386 pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
388 match (self.func)(args)? {
389 ExtensionOutputValue::Known(v) => Ok(PartialValue::Value(v)),
390 ExtensionOutputValue::Unknown(u) => Ok(PartialValue::Residual(Expr::unknown(u))),
391 }
392 }
393
394 pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
397 self.return_type
398 .iter()
399 .flat_map(|ret_ty| ret_ty.contained_ext_types())
400 }
401}
402
403impl std::fmt::Debug for ExtensionFunction {
404 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405 write!(f, "<extension function {}>", self.name())
406 }
407}
408
409pub trait ExtensionValue: Debug + Send + Sync + UnwindSafe + RefUnwindSafe {
415 fn typename(&self) -> Name;
420
421 fn supports_operator_overloading(&self) -> bool;
423}
424
425impl<V: ExtensionValue> StaticallyTyped for V {
426 fn type_of(&self) -> Type {
427 Type::Extension {
428 name: self.typename(),
429 }
430 }
431}
432
433#[derive(Debug, Clone)]
434pub struct RepresentableExtensionValue {
441 pub(crate) func: Name,
442 pub(crate) args: Vec<RestrictedExpr>,
443 pub(crate) value: Arc<dyn InternalExtensionValue>,
444}
445
446impl RepresentableExtensionValue {
447 pub fn new(
449 value: Arc<dyn InternalExtensionValue + Send + Sync>,
450 func: Name,
451 args: Vec<RestrictedExpr>,
452 ) -> Self {
453 Self { func, args, value }
454 }
455
456 pub fn value(&self) -> &dyn InternalExtensionValue {
458 self.value.as_ref()
459 }
460
461 pub fn typename(&self) -> Name {
463 self.value.typename()
464 }
465
466 pub(crate) fn supports_operator_overloading(&self) -> bool {
468 self.value.supports_operator_overloading()
469 }
470}
471
472impl From<RepresentableExtensionValue> for RestrictedExpr {
473 fn from(val: RepresentableExtensionValue) -> Self {
474 RestrictedExpr::call_extension_fn(val.func, val.args)
475 }
476}
477
478impl StaticallyTyped for RepresentableExtensionValue {
479 fn type_of(&self) -> Type {
480 self.value.type_of()
481 }
482}
483
484impl PartialEq for RepresentableExtensionValue {
485 fn eq(&self, other: &Self) -> bool {
486 self.value.as_ref() == other.value.as_ref()
488 }
489}
490
491impl Eq for RepresentableExtensionValue {}
492
493impl PartialOrd for RepresentableExtensionValue {
494 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
495 Some(self.cmp(other))
496 }
497}
498
499impl Ord for RepresentableExtensionValue {
500 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
501 self.value.cmp(&other.value)
502 }
503}
504
505pub trait InternalExtensionValue: ExtensionValue {
518 fn as_any(&self) -> &dyn Any;
520 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
523 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
526}
527
528impl<V: 'static + Eq + Ord + ExtensionValue + Send + Sync + Clone> InternalExtensionValue for V {
529 fn as_any(&self) -> &dyn Any {
530 self
531 }
532
533 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
534 other
535 .as_any()
536 .downcast_ref::<V>()
537 .map(|v| self == v)
538 .unwrap_or(false) }
540
541 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
542 other
543 .as_any()
544 .downcast_ref::<V>()
545 .map(|v| self.cmp(v))
546 .unwrap_or_else(|| {
547 self.typename().cmp(&other.typename())
550 })
551 }
552}
553
554impl PartialEq for dyn InternalExtensionValue {
555 fn eq(&self, other: &Self) -> bool {
556 self.equals_extvalue(other)
557 }
558}
559
560impl Eq for dyn InternalExtensionValue {}
561
562impl PartialOrd for dyn InternalExtensionValue {
563 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
564 Some(self.cmp(other))
565 }
566}
567
568impl Ord for dyn InternalExtensionValue {
569 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
570 self.cmp_extvalue(other)
571 }
572}
573
574impl StaticallyTyped for dyn InternalExtensionValue {
575 fn type_of(&self) -> Type {
576 Type::Extension {
577 name: self.typename(),
578 }
579 }
580}