drasi_core/evaluation/functions/aggregation/
avg.rs

1// Copyright 2024 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{fmt::Debug, sync::Arc};
16
17use crate::{
18    evaluation::{FunctionError, FunctionEvaluationError},
19    interface::ResultIndex,
20};
21
22use async_trait::async_trait;
23
24use drasi_query_ast::ast;
25
26use crate::evaluation::{
27    variable_value::duration::Duration, variable_value::float::Float,
28    variable_value::VariableValue, ExpressionEvaluationContext,
29};
30
31use chrono::Duration as ChronoDuration;
32
33use super::{super::AggregatingFunction, Accumulator, ValueAccumulator};
34
35pub struct Avg {}
36
37#[async_trait]
38impl AggregatingFunction for Avg {
39    fn initialize_accumulator(
40        &self,
41        _context: &ExpressionEvaluationContext,
42        _expression: &ast::FunctionExpression,
43        _grouping_keys: &Vec<VariableValue>,
44        _index: Arc<dyn ResultIndex>,
45    ) -> Accumulator {
46        Accumulator::Value(ValueAccumulator::Avg { sum: 0.0, count: 0 })
47    }
48
49    fn accumulator_is_lazy(&self) -> bool {
50        false
51    }
52
53    async fn apply(
54        &self,
55        _context: &ExpressionEvaluationContext,
56        args: Vec<VariableValue>,
57        accumulator: &mut Accumulator,
58    ) -> Result<VariableValue, FunctionError> {
59        if args.len() != 1 {
60            return Err(FunctionError {
61                function_name: "Avg".to_string(),
62                error: FunctionEvaluationError::InvalidArgumentCount,
63            });
64        }
65
66        let (sum, count) = match accumulator {
67            Accumulator::Value(ValueAccumulator::Avg { sum, count }) => (sum, count),
68            _ => {
69                return Err(FunctionError {
70                    function_name: "Avg".to_string(),
71                    error: FunctionEvaluationError::CorruptData,
72                })
73            }
74        };
75
76        match &args[0] {
77            VariableValue::Float(n) => {
78                *count += 1;
79                *sum += match n.as_f64() {
80                    Some(n) => n,
81                    None => {
82                        return Err(FunctionError {
83                            function_name: "Avg".to_string(),
84                            error: FunctionEvaluationError::OverflowError,
85                        })
86                    }
87                };
88                let avg = *sum / *count as f64;
89
90                Ok(VariableValue::Float(
91                    Float::from_f64(avg).unwrap_or_default(),
92                ))
93            }
94            VariableValue::Integer(n) => {
95                *count += 1;
96                *sum += match n.as_i64() {
97                    Some(n) => n as f64,
98                    None => {
99                        return Err(FunctionError {
100                            function_name: "Avg".to_string(),
101                            error: FunctionEvaluationError::OverflowError,
102                        })
103                    }
104                };
105                let avg = *sum / *count as f64;
106
107                Ok(VariableValue::Float(
108                    Float::from_f64(avg).unwrap_or_default(),
109                ))
110            }
111            // The average of two dates/times does not really make sense
112            // Only adding duration for now
113            VariableValue::Duration(d) => {
114                *count += 1;
115                *sum += d.duration().num_milliseconds() as f64;
116                let avg = *sum / *count as f64;
117
118                Ok(VariableValue::Duration(Duration::new(
119                    ChronoDuration::milliseconds(avg as i64),
120                    0,
121                    0,
122                )))
123            }
124            VariableValue::Null => {
125                let avg = *sum / *count as f64;
126                Ok(VariableValue::Float(
127                    Float::from_f64(avg).unwrap_or_default(),
128                ))
129            }
130            _ => Err(FunctionError {
131                function_name: "Avg".to_string(),
132                error: FunctionEvaluationError::InvalidArgument(0),
133            }),
134        }
135    }
136
137    async fn revert(
138        &self,
139        _context: &ExpressionEvaluationContext,
140        args: Vec<VariableValue>,
141        accumulator: &mut Accumulator,
142    ) -> Result<VariableValue, FunctionError> {
143        if args.len() != 1 {
144            return Err(FunctionError {
145                function_name: "Avg".to_string(),
146                error: FunctionEvaluationError::InvalidArgumentCount,
147            });
148        }
149        let (sum, count) = match accumulator {
150            Accumulator::Value(ValueAccumulator::Avg { sum, count }) => (sum, count),
151            _ => {
152                return Err(FunctionError {
153                    function_name: "Avg".to_string(),
154                    error: FunctionEvaluationError::CorruptData,
155                })
156            }
157        };
158
159        match &args[0] {
160            VariableValue::Float(n) => {
161                *count -= 1;
162                *sum -= match n.as_f64() {
163                    Some(n) => n,
164                    None => {
165                        return Err(FunctionError {
166                            function_name: "Avg".to_string(),
167                            error: FunctionEvaluationError::OverflowError,
168                        })
169                    }
170                };
171
172                if *count == 0 {
173                    return Ok(VariableValue::Float(
174                        Float::from_f64(0.0).unwrap_or_default(),
175                    ));
176                }
177
178                let avg = *sum / *count as f64;
179
180                Ok(VariableValue::Float(
181                    Float::from_f64(avg).unwrap_or_default(),
182                ))
183            }
184            VariableValue::Integer(n) => {
185                *count -= 1;
186                *sum -= match n.as_i64() {
187                    Some(n) => n as f64,
188                    None => {
189                        return Err(FunctionError {
190                            function_name: "Avg".to_string(),
191                            error: FunctionEvaluationError::OverflowError,
192                        })
193                    }
194                };
195
196                if *count == 0 {
197                    return Ok(VariableValue::Float(
198                        Float::from_f64(0.0).unwrap_or_default(),
199                    ));
200                }
201
202                let avg = *sum / *count as f64;
203
204                Ok(VariableValue::Float(
205                    Float::from_f64(avg).unwrap_or_default(),
206                ))
207            }
208            VariableValue::Duration(d) => {
209                *count -= 1;
210                *sum -= d.duration().num_milliseconds() as f64;
211
212                if *count == 0 {
213                    return Ok(VariableValue::Float(
214                        Float::from_f64(0.0).unwrap_or_default(),
215                    ));
216                }
217
218                let avg = *sum / *count as f64;
219
220                Ok(VariableValue::Duration(Duration::new(
221                    ChronoDuration::milliseconds(avg as i64),
222                    0,
223                    0,
224                )))
225            }
226            VariableValue::Null => {
227                let avg = *sum / *count as f64;
228                Ok(VariableValue::Float(
229                    Float::from_f64(avg).unwrap_or_default(),
230                ))
231            }
232            _ => Err(FunctionError {
233                function_name: "Avg".to_string(),
234                error: FunctionEvaluationError::InvalidArgument(0),
235            }),
236        }
237    }
238
239    async fn snapshot(
240        &self,
241        _context: &ExpressionEvaluationContext,
242        args: Vec<VariableValue>,
243        accumulator: &Accumulator,
244    ) -> Result<VariableValue, FunctionError> {
245        if args.len() != 1 {
246            return Err(FunctionError {
247                function_name: "Avg".to_string(),
248                error: FunctionEvaluationError::InvalidArgumentCount,
249            });
250        }
251        let (sum, count) = match accumulator {
252            Accumulator::Value(ValueAccumulator::Avg { sum, count }) => (sum, count),
253            _ => {
254                return Err(FunctionError {
255                    function_name: "Avg".to_string(),
256                    error: FunctionEvaluationError::CorruptData,
257                })
258            }
259        };
260
261        if *count == 0 {
262            return Ok(VariableValue::Float(
263                Float::from_f64(0.0).unwrap_or_default(),
264            ));
265        }
266
267        let avg = *sum / *count as f64;
268
269        match &args[0] {
270            VariableValue::Float(_) => Ok(VariableValue::Float(
271                Float::from_f64(avg).unwrap_or_default(),
272            )),
273            VariableValue::Integer(_) => Ok(VariableValue::Float(
274                Float::from_f64(avg).unwrap_or_default(),
275            )),
276            VariableValue::Duration(_) => Ok(VariableValue::Duration(Duration::new(
277                ChronoDuration::milliseconds(avg as i64),
278                0,
279                0,
280            ))),
281            _ => Err(FunctionError {
282                function_name: "Avg".to_string(),
283                error: FunctionEvaluationError::InvalidArgument(0),
284            }),
285        }
286    }
287}
288
289impl Debug for Avg {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        write!(f, "Avg")
292    }
293}