1use crate::ast::*;
18use crate::entities::SchemaType;
19use crate::evaluator;
20use std::any::Any;
21use std::collections::{BTreeSet, HashMap};
22use std::fmt::Debug;
23use std::panic::{RefUnwindSafe, UnwindSafe};
24use std::sync::Arc;
25
26pub struct Extension {
33 name: Name,
35 functions: HashMap<Name, ExtensionFunction>,
37 types_with_operator_overloading: BTreeSet<Name>,
39}
40
41impl Extension {
42 pub fn new(
44 name: Name,
45 functions: impl IntoIterator<Item = ExtensionFunction>,
46 types_with_operator_overloading: impl IntoIterator<Item = Name>,
47 ) -> Self {
48 Self {
49 name,
50 functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
51 types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(),
52 }
53 }
54
55 pub fn name(&self) -> &Name {
57 &self.name
58 }
59
60 pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
63 self.functions.get(name)
64 }
65
66 pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
68 self.functions.values()
69 }
70
71 pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
74 self.funcs().flat_map(|func| func.ext_types())
75 }
76
77 pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
79 self.types_with_operator_overloading.iter()
80 }
81}
82
83impl std::fmt::Debug for Extension {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "<extension {}>", self.name())
86 }
87}
88
89#[derive(Debug, Clone)]
91pub enum ExtensionOutputValue {
92 Known(Value),
94 Unknown(Unknown),
96}
97
98impl<T> From<T> for ExtensionOutputValue
99where
100 T: Into<Value>,
101{
102 fn from(v: T) -> Self {
103 ExtensionOutputValue::Known(v.into())
104 }
105}
106
107#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
109#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
110pub enum CallStyle {
111 FunctionStyle,
113 MethodStyle,
115}
116
117macro_rules! extension_function_object {
120 ( $( $tys:ty ), * ) => {
121 Box<dyn Fn($($tys,)*) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>
122 }
123}
124
125pub type ExtensionFunctionObject = extension_function_object!(&[Value]);
127pub type NullaryExtensionFunctionObject = extension_function_object!();
129pub type UnaryExtensionFunctionObject = extension_function_object!(&Value);
131pub type BinaryExtensionFunctionObject = extension_function_object!(&Value, &Value);
133pub type TernaryExtensionFunctionObject = extension_function_object!(&Value, &Value, &Value);
135pub type VariadicExtensionFunctionObject = extension_function_object!(&Value, &[Value]);
137
138pub struct ExtensionFunction {
141 name: Name,
143 style: CallStyle,
145 func: ExtensionFunctionObject,
148 return_type: Option<SchemaType>,
156 arg_types: Vec<SchemaType>,
158 is_variadic: bool,
161}
162
163impl ExtensionFunction {
164 fn new(
166 name: Name,
167 style: CallStyle,
168 func: ExtensionFunctionObject,
169 return_type: Option<SchemaType>,
170 arg_types: Vec<SchemaType>,
171 is_variadic: bool,
172 ) -> Self {
173 Self {
174 name,
175 style,
176 func,
177 return_type,
178 arg_types,
179 is_variadic,
180 }
181 }
182
183 pub fn nullary(
185 name: Name,
186 style: CallStyle,
187 func: NullaryExtensionFunctionObject,
188 return_type: SchemaType,
189 ) -> Self {
190 Self::new(
191 name.clone(),
192 style,
193 Box::new(move |args: &[Value]| {
194 if args.is_empty() {
195 func()
196 } else {
197 Err(evaluator::EvaluationError::wrong_num_arguments(
198 name.clone(),
199 0,
200 args.len(),
201 None, ))
203 }
204 }),
205 Some(return_type),
206 vec![],
207 false,
208 )
209 }
210
211 pub fn partial_eval_unknown(
214 name: Name,
215 style: CallStyle,
216 func: UnaryExtensionFunctionObject,
217 arg_type: SchemaType,
218 ) -> Self {
219 Self::new(
220 name.clone(),
221 style,
222 Box::new(move |args: &[Value]| match args.first() {
223 Some(arg) => func(arg),
224 None => Err(evaluator::EvaluationError::wrong_num_arguments(
225 name.clone(),
226 1,
227 args.len(),
228 None, )),
230 }),
231 None,
232 vec![arg_type],
233 false,
234 )
235 }
236
237 #[allow(clippy::type_complexity)]
239 pub fn unary(
240 name: Name,
241 style: CallStyle,
242 func: UnaryExtensionFunctionObject,
243 return_type: SchemaType,
244 arg_type: SchemaType,
245 ) -> Self {
246 Self::new(
247 name.clone(),
248 style,
249 Box::new(move |args: &[Value]| match &args {
250 &[arg] => func(arg),
251 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
252 name.clone(),
253 1,
254 args.len(),
255 None, )),
257 }),
258 Some(return_type),
259 vec![arg_type],
260 false,
261 )
262 }
263
264 #[allow(clippy::type_complexity)]
266 pub fn binary(
267 name: Name,
268 style: CallStyle,
269 func: BinaryExtensionFunctionObject,
270 return_type: SchemaType,
271 arg_types: (SchemaType, SchemaType),
272 ) -> Self {
273 Self::new(
274 name.clone(),
275 style,
276 Box::new(move |args: &[Value]| match &args {
277 &[first, second] => func(first, second),
278 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
279 name.clone(),
280 2,
281 args.len(),
282 None, )),
284 }),
285 Some(return_type),
286 vec![arg_types.0, arg_types.1],
287 false,
288 )
289 }
290
291 #[allow(clippy::type_complexity)]
293 pub fn ternary(
294 name: Name,
295 style: CallStyle,
296 func: TernaryExtensionFunctionObject,
297 return_type: SchemaType,
298 arg_types: (SchemaType, SchemaType, SchemaType),
299 ) -> Self {
300 Self::new(
301 name.clone(),
302 style,
303 Box::new(move |args: &[Value]| match &args {
304 &[first, second, third] => func(first, second, third),
305 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
306 name.clone(),
307 3,
308 args.len(),
309 None, )),
311 }),
312 Some(return_type),
313 vec![arg_types.0, arg_types.1, arg_types.2],
314 false,
315 )
316 }
317
318 #[allow(clippy::type_complexity)]
320 pub fn variadic(
321 name: Name,
322 style: CallStyle,
323 func: VariadicExtensionFunctionObject,
324 return_type: SchemaType,
325 arg_types: (SchemaType, SchemaType),
326 ) -> Self {
327 Self::new(
328 name.clone(),
329 style,
330 Box::new(move |args: &[Value]| match &args {
331 #[cfg(feature = "variadic-is-in-range")]
332 &[first, rest @ ..] => func(first, rest),
333 #[cfg(not(feature = "variadic-is-in-range"))]
334 &[first, second] => func(first, std::slice::from_ref(second)),
335 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
336 name.clone(),
337 2,
338 args.len(),
339 None, )),
341 }),
342 Some(return_type),
343 vec![arg_types.0, arg_types.1],
344 #[cfg(feature = "variadic-is-in-range")]
345 true,
346 #[cfg(not(feature = "variadic-is-in-range"))]
347 false,
348 )
349 }
350
351 pub fn name(&self) -> &Name {
353 &self.name
354 }
355
356 pub fn style(&self) -> CallStyle {
358 self.style
359 }
360
361 pub fn return_type(&self) -> Option<&SchemaType> {
365 self.return_type.as_ref()
366 }
367
368 pub fn arg_types(&self) -> &[SchemaType] {
370 &self.arg_types
371 }
372
373 pub fn is_variadic(&self) -> bool {
375 self.is_variadic
376 }
377
378 pub fn is_single_arg_constructor(&self) -> bool {
384 matches!(self.return_type(), Some(SchemaType::Extension { .. }))
386 && matches!(self.arg_types(), [SchemaType::String])
388 }
389
390 pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
392 match (self.func)(args)? {
393 ExtensionOutputValue::Known(v) => Ok(PartialValue::Value(v)),
394 ExtensionOutputValue::Unknown(u) => Ok(PartialValue::Residual(Expr::unknown(u))),
395 }
396 }
397
398 pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
401 self.return_type
402 .iter()
403 .flat_map(|ret_ty| ret_ty.contained_ext_types())
404 }
405}
406
407impl std::fmt::Debug for ExtensionFunction {
408 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409 write!(f, "<extension function {}>", self.name())
410 }
411}
412
413pub trait ExtensionValue: Debug + Send + Sync + UnwindSafe + RefUnwindSafe {
419 fn typename(&self) -> Name;
424
425 fn supports_operator_overloading(&self) -> bool;
427}
428
429impl<V: ExtensionValue> StaticallyTyped for V {
430 fn type_of(&self) -> Type {
431 Type::Extension {
432 name: self.typename(),
433 }
434 }
435}
436
437#[derive(Debug, Clone)]
438pub struct RepresentableExtensionValue {
445 pub(crate) func: Name,
446 pub(crate) args: Vec<RestrictedExpr>,
447 pub(crate) value: Arc<dyn InternalExtensionValue>,
448}
449
450impl RepresentableExtensionValue {
451 pub fn new(
453 value: Arc<dyn InternalExtensionValue + Send + Sync>,
454 func: Name,
455 args: Vec<RestrictedExpr>,
456 ) -> Self {
457 Self { func, args, value }
458 }
459
460 pub fn value(&self) -> &dyn InternalExtensionValue {
462 self.value.as_ref()
463 }
464
465 pub fn typename(&self) -> Name {
467 self.value.typename()
468 }
469
470 pub(crate) fn supports_operator_overloading(&self) -> bool {
472 self.value.supports_operator_overloading()
473 }
474}
475
476impl From<RepresentableExtensionValue> for RestrictedExpr {
477 fn from(val: RepresentableExtensionValue) -> Self {
478 RestrictedExpr::call_extension_fn(val.func, val.args)
479 }
480}
481
482impl StaticallyTyped for RepresentableExtensionValue {
483 fn type_of(&self) -> Type {
484 self.value.type_of()
485 }
486}
487
488impl PartialEq for RepresentableExtensionValue {
489 fn eq(&self, other: &Self) -> bool {
490 self.value.as_ref() == other.value.as_ref()
492 }
493}
494
495impl Eq for RepresentableExtensionValue {}
496
497impl PartialOrd for RepresentableExtensionValue {
498 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
499 Some(self.cmp(other))
500 }
501}
502
503impl Ord for RepresentableExtensionValue {
504 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
505 self.value.cmp(&other.value)
506 }
507}
508
509pub trait InternalExtensionValue: ExtensionValue {
522 fn as_any(&self) -> &dyn Any;
524 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
527 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
530}
531
532impl<V: 'static + Eq + Ord + ExtensionValue + Send + Sync + Clone> InternalExtensionValue for V {
533 fn as_any(&self) -> &dyn Any {
534 self
535 }
536
537 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
538 other
539 .as_any()
540 .downcast_ref::<V>()
541 .map(|v| self == v)
542 .unwrap_or(false) }
544
545 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
546 other
547 .as_any()
548 .downcast_ref::<V>()
549 .map(|v| self.cmp(v))
550 .unwrap_or_else(|| {
551 self.typename().cmp(&other.typename())
554 })
555 }
556}
557
558impl PartialEq for dyn InternalExtensionValue {
559 fn eq(&self, other: &Self) -> bool {
560 self.equals_extvalue(other)
561 }
562}
563
564impl Eq for dyn InternalExtensionValue {}
565
566impl PartialOrd for dyn InternalExtensionValue {
567 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
568 Some(self.cmp(other))
569 }
570}
571
572impl Ord for dyn InternalExtensionValue {
573 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
574 self.cmp_extvalue(other)
575 }
576}
577
578impl StaticallyTyped for dyn InternalExtensionValue {
579 fn type_of(&self) -> Type {
580 Type::Extension {
581 name: self.typename(),
582 }
583 }
584}