drasi_core/evaluation/functions/aggregation/
sum.rs

1// Copyright 2024 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{fmt::Debug, sync::Arc};
16
17use crate::{
18    evaluation::{FunctionError, FunctionEvaluationError},
19    interface::ResultIndex,
20};
21
22use async_trait::async_trait;
23
24use drasi_query_ast::ast;
25
26use crate::evaluation::{
27    variable_value::duration::Duration, variable_value::float::Float,
28    variable_value::VariableValue, ExpressionEvaluationContext,
29};
30
31use super::{super::AggregatingFunction, Accumulator, ValueAccumulator};
32use chrono::Duration as ChronoDuration;
33
34#[derive(Clone)]
35pub struct Sum {}
36
37#[async_trait]
38impl AggregatingFunction for Sum {
39    fn initialize_accumulator(
40        &self,
41        _context: &ExpressionEvaluationContext,
42        _expression: &ast::FunctionExpression,
43        _grouping_keys: &Vec<VariableValue>,
44        _index: Arc<dyn ResultIndex>,
45    ) -> Accumulator {
46        Accumulator::Value(ValueAccumulator::Sum { value: 0.0 })
47    }
48
49    fn accumulator_is_lazy(&self) -> bool {
50        false
51    }
52
53    async fn apply(
54        &self,
55        _context: &ExpressionEvaluationContext,
56        args: Vec<VariableValue>,
57        accumulator: &mut Accumulator,
58    ) -> Result<VariableValue, FunctionError> {
59        if args.len() != 1 {
60            return Err(FunctionError {
61                function_name: "Sum".to_string(),
62                error: FunctionEvaluationError::InvalidArgumentCount,
63            });
64        }
65
66        let accumulator = match accumulator {
67            Accumulator::Value(super::ValueAccumulator::Sum { value }) => value,
68            _ => {
69                return Err(FunctionError {
70                    function_name: "Sum".to_string(),
71                    error: FunctionEvaluationError::CorruptData,
72                });
73            }
74        };
75
76        match &args[0] {
77            VariableValue::Float(n) => {
78                *accumulator += match n.as_f64() {
79                    Some(n) => n,
80                    None => {
81                        return Err(FunctionError {
82                            function_name: "Sum".to_string(),
83                            error: FunctionEvaluationError::OverflowError,
84                        })
85                    }
86                };
87                Ok(VariableValue::Float(match Float::from_f64(*accumulator) {
88                    Some(n) => n,
89                    None => {
90                        return Err(FunctionError {
91                            function_name: "Sum".to_string(),
92                            error: FunctionEvaluationError::OverflowError,
93                        })
94                    }
95                }))
96            }
97            VariableValue::Integer(n) => {
98                *accumulator += match n.as_i64() {
99                    Some(n) => n as f64,
100                    None => {
101                        return Err(FunctionError {
102                            function_name: "Sum".to_string(),
103                            error: FunctionEvaluationError::OverflowError,
104                        })
105                    }
106                };
107                Ok(VariableValue::Float(match Float::from_f64(*accumulator) {
108                    Some(n) => n,
109                    None => {
110                        return Err(FunctionError {
111                            function_name: "Sum".to_string(),
112                            error: FunctionEvaluationError::OverflowError,
113                        })
114                    }
115                }))
116            }
117            VariableValue::Duration(d) => {
118                *accumulator += d.duration().num_milliseconds() as f64;
119                Ok(VariableValue::Duration(Duration::new(
120                    ChronoDuration::milliseconds(*accumulator as i64),
121                    0,
122                    0,
123                )))
124            }
125            VariableValue::Null => Ok(VariableValue::Float(match Float::from_f64(*accumulator) {
126                Some(n) => n,
127                None => {
128                    return Err(FunctionError {
129                        function_name: "Sum".to_string(),
130                        error: FunctionEvaluationError::OverflowError,
131                    })
132                }
133            })),
134            _ => Err(FunctionError {
135                function_name: "Sum".to_string(),
136                error: FunctionEvaluationError::InvalidArgument(0),
137            }),
138        }
139    }
140
141    async fn revert(
142        &self,
143        _context: &ExpressionEvaluationContext,
144        args: Vec<VariableValue>,
145        accumulator: &mut Accumulator,
146    ) -> Result<VariableValue, FunctionError> {
147        if args.len() != 1 {
148            return Err(FunctionError {
149                function_name: "Sum".to_string(),
150                error: FunctionEvaluationError::InvalidArgumentCount,
151            });
152        }
153        let accumulator = match accumulator {
154            Accumulator::Value(super::ValueAccumulator::Sum { value }) => value,
155            _ => {
156                return Err(FunctionError {
157                    function_name: "Sum".to_string(),
158                    error: FunctionEvaluationError::CorruptData,
159                })
160            }
161        };
162
163        match &args[0] {
164            VariableValue::Float(n) => {
165                *accumulator -= match n.as_f64() {
166                    Some(n) => n,
167                    None => {
168                        return Err(FunctionError {
169                            function_name: "Sum".to_string(),
170                            error: FunctionEvaluationError::OverflowError,
171                        })
172                    }
173                };
174                Ok(VariableValue::Float(match Float::from_f64(*accumulator) {
175                    Some(n) => n,
176                    None => {
177                        return Err(FunctionError {
178                            function_name: "Sum".to_string(),
179                            error: FunctionEvaluationError::OverflowError,
180                        })
181                    }
182                }))
183            }
184            VariableValue::Integer(n) => {
185                *accumulator -= match n.as_i64() {
186                    Some(n) => n as f64,
187                    None => {
188                        return Err(FunctionError {
189                            function_name: "Sum".to_string(),
190                            error: FunctionEvaluationError::OverflowError,
191                        })
192                    }
193                };
194                Ok(VariableValue::Float(match Float::from_f64(*accumulator) {
195                    Some(n) => n,
196                    None => {
197                        return Err(FunctionError {
198                            function_name: "Sum".to_string(),
199                            error: FunctionEvaluationError::OverflowError,
200                        })
201                    }
202                }))
203            }
204            VariableValue::Duration(d) => {
205                *accumulator -= d.duration().num_milliseconds() as f64;
206                Ok(VariableValue::Duration(Duration::new(
207                    ChronoDuration::milliseconds(*accumulator as i64),
208                    0,
209                    0,
210                )))
211            }
212            VariableValue::Null => Ok(VariableValue::Float(match Float::from_f64(*accumulator) {
213                Some(n) => n,
214                None => {
215                    return Err(FunctionError {
216                        function_name: "Sum".to_string(),
217                        error: FunctionEvaluationError::OverflowError,
218                    })
219                }
220            })),
221            _ => Err(FunctionError {
222                function_name: "Sum".to_string(),
223                error: FunctionEvaluationError::InvalidArgument(0),
224            }),
225        }
226    }
227
228    async fn snapshot(
229        &self,
230        _context: &ExpressionEvaluationContext,
231        args: Vec<VariableValue>,
232        accumulator: &Accumulator,
233    ) -> Result<VariableValue, FunctionError> {
234        if args.len() != 1 {
235            return Err(FunctionError {
236                function_name: "Sum".to_string(),
237                error: FunctionEvaluationError::InvalidArgumentCount,
238            });
239        }
240        let accumulator_value = match accumulator {
241            Accumulator::Value(super::ValueAccumulator::Sum { value }) => value,
242            _ => {
243                return Err(FunctionError {
244                    function_name: "Sum".to_string(),
245                    error: FunctionEvaluationError::CorruptData,
246                });
247            }
248        };
249
250        match &args[0] {
251            VariableValue::Float(_) => Ok(VariableValue::Float(
252                match Float::from_f64(*accumulator_value) {
253                    Some(n) => n,
254                    None => {
255                        return Err(FunctionError {
256                            function_name: "Sum".to_string(),
257                            error: FunctionEvaluationError::OverflowError,
258                        })
259                    }
260                },
261            )),
262            VariableValue::Integer(_) => Ok(VariableValue::Float(
263                match Float::from_f64(*accumulator_value) {
264                    Some(n) => n,
265                    None => {
266                        return Err(FunctionError {
267                            function_name: "Sum".to_string(),
268                            error: FunctionEvaluationError::OverflowError,
269                        })
270                    }
271                },
272            )),
273            VariableValue::Duration(_) => Ok(VariableValue::Duration(Duration::new(
274                ChronoDuration::milliseconds(*accumulator_value as i64),
275                0,
276                0,
277            ))),
278            VariableValue::Null => Ok(VariableValue::Float(
279                match Float::from_f64(*accumulator_value) {
280                    Some(n) => n,
281                    None => {
282                        return Err(FunctionError {
283                            function_name: "Sum".to_string(),
284                            error: FunctionEvaluationError::OverflowError,
285                        })
286                    }
287                },
288            )),
289            _ => Err(FunctionError {
290                function_name: "Sum".to_string(),
291                error: FunctionEvaluationError::InvalidArgument(0),
292            }),
293        }
294    }
295}
296
297impl Debug for Sum {
298    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299        write!(f, "Sum")
300    }
301}