cedar_policy_core/
extensions.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! This module contains all of the standard Cedar extensions.
18
19#[cfg(feature = "ipaddr")]
20pub mod ipaddr;
21
22#[cfg(feature = "decimal")]
23pub mod decimal;
24pub mod partial_evaluation;
25
26use crate::ast::{Extension, ExtensionFunction, Name};
27use crate::entities::SchemaType;
28use miette::Diagnostic;
29use thiserror::Error;
30
31lazy_static::lazy_static! {
32    static ref ALL_AVAILABLE_EXTENSIONS: Vec<Extension> = vec![
33        #[cfg(feature = "ipaddr")]
34        ipaddr::extension(),
35        #[cfg(feature = "decimal")]
36        decimal::extension(),
37        #[cfg(feature = "partial-eval")]
38        partial_evaluation::extension(),
39    ];
40}
41
42/// Holds data on all the Extensions which are active for a given evaluation.
43///
44/// Clone is cheap for this type.
45#[derive(Debug, Clone, Copy)]
46pub struct Extensions<'a> {
47    /// the actual extensions
48    extensions: &'a [Extension],
49}
50
51impl Extensions<'static> {
52    /// Get a new `Extensions` containing data on all the available extensions.
53    pub fn all_available() -> Extensions<'static> {
54        Extensions {
55            extensions: &ALL_AVAILABLE_EXTENSIONS,
56        }
57    }
58
59    /// Get a new `Extensions` with no extensions enabled.
60    pub fn none() -> Extensions<'static> {
61        Extensions { extensions: &[] }
62    }
63}
64
65impl<'a> Extensions<'a> {
66    /// Get a new `Extensions` with these specific extensions enabled.
67    pub fn specific_extensions(extensions: &'a [Extension]) -> Extensions<'a> {
68        Extensions { extensions }
69    }
70
71    /// Get the names of all active extensions.
72    pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
73        self.extensions.iter().map(|ext| ext.name())
74    }
75
76    /// Get the extension function with the given name, from these extensions.
77    ///
78    /// Returns an error if the function is not defined by any extension, or if
79    /// it is defined multiple times.
80    pub fn func(&self, name: &Name) -> Result<&ExtensionFunction> {
81        // NOTE: in the future, we could build a single HashMap of function
82        // name to ExtensionFunction, combining all extension functions
83        // into one map, to make this lookup faster.
84        let extension_funcs: Vec<&ExtensionFunction> = self
85            .extensions
86            .iter()
87            .filter_map(|ext| ext.get_func(name))
88            .collect();
89        match extension_funcs.first() {
90            None => Err(ExtensionFunctionLookupError::FuncDoesNotExist { name: name.clone() }),
91            Some(first) if extension_funcs.len() == 1 => Ok(first),
92            _ => Err(ExtensionFunctionLookupError::FuncMultiplyDefined {
93                name: name.clone(),
94                num_defs: extension_funcs.len(),
95            }),
96        }
97    }
98
99    /// Iterate over all extension functions defined by all of these extensions.
100    ///
101    /// No guarantee that this list won't have duplicates or repeated names.
102    pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
103        self.extensions.iter().flat_map(|ext| ext.funcs())
104    }
105
106    /// Lookup a single-argument constructor by its return type and argument type.
107    /// This will ignore polymorphic functions (that accept multiple argument types).
108    ///
109    /// `Ok(None)` means no constructor has that signature.
110    /// `Err` is returned in the case that multiple constructors have that signature.
111    pub(crate) fn lookup_single_arg_constructor(
112        &self,
113        return_type: &SchemaType,
114        arg_type: &SchemaType,
115    ) -> Result<Option<&ExtensionFunction>> {
116        let matches = self
117            .all_funcs()
118            .filter(|f| {
119                f.is_constructor()
120                    && f.return_type() == Some(return_type)
121                    && f.arg_types().first().map(Option::as_ref) == Some(Some(arg_type))
122            })
123            .collect::<Vec<_>>();
124        match matches.first() {
125            None => Ok(None),
126            Some(first) if matches.len() == 1 => Ok(Some(first)),
127            _ => Err(
128                ExtensionFunctionLookupError::MultipleConstructorsSameSignature {
129                    return_type: Box::new(return_type.clone()),
130                    arg_type: Box::new(arg_type.clone()),
131                },
132            ),
133        }
134    }
135}
136
137/// Errors thrown when looking up an extension function in [`Extensions`].
138#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
139pub enum ExtensionFunctionLookupError {
140    /// Tried to call a function that doesn't exist
141    #[error("extension function `{name}` does not exist")]
142    FuncDoesNotExist {
143        /// Name of the function that doesn't exist
144        name: Name,
145    },
146
147    /// Attempted to typecheck an expression that had no type
148    #[error("extension function `{name}` has no type")]
149    HasNoType {
150        /// Name of the function that returns no type
151        name: Name,
152    },
153
154    /// Tried to call a function but it was defined multiple times (e.g., by
155    /// multiple different extensions)
156    #[error("extension function `{name}` is defined {num_defs} times")]
157    FuncMultiplyDefined {
158        /// Name of the function that is multiply defined
159        name: Name,
160        /// How many times that function is defined
161        num_defs: usize,
162    },
163
164    /// Two extension constructors (in the same or different extensions) had
165    /// exactly the same type signature.  This is currently not allowed.
166    #[error(
167        "multiple extension constructors have the same type signature {arg_type} -> {return_type}"
168    )]
169    MultipleConstructorsSameSignature {
170        /// return type of the shared constructor signature
171        return_type: Box<SchemaType>,
172        /// argument type of the shared constructor signature
173        arg_type: Box<SchemaType>,
174    },
175}
176
177/// Type alias for convenience
178pub type Result<T> = std::result::Result<T, ExtensionFunctionLookupError>;
179
180#[cfg(test)]
181pub(crate) mod test {
182    use super::*;
183    use std::collections::HashSet;
184
185    #[test]
186    fn no_common_extension_function_names() {
187        // Our expr display must search for callstyle given a name, so
188        // no names can be used for both callstyles
189
190        // Test that names are all unique for ease of use.
191        // This overconstrains our current requirements, but shouldn't change
192        // until we identify a strong need.
193        let all_names: Vec<_> = Extensions::all_available()
194            .extensions
195            .iter()
196            .flat_map(|e| e.funcs().map(|f| f.name().clone()))
197            .collect();
198        let dedup_names: HashSet<_> = all_names.iter().collect();
199        assert_eq!(all_names.len(), dedup_names.len());
200    }
201}