cedar_policy_core/ast/
extension.rs

1/*
2 * Copyright 2022-2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use smol_str::SmolStr;
18
19use crate::ast::*;
20use crate::entities::SchemaType;
21use crate::evaluator;
22use std::any::Any;
23use std::collections::HashMap;
24use std::fmt::{Debug, Display};
25use std::sync::Arc;
26
27/// Cedar extension.
28///
29/// An extension can define new types and functions on those types. (Currently,
30/// there's nothing preventing an extension from defining new functions on
31/// built-in types, either, although we haven't discussed whether we want to
32/// allow this long-term.)
33pub struct Extension {
34    /// Name of the extension
35    name: Name,
36    /// Extension functions. These are legal to call in Cedar expressions.
37    functions: HashMap<Name, ExtensionFunction>,
38}
39
40impl Extension {
41    /// Create a new `Extension` with the given name and extension functions
42    pub fn new(name: Name, functions: impl IntoIterator<Item = ExtensionFunction>) -> Self {
43        Self {
44            name,
45            functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
46        }
47    }
48
49    /// Get the name of the extension
50    pub fn name(&self) -> &Name {
51        &self.name
52    }
53
54    /// Look up a function by name, or return `None` if the extension doesn't
55    /// provide a function with that name
56    pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
57        self.functions.get(name)
58    }
59
60    /// Get an iterator over the function names
61    pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
62        self.functions.values()
63    }
64}
65
66impl std::fmt::Debug for Extension {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        write!(f, "<extension {}>", self.name())
69    }
70}
71
72/// The output of an extension call, either a value or an unknown
73#[derive(Debug, Clone)]
74pub enum ExtensionOutputValue {
75    /// A concrete value from an extension call
76    Concrete(Value),
77    /// An unknown returned from an extension call
78    Unknown(SmolStr),
79}
80
81impl<T> From<T> for ExtensionOutputValue
82where
83    T: Into<Value>,
84{
85    fn from(v: T) -> Self {
86        ExtensionOutputValue::Concrete(v.into())
87    }
88}
89
90/// Which "style" is a function call
91#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
92#[cfg_attr(fuzzing, derive(arbitrary::Arbitrary))]
93pub enum CallStyle {
94    /// Function-style, eg foo(a, b)
95    FunctionStyle,
96    /// Method-style, eg a.foo(b)
97    MethodStyle,
98}
99
100// Note: we could use currying to make this a little nicer
101
102/// Trait object that implements the extension function call.
103pub type ExtensionFunctionObject =
104    Box<dyn Fn(&[Value]) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>;
105
106/// Extension function. These can be called by the given `name` in Ceder
107/// expressions.
108pub struct ExtensionFunction {
109    /// Name of the function
110    name: Name,
111    /// Which `CallStyle` should be used when calling this function
112    style: CallStyle,
113    /// The actual function, which takes an `&[Value]` and returns a `Value`,
114    /// or an evaluation error
115    func: ExtensionFunctionObject,
116    /// The return type of this function, as a `SchemaType`. We require that
117    /// this be constant -- any given extension function must always return a
118    /// value of this `SchemaType`.
119    /// If `return_type` is `None`, the function may never return a value.
120    /// (ie: it functions as the `Never` type)
121    return_type: Option<SchemaType>,
122    /// The argument types that this function expects, as `SchemaType`s. If any
123    /// given argument type is not constant (function works with multiple
124    /// `SchemaType`s) then this will be `None` for that argument.
125    arg_types: Vec<Option<SchemaType>>,
126}
127
128impl ExtensionFunction {
129    /// Create a new `ExtensionFunction` taking any number of arguments
130    fn new(
131        name: Name,
132        style: CallStyle,
133        func: ExtensionFunctionObject,
134        return_type: Option<SchemaType>,
135        arg_types: Vec<Option<SchemaType>>,
136    ) -> Self {
137        Self {
138            name,
139            func,
140            style,
141            return_type,
142            arg_types,
143        }
144    }
145
146    /// Create a new `ExtensionFunction` taking no arguments
147    pub fn nullary(
148        name: Name,
149        style: CallStyle,
150        func: Box<dyn Fn() -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
151        return_type: SchemaType,
152    ) -> Self {
153        Self::new(
154            name.clone(),
155            style,
156            Box::new(move |args: &[Value]| {
157                if args.is_empty() {
158                    func()
159                } else {
160                    Err(evaluator::EvaluationError::WrongNumArguments {
161                        op: ExtensionFunctionOp {
162                            function_name: name.clone(),
163                        },
164                        expected: 0,
165                        actual: args.len(),
166                    })
167                }
168            }),
169            Some(return_type),
170            vec![],
171        )
172    }
173
174    /// Create a new `ExtensionFunction` taking one argument, that never returns a value
175    pub fn unary_never(
176        name: Name,
177        style: CallStyle,
178        func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
179        arg_type: Option<SchemaType>,
180    ) -> Self {
181        Self::new(
182            name.clone(),
183            style,
184            Box::new(move |args: &[Value]| {
185                if args.len() == 1 {
186                    func(args[0].clone())
187                } else {
188                    let op = ExtensionFunctionOp {
189                        function_name: name.clone(),
190                    };
191                    Err(evaluator::EvaluationError::WrongNumArguments {
192                        op,
193                        expected: 1,
194                        actual: args.len(),
195                    })
196                }
197            }),
198            None,
199            vec![arg_type],
200        )
201    }
202
203    /// Create a new `ExtensionFunction` taking one argument
204    pub fn unary(
205        name: Name,
206        style: CallStyle,
207        func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
208        return_type: SchemaType,
209        arg_type: Option<SchemaType>,
210    ) -> Self {
211        Self::new(
212            name.clone(),
213            style,
214            Box::new(move |args: &[Value]| {
215                if args.len() == 1 {
216                    func(args[0].clone())
217                } else {
218                    let op = ExtensionFunctionOp {
219                        function_name: name.clone(),
220                    };
221                    Err(evaluator::EvaluationError::WrongNumArguments {
222                        op,
223                        expected: 1,
224                        actual: args.len(),
225                    })
226                }
227            }),
228            Some(return_type),
229            vec![arg_type],
230        )
231    }
232
233    /// Create a new `ExtensionFunction` taking two arguments
234    pub fn binary(
235        name: Name,
236        style: CallStyle,
237        func: Box<
238            dyn Fn(Value, Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static,
239        >,
240        return_type: SchemaType,
241        arg_types: (Option<SchemaType>, Option<SchemaType>),
242    ) -> Self {
243        Self::new(
244            name.clone(),
245            style,
246            Box::new(move |args: &[Value]| {
247                if args.len() == 2 {
248                    func(args[0].clone(), args[1].clone())
249                } else {
250                    Err(evaluator::EvaluationError::WrongNumArguments {
251                        op: ExtensionFunctionOp {
252                            function_name: name.clone(),
253                        },
254                        expected: 2,
255                        actual: args.len(),
256                    })
257                }
258            }),
259            Some(return_type),
260            vec![arg_types.0, arg_types.1],
261        )
262    }
263
264    /// Create a new `ExtensionFunction` taking three arguments
265    pub fn ternary(
266        name: Name,
267        style: CallStyle,
268        func: Box<
269            dyn Fn(Value, Value, Value) -> evaluator::Result<ExtensionOutputValue>
270                + Sync
271                + Send
272                + 'static,
273        >,
274        return_type: SchemaType,
275        arg_types: (Option<SchemaType>, Option<SchemaType>, Option<SchemaType>),
276    ) -> Self {
277        Self::new(
278            name.clone(),
279            style,
280            Box::new(move |args: &[Value]| {
281                if args.len() == 3 {
282                    func(args[0].clone(), args[1].clone(), args[2].clone())
283                } else {
284                    Err(evaluator::EvaluationError::WrongNumArguments {
285                        op: ExtensionFunctionOp {
286                            function_name: name.clone(),
287                        },
288                        expected: 3,
289                        actual: args.len(),
290                    })
291                }
292            }),
293            Some(return_type),
294            vec![arg_types.0, arg_types.1, arg_types.2],
295        )
296    }
297
298    /// Get the `Name` of the `ExtensionFunction`
299    pub fn name(&self) -> &Name {
300        &self.name
301    }
302
303    /// Get the `CallStyle` of the `ExtensionFunction`
304    pub fn style(&self) -> CallStyle {
305        self.style
306    }
307
308    /// Get the return type of the `ExtensionFunction`
309    /// `None` represents the `Never` type.
310    pub fn return_type(&self) -> Option<&SchemaType> {
311        self.return_type.as_ref()
312    }
313
314    /// Get the argument types of the `ExtensionFunction`.
315    ///
316    /// If any given argument type is not constant (function works with multiple
317    /// `SchemaType`s) then this will be `None` for that argument.
318    pub fn arg_types(&self) -> &[Option<SchemaType>] {
319        &self.arg_types
320    }
321
322    /// Returns `true` if this function is considered a "constructor".
323    ///
324    /// Currently, the only impact of this is that non-constructors are not
325    /// accessible in the JSON format (entities/json.rs).
326    pub fn is_constructor(&self) -> bool {
327        // return type is an extension type
328        matches!(self.return_type(), Some(SchemaType::Extension { .. }))
329        // all arg types are `Some()`
330        && self.arg_types().iter().all(Option::is_some)
331        // no argument is an extension type
332        && !self.arg_types().iter().any(|ty| matches!(ty, Some(SchemaType::Extension { .. })))
333    }
334
335    /// Call the `ExtensionFunction` with the given args
336    pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
337        match (self.func)(args)? {
338            ExtensionOutputValue::Concrete(v) => Ok(PartialValue::Value(v)),
339            ExtensionOutputValue::Unknown(name) => Ok(PartialValue::Residual(Expr::unknown(name))),
340        }
341    }
342}
343
344impl std::fmt::Debug for ExtensionFunction {
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        write!(f, "<extension function {}>", self.name())
347    }
348}
349
350/// Extension value.
351///
352/// Anything implementing this trait can be used as a first-class value in
353/// Cedar. For instance, the `ipaddr` extension uses this mechanism
354/// to implement IPAddr as a Cedar first-class value.
355pub trait ExtensionValue: Debug + Display {
356    /// Get the name of the type of this value.
357    ///
358    /// Cedar has nominal typing, so two values have the same type iff they
359    /// return the same typename here.
360    fn typename(&self) -> Name;
361}
362
363impl<V: ExtensionValue> StaticallyTyped for V {
364    fn type_of(&self) -> Type {
365        Type::Extension {
366            name: self.typename(),
367        }
368    }
369}
370
371#[derive(Debug, Clone)]
372/// Object container for extension values, also stores the fully reduced AST
373/// for the arguments
374pub struct ExtensionValueWithArgs {
375    value: Arc<dyn InternalExtensionValue>,
376    args: Vec<Expr>,
377    constructor: ExtensionFunctionOp,
378}
379
380impl ExtensionValueWithArgs {
381    /// Get the internal value
382    pub fn value(&self) -> &dyn InternalExtensionValue {
383        self.value.as_ref()
384    }
385
386    /// Get the typename of this extension value
387    pub fn typename(&self) -> Name {
388        self.value.typename()
389    }
390
391    /// Constructor
392    pub fn new(
393        value: Arc<dyn InternalExtensionValue>,
394        args: Vec<Expr>,
395        constructor: ExtensionFunctionOp,
396    ) -> Self {
397        Self {
398            value,
399            args,
400            constructor,
401        }
402    }
403}
404
405impl From<ExtensionValueWithArgs> for Expr {
406    fn from(val: ExtensionValueWithArgs) -> Self {
407        ExprBuilder::new().call_extension_fn(val.constructor.function_name, val.args)
408    }
409}
410
411impl StaticallyTyped for ExtensionValueWithArgs {
412    fn type_of(&self) -> Type {
413        self.value.type_of()
414    }
415}
416
417impl Display for ExtensionValueWithArgs {
418    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419        write!(f, "{}", self.value)
420    }
421}
422
423impl PartialEq for ExtensionValueWithArgs {
424    fn eq(&self, other: &Self) -> bool {
425        // Values that are equal are equal regardless of which arguments made them
426        self.value.as_ref() == other.value.as_ref()
427    }
428}
429
430impl Eq for ExtensionValueWithArgs {}
431
432impl PartialOrd for ExtensionValueWithArgs {
433    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
434        self.value.partial_cmp(&other.value)
435    }
436}
437
438impl Ord for ExtensionValueWithArgs {
439    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
440        self.value.cmp(&other.value)
441    }
442}
443
444/// Extensions provide a type implementing `ExtensionValue`, `Eq`, and `Ord`.
445/// We automatically implement `InternalExtensionValue` for that type (with the
446/// impl below).  Internally, we use `dyn InternalExtensionValue` instead of
447/// `dyn ExtensionValue`.
448///
449/// You might wonder why we don't just have `ExtensionValue: Eq + Ord` and use
450/// `dyn ExtensionValue` everywhere.  The answer is that the Rust compiler
451/// doesn't let you because of
452/// [object safety](https://doc.rust-lang.org/reference/items/traits.html#object-safety).
453/// So instead we have this workaround where we define our own `equals_extvalue`
454/// method that compares not against `&Self` but against `&dyn InternalExtensionValue`,
455/// and likewise for `cmp_extvalue`.
456pub trait InternalExtensionValue: ExtensionValue {
457    /// convert to an `Any`
458    fn as_any(&self) -> &dyn Any;
459    /// this will be the basis for `PartialEq` on `InternalExtensionValue`; but
460    /// note the `&dyn` (normal `PartialEq` doesn't have the `dyn`)
461    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
462    /// this will be the basis for `Ord` on `InternalExtensionValue`; but note
463    /// the `&dyn` (normal `Ord` doesn't have the `dyn`)
464    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
465}
466
467impl<V: 'static + Eq + Ord + ExtensionValue> InternalExtensionValue for V {
468    fn as_any(&self) -> &dyn Any {
469        self
470    }
471
472    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
473        other
474            .as_any()
475            .downcast_ref::<V>()
476            .map(|v| self == v)
477            .unwrap_or(false) // if the downcast failed, values are different types, so equality is false
478    }
479
480    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
481        other
482            .as_any()
483            .downcast_ref::<V>()
484            .map(|v| self.cmp(v))
485            .unwrap_or_else(|| {
486                // downcast failed, so values are different types.
487                // we fall back on the total ordering on typenames.
488                self.typename().cmp(&other.typename())
489            })
490    }
491}
492
493impl PartialEq for dyn InternalExtensionValue {
494    fn eq(&self, other: &Self) -> bool {
495        self.equals_extvalue(other)
496    }
497}
498
499impl Eq for dyn InternalExtensionValue {}
500
501impl PartialOrd for dyn InternalExtensionValue {
502    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
503        Some(self.cmp(other))
504    }
505}
506
507impl Ord for dyn InternalExtensionValue {
508    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
509        self.cmp_extvalue(other)
510    }
511}
512
513impl StaticallyTyped for dyn InternalExtensionValue {
514    fn type_of(&self) -> Type {
515        Type::Extension {
516            name: self.typename(),
517        }
518    }
519}