Skip to main content

drasi_core/evaluation/functions/
mod.rs

1// Copyright 2024 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{HashMap, HashSet},
17    fmt::Debug,
18    sync::{Arc, RwLock},
19};
20
21pub use async_trait::async_trait;
22
23pub use drasi_query_ast::api::QueryConfiguration;
24pub use drasi_query_ast::ast;
25
26use super::{ExpressionEvaluationContext, FunctionError};
27use crate::evaluation::variable_value::VariableValue;
28use crate::interface::ResultIndex;
29
30pub mod aggregation;
31pub mod context_mutators;
32pub mod cypher_scalar;
33pub mod drasi;
34pub mod future;
35pub mod list;
36pub mod metadata;
37pub mod numeric;
38pub mod past;
39pub mod temporal_duration;
40pub mod temporal_instant;
41pub mod text;
42pub mod trigonometric;
43
44pub use self::{
45    aggregation::*, context_mutators::*, cypher_scalar::*, drasi::*, list::*, metadata::*,
46    numeric::*, temporal_duration::*, temporal_instant::*, text::text::*, trigonometric::*,
47};
48pub use aggregation::Accumulator;
49
50pub enum Function {
51    Scalar(Arc<dyn ScalarFunction>),
52    LazyScalar(Arc<dyn LazyScalarFunction>),
53    Aggregating(Arc<dyn AggregatingFunction>),
54    ContextMutator(Arc<dyn ContextMutatorFunction>),
55}
56
57#[async_trait]
58pub trait ScalarFunction: Send + Sync {
59    async fn call(
60        &self,
61        context: &ExpressionEvaluationContext,
62        expression: &ast::FunctionExpression,
63        args: Vec<VariableValue>,
64    ) -> Result<VariableValue, FunctionError>;
65}
66
67#[async_trait]
68pub trait LazyScalarFunction: Send + Sync {
69    async fn call(
70        &self,
71        context: &ExpressionEvaluationContext,
72        expression: &ast::FunctionExpression,
73        args: &Vec<ast::Expression>,
74    ) -> Result<VariableValue, FunctionError>;
75}
76
77#[async_trait]
78pub trait AggregatingFunction: Debug + Send + Sync {
79    fn initialize_accumulator(
80        &self,
81        context: &ExpressionEvaluationContext,
82        expression: &ast::FunctionExpression,
83        grouping_keys: &Vec<VariableValue>,
84        index: Arc<dyn ResultIndex>,
85    ) -> Accumulator; //todo: switch `dyn ResultIndex` to `dyn LazySortedSetStore` after trait upcasting is stable
86    async fn apply(
87        &self,
88        context: &ExpressionEvaluationContext,
89        args: Vec<VariableValue>,
90        accumulator: &mut Accumulator,
91    ) -> Result<VariableValue, FunctionError>;
92    async fn revert(
93        &self,
94        context: &ExpressionEvaluationContext,
95        args: Vec<VariableValue>,
96        accumulator: &mut Accumulator,
97    ) -> Result<VariableValue, FunctionError>;
98    async fn snapshot(
99        &self,
100        context: &ExpressionEvaluationContext,
101        args: Vec<VariableValue>,
102        accumulator: &Accumulator,
103    ) -> Result<VariableValue, FunctionError>;
104    fn accumulator_is_lazy(&self) -> bool;
105}
106
107#[async_trait]
108pub trait ContextMutatorFunction: Send + Sync {
109    async fn call<'a>(
110        &self,
111        context: &ExpressionEvaluationContext<'a>,
112        expression: &ast::FunctionExpression,
113    ) -> Result<ExpressionEvaluationContext<'a>, FunctionError>;
114}
115
116pub struct FunctionRegistry {
117    functions: Arc<RwLock<HashMap<String, Arc<Function>>>>,
118}
119
120impl Default for FunctionRegistry {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl FunctionRegistry {
127    pub fn new() -> FunctionRegistry {
128        FunctionRegistry {
129            functions: Arc::new(RwLock::new(HashMap::new())),
130        }
131    }
132
133    #[allow(clippy::unwrap_used)]
134    pub fn register_function(&self, name: &str, function: Function) {
135        let mut lock = self.functions.write().unwrap();
136        lock.insert(name.to_string(), Arc::new(function));
137    }
138
139    #[allow(clippy::unwrap_used)]
140    pub fn get_function(&self, name: &str) -> Option<Arc<Function>> {
141        let lock = self.functions.read().unwrap();
142        lock.get(name).cloned()
143    }
144
145    #[allow(clippy::unwrap_used)]
146    pub fn get_aggregating_function_names(&self) -> HashSet<String> {
147        let lock = self.functions.read().unwrap();
148        lock.iter()
149            .filter_map(|(name, function)| match function.as_ref() {
150                Function::Aggregating(_) => Some(name.clone()),
151                _ => None,
152            })
153            .collect()
154    }
155}
156
157impl QueryConfiguration for FunctionRegistry {
158    fn get_aggregating_function_names(&self) -> HashSet<String> {
159        self.get_aggregating_function_names()
160    }
161}