drasi_core/evaluation/functions/
mod.rs1use std::{
16 collections::HashMap,
17 fmt::Debug,
18 sync::{Arc, RwLock},
19};
20
21use async_trait::async_trait;
22
23use drasi_query_ast::ast;
24use drasi_query_cypher::CypherConfiguration;
25
26use crate::evaluation::variable_value::VariableValue;
27use crate::interface::ResultIndex;
28
29use self::{
30 aggregation::{Accumulator, RegisterAggregationFunctions},
31 context_mutators::RegisterContextMutatorFunctions,
32 cypher_scalar::RegisterCypherScalarFunctions,
33 drasi::RegisterDrasiFunctions,
34 list::RegisterListFunctions,
35 metadata::RegisterMetadataFunctions,
36 numeric::RegisterNumericFunctions,
37 temporal_duration::RegisterTemporalDurationFunctions,
38 temporal_instant::RegisterTemporalInstantFunctions,
39 text::RegisterTextFunctions,
40};
41
42use super::{ExpressionEvaluationContext, FunctionError};
43
44pub mod aggregation;
45pub mod context_mutators;
46pub mod cypher_scalar;
47pub mod drasi;
48pub mod future;
49pub mod list;
50pub mod metadata;
51pub mod numeric;
52pub mod past;
53pub mod temporal_duration;
54pub mod temporal_instant;
55pub mod text;
56
57pub enum Function {
58 Scalar(Arc<dyn ScalarFunction>),
59 LazyScalar(Arc<dyn LazyScalarFunction>),
60 Aggregating(Arc<dyn AggregatingFunction>),
61 ContextMutator(Arc<dyn ContextMutatorFunction>),
62}
63
64#[async_trait]
65pub trait ScalarFunction: Send + Sync {
66 async fn call(
67 &self,
68 context: &ExpressionEvaluationContext,
69 expression: &ast::FunctionExpression,
70 args: Vec<VariableValue>,
71 ) -> Result<VariableValue, FunctionError>;
72}
73
74#[async_trait]
75pub trait LazyScalarFunction: Send + Sync {
76 async fn call(
77 &self,
78 context: &ExpressionEvaluationContext,
79 expression: &ast::FunctionExpression,
80 args: &Vec<ast::Expression>,
81 ) -> Result<VariableValue, FunctionError>;
82}
83
84#[async_trait]
85pub trait AggregatingFunction: Debug + Send + Sync {
86 fn initialize_accumulator(
87 &self,
88 context: &ExpressionEvaluationContext,
89 expression: &ast::FunctionExpression,
90 grouping_keys: &Vec<VariableValue>,
91 index: Arc<dyn ResultIndex>,
92 ) -> Accumulator; async fn apply(
94 &self,
95 context: &ExpressionEvaluationContext,
96 args: Vec<VariableValue>,
97 accumulator: &mut Accumulator,
98 ) -> Result<VariableValue, FunctionError>;
99 async fn revert(
100 &self,
101 context: &ExpressionEvaluationContext,
102 args: Vec<VariableValue>,
103 accumulator: &mut Accumulator,
104 ) -> Result<VariableValue, FunctionError>;
105 async fn snapshot(
106 &self,
107 context: &ExpressionEvaluationContext,
108 args: Vec<VariableValue>,
109 accumulator: &Accumulator,
110 ) -> Result<VariableValue, FunctionError>;
111 fn accumulator_is_lazy(&self) -> bool;
112}
113
114#[async_trait]
115pub trait ContextMutatorFunction: Send + Sync {
116 async fn call<'a>(
117 &self,
118 context: &ExpressionEvaluationContext<'a>,
119 expression: &ast::FunctionExpression,
120 ) -> Result<ExpressionEvaluationContext<'a>, FunctionError>;
121}
122
123pub struct FunctionRegistry {
124 functions: Arc<RwLock<HashMap<String, Arc<Function>>>>,
125}
126
127impl Default for FunctionRegistry {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133impl FunctionRegistry {
134 pub fn new() -> FunctionRegistry {
135 let result = FunctionRegistry {
136 functions: Arc::new(RwLock::new(HashMap::new())),
137 };
138
139 result.register_text_functions();
140 result.register_metadata_functions();
141 result.register_numeric_functions();
142 result.register_aggregation_functions();
143 result.register_temporal_instant_functions();
144 result.register_temporal_duration_functions();
145 result.register_list_functions();
146 result.register_drasi_functions();
147 result.register_scalar_functions();
148 result.register_context_mutators();
149
150 result
151 }
152
153 #[allow(clippy::unwrap_used)]
154 pub fn register_function(&self, name: &str, function: Function) {
155 let mut lock = self.functions.write().unwrap();
156 lock.insert(name.to_string(), Arc::new(function));
157 }
158
159 #[allow(clippy::unwrap_used)]
160 pub fn get_function(&self, name: &str) -> Option<Arc<Function>> {
161 let lock = self.functions.read().unwrap();
162 lock.get(name).cloned()
163 }
164}
165
166impl CypherConfiguration for FunctionRegistry {
167 #[allow(clippy::unwrap_used)]
168 fn get_aggregating_function_names(&self) -> std::collections::HashSet<String> {
169 let lock = self.functions.read().unwrap();
170 lock.iter()
171 .filter_map(|(name, function)| match function.as_ref() {
172 Function::Aggregating(_) => Some(name.clone()),
173 _ => None,
174 })
175 .collect()
176 }
177}