1use crate::ast::*;
18use crate::entities::SchemaType;
19use crate::evaluator;
20use std::any::Any;
21use std::collections::HashMap;
22use std::fmt::{Debug, Display};
23use std::panic::{RefUnwindSafe, UnwindSafe};
24use std::sync::Arc;
25
26pub struct Extension {
33 name: Name,
35 functions: HashMap<Name, ExtensionFunction>,
37}
38
39impl Extension {
40 pub fn new(name: Name, functions: impl IntoIterator<Item = ExtensionFunction>) -> Self {
42 Self {
43 name,
44 functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
45 }
46 }
47
48 pub fn name(&self) -> &Name {
50 &self.name
51 }
52
53 pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
56 self.functions.get(name)
57 }
58
59 pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
61 self.functions.values()
62 }
63}
64
65impl std::fmt::Debug for Extension {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 write!(f, "<extension {}>", self.name())
68 }
69}
70
71#[derive(Debug, Clone)]
73pub enum ExtensionOutputValue {
74 Known(Value),
76 Unknown(Unknown),
78}
79
80impl<T> From<T> for ExtensionOutputValue
81where
82 T: Into<Value>,
83{
84 fn from(v: T) -> Self {
85 ExtensionOutputValue::Known(v.into())
86 }
87}
88
89#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
91#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
92pub enum CallStyle {
93 FunctionStyle,
95 MethodStyle,
97}
98
99pub type ExtensionFunctionObject =
103 Box<dyn Fn(&[Value]) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>;
104
105pub struct ExtensionFunction {
108 name: Name,
110 style: CallStyle,
112 func: ExtensionFunctionObject,
115 return_type: Option<SchemaType>,
121 arg_types: Vec<Option<SchemaType>>,
125}
126
127impl ExtensionFunction {
128 fn new(
130 name: Name,
131 style: CallStyle,
132 func: ExtensionFunctionObject,
133 return_type: Option<SchemaType>,
134 arg_types: Vec<Option<SchemaType>>,
135 ) -> Self {
136 Self {
137 name,
138 func,
139 style,
140 return_type,
141 arg_types,
142 }
143 }
144
145 pub fn nullary(
147 name: Name,
148 style: CallStyle,
149 func: Box<dyn Fn() -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
150 return_type: SchemaType,
151 ) -> Self {
152 Self::new(
153 name.clone(),
154 style,
155 Box::new(move |args: &[Value]| {
156 if args.is_empty() {
157 func()
158 } else {
159 Err(evaluator::EvaluationError::wrong_num_arguments(
160 name.clone(),
161 0,
162 args.len(),
163 None, ))
165 }
166 }),
167 Some(return_type),
168 vec![],
169 )
170 }
171
172 pub fn unary_never(
174 name: Name,
175 style: CallStyle,
176 func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
177 arg_type: Option<SchemaType>,
178 ) -> Self {
179 Self::new(
180 name.clone(),
181 style,
182 Box::new(move |args: &[Value]| match args.first() {
183 Some(arg) => func(arg.clone()),
184 None => Err(evaluator::EvaluationError::wrong_num_arguments(
185 name.clone(),
186 1,
187 args.len(),
188 None, )),
190 }),
191 None,
192 vec![arg_type],
193 )
194 }
195
196 pub fn unary(
198 name: Name,
199 style: CallStyle,
200 func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
201 return_type: SchemaType,
202 arg_type: Option<SchemaType>,
203 ) -> Self {
204 Self::new(
205 name.clone(),
206 style,
207 Box::new(move |args: &[Value]| match &args {
208 &[arg] => func(arg.clone()),
209 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
210 name.clone(),
211 1,
212 args.len(),
213 None, )),
215 }),
216 Some(return_type),
217 vec![arg_type],
218 )
219 }
220
221 pub fn binary(
223 name: Name,
224 style: CallStyle,
225 func: Box<
226 dyn Fn(Value, Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static,
227 >,
228 return_type: SchemaType,
229 arg_types: (Option<SchemaType>, Option<SchemaType>),
230 ) -> Self {
231 Self::new(
232 name.clone(),
233 style,
234 Box::new(move |args: &[Value]| match &args {
235 &[first, second] => func(first.clone(), second.clone()),
236 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
237 name.clone(),
238 2,
239 args.len(),
240 None, )),
242 }),
243 Some(return_type),
244 vec![arg_types.0, arg_types.1],
245 )
246 }
247
248 pub fn ternary(
250 name: Name,
251 style: CallStyle,
252 func: Box<
253 dyn Fn(Value, Value, Value) -> evaluator::Result<ExtensionOutputValue>
254 + Sync
255 + Send
256 + 'static,
257 >,
258 return_type: SchemaType,
259 arg_types: (Option<SchemaType>, Option<SchemaType>, Option<SchemaType>),
260 ) -> Self {
261 Self::new(
262 name.clone(),
263 style,
264 Box::new(move |args: &[Value]| match &args {
265 &[first, second, third] => func(first.clone(), second.clone(), third.clone()),
266 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
267 name.clone(),
268 3,
269 args.len(),
270 None, )),
272 }),
273 Some(return_type),
274 vec![arg_types.0, arg_types.1, arg_types.2],
275 )
276 }
277
278 pub fn name(&self) -> &Name {
280 &self.name
281 }
282
283 pub fn style(&self) -> CallStyle {
285 self.style
286 }
287
288 pub fn return_type(&self) -> Option<&SchemaType> {
291 self.return_type.as_ref()
292 }
293
294 pub fn arg_types(&self) -> &[Option<SchemaType>] {
299 &self.arg_types
300 }
301
302 pub fn is_constructor(&self) -> bool {
307 matches!(self.return_type(), Some(SchemaType::Extension { .. }))
309 && self.arg_types().iter().all(Option::is_some)
311 && !self.arg_types().iter().any(|ty| matches!(ty, Some(SchemaType::Extension { .. })))
313 }
314
315 pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
317 match (self.func)(args)? {
318 ExtensionOutputValue::Known(v) => Ok(PartialValue::Value(v)),
319 ExtensionOutputValue::Unknown(u) => Ok(PartialValue::Residual(Expr::unknown(u))),
320 }
321 }
322}
323
324impl std::fmt::Debug for ExtensionFunction {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 write!(f, "<extension function {}>", self.name())
327 }
328}
329
330pub trait ExtensionValue: Debug + Display + Send + Sync + UnwindSafe + RefUnwindSafe {
336 fn typename(&self) -> Name;
341}
342
343impl<V: ExtensionValue> StaticallyTyped for V {
344 fn type_of(&self) -> Type {
345 Type::Extension {
346 name: self.typename(),
347 }
348 }
349}
350
351#[derive(Debug, Clone)]
352pub struct ExtensionValueWithArgs {
356 value: Arc<dyn InternalExtensionValue>,
357 pub(crate) constructor: Name,
358 pub(crate) args: Vec<RestrictedExpr>,
362}
363
364impl ExtensionValueWithArgs {
365 pub fn new(
367 value: Arc<dyn InternalExtensionValue + Send + Sync>,
368 constructor: Name,
369 args: Vec<RestrictedExpr>,
370 ) -> Self {
371 Self {
372 value,
373 constructor,
374 args,
375 }
376 }
377
378 pub fn value(&self) -> &(dyn InternalExtensionValue) {
380 self.value.as_ref()
381 }
382
383 pub fn typename(&self) -> Name {
385 self.value.typename()
386 }
387
388 pub fn constructor_and_args(&self) -> (&Name, &[RestrictedExpr]) {
390 (&self.constructor, &self.args)
391 }
392}
393
394impl From<ExtensionValueWithArgs> for Expr {
395 fn from(val: ExtensionValueWithArgs) -> Self {
396 ExprBuilder::new().call_extension_fn(val.constructor, val.args.into_iter().map(Into::into))
397 }
398}
399
400impl StaticallyTyped for ExtensionValueWithArgs {
401 fn type_of(&self) -> Type {
402 self.value.type_of()
403 }
404}
405
406impl Display for ExtensionValueWithArgs {
407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 write!(f, "{}", self.value)
409 }
410}
411
412impl PartialEq for ExtensionValueWithArgs {
413 fn eq(&self, other: &Self) -> bool {
414 self.value.as_ref() == other.value.as_ref()
416 }
417}
418
419impl Eq for ExtensionValueWithArgs {}
420
421impl PartialOrd for ExtensionValueWithArgs {
422 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
423 Some(self.cmp(other))
424 }
425}
426
427impl Ord for ExtensionValueWithArgs {
428 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
429 self.value.cmp(&other.value)
430 }
431}
432
433pub trait InternalExtensionValue: ExtensionValue {
446 fn as_any(&self) -> &dyn Any;
448 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
451 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
454}
455
456impl<V: 'static + Eq + Ord + ExtensionValue + Send + Sync> InternalExtensionValue for V {
457 fn as_any(&self) -> &dyn Any {
458 self
459 }
460
461 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
462 other
463 .as_any()
464 .downcast_ref::<V>()
465 .map(|v| self == v)
466 .unwrap_or(false) }
468
469 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
470 other
471 .as_any()
472 .downcast_ref::<V>()
473 .map(|v| self.cmp(v))
474 .unwrap_or_else(|| {
475 self.typename().cmp(&other.typename())
478 })
479 }
480}
481
482impl PartialEq for dyn InternalExtensionValue {
483 fn eq(&self, other: &Self) -> bool {
484 self.equals_extvalue(other)
485 }
486}
487
488impl Eq for dyn InternalExtensionValue {}
489
490impl PartialOrd for dyn InternalExtensionValue {
491 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
492 Some(self.cmp(other))
493 }
494}
495
496impl Ord for dyn InternalExtensionValue {
497 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
498 self.cmp_extvalue(other)
499 }
500}
501
502impl StaticallyTyped for dyn InternalExtensionValue {
503 fn type_of(&self) -> Type {
504 Type::Extension {
505 name: self.typename(),
506 }
507 }
508}