cedar_policy_core/extensions.rs
1/*
2 * Copyright 2022-2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! This module contains all of the standard Cedar extensions.
18
19#[cfg(feature = "ipaddr")]
20pub mod ipaddr;
21
22#[cfg(feature = "decimal")]
23pub mod decimal;
24pub mod partial_evaluation;
25
26use crate::ast::{Extension, ExtensionFunction, Name};
27use crate::entities::SchemaType;
28use thiserror::Error;
29
30lazy_static::lazy_static! {
31 static ref ALL_AVAILABLE_EXTENSIONS: Vec<Extension> = vec![
32 #[cfg(feature = "ipaddr")]
33 ipaddr::extension(),
34 #[cfg(feature = "decimal")]
35 decimal::extension(),
36 partial_evaluation::extension(),
37 ];
38}
39
40/// Holds data on all the Extensions which are active for a given evaluation.
41///
42/// Clone is cheap for this type.
43#[derive(Debug, Clone)]
44pub struct Extensions<'a> {
45 /// the actual extensions
46 extensions: &'a [Extension],
47}
48
49impl Extensions<'static> {
50 /// Get a new `Extensions` containing data on all the available extensions.
51 pub fn all_available() -> Extensions<'static> {
52 Extensions {
53 extensions: &ALL_AVAILABLE_EXTENSIONS,
54 }
55 }
56
57 /// Get a new `Extensions` with no extensions enabled.
58 pub fn none() -> Extensions<'static> {
59 Extensions { extensions: &[] }
60 }
61}
62
63impl<'a> Extensions<'a> {
64 /// Get a new `Extensions` with these specific extensions enabled.
65 pub fn specific_extensions(extensions: &'a [Extension]) -> Extensions<'a> {
66 Extensions { extensions }
67 }
68
69 /// Get the names of all active extensions.
70 pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
71 self.extensions.iter().map(|ext| ext.name())
72 }
73
74 /// Get the extension function with the given name, from these extensions.
75 ///
76 /// Returns an error if the function is not defined by any extension, or if
77 /// it is defined multiple times.
78 pub fn func(&self, name: &Name) -> Result<&ExtensionFunction> {
79 // NOTE: in the future, we could build a single HashMap of function
80 // name to ExtensionFunction, combining all extension functions
81 // into one map, to make this lookup faster.
82 let extension_funcs: Vec<&ExtensionFunction> = self
83 .extensions
84 .iter()
85 .filter_map(|ext| ext.get_func(name))
86 .collect();
87 match extension_funcs.len() {
88 0 => Err(ExtensionsError::FuncDoesNotExist { name: name.clone() }),
89 1 => Ok(extension_funcs[0]),
90 num_defs => Err(ExtensionsError::FuncMultiplyDefined {
91 name: name.clone(),
92 num_defs,
93 }),
94 }
95 }
96
97 /// Iterate over all extension functions defined by all of these extensions.
98 ///
99 /// No guarantee that this list won't have duplicates or repeated names.
100 pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
101 self.extensions.iter().flat_map(|ext| ext.funcs())
102 }
103
104 /// Lookup a single-argument constructor by its return type and argument type.
105 /// This will ignore polymorphic functions (that accept multiple argument types).
106 ///
107 /// `Ok(None)` means no constructor has that signature.
108 /// `Err` is returned in the case that multiple constructors have that signature.
109 pub(crate) fn lookup_single_arg_constructor(
110 &self,
111 return_type: &SchemaType,
112 arg_type: &SchemaType,
113 ) -> Result<Option<&ExtensionFunction>> {
114 let matches = self
115 .all_funcs()
116 .filter(|f| {
117 f.is_constructor()
118 && f.return_type() == Some(return_type)
119 && f.arg_types().len() == 1
120 && f.arg_types()[0].as_ref() == Some(arg_type)
121 })
122 .collect::<Vec<_>>();
123 match matches.len() {
124 0 => Ok(None),
125 1 => Ok(Some(matches[0])),
126 _ => Err(ExtensionsError::MultipleConstructorsSameSignature {
127 return_type: Box::new(return_type.clone()),
128 arg_type: Box::new(arg_type.clone()),
129 }),
130 }
131 }
132}
133
134/// Errors thrown during operations on `Extensions`.
135#[derive(Debug, PartialEq, Clone, Error)]
136pub enum ExtensionsError {
137 /// Tried to call a function that doesn't exist
138 #[error("function does not exist: {name}")]
139 FuncDoesNotExist {
140 /// Name of the function that doesn't exist
141 name: Name,
142 },
143
144 #[error("The extension function called has no type: {name}")]
145 /// Attempted to typecheck an expression that had no type
146 HasNoType {
147 /// Name of the function that returns no type
148 name: Name,
149 },
150
151 /// Tried to call a function but it was defined multiple times (eg, by
152 /// multiple different extensions)
153 #[error("function is defined {num_defs} times: {name}")]
154 FuncMultiplyDefined {
155 /// Name of the function that is multiply defined
156 name: Name,
157 /// How many times that function is defined
158 num_defs: usize,
159 },
160
161 /// Two extension constructors (in the same or different extensions) had
162 /// exactly the same type signature. This is currently not allowed.
163 #[error(
164 "multiple extension constructors with the same type signature {arg_type} -> {return_type}"
165 )]
166 MultipleConstructorsSameSignature {
167 /// return type of the shared constructor signature
168 return_type: Box<SchemaType>,
169 /// argument type of the shared constructor signature
170 arg_type: Box<SchemaType>,
171 },
172}
173
174/// Type alias for convenience
175pub type Result<T> = std::result::Result<T, ExtensionsError>;
176
177#[cfg(test)]
178pub mod test {
179 use super::*;
180 use std::collections::HashSet;
181
182 #[test]
183 fn no_common_extension_function_names() {
184 // Our expr display must search for callstyle given a name, so
185 // no names can be used for both callstyles
186
187 // Test that names are all unique for ease of use.
188 // This overconstrains our current requirements, but shouldn't change
189 // until we identify a strong need.
190 let all_names: Vec<_> = Extensions::all_available()
191 .extensions
192 .iter()
193 .flat_map(|e| e.funcs().map(|f| f.name().clone()))
194 .collect();
195 let dedup_names: HashSet<_> = all_names.iter().collect();
196 assert_eq!(all_names.len(), dedup_names.len());
197 }
198}