cedar_policy_core/
extensions.rs1#[cfg(feature = "ipaddr")]
20pub mod ipaddr;
21
22#[cfg(feature = "decimal")]
23pub mod decimal;
24
25#[cfg(feature = "datetime")]
26pub mod datetime;
27pub mod partial_evaluation;
28
29use std::collections::HashMap;
30use std::sync::LazyLock;
31
32use crate::ast::{Extension, ExtensionFunction, Name};
33use crate::entities::SchemaType;
34use crate::extensions::extension_initialization_errors::MultipleConstructorsSameSignatureError;
35use crate::parser::Loc;
36use miette::Diagnostic;
37use thiserror::Error;
38
39use self::extension_function_lookup_errors::FuncDoesNotExistError;
40use self::extension_initialization_errors::FuncMultiplyDefinedError;
41
42static ALL_AVAILABLE_EXTENSION_OBJECTS: LazyLock<Vec<Extension>> = LazyLock::new(|| {
43 vec![
44 #[cfg(feature = "ipaddr")]
45 ipaddr::extension(),
46 #[cfg(feature = "decimal")]
47 decimal::extension(),
48 #[cfg(feature = "datetime")]
49 datetime::extension(),
50 #[cfg(feature = "partial-eval")]
51 partial_evaluation::extension(),
52 ]
53});
54
55static ALL_AVAILABLE_EXTENSIONS: LazyLock<Extensions<'static>> =
56 LazyLock::new(Extensions::build_all_available);
57
58static EXTENSIONS_NONE: LazyLock<Extensions<'static>> = LazyLock::new(|| Extensions {
59 extensions: &[],
60 functions: HashMap::new(),
61 single_arg_constructors: HashMap::new(),
62});
63
64#[derive(Debug)]
69pub struct Extensions<'a> {
70 extensions: &'a [Extension],
72 functions: HashMap<&'a Name, &'a ExtensionFunction>,
77 single_arg_constructors: HashMap<&'a SchemaType, &'a ExtensionFunction>,
81}
82
83impl Extensions<'static> {
84 fn build_all_available() -> Extensions<'static> {
86 #[allow(clippy::expect_used)]
88 Self::specific_extensions(&ALL_AVAILABLE_EXTENSION_OBJECTS)
89 .expect("Default extensions should never error on initialization")
90 }
91
92 pub fn all_available() -> &'static Extensions<'static> {
94 &ALL_AVAILABLE_EXTENSIONS
95 }
96
97 pub fn none() -> &'static Extensions<'static> {
99 &EXTENSIONS_NONE
100 }
101}
102
103impl<'a> Extensions<'a> {
104 pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
106 self.extensions
107 .iter()
108 .flat_map(|ext| ext.types_with_operator_overloading())
109 }
110 pub fn specific_extensions(
112 extensions: &'a [Extension],
113 ) -> std::result::Result<Extensions<'a>, ExtensionInitializationError> {
114 let functions = util::collect_no_duplicates(
116 extensions
117 .iter()
118 .flat_map(|e| e.funcs())
119 .map(|f| (f.name(), f)),
120 )
121 .map_err(|name| FuncMultiplyDefinedError { name: name.clone() })?;
122
123 let single_arg_constructors = util::collect_no_duplicates(
125 extensions
126 .iter()
127 .flat_map(|e| e.funcs())
128 .filter(|f| f.is_single_arg_constructor())
129 .filter_map(|f| f.return_type().map(|return_type| (return_type, f))),
130 )
131 .map_err(|return_type| MultipleConstructorsSameSignatureError {
132 return_type: Box::new(return_type.clone()),
133 })?;
134
135 Ok(Extensions {
136 extensions,
137 functions,
138 single_arg_constructors,
139 })
140 }
141
142 pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
144 self.extensions.iter().map(|ext| ext.name())
145 }
146
147 pub fn ext_types(&self) -> impl Iterator<Item = &Name> {
152 self.extensions.iter().flat_map(|ext| ext.ext_types())
153 }
154
155 pub fn func(
159 &self,
160 name: &Name,
161 ) -> std::result::Result<&ExtensionFunction, ExtensionFunctionLookupError> {
162 self.functions.get(name).copied().ok_or_else(|| {
163 FuncDoesNotExistError {
164 name: name.clone(),
165 source_loc: name.loc().cloned(),
166 }
167 .into()
168 })
169 }
170
171 pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
175 self.extensions.iter().flat_map(|ext| ext.funcs())
176 }
177
178 pub(crate) fn lookup_single_arg_constructor(
180 &self,
181 return_type: &SchemaType,
182 ) -> Option<&ExtensionFunction> {
183 self.single_arg_constructors.get(return_type).copied()
184 }
185}
186
187#[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
191pub enum ExtensionInitializationError {
192 #[error(transparent)]
194 #[diagnostic(transparent)]
195 FuncMultiplyDefined(#[from] extension_initialization_errors::FuncMultiplyDefinedError),
196
197 #[error(transparent)]
200 #[diagnostic(transparent)]
201 MultipleConstructorsSameSignature(
202 #[from] extension_initialization_errors::MultipleConstructorsSameSignatureError,
203 ),
204}
205
206mod extension_initialization_errors {
208 use crate::{ast::Name, entities::SchemaType};
209 use miette::Diagnostic;
210 use thiserror::Error;
211
212 #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
214 #[error("extension function `{name}` is defined multiple times")]
215 pub struct FuncMultiplyDefinedError {
216 pub(crate) name: Name,
218 }
219
220 #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
223 #[error("multiple extension constructors for the same extension type {return_type}")]
224 pub struct MultipleConstructorsSameSignatureError {
225 pub(crate) return_type: Box<SchemaType>,
227 }
228}
229
230#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
236pub enum ExtensionFunctionLookupError {
237 #[error(transparent)]
239 #[diagnostic(transparent)]
240 FuncDoesNotExist(#[from] extension_function_lookup_errors::FuncDoesNotExistError),
241}
242
243impl ExtensionFunctionLookupError {
244 pub(crate) fn source_loc(&self) -> Option<&Loc> {
245 match self {
246 Self::FuncDoesNotExist(e) => e.source_loc.as_ref(),
247 }
248 }
249
250 pub(crate) fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
251 match self {
252 Self::FuncDoesNotExist(e) => {
253 Self::FuncDoesNotExist(extension_function_lookup_errors::FuncDoesNotExistError {
254 source_loc,
255 ..e
256 })
257 }
258 }
259 }
260}
261
262pub mod extension_function_lookup_errors {
264 use crate::ast::Name;
265 use crate::parser::Loc;
266 use miette::Diagnostic;
267 use thiserror::Error;
268
269 #[derive(Debug, PartialEq, Eq, Clone, Error)]
275 #[error("extension function `{name}` does not exist")]
276 pub struct FuncDoesNotExistError {
277 pub(crate) name: Name,
279 pub(crate) source_loc: Option<Loc>,
281 }
282
283 impl Diagnostic for FuncDoesNotExistError {
284 impl_diagnostic_from_source_loc_opt_field!(source_loc);
285 }
286}
287
288pub type Result<T> = std::result::Result<T, ExtensionFunctionLookupError>;
290
291pub mod util {
293 use std::collections::{hash_map::Entry, HashMap};
294
295 pub fn collect_no_duplicates<K, V>(
299 i: impl Iterator<Item = (K, V)>,
300 ) -> std::result::Result<HashMap<K, V>, K>
301 where
302 K: Clone + std::hash::Hash + Eq,
303 {
304 let mut map = HashMap::with_capacity(i.size_hint().0);
305 for (k, v) in i {
306 match map.entry(k) {
307 Entry::Occupied(occupied) => {
308 return Err(occupied.key().clone());
309 }
310 Entry::Vacant(vacant) => {
311 vacant.insert(v);
312 }
313 }
314 }
315 Ok(map)
316 }
317}
318
319#[cfg(test)]
320mod test {
321 use super::*;
322 use std::collections::HashSet;
323
324 #[test]
325 fn no_common_extension_function_names() {
326 let all_names: Vec<_> = Extensions::all_available()
333 .extensions
334 .iter()
335 .flat_map(|e| e.funcs().map(|f| f.name().clone()))
336 .collect();
337 let dedup_names: HashSet<_> = all_names.iter().collect();
338 assert_eq!(all_names.len(), dedup_names.len());
339 }
340}