use std::f64::consts::LN_2;
use crate::interval_arithmetic::{apply_operator, Interval};
use crate::operator::Operator;
use crate::type_coercion::binary::binary_numeric_coercion;
use arrow::array::ArrowNativeTypeOp;
use arrow::datatypes::DataType;
use datafusion_common::rounding::alter_fp_rounding_mode;
use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
#[derive(Clone, Debug, PartialEq)]
pub enum Distribution {
Uniform(UniformDistribution),
Exponential(ExponentialDistribution),
Gaussian(GaussianDistribution),
Bernoulli(BernoulliDistribution),
Generic(GenericDistribution),
}
use Distribution::{Bernoulli, Exponential, Gaussian, Generic, Uniform};
impl Distribution {
pub fn new_uniform(interval: Interval) -> Result<Self> {
UniformDistribution::try_new(interval).map(Uniform)
}
pub fn new_exponential(
rate: ScalarValue,
offset: ScalarValue,
positive_tail: bool,
) -> Result<Self> {
ExponentialDistribution::try_new(rate, offset, positive_tail).map(Exponential)
}
pub fn new_gaussian(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
GaussianDistribution::try_new(mean, variance).map(Gaussian)
}
pub fn new_bernoulli(p: ScalarValue) -> Result<Self> {
BernoulliDistribution::try_new(p).map(Bernoulli)
}
pub fn new_generic(
mean: ScalarValue,
median: ScalarValue,
variance: ScalarValue,
range: Interval,
) -> Result<Self> {
GenericDistribution::try_new(mean, median, variance, range).map(Generic)
}
pub fn new_from_interval(range: Interval) -> Result<Self> {
let null = ScalarValue::try_from(range.data_type())?;
Distribution::new_generic(null.clone(), null.clone(), null, range)
}
pub fn mean(&self) -> Result<ScalarValue> {
match &self {
Uniform(u) => u.mean(),
Exponential(e) => e.mean(),
Gaussian(g) => Ok(g.mean().clone()),
Bernoulli(b) => Ok(b.mean().clone()),
Generic(u) => Ok(u.mean().clone()),
}
}
pub fn median(&self) -> Result<ScalarValue> {
match &self {
Uniform(u) => u.median(),
Exponential(e) => e.median(),
Gaussian(g) => Ok(g.median().clone()),
Bernoulli(b) => b.median(),
Generic(u) => Ok(u.median().clone()),
}
}
pub fn variance(&self) -> Result<ScalarValue> {
match &self {
Uniform(u) => u.variance(),
Exponential(e) => e.variance(),
Gaussian(g) => Ok(g.variance.clone()),
Bernoulli(b) => b.variance(),
Generic(u) => Ok(u.variance.clone()),
}
}
pub fn range(&self) -> Result<Interval> {
match &self {
Uniform(u) => Ok(u.range().clone()),
Exponential(e) => e.range(),
Gaussian(g) => g.range(),
Bernoulli(b) => Ok(b.range()),
Generic(u) => Ok(u.range().clone()),
}
}
pub fn data_type(&self) -> DataType {
match &self {
Uniform(u) => u.data_type(),
Exponential(e) => e.data_type(),
Gaussian(g) => g.data_type(),
Bernoulli(b) => b.data_type(),
Generic(u) => u.data_type(),
}
}
pub fn target_type(args: &[&ScalarValue]) -> Result<DataType> {
let mut arg_types = args
.iter()
.filter(|&&arg| (arg != &ScalarValue::Null))
.map(|&arg| arg.data_type());
let Some(dt) = arg_types.next().map_or_else(
|| Some(DataType::Null),
|first| {
arg_types
.try_fold(first, |target, arg| binary_numeric_coercion(&target, &arg))
},
) else {
return internal_err!("Can only evaluate statistics for numeric types");
};
Ok(dt)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct UniformDistribution {
interval: Interval,
}
#[derive(Clone, Debug, PartialEq)]
pub struct ExponentialDistribution {
rate: ScalarValue,
offset: ScalarValue,
positive_tail: bool,
}
#[derive(Clone, Debug, PartialEq)]
pub struct GaussianDistribution {
mean: ScalarValue,
variance: ScalarValue,
}
#[derive(Clone, Debug, PartialEq)]
pub struct BernoulliDistribution {
p: ScalarValue,
}
#[derive(Clone, Debug, PartialEq)]
pub struct GenericDistribution {
mean: ScalarValue,
median: ScalarValue,
variance: ScalarValue,
range: Interval,
}
impl UniformDistribution {
fn try_new(interval: Interval) -> Result<Self> {
if interval.data_type().eq(&DataType::Boolean) {
return internal_err!(
"Construction of a boolean `Uniform` distribution is prohibited, create a `Bernoulli` distribution instead."
);
}
Ok(Self { interval })
}
pub fn data_type(&self) -> DataType {
self.interval.data_type()
}
pub fn mean(&self) -> Result<ScalarValue> {
let dt = self.data_type();
let two = ScalarValue::from(2).cast_to(&dt)?;
let result = self
.interval
.lower()
.add_checked(self.interval.upper())?
.div(two);
debug_assert!(
!self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
);
result
}
pub fn median(&self) -> Result<ScalarValue> {
self.mean()
}
pub fn variance(&self) -> Result<ScalarValue> {
let width = self.interval.width()?;
let dt = width.data_type();
let twelve = ScalarValue::from(12).cast_to(&dt)?;
let result = width.mul_checked(&width)?.div(twelve);
debug_assert!(
!self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
);
result
}
pub fn range(&self) -> &Interval {
&self.interval
}
}
impl ExponentialDistribution {
fn try_new(
rate: ScalarValue,
offset: ScalarValue,
positive_tail: bool,
) -> Result<Self> {
let dt = rate.data_type();
if offset.data_type() != dt {
internal_err!("Rate and offset must have the same data type")
} else if offset.is_null() {
internal_err!("Offset of an `ExponentialDistribution` cannot be null")
} else if rate.is_null() {
internal_err!("Rate of an `ExponentialDistribution` cannot be null")
} else if rate.le(&ScalarValue::new_zero(&dt)?) {
internal_err!("Rate of an `ExponentialDistribution` must be positive")
} else {
Ok(Self {
rate,
offset,
positive_tail,
})
}
}
pub fn data_type(&self) -> DataType {
self.rate.data_type()
}
pub fn rate(&self) -> &ScalarValue {
&self.rate
}
pub fn offset(&self) -> &ScalarValue {
&self.offset
}
pub fn positive_tail(&self) -> bool {
self.positive_tail
}
pub fn mean(&self) -> Result<ScalarValue> {
let one = ScalarValue::new_one(&self.data_type())?;
let tail_mean = one.div(&self.rate)?;
if self.positive_tail {
self.offset.add_checked(tail_mean)
} else {
self.offset.sub_checked(tail_mean)
}
}
pub fn median(&self) -> Result<ScalarValue> {
let ln_two = ScalarValue::from(LN_2).cast_to(&self.data_type())?;
let tail_median = ln_two.div(&self.rate)?;
if self.positive_tail {
self.offset.add_checked(tail_median)
} else {
self.offset.sub_checked(tail_median)
}
}
pub fn variance(&self) -> Result<ScalarValue> {
let one = ScalarValue::new_one(&self.data_type())?;
let rate_squared = self.rate.mul_checked(&self.rate)?;
one.div(rate_squared)
}
pub fn range(&self) -> Result<Interval> {
let end = ScalarValue::try_from(&self.data_type())?;
if self.positive_tail {
Interval::try_new(self.offset.clone(), end)
} else {
Interval::try_new(end, self.offset.clone())
}
}
}
impl GaussianDistribution {
fn try_new(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
let dt = mean.data_type();
if variance.data_type() != dt {
internal_err!("Mean and variance must have the same data type")
} else if variance.is_null() {
internal_err!("Variance of a `GaussianDistribution` cannot be null")
} else if variance.lt(&ScalarValue::new_zero(&dt)?) {
internal_err!("Variance of a `GaussianDistribution` must be positive")
} else {
Ok(Self { mean, variance })
}
}
pub fn data_type(&self) -> DataType {
self.mean.data_type()
}
pub fn mean(&self) -> &ScalarValue {
&self.mean
}
pub fn variance(&self) -> &ScalarValue {
&self.variance
}
pub fn median(&self) -> &ScalarValue {
self.mean()
}
pub fn range(&self) -> Result<Interval> {
Interval::make_unbounded(&self.data_type())
}
}
impl BernoulliDistribution {
fn try_new(p: ScalarValue) -> Result<Self> {
if p.is_null() {
Ok(Self { p })
} else {
let dt = p.data_type();
let zero = ScalarValue::new_zero(&dt)?;
let one = ScalarValue::new_one(&dt)?;
if p.ge(&zero) && p.le(&one) {
Ok(Self { p })
} else {
internal_err!(
"Success probability of a `BernoulliDistribution` must be in [0, 1]"
)
}
}
}
pub fn data_type(&self) -> DataType {
self.p.data_type()
}
pub fn p_value(&self) -> &ScalarValue {
&self.p
}
pub fn mean(&self) -> &ScalarValue {
&self.p
}
pub fn median(&self) -> Result<ScalarValue> {
let dt = self.data_type();
if self.p.is_null() {
ScalarValue::try_from(&dt)
} else {
let one = ScalarValue::new_one(&dt)?;
if one.sub_checked(&self.p)?.lt(&self.p) {
ScalarValue::new_one(&dt)
} else {
ScalarValue::new_zero(&dt)
}
}
}
pub fn variance(&self) -> Result<ScalarValue> {
let dt = self.data_type();
let one = ScalarValue::new_one(&dt)?;
let result = one.sub_checked(&self.p)?.mul_checked(&self.p);
debug_assert!(!self.p.is_null() || result.as_ref().is_ok_and(|r| r.is_null()));
result
}
pub fn range(&self) -> Interval {
let dt = self.data_type();
if ScalarValue::new_zero(&dt).unwrap().eq(&self.p) {
Interval::CERTAINLY_FALSE
} else if ScalarValue::new_one(&dt).unwrap().eq(&self.p) {
Interval::CERTAINLY_TRUE
} else {
Interval::UNCERTAIN
}
}
}
impl GenericDistribution {
fn try_new(
mean: ScalarValue,
median: ScalarValue,
variance: ScalarValue,
range: Interval,
) -> Result<Self> {
if range.data_type().eq(&DataType::Boolean) {
return internal_err!(
"Construction of a boolean `Generic` distribution is prohibited, create a `Bernoulli` distribution instead."
);
}
let validate_location = |m: &ScalarValue| -> Result<bool> {
if m.is_null() {
Ok(true)
} else {
range.contains_value(m)
}
};
if !validate_location(&mean)?
|| !validate_location(&median)?
|| (!variance.is_null()
&& variance.lt(&ScalarValue::new_zero(&variance.data_type())?))
{
internal_err!("Tried to construct an invalid `GenericDistribution` instance")
} else {
Ok(Self {
mean,
median,
variance,
range,
})
}
}
pub fn data_type(&self) -> DataType {
self.mean.data_type()
}
pub fn mean(&self) -> &ScalarValue {
&self.mean
}
pub fn median(&self) -> &ScalarValue {
&self.median
}
pub fn variance(&self) -> &ScalarValue {
&self.variance
}
pub fn range(&self) -> &Interval {
&self.range
}
}
pub fn combine_bernoullis(
op: &Operator,
left: &BernoulliDistribution,
right: &BernoulliDistribution,
) -> Result<BernoulliDistribution> {
let left_p = left.p_value();
let right_p = right.p_value();
match op {
Operator::And => match (left_p.is_null(), right_p.is_null()) {
(false, false) => {
BernoulliDistribution::try_new(left_p.mul_checked(right_p)?)
}
(false, true) if left_p.eq(&ScalarValue::new_zero(&left_p.data_type())?) => {
Ok(left.clone())
}
(true, false)
if right_p.eq(&ScalarValue::new_zero(&right_p.data_type())?) =>
{
Ok(right.clone())
}
_ => {
let dt = Distribution::target_type(&[left_p, right_p])?;
BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
}
},
Operator::Or => match (left_p.is_null(), right_p.is_null()) {
(false, false) => {
let sum = left_p.add_checked(right_p)?;
let product = left_p.mul_checked(right_p)?;
let or_success = sum.sub_checked(product)?;
BernoulliDistribution::try_new(or_success)
}
(false, true) if left_p.eq(&ScalarValue::new_one(&left_p.data_type())?) => {
Ok(left.clone())
}
(true, false) if right_p.eq(&ScalarValue::new_one(&right_p.data_type())?) => {
Ok(right.clone())
}
_ => {
let dt = Distribution::target_type(&[left_p, right_p])?;
BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
}
},
_ => {
not_impl_err!("Statistical evaluation only supports AND and OR operators")
}
}
}
pub fn combine_gaussians(
op: &Operator,
left: &GaussianDistribution,
right: &GaussianDistribution,
) -> Result<Option<GaussianDistribution>> {
match op {
Operator::Plus => GaussianDistribution::try_new(
left.mean().add_checked(right.mean())?,
left.variance().add_checked(right.variance())?,
)
.map(Some),
Operator::Minus => GaussianDistribution::try_new(
left.mean().sub_checked(right.mean())?,
left.variance().add_checked(right.variance())?,
)
.map(Some),
_ => Ok(None),
}
}
pub fn create_bernoulli_from_comparison(
op: &Operator,
left: &Distribution,
right: &Distribution,
) -> Result<Distribution> {
match (left, right) {
(Uniform(left), Uniform(right)) => {
match op {
Operator::Eq | Operator::NotEq => {
let (li, ri) = (left.range(), right.range());
if let Some(intersection) = li.intersect(ri)? {
if let (Some(lc), Some(rc), Some(ic)) = (
li.cardinality(),
ri.cardinality(),
intersection.cardinality(),
) {
let pairs = ((lc as u128) * (rc as u128)) as f64;
let p = (ic as f64).div_checked(pairs)?;
let mut p_value = ScalarValue::from(p);
if op == &Operator::NotEq {
let one = ScalarValue::from(1.0);
p_value = alter_fp_rounding_mode::<false, _>(
&one,
&p_value,
|lhs, rhs| lhs.sub_checked(rhs),
)?;
};
return Distribution::new_bernoulli(p_value);
}
} else if op == &Operator::Eq {
return Distribution::new_bernoulli(ScalarValue::from(0.0));
} else {
return Distribution::new_bernoulli(ScalarValue::from(1.0));
}
}
Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => {
}
_ => {}
}
}
(Gaussian(_), Gaussian(_)) => {
}
_ => {}
}
let (li, ri) = (left.range()?, right.range()?);
let range_evaluation = apply_operator(op, &li, &ri)?;
if range_evaluation.eq(&Interval::CERTAINLY_FALSE) {
Distribution::new_bernoulli(ScalarValue::from(0.0))
} else if range_evaluation.eq(&Interval::CERTAINLY_TRUE) {
Distribution::new_bernoulli(ScalarValue::from(1.0))
} else if range_evaluation.eq(&Interval::UNCERTAIN) {
Distribution::new_bernoulli(ScalarValue::try_from(&DataType::Float64)?)
} else {
internal_err!("This function must be called with a comparison operator")
}
}
pub fn new_generic_from_binary_op(
op: &Operator,
left: &Distribution,
right: &Distribution,
) -> Result<Distribution> {
Distribution::new_generic(
compute_mean(op, left, right)?,
compute_median(op, left, right)?,
compute_variance(op, left, right)?,
apply_operator(op, &left.range()?, &right.range()?)?,
)
}
pub fn compute_mean(
op: &Operator,
left: &Distribution,
right: &Distribution,
) -> Result<ScalarValue> {
let (left_mean, right_mean) = (left.mean()?, right.mean()?);
match op {
Operator::Plus => return left_mean.add_checked(right_mean),
Operator::Minus => return left_mean.sub_checked(right_mean),
Operator::Multiply => return left_mean.mul_checked(right_mean),
Operator::Divide => {}
_ => {}
}
let target_type = Distribution::target_type(&[&left_mean, &right_mean])?;
ScalarValue::try_from(target_type)
}
pub fn compute_median(
op: &Operator,
left: &Distribution,
right: &Distribution,
) -> Result<ScalarValue> {
match (left, right) {
(Uniform(lu), Uniform(ru)) => {
let (left_median, right_median) = (lu.median()?, ru.median()?);
match op {
Operator::Plus => return left_median.add_checked(right_median),
Operator::Minus => return left_median.sub_checked(right_median),
_ => {}
}
}
(Gaussian(lg), Gaussian(rg)) => match op {
Operator::Plus => return lg.mean().add_checked(rg.mean()),
Operator::Minus => return lg.mean().sub_checked(rg.mean()),
_ => {}
},
_ => {}
}
let (left_median, right_median) = (left.median()?, right.median()?);
let target_type = Distribution::target_type(&[&left_median, &right_median])?;
ScalarValue::try_from(target_type)
}
pub fn compute_variance(
op: &Operator,
left: &Distribution,
right: &Distribution,
) -> Result<ScalarValue> {
let (left_variance, right_variance) = (left.variance()?, right.variance()?);
match op {
Operator::Plus => return left_variance.add_checked(right_variance),
Operator::Minus => return left_variance.add_checked(right_variance),
Operator::Multiply => {
let (left_mean, right_mean) = (left.mean()?, right.mean()?);
let left_mean_sq = left_mean.mul_checked(&left_mean)?;
let right_mean_sq = right_mean.mul_checked(&right_mean)?;
let left_sos = left_variance.add_checked(&left_mean_sq)?;
let right_sos = right_variance.add_checked(&right_mean_sq)?;
let pos = left_mean_sq.mul_checked(right_mean_sq)?;
return left_sos.mul_checked(right_sos)?.sub_checked(pos);
}
Operator::Divide => {}
_ => {}
}
let target_type = Distribution::target_type(&[&left_variance, &right_variance])?;
ScalarValue::try_from(target_type)
}
#[cfg(test)]
mod tests {
use super::{
combine_bernoullis, combine_gaussians, compute_mean, compute_median,
compute_variance, create_bernoulli_from_comparison, new_generic_from_binary_op,
BernoulliDistribution, Distribution, GaussianDistribution, UniformDistribution,
};
use crate::interval_arithmetic::{apply_operator, Interval};
use crate::operator::Operator;
use arrow::datatypes::DataType;
use datafusion_common::{HashSet, Result, ScalarValue};
#[test]
fn uniform_dist_is_valid_test() -> Result<()> {
assert_eq!(
Distribution::new_uniform(Interval::make_zero(&DataType::Int8)?)?,
Distribution::Uniform(UniformDistribution {
interval: Interval::make_zero(&DataType::Int8)?,
})
);
assert!(Distribution::new_uniform(Interval::UNCERTAIN).is_err());
Ok(())
}
#[test]
fn exponential_dist_is_valid_test() {
let exponentials = vec![
(
Distribution::new_exponential(ScalarValue::Null, ScalarValue::Null, true),
false,
),
(
Distribution::new_exponential(
ScalarValue::from(0_f32),
ScalarValue::from(1_f32),
true,
),
false,
),
(
Distribution::new_exponential(
ScalarValue::from(100_f32),
ScalarValue::from(1_f32),
true,
),
true,
),
(
Distribution::new_exponential(
ScalarValue::from(-100_f32),
ScalarValue::from(1_f32),
true,
),
false,
),
];
for case in exponentials {
assert_eq!(case.0.is_ok(), case.1);
}
}
#[test]
fn gaussian_dist_is_valid_test() {
let gaussians = vec![
(
Distribution::new_gaussian(ScalarValue::Null, ScalarValue::Null),
false,
),
(
Distribution::new_gaussian(
ScalarValue::from(0_f32),
ScalarValue::from(0_f32),
),
true,
),
(
Distribution::new_gaussian(
ScalarValue::from(0_f32),
ScalarValue::from(0.5_f32),
),
true,
),
(
Distribution::new_gaussian(
ScalarValue::from(0_f32),
ScalarValue::from(-0.5_f32),
),
false,
),
];
for case in gaussians {
assert_eq!(case.0.is_ok(), case.1);
}
}
#[test]
fn bernoulli_dist_is_valid_test() {
let bernoullis = vec![
(Distribution::new_bernoulli(ScalarValue::Null), true),
(Distribution::new_bernoulli(ScalarValue::from(0.)), true),
(Distribution::new_bernoulli(ScalarValue::from(0.25)), true),
(Distribution::new_bernoulli(ScalarValue::from(1.)), true),
(Distribution::new_bernoulli(ScalarValue::from(11.)), false),
(Distribution::new_bernoulli(ScalarValue::from(-11.)), false),
(Distribution::new_bernoulli(ScalarValue::from(0_i64)), true),
(Distribution::new_bernoulli(ScalarValue::from(1_i64)), true),
(
Distribution::new_bernoulli(ScalarValue::from(11_i64)),
false,
),
(
Distribution::new_bernoulli(ScalarValue::from(-11_i64)),
false,
),
];
for case in bernoullis {
assert_eq!(case.0.is_ok(), case.1);
}
}
#[test]
fn generic_dist_is_valid_test() -> Result<()> {
let generic_dists = vec![
(
Distribution::new_generic(
ScalarValue::Null,
ScalarValue::Null,
ScalarValue::Null,
Interval::UNCERTAIN,
),
false,
),
(
Distribution::new_generic(
ScalarValue::Null,
ScalarValue::Null,
ScalarValue::Null,
Interval::make_zero(&DataType::Float32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(0_f32),
ScalarValue::Float32(None),
ScalarValue::Float32(None),
Interval::make_zero(&DataType::Float32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::Float64(None),
ScalarValue::from(0.),
ScalarValue::Float64(None),
Interval::make_zero(&DataType::Float32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(-10_f32),
ScalarValue::Float32(None),
ScalarValue::Float32(None),
Interval::make_zero(&DataType::Float32)?,
),
false,
),
(
Distribution::new_generic(
ScalarValue::Float32(None),
ScalarValue::from(10_f32),
ScalarValue::Float32(None),
Interval::make_zero(&DataType::Float32)?,
),
false,
),
(
Distribution::new_generic(
ScalarValue::Null,
ScalarValue::Null,
ScalarValue::Null,
Interval::make_zero(&DataType::Float32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(0),
ScalarValue::from(0),
ScalarValue::Int32(None),
Interval::make_zero(&DataType::Int32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(0_f32),
ScalarValue::from(0_f32),
ScalarValue::Float32(None),
Interval::make_zero(&DataType::Float32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(50.),
ScalarValue::from(50.),
ScalarValue::Float64(None),
Interval::make(Some(0.), Some(100.))?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(50.),
ScalarValue::from(50.),
ScalarValue::Float64(None),
Interval::make(Some(-100.), Some(0.))?,
),
false,
),
(
Distribution::new_generic(
ScalarValue::Float64(None),
ScalarValue::Float64(None),
ScalarValue::from(1.),
Interval::make_zero(&DataType::Float64)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::Float64(None),
ScalarValue::Float64(None),
ScalarValue::from(-1.),
Interval::make_zero(&DataType::Float64)?,
),
false,
),
];
for case in generic_dists {
assert_eq!(case.0.is_ok(), case.1, "{:?}", case.0);
}
Ok(())
}
#[test]
fn mean_extraction_test() -> Result<()> {
let dists = vec![
(
Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
ScalarValue::from(0_i64),
),
(
Distribution::new_uniform(Interval::make_zero(&DataType::Float64)?),
ScalarValue::from(0.),
),
(
Distribution::new_uniform(Interval::make(Some(1), Some(100))?),
ScalarValue::from(50),
),
(
Distribution::new_uniform(Interval::make(Some(-100), Some(-1))?),
ScalarValue::from(-50),
),
(
Distribution::new_uniform(Interval::make(Some(-100), Some(100))?),
ScalarValue::from(0),
),
(
Distribution::new_exponential(
ScalarValue::from(2.),
ScalarValue::from(0.),
true,
),
ScalarValue::from(0.5),
),
(
Distribution::new_exponential(
ScalarValue::from(2.),
ScalarValue::from(1.),
true,
),
ScalarValue::from(1.5),
),
(
Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
ScalarValue::from(0.),
),
(
Distribution::new_gaussian(
ScalarValue::from(-2.),
ScalarValue::from(0.5),
),
ScalarValue::from(-2.),
),
(
Distribution::new_bernoulli(ScalarValue::from(0.5)),
ScalarValue::from(0.5),
),
(
Distribution::new_generic(
ScalarValue::from(42.),
ScalarValue::from(42.),
ScalarValue::Float64(None),
Interval::make(Some(25.), Some(50.))?,
),
ScalarValue::from(42.),
),
];
for case in dists {
assert_eq!(case.0?.mean()?, case.1);
}
Ok(())
}
#[test]
fn median_extraction_test() -> Result<()> {
let dists = vec![
(
Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
ScalarValue::from(0_i64),
),
(
Distribution::new_uniform(Interval::make(Some(25.), Some(75.))?),
ScalarValue::from(50.),
),
(
Distribution::new_exponential(
ScalarValue::from(2_f64.ln()),
ScalarValue::from(0.),
true,
),
ScalarValue::from(1.),
),
(
Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
ScalarValue::from(2.),
),
(
Distribution::new_bernoulli(ScalarValue::from(0.25)),
ScalarValue::from(0.),
),
(
Distribution::new_bernoulli(ScalarValue::from(0.75)),
ScalarValue::from(1.),
),
(
Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
ScalarValue::from(2.),
),
(
Distribution::new_generic(
ScalarValue::from(12.),
ScalarValue::from(12.),
ScalarValue::Float64(None),
Interval::make(Some(0.), Some(25.))?,
),
ScalarValue::from(12.),
),
];
for case in dists {
assert_eq!(case.0?.median()?, case.1);
}
Ok(())
}
#[test]
fn variance_extraction_test() -> Result<()> {
let dists = vec![
(
Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?),
ScalarValue::from(12.),
),
(
Distribution::new_exponential(
ScalarValue::from(10.),
ScalarValue::from(0.),
true,
),
ScalarValue::from(0.01),
),
(
Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
ScalarValue::from(1.),
),
(
Distribution::new_bernoulli(ScalarValue::from(0.5)),
ScalarValue::from(0.25),
),
(
Distribution::new_generic(
ScalarValue::Float64(None),
ScalarValue::Float64(None),
ScalarValue::from(0.02),
Interval::make_zero(&DataType::Float64)?,
),
ScalarValue::from(0.02),
),
];
for case in dists {
assert_eq!(case.0?.variance()?, case.1);
}
Ok(())
}
#[test]
fn test_calculate_generic_properties_gauss_gauss() -> Result<()> {
let dist_a =
Distribution::new_gaussian(ScalarValue::from(10.), ScalarValue::from(0.0))?;
let dist_b =
Distribution::new_gaussian(ScalarValue::from(20.), ScalarValue::from(0.0))?;
let test_data = vec![
(
compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
ScalarValue::from(30.),
),
(
compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
ScalarValue::from(-10.),
),
(
compute_median(&Operator::Plus, &dist_a, &dist_b)?,
ScalarValue::from(30.),
),
(
compute_median(&Operator::Minus, &dist_a, &dist_b)?,
ScalarValue::from(-10.),
),
];
for (actual, expected) in test_data {
assert_eq!(actual, expected);
}
Ok(())
}
#[test]
fn test_combine_bernoullis_and_op() -> Result<()> {
let op = Operator::And;
let left = BernoulliDistribution::try_new(ScalarValue::from(0.5))?;
let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
assert_eq!(
combine_bernoullis(&op, &left, &right)?.p_value(),
&ScalarValue::from(0.5 * 0.4)
);
assert_eq!(
combine_bernoullis(&op, &left_null, &right)?.p_value(),
&ScalarValue::Float64(None)
);
assert_eq!(
combine_bernoullis(&op, &left, &right_null)?.p_value(),
&ScalarValue::Float64(None)
);
assert_eq!(
combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
&ScalarValue::Null
);
Ok(())
}
#[test]
fn test_combine_bernoullis_or_op() -> Result<()> {
let op = Operator::Or;
let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
assert_eq!(
combine_bernoullis(&op, &left, &right)?.p_value(),
&ScalarValue::from(0.6 + 0.4 - (0.6 * 0.4))
);
assert_eq!(
combine_bernoullis(&op, &left_null, &right)?.p_value(),
&ScalarValue::Float64(None)
);
assert_eq!(
combine_bernoullis(&op, &left, &right_null)?.p_value(),
&ScalarValue::Float64(None)
);
assert_eq!(
combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
&ScalarValue::Null
);
Ok(())
}
#[test]
fn test_combine_bernoullis_unsupported_ops() -> Result<()> {
let mut operator_set = operator_set();
operator_set.remove(&Operator::And);
operator_set.remove(&Operator::Or);
let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
for op in operator_set {
assert!(
combine_bernoullis(&op, &left, &right).is_err(),
"Operator {op} should not be supported for Bernoulli distributions"
);
}
Ok(())
}
#[test]
fn test_combine_gaussians_addition() -> Result<()> {
let op = Operator::Plus;
let left = GaussianDistribution::try_new(
ScalarValue::from(3.0),
ScalarValue::from(2.0),
)?;
let right = GaussianDistribution::try_new(
ScalarValue::from(4.0),
ScalarValue::from(1.0),
)?;
let result = combine_gaussians(&op, &left, &right)?.unwrap();
assert_eq!(result.mean(), &ScalarValue::from(7.0)); assert_eq!(result.variance(), &ScalarValue::from(3.0)); Ok(())
}
#[test]
fn test_combine_gaussians_subtraction() -> Result<()> {
let op = Operator::Minus;
let left = GaussianDistribution::try_new(
ScalarValue::from(7.0),
ScalarValue::from(2.0),
)?;
let right = GaussianDistribution::try_new(
ScalarValue::from(4.0),
ScalarValue::from(1.0),
)?;
let result = combine_gaussians(&op, &left, &right)?.unwrap();
assert_eq!(result.mean(), &ScalarValue::from(3.0)); assert_eq!(result.variance(), &ScalarValue::from(3.0));
Ok(())
}
#[test]
fn test_combine_gaussians_unsupported_ops() -> Result<()> {
let mut operator_set = operator_set();
operator_set.remove(&Operator::Plus);
operator_set.remove(&Operator::Minus);
let left = GaussianDistribution::try_new(
ScalarValue::from(7.0),
ScalarValue::from(2.0),
)?;
let right = GaussianDistribution::try_new(
ScalarValue::from(4.0),
ScalarValue::from(1.0),
)?;
for op in operator_set {
assert!(
combine_gaussians(&op, &left, &right)?.is_none(),
"Operator {op} should not be supported for Gaussian distributions"
);
}
Ok(())
}
#[test]
fn test_calculate_generic_properties_uniform_uniform() -> Result<()> {
let dist_a = Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?)?;
let dist_b = Distribution::new_uniform(Interval::make(Some(12.), Some(36.))?)?;
let test_data = vec![
(
compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
ScalarValue::from(30.),
),
(
compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
ScalarValue::from(-18.),
),
(
compute_mean(&Operator::Multiply, &dist_a, &dist_b)?,
ScalarValue::from(144.),
),
(
compute_median(&Operator::Plus, &dist_a, &dist_b)?,
ScalarValue::from(30.),
),
(
compute_median(&Operator::Minus, &dist_a, &dist_b)?,
ScalarValue::from(-18.),
),
(
compute_variance(&Operator::Plus, &dist_a, &dist_b)?,
ScalarValue::from(60.),
),
(
compute_variance(&Operator::Minus, &dist_a, &dist_b)?,
ScalarValue::from(60.),
),
(
compute_variance(&Operator::Multiply, &dist_a, &dist_b)?,
ScalarValue::from(9216.),
),
];
for (actual, expected) in test_data {
assert_eq!(actual, expected);
}
Ok(())
}
#[test]
fn test_compute_range_where_present() -> Result<()> {
let a = &Interval::make(Some(0.), Some(12.0))?;
let b = &Interval::make(Some(0.), Some(12.0))?;
let mean = ScalarValue::from(6.0);
for (dist_a, dist_b) in [
(
Distribution::new_uniform(a.clone())?,
Distribution::new_uniform(b.clone())?,
),
(
Distribution::new_generic(
mean.clone(),
mean.clone(),
ScalarValue::Float64(None),
a.clone(),
)?,
Distribution::new_uniform(b.clone())?,
),
(
Distribution::new_uniform(a.clone())?,
Distribution::new_generic(
mean.clone(),
mean.clone(),
ScalarValue::Float64(None),
b.clone(),
)?,
),
(
Distribution::new_generic(
mean.clone(),
mean.clone(),
ScalarValue::Float64(None),
a.clone(),
)?,
Distribution::new_generic(
mean.clone(),
mean.clone(),
ScalarValue::Float64(None),
b.clone(),
)?,
),
] {
use super::Operator::{
Divide, Eq, Gt, GtEq, Lt, LtEq, Minus, Multiply, NotEq, Plus,
};
for op in [Plus, Minus, Multiply, Divide] {
assert_eq!(
new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?,
apply_operator(&op, a, b)?,
"Failed for {dist_a:?} {op} {dist_b:?}"
);
}
for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] {
assert_eq!(
create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?,
apply_operator(&op, a, b)?,
"Failed for {dist_a:?} {op} {dist_b:?}"
);
}
}
Ok(())
}
fn operator_set() -> HashSet<Operator> {
use super::Operator::*;
let all_ops = vec![
And,
Or,
Eq,
NotEq,
Gt,
GtEq,
Lt,
LtEq,
Plus,
Minus,
Multiply,
Divide,
Modulo,
IsDistinctFrom,
IsNotDistinctFrom,
RegexMatch,
RegexIMatch,
RegexNotMatch,
RegexNotIMatch,
LikeMatch,
ILikeMatch,
NotLikeMatch,
NotILikeMatch,
BitwiseAnd,
BitwiseOr,
BitwiseXor,
BitwiseShiftRight,
BitwiseShiftLeft,
StringConcat,
AtArrow,
ArrowAt,
];
all_ops.into_iter().collect()
}
}