Skip to main content

cedar_policy_core/
extensions.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! This module contains all of the standard Cedar extensions.
18
19#[cfg(feature = "ipaddr")]
20pub mod ipaddr;
21
22#[cfg(feature = "decimal")]
23pub mod decimal;
24
25#[cfg(feature = "datetime")]
26pub mod datetime;
27pub mod partial_evaluation;
28
29use std::collections::{HashMap, HashSet};
30use std::sync::LazyLock;
31
32use crate::ast::{CallStyle, Extension, ExtensionFunction, Name, UnreservedId};
33use crate::entities::SchemaType;
34use crate::extensions::extension_initialization_errors::MultipleConstructorsSameSignatureError;
35use crate::fuzzy_match::fuzzy_search_limited;
36use crate::parser::Loc;
37use miette::Diagnostic;
38use smol_str::{SmolStr, ToSmolStr};
39use thiserror::Error;
40
41use self::extension_function_lookup_errors::FuncDoesNotExistError;
42use self::extension_initialization_errors::FuncMultiplyDefinedError;
43
44static ALL_AVAILABLE_EXTENSION_OBJECTS: LazyLock<Vec<Extension>> = LazyLock::new(|| {
45    vec![
46        #[cfg(feature = "ipaddr")]
47        ipaddr::extension(),
48        #[cfg(feature = "decimal")]
49        decimal::extension(),
50        #[cfg(feature = "datetime")]
51        datetime::extension(),
52        #[cfg(feature = "partial-eval")]
53        partial_evaluation::extension(),
54    ]
55});
56
57static ALL_AVAILABLE_EXTENSIONS: LazyLock<Extensions<'static>> =
58    LazyLock::new(Extensions::build_all_available);
59
60static EXTENSIONS_NONE: LazyLock<Extensions<'static>> = LazyLock::new(|| Extensions {
61    extensions: &[],
62    functions: HashMap::new(),
63    single_arg_constructors: HashMap::new(),
64});
65
66static EXTENSION_STYLES: LazyLock<ExtStyles<'static>> = LazyLock::new(ExtStyles::load);
67
68/// Holds data on all the Extensions which are active for a given evaluation.
69///
70/// This structure is intentionally not `Clone` because we can use it entirely
71/// by reference.
72#[derive(Debug)]
73pub struct Extensions<'a> {
74    /// the actual extensions
75    extensions: &'a [Extension],
76    /// All extension functions, collected from every extension used to
77    /// construct this object.  Built ahead of time so that we know during
78    /// extension function lookup that at most one extension function exists
79    /// for a name. This should also make the lookup more efficient.
80    functions: HashMap<&'a Name, &'a ExtensionFunction>,
81    /// All single argument extension function constructors, indexed by their
82    /// return type. Built ahead of time so that we know each constructor has
83    /// a unique return type.
84    single_arg_constructors: HashMap<&'a SchemaType, &'a ExtensionFunction>,
85}
86
87impl Extensions<'static> {
88    /// Get a new `Extensions` containing data on all the available extensions.
89    fn build_all_available() -> Extensions<'static> {
90        #[expect(
91            clippy::expect_used,
92            reason = "Builtin extensions define functions/constructors only once. Also tested by many different test cases."
93        )]
94        Self::specific_extensions(&ALL_AVAILABLE_EXTENSION_OBJECTS)
95            .expect("Default extensions should never error on initialization")
96    }
97
98    /// An [`Extensions`] object with static lifetime contain all available extensions.
99    pub fn all_available() -> &'static Extensions<'static> {
100        &ALL_AVAILABLE_EXTENSIONS
101    }
102
103    /// Get a new `Extensions` with no extensions enabled.
104    pub fn none() -> &'static Extensions<'static> {
105        &EXTENSIONS_NONE
106    }
107}
108
109impl<'a> Extensions<'a> {
110    /// Obtain the non-empty vector of types supporting operator overloading
111    pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
112        self.extensions
113            .iter()
114            .flat_map(|ext| ext.types_with_operator_overloading())
115    }
116    /// Get a new `Extensions` with these specific extensions enabled.
117    pub fn specific_extensions(
118        extensions: &'a [Extension],
119    ) -> std::result::Result<Extensions<'a>, ExtensionInitializationError> {
120        // Build functions map, ensuring that no functions share the same name.
121        let functions = util::collect_no_duplicates(
122            extensions
123                .iter()
124                .flat_map(|e| e.funcs())
125                .map(|f| (f.name(), f)),
126        )
127        .map_err(|name| FuncMultiplyDefinedError { name: name.clone() })?;
128
129        // Build the constructor map, ensuring that no constructors share a return type
130        let single_arg_constructors = util::collect_no_duplicates(
131            extensions
132                .iter()
133                .flat_map(|e| e.funcs())
134                .filter(|f| f.is_single_arg_constructor())
135                .filter_map(|f| f.return_type().map(|return_type| (return_type, f))),
136        )
137        .map_err(|return_type| MultipleConstructorsSameSignatureError {
138            return_type: Box::new(return_type.clone()),
139        })?;
140
141        Ok(Extensions {
142            extensions,
143            functions,
144            single_arg_constructors,
145        })
146    }
147
148    /// Get the names of all active extensions.
149    pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
150        self.extensions.iter().map(|ext| ext.name())
151    }
152
153    /// Get all extension type names declared by active extensions.
154    ///
155    /// (More specifically, all extension type names such that any function in
156    /// an active extension could produce a value of that extension type.)
157    pub fn ext_types(&self) -> impl Iterator<Item = &Name> {
158        self.extensions.iter().flat_map(|ext| ext.ext_types())
159    }
160
161    /// Get the extension function with the given name, from these extensions.
162    ///
163    /// Returns an error if the function is not defined by any extension
164    pub fn func(
165        &self,
166        name: &Name,
167    ) -> std::result::Result<&ExtensionFunction, ExtensionFunctionLookupError> {
168        self.functions.get(name).copied().ok_or_else(|| {
169            FuncDoesNotExistError {
170                name: name.clone(),
171                source_loc: name.loc().cloned(),
172            }
173            .into()
174        })
175    }
176
177    /// Iterate over all extension functions defined by all of these extensions.
178    ///
179    /// No guarantee that this list won't have duplicates or repeated names.
180    pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
181        self.extensions.iter().flat_map(|ext| ext.funcs())
182    }
183
184    /// Lookup a single-argument constructor by its return type
185    pub(crate) fn lookup_single_arg_constructor(
186        &self,
187        return_type: &SchemaType,
188    ) -> Option<&ExtensionFunction> {
189        self.single_arg_constructors.get(return_type).copied()
190    }
191}
192
193/// Errors occurring while initializing extensions. There are internal errors, so
194/// this enum should not become part of the public API unless we publicly expose
195/// user-defined extension function.
196#[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
197pub enum ExtensionInitializationError {
198    /// An extension function was defined by multiple extensions.
199    #[error(transparent)]
200    #[diagnostic(transparent)]
201    FuncMultiplyDefined(#[from] extension_initialization_errors::FuncMultiplyDefinedError),
202
203    /// Two extension constructors (in the same or different extensions) had
204    /// exactly the same type signature.  This is currently not allowed.
205    #[error(transparent)]
206    #[diagnostic(transparent)]
207    MultipleConstructorsSameSignature(
208        #[from] extension_initialization_errors::MultipleConstructorsSameSignatureError,
209    ),
210}
211
212/// Error subtypes for [`ExtensionInitializationError`]
213mod extension_initialization_errors {
214    use crate::{ast::Name, entities::SchemaType};
215    use miette::Diagnostic;
216    use thiserror::Error;
217
218    /// An extension function was defined by multiple extensions.
219    #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
220    #[error("extension function `{name}` is defined multiple times")]
221    pub struct FuncMultiplyDefinedError {
222        /// Name of the function that was multiply defined
223        pub(crate) name: Name,
224    }
225
226    /// Two extension constructors (in the same or different extensions) exist
227    /// for one extension type.  This is currently not allowed.
228    #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
229    #[error("multiple extension constructors for the same extension type {return_type}")]
230    pub struct MultipleConstructorsSameSignatureError {
231        /// return type of the shared constructor signature
232        pub(crate) return_type: Box<SchemaType>,
233    }
234}
235
236/// Errors thrown when looking up an extension function in [`Extensions`].
237//
238// CAUTION: this type is publicly exported in `cedar-policy`.
239// Don't make fields `pub`, don't make breaking changes, and use caution
240// when adding public methods.
241#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
242pub enum ExtensionFunctionLookupError {
243    /// Tried to call a function that doesn't exist
244    #[error(transparent)]
245    #[diagnostic(transparent)]
246    FuncDoesNotExist(#[from] extension_function_lookup_errors::FuncDoesNotExistError),
247}
248
249impl ExtensionFunctionLookupError {
250    pub(crate) fn source_loc(&self) -> Option<&Loc> {
251        match self {
252            Self::FuncDoesNotExist(e) => e.source_loc.as_ref(),
253        }
254    }
255
256    pub(crate) fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
257        match self {
258            Self::FuncDoesNotExist(e) => {
259                Self::FuncDoesNotExist(extension_function_lookup_errors::FuncDoesNotExistError {
260                    source_loc,
261                    ..e
262                })
263            }
264        }
265    }
266}
267
268/// Error subtypes for [`ExtensionFunctionLookupError`]
269pub mod extension_function_lookup_errors {
270    use crate::ast::Name;
271    use crate::parser::Loc;
272    use miette::Diagnostic;
273    use thiserror::Error;
274
275    /// Tried to call a function that doesn't exist
276    //
277    // CAUTION: this type is publicly exported in `cedar-policy`.
278    // Don't make fields `pub`, don't make breaking changes, and use caution
279    // when adding public methods.
280    #[derive(Debug, PartialEq, Eq, Clone, Error)]
281    #[error("extension function `{name}` does not exist")]
282    pub struct FuncDoesNotExistError {
283        /// Name of the function that doesn't exist
284        pub(crate) name: Name,
285        /// Source location
286        pub(crate) source_loc: Option<Loc>,
287    }
288
289    impl Diagnostic for FuncDoesNotExistError {
290        impl_diagnostic_from_source_loc_opt_field!(source_loc);
291    }
292}
293
294/// Type alias for convenience
295pub type Result<T> = std::result::Result<T, ExtensionFunctionLookupError>;
296
297/// Extension functions have different callstyles. This stores information about the expected
298/// callstyle for each function. Provided static methods can be used to check the expected syntax
299/// of a given function call.
300#[derive(Debug)]
301pub(crate) struct ExtStyles<'a> {
302    /// All extension function names (just functions, not methods), as `Name`s
303    functions: HashSet<&'a Name>,
304    /// All extension function methods. `UnreservedId` is appropriate because methods cannot be namespaced.
305    methods: HashSet<UnreservedId>,
306    /// All extension function and method names (both qualified and unqualified), in their string (`Display`) form
307    functions_and_methods_as_str: HashSet<SmolStr>,
308}
309
310impl ExtStyles<'static> {
311    fn load() -> ExtStyles<'static> {
312        let mut functions = HashSet::new();
313        let mut methods = HashSet::new();
314        let mut functions_and_methods_as_str = HashSet::new();
315        for func in crate::extensions::Extensions::all_available().all_funcs() {
316            functions_and_methods_as_str.insert(func.name().to_smolstr());
317            match func.style() {
318                CallStyle::FunctionStyle => {
319                    functions.insert(func.name());
320                }
321                CallStyle::MethodStyle => {
322                    debug_assert!(func.name().is_unqualified());
323                    methods.insert(func.name().basename());
324                }
325            };
326        }
327        ExtStyles {
328            functions,
329            methods,
330            functions_and_methods_as_str,
331        }
332    }
333
334    /// If this [`UnreservedId`] is a known method name
335    pub(crate) fn is_method(id: &UnreservedId) -> bool {
336        EXTENSION_STYLES.methods.contains(id)
337    }
338
339    /// If this [`Name`] is a known function name
340    pub(crate) fn is_function(id: &Name) -> bool {
341        EXTENSION_STYLES.functions.contains(id)
342    }
343
344    /// If this [`Name`] is a known extension function/method name or not
345    pub(crate) fn is_known_extension_func_name(name: &Name) -> bool {
346        Self::is_function(name) || (name.0.path.is_empty() && Self::is_method(&name.basename()))
347    }
348
349    /// If this [`SmolStr`] is a known extension function/method name or not. Works
350    /// with both qualified and unqualified `s`. (As of this writing, there are no
351    /// qualified extension function/method names, so qualified `s` always results
352    /// in `false`.)
353    pub(crate) fn is_known_extension_func_str(s: &SmolStr) -> bool {
354        EXTENSION_STYLES.functions_and_methods_as_str.contains(s)
355    }
356
357    fn suggest<I, T>(key: &str, choices: I) -> Option<String>
358    where
359        I: IntoIterator<Item = T>,
360        T: ToString,
361    {
362        const SUGGESTION_EXTENSION_MAX_DISTANCE: usize = 3;
363        let choice_strings: Vec<String> = choices.into_iter().map(|c| c.to_string()).collect();
364        let suggestion = fuzzy_search_limited(
365            key,
366            choice_strings.as_slice(),
367            Some(SUGGESTION_EXTENSION_MAX_DISTANCE),
368        );
369        suggestion.map(|m| format!("did you mean `{m}`?"))
370    }
371
372    /// When a method call was expected, suggest a method name matching the provided name
373    pub(crate) fn suggest_method(name: &UnreservedId) -> Option<String> {
374        Self::suggest(name.as_ref(), &EXTENSION_STYLES.methods)
375    }
376
377    /// When a function call was expected, suggest a function name matching the provided name
378    pub(crate) fn suggest_function(name: &Name) -> Option<String> {
379        Self::suggest(&name.to_string(), &EXTENSION_STYLES.functions)
380    }
381}
382
383/// Utilities shared with the `cedar-policy-validator` extensions module.
384pub mod util {
385    use std::collections::{hash_map::Entry, HashMap};
386
387    /// Utility to build a `HashMap` of key value pairs from an iterator,
388    /// returning an `Err` result if there are any duplicate keys in the
389    /// iterator.
390    pub fn collect_no_duplicates<K, V>(
391        i: impl Iterator<Item = (K, V)>,
392    ) -> std::result::Result<HashMap<K, V>, K>
393    where
394        K: Clone + std::hash::Hash + Eq,
395    {
396        let mut map = HashMap::with_capacity(i.size_hint().0);
397        for (k, v) in i {
398            match map.entry(k) {
399                Entry::Occupied(occupied) => {
400                    return Err(occupied.key().clone());
401                }
402                Entry::Vacant(vacant) => {
403                    vacant.insert(v);
404                }
405            }
406        }
407        Ok(map)
408    }
409}
410
411#[cfg(test)]
412mod test {
413    use super::*;
414    use std::collections::HashSet;
415
416    #[test]
417    fn no_common_extension_function_names() {
418        // Our expr display must search for callstyle given a name, so
419        // no names can be used for both callstyles
420
421        // Test that names are all unique for ease of use.
422        // This overconstrains our current requirements, but shouldn't change
423        // until we identify a strong need.
424        let all_names: Vec<_> = Extensions::all_available()
425            .extensions
426            .iter()
427            .flat_map(|e| e.funcs().map(|f| f.name().clone()))
428            .collect();
429        let dedup_names: HashSet<_> = all_names.iter().collect();
430        assert_eq!(all_names.len(), dedup_names.len());
431    }
432}