1use crate::ast::*;
18use crate::entities::SchemaType;
19use crate::evaluator;
20use std::any::Any;
21use std::collections::HashMap;
22use std::fmt::{Debug, Display};
23use std::panic::{RefUnwindSafe, UnwindSafe};
24use std::sync::Arc;
25
26pub struct Extension {
33 name: Name,
35 functions: HashMap<Name, ExtensionFunction>,
37}
38
39impl Extension {
40 pub fn new(name: Name, functions: impl IntoIterator<Item = ExtensionFunction>) -> Self {
42 Self {
43 name,
44 functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
45 }
46 }
47
48 pub fn name(&self) -> &Name {
50 &self.name
51 }
52
53 pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
56 self.functions.get(name)
57 }
58
59 pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
61 self.functions.values()
62 }
63
64 pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
67 self.funcs().flat_map(|func| func.ext_types())
68 }
69}
70
71impl std::fmt::Debug for Extension {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 write!(f, "<extension {}>", self.name())
74 }
75}
76
77#[derive(Debug, Clone)]
79pub enum ExtensionOutputValue {
80 Known(Value),
82 Unknown(Unknown),
84}
85
86impl<T> From<T> for ExtensionOutputValue
87where
88 T: Into<Value>,
89{
90 fn from(v: T) -> Self {
91 ExtensionOutputValue::Known(v.into())
92 }
93}
94
95#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
97#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
98pub enum CallStyle {
99 FunctionStyle,
101 MethodStyle,
103}
104
105pub type ExtensionFunctionObject =
109 Box<dyn Fn(&[Value]) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>;
110
111pub struct ExtensionFunction {
114 name: Name,
116 style: CallStyle,
118 func: ExtensionFunctionObject,
121 return_type: Option<SchemaType>,
129 arg_types: Vec<SchemaType>,
131}
132
133impl ExtensionFunction {
134 fn new(
136 name: Name,
137 style: CallStyle,
138 func: ExtensionFunctionObject,
139 return_type: Option<SchemaType>,
140 arg_types: Vec<SchemaType>,
141 ) -> Self {
142 Self {
143 name,
144 func,
145 style,
146 return_type,
147 arg_types,
148 }
149 }
150
151 pub fn nullary(
153 name: Name,
154 style: CallStyle,
155 func: Box<dyn Fn() -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
156 return_type: SchemaType,
157 ) -> Self {
158 Self::new(
159 name.clone(),
160 style,
161 Box::new(move |args: &[Value]| {
162 if args.is_empty() {
163 func()
164 } else {
165 Err(evaluator::EvaluationError::wrong_num_arguments(
166 name.clone(),
167 0,
168 args.len(),
169 None, ))
171 }
172 }),
173 Some(return_type),
174 vec![],
175 )
176 }
177
178 pub fn partial_eval_unknown(
181 name: Name,
182 style: CallStyle,
183 func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
184 arg_type: SchemaType,
185 ) -> Self {
186 Self::new(
187 name.clone(),
188 style,
189 Box::new(move |args: &[Value]| match args.first() {
190 Some(arg) => func(arg.clone()),
191 None => Err(evaluator::EvaluationError::wrong_num_arguments(
192 name.clone(),
193 1,
194 args.len(),
195 None, )),
197 }),
198 None,
199 vec![arg_type],
200 )
201 }
202
203 pub fn unary(
205 name: Name,
206 style: CallStyle,
207 func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
208 return_type: SchemaType,
209 arg_type: SchemaType,
210 ) -> Self {
211 Self::new(
212 name.clone(),
213 style,
214 Box::new(move |args: &[Value]| match &args {
215 &[arg] => func(arg.clone()),
216 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
217 name.clone(),
218 1,
219 args.len(),
220 None, )),
222 }),
223 Some(return_type),
224 vec![arg_type],
225 )
226 }
227
228 pub fn binary(
230 name: Name,
231 style: CallStyle,
232 func: Box<
233 dyn Fn(Value, Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static,
234 >,
235 return_type: SchemaType,
236 arg_types: (SchemaType, SchemaType),
237 ) -> Self {
238 Self::new(
239 name.clone(),
240 style,
241 Box::new(move |args: &[Value]| match &args {
242 &[first, second] => func(first.clone(), second.clone()),
243 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
244 name.clone(),
245 2,
246 args.len(),
247 None, )),
249 }),
250 Some(return_type),
251 vec![arg_types.0, arg_types.1],
252 )
253 }
254
255 pub fn ternary(
257 name: Name,
258 style: CallStyle,
259 func: Box<
260 dyn Fn(Value, Value, Value) -> evaluator::Result<ExtensionOutputValue>
261 + Sync
262 + Send
263 + 'static,
264 >,
265 return_type: SchemaType,
266 arg_types: (SchemaType, SchemaType, SchemaType),
267 ) -> Self {
268 Self::new(
269 name.clone(),
270 style,
271 Box::new(move |args: &[Value]| match &args {
272 &[first, second, third] => func(first.clone(), second.clone(), third.clone()),
273 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
274 name.clone(),
275 3,
276 args.len(),
277 None, )),
279 }),
280 Some(return_type),
281 vec![arg_types.0, arg_types.1, arg_types.2],
282 )
283 }
284
285 pub fn name(&self) -> &Name {
287 &self.name
288 }
289
290 pub fn style(&self) -> CallStyle {
292 self.style
293 }
294
295 pub fn return_type(&self) -> Option<&SchemaType> {
299 self.return_type.as_ref()
300 }
301
302 pub fn arg_types(&self) -> &[SchemaType] {
304 &self.arg_types
305 }
306
307 pub fn is_constructor(&self) -> bool {
312 matches!(self.return_type(), Some(SchemaType::Extension { .. }))
314 && !self.arg_types().iter().any(|ty| matches!(ty, SchemaType::Extension { .. }))
316 }
317
318 pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
320 match (self.func)(args)? {
321 ExtensionOutputValue::Known(v) => Ok(PartialValue::Value(v)),
322 ExtensionOutputValue::Unknown(u) => Ok(PartialValue::Residual(Expr::unknown(u))),
323 }
324 }
325
326 pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
329 self.return_type
330 .iter()
331 .flat_map(|ret_ty| ret_ty.contained_ext_types())
332 }
333}
334
335impl std::fmt::Debug for ExtensionFunction {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 write!(f, "<extension function {}>", self.name())
338 }
339}
340
341pub trait ExtensionValue: Debug + Display + Send + Sync + UnwindSafe + RefUnwindSafe {
347 fn typename(&self) -> Name;
352}
353
354impl<V: ExtensionValue> StaticallyTyped for V {
355 fn type_of(&self) -> Type {
356 Type::Extension {
357 name: self.typename(),
358 }
359 }
360}
361
362#[derive(Debug, Clone)]
363pub struct ExtensionValueWithArgs {
367 value: Arc<dyn InternalExtensionValue>,
368 pub(crate) constructor: Name,
369 pub(crate) args: Vec<RestrictedExpr>,
373}
374
375impl ExtensionValueWithArgs {
376 pub fn new(
378 value: Arc<dyn InternalExtensionValue + Send + Sync>,
379 constructor: Name,
380 args: Vec<RestrictedExpr>,
381 ) -> Self {
382 Self {
383 value,
384 constructor,
385 args,
386 }
387 }
388
389 pub fn value(&self) -> &(dyn InternalExtensionValue) {
391 self.value.as_ref()
392 }
393
394 pub fn typename(&self) -> Name {
396 self.value.typename()
397 }
398
399 pub fn constructor_and_args(&self) -> (&Name, &[RestrictedExpr]) {
401 (&self.constructor, &self.args)
402 }
403}
404
405impl From<ExtensionValueWithArgs> for Expr {
406 fn from(val: ExtensionValueWithArgs) -> Self {
407 ExprBuilder::new().call_extension_fn(val.constructor, val.args.into_iter().map(Into::into))
408 }
409}
410
411impl StaticallyTyped for ExtensionValueWithArgs {
412 fn type_of(&self) -> Type {
413 self.value.type_of()
414 }
415}
416
417impl Display for ExtensionValueWithArgs {
418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419 write!(f, "{}", self.value)
420 }
421}
422
423impl PartialEq for ExtensionValueWithArgs {
424 fn eq(&self, other: &Self) -> bool {
425 self.value.as_ref() == other.value.as_ref()
427 }
428}
429
430impl Eq for ExtensionValueWithArgs {}
431
432impl PartialOrd for ExtensionValueWithArgs {
433 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
434 Some(self.cmp(other))
435 }
436}
437
438impl Ord for ExtensionValueWithArgs {
439 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
440 self.value.cmp(&other.value)
441 }
442}
443
444pub trait InternalExtensionValue: ExtensionValue {
457 fn as_any(&self) -> &dyn Any;
459 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
462 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
465}
466
467impl<V: 'static + Eq + Ord + ExtensionValue + Send + Sync> InternalExtensionValue for V {
468 fn as_any(&self) -> &dyn Any {
469 self
470 }
471
472 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
473 other
474 .as_any()
475 .downcast_ref::<V>()
476 .map(|v| self == v)
477 .unwrap_or(false) }
479
480 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
481 other
482 .as_any()
483 .downcast_ref::<V>()
484 .map(|v| self.cmp(v))
485 .unwrap_or_else(|| {
486 self.typename().cmp(&other.typename())
489 })
490 }
491}
492
493impl PartialEq for dyn InternalExtensionValue {
494 fn eq(&self, other: &Self) -> bool {
495 self.equals_extvalue(other)
496 }
497}
498
499impl Eq for dyn InternalExtensionValue {}
500
501impl PartialOrd for dyn InternalExtensionValue {
502 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
503 Some(self.cmp(other))
504 }
505}
506
507impl Ord for dyn InternalExtensionValue {
508 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
509 self.cmp_extvalue(other)
510 }
511}
512
513impl StaticallyTyped for dyn InternalExtensionValue {
514 fn type_of(&self) -> Type {
515 Type::Extension {
516 name: self.typename(),
517 }
518 }
519}