Skip to main content

drasi_core/evaluation/functions/aggregation/
linear_gradient.rs

1// Copyright 2024 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{fmt::Debug, sync::Arc};
16
17use crate::{
18    evaluation::{FunctionError, FunctionEvaluationError},
19    interface::ResultIndex,
20};
21
22use async_trait::async_trait;
23
24use drasi_query_ast::ast;
25
26use crate::evaluation::{
27    variable_value::float::Float, variable_value::VariableValue, ExpressionEvaluationContext,
28};
29
30use chrono::{NaiveTime, Timelike};
31
32use super::{super::AggregatingFunction, Accumulator, ValueAccumulator};
33
34pub struct LinearGradient {}
35
36#[async_trait]
37impl AggregatingFunction for LinearGradient {
38    fn initialize_accumulator(
39        &self,
40        _context: &ExpressionEvaluationContext,
41        _expression: &ast::FunctionExpression,
42        _grouping_keys: &Vec<VariableValue>,
43        _index: Arc<dyn ResultIndex>,
44    ) -> Accumulator {
45        Accumulator::Value(ValueAccumulator::LinearGradient {
46            count: 0,
47            mean_x: 0.0,
48            mean_y: 0.0,
49            m2: 0.0,
50            cov: 0.0,
51        })
52    }
53
54    fn accumulator_is_lazy(&self) -> bool {
55        false
56    }
57
58    async fn apply(
59        &self,
60        _context: &ExpressionEvaluationContext,
61        args: Vec<VariableValue>,
62        accumulator: &mut Accumulator,
63    ) -> Result<VariableValue, FunctionError> {
64        if args.len() != 2 {
65            return Err(FunctionError {
66                function_name: "linearGradient".to_string(),
67                error: FunctionEvaluationError::InvalidArgumentCount,
68            });
69        }
70
71        let (count, mean_x, mean_y, m2, cov) = match accumulator {
72            Accumulator::Value(ValueAccumulator::LinearGradient {
73                count,
74                mean_x,
75                mean_y,
76                m2,
77                cov,
78            }) => (count, mean_x, mean_y, m2, cov),
79            _ => {
80                return Err(FunctionError {
81                    function_name: "LinearGradient".to_string(),
82                    error: FunctionEvaluationError::CorruptData,
83                })
84            }
85        };
86
87        if let VariableValue::Null = args[0] {
88            return Ok(VariableValue::Null);
89        }
90
91        if let VariableValue::Null = args[1] {
92            return Ok(VariableValue::Null);
93        }
94
95        let x = extract_parameter(&args[0], 0)?;
96        let y = extract_parameter(&args[1], 1)?;
97
98        *count += 1;
99        let delta_x = x - *mean_x;
100        let delta_y = y - *mean_y;
101        *mean_x += delta_x / *count as f64;
102        *mean_y += delta_y / *count as f64;
103        let delta2 = x - *mean_x;
104        *m2 += delta_x * delta2;
105        *cov += delta_x * (y - *mean_y);
106
107        let result = covariance(*cov, *count) / variance(*m2, *count);
108
109        if result.is_nan() {
110            return Ok(VariableValue::Null);
111        }
112
113        Ok(VariableValue::Float(
114            Float::from_f64(result).unwrap_or_default(),
115        ))
116    }
117
118    async fn revert(
119        &self,
120        _context: &ExpressionEvaluationContext,
121        args: Vec<VariableValue>,
122        accumulator: &mut Accumulator,
123    ) -> Result<VariableValue, FunctionError> {
124        if args.len() != 2 {
125            return Err(FunctionError {
126                function_name: "linearGradient".to_string(),
127                error: FunctionEvaluationError::InvalidArgumentCount,
128            });
129        }
130
131        let (count, mean_x, mean_y, m2, cov) = match accumulator {
132            Accumulator::Value(ValueAccumulator::LinearGradient {
133                count,
134                mean_x,
135                mean_y,
136                m2,
137                cov,
138            }) => (count, mean_x, mean_y, m2, cov),
139            _ => {
140                return Err(FunctionError {
141                    function_name: "LinearGradient".to_string(),
142                    error: FunctionEvaluationError::CorruptData,
143                });
144            }
145        };
146
147        if let VariableValue::Null = args[0] {
148            return Ok(VariableValue::Null);
149        }
150
151        if let VariableValue::Null = args[1] {
152            return Ok(VariableValue::Null);
153        }
154
155        let x = extract_parameter(&args[0], 0)?;
156        let y = extract_parameter(&args[1], 1)?;
157
158        *count -= 1;
159
160        if *count == 0 {
161            *mean_x = 0.0;
162            *mean_y = 0.0;
163            *m2 = 0.0;
164            *cov = 0.0;
165            return Ok(VariableValue::Null);
166        }
167
168        let delta_x = x - *mean_x;
169        let delta_y = y - *mean_y;
170        *mean_x -= delta_x / *count as f64;
171        *mean_y -= delta_y / *count as f64;
172        let delta2 = x - *mean_x;
173        *m2 -= delta_x * delta2;
174        *cov -= delta_x * (y - *mean_y);
175
176        let result = covariance(*cov, *count) / variance(*m2, *count);
177
178        if result.is_nan() {
179            return Ok(VariableValue::Null);
180        }
181
182        Ok(VariableValue::Float(
183            Float::from_f64(result).unwrap_or_default(),
184        ))
185    }
186
187    async fn snapshot(
188        &self,
189        _context: &ExpressionEvaluationContext,
190        args: Vec<VariableValue>,
191        accumulator: &Accumulator,
192    ) -> Result<VariableValue, FunctionError> {
193        if args.len() != 2 {
194            return Err(FunctionError {
195                function_name: "linearGradient".to_string(),
196                error: FunctionEvaluationError::InvalidArgumentCount,
197            });
198        }
199
200        let (count, _mean_x, _mean_y, m2, cov) = match accumulator {
201            Accumulator::Value(ValueAccumulator::LinearGradient {
202                count,
203                mean_x,
204                mean_y,
205                m2,
206                cov,
207            }) => (count, mean_x, mean_y, m2, cov),
208            _ => {
209                return Err(FunctionError {
210                    function_name: "LinearGradient".to_string(),
211                    error: FunctionEvaluationError::CorruptData,
212                });
213            }
214        };
215
216        if *count == 0 {
217            return Ok(VariableValue::Null);
218        }
219
220        let result = covariance(*cov, *count) / variance(*m2, *count);
221
222        if result.is_nan() {
223            return Ok(VariableValue::Null);
224        }
225
226        Ok(VariableValue::Float(
227            Float::from_f64(result).unwrap_or_default(),
228        ))
229    }
230}
231
232fn extract_parameter(p: &VariableValue, index: u64) -> Result<f64, FunctionError> {
233    let result = match p {
234        VariableValue::Float(n) => match n.as_f64() {
235            Some(n) => n,
236            None => {
237                return Err(FunctionError {
238                    function_name: "LinearGradient".to_string(),
239                    error: FunctionEvaluationError::OverflowError,
240                })
241            }
242        },
243        VariableValue::Integer(n) => match n.as_i64() {
244            Some(n) => n as f64,
245            None => {
246                return Err(FunctionError {
247                    function_name: "LinearGradient".to_string(),
248                    error: FunctionEvaluationError::OverflowError,
249                })
250            }
251        },
252        VariableValue::Duration(d) => d.duration().num_milliseconds() as f64,
253        VariableValue::LocalDateTime(l) => l.and_utc().timestamp_millis() as f64,
254        VariableValue::ZonedDateTime(z) => z.datetime().timestamp_millis() as f64,
255        VariableValue::Date(d) => d.and_time(NaiveTime::MIN).and_utc().timestamp_millis() as f64,
256        VariableValue::LocalTime(l) => l.num_seconds_from_midnight() as f64,
257        VariableValue::ZonedTime(z) => z.time().num_seconds_from_midnight() as f64,
258        _ => {
259            return Err(FunctionError {
260                function_name: "LinearGradient".to_string(),
261                error: FunctionEvaluationError::InvalidArgument(index as usize),
262            })
263        }
264    };
265
266    Ok(result)
267}
268
269fn variance(m2: f64, count: i64) -> f64 {
270    if count < 2 {
271        return 0.0;
272    }
273    m2 / (count - 1) as f64
274}
275
276fn covariance(cov: f64, count: i64) -> f64 {
277    if count < 2 {
278        return 0.0;
279    }
280    cov / (count - 1) as f64
281}
282
283impl Debug for LinearGradient {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        write!(f, "LinearGradient")
286    }
287}