drasi_core/evaluation/functions/
mod.rs

1// Copyright 2024 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::HashMap,
17    fmt::Debug,
18    sync::{Arc, RwLock},
19};
20
21use async_trait::async_trait;
22
23use drasi_query_ast::ast;
24use drasi_query_cypher::CypherConfiguration;
25
26use crate::evaluation::variable_value::VariableValue;
27use crate::interface::ResultIndex;
28
29use self::{
30    aggregation::{Accumulator, RegisterAggregationFunctions},
31    context_mutators::RegisterContextMutatorFunctions,
32    cypher_scalar::RegisterCypherScalarFunctions,
33    drasi::RegisterDrasiFunctions,
34    list::RegisterListFunctions,
35    metadata::RegisterMetadataFunctions,
36    numeric::RegisterNumericFunctions,
37    temporal_duration::RegisterTemporalDurationFunctions,
38    temporal_instant::RegisterTemporalInstantFunctions,
39    text::RegisterTextFunctions,
40};
41
42use super::{ExpressionEvaluationContext, FunctionError};
43
44pub mod aggregation;
45pub mod context_mutators;
46pub mod cypher_scalar;
47pub mod drasi;
48pub mod future;
49pub mod list;
50pub mod metadata;
51pub mod numeric;
52pub mod past;
53pub mod temporal_duration;
54pub mod temporal_instant;
55pub mod text;
56
57pub enum Function {
58    Scalar(Arc<dyn ScalarFunction>),
59    LazyScalar(Arc<dyn LazyScalarFunction>),
60    Aggregating(Arc<dyn AggregatingFunction>),
61    ContextMutator(Arc<dyn ContextMutatorFunction>),
62}
63
64#[async_trait]
65pub trait ScalarFunction: Send + Sync {
66    async fn call(
67        &self,
68        context: &ExpressionEvaluationContext,
69        expression: &ast::FunctionExpression,
70        args: Vec<VariableValue>,
71    ) -> Result<VariableValue, FunctionError>;
72}
73
74#[async_trait]
75pub trait LazyScalarFunction: Send + Sync {
76    async fn call(
77        &self,
78        context: &ExpressionEvaluationContext,
79        expression: &ast::FunctionExpression,
80        args: &Vec<ast::Expression>,
81    ) -> Result<VariableValue, FunctionError>;
82}
83
84#[async_trait]
85pub trait AggregatingFunction: Debug + Send + Sync {
86    fn initialize_accumulator(
87        &self,
88        context: &ExpressionEvaluationContext,
89        expression: &ast::FunctionExpression,
90        grouping_keys: &Vec<VariableValue>,
91        index: Arc<dyn ResultIndex>,
92    ) -> Accumulator; //todo: switch `dyn ResultIndex` to `dyn LazySortedSetStore` after trait upcasting is stable
93    async fn apply(
94        &self,
95        context: &ExpressionEvaluationContext,
96        args: Vec<VariableValue>,
97        accumulator: &mut Accumulator,
98    ) -> Result<VariableValue, FunctionError>;
99    async fn revert(
100        &self,
101        context: &ExpressionEvaluationContext,
102        args: Vec<VariableValue>,
103        accumulator: &mut Accumulator,
104    ) -> Result<VariableValue, FunctionError>;
105    async fn snapshot(
106        &self,
107        context: &ExpressionEvaluationContext,
108        args: Vec<VariableValue>,
109        accumulator: &Accumulator,
110    ) -> Result<VariableValue, FunctionError>;
111    fn accumulator_is_lazy(&self) -> bool;
112}
113
114#[async_trait]
115pub trait ContextMutatorFunction: Send + Sync {
116    async fn call<'a>(
117        &self,
118        context: &ExpressionEvaluationContext<'a>,
119        expression: &ast::FunctionExpression,
120    ) -> Result<ExpressionEvaluationContext<'a>, FunctionError>;
121}
122
123pub struct FunctionRegistry {
124    functions: Arc<RwLock<HashMap<String, Arc<Function>>>>,
125}
126
127impl Default for FunctionRegistry {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133impl FunctionRegistry {
134    pub fn new() -> FunctionRegistry {
135        let result = FunctionRegistry {
136            functions: Arc::new(RwLock::new(HashMap::new())),
137        };
138
139        result.register_text_functions();
140        result.register_metadata_functions();
141        result.register_numeric_functions();
142        result.register_aggregation_functions();
143        result.register_temporal_instant_functions();
144        result.register_temporal_duration_functions();
145        result.register_list_functions();
146        result.register_drasi_functions();
147        result.register_scalar_functions();
148        result.register_context_mutators();
149
150        result
151    }
152
153    #[allow(clippy::unwrap_used)]
154    pub fn register_function(&self, name: &str, function: Function) {
155        let mut lock = self.functions.write().unwrap();
156        lock.insert(name.to_string(), Arc::new(function));
157    }
158
159    #[allow(clippy::unwrap_used)]
160    pub fn get_function(&self, name: &str) -> Option<Arc<Function>> {
161        let lock = self.functions.read().unwrap();
162        lock.get(name).cloned()
163    }
164}
165
166impl CypherConfiguration for FunctionRegistry {
167    #[allow(clippy::unwrap_used)]
168    fn get_aggregating_function_names(&self) -> std::collections::HashSet<String> {
169        let lock = self.functions.read().unwrap();
170        lock.iter()
171            .filter_map(|(name, function)| match function.as_ref() {
172                Function::Aggregating(_) => Some(name.clone()),
173                _ => None,
174            })
175            .collect()
176    }
177}