drasi_core/evaluation/functions/aggregation/
linear_gradient.rs1use std::{fmt::Debug, sync::Arc};
16
17use crate::{
18 evaluation::{FunctionError, FunctionEvaluationError},
19 interface::ResultIndex,
20};
21
22use async_trait::async_trait;
23
24use drasi_query_ast::ast;
25
26use crate::evaluation::{
27 variable_value::float::Float, variable_value::VariableValue, ExpressionEvaluationContext,
28};
29
30use chrono::{NaiveTime, Timelike};
31
32use super::{super::AggregatingFunction, Accumulator, ValueAccumulator};
33
34pub struct LinearGradient {}
35
36#[async_trait]
37impl AggregatingFunction for LinearGradient {
38 fn initialize_accumulator(
39 &self,
40 _context: &ExpressionEvaluationContext,
41 _expression: &ast::FunctionExpression,
42 _grouping_keys: &Vec<VariableValue>,
43 _index: Arc<dyn ResultIndex>,
44 ) -> Accumulator {
45 Accumulator::Value(ValueAccumulator::LinearGradient {
46 count: 0,
47 mean_x: 0.0,
48 mean_y: 0.0,
49 m2: 0.0,
50 cov: 0.0,
51 })
52 }
53
54 fn accumulator_is_lazy(&self) -> bool {
55 false
56 }
57
58 async fn apply(
59 &self,
60 _context: &ExpressionEvaluationContext,
61 args: Vec<VariableValue>,
62 accumulator: &mut Accumulator,
63 ) -> Result<VariableValue, FunctionError> {
64 if args.len() != 2 {
65 return Err(FunctionError {
66 function_name: "linearGradient".to_string(),
67 error: FunctionEvaluationError::InvalidArgumentCount,
68 });
69 }
70
71 let (count, mean_x, mean_y, m2, cov) = match accumulator {
72 Accumulator::Value(ValueAccumulator::LinearGradient {
73 count,
74 mean_x,
75 mean_y,
76 m2,
77 cov,
78 }) => (count, mean_x, mean_y, m2, cov),
79 _ => {
80 return Err(FunctionError {
81 function_name: "LinearGradient".to_string(),
82 error: FunctionEvaluationError::CorruptData,
83 })
84 }
85 };
86
87 if let VariableValue::Null = args[0] {
88 return Ok(VariableValue::Null);
89 }
90
91 if let VariableValue::Null = args[1] {
92 return Ok(VariableValue::Null);
93 }
94
95 let x = extract_parameter(&args[0], 0)?;
96 let y = extract_parameter(&args[1], 1)?;
97
98 *count += 1;
99 let delta_x = x - *mean_x;
100 let delta_y = y - *mean_y;
101 *mean_x += delta_x / *count as f64;
102 *mean_y += delta_y / *count as f64;
103 let delta2 = x - *mean_x;
104 *m2 += delta_x * delta2;
105 *cov += delta_x * (y - *mean_y);
106
107 let result = covariance(*cov, *count) / variance(*m2, *count);
108
109 if result.is_nan() {
110 return Ok(VariableValue::Null);
111 }
112
113 Ok(VariableValue::Float(
114 Float::from_f64(result).unwrap_or_default(),
115 ))
116 }
117
118 async fn revert(
119 &self,
120 _context: &ExpressionEvaluationContext,
121 args: Vec<VariableValue>,
122 accumulator: &mut Accumulator,
123 ) -> Result<VariableValue, FunctionError> {
124 if args.len() != 2 {
125 return Err(FunctionError {
126 function_name: "linearGradient".to_string(),
127 error: FunctionEvaluationError::InvalidArgumentCount,
128 });
129 }
130
131 let (count, mean_x, mean_y, m2, cov) = match accumulator {
132 Accumulator::Value(ValueAccumulator::LinearGradient {
133 count,
134 mean_x,
135 mean_y,
136 m2,
137 cov,
138 }) => (count, mean_x, mean_y, m2, cov),
139 _ => {
140 return Err(FunctionError {
141 function_name: "LinearGradient".to_string(),
142 error: FunctionEvaluationError::CorruptData,
143 });
144 }
145 };
146
147 if let VariableValue::Null = args[0] {
148 return Ok(VariableValue::Null);
149 }
150
151 if let VariableValue::Null = args[1] {
152 return Ok(VariableValue::Null);
153 }
154
155 let x = extract_parameter(&args[0], 0)?;
156 let y = extract_parameter(&args[1], 1)?;
157
158 *count -= 1;
159
160 if *count == 0 {
161 *mean_x = 0.0;
162 *mean_y = 0.0;
163 *m2 = 0.0;
164 *cov = 0.0;
165 return Ok(VariableValue::Null);
166 }
167
168 let delta_x = x - *mean_x;
169 let delta_y = y - *mean_y;
170 *mean_x -= delta_x / *count as f64;
171 *mean_y -= delta_y / *count as f64;
172 let delta2 = x - *mean_x;
173 *m2 -= delta_x * delta2;
174 *cov -= delta_x * (y - *mean_y);
175
176 let result = covariance(*cov, *count) / variance(*m2, *count);
177
178 if result.is_nan() {
179 return Ok(VariableValue::Null);
180 }
181
182 Ok(VariableValue::Float(
183 Float::from_f64(result).unwrap_or_default(),
184 ))
185 }
186
187 async fn snapshot(
188 &self,
189 _context: &ExpressionEvaluationContext,
190 args: Vec<VariableValue>,
191 accumulator: &Accumulator,
192 ) -> Result<VariableValue, FunctionError> {
193 if args.len() != 2 {
194 return Err(FunctionError {
195 function_name: "linearGradient".to_string(),
196 error: FunctionEvaluationError::InvalidArgumentCount,
197 });
198 }
199
200 let (count, _mean_x, _mean_y, m2, cov) = match accumulator {
201 Accumulator::Value(ValueAccumulator::LinearGradient {
202 count,
203 mean_x,
204 mean_y,
205 m2,
206 cov,
207 }) => (count, mean_x, mean_y, m2, cov),
208 _ => {
209 return Err(FunctionError {
210 function_name: "LinearGradient".to_string(),
211 error: FunctionEvaluationError::CorruptData,
212 });
213 }
214 };
215
216 if *count == 0 {
217 return Ok(VariableValue::Null);
218 }
219
220 let result = covariance(*cov, *count) / variance(*m2, *count);
221
222 if result.is_nan() {
223 return Ok(VariableValue::Null);
224 }
225
226 Ok(VariableValue::Float(
227 Float::from_f64(result).unwrap_or_default(),
228 ))
229 }
230}
231
232fn extract_parameter(p: &VariableValue, index: u64) -> Result<f64, FunctionError> {
233 let result = match p {
234 VariableValue::Float(n) => match n.as_f64() {
235 Some(n) => n,
236 None => {
237 return Err(FunctionError {
238 function_name: "LinearGradient".to_string(),
239 error: FunctionEvaluationError::OverflowError,
240 })
241 }
242 },
243 VariableValue::Integer(n) => match n.as_i64() {
244 Some(n) => n as f64,
245 None => {
246 return Err(FunctionError {
247 function_name: "LinearGradient".to_string(),
248 error: FunctionEvaluationError::OverflowError,
249 })
250 }
251 },
252 VariableValue::Duration(d) => d.duration().num_milliseconds() as f64,
253 VariableValue::LocalDateTime(l) => l.and_utc().timestamp_millis() as f64,
254 VariableValue::ZonedDateTime(z) => z.datetime().timestamp_millis() as f64,
255 VariableValue::Date(d) => d.and_time(NaiveTime::MIN).and_utc().timestamp_millis() as f64,
256 VariableValue::LocalTime(l) => l.num_seconds_from_midnight() as f64,
257 VariableValue::ZonedTime(z) => z.time().num_seconds_from_midnight() as f64,
258 _ => {
259 return Err(FunctionError {
260 function_name: "LinearGradient".to_string(),
261 error: FunctionEvaluationError::InvalidArgument(index as usize),
262 })
263 }
264 };
265
266 Ok(result)
267}
268
269fn variance(m2: f64, count: i64) -> f64 {
270 if count < 2 {
271 return 0.0;
272 }
273 m2 / (count - 1) as f64
274}
275
276fn covariance(cov: f64, count: i64) -> f64 {
277 if count < 2 {
278 return 0.0;
279 }
280 cov / (count - 1) as f64
281}
282
283impl Debug for LinearGradient {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 write!(f, "LinearGradient")
286 }
287}