cedar_policy_core/ast/
extension.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::SchemaType;
19use crate::evaluator;
20use std::any::Any;
21use std::collections::HashMap;
22use std::fmt::{Debug, Display};
23use std::panic::{RefUnwindSafe, UnwindSafe};
24use std::sync::Arc;
25
26/// Cedar extension.
27///
28/// An extension can define new types and functions on those types. (Currently,
29/// there's nothing preventing an extension from defining new functions on
30/// built-in types, either, although we haven't discussed whether we want to
31/// allow this long-term.)
32pub struct Extension {
33    /// Name of the extension
34    name: Name,
35    /// Extension functions. These are legal to call in Cedar expressions.
36    functions: HashMap<Name, ExtensionFunction>,
37}
38
39impl Extension {
40    /// Create a new `Extension` with the given name and extension functions
41    pub fn new(name: Name, functions: impl IntoIterator<Item = ExtensionFunction>) -> Self {
42        Self {
43            name,
44            functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
45        }
46    }
47
48    /// Get the name of the extension
49    pub fn name(&self) -> &Name {
50        &self.name
51    }
52
53    /// Look up a function by name, or return `None` if the extension doesn't
54    /// provide a function with that name
55    pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
56        self.functions.get(name)
57    }
58
59    /// Iterate over the functions
60    pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
61        self.functions.values()
62    }
63
64    /// Iterate over the extension types that can be produced by any functions
65    /// in this extension
66    pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
67        self.funcs().flat_map(|func| func.ext_types())
68    }
69}
70
71impl std::fmt::Debug for Extension {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        write!(f, "<extension {}>", self.name())
74    }
75}
76
77/// The output of an extension call, either a value or an unknown
78#[derive(Debug, Clone)]
79pub enum ExtensionOutputValue {
80    /// A concrete value from an extension call
81    Known(Value),
82    /// An unknown returned from an extension call
83    Unknown(Unknown),
84}
85
86impl<T> From<T> for ExtensionOutputValue
87where
88    T: Into<Value>,
89{
90    fn from(v: T) -> Self {
91        ExtensionOutputValue::Known(v.into())
92    }
93}
94
95/// Which "style" is a function call
96#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
97#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
98pub enum CallStyle {
99    /// Function-style, eg foo(a, b)
100    FunctionStyle,
101    /// Method-style, eg a.foo(b)
102    MethodStyle,
103}
104
105// Note: we could use currying to make this a little nicer
106
107/// Trait object that implements the extension function call.
108pub type ExtensionFunctionObject =
109    Box<dyn Fn(&[Value]) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>;
110
111/// Extension function. These can be called by the given `name` in Ceder
112/// expressions.
113pub struct ExtensionFunction {
114    /// Name of the function
115    name: Name,
116    /// Which `CallStyle` should be used when calling this function
117    style: CallStyle,
118    /// The actual function, which takes an `&[Value]` and returns a `Value`,
119    /// or an evaluation error
120    func: ExtensionFunctionObject,
121    /// The return type of this function, as a `SchemaType`. We require that
122    /// this be constant -- any given extension function must always return a
123    /// value of this `SchemaType`.
124    ///
125    /// `return_type` is `None` if and only if this function represents an
126    /// "unknown" value for partial evaluation. Such a function may only return
127    /// a fully unknown residual and may never return a value.
128    return_type: Option<SchemaType>,
129    /// The argument types that this function expects, as `SchemaType`s.
130    arg_types: Vec<SchemaType>,
131}
132
133impl ExtensionFunction {
134    /// Create a new `ExtensionFunction` taking any number of arguments
135    fn new(
136        name: Name,
137        style: CallStyle,
138        func: ExtensionFunctionObject,
139        return_type: Option<SchemaType>,
140        arg_types: Vec<SchemaType>,
141    ) -> Self {
142        Self {
143            name,
144            func,
145            style,
146            return_type,
147            arg_types,
148        }
149    }
150
151    /// Create a new `ExtensionFunction` taking no arguments
152    pub fn nullary(
153        name: Name,
154        style: CallStyle,
155        func: Box<dyn Fn() -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
156        return_type: SchemaType,
157    ) -> Self {
158        Self::new(
159            name.clone(),
160            style,
161            Box::new(move |args: &[Value]| {
162                if args.is_empty() {
163                    func()
164                } else {
165                    Err(evaluator::EvaluationError::wrong_num_arguments(
166                        name.clone(),
167                        0,
168                        args.len(),
169                        None, // evaluator will add the source location later
170                    ))
171                }
172            }),
173            Some(return_type),
174            vec![],
175        )
176    }
177
178    /// Create a new `ExtensionFunction` to represent a function which is an
179    /// "unknown" in partial evaluation. Please don't use this for anything else.
180    pub fn partial_eval_unknown(
181        name: Name,
182        style: CallStyle,
183        func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
184        arg_type: SchemaType,
185    ) -> Self {
186        Self::new(
187            name.clone(),
188            style,
189            Box::new(move |args: &[Value]| match args.first() {
190                Some(arg) => func(arg.clone()),
191                None => Err(evaluator::EvaluationError::wrong_num_arguments(
192                    name.clone(),
193                    1,
194                    args.len(),
195                    None, // evaluator will add the source location later
196                )),
197            }),
198            None,
199            vec![arg_type],
200        )
201    }
202
203    /// Create a new `ExtensionFunction` taking one argument
204    pub fn unary(
205        name: Name,
206        style: CallStyle,
207        func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
208        return_type: SchemaType,
209        arg_type: SchemaType,
210    ) -> Self {
211        Self::new(
212            name.clone(),
213            style,
214            Box::new(move |args: &[Value]| match &args {
215                &[arg] => func(arg.clone()),
216                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
217                    name.clone(),
218                    1,
219                    args.len(),
220                    None, // evaluator will add the source location later
221                )),
222            }),
223            Some(return_type),
224            vec![arg_type],
225        )
226    }
227
228    /// Create a new `ExtensionFunction` taking two arguments
229    pub fn binary(
230        name: Name,
231        style: CallStyle,
232        func: Box<
233            dyn Fn(Value, Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static,
234        >,
235        return_type: SchemaType,
236        arg_types: (SchemaType, SchemaType),
237    ) -> Self {
238        Self::new(
239            name.clone(),
240            style,
241            Box::new(move |args: &[Value]| match &args {
242                &[first, second] => func(first.clone(), second.clone()),
243                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
244                    name.clone(),
245                    2,
246                    args.len(),
247                    None, // evaluator will add the source location later
248                )),
249            }),
250            Some(return_type),
251            vec![arg_types.0, arg_types.1],
252        )
253    }
254
255    /// Create a new `ExtensionFunction` taking three arguments
256    pub fn ternary(
257        name: Name,
258        style: CallStyle,
259        func: Box<
260            dyn Fn(Value, Value, Value) -> evaluator::Result<ExtensionOutputValue>
261                + Sync
262                + Send
263                + 'static,
264        >,
265        return_type: SchemaType,
266        arg_types: (SchemaType, SchemaType, SchemaType),
267    ) -> Self {
268        Self::new(
269            name.clone(),
270            style,
271            Box::new(move |args: &[Value]| match &args {
272                &[first, second, third] => func(first.clone(), second.clone(), third.clone()),
273                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
274                    name.clone(),
275                    3,
276                    args.len(),
277                    None, // evaluator will add the source location later
278                )),
279            }),
280            Some(return_type),
281            vec![arg_types.0, arg_types.1, arg_types.2],
282        )
283    }
284
285    /// Get the `Name` of the `ExtensionFunction`
286    pub fn name(&self) -> &Name {
287        &self.name
288    }
289
290    /// Get the `CallStyle` of the `ExtensionFunction`
291    pub fn style(&self) -> CallStyle {
292        self.style
293    }
294
295    /// Get the return type of the `ExtensionFunction`
296    /// `None` is returned exactly when this function represents an "unknown"
297    /// for partial evaluation.
298    pub fn return_type(&self) -> Option<&SchemaType> {
299        self.return_type.as_ref()
300    }
301
302    /// Get the argument types of the `ExtensionFunction`.
303    pub fn arg_types(&self) -> &[SchemaType] {
304        &self.arg_types
305    }
306
307    /// Returns `true` if this function is considered a "constructor".
308    ///
309    /// Currently, the only impact of this is that non-constructors are not
310    /// accessible in the JSON format (entities/json.rs).
311    pub fn is_constructor(&self) -> bool {
312        // return type is an extension type
313        matches!(self.return_type(), Some(SchemaType::Extension { .. }))
314        // no argument is an extension type
315        && !self.arg_types().iter().any(|ty| matches!(ty, SchemaType::Extension { .. }))
316    }
317
318    /// Call the `ExtensionFunction` with the given args
319    pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
320        match (self.func)(args)? {
321            ExtensionOutputValue::Known(v) => Ok(PartialValue::Value(v)),
322            ExtensionOutputValue::Unknown(u) => Ok(PartialValue::Residual(Expr::unknown(u))),
323        }
324    }
325
326    /// Iterate over the extension types that could be produced by this
327    /// function, if any
328    pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
329        self.return_type
330            .iter()
331            .flat_map(|ret_ty| ret_ty.contained_ext_types())
332    }
333}
334
335impl std::fmt::Debug for ExtensionFunction {
336    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        write!(f, "<extension function {}>", self.name())
338    }
339}
340
341/// Extension value.
342///
343/// Anything implementing this trait can be used as a first-class value in
344/// Cedar. For instance, the `ipaddr` extension uses this mechanism
345/// to implement IPAddr as a Cedar first-class value.
346pub trait ExtensionValue: Debug + Display + Send + Sync + UnwindSafe + RefUnwindSafe {
347    /// Get the name of the type of this value.
348    ///
349    /// Cedar has nominal typing, so two values have the same type iff they
350    /// return the same typename here.
351    fn typename(&self) -> Name;
352}
353
354impl<V: ExtensionValue> StaticallyTyped for V {
355    fn type_of(&self) -> Type {
356        Type::Extension {
357            name: self.typename(),
358        }
359    }
360}
361
362#[derive(Debug, Clone)]
363/// Object container for extension values, also stores the constructor-and-args
364/// that can reproduce the value (important for converting the value back to
365/// `RestrictedExpr` for instance)
366pub struct ExtensionValueWithArgs {
367    value: Arc<dyn InternalExtensionValue>,
368    pub(crate) constructor: Name,
369    /// Args are stored in `RestrictedExpr` form, just because that's most
370    /// convenient for reconstructing a `RestrictedExpr` that reproduces this
371    /// extension value
372    pub(crate) args: Vec<RestrictedExpr>,
373}
374
375impl ExtensionValueWithArgs {
376    /// Create a new `ExtensionValueWithArgs`
377    pub fn new(
378        value: Arc<dyn InternalExtensionValue + Send + Sync>,
379        constructor: Name,
380        args: Vec<RestrictedExpr>,
381    ) -> Self {
382        Self {
383            value,
384            constructor,
385            args,
386        }
387    }
388
389    /// Get the internal value
390    pub fn value(&self) -> &(dyn InternalExtensionValue) {
391        self.value.as_ref()
392    }
393
394    /// Get the typename of this extension value
395    pub fn typename(&self) -> Name {
396        self.value.typename()
397    }
398
399    /// Get the constructor and args that can reproduce this value
400    pub fn constructor_and_args(&self) -> (&Name, &[RestrictedExpr]) {
401        (&self.constructor, &self.args)
402    }
403}
404
405impl From<ExtensionValueWithArgs> for Expr {
406    fn from(val: ExtensionValueWithArgs) -> Self {
407        ExprBuilder::new().call_extension_fn(val.constructor, val.args.into_iter().map(Into::into))
408    }
409}
410
411impl StaticallyTyped for ExtensionValueWithArgs {
412    fn type_of(&self) -> Type {
413        self.value.type_of()
414    }
415}
416
417impl Display for ExtensionValueWithArgs {
418    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419        write!(f, "{}", self.value)
420    }
421}
422
423impl PartialEq for ExtensionValueWithArgs {
424    fn eq(&self, other: &Self) -> bool {
425        // Values that are equal are equal regardless of which arguments made them
426        self.value.as_ref() == other.value.as_ref()
427    }
428}
429
430impl Eq for ExtensionValueWithArgs {}
431
432impl PartialOrd for ExtensionValueWithArgs {
433    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
434        Some(self.cmp(other))
435    }
436}
437
438impl Ord for ExtensionValueWithArgs {
439    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
440        self.value.cmp(&other.value)
441    }
442}
443
444/// Extensions provide a type implementing `ExtensionValue`, `Eq`, and `Ord`.
445/// We automatically implement `InternalExtensionValue` for that type (with the
446/// impl below).  Internally, we use `dyn InternalExtensionValue` instead of
447/// `dyn ExtensionValue`.
448///
449/// You might wonder why we don't just have `ExtensionValue: Eq + Ord` and use
450/// `dyn ExtensionValue` everywhere.  The answer is that the Rust compiler
451/// doesn't let you because of
452/// [object safety](https://doc.rust-lang.org/reference/items/traits.html#object-safety).
453/// So instead we have this workaround where we define our own `equals_extvalue`
454/// method that compares not against `&Self` but against `&dyn InternalExtensionValue`,
455/// and likewise for `cmp_extvalue`.
456pub trait InternalExtensionValue: ExtensionValue {
457    /// convert to an `Any`
458    fn as_any(&self) -> &dyn Any;
459    /// this will be the basis for `PartialEq` on `InternalExtensionValue`; but
460    /// note the `&dyn` (normal `PartialEq` doesn't have the `dyn`)
461    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
462    /// this will be the basis for `Ord` on `InternalExtensionValue`; but note
463    /// the `&dyn` (normal `Ord` doesn't have the `dyn`)
464    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
465}
466
467impl<V: 'static + Eq + Ord + ExtensionValue + Send + Sync> InternalExtensionValue for V {
468    fn as_any(&self) -> &dyn Any {
469        self
470    }
471
472    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
473        other
474            .as_any()
475            .downcast_ref::<V>()
476            .map(|v| self == v)
477            .unwrap_or(false) // if the downcast failed, values are different types, so equality is false
478    }
479
480    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
481        other
482            .as_any()
483            .downcast_ref::<V>()
484            .map(|v| self.cmp(v))
485            .unwrap_or_else(|| {
486                // downcast failed, so values are different types.
487                // we fall back on the total ordering on typenames.
488                self.typename().cmp(&other.typename())
489            })
490    }
491}
492
493impl PartialEq for dyn InternalExtensionValue {
494    fn eq(&self, other: &Self) -> bool {
495        self.equals_extvalue(other)
496    }
497}
498
499impl Eq for dyn InternalExtensionValue {}
500
501impl PartialOrd for dyn InternalExtensionValue {
502    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
503        Some(self.cmp(other))
504    }
505}
506
507impl Ord for dyn InternalExtensionValue {
508    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
509        self.cmp_extvalue(other)
510    }
511}
512
513impl StaticallyTyped for dyn InternalExtensionValue {
514    fn type_of(&self) -> Type {
515        Type::Extension {
516            name: self.typename(),
517        }
518    }
519}