cedar_policy_core/
extensions.rs1#[cfg(feature = "ipaddr")]
20pub mod ipaddr;
21
22#[cfg(feature = "decimal")]
23pub mod decimal;
24pub mod partial_evaluation;
25
26use std::collections::HashMap;
27
28use crate::ast::{Extension, ExtensionFunction, Name};
29use crate::entities::SchemaType;
30use crate::parser::Loc;
31use miette::Diagnostic;
32use thiserror::Error;
33
34use self::extension_function_lookup_errors::FuncDoesNotExistError;
35use self::extension_initialization_errors::{
36 FuncMultiplyDefinedError, MultipleConstructorsSameSignatureError,
37};
38
39lazy_static::lazy_static! {
40 static ref ALL_AVAILABLE_EXTENSION_OBJECTS: Vec<Extension> = vec![
41 #[cfg(feature = "ipaddr")]
42 ipaddr::extension(),
43 #[cfg(feature = "decimal")]
44 decimal::extension(),
45 #[cfg(feature = "partial-eval")]
46 partial_evaluation::extension(),
47 ];
48
49 static ref ALL_AVAILABLE_EXTENSIONS : Extensions<'static> = Extensions::build_all_available();
50
51 static ref EXTENSIONS_NONE : Extensions<'static> = Extensions {
52 extensions: &[],
53 functions: HashMap::new(),
54 single_arg_constructors: HashMap::new(),
55 };
56}
57
58#[derive(Debug)]
63pub struct Extensions<'a> {
64 extensions: &'a [Extension],
66 functions: HashMap<&'a Name, &'a ExtensionFunction>,
71 single_arg_constructors: HashMap<&'a SchemaType, &'a ExtensionFunction>,
75}
76
77impl Extensions<'static> {
78 fn build_all_available() -> Extensions<'static> {
80 #[allow(clippy::expect_used)]
82 Self::specific_extensions(&ALL_AVAILABLE_EXTENSION_OBJECTS)
83 .expect("Default extensions should never error on initialization")
84 }
85
86 pub fn all_available() -> &'static Extensions<'static> {
88 &ALL_AVAILABLE_EXTENSIONS
89 }
90
91 pub fn none() -> &'static Extensions<'static> {
93 &EXTENSIONS_NONE
94 }
95}
96
97impl<'a> Extensions<'a> {
98 pub fn specific_extensions(
100 extensions: &'a [Extension],
101 ) -> std::result::Result<Extensions<'a>, ExtensionInitializationError> {
102 let functions = util::collect_no_duplicates(
104 extensions
105 .iter()
106 .flat_map(|e| e.funcs())
107 .map(|f| (f.name(), f)),
108 )
109 .map_err(|name| FuncMultiplyDefinedError { name: name.clone() })?;
110
111 let single_arg_constructors = util::collect_no_duplicates(
113 extensions
114 .iter()
115 .flat_map(|e| e.funcs())
116 .filter(|f| f.is_constructor() && f.arg_types().len() == 1)
117 .filter_map(|f| f.return_type().map(|return_type| (return_type, f))),
118 )
119 .map_err(|return_type| MultipleConstructorsSameSignatureError {
120 return_type: Box::new(return_type.clone()),
121 })?;
122
123 Ok(Extensions {
124 extensions,
125 functions,
126 single_arg_constructors,
127 })
128 }
129
130 pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
132 self.extensions.iter().map(|ext| ext.name())
133 }
134
135 pub fn ext_types(&self) -> impl Iterator<Item = &Name> {
140 self.extensions.iter().flat_map(|ext| ext.ext_types())
141 }
142
143 pub fn func(
147 &self,
148 name: &Name,
149 ) -> std::result::Result<&ExtensionFunction, ExtensionFunctionLookupError> {
150 self.functions.get(name).copied().ok_or_else(|| {
151 FuncDoesNotExistError {
152 name: name.clone(),
153 source_loc: name.loc().cloned(),
154 }
155 .into()
156 })
157 }
158
159 pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
163 self.extensions.iter().flat_map(|ext| ext.funcs())
164 }
165
166 pub(crate) fn lookup_single_arg_constructor(
170 &self,
171 return_type: &SchemaType,
172 ) -> Option<&ExtensionFunction> {
173 self.single_arg_constructors.get(return_type).copied()
174 }
175}
176
177#[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
181pub enum ExtensionInitializationError {
182 #[error(transparent)]
184 #[diagnostic(transparent)]
185 FuncMultiplyDefined(#[from] extension_initialization_errors::FuncMultiplyDefinedError),
186
187 #[error(transparent)]
190 #[diagnostic(transparent)]
191 MultipleConstructorsSameSignature(
192 #[from] extension_initialization_errors::MultipleConstructorsSameSignatureError,
193 ),
194}
195
196mod extension_initialization_errors {
198 use crate::{ast::Name, entities::SchemaType};
199 use miette::Diagnostic;
200 use thiserror::Error;
201
202 #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
204 #[error("extension function `{name}` is defined multiple times")]
205 pub struct FuncMultiplyDefinedError {
206 pub(crate) name: Name,
208 }
209
210 #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
213 #[error("multiple extension constructors for the same extension type {return_type}")]
214 pub struct MultipleConstructorsSameSignatureError {
215 pub(crate) return_type: Box<SchemaType>,
217 }
218}
219
220#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
226pub enum ExtensionFunctionLookupError {
227 #[error(transparent)]
229 #[diagnostic(transparent)]
230 FuncDoesNotExist(#[from] extension_function_lookup_errors::FuncDoesNotExistError),
231}
232
233impl ExtensionFunctionLookupError {
234 pub(crate) fn source_loc(&self) -> Option<&Loc> {
235 match self {
236 Self::FuncDoesNotExist(e) => e.source_loc.as_ref(),
237 }
238 }
239
240 pub(crate) fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
241 match self {
242 Self::FuncDoesNotExist(e) => {
243 Self::FuncDoesNotExist(extension_function_lookup_errors::FuncDoesNotExistError {
244 source_loc,
245 ..e
246 })
247 }
248 }
249 }
250}
251
252pub mod extension_function_lookup_errors {
254 use crate::ast::Name;
255 use crate::parser::Loc;
256 use miette::Diagnostic;
257 use thiserror::Error;
258
259 #[derive(Debug, PartialEq, Eq, Clone, Error)]
265 #[error("extension function `{name}` does not exist")]
266 pub struct FuncDoesNotExistError {
267 pub(crate) name: Name,
269 pub(crate) source_loc: Option<Loc>,
271 }
272
273 impl Diagnostic for FuncDoesNotExistError {
274 impl_diagnostic_from_source_loc_opt_field!(source_loc);
275 }
276}
277
278pub type Result<T> = std::result::Result<T, ExtensionFunctionLookupError>;
280
281pub mod util {
283 use std::collections::{hash_map::Entry, HashMap};
284
285 pub fn collect_no_duplicates<K, V>(
289 i: impl Iterator<Item = (K, V)>,
290 ) -> std::result::Result<HashMap<K, V>, K>
291 where
292 K: Clone + std::hash::Hash + Eq,
293 {
294 let mut map = HashMap::with_capacity(i.size_hint().0);
295 for (k, v) in i {
296 match map.entry(k) {
297 Entry::Occupied(occupied) => {
298 return Err(occupied.key().clone());
299 }
300 Entry::Vacant(vacant) => {
301 vacant.insert(v);
302 }
303 }
304 }
305 Ok(map)
306 }
307}
308
309#[cfg(test)]
310pub mod test {
311 use super::*;
312 use std::collections::HashSet;
313
314 #[test]
315 fn no_common_extension_function_names() {
316 let all_names: Vec<_> = Extensions::all_available()
323 .extensions
324 .iter()
325 .flat_map(|e| e.funcs().map(|f| f.name().clone()))
326 .collect();
327 let dedup_names: HashSet<_> = all_names.iter().collect();
328 assert_eq!(all_names.len(), dedup_names.len());
329 }
330}