cedar_policy_core/
extensions.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! This module contains all of the standard Cedar extensions.
18
19#[cfg(feature = "ipaddr")]
20pub mod ipaddr;
21
22#[cfg(feature = "decimal")]
23pub mod decimal;
24pub mod partial_evaluation;
25
26use std::collections::HashMap;
27
28use crate::ast::{Extension, ExtensionFunction, Name};
29use crate::entities::SchemaType;
30use crate::parser::Loc;
31use miette::Diagnostic;
32use thiserror::Error;
33
34use self::extension_function_lookup_errors::FuncDoesNotExistError;
35use self::extension_initialization_errors::{
36    FuncMultiplyDefinedError, MultipleConstructorsSameSignatureError,
37};
38
39lazy_static::lazy_static! {
40    static ref ALL_AVAILABLE_EXTENSION_OBJECTS: Vec<Extension> = vec![
41        #[cfg(feature = "ipaddr")]
42        ipaddr::extension(),
43        #[cfg(feature = "decimal")]
44        decimal::extension(),
45        #[cfg(feature = "partial-eval")]
46        partial_evaluation::extension(),
47    ];
48
49    static ref ALL_AVAILABLE_EXTENSIONS : Extensions<'static> = Extensions::build_all_available();
50
51    static ref EXTENSIONS_NONE : Extensions<'static> = Extensions {
52        extensions: &[],
53        functions: HashMap::new(),
54        single_arg_constructors: HashMap::new(),
55    };
56}
57
58/// Holds data on all the Extensions which are active for a given evaluation.
59///
60/// This structure is intentionally not `Clone` because we can use it entirely
61/// by reference.
62#[derive(Debug)]
63pub struct Extensions<'a> {
64    /// the actual extensions
65    extensions: &'a [Extension],
66    /// All extension functions, collected from every extension used to
67    /// construct this object.  Built ahead of time so that we know during
68    /// extension function lookup that at most one extension function exists
69    /// for a name. This should also make the lookup more efficient.
70    functions: HashMap<&'a Name, &'a ExtensionFunction>,
71    /// All single argument extension function constructors, indexed by their
72    /// return type. Built ahead of time so that we know each constructor has
73    /// a unique return type.
74    single_arg_constructors: HashMap<&'a SchemaType, &'a ExtensionFunction>,
75}
76
77impl Extensions<'static> {
78    /// Get a new `Extensions` containing data on all the available extensions.
79    fn build_all_available() -> Extensions<'static> {
80        // PANIC SAFETY: Builtin extensions define functions/constructors only once. Also tested by many different test cases.
81        #[allow(clippy::expect_used)]
82        Self::specific_extensions(&ALL_AVAILABLE_EXTENSION_OBJECTS)
83            .expect("Default extensions should never error on initialization")
84    }
85
86    /// An [`Extensions`] object with static lifetime contain all available extensions.
87    pub fn all_available() -> &'static Extensions<'static> {
88        &ALL_AVAILABLE_EXTENSIONS
89    }
90
91    /// Get a new `Extensions` with no extensions enabled.
92    pub fn none() -> &'static Extensions<'static> {
93        &EXTENSIONS_NONE
94    }
95}
96
97impl<'a> Extensions<'a> {
98    /// Get a new `Extensions` with these specific extensions enabled.
99    pub fn specific_extensions(
100        extensions: &'a [Extension],
101    ) -> std::result::Result<Extensions<'a>, ExtensionInitializationError> {
102        // Build functions map, ensuring that no functions share the same name.
103        let functions = util::collect_no_duplicates(
104            extensions
105                .iter()
106                .flat_map(|e| e.funcs())
107                .map(|f| (f.name(), f)),
108        )
109        .map_err(|name| FuncMultiplyDefinedError { name: name.clone() })?;
110
111        // Build the constructor map, ensuring that no constructors share a return type
112        let single_arg_constructors = util::collect_no_duplicates(
113            extensions
114                .iter()
115                .flat_map(|e| e.funcs())
116                .filter(|f| f.is_constructor() && f.arg_types().len() == 1)
117                .filter_map(|f| f.return_type().map(|return_type| (return_type, f))),
118        )
119        .map_err(|return_type| MultipleConstructorsSameSignatureError {
120            return_type: Box::new(return_type.clone()),
121        })?;
122
123        Ok(Extensions {
124            extensions,
125            functions,
126            single_arg_constructors,
127        })
128    }
129
130    /// Get the names of all active extensions.
131    pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
132        self.extensions.iter().map(|ext| ext.name())
133    }
134
135    /// Get all extension type names declared by active extensions.
136    ///
137    /// (More specifically, all extension type names such that any function in
138    /// an active extension could produce a value of that extension type.)
139    pub fn ext_types(&self) -> impl Iterator<Item = &Name> {
140        self.extensions.iter().flat_map(|ext| ext.ext_types())
141    }
142
143    /// Get the extension function with the given name, from these extensions.
144    ///
145    /// Returns an error if the function is not defined by any extension
146    pub fn func(
147        &self,
148        name: &Name,
149    ) -> std::result::Result<&ExtensionFunction, ExtensionFunctionLookupError> {
150        self.functions.get(name).copied().ok_or_else(|| {
151            FuncDoesNotExistError {
152                name: name.clone(),
153                source_loc: name.loc().cloned(),
154            }
155            .into()
156        })
157    }
158
159    /// Iterate over all extension functions defined by all of these extensions.
160    ///
161    /// No guarantee that this list won't have duplicates or repeated names.
162    pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
163        self.extensions.iter().flat_map(|ext| ext.funcs())
164    }
165
166    /// Lookup a single-argument constructor by its return type and argument type.
167    ///
168    /// `None` means no constructor has that signature.
169    pub(crate) fn lookup_single_arg_constructor(
170        &self,
171        return_type: &SchemaType,
172    ) -> Option<&ExtensionFunction> {
173        self.single_arg_constructors.get(return_type).copied()
174    }
175}
176
177/// Errors occurring while initializing extensions. There are internal errors, so
178/// this enum should not become part of the public API unless we publicly expose
179/// user-defined extension function.
180#[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
181pub enum ExtensionInitializationError {
182    /// An extension function was defined by multiple extensions.
183    #[error(transparent)]
184    #[diagnostic(transparent)]
185    FuncMultiplyDefined(#[from] extension_initialization_errors::FuncMultiplyDefinedError),
186
187    /// Two extension constructors (in the same or different extensions) had
188    /// exactly the same type signature.  This is currently not allowed.
189    #[error(transparent)]
190    #[diagnostic(transparent)]
191    MultipleConstructorsSameSignature(
192        #[from] extension_initialization_errors::MultipleConstructorsSameSignatureError,
193    ),
194}
195
196/// Error subtypes for [`ExtensionInitializationError`]
197mod extension_initialization_errors {
198    use crate::{ast::Name, entities::SchemaType};
199    use miette::Diagnostic;
200    use thiserror::Error;
201
202    /// An extension function was defined by multiple extensions.
203    #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
204    #[error("extension function `{name}` is defined multiple times")]
205    pub struct FuncMultiplyDefinedError {
206        /// Name of the function that was multiply defined
207        pub(crate) name: Name,
208    }
209
210    /// Two extension constructors (in the same or different extensions) exist
211    /// for one extension type.  This is currently not allowed.
212    #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
213    #[error("multiple extension constructors for the same extension type {return_type}")]
214    pub struct MultipleConstructorsSameSignatureError {
215        /// return type of the shared constructor signature
216        pub(crate) return_type: Box<SchemaType>,
217    }
218}
219
220/// Errors thrown when looking up an extension function in [`Extensions`].
221//
222// CAUTION: this type is publicly exported in `cedar-policy`.
223// Don't make fields `pub`, don't make breaking changes, and use caution
224// when adding public methods.
225#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
226pub enum ExtensionFunctionLookupError {
227    /// Tried to call a function that doesn't exist
228    #[error(transparent)]
229    #[diagnostic(transparent)]
230    FuncDoesNotExist(#[from] extension_function_lookup_errors::FuncDoesNotExistError),
231}
232
233impl ExtensionFunctionLookupError {
234    pub(crate) fn source_loc(&self) -> Option<&Loc> {
235        match self {
236            Self::FuncDoesNotExist(e) => e.source_loc.as_ref(),
237        }
238    }
239
240    pub(crate) fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
241        match self {
242            Self::FuncDoesNotExist(e) => {
243                Self::FuncDoesNotExist(extension_function_lookup_errors::FuncDoesNotExistError {
244                    source_loc,
245                    ..e
246                })
247            }
248        }
249    }
250}
251
252/// Error subtypes for [`ExtensionFunctionLookupError`]
253pub mod extension_function_lookup_errors {
254    use crate::ast::Name;
255    use crate::parser::Loc;
256    use miette::Diagnostic;
257    use thiserror::Error;
258
259    /// Tried to call a function that doesn't exist
260    //
261    // CAUTION: this type is publicly exported in `cedar-policy`.
262    // Don't make fields `pub`, don't make breaking changes, and use caution
263    // when adding public methods.
264    #[derive(Debug, PartialEq, Eq, Clone, Error)]
265    #[error("extension function `{name}` does not exist")]
266    pub struct FuncDoesNotExistError {
267        /// Name of the function that doesn't exist
268        pub(crate) name: Name,
269        /// Source location
270        pub(crate) source_loc: Option<Loc>,
271    }
272
273    impl Diagnostic for FuncDoesNotExistError {
274        impl_diagnostic_from_source_loc_opt_field!(source_loc);
275    }
276}
277
278/// Type alias for convenience
279pub type Result<T> = std::result::Result<T, ExtensionFunctionLookupError>;
280
281/// Utilities shared with the `cedar-policy-validator` extensions module.
282pub mod util {
283    use std::collections::{hash_map::Entry, HashMap};
284
285    /// Utility to build a `HashMap` of key value pairs from an iterator,
286    /// returning an `Err` result if there are any duplicate keys in the
287    /// iterator.
288    pub fn collect_no_duplicates<K, V>(
289        i: impl Iterator<Item = (K, V)>,
290    ) -> std::result::Result<HashMap<K, V>, K>
291    where
292        K: Clone + std::hash::Hash + Eq,
293    {
294        let mut map = HashMap::with_capacity(i.size_hint().0);
295        for (k, v) in i {
296            match map.entry(k) {
297                Entry::Occupied(occupied) => {
298                    return Err(occupied.key().clone());
299                }
300                Entry::Vacant(vacant) => {
301                    vacant.insert(v);
302                }
303            }
304        }
305        Ok(map)
306    }
307}
308
309#[cfg(test)]
310pub mod test {
311    use super::*;
312    use std::collections::HashSet;
313
314    #[test]
315    fn no_common_extension_function_names() {
316        // Our expr display must search for callstyle given a name, so
317        // no names can be used for both callstyles
318
319        // Test that names are all unique for ease of use.
320        // This overconstrains our current requirements, but shouldn't change
321        // until we identify a strong need.
322        let all_names: Vec<_> = Extensions::all_available()
323            .extensions
324            .iter()
325            .flat_map(|e| e.funcs().map(|f| f.name().clone()))
326            .collect();
327        let dedup_names: HashSet<_> = all_names.iter().collect();
328        assert_eq!(all_names.len(), dedup_names.len());
329    }
330}