use std::{fmt::Debug, sync::Arc};
use crate::{
evaluation::{FunctionError, FunctionEvaluationError},
interface::ResultIndex,
};
use async_trait::async_trait;
use drasi_query_ast::ast;
use crate::evaluation::{
variable_value::float::Float, variable_value::VariableValue, ExpressionEvaluationContext,
};
use chrono::{NaiveTime, Timelike};
use super::{super::AggregatingFunction, Accumulator, ValueAccumulator};
pub struct LinearGradient {}
#[async_trait]
impl AggregatingFunction for LinearGradient {
fn initialize_accumulator(
&self,
_context: &ExpressionEvaluationContext,
_expression: &ast::FunctionExpression,
_grouping_keys: &Vec<VariableValue>,
_index: Arc<dyn ResultIndex>,
) -> Accumulator {
Accumulator::Value(ValueAccumulator::LinearGradient {
count: 0,
mean_x: 0.0,
mean_y: 0.0,
m2: 0.0,
cov: 0.0,
})
}
fn accumulator_is_lazy(&self) -> bool {
false
}
async fn apply(
&self,
_context: &ExpressionEvaluationContext,
args: Vec<VariableValue>,
accumulator: &mut Accumulator,
) -> Result<VariableValue, FunctionError> {
if args.len() != 2 {
return Err(FunctionError {
function_name: "linearGradient".to_string(),
error: FunctionEvaluationError::InvalidArgumentCount,
});
}
let (count, mean_x, mean_y, m2, cov) = match accumulator {
Accumulator::Value(ValueAccumulator::LinearGradient {
count,
mean_x,
mean_y,
m2,
cov,
}) => (count, mean_x, mean_y, m2, cov),
_ => {
return Err(FunctionError {
function_name: "LinearGradient".to_string(),
error: FunctionEvaluationError::CorruptData,
})
}
};
if let VariableValue::Null = args[0] {
return Ok(VariableValue::Null);
}
if let VariableValue::Null = args[1] {
return Ok(VariableValue::Null);
}
let x = extract_parameter(&args[0], 0)?;
let y = extract_parameter(&args[1], 1)?;
*count += 1;
let delta_x = x - *mean_x;
let delta_y = y - *mean_y;
*mean_x += delta_x / *count as f64;
*mean_y += delta_y / *count as f64;
let delta2 = x - *mean_x;
*m2 += delta_x * delta2;
*cov += delta_x * (y - *mean_y);
let result = covariance(*cov, *count) / variance(*m2, *count);
if result.is_nan() {
return Ok(VariableValue::Null);
}
Ok(VariableValue::Float(
Float::from_f64(result).unwrap_or_default(),
))
}
async fn revert(
&self,
_context: &ExpressionEvaluationContext,
args: Vec<VariableValue>,
accumulator: &mut Accumulator,
) -> Result<VariableValue, FunctionError> {
if args.len() != 2 {
return Err(FunctionError {
function_name: "linearGradient".to_string(),
error: FunctionEvaluationError::InvalidArgumentCount,
});
}
let (count, mean_x, mean_y, m2, cov) = match accumulator {
Accumulator::Value(ValueAccumulator::LinearGradient {
count,
mean_x,
mean_y,
m2,
cov,
}) => (count, mean_x, mean_y, m2, cov),
_ => {
return Err(FunctionError {
function_name: "LinearGradient".to_string(),
error: FunctionEvaluationError::CorruptData,
});
}
};
if let VariableValue::Null = args[0] {
return Ok(VariableValue::Null);
}
if let VariableValue::Null = args[1] {
return Ok(VariableValue::Null);
}
let x = extract_parameter(&args[0], 0)?;
let y = extract_parameter(&args[1], 1)?;
*count -= 1;
if *count == 0 {
*mean_x = 0.0;
*mean_y = 0.0;
*m2 = 0.0;
*cov = 0.0;
return Ok(VariableValue::Null);
}
let delta_x = x - *mean_x;
let delta_y = y - *mean_y;
*mean_x -= delta_x / *count as f64;
*mean_y -= delta_y / *count as f64;
let delta2 = x - *mean_x;
*m2 -= delta_x * delta2;
*cov -= delta_x * (y - *mean_y);
let result = covariance(*cov, *count) / variance(*m2, *count);
if result.is_nan() {
return Ok(VariableValue::Null);
}
Ok(VariableValue::Float(
Float::from_f64(result).unwrap_or_default(),
))
}
async fn snapshot(
&self,
_context: &ExpressionEvaluationContext,
args: Vec<VariableValue>,
accumulator: &Accumulator,
) -> Result<VariableValue, FunctionError> {
if args.len() != 2 {
return Err(FunctionError {
function_name: "linearGradient".to_string(),
error: FunctionEvaluationError::InvalidArgumentCount,
});
}
let (count, _mean_x, _mean_y, m2, cov) = match accumulator {
Accumulator::Value(ValueAccumulator::LinearGradient {
count,
mean_x,
mean_y,
m2,
cov,
}) => (count, mean_x, mean_y, m2, cov),
_ => {
return Err(FunctionError {
function_name: "LinearGradient".to_string(),
error: FunctionEvaluationError::CorruptData,
});
}
};
if *count == 0 {
return Ok(VariableValue::Null);
}
let result = covariance(*cov, *count) / variance(*m2, *count);
if result.is_nan() {
return Ok(VariableValue::Null);
}
Ok(VariableValue::Float(
Float::from_f64(result).unwrap_or_default(),
))
}
}
fn extract_parameter(p: &VariableValue, index: u64) -> Result<f64, FunctionError> {
let result = match p {
VariableValue::Float(n) => match n.as_f64() {
Some(n) => n,
None => {
return Err(FunctionError {
function_name: "LinearGradient".to_string(),
error: FunctionEvaluationError::OverflowError,
})
}
},
VariableValue::Integer(n) => match n.as_i64() {
Some(n) => n as f64,
None => {
return Err(FunctionError {
function_name: "LinearGradient".to_string(),
error: FunctionEvaluationError::OverflowError,
})
}
},
VariableValue::Duration(d) => d.duration().num_milliseconds() as f64,
VariableValue::LocalDateTime(l) => l.and_utc().timestamp_millis() as f64,
VariableValue::ZonedDateTime(z) => z.datetime().timestamp_millis() as f64,
VariableValue::Date(d) => d.and_time(NaiveTime::MIN).and_utc().timestamp_millis() as f64,
VariableValue::LocalTime(l) => l.num_seconds_from_midnight() as f64,
VariableValue::ZonedTime(z) => z.time().num_seconds_from_midnight() as f64,
_ => {
return Err(FunctionError {
function_name: "LinearGradient".to_string(),
error: FunctionEvaluationError::InvalidArgument(index as usize),
})
}
};
Ok(result)
}
fn variance(m2: f64, count: i64) -> f64 {
if count < 2 {
return 0.0;
}
m2 / (count - 1) as f64
}
fn covariance(cov: f64, count: i64) -> f64 {
if count < 2 {
return 0.0;
}
cov / (count - 1) as f64
}
impl Debug for LinearGradient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "LinearGradient")
}
}