Skip to main content

cedar_policy_core/ast/
extension.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::SchemaType;
19use crate::evaluator;
20use std::any::Any;
21use std::collections::{BTreeSet, HashMap};
22use std::fmt::Debug;
23use std::panic::{RefUnwindSafe, UnwindSafe};
24use std::sync::Arc;
25
26/// Cedar extension.
27///
28/// An extension can define new types and functions on those types. (Currently,
29/// there's nothing preventing an extension from defining new functions on
30/// built-in types, either, although we haven't discussed whether we want to
31/// allow this long-term.)
32pub struct Extension {
33    /// Name of the extension
34    name: Name,
35    /// Extension functions. These are legal to call in Cedar expressions.
36    functions: HashMap<Name, ExtensionFunction>,
37    /// Types with operator overloading
38    types_with_operator_overloading: BTreeSet<Name>,
39}
40
41impl Extension {
42    /// Create a new `Extension` with the given name and extension functions
43    pub fn new(
44        name: Name,
45        functions: impl IntoIterator<Item = ExtensionFunction>,
46        types_with_operator_overloading: impl IntoIterator<Item = Name>,
47    ) -> Self {
48        Self {
49            name,
50            functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
51            types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(),
52        }
53    }
54
55    /// Get the name of the extension
56    pub fn name(&self) -> &Name {
57        &self.name
58    }
59
60    /// Look up a function by name, or return `None` if the extension doesn't
61    /// provide a function with that name
62    pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
63        self.functions.get(name)
64    }
65
66    /// Iterate over the functions
67    pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
68        self.functions.values()
69    }
70
71    /// Iterate over the extension types that can be produced by any functions
72    /// in this extension
73    pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
74        self.funcs().flat_map(|func| func.ext_types())
75    }
76
77    /// Iterate over extension types with operator overloading
78    pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
79        self.types_with_operator_overloading.iter()
80    }
81}
82
83impl std::fmt::Debug for Extension {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(f, "<extension {}>", self.name())
86    }
87}
88
89/// The output of an extension call, either a value or an unknown
90#[derive(Debug, Clone)]
91pub enum ExtensionOutputValue {
92    /// A concrete value from an extension call
93    Known(Value),
94    /// An unknown returned from an extension call
95    Unknown(Unknown),
96}
97
98impl<T> From<T> for ExtensionOutputValue
99where
100    T: Into<Value>,
101{
102    fn from(v: T) -> Self {
103        ExtensionOutputValue::Known(v.into())
104    }
105}
106
107/// Which "style" is a function call
108#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
109#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
110pub enum CallStyle {
111    /// Function-style, eg foo(a, b)
112    FunctionStyle,
113    /// Method-style, eg a.foo(b)
114    MethodStyle,
115}
116
117// Note: we could use currying to make this a little nicer
118
119macro_rules! extension_function_object {
120    ( $( $tys:ty ), * ) => {
121        Box<dyn Fn($($tys,)*) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>
122    }
123}
124
125/// Trait object that implements the extension function call accepting any number of arguments.
126pub type ExtensionFunctionObject = extension_function_object!(&[Value]);
127/// Trait object that implements the extension function call accepting exactly 0 arguments
128pub type NullaryExtensionFunctionObject = extension_function_object!();
129/// Trait object that implements the extension function call accepting exactly 1 arguments
130pub type UnaryExtensionFunctionObject = extension_function_object!(&Value);
131/// Trait object that implements the extension function call accepting exactly 2 arguments
132pub type BinaryExtensionFunctionObject = extension_function_object!(&Value, &Value);
133/// Trait object that implements the extension function call accepting exactly 3 arguments
134pub type TernaryExtensionFunctionObject = extension_function_object!(&Value, &Value, &Value);
135/// Trait object that implements the extension function call that takes one argument, followed by a variadic number of arguments.
136pub type VariadicExtensionFunctionObject = extension_function_object!(&Value, &[Value]);
137
138/// Extension function. These can be called by the given `name` in Ceder
139/// expressions.
140pub struct ExtensionFunction {
141    /// Name of the function
142    name: Name,
143    /// Which `CallStyle` should be used when calling this function
144    style: CallStyle,
145    /// The actual function, which takes an `&[Value]` and returns a `Value`,
146    /// or an evaluation error
147    func: ExtensionFunctionObject,
148    /// The return type of this function, as a `SchemaType`. We require that
149    /// this be constant -- any given extension function must always return a
150    /// value of this `SchemaType`.
151    ///
152    /// `return_type` is `None` if and only if this function represents an
153    /// "unknown" value for partial evaluation. Such a function may only return
154    /// a fully unknown residual and may never return a value.
155    return_type: Option<SchemaType>,
156    /// The argument types that this function expects, as `SchemaType`s.
157    arg_types: Vec<SchemaType>,
158    /// Whether this is a variadic function or not. If it is a variadic function it can accept 1 or more arguments
159    /// of the last argument type.
160    is_variadic: bool,
161}
162
163impl ExtensionFunction {
164    /// Create a new `ExtensionFunction` taking any number of arguments
165    fn new(
166        name: Name,
167        style: CallStyle,
168        func: ExtensionFunctionObject,
169        return_type: Option<SchemaType>,
170        arg_types: Vec<SchemaType>,
171        is_variadic: bool,
172    ) -> Self {
173        Self {
174            name,
175            style,
176            func,
177            return_type,
178            arg_types,
179            is_variadic,
180        }
181    }
182
183    /// Create a new `ExtensionFunction` taking no arguments
184    pub fn nullary(
185        name: Name,
186        style: CallStyle,
187        func: NullaryExtensionFunctionObject,
188        return_type: SchemaType,
189    ) -> Self {
190        Self::new(
191            name.clone(),
192            style,
193            Box::new(move |args: &[Value]| {
194                if args.is_empty() {
195                    func()
196                } else {
197                    Err(evaluator::EvaluationError::wrong_num_arguments(
198                        name.clone(),
199                        0,
200                        args.len(),
201                        None, // evaluator will add the source location later
202                    ))
203                }
204            }),
205            Some(return_type),
206            vec![],
207            false,
208        )
209    }
210
211    /// Create a new `ExtensionFunction` to represent a function which is an
212    /// "unknown" in partial evaluation. Please don't use this for anything else.
213    pub fn partial_eval_unknown(
214        name: Name,
215        style: CallStyle,
216        func: UnaryExtensionFunctionObject,
217        arg_type: SchemaType,
218    ) -> Self {
219        Self::new(
220            name.clone(),
221            style,
222            Box::new(move |args: &[Value]| match args.first() {
223                Some(arg) => func(arg),
224                None => Err(evaluator::EvaluationError::wrong_num_arguments(
225                    name.clone(),
226                    1,
227                    args.len(),
228                    None, // evaluator will add the source location later
229                )),
230            }),
231            None,
232            vec![arg_type],
233            false,
234        )
235    }
236
237    /// Create a new `ExtensionFunction` taking one argument
238    pub fn unary(
239        name: Name,
240        style: CallStyle,
241        func: UnaryExtensionFunctionObject,
242        return_type: SchemaType,
243        arg_type: SchemaType,
244    ) -> Self {
245        Self::new(
246            name.clone(),
247            style,
248            Box::new(move |args: &[Value]| match &args {
249                &[arg] => func(arg),
250                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
251                    name.clone(),
252                    1,
253                    args.len(),
254                    None, // evaluator will add the source location later
255                )),
256            }),
257            Some(return_type),
258            vec![arg_type],
259            false,
260        )
261    }
262
263    /// Create a new `ExtensionFunction` taking two arguments
264    pub fn binary(
265        name: Name,
266        style: CallStyle,
267        func: BinaryExtensionFunctionObject,
268        return_type: SchemaType,
269        arg_types: (SchemaType, SchemaType),
270    ) -> Self {
271        Self::new(
272            name.clone(),
273            style,
274            Box::new(move |args: &[Value]| match &args {
275                &[first, second] => func(first, second),
276                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
277                    name.clone(),
278                    2,
279                    args.len(),
280                    None, // evaluator will add the source location later
281                )),
282            }),
283            Some(return_type),
284            vec![arg_types.0, arg_types.1],
285            false,
286        )
287    }
288
289    /// Create a new `ExtensionFunction` taking three arguments
290    pub fn ternary(
291        name: Name,
292        style: CallStyle,
293        func: TernaryExtensionFunctionObject,
294        return_type: SchemaType,
295        arg_types: (SchemaType, SchemaType, SchemaType),
296    ) -> Self {
297        Self::new(
298            name.clone(),
299            style,
300            Box::new(move |args: &[Value]| match &args {
301                &[first, second, third] => func(first, second, third),
302                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
303                    name.clone(),
304                    3,
305                    args.len(),
306                    None, // evaluator will add the source location later
307                )),
308            }),
309            Some(return_type),
310            vec![arg_types.0, arg_types.1, arg_types.2],
311            false,
312        )
313    }
314
315    /// Create a new variadic `ExtensionFunction` taking two or more argument.
316    pub fn variadic(
317        name: Name,
318        style: CallStyle,
319        func: VariadicExtensionFunctionObject,
320        return_type: SchemaType,
321        arg_types: (SchemaType, SchemaType),
322    ) -> Self {
323        Self::new(
324            name.clone(),
325            style,
326            Box::new(move |args: &[Value]| match &args {
327                #[cfg(feature = "variadic-is-in-range")]
328                &[first, rest @ ..] => func(first, rest),
329                #[cfg(not(feature = "variadic-is-in-range"))]
330                &[first, second] => func(first, std::slice::from_ref(second)),
331                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
332                    name.clone(),
333                    2,
334                    args.len(),
335                    None, // evaluator will add the source location later
336                )),
337            }),
338            Some(return_type),
339            vec![arg_types.0, arg_types.1],
340            #[cfg(feature = "variadic-is-in-range")]
341            true,
342            #[cfg(not(feature = "variadic-is-in-range"))]
343            false,
344        )
345    }
346
347    /// Get the `Name` of the `ExtensionFunction`
348    pub fn name(&self) -> &Name {
349        &self.name
350    }
351
352    /// Get the `CallStyle` of the `ExtensionFunction`
353    pub fn style(&self) -> CallStyle {
354        self.style
355    }
356
357    /// Get the return type of the `ExtensionFunction`
358    /// `None` is returned exactly when this function represents an "unknown"
359    /// for partial evaluation.
360    pub fn return_type(&self) -> Option<&SchemaType> {
361        self.return_type.as_ref()
362    }
363
364    /// Get the argument types of the `ExtensionFunction`.
365    pub fn arg_types(&self) -> &[SchemaType] {
366        &self.arg_types
367    }
368
369    /// Whether this is a variadic function.
370    pub fn is_variadic(&self) -> bool {
371        self.is_variadic
372    }
373
374    /// Returns `true` if this function is considered a single argument
375    /// constructor.
376    ///
377    /// Only functions satisfying this predicate can have their names implicit
378    /// during schema-based entity parsing
379    pub fn is_single_arg_constructor(&self) -> bool {
380        // return type is an extension type
381        matches!(self.return_type(), Some(SchemaType::Extension { .. }))
382        // the only argument is a string
383        && matches!(self.arg_types(), [SchemaType::String])
384    }
385
386    /// Call the `ExtensionFunction` with the given args
387    pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
388        match (self.func)(args)? {
389            ExtensionOutputValue::Known(v) => Ok(PartialValue::Value(v)),
390            ExtensionOutputValue::Unknown(u) => Ok(PartialValue::Residual(Expr::unknown(u))),
391        }
392    }
393
394    /// Iterate over the extension types that could be produced by this
395    /// function, if any
396    pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
397        self.return_type
398            .iter()
399            .flat_map(|ret_ty| ret_ty.contained_ext_types())
400    }
401}
402
403impl std::fmt::Debug for ExtensionFunction {
404    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405        write!(f, "<extension function {}>", self.name())
406    }
407}
408
409/// Extension value.
410///
411/// Anything implementing this trait can be used as a first-class value in
412/// Cedar. For instance, the `ipaddr` extension uses this mechanism
413/// to implement IPAddr as a Cedar first-class value.
414pub trait ExtensionValue: Debug + Send + Sync + UnwindSafe + RefUnwindSafe {
415    /// Get the name of the type of this value.
416    ///
417    /// Cedar has nominal typing, so two values have the same type iff they
418    /// return the same typename here.
419    fn typename(&self) -> Name;
420
421    /// If it supports operator overloading
422    fn supports_operator_overloading(&self) -> bool;
423}
424
425impl<V: ExtensionValue> StaticallyTyped for V {
426    fn type_of(&self) -> Type {
427        Type::Extension {
428            name: self.typename(),
429        }
430    }
431}
432
433#[derive(Debug, Clone)]
434/// Object container for extension values
435/// An extension value must be representable by a [`RestrictedExpr`]
436/// Specifically, it will be a function call `func` on `args`
437/// Note that `func` may not be the constructor. A counterexample is that a
438/// `datetime` is represented by an `offset` method call.
439/// Nevertheless, an invariant is that `eval(<func>(<args>)) == value`
440pub struct RepresentableExtensionValue {
441    pub(crate) func: Name,
442    pub(crate) args: Vec<RestrictedExpr>,
443    pub(crate) value: Arc<dyn InternalExtensionValue>,
444}
445
446impl RepresentableExtensionValue {
447    /// Create a new [`RepresentableExtensionValue`]
448    pub fn new(
449        value: Arc<dyn InternalExtensionValue + Send + Sync>,
450        func: Name,
451        args: Vec<RestrictedExpr>,
452    ) -> Self {
453        Self { func, args, value }
454    }
455
456    /// Get the internal value
457    pub fn value(&self) -> &dyn InternalExtensionValue {
458        self.value.as_ref()
459    }
460
461    /// Get the typename of this extension value
462    pub fn typename(&self) -> Name {
463        self.value.typename()
464    }
465
466    /// If this value supports operator overloading
467    pub(crate) fn supports_operator_overloading(&self) -> bool {
468        self.value.supports_operator_overloading()
469    }
470}
471
472impl From<RepresentableExtensionValue> for RestrictedExpr {
473    fn from(val: RepresentableExtensionValue) -> Self {
474        RestrictedExpr::call_extension_fn(val.func, val.args)
475    }
476}
477
478impl StaticallyTyped for RepresentableExtensionValue {
479    fn type_of(&self) -> Type {
480        self.value.type_of()
481    }
482}
483
484impl PartialEq for RepresentableExtensionValue {
485    fn eq(&self, other: &Self) -> bool {
486        // Values that are equal are equal regardless of which arguments made them
487        self.value.as_ref() == other.value.as_ref()
488    }
489}
490
491impl Eq for RepresentableExtensionValue {}
492
493impl PartialOrd for RepresentableExtensionValue {
494    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
495        Some(self.cmp(other))
496    }
497}
498
499impl Ord for RepresentableExtensionValue {
500    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
501        self.value.cmp(&other.value)
502    }
503}
504
505/// Extensions provide a type implementing `ExtensionValue`, `Eq`, and `Ord`.
506/// We automatically implement `InternalExtensionValue` for that type (with the
507/// impl below).  Internally, we use `dyn InternalExtensionValue` instead of
508/// `dyn ExtensionValue`.
509///
510/// You might wonder why we don't just have `ExtensionValue: Eq + Ord` and use
511/// `dyn ExtensionValue` everywhere.  The answer is that the Rust compiler
512/// doesn't let you because of
513/// [object safety](https://doc.rust-lang.org/reference/items/traits.html#object-safety).
514/// So instead we have this workaround where we define our own `equals_extvalue`
515/// method that compares not against `&Self` but against `&dyn InternalExtensionValue`,
516/// and likewise for `cmp_extvalue`.
517pub trait InternalExtensionValue: ExtensionValue {
518    /// convert to an `Any`
519    fn as_any(&self) -> &dyn Any;
520    /// this will be the basis for `PartialEq` on `InternalExtensionValue`; but
521    /// note the `&dyn` (normal `PartialEq` doesn't have the `dyn`)
522    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
523    /// this will be the basis for `Ord` on `InternalExtensionValue`; but note
524    /// the `&dyn` (normal `Ord` doesn't have the `dyn`)
525    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
526}
527
528impl<V: 'static + Eq + Ord + ExtensionValue + Send + Sync + Clone> InternalExtensionValue for V {
529    fn as_any(&self) -> &dyn Any {
530        self
531    }
532
533    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
534        other
535            .as_any()
536            .downcast_ref::<V>()
537            .map(|v| self == v)
538            .unwrap_or(false) // if the downcast failed, values are different types, so equality is false
539    }
540
541    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
542        other
543            .as_any()
544            .downcast_ref::<V>()
545            .map(|v| self.cmp(v))
546            .unwrap_or_else(|| {
547                // downcast failed, so values are different types.
548                // we fall back on the total ordering on typenames.
549                self.typename().cmp(&other.typename())
550            })
551    }
552}
553
554impl PartialEq for dyn InternalExtensionValue {
555    fn eq(&self, other: &Self) -> bool {
556        self.equals_extvalue(other)
557    }
558}
559
560impl Eq for dyn InternalExtensionValue {}
561
562impl PartialOrd for dyn InternalExtensionValue {
563    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
564        Some(self.cmp(other))
565    }
566}
567
568impl Ord for dyn InternalExtensionValue {
569    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
570        self.cmp_extvalue(other)
571    }
572}
573
574impl StaticallyTyped for dyn InternalExtensionValue {
575    fn type_of(&self) -> Type {
576        Type::Extension {
577            name: self.typename(),
578        }
579    }
580}