Skip to main content

cedar_policy_core/
extensions.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! This module contains all of the standard Cedar extensions.
18
19#[cfg(feature = "ipaddr")]
20pub mod ipaddr;
21
22#[cfg(feature = "decimal")]
23pub mod decimal;
24
25#[cfg(feature = "datetime")]
26pub mod datetime;
27pub mod partial_evaluation;
28
29use std::collections::HashMap;
30use std::sync::LazyLock;
31
32use crate::ast::{Extension, ExtensionFunction, Name};
33use crate::entities::SchemaType;
34use crate::extensions::extension_initialization_errors::MultipleConstructorsSameSignatureError;
35use crate::parser::Loc;
36use miette::Diagnostic;
37use thiserror::Error;
38
39use self::extension_function_lookup_errors::FuncDoesNotExistError;
40use self::extension_initialization_errors::FuncMultiplyDefinedError;
41
42static ALL_AVAILABLE_EXTENSION_OBJECTS: LazyLock<Vec<Extension>> = LazyLock::new(|| {
43    vec![
44        #[cfg(feature = "ipaddr")]
45        ipaddr::extension(),
46        #[cfg(feature = "decimal")]
47        decimal::extension(),
48        #[cfg(feature = "datetime")]
49        datetime::extension(),
50        #[cfg(feature = "partial-eval")]
51        partial_evaluation::extension(),
52    ]
53});
54
55static ALL_AVAILABLE_EXTENSIONS: LazyLock<Extensions<'static>> =
56    LazyLock::new(Extensions::build_all_available);
57
58static EXTENSIONS_NONE: LazyLock<Extensions<'static>> = LazyLock::new(|| Extensions {
59    extensions: &[],
60    functions: HashMap::new(),
61    single_arg_constructors: HashMap::new(),
62});
63
64/// Holds data on all the Extensions which are active for a given evaluation.
65///
66/// This structure is intentionally not `Clone` because we can use it entirely
67/// by reference.
68#[derive(Debug)]
69pub struct Extensions<'a> {
70    /// the actual extensions
71    extensions: &'a [Extension],
72    /// All extension functions, collected from every extension used to
73    /// construct this object.  Built ahead of time so that we know during
74    /// extension function lookup that at most one extension function exists
75    /// for a name. This should also make the lookup more efficient.
76    functions: HashMap<&'a Name, &'a ExtensionFunction>,
77    /// All single argument extension function constructors, indexed by their
78    /// return type. Built ahead of time so that we know each constructor has
79    /// a unique return type.
80    single_arg_constructors: HashMap<&'a SchemaType, &'a ExtensionFunction>,
81}
82
83impl Extensions<'static> {
84    /// Get a new `Extensions` containing data on all the available extensions.
85    fn build_all_available() -> Extensions<'static> {
86        #[expect(
87            clippy::expect_used,
88            reason = "Builtin extensions define functions/constructors only once. Also tested by many different test cases."
89        )]
90        Self::specific_extensions(&ALL_AVAILABLE_EXTENSION_OBJECTS)
91            .expect("Default extensions should never error on initialization")
92    }
93
94    /// An [`Extensions`] object with static lifetime contain all available extensions.
95    pub fn all_available() -> &'static Extensions<'static> {
96        &ALL_AVAILABLE_EXTENSIONS
97    }
98
99    /// Get a new `Extensions` with no extensions enabled.
100    pub fn none() -> &'static Extensions<'static> {
101        &EXTENSIONS_NONE
102    }
103}
104
105impl<'a> Extensions<'a> {
106    /// Obtain the non-empty vector of types supporting operator overloading
107    pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
108        self.extensions
109            .iter()
110            .flat_map(|ext| ext.types_with_operator_overloading())
111    }
112    /// Get a new `Extensions` with these specific extensions enabled.
113    pub fn specific_extensions(
114        extensions: &'a [Extension],
115    ) -> std::result::Result<Extensions<'a>, ExtensionInitializationError> {
116        // Build functions map, ensuring that no functions share the same name.
117        let functions = util::collect_no_duplicates(
118            extensions
119                .iter()
120                .flat_map(|e| e.funcs())
121                .map(|f| (f.name(), f)),
122        )
123        .map_err(|name| FuncMultiplyDefinedError { name: name.clone() })?;
124
125        // Build the constructor map, ensuring that no constructors share a return type
126        let single_arg_constructors = util::collect_no_duplicates(
127            extensions
128                .iter()
129                .flat_map(|e| e.funcs())
130                .filter(|f| f.is_single_arg_constructor())
131                .filter_map(|f| f.return_type().map(|return_type| (return_type, f))),
132        )
133        .map_err(|return_type| MultipleConstructorsSameSignatureError {
134            return_type: Box::new(return_type.clone()),
135        })?;
136
137        Ok(Extensions {
138            extensions,
139            functions,
140            single_arg_constructors,
141        })
142    }
143
144    /// Get the names of all active extensions.
145    pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
146        self.extensions.iter().map(|ext| ext.name())
147    }
148
149    /// Get all extension type names declared by active extensions.
150    ///
151    /// (More specifically, all extension type names such that any function in
152    /// an active extension could produce a value of that extension type.)
153    pub fn ext_types(&self) -> impl Iterator<Item = &Name> {
154        self.extensions.iter().flat_map(|ext| ext.ext_types())
155    }
156
157    /// Get the extension function with the given name, from these extensions.
158    ///
159    /// Returns an error if the function is not defined by any extension
160    pub fn func(
161        &self,
162        name: &Name,
163    ) -> std::result::Result<&ExtensionFunction, ExtensionFunctionLookupError> {
164        self.functions.get(name).copied().ok_or_else(|| {
165            FuncDoesNotExistError {
166                name: name.clone(),
167                source_loc: name.loc().cloned(),
168            }
169            .into()
170        })
171    }
172
173    /// Iterate over all extension functions defined by all of these extensions.
174    ///
175    /// No guarantee that this list won't have duplicates or repeated names.
176    pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
177        self.extensions.iter().flat_map(|ext| ext.funcs())
178    }
179
180    /// Lookup a single-argument constructor by its return type
181    pub(crate) fn lookup_single_arg_constructor(
182        &self,
183        return_type: &SchemaType,
184    ) -> Option<&ExtensionFunction> {
185        self.single_arg_constructors.get(return_type).copied()
186    }
187}
188
189/// Errors occurring while initializing extensions. There are internal errors, so
190/// this enum should not become part of the public API unless we publicly expose
191/// user-defined extension function.
192#[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
193pub enum ExtensionInitializationError {
194    /// An extension function was defined by multiple extensions.
195    #[error(transparent)]
196    #[diagnostic(transparent)]
197    FuncMultiplyDefined(#[from] extension_initialization_errors::FuncMultiplyDefinedError),
198
199    /// Two extension constructors (in the same or different extensions) had
200    /// exactly the same type signature.  This is currently not allowed.
201    #[error(transparent)]
202    #[diagnostic(transparent)]
203    MultipleConstructorsSameSignature(
204        #[from] extension_initialization_errors::MultipleConstructorsSameSignatureError,
205    ),
206}
207
208/// Error subtypes for [`ExtensionInitializationError`]
209mod extension_initialization_errors {
210    use crate::{ast::Name, entities::SchemaType};
211    use miette::Diagnostic;
212    use thiserror::Error;
213
214    /// An extension function was defined by multiple extensions.
215    #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
216    #[error("extension function `{name}` is defined multiple times")]
217    pub struct FuncMultiplyDefinedError {
218        /// Name of the function that was multiply defined
219        pub(crate) name: Name,
220    }
221
222    /// Two extension constructors (in the same or different extensions) exist
223    /// for one extension type.  This is currently not allowed.
224    #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
225    #[error("multiple extension constructors for the same extension type {return_type}")]
226    pub struct MultipleConstructorsSameSignatureError {
227        /// return type of the shared constructor signature
228        pub(crate) return_type: Box<SchemaType>,
229    }
230}
231
232/// Errors thrown when looking up an extension function in [`Extensions`].
233//
234// CAUTION: this type is publicly exported in `cedar-policy`.
235// Don't make fields `pub`, don't make breaking changes, and use caution
236// when adding public methods.
237#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
238pub enum ExtensionFunctionLookupError {
239    /// Tried to call a function that doesn't exist
240    #[error(transparent)]
241    #[diagnostic(transparent)]
242    FuncDoesNotExist(#[from] extension_function_lookup_errors::FuncDoesNotExistError),
243}
244
245impl ExtensionFunctionLookupError {
246    pub(crate) fn source_loc(&self) -> Option<&Loc> {
247        match self {
248            Self::FuncDoesNotExist(e) => e.source_loc.as_ref(),
249        }
250    }
251
252    pub(crate) fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
253        match self {
254            Self::FuncDoesNotExist(e) => {
255                Self::FuncDoesNotExist(extension_function_lookup_errors::FuncDoesNotExistError {
256                    source_loc,
257                    ..e
258                })
259            }
260        }
261    }
262}
263
264/// Error subtypes for [`ExtensionFunctionLookupError`]
265pub mod extension_function_lookup_errors {
266    use crate::ast::Name;
267    use crate::parser::Loc;
268    use miette::Diagnostic;
269    use thiserror::Error;
270
271    /// Tried to call a function that doesn't exist
272    //
273    // CAUTION: this type is publicly exported in `cedar-policy`.
274    // Don't make fields `pub`, don't make breaking changes, and use caution
275    // when adding public methods.
276    #[derive(Debug, PartialEq, Eq, Clone, Error)]
277    #[error("extension function `{name}` does not exist")]
278    pub struct FuncDoesNotExistError {
279        /// Name of the function that doesn't exist
280        pub(crate) name: Name,
281        /// Source location
282        pub(crate) source_loc: Option<Loc>,
283    }
284
285    impl Diagnostic for FuncDoesNotExistError {
286        impl_diagnostic_from_source_loc_opt_field!(source_loc);
287    }
288}
289
290/// Type alias for convenience
291pub type Result<T> = std::result::Result<T, ExtensionFunctionLookupError>;
292
293/// Utilities shared with the `cedar-policy-validator` extensions module.
294pub mod util {
295    use std::collections::{hash_map::Entry, HashMap};
296
297    /// Utility to build a `HashMap` of key value pairs from an iterator,
298    /// returning an `Err` result if there are any duplicate keys in the
299    /// iterator.
300    pub fn collect_no_duplicates<K, V>(
301        i: impl Iterator<Item = (K, V)>,
302    ) -> std::result::Result<HashMap<K, V>, K>
303    where
304        K: Clone + std::hash::Hash + Eq,
305    {
306        let mut map = HashMap::with_capacity(i.size_hint().0);
307        for (k, v) in i {
308            match map.entry(k) {
309                Entry::Occupied(occupied) => {
310                    return Err(occupied.key().clone());
311                }
312                Entry::Vacant(vacant) => {
313                    vacant.insert(v);
314                }
315            }
316        }
317        Ok(map)
318    }
319}
320
321#[cfg(test)]
322mod test {
323    use super::*;
324    use std::collections::HashSet;
325
326    #[test]
327    fn no_common_extension_function_names() {
328        // Our expr display must search for callstyle given a name, so
329        // no names can be used for both callstyles
330
331        // Test that names are all unique for ease of use.
332        // This overconstrains our current requirements, but shouldn't change
333        // until we identify a strong need.
334        let all_names: Vec<_> = Extensions::all_available()
335            .extensions
336            .iter()
337            .flat_map(|e| e.funcs().map(|f| f.name().clone()))
338            .collect();
339        let dedup_names: HashSet<_> = all_names.iter().collect();
340        assert_eq!(all_names.len(), dedup_names.len());
341    }
342}