1use smol_str::SmolStr;
18
19use crate::ast::*;
20use crate::entities::SchemaType;
21use crate::evaluator;
22use std::any::Any;
23use std::collections::HashMap;
24use std::fmt::{Debug, Display};
25use std::sync::Arc;
26
27pub struct Extension {
34 name: Name,
36 functions: HashMap<Name, ExtensionFunction>,
38}
39
40impl Extension {
41 pub fn new(name: Name, functions: impl IntoIterator<Item = ExtensionFunction>) -> Self {
43 Self {
44 name,
45 functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
46 }
47 }
48
49 pub fn name(&self) -> &Name {
51 &self.name
52 }
53
54 pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
57 self.functions.get(name)
58 }
59
60 pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
62 self.functions.values()
63 }
64}
65
66impl std::fmt::Debug for Extension {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, "<extension {}>", self.name())
69 }
70}
71
72#[derive(Debug, Clone)]
74pub enum ExtensionOutputValue {
75 Concrete(Value),
77 Unknown(SmolStr),
79}
80
81impl<T> From<T> for ExtensionOutputValue
82where
83 T: Into<Value>,
84{
85 fn from(v: T) -> Self {
86 ExtensionOutputValue::Concrete(v.into())
87 }
88}
89
90#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
92#[cfg_attr(fuzzing, derive(arbitrary::Arbitrary))]
93pub enum CallStyle {
94 FunctionStyle,
96 MethodStyle,
98}
99
100pub type ExtensionFunctionObject =
104 Box<dyn Fn(&[Value]) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>;
105
106pub struct ExtensionFunction {
109 name: Name,
111 style: CallStyle,
113 func: ExtensionFunctionObject,
116 return_type: Option<SchemaType>,
122 arg_types: Vec<Option<SchemaType>>,
126}
127
128impl ExtensionFunction {
129 fn new(
131 name: Name,
132 style: CallStyle,
133 func: ExtensionFunctionObject,
134 return_type: Option<SchemaType>,
135 arg_types: Vec<Option<SchemaType>>,
136 ) -> Self {
137 Self {
138 name,
139 func,
140 style,
141 return_type,
142 arg_types,
143 }
144 }
145
146 pub fn nullary(
148 name: Name,
149 style: CallStyle,
150 func: Box<dyn Fn() -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
151 return_type: SchemaType,
152 ) -> Self {
153 Self::new(
154 name.clone(),
155 style,
156 Box::new(move |args: &[Value]| {
157 if args.is_empty() {
158 func()
159 } else {
160 Err(evaluator::EvaluationError::WrongNumArguments {
161 op: ExtensionFunctionOp {
162 function_name: name.clone(),
163 },
164 expected: 0,
165 actual: args.len(),
166 })
167 }
168 }),
169 Some(return_type),
170 vec![],
171 )
172 }
173
174 pub fn unary_never(
176 name: Name,
177 style: CallStyle,
178 func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
179 arg_type: Option<SchemaType>,
180 ) -> Self {
181 Self::new(
182 name.clone(),
183 style,
184 Box::new(move |args: &[Value]| {
185 if args.len() == 1 {
186 func(args[0].clone())
187 } else {
188 let op = ExtensionFunctionOp {
189 function_name: name.clone(),
190 };
191 Err(evaluator::EvaluationError::WrongNumArguments {
192 op,
193 expected: 1,
194 actual: args.len(),
195 })
196 }
197 }),
198 None,
199 vec![arg_type],
200 )
201 }
202
203 pub fn unary(
205 name: Name,
206 style: CallStyle,
207 func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
208 return_type: SchemaType,
209 arg_type: Option<SchemaType>,
210 ) -> Self {
211 Self::new(
212 name.clone(),
213 style,
214 Box::new(move |args: &[Value]| {
215 if args.len() == 1 {
216 func(args[0].clone())
217 } else {
218 let op = ExtensionFunctionOp {
219 function_name: name.clone(),
220 };
221 Err(evaluator::EvaluationError::WrongNumArguments {
222 op,
223 expected: 1,
224 actual: args.len(),
225 })
226 }
227 }),
228 Some(return_type),
229 vec![arg_type],
230 )
231 }
232
233 pub fn binary(
235 name: Name,
236 style: CallStyle,
237 func: Box<
238 dyn Fn(Value, Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static,
239 >,
240 return_type: SchemaType,
241 arg_types: (Option<SchemaType>, Option<SchemaType>),
242 ) -> Self {
243 Self::new(
244 name.clone(),
245 style,
246 Box::new(move |args: &[Value]| {
247 if args.len() == 2 {
248 func(args[0].clone(), args[1].clone())
249 } else {
250 Err(evaluator::EvaluationError::WrongNumArguments {
251 op: ExtensionFunctionOp {
252 function_name: name.clone(),
253 },
254 expected: 2,
255 actual: args.len(),
256 })
257 }
258 }),
259 Some(return_type),
260 vec![arg_types.0, arg_types.1],
261 )
262 }
263
264 pub fn ternary(
266 name: Name,
267 style: CallStyle,
268 func: Box<
269 dyn Fn(Value, Value, Value) -> evaluator::Result<ExtensionOutputValue>
270 + Sync
271 + Send
272 + 'static,
273 >,
274 return_type: SchemaType,
275 arg_types: (Option<SchemaType>, Option<SchemaType>, Option<SchemaType>),
276 ) -> Self {
277 Self::new(
278 name.clone(),
279 style,
280 Box::new(move |args: &[Value]| {
281 if args.len() == 3 {
282 func(args[0].clone(), args[1].clone(), args[2].clone())
283 } else {
284 Err(evaluator::EvaluationError::WrongNumArguments {
285 op: ExtensionFunctionOp {
286 function_name: name.clone(),
287 },
288 expected: 3,
289 actual: args.len(),
290 })
291 }
292 }),
293 Some(return_type),
294 vec![arg_types.0, arg_types.1, arg_types.2],
295 )
296 }
297
298 pub fn name(&self) -> &Name {
300 &self.name
301 }
302
303 pub fn style(&self) -> CallStyle {
305 self.style
306 }
307
308 pub fn return_type(&self) -> Option<&SchemaType> {
311 self.return_type.as_ref()
312 }
313
314 pub fn arg_types(&self) -> &[Option<SchemaType>] {
319 &self.arg_types
320 }
321
322 pub fn is_constructor(&self) -> bool {
327 matches!(self.return_type(), Some(SchemaType::Extension { .. }))
329 && self.arg_types().iter().all(Option::is_some)
331 && !self.arg_types().iter().any(|ty| matches!(ty, Some(SchemaType::Extension { .. })))
333 }
334
335 pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
337 match (self.func)(args)? {
338 ExtensionOutputValue::Concrete(v) => Ok(PartialValue::Value(v)),
339 ExtensionOutputValue::Unknown(name) => Ok(PartialValue::Residual(Expr::unknown(name))),
340 }
341 }
342}
343
344impl std::fmt::Debug for ExtensionFunction {
345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 write!(f, "<extension function {}>", self.name())
347 }
348}
349
350pub trait ExtensionValue: Debug + Display {
356 fn typename(&self) -> Name;
361}
362
363impl<V: ExtensionValue> StaticallyTyped for V {
364 fn type_of(&self) -> Type {
365 Type::Extension {
366 name: self.typename(),
367 }
368 }
369}
370
371#[derive(Debug, Clone)]
372pub struct ExtensionValueWithArgs {
375 value: Arc<dyn InternalExtensionValue>,
376 args: Vec<Expr>,
377 constructor: ExtensionFunctionOp,
378}
379
380impl ExtensionValueWithArgs {
381 pub fn value(&self) -> &dyn InternalExtensionValue {
383 self.value.as_ref()
384 }
385
386 pub fn typename(&self) -> Name {
388 self.value.typename()
389 }
390
391 pub fn new(
393 value: Arc<dyn InternalExtensionValue>,
394 args: Vec<Expr>,
395 constructor: ExtensionFunctionOp,
396 ) -> Self {
397 Self {
398 value,
399 args,
400 constructor,
401 }
402 }
403}
404
405impl From<ExtensionValueWithArgs> for Expr {
406 fn from(val: ExtensionValueWithArgs) -> Self {
407 ExprBuilder::new().call_extension_fn(val.constructor.function_name, val.args)
408 }
409}
410
411impl StaticallyTyped for ExtensionValueWithArgs {
412 fn type_of(&self) -> Type {
413 self.value.type_of()
414 }
415}
416
417impl Display for ExtensionValueWithArgs {
418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419 write!(f, "{}", self.value)
420 }
421}
422
423impl PartialEq for ExtensionValueWithArgs {
424 fn eq(&self, other: &Self) -> bool {
425 self.value.as_ref() == other.value.as_ref()
427 }
428}
429
430impl Eq for ExtensionValueWithArgs {}
431
432impl PartialOrd for ExtensionValueWithArgs {
433 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
434 self.value.partial_cmp(&other.value)
435 }
436}
437
438impl Ord for ExtensionValueWithArgs {
439 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
440 self.value.cmp(&other.value)
441 }
442}
443
444pub trait InternalExtensionValue: ExtensionValue {
457 fn as_any(&self) -> &dyn Any;
459 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
462 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
465}
466
467impl<V: 'static + Eq + Ord + ExtensionValue> InternalExtensionValue for V {
468 fn as_any(&self) -> &dyn Any {
469 self
470 }
471
472 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
473 other
474 .as_any()
475 .downcast_ref::<V>()
476 .map(|v| self == v)
477 .unwrap_or(false) }
479
480 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
481 other
482 .as_any()
483 .downcast_ref::<V>()
484 .map(|v| self.cmp(v))
485 .unwrap_or_else(|| {
486 self.typename().cmp(&other.typename())
489 })
490 }
491}
492
493impl PartialEq for dyn InternalExtensionValue {
494 fn eq(&self, other: &Self) -> bool {
495 self.equals_extvalue(other)
496 }
497}
498
499impl Eq for dyn InternalExtensionValue {}
500
501impl PartialOrd for dyn InternalExtensionValue {
502 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
503 Some(self.cmp(other))
504 }
505}
506
507impl Ord for dyn InternalExtensionValue {
508 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
509 self.cmp_extvalue(other)
510 }
511}
512
513impl StaticallyTyped for dyn InternalExtensionValue {
514 fn type_of(&self) -> Type {
515 Type::Extension {
516 name: self.typename(),
517 }
518 }
519}