cedar_policy_core/ast/
extension.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::SchemaType;
19use crate::evaluator;
20use std::any::Any;
21use std::collections::HashMap;
22use std::fmt::{Debug, Display};
23use std::panic::{RefUnwindSafe, UnwindSafe};
24use std::sync::Arc;
25
26/// Cedar extension.
27///
28/// An extension can define new types and functions on those types. (Currently,
29/// there's nothing preventing an extension from defining new functions on
30/// built-in types, either, although we haven't discussed whether we want to
31/// allow this long-term.)
32pub struct Extension {
33    /// Name of the extension
34    name: Name,
35    /// Extension functions. These are legal to call in Cedar expressions.
36    functions: HashMap<Name, ExtensionFunction>,
37}
38
39impl Extension {
40    /// Create a new `Extension` with the given name and extension functions
41    pub fn new(name: Name, functions: impl IntoIterator<Item = ExtensionFunction>) -> Self {
42        Self {
43            name,
44            functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
45        }
46    }
47
48    /// Get the name of the extension
49    pub fn name(&self) -> &Name {
50        &self.name
51    }
52
53    /// Look up a function by name, or return `None` if the extension doesn't
54    /// provide a function with that name
55    pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
56        self.functions.get(name)
57    }
58
59    /// Get an iterator over the function names
60    pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
61        self.functions.values()
62    }
63}
64
65impl std::fmt::Debug for Extension {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        write!(f, "<extension {}>", self.name())
68    }
69}
70
71/// The output of an extension call, either a value or an unknown
72#[derive(Debug, Clone)]
73pub enum ExtensionOutputValue {
74    /// A concrete value from an extension call
75    Known(Value),
76    /// An unknown returned from an extension call
77    Unknown(Unknown),
78}
79
80impl<T> From<T> for ExtensionOutputValue
81where
82    T: Into<Value>,
83{
84    fn from(v: T) -> Self {
85        ExtensionOutputValue::Known(v.into())
86    }
87}
88
89/// Which "style" is a function call
90#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
91#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
92pub enum CallStyle {
93    /// Function-style, eg foo(a, b)
94    FunctionStyle,
95    /// Method-style, eg a.foo(b)
96    MethodStyle,
97}
98
99// Note: we could use currying to make this a little nicer
100
101/// Trait object that implements the extension function call.
102pub type ExtensionFunctionObject =
103    Box<dyn Fn(&[Value]) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>;
104
105/// Extension function. These can be called by the given `name` in Ceder
106/// expressions.
107pub struct ExtensionFunction {
108    /// Name of the function
109    name: Name,
110    /// Which `CallStyle` should be used when calling this function
111    style: CallStyle,
112    /// The actual function, which takes an `&[Value]` and returns a `Value`,
113    /// or an evaluation error
114    func: ExtensionFunctionObject,
115    /// The return type of this function, as a `SchemaType`. We require that
116    /// this be constant -- any given extension function must always return a
117    /// value of this `SchemaType`.
118    /// If `return_type` is `None`, the function may never return a value.
119    /// (ie: it functions as the `Never` type)
120    return_type: Option<SchemaType>,
121    /// The argument types that this function expects, as `SchemaType`s. If any
122    /// given argument type is not constant (function works with multiple
123    /// `SchemaType`s) then this will be `None` for that argument.
124    arg_types: Vec<Option<SchemaType>>,
125}
126
127impl ExtensionFunction {
128    /// Create a new `ExtensionFunction` taking any number of arguments
129    fn new(
130        name: Name,
131        style: CallStyle,
132        func: ExtensionFunctionObject,
133        return_type: Option<SchemaType>,
134        arg_types: Vec<Option<SchemaType>>,
135    ) -> Self {
136        Self {
137            name,
138            func,
139            style,
140            return_type,
141            arg_types,
142        }
143    }
144
145    /// Create a new `ExtensionFunction` taking no arguments
146    pub fn nullary(
147        name: Name,
148        style: CallStyle,
149        func: Box<dyn Fn() -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
150        return_type: SchemaType,
151    ) -> Self {
152        Self::new(
153            name.clone(),
154            style,
155            Box::new(move |args: &[Value]| {
156                if args.is_empty() {
157                    func()
158                } else {
159                    Err(evaluator::EvaluationError::wrong_num_arguments(
160                        name.clone(),
161                        0,
162                        args.len(),
163                        None, // evaluator will add the source location later
164                    ))
165                }
166            }),
167            Some(return_type),
168            vec![],
169        )
170    }
171
172    /// Create a new `ExtensionFunction` taking one argument, that never returns a value
173    pub fn unary_never(
174        name: Name,
175        style: CallStyle,
176        func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
177        arg_type: Option<SchemaType>,
178    ) -> Self {
179        Self::new(
180            name.clone(),
181            style,
182            Box::new(move |args: &[Value]| match args.first() {
183                Some(arg) => func(arg.clone()),
184                None => Err(evaluator::EvaluationError::wrong_num_arguments(
185                    name.clone(),
186                    1,
187                    args.len(),
188                    None, // evaluator will add the source location later
189                )),
190            }),
191            None,
192            vec![arg_type],
193        )
194    }
195
196    /// Create a new `ExtensionFunction` taking one argument
197    pub fn unary(
198        name: Name,
199        style: CallStyle,
200        func: Box<dyn Fn(Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>,
201        return_type: SchemaType,
202        arg_type: Option<SchemaType>,
203    ) -> Self {
204        Self::new(
205            name.clone(),
206            style,
207            Box::new(move |args: &[Value]| match &args {
208                &[arg] => func(arg.clone()),
209                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
210                    name.clone(),
211                    1,
212                    args.len(),
213                    None, // evaluator will add the source location later
214                )),
215            }),
216            Some(return_type),
217            vec![arg_type],
218        )
219    }
220
221    /// Create a new `ExtensionFunction` taking two arguments
222    pub fn binary(
223        name: Name,
224        style: CallStyle,
225        func: Box<
226            dyn Fn(Value, Value) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static,
227        >,
228        return_type: SchemaType,
229        arg_types: (Option<SchemaType>, Option<SchemaType>),
230    ) -> Self {
231        Self::new(
232            name.clone(),
233            style,
234            Box::new(move |args: &[Value]| match &args {
235                &[first, second] => func(first.clone(), second.clone()),
236                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
237                    name.clone(),
238                    2,
239                    args.len(),
240                    None, // evaluator will add the source location later
241                )),
242            }),
243            Some(return_type),
244            vec![arg_types.0, arg_types.1],
245        )
246    }
247
248    /// Create a new `ExtensionFunction` taking three arguments
249    pub fn ternary(
250        name: Name,
251        style: CallStyle,
252        func: Box<
253            dyn Fn(Value, Value, Value) -> evaluator::Result<ExtensionOutputValue>
254                + Sync
255                + Send
256                + 'static,
257        >,
258        return_type: SchemaType,
259        arg_types: (Option<SchemaType>, Option<SchemaType>, Option<SchemaType>),
260    ) -> Self {
261        Self::new(
262            name.clone(),
263            style,
264            Box::new(move |args: &[Value]| match &args {
265                &[first, second, third] => func(first.clone(), second.clone(), third.clone()),
266                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
267                    name.clone(),
268                    3,
269                    args.len(),
270                    None, // evaluator will add the source location later
271                )),
272            }),
273            Some(return_type),
274            vec![arg_types.0, arg_types.1, arg_types.2],
275        )
276    }
277
278    /// Get the `Name` of the `ExtensionFunction`
279    pub fn name(&self) -> &Name {
280        &self.name
281    }
282
283    /// Get the `CallStyle` of the `ExtensionFunction`
284    pub fn style(&self) -> CallStyle {
285        self.style
286    }
287
288    /// Get the return type of the `ExtensionFunction`
289    /// `None` represents the `Never` type.
290    pub fn return_type(&self) -> Option<&SchemaType> {
291        self.return_type.as_ref()
292    }
293
294    /// Get the argument types of the `ExtensionFunction`.
295    ///
296    /// If any given argument type is not constant (function works with multiple
297    /// `SchemaType`s) then this will be `None` for that argument.
298    pub fn arg_types(&self) -> &[Option<SchemaType>] {
299        &self.arg_types
300    }
301
302    /// Returns `true` if this function is considered a "constructor".
303    ///
304    /// Currently, the only impact of this is that non-constructors are not
305    /// accessible in the JSON format (entities/json.rs).
306    pub fn is_constructor(&self) -> bool {
307        // return type is an extension type
308        matches!(self.return_type(), Some(SchemaType::Extension { .. }))
309        // all arg types are `Some()`
310        && self.arg_types().iter().all(Option::is_some)
311        // no argument is an extension type
312        && !self.arg_types().iter().any(|ty| matches!(ty, Some(SchemaType::Extension { .. })))
313    }
314
315    /// Call the `ExtensionFunction` with the given args
316    pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
317        match (self.func)(args)? {
318            ExtensionOutputValue::Known(v) => Ok(PartialValue::Value(v)),
319            ExtensionOutputValue::Unknown(u) => Ok(PartialValue::Residual(Expr::unknown(u))),
320        }
321    }
322}
323
324impl std::fmt::Debug for ExtensionFunction {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        write!(f, "<extension function {}>", self.name())
327    }
328}
329
330/// Extension value.
331///
332/// Anything implementing this trait can be used as a first-class value in
333/// Cedar. For instance, the `ipaddr` extension uses this mechanism
334/// to implement IPAddr as a Cedar first-class value.
335pub trait ExtensionValue: Debug + Display + Send + Sync + UnwindSafe + RefUnwindSafe {
336    /// Get the name of the type of this value.
337    ///
338    /// Cedar has nominal typing, so two values have the same type iff they
339    /// return the same typename here.
340    fn typename(&self) -> Name;
341}
342
343impl<V: ExtensionValue> StaticallyTyped for V {
344    fn type_of(&self) -> Type {
345        Type::Extension {
346            name: self.typename(),
347        }
348    }
349}
350
351#[derive(Debug, Clone)]
352/// Object container for extension values, also stores the constructor-and-args
353/// that can reproduce the value (important for converting the value back to
354/// `RestrictedExpr` for instance)
355pub struct ExtensionValueWithArgs {
356    value: Arc<dyn InternalExtensionValue>,
357    pub(crate) constructor: Name,
358    /// Args are stored in `RestrictedExpr` form, just because that's most
359    /// convenient for reconstructing a `RestrictedExpr` that reproduces this
360    /// extension value
361    pub(crate) args: Vec<RestrictedExpr>,
362}
363
364impl ExtensionValueWithArgs {
365    /// Create a new `ExtensionValueWithArgs`
366    pub fn new(
367        value: Arc<dyn InternalExtensionValue + Send + Sync>,
368        constructor: Name,
369        args: Vec<RestrictedExpr>,
370    ) -> Self {
371        Self {
372            value,
373            constructor,
374            args,
375        }
376    }
377
378    /// Get the internal value
379    pub fn value(&self) -> &(dyn InternalExtensionValue) {
380        self.value.as_ref()
381    }
382
383    /// Get the typename of this extension value
384    pub fn typename(&self) -> Name {
385        self.value.typename()
386    }
387
388    /// Get the constructor and args that can reproduce this value
389    pub fn constructor_and_args(&self) -> (&Name, &[RestrictedExpr]) {
390        (&self.constructor, &self.args)
391    }
392}
393
394impl From<ExtensionValueWithArgs> for Expr {
395    fn from(val: ExtensionValueWithArgs) -> Self {
396        ExprBuilder::new().call_extension_fn(val.constructor, val.args.into_iter().map(Into::into))
397    }
398}
399
400impl StaticallyTyped for ExtensionValueWithArgs {
401    fn type_of(&self) -> Type {
402        self.value.type_of()
403    }
404}
405
406impl Display for ExtensionValueWithArgs {
407    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408        write!(f, "{}", self.value)
409    }
410}
411
412impl PartialEq for ExtensionValueWithArgs {
413    fn eq(&self, other: &Self) -> bool {
414        // Values that are equal are equal regardless of which arguments made them
415        self.value.as_ref() == other.value.as_ref()
416    }
417}
418
419impl Eq for ExtensionValueWithArgs {}
420
421impl PartialOrd for ExtensionValueWithArgs {
422    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
423        Some(self.cmp(other))
424    }
425}
426
427impl Ord for ExtensionValueWithArgs {
428    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
429        self.value.cmp(&other.value)
430    }
431}
432
433/// Extensions provide a type implementing `ExtensionValue`, `Eq`, and `Ord`.
434/// We automatically implement `InternalExtensionValue` for that type (with the
435/// impl below).  Internally, we use `dyn InternalExtensionValue` instead of
436/// `dyn ExtensionValue`.
437///
438/// You might wonder why we don't just have `ExtensionValue: Eq + Ord` and use
439/// `dyn ExtensionValue` everywhere.  The answer is that the Rust compiler
440/// doesn't let you because of
441/// [object safety](https://doc.rust-lang.org/reference/items/traits.html#object-safety).
442/// So instead we have this workaround where we define our own `equals_extvalue`
443/// method that compares not against `&Self` but against `&dyn InternalExtensionValue`,
444/// and likewise for `cmp_extvalue`.
445pub trait InternalExtensionValue: ExtensionValue {
446    /// convert to an `Any`
447    fn as_any(&self) -> &dyn Any;
448    /// this will be the basis for `PartialEq` on `InternalExtensionValue`; but
449    /// note the `&dyn` (normal `PartialEq` doesn't have the `dyn`)
450    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
451    /// this will be the basis for `Ord` on `InternalExtensionValue`; but note
452    /// the `&dyn` (normal `Ord` doesn't have the `dyn`)
453    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
454}
455
456impl<V: 'static + Eq + Ord + ExtensionValue + Send + Sync> InternalExtensionValue for V {
457    fn as_any(&self) -> &dyn Any {
458        self
459    }
460
461    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
462        other
463            .as_any()
464            .downcast_ref::<V>()
465            .map(|v| self == v)
466            .unwrap_or(false) // if the downcast failed, values are different types, so equality is false
467    }
468
469    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
470        other
471            .as_any()
472            .downcast_ref::<V>()
473            .map(|v| self.cmp(v))
474            .unwrap_or_else(|| {
475                // downcast failed, so values are different types.
476                // we fall back on the total ordering on typenames.
477                self.typename().cmp(&other.typename())
478            })
479    }
480}
481
482impl PartialEq for dyn InternalExtensionValue {
483    fn eq(&self, other: &Self) -> bool {
484        self.equals_extvalue(other)
485    }
486}
487
488impl Eq for dyn InternalExtensionValue {}
489
490impl PartialOrd for dyn InternalExtensionValue {
491    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
492        Some(self.cmp(other))
493    }
494}
495
496impl Ord for dyn InternalExtensionValue {
497    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
498        self.cmp_extvalue(other)
499    }
500}
501
502impl StaticallyTyped for dyn InternalExtensionValue {
503    fn type_of(&self) -> Type {
504        Type::Extension {
505            name: self.typename(),
506        }
507    }
508}