cedar_policy_core/ast/
extension.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::SchemaType;
19use crate::evaluator;
20use std::any::Any;
21use std::collections::{BTreeSet, HashMap};
22use std::fmt::Debug;
23use std::panic::{RefUnwindSafe, UnwindSafe};
24use std::sync::Arc;
25
26/// Cedar extension.
27///
28/// An extension can define new types and functions on those types. (Currently,
29/// there's nothing preventing an extension from defining new functions on
30/// built-in types, either, although we haven't discussed whether we want to
31/// allow this long-term.)
32pub struct Extension {
33    /// Name of the extension
34    name: Name,
35    /// Extension functions. These are legal to call in Cedar expressions.
36    functions: HashMap<Name, ExtensionFunction>,
37    /// Types with operator overloading
38    types_with_operator_overloading: BTreeSet<Name>,
39}
40
41impl Extension {
42    /// Create a new `Extension` with the given name and extension functions
43    pub fn new(
44        name: Name,
45        functions: impl IntoIterator<Item = ExtensionFunction>,
46        types_with_operator_overloading: impl IntoIterator<Item = Name>,
47    ) -> Self {
48        Self {
49            name,
50            functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
51            types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(),
52        }
53    }
54
55    /// Get the name of the extension
56    pub fn name(&self) -> &Name {
57        &self.name
58    }
59
60    /// Look up a function by name, or return `None` if the extension doesn't
61    /// provide a function with that name
62    pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
63        self.functions.get(name)
64    }
65
66    /// Iterate over the functions
67    pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
68        self.functions.values()
69    }
70
71    /// Iterate over the extension types that can be produced by any functions
72    /// in this extension
73    pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
74        self.funcs().flat_map(|func| func.ext_types())
75    }
76
77    /// Iterate over extension types with operator overloading
78    pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
79        self.types_with_operator_overloading.iter()
80    }
81}
82
83impl std::fmt::Debug for Extension {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(f, "<extension {}>", self.name())
86    }
87}
88
89/// The output of an extension call, either a value or an unknown
90#[derive(Debug, Clone)]
91pub enum ExtensionOutputValue {
92    /// A concrete value from an extension call
93    Known(Value),
94    /// An unknown returned from an extension call
95    Unknown(Unknown),
96}
97
98impl<T> From<T> for ExtensionOutputValue
99where
100    T: Into<Value>,
101{
102    fn from(v: T) -> Self {
103        ExtensionOutputValue::Known(v.into())
104    }
105}
106
107/// Which "style" is a function call
108#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
109#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
110pub enum CallStyle {
111    /// Function-style, eg foo(a, b)
112    FunctionStyle,
113    /// Method-style, eg a.foo(b)
114    MethodStyle,
115}
116
117// Note: we could use currying to make this a little nicer
118
119macro_rules! extension_function_object {
120    ( $( $tys:ty ), * ) => {
121        Box<dyn Fn($($tys,)*) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>
122    }
123}
124
125/// Trait object that implements the extension function call accepting any number of arguments.
126pub type ExtensionFunctionObject = extension_function_object!(&[Value]);
127/// Trait object that implements the extension function call accepting exactly 0 arguments
128pub type NullaryExtensionFunctionObject = extension_function_object!();
129/// Trait object that implements the extension function call accepting exactly 1 arguments
130pub type UnaryExtensionFunctionObject = extension_function_object!(&Value);
131/// Trait object that implements the extension function call accepting exactly 2 arguments
132pub type BinaryExtensionFunctionObject = extension_function_object!(&Value, &Value);
133/// Trait object that implements the extension function call accepting exactly 3 arguments
134pub type TernaryExtensionFunctionObject = extension_function_object!(&Value, &Value, &Value);
135/// Trait object that implements the extension function call that takes one argument, followed by a variadic number of arguments.
136pub type VariadicExtensionFunctionObject = extension_function_object!(&Value, &[Value]);
137
138/// Extension function. These can be called by the given `name` in Ceder
139/// expressions.
140pub struct ExtensionFunction {
141    /// Name of the function
142    name: Name,
143    /// Which `CallStyle` should be used when calling this function
144    style: CallStyle,
145    /// The actual function, which takes an `&[Value]` and returns a `Value`,
146    /// or an evaluation error
147    func: ExtensionFunctionObject,
148    /// The return type of this function, as a `SchemaType`. We require that
149    /// this be constant -- any given extension function must always return a
150    /// value of this `SchemaType`.
151    ///
152    /// `return_type` is `None` if and only if this function represents an
153    /// "unknown" value for partial evaluation. Such a function may only return
154    /// a fully unknown residual and may never return a value.
155    return_type: Option<SchemaType>,
156    /// The argument types that this function expects, as `SchemaType`s.
157    arg_types: Vec<SchemaType>,
158    /// Whether this is a variadic function or not. If it is a variadic function it can accept 1 or more arguments
159    /// of the last argument type.
160    is_variadic: bool,
161}
162
163impl ExtensionFunction {
164    /// Create a new `ExtensionFunction` taking any number of arguments
165    fn new(
166        name: Name,
167        style: CallStyle,
168        func: ExtensionFunctionObject,
169        return_type: Option<SchemaType>,
170        arg_types: Vec<SchemaType>,
171        is_variadic: bool,
172    ) -> Self {
173        Self {
174            name,
175            style,
176            func,
177            return_type,
178            arg_types,
179            is_variadic,
180        }
181    }
182
183    /// Create a new `ExtensionFunction` taking no arguments
184    pub fn nullary(
185        name: Name,
186        style: CallStyle,
187        func: NullaryExtensionFunctionObject,
188        return_type: SchemaType,
189    ) -> Self {
190        Self::new(
191            name.clone(),
192            style,
193            Box::new(move |args: &[Value]| {
194                if args.is_empty() {
195                    func()
196                } else {
197                    Err(evaluator::EvaluationError::wrong_num_arguments(
198                        name.clone(),
199                        0,
200                        args.len(),
201                        None, // evaluator will add the source location later
202                    ))
203                }
204            }),
205            Some(return_type),
206            vec![],
207            false,
208        )
209    }
210
211    /// Create a new `ExtensionFunction` to represent a function which is an
212    /// "unknown" in partial evaluation. Please don't use this for anything else.
213    pub fn partial_eval_unknown(
214        name: Name,
215        style: CallStyle,
216        func: UnaryExtensionFunctionObject,
217        arg_type: SchemaType,
218    ) -> Self {
219        Self::new(
220            name.clone(),
221            style,
222            Box::new(move |args: &[Value]| match args.first() {
223                Some(arg) => func(arg),
224                None => Err(evaluator::EvaluationError::wrong_num_arguments(
225                    name.clone(),
226                    1,
227                    args.len(),
228                    None, // evaluator will add the source location later
229                )),
230            }),
231            None,
232            vec![arg_type],
233            false,
234        )
235    }
236
237    /// Create a new `ExtensionFunction` taking one argument
238    #[allow(clippy::type_complexity)]
239    pub fn unary(
240        name: Name,
241        style: CallStyle,
242        func: UnaryExtensionFunctionObject,
243        return_type: SchemaType,
244        arg_type: SchemaType,
245    ) -> Self {
246        Self::new(
247            name.clone(),
248            style,
249            Box::new(move |args: &[Value]| match &args {
250                &[arg] => func(arg),
251                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
252                    name.clone(),
253                    1,
254                    args.len(),
255                    None, // evaluator will add the source location later
256                )),
257            }),
258            Some(return_type),
259            vec![arg_type],
260            false,
261        )
262    }
263
264    /// Create a new `ExtensionFunction` taking two arguments
265    #[allow(clippy::type_complexity)]
266    pub fn binary(
267        name: Name,
268        style: CallStyle,
269        func: BinaryExtensionFunctionObject,
270        return_type: SchemaType,
271        arg_types: (SchemaType, SchemaType),
272    ) -> Self {
273        Self::new(
274            name.clone(),
275            style,
276            Box::new(move |args: &[Value]| match &args {
277                &[first, second] => func(first, second),
278                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
279                    name.clone(),
280                    2,
281                    args.len(),
282                    None, // evaluator will add the source location later
283                )),
284            }),
285            Some(return_type),
286            vec![arg_types.0, arg_types.1],
287            false,
288        )
289    }
290
291    /// Create a new `ExtensionFunction` taking three arguments
292    #[allow(clippy::type_complexity)]
293    pub fn ternary(
294        name: Name,
295        style: CallStyle,
296        func: TernaryExtensionFunctionObject,
297        return_type: SchemaType,
298        arg_types: (SchemaType, SchemaType, SchemaType),
299    ) -> Self {
300        Self::new(
301            name.clone(),
302            style,
303            Box::new(move |args: &[Value]| match &args {
304                &[first, second, third] => func(first, second, third),
305                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
306                    name.clone(),
307                    3,
308                    args.len(),
309                    None, // evaluator will add the source location later
310                )),
311            }),
312            Some(return_type),
313            vec![arg_types.0, arg_types.1, arg_types.2],
314            false,
315        )
316    }
317
318    /// Create a new variadic `ExtensionFunction` taking two or more argument.
319    #[allow(clippy::type_complexity)]
320    pub fn variadic(
321        name: Name,
322        style: CallStyle,
323        func: VariadicExtensionFunctionObject,
324        return_type: SchemaType,
325        arg_types: (SchemaType, SchemaType),
326    ) -> Self {
327        Self::new(
328            name.clone(),
329            style,
330            Box::new(move |args: &[Value]| match &args {
331                #[cfg(feature = "variadic-is-in-range")]
332                &[first, rest @ ..] => func(first, rest),
333                #[cfg(not(feature = "variadic-is-in-range"))]
334                &[first, second] => func(first, std::slice::from_ref(second)),
335                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
336                    name.clone(),
337                    2,
338                    args.len(),
339                    None, // evaluator will add the source location later
340                )),
341            }),
342            Some(return_type),
343            vec![arg_types.0, arg_types.1],
344            #[cfg(feature = "variadic-is-in-range")]
345            true,
346            #[cfg(not(feature = "variadic-is-in-range"))]
347            false,
348        )
349    }
350
351    /// Get the `Name` of the `ExtensionFunction`
352    pub fn name(&self) -> &Name {
353        &self.name
354    }
355
356    /// Get the `CallStyle` of the `ExtensionFunction`
357    pub fn style(&self) -> CallStyle {
358        self.style
359    }
360
361    /// Get the return type of the `ExtensionFunction`
362    /// `None` is returned exactly when this function represents an "unknown"
363    /// for partial evaluation.
364    pub fn return_type(&self) -> Option<&SchemaType> {
365        self.return_type.as_ref()
366    }
367
368    /// Get the argument types of the `ExtensionFunction`.
369    pub fn arg_types(&self) -> &[SchemaType] {
370        &self.arg_types
371    }
372
373    /// Whether this is a variadic function.
374    pub fn is_variadic(&self) -> bool {
375        self.is_variadic
376    }
377
378    /// Returns `true` if this function is considered a single argument
379    /// constructor.
380    ///
381    /// Only functions satisfying this predicate can have their names implicit
382    /// during schema-based entity parsing
383    pub fn is_single_arg_constructor(&self) -> bool {
384        // return type is an extension type
385        matches!(self.return_type(), Some(SchemaType::Extension { .. }))
386        // the only argument is a string
387        && matches!(self.arg_types(), [SchemaType::String])
388    }
389
390    /// Call the `ExtensionFunction` with the given args
391    pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
392        match (self.func)(args)? {
393            ExtensionOutputValue::Known(v) => Ok(PartialValue::Value(v)),
394            ExtensionOutputValue::Unknown(u) => Ok(PartialValue::Residual(Expr::unknown(u))),
395        }
396    }
397
398    /// Iterate over the extension types that could be produced by this
399    /// function, if any
400    pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
401        self.return_type
402            .iter()
403            .flat_map(|ret_ty| ret_ty.contained_ext_types())
404    }
405}
406
407impl std::fmt::Debug for ExtensionFunction {
408    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409        write!(f, "<extension function {}>", self.name())
410    }
411}
412
413/// Extension value.
414///
415/// Anything implementing this trait can be used as a first-class value in
416/// Cedar. For instance, the `ipaddr` extension uses this mechanism
417/// to implement IPAddr as a Cedar first-class value.
418pub trait ExtensionValue: Debug + Send + Sync + UnwindSafe + RefUnwindSafe {
419    /// Get the name of the type of this value.
420    ///
421    /// Cedar has nominal typing, so two values have the same type iff they
422    /// return the same typename here.
423    fn typename(&self) -> Name;
424
425    /// If it supports operator overloading
426    fn supports_operator_overloading(&self) -> bool;
427}
428
429impl<V: ExtensionValue> StaticallyTyped for V {
430    fn type_of(&self) -> Type {
431        Type::Extension {
432            name: self.typename(),
433        }
434    }
435}
436
437#[derive(Debug, Clone)]
438/// Object container for extension values
439/// An extension value must be representable by a [`RestrictedExpr`]
440/// Specifically, it will be a function call `func` on `args`
441/// Note that `func` may not be the constructor. A counterexample is that a
442/// `datetime` is represented by an `offset` method call.
443/// Nevertheless, an invariant is that `eval(<func>(<args>)) == value`
444pub struct RepresentableExtensionValue {
445    pub(crate) func: Name,
446    pub(crate) args: Vec<RestrictedExpr>,
447    pub(crate) value: Arc<dyn InternalExtensionValue>,
448}
449
450impl RepresentableExtensionValue {
451    /// Create a new [`RepresentableExtensionValue`]
452    pub fn new(
453        value: Arc<dyn InternalExtensionValue + Send + Sync>,
454        func: Name,
455        args: Vec<RestrictedExpr>,
456    ) -> Self {
457        Self { func, args, value }
458    }
459
460    /// Get the internal value
461    pub fn value(&self) -> &dyn InternalExtensionValue {
462        self.value.as_ref()
463    }
464
465    /// Get the typename of this extension value
466    pub fn typename(&self) -> Name {
467        self.value.typename()
468    }
469
470    /// If this value supports operator overloading
471    pub(crate) fn supports_operator_overloading(&self) -> bool {
472        self.value.supports_operator_overloading()
473    }
474}
475
476impl From<RepresentableExtensionValue> for RestrictedExpr {
477    fn from(val: RepresentableExtensionValue) -> Self {
478        RestrictedExpr::call_extension_fn(val.func, val.args)
479    }
480}
481
482impl StaticallyTyped for RepresentableExtensionValue {
483    fn type_of(&self) -> Type {
484        self.value.type_of()
485    }
486}
487
488impl PartialEq for RepresentableExtensionValue {
489    fn eq(&self, other: &Self) -> bool {
490        // Values that are equal are equal regardless of which arguments made them
491        self.value.as_ref() == other.value.as_ref()
492    }
493}
494
495impl Eq for RepresentableExtensionValue {}
496
497impl PartialOrd for RepresentableExtensionValue {
498    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
499        Some(self.cmp(other))
500    }
501}
502
503impl Ord for RepresentableExtensionValue {
504    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
505        self.value.cmp(&other.value)
506    }
507}
508
509/// Extensions provide a type implementing `ExtensionValue`, `Eq`, and `Ord`.
510/// We automatically implement `InternalExtensionValue` for that type (with the
511/// impl below).  Internally, we use `dyn InternalExtensionValue` instead of
512/// `dyn ExtensionValue`.
513///
514/// You might wonder why we don't just have `ExtensionValue: Eq + Ord` and use
515/// `dyn ExtensionValue` everywhere.  The answer is that the Rust compiler
516/// doesn't let you because of
517/// [object safety](https://doc.rust-lang.org/reference/items/traits.html#object-safety).
518/// So instead we have this workaround where we define our own `equals_extvalue`
519/// method that compares not against `&Self` but against `&dyn InternalExtensionValue`,
520/// and likewise for `cmp_extvalue`.
521pub trait InternalExtensionValue: ExtensionValue {
522    /// convert to an `Any`
523    fn as_any(&self) -> &dyn Any;
524    /// this will be the basis for `PartialEq` on `InternalExtensionValue`; but
525    /// note the `&dyn` (normal `PartialEq` doesn't have the `dyn`)
526    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
527    /// this will be the basis for `Ord` on `InternalExtensionValue`; but note
528    /// the `&dyn` (normal `Ord` doesn't have the `dyn`)
529    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
530}
531
532impl<V: 'static + Eq + Ord + ExtensionValue + Send + Sync + Clone> InternalExtensionValue for V {
533    fn as_any(&self) -> &dyn Any {
534        self
535    }
536
537    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
538        other
539            .as_any()
540            .downcast_ref::<V>()
541            .map(|v| self == v)
542            .unwrap_or(false) // if the downcast failed, values are different types, so equality is false
543    }
544
545    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
546        other
547            .as_any()
548            .downcast_ref::<V>()
549            .map(|v| self.cmp(v))
550            .unwrap_or_else(|| {
551                // downcast failed, so values are different types.
552                // we fall back on the total ordering on typenames.
553                self.typename().cmp(&other.typename())
554            })
555    }
556}
557
558impl PartialEq for dyn InternalExtensionValue {
559    fn eq(&self, other: &Self) -> bool {
560        self.equals_extvalue(other)
561    }
562}
563
564impl Eq for dyn InternalExtensionValue {}
565
566impl PartialOrd for dyn InternalExtensionValue {
567    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
568        Some(self.cmp(other))
569    }
570}
571
572impl Ord for dyn InternalExtensionValue {
573    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
574        self.cmp_extvalue(other)
575    }
576}
577
578impl StaticallyTyped for dyn InternalExtensionValue {
579    fn type_of(&self) -> Type {
580        Type::Extension {
581            name: self.typename(),
582        }
583    }
584}