cedar_policy_core/
extensions.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! This module contains all of the standard Cedar extensions.
18
19#[cfg(feature = "ipaddr")]
20pub mod ipaddr;
21
22#[cfg(feature = "decimal")]
23pub mod decimal;
24
25#[cfg(feature = "datetime")]
26pub mod datetime;
27pub mod partial_evaluation;
28
29use std::collections::HashMap;
30use std::sync::LazyLock;
31
32use crate::ast::{Extension, ExtensionFunction, Name};
33use crate::entities::SchemaType;
34use crate::extensions::extension_initialization_errors::MultipleConstructorsSameSignatureError;
35use crate::parser::Loc;
36use miette::Diagnostic;
37use thiserror::Error;
38
39use self::extension_function_lookup_errors::FuncDoesNotExistError;
40use self::extension_initialization_errors::FuncMultiplyDefinedError;
41
42static ALL_AVAILABLE_EXTENSION_OBJECTS: LazyLock<Vec<Extension>> = LazyLock::new(|| {
43    vec![
44        #[cfg(feature = "ipaddr")]
45        ipaddr::extension(),
46        #[cfg(feature = "decimal")]
47        decimal::extension(),
48        #[cfg(feature = "datetime")]
49        datetime::extension(),
50        #[cfg(feature = "partial-eval")]
51        partial_evaluation::extension(),
52    ]
53});
54
55static ALL_AVAILABLE_EXTENSIONS: LazyLock<Extensions<'static>> =
56    LazyLock::new(Extensions::build_all_available);
57
58static EXTENSIONS_NONE: LazyLock<Extensions<'static>> = LazyLock::new(|| Extensions {
59    extensions: &[],
60    functions: HashMap::new(),
61    single_arg_constructors: HashMap::new(),
62});
63
64/// Holds data on all the Extensions which are active for a given evaluation.
65///
66/// This structure is intentionally not `Clone` because we can use it entirely
67/// by reference.
68#[derive(Debug)]
69pub struct Extensions<'a> {
70    /// the actual extensions
71    extensions: &'a [Extension],
72    /// All extension functions, collected from every extension used to
73    /// construct this object.  Built ahead of time so that we know during
74    /// extension function lookup that at most one extension function exists
75    /// for a name. This should also make the lookup more efficient.
76    functions: HashMap<&'a Name, &'a ExtensionFunction>,
77    /// All single argument extension function constructors, indexed by their
78    /// return type. Built ahead of time so that we know each constructor has
79    /// a unique return type.
80    single_arg_constructors: HashMap<&'a SchemaType, &'a ExtensionFunction>,
81}
82
83impl Extensions<'static> {
84    /// Get a new `Extensions` containing data on all the available extensions.
85    fn build_all_available() -> Extensions<'static> {
86        // PANIC SAFETY: Builtin extensions define functions/constructors only once. Also tested by many different test cases.
87        #[allow(clippy::expect_used)]
88        Self::specific_extensions(&ALL_AVAILABLE_EXTENSION_OBJECTS)
89            .expect("Default extensions should never error on initialization")
90    }
91
92    /// An [`Extensions`] object with static lifetime contain all available extensions.
93    pub fn all_available() -> &'static Extensions<'static> {
94        &ALL_AVAILABLE_EXTENSIONS
95    }
96
97    /// Get a new `Extensions` with no extensions enabled.
98    pub fn none() -> &'static Extensions<'static> {
99        &EXTENSIONS_NONE
100    }
101}
102
103impl<'a> Extensions<'a> {
104    /// Obtain the non-empty vector of types supporting operator overloading
105    pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
106        self.extensions
107            .iter()
108            .flat_map(|ext| ext.types_with_operator_overloading())
109    }
110    /// Get a new `Extensions` with these specific extensions enabled.
111    pub fn specific_extensions(
112        extensions: &'a [Extension],
113    ) -> std::result::Result<Extensions<'a>, ExtensionInitializationError> {
114        // Build functions map, ensuring that no functions share the same name.
115        let functions = util::collect_no_duplicates(
116            extensions
117                .iter()
118                .flat_map(|e| e.funcs())
119                .map(|f| (f.name(), f)),
120        )
121        .map_err(|name| FuncMultiplyDefinedError { name: name.clone() })?;
122
123        // Build the constructor map, ensuring that no constructors share a return type
124        let single_arg_constructors = util::collect_no_duplicates(
125            extensions
126                .iter()
127                .flat_map(|e| e.funcs())
128                .filter(|f| f.is_single_arg_constructor())
129                .filter_map(|f| f.return_type().map(|return_type| (return_type, f))),
130        )
131        .map_err(|return_type| MultipleConstructorsSameSignatureError {
132            return_type: Box::new(return_type.clone()),
133        })?;
134
135        Ok(Extensions {
136            extensions,
137            functions,
138            single_arg_constructors,
139        })
140    }
141
142    /// Get the names of all active extensions.
143    pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
144        self.extensions.iter().map(|ext| ext.name())
145    }
146
147    /// Get all extension type names declared by active extensions.
148    ///
149    /// (More specifically, all extension type names such that any function in
150    /// an active extension could produce a value of that extension type.)
151    pub fn ext_types(&self) -> impl Iterator<Item = &Name> {
152        self.extensions.iter().flat_map(|ext| ext.ext_types())
153    }
154
155    /// Get the extension function with the given name, from these extensions.
156    ///
157    /// Returns an error if the function is not defined by any extension
158    pub fn func(
159        &self,
160        name: &Name,
161    ) -> std::result::Result<&ExtensionFunction, ExtensionFunctionLookupError> {
162        self.functions.get(name).copied().ok_or_else(|| {
163            FuncDoesNotExistError {
164                name: name.clone(),
165                source_loc: name.loc().cloned(),
166            }
167            .into()
168        })
169    }
170
171    /// Iterate over all extension functions defined by all of these extensions.
172    ///
173    /// No guarantee that this list won't have duplicates or repeated names.
174    pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
175        self.extensions.iter().flat_map(|ext| ext.funcs())
176    }
177
178    /// Lookup a single-argument constructor by its return type
179    pub(crate) fn lookup_single_arg_constructor(
180        &self,
181        return_type: &SchemaType,
182    ) -> Option<&ExtensionFunction> {
183        self.single_arg_constructors.get(return_type).copied()
184    }
185}
186
187/// Errors occurring while initializing extensions. There are internal errors, so
188/// this enum should not become part of the public API unless we publicly expose
189/// user-defined extension function.
190#[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
191pub enum ExtensionInitializationError {
192    /// An extension function was defined by multiple extensions.
193    #[error(transparent)]
194    #[diagnostic(transparent)]
195    FuncMultiplyDefined(#[from] extension_initialization_errors::FuncMultiplyDefinedError),
196
197    /// Two extension constructors (in the same or different extensions) had
198    /// exactly the same type signature.  This is currently not allowed.
199    #[error(transparent)]
200    #[diagnostic(transparent)]
201    MultipleConstructorsSameSignature(
202        #[from] extension_initialization_errors::MultipleConstructorsSameSignatureError,
203    ),
204}
205
206/// Error subtypes for [`ExtensionInitializationError`]
207mod extension_initialization_errors {
208    use crate::{ast::Name, entities::SchemaType};
209    use miette::Diagnostic;
210    use thiserror::Error;
211
212    /// An extension function was defined by multiple extensions.
213    #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
214    #[error("extension function `{name}` is defined multiple times")]
215    pub struct FuncMultiplyDefinedError {
216        /// Name of the function that was multiply defined
217        pub(crate) name: Name,
218    }
219
220    /// Two extension constructors (in the same or different extensions) exist
221    /// for one extension type.  This is currently not allowed.
222    #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
223    #[error("multiple extension constructors for the same extension type {return_type}")]
224    pub struct MultipleConstructorsSameSignatureError {
225        /// return type of the shared constructor signature
226        pub(crate) return_type: Box<SchemaType>,
227    }
228}
229
230/// Errors thrown when looking up an extension function in [`Extensions`].
231//
232// CAUTION: this type is publicly exported in `cedar-policy`.
233// Don't make fields `pub`, don't make breaking changes, and use caution
234// when adding public methods.
235#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
236pub enum ExtensionFunctionLookupError {
237    /// Tried to call a function that doesn't exist
238    #[error(transparent)]
239    #[diagnostic(transparent)]
240    FuncDoesNotExist(#[from] extension_function_lookup_errors::FuncDoesNotExistError),
241}
242
243impl ExtensionFunctionLookupError {
244    pub(crate) fn source_loc(&self) -> Option<&Loc> {
245        match self {
246            Self::FuncDoesNotExist(e) => e.source_loc.as_ref(),
247        }
248    }
249
250    pub(crate) fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
251        match self {
252            Self::FuncDoesNotExist(e) => {
253                Self::FuncDoesNotExist(extension_function_lookup_errors::FuncDoesNotExistError {
254                    source_loc,
255                    ..e
256                })
257            }
258        }
259    }
260}
261
262/// Error subtypes for [`ExtensionFunctionLookupError`]
263pub mod extension_function_lookup_errors {
264    use crate::ast::Name;
265    use crate::parser::Loc;
266    use miette::Diagnostic;
267    use thiserror::Error;
268
269    /// Tried to call a function that doesn't exist
270    //
271    // CAUTION: this type is publicly exported in `cedar-policy`.
272    // Don't make fields `pub`, don't make breaking changes, and use caution
273    // when adding public methods.
274    #[derive(Debug, PartialEq, Eq, Clone, Error)]
275    #[error("extension function `{name}` does not exist")]
276    pub struct FuncDoesNotExistError {
277        /// Name of the function that doesn't exist
278        pub(crate) name: Name,
279        /// Source location
280        pub(crate) source_loc: Option<Loc>,
281    }
282
283    impl Diagnostic for FuncDoesNotExistError {
284        impl_diagnostic_from_source_loc_opt_field!(source_loc);
285    }
286}
287
288/// Type alias for convenience
289pub type Result<T> = std::result::Result<T, ExtensionFunctionLookupError>;
290
291/// Utilities shared with the `cedar-policy-validator` extensions module.
292pub mod util {
293    use std::collections::{hash_map::Entry, HashMap};
294
295    /// Utility to build a `HashMap` of key value pairs from an iterator,
296    /// returning an `Err` result if there are any duplicate keys in the
297    /// iterator.
298    pub fn collect_no_duplicates<K, V>(
299        i: impl Iterator<Item = (K, V)>,
300    ) -> std::result::Result<HashMap<K, V>, K>
301    where
302        K: Clone + std::hash::Hash + Eq,
303    {
304        let mut map = HashMap::with_capacity(i.size_hint().0);
305        for (k, v) in i {
306            match map.entry(k) {
307                Entry::Occupied(occupied) => {
308                    return Err(occupied.key().clone());
309                }
310                Entry::Vacant(vacant) => {
311                    vacant.insert(v);
312                }
313            }
314        }
315        Ok(map)
316    }
317}
318
319#[cfg(test)]
320mod test {
321    use super::*;
322    use std::collections::HashSet;
323
324    #[test]
325    fn no_common_extension_function_names() {
326        // Our expr display must search for callstyle given a name, so
327        // no names can be used for both callstyles
328
329        // Test that names are all unique for ease of use.
330        // This overconstrains our current requirements, but shouldn't change
331        // until we identify a strong need.
332        let all_names: Vec<_> = Extensions::all_available()
333            .extensions
334            .iter()
335            .flat_map(|e| e.funcs().map(|f| f.name().clone()))
336            .collect();
337        let dedup_names: HashSet<_> = all_names.iter().collect();
338        assert_eq!(all_names.len(), dedup_names.len());
339    }
340}