drasi_core/evaluation/functions/aggregation/
count.rs

1// Copyright 2024 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{fmt::Debug, sync::Arc};
16
17use crate::{
18    evaluation::{FunctionError, FunctionEvaluationError},
19    interface::ResultIndex,
20};
21
22use async_trait::async_trait;
23
24use drasi_query_ast::ast;
25
26use crate::evaluation::{
27    variable_value::integer::Integer, variable_value::VariableValue, ExpressionEvaluationContext,
28};
29
30use super::{super::AggregatingFunction, Accumulator, ValueAccumulator};
31
32pub struct Count {}
33
34#[async_trait]
35impl AggregatingFunction for Count {
36    fn initialize_accumulator(
37        &self,
38        _context: &ExpressionEvaluationContext,
39        _expression: &ast::FunctionExpression,
40        _grouping_keys: &Vec<VariableValue>,
41        _index: Arc<dyn ResultIndex>,
42    ) -> Accumulator {
43        Accumulator::Value(ValueAccumulator::Count { value: 0 })
44    }
45
46    fn accumulator_is_lazy(&self) -> bool {
47        false
48    }
49
50    async fn apply(
51        &self,
52        _context: &ExpressionEvaluationContext,
53        args: Vec<VariableValue>,
54        accumulator: &mut Accumulator,
55    ) -> Result<VariableValue, FunctionError> {
56        if args.len() != 1 {
57            return Err(FunctionError {
58                function_name: "Count".to_string(),
59                error: FunctionEvaluationError::InvalidArgumentCount,
60            });
61        }
62
63        let value = match accumulator {
64            Accumulator::Value(super::ValueAccumulator::Count { value }) => value,
65            _ => {
66                return Err(FunctionError {
67                    function_name: "Count".to_string(),
68                    error: FunctionEvaluationError::CorruptData,
69                })
70            }
71        };
72
73        match &args[0] {
74            VariableValue::Null => Ok(VariableValue::Integer(Integer::from(*value))),
75            _ => {
76                *value += 1;
77                Ok(VariableValue::Integer(Integer::from(*value)))
78            }
79        }
80    }
81
82    async fn revert(
83        &self,
84        _context: &ExpressionEvaluationContext,
85        args: Vec<VariableValue>,
86        accumulator: &mut Accumulator,
87    ) -> Result<VariableValue, FunctionError> {
88        if args.len() != 1 {
89            return Err(FunctionError {
90                function_name: "Count".to_string(),
91                error: FunctionEvaluationError::InvalidArgumentCount,
92            });
93        }
94        let value =
95            if let Accumulator::Value(super::ValueAccumulator::Count { value }) = accumulator {
96                value
97            } else {
98                return Err(FunctionError {
99                    function_name: "Count".to_string(),
100                    error: FunctionEvaluationError::CorruptData,
101                });
102            };
103
104        match &args[0] {
105            VariableValue::Null => Ok(VariableValue::Integer(Integer::from(*value))),
106            _ => {
107                *value -= 1;
108                Ok(VariableValue::Integer(Integer::from(*value)))
109            }
110        }
111    }
112
113    async fn snapshot(
114        &self,
115        _context: &ExpressionEvaluationContext,
116        _args: Vec<VariableValue>,
117        accumulator: &Accumulator,
118    ) -> Result<VariableValue, FunctionError> {
119        let value =
120            if let Accumulator::Value(super::ValueAccumulator::Count { value }) = accumulator {
121                value
122            } else {
123                return Err(FunctionError {
124                    function_name: "Count".to_string(),
125                    error: FunctionEvaluationError::CorruptData,
126                });
127            };
128        Ok(VariableValue::Integer(Integer::from(*value)))
129    }
130}
131
132impl Debug for Count {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        write!(f, "Count")
135    }
136}