1use std::f64::consts::LN_2;
19
20use crate::interval_arithmetic::{apply_operator, Interval};
21use crate::operator::Operator;
22use crate::type_coercion::binary::binary_numeric_coercion;
23
24use arrow::array::ArrowNativeTypeOp;
25use arrow::datatypes::DataType;
26use datafusion_common::rounding::alter_fp_rounding_mode;
27use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
28
29#[derive(Clone, Debug, PartialEq)]
38pub enum Distribution {
39 Uniform(UniformDistribution),
40 Exponential(ExponentialDistribution),
41 Gaussian(GaussianDistribution),
42 Bernoulli(BernoulliDistribution),
43 Generic(GenericDistribution),
44}
45
46use Distribution::{Bernoulli, Exponential, Gaussian, Generic, Uniform};
47
48impl Distribution {
49 pub fn new_uniform(interval: Interval) -> Result<Self> {
51 UniformDistribution::try_new(interval).map(Uniform)
52 }
53
54 pub fn new_exponential(
57 rate: ScalarValue,
58 offset: ScalarValue,
59 positive_tail: bool,
60 ) -> Result<Self> {
61 ExponentialDistribution::try_new(rate, offset, positive_tail).map(Exponential)
62 }
63
64 pub fn new_gaussian(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
67 GaussianDistribution::try_new(mean, variance).map(Gaussian)
68 }
69
70 pub fn new_bernoulli(p: ScalarValue) -> Result<Self> {
73 BernoulliDistribution::try_new(p).map(Bernoulli)
74 }
75
76 pub fn new_generic(
79 mean: ScalarValue,
80 median: ScalarValue,
81 variance: ScalarValue,
82 range: Interval,
83 ) -> Result<Self> {
84 GenericDistribution::try_new(mean, median, variance, range).map(Generic)
85 }
86
87 pub fn new_from_interval(range: Interval) -> Result<Self> {
90 let null = ScalarValue::try_from(range.data_type())?;
91 Distribution::new_generic(null.clone(), null.clone(), null, range)
92 }
93
94 pub fn mean(&self) -> Result<ScalarValue> {
105 match &self {
106 Uniform(u) => u.mean(),
107 Exponential(e) => e.mean(),
108 Gaussian(g) => Ok(g.mean().clone()),
109 Bernoulli(b) => Ok(b.mean().clone()),
110 Generic(u) => Ok(u.mean().clone()),
111 }
112 }
113
114 pub fn median(&self) -> Result<ScalarValue> {
127 match &self {
128 Uniform(u) => u.median(),
129 Exponential(e) => e.median(),
130 Gaussian(g) => Ok(g.median().clone()),
131 Bernoulli(b) => b.median(),
132 Generic(u) => Ok(u.median().clone()),
133 }
134 }
135
136 pub fn variance(&self) -> Result<ScalarValue> {
148 match &self {
149 Uniform(u) => u.variance(),
150 Exponential(e) => e.variance(),
151 Gaussian(g) => Ok(g.variance.clone()),
152 Bernoulli(b) => b.variance(),
153 Generic(u) => Ok(u.variance.clone()),
154 }
155 }
156
157 pub fn range(&self) -> Result<Interval> {
168 match &self {
169 Uniform(u) => Ok(u.range().clone()),
170 Exponential(e) => e.range(),
171 Gaussian(g) => g.range(),
172 Bernoulli(b) => Ok(b.range()),
173 Generic(u) => Ok(u.range().clone()),
174 }
175 }
176
177 pub fn data_type(&self) -> DataType {
180 match &self {
181 Uniform(u) => u.data_type(),
182 Exponential(e) => e.data_type(),
183 Gaussian(g) => g.data_type(),
184 Bernoulli(b) => b.data_type(),
185 Generic(u) => u.data_type(),
186 }
187 }
188
189 pub fn target_type(args: &[&ScalarValue]) -> Result<DataType> {
190 let mut arg_types = args
191 .iter()
192 .filter(|&&arg| (arg != &ScalarValue::Null))
193 .map(|&arg| arg.data_type());
194
195 let Some(dt) = arg_types.next().map_or_else(
196 || Some(DataType::Null),
197 |first| {
198 arg_types
199 .try_fold(first, |target, arg| binary_numeric_coercion(&target, &arg))
200 },
201 ) else {
202 return internal_err!("Can only evaluate statistics for numeric types");
203 };
204 Ok(dt)
205 }
206}
207
208#[derive(Clone, Debug, PartialEq)]
215pub struct UniformDistribution {
216 interval: Interval,
217}
218
219#[derive(Clone, Debug, PartialEq)]
237pub struct ExponentialDistribution {
238 rate: ScalarValue,
239 offset: ScalarValue,
240 positive_tail: bool,
243}
244
245#[derive(Clone, Debug, PartialEq)]
250pub struct GaussianDistribution {
251 mean: ScalarValue,
252 variance: ScalarValue,
253}
254
255#[derive(Clone, Debug, PartialEq)]
260pub struct BernoulliDistribution {
261 p: ScalarValue,
262}
263
264#[derive(Clone, Debug, PartialEq)]
269pub struct GenericDistribution {
270 mean: ScalarValue,
271 median: ScalarValue,
272 variance: ScalarValue,
273 range: Interval,
274}
275
276impl UniformDistribution {
277 fn try_new(interval: Interval) -> Result<Self> {
278 if interval.data_type().eq(&DataType::Boolean) {
279 return internal_err!(
280 "Construction of a boolean `Uniform` distribution is prohibited, create a `Bernoulli` distribution instead."
281 );
282 }
283
284 Ok(Self { interval })
285 }
286
287 pub fn data_type(&self) -> DataType {
288 self.interval.data_type()
289 }
290
291 pub fn mean(&self) -> Result<ScalarValue> {
295 let dt = self.data_type();
297 let two = ScalarValue::from(2).cast_to(&dt)?;
298 let result = self
299 .interval
300 .lower()
301 .add_checked(self.interval.upper())?
302 .div(two);
303 debug_assert!(
304 !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
305 );
306 result
307 }
308
309 pub fn median(&self) -> Result<ScalarValue> {
310 self.mean()
311 }
312
313 pub fn variance(&self) -> Result<ScalarValue> {
317 let width = self.interval.width()?;
319 let dt = width.data_type();
320 let twelve = ScalarValue::from(12).cast_to(&dt)?;
321 let result = width.mul_checked(&width)?.div(twelve);
322 debug_assert!(
323 !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
324 );
325 result
326 }
327
328 pub fn range(&self) -> &Interval {
329 &self.interval
330 }
331}
332
333impl ExponentialDistribution {
334 fn try_new(
335 rate: ScalarValue,
336 offset: ScalarValue,
337 positive_tail: bool,
338 ) -> Result<Self> {
339 let dt = rate.data_type();
340 if offset.data_type() != dt {
341 internal_err!("Rate and offset must have the same data type")
342 } else if offset.is_null() {
343 internal_err!("Offset of an `ExponentialDistribution` cannot be null")
344 } else if rate.is_null() {
345 internal_err!("Rate of an `ExponentialDistribution` cannot be null")
346 } else if rate.le(&ScalarValue::new_zero(&dt)?) {
347 internal_err!("Rate of an `ExponentialDistribution` must be positive")
348 } else {
349 Ok(Self {
350 rate,
351 offset,
352 positive_tail,
353 })
354 }
355 }
356
357 pub fn data_type(&self) -> DataType {
358 self.rate.data_type()
359 }
360
361 pub fn rate(&self) -> &ScalarValue {
362 &self.rate
363 }
364
365 pub fn offset(&self) -> &ScalarValue {
366 &self.offset
367 }
368
369 pub fn positive_tail(&self) -> bool {
370 self.positive_tail
371 }
372
373 pub fn mean(&self) -> Result<ScalarValue> {
374 let one = ScalarValue::new_one(&self.data_type())?;
376 let tail_mean = one.div(&self.rate)?;
377 if self.positive_tail {
378 self.offset.add_checked(tail_mean)
379 } else {
380 self.offset.sub_checked(tail_mean)
381 }
382 }
383
384 pub fn median(&self) -> Result<ScalarValue> {
385 let ln_two = ScalarValue::from(LN_2).cast_to(&self.data_type())?;
387 let tail_median = ln_two.div(&self.rate)?;
388 if self.positive_tail {
389 self.offset.add_checked(tail_median)
390 } else {
391 self.offset.sub_checked(tail_median)
392 }
393 }
394
395 pub fn variance(&self) -> Result<ScalarValue> {
396 let one = ScalarValue::new_one(&self.data_type())?;
398 let rate_squared = self.rate.mul_checked(&self.rate)?;
399 one.div(rate_squared)
400 }
401
402 pub fn range(&self) -> Result<Interval> {
403 let end = ScalarValue::try_from(&self.data_type())?;
404 if self.positive_tail {
405 Interval::try_new(self.offset.clone(), end)
406 } else {
407 Interval::try_new(end, self.offset.clone())
408 }
409 }
410}
411
412impl GaussianDistribution {
413 fn try_new(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
414 let dt = mean.data_type();
415 if variance.data_type() != dt {
416 internal_err!("Mean and variance must have the same data type")
417 } else if variance.is_null() {
418 internal_err!("Variance of a `GaussianDistribution` cannot be null")
419 } else if variance.lt(&ScalarValue::new_zero(&dt)?) {
420 internal_err!("Variance of a `GaussianDistribution` must be positive")
421 } else {
422 Ok(Self { mean, variance })
423 }
424 }
425
426 pub fn data_type(&self) -> DataType {
427 self.mean.data_type()
428 }
429
430 pub fn mean(&self) -> &ScalarValue {
431 &self.mean
432 }
433
434 pub fn variance(&self) -> &ScalarValue {
435 &self.variance
436 }
437
438 pub fn median(&self) -> &ScalarValue {
439 self.mean()
440 }
441
442 pub fn range(&self) -> Result<Interval> {
443 Interval::make_unbounded(&self.data_type())
444 }
445}
446
447impl BernoulliDistribution {
448 fn try_new(p: ScalarValue) -> Result<Self> {
449 if p.is_null() {
450 Ok(Self { p })
451 } else {
452 let dt = p.data_type();
453 let zero = ScalarValue::new_zero(&dt)?;
454 let one = ScalarValue::new_one(&dt)?;
455 if p.ge(&zero) && p.le(&one) {
456 Ok(Self { p })
457 } else {
458 internal_err!(
459 "Success probability of a `BernoulliDistribution` must be in [0, 1]"
460 )
461 }
462 }
463 }
464
465 pub fn data_type(&self) -> DataType {
466 self.p.data_type()
467 }
468
469 pub fn p_value(&self) -> &ScalarValue {
470 &self.p
471 }
472
473 pub fn mean(&self) -> &ScalarValue {
474 &self.p
475 }
476
477 pub fn median(&self) -> Result<ScalarValue> {
480 let dt = self.data_type();
481 if self.p.is_null() {
482 ScalarValue::try_from(&dt)
483 } else {
484 let one = ScalarValue::new_one(&dt)?;
485 if one.sub_checked(&self.p)?.lt(&self.p) {
486 ScalarValue::new_one(&dt)
487 } else {
488 ScalarValue::new_zero(&dt)
489 }
490 }
491 }
492
493 pub fn variance(&self) -> Result<ScalarValue> {
496 let dt = self.data_type();
497 let one = ScalarValue::new_one(&dt)?;
498 let result = one.sub_checked(&self.p)?.mul_checked(&self.p);
499 debug_assert!(!self.p.is_null() || result.as_ref().is_ok_and(|r| r.is_null()));
500 result
501 }
502
503 pub fn range(&self) -> Interval {
504 let dt = self.data_type();
505 if ScalarValue::new_zero(&dt).unwrap().eq(&self.p) {
508 Interval::CERTAINLY_FALSE
509 } else if ScalarValue::new_one(&dt).unwrap().eq(&self.p) {
510 Interval::CERTAINLY_TRUE
511 } else {
512 Interval::UNCERTAIN
513 }
514 }
515}
516
517impl GenericDistribution {
518 fn try_new(
519 mean: ScalarValue,
520 median: ScalarValue,
521 variance: ScalarValue,
522 range: Interval,
523 ) -> Result<Self> {
524 if range.data_type().eq(&DataType::Boolean) {
525 return internal_err!(
526 "Construction of a boolean `Generic` distribution is prohibited, create a `Bernoulli` distribution instead."
527 );
528 }
529
530 let validate_location = |m: &ScalarValue| -> Result<bool> {
531 if m.is_null() {
533 Ok(true)
534 } else {
535 range.contains_value(m)
536 }
537 };
538
539 if !validate_location(&mean)?
540 || !validate_location(&median)?
541 || (!variance.is_null()
542 && variance.lt(&ScalarValue::new_zero(&variance.data_type())?))
543 {
544 internal_err!("Tried to construct an invalid `GenericDistribution` instance")
545 } else {
546 Ok(Self {
547 mean,
548 median,
549 variance,
550 range,
551 })
552 }
553 }
554
555 pub fn data_type(&self) -> DataType {
556 self.mean.data_type()
557 }
558
559 pub fn mean(&self) -> &ScalarValue {
560 &self.mean
561 }
562
563 pub fn median(&self) -> &ScalarValue {
564 &self.median
565 }
566
567 pub fn variance(&self) -> &ScalarValue {
568 &self.variance
569 }
570
571 pub fn range(&self) -> &Interval {
572 &self.range
573 }
574}
575
576pub fn combine_bernoullis(
580 op: &Operator,
581 left: &BernoulliDistribution,
582 right: &BernoulliDistribution,
583) -> Result<BernoulliDistribution> {
584 let left_p = left.p_value();
585 let right_p = right.p_value();
586 match op {
587 Operator::And => match (left_p.is_null(), right_p.is_null()) {
588 (false, false) => {
589 BernoulliDistribution::try_new(left_p.mul_checked(right_p)?)
590 }
591 (false, true) if left_p.eq(&ScalarValue::new_zero(&left_p.data_type())?) => {
592 Ok(left.clone())
593 }
594 (true, false)
595 if right_p.eq(&ScalarValue::new_zero(&right_p.data_type())?) =>
596 {
597 Ok(right.clone())
598 }
599 _ => {
600 let dt = Distribution::target_type(&[left_p, right_p])?;
601 BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
602 }
603 },
604 Operator::Or => match (left_p.is_null(), right_p.is_null()) {
605 (false, false) => {
606 let sum = left_p.add_checked(right_p)?;
607 let product = left_p.mul_checked(right_p)?;
608 let or_success = sum.sub_checked(product)?;
609 BernoulliDistribution::try_new(or_success)
610 }
611 (false, true) if left_p.eq(&ScalarValue::new_one(&left_p.data_type())?) => {
612 Ok(left.clone())
613 }
614 (true, false) if right_p.eq(&ScalarValue::new_one(&right_p.data_type())?) => {
615 Ok(right.clone())
616 }
617 _ => {
618 let dt = Distribution::target_type(&[left_p, right_p])?;
619 BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
620 }
621 },
622 _ => {
623 not_impl_err!("Statistical evaluation only supports AND and OR operators")
624 }
625 }
626}
627
628pub fn combine_gaussians(
635 op: &Operator,
636 left: &GaussianDistribution,
637 right: &GaussianDistribution,
638) -> Result<Option<GaussianDistribution>> {
639 match op {
640 Operator::Plus => GaussianDistribution::try_new(
641 left.mean().add_checked(right.mean())?,
642 left.variance().add_checked(right.variance())?,
643 )
644 .map(Some),
645 Operator::Minus => GaussianDistribution::try_new(
646 left.mean().sub_checked(right.mean())?,
647 left.variance().add_checked(right.variance())?,
648 )
649 .map(Some),
650 _ => Ok(None),
651 }
652}
653
654pub fn create_bernoulli_from_comparison(
659 op: &Operator,
660 left: &Distribution,
661 right: &Distribution,
662) -> Result<Distribution> {
663 match (left, right) {
664 (Uniform(left), Uniform(right)) => {
665 match op {
666 Operator::Eq | Operator::NotEq => {
667 let (li, ri) = (left.range(), right.range());
668 if let Some(intersection) = li.intersect(ri)? {
669 if let (Some(lc), Some(rc), Some(ic)) = (
672 li.cardinality(),
673 ri.cardinality(),
674 intersection.cardinality(),
675 ) {
676 let pairs = ((lc as u128) * (rc as u128)) as f64;
678 let p = (ic as f64).div_checked(pairs)?;
679 let mut p_value = ScalarValue::from(p);
685 if op == &Operator::NotEq {
686 let one = ScalarValue::from(1.0);
687 p_value = alter_fp_rounding_mode::<false, _>(
688 &one,
689 &p_value,
690 |lhs, rhs| lhs.sub_checked(rhs),
691 )?;
692 };
693 return Distribution::new_bernoulli(p_value);
694 }
695 } else if op == &Operator::Eq {
696 return Distribution::new_bernoulli(ScalarValue::from(0.0));
698 } else {
699 return Distribution::new_bernoulli(ScalarValue::from(1.0));
701 }
702 }
703 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => {
704 }
710 _ => {}
711 }
712 }
713 (Gaussian(_), Gaussian(_)) => {
714 }
717 _ => {}
718 }
719 let (li, ri) = (left.range()?, right.range()?);
720 let range_evaluation = apply_operator(op, &li, &ri)?;
721 if range_evaluation.eq(&Interval::CERTAINLY_FALSE) {
722 Distribution::new_bernoulli(ScalarValue::from(0.0))
723 } else if range_evaluation.eq(&Interval::CERTAINLY_TRUE) {
724 Distribution::new_bernoulli(ScalarValue::from(1.0))
725 } else if range_evaluation.eq(&Interval::UNCERTAIN) {
726 Distribution::new_bernoulli(ScalarValue::try_from(&DataType::Float64)?)
727 } else {
728 internal_err!("This function must be called with a comparison operator")
729 }
730}
731
732pub fn new_generic_from_binary_op(
737 op: &Operator,
738 left: &Distribution,
739 right: &Distribution,
740) -> Result<Distribution> {
741 Distribution::new_generic(
742 compute_mean(op, left, right)?,
743 compute_median(op, left, right)?,
744 compute_variance(op, left, right)?,
745 apply_operator(op, &left.range()?, &right.range()?)?,
746 )
747}
748
749pub fn compute_mean(
752 op: &Operator,
753 left: &Distribution,
754 right: &Distribution,
755) -> Result<ScalarValue> {
756 let (left_mean, right_mean) = (left.mean()?, right.mean()?);
757
758 match op {
759 Operator::Plus => return left_mean.add_checked(right_mean),
760 Operator::Minus => return left_mean.sub_checked(right_mean),
761 Operator::Multiply => return left_mean.mul_checked(right_mean),
763 Operator::Divide => {}
771 _ => {}
773 }
774 let target_type = Distribution::target_type(&[&left_mean, &right_mean])?;
775 ScalarValue::try_from(target_type)
776}
777
778pub fn compute_median(
784 op: &Operator,
785 left: &Distribution,
786 right: &Distribution,
787) -> Result<ScalarValue> {
788 match (left, right) {
789 (Uniform(lu), Uniform(ru)) => {
790 let (left_median, right_median) = (lu.median()?, ru.median()?);
791 match op {
795 Operator::Plus => return left_median.add_checked(right_median),
796 Operator::Minus => return left_median.sub_checked(right_median),
797 _ => {}
799 }
800 }
801 (Gaussian(lg), Gaussian(rg)) => match op {
804 Operator::Plus => return lg.mean().add_checked(rg.mean()),
805 Operator::Minus => return lg.mean().sub_checked(rg.mean()),
806 _ => {}
808 },
809 _ => {}
811 }
812
813 let (left_median, right_median) = (left.median()?, right.median()?);
814 let target_type = Distribution::target_type(&[&left_median, &right_median])?;
815 ScalarValue::try_from(target_type)
816}
817
818pub fn compute_variance(
821 op: &Operator,
822 left: &Distribution,
823 right: &Distribution,
824) -> Result<ScalarValue> {
825 let (left_variance, right_variance) = (left.variance()?, right.variance()?);
826
827 match op {
828 Operator::Plus => return left_variance.add_checked(right_variance),
830 Operator::Minus => return left_variance.add_checked(right_variance),
832 Operator::Multiply => {
834 let (left_mean, right_mean) = (left.mean()?, right.mean()?);
838 let left_mean_sq = left_mean.mul_checked(&left_mean)?;
839 let right_mean_sq = right_mean.mul_checked(&right_mean)?;
840 let left_sos = left_variance.add_checked(&left_mean_sq)?;
841 let right_sos = right_variance.add_checked(&right_mean_sq)?;
842 let pos = left_mean_sq.mul_checked(right_mean_sq)?;
843 return left_sos.mul_checked(right_sos)?.sub_checked(pos);
844 }
845 Operator::Divide => {}
853 _ => {}
855 }
856 let target_type = Distribution::target_type(&[&left_variance, &right_variance])?;
857 ScalarValue::try_from(target_type)
858}
859
860#[cfg(test)]
861mod tests {
862 use super::{
863 combine_bernoullis, combine_gaussians, compute_mean, compute_median,
864 compute_variance, create_bernoulli_from_comparison, new_generic_from_binary_op,
865 BernoulliDistribution, Distribution, GaussianDistribution, UniformDistribution,
866 };
867 use crate::interval_arithmetic::{apply_operator, Interval};
868 use crate::operator::Operator;
869
870 use arrow::datatypes::DataType;
871 use datafusion_common::{HashSet, Result, ScalarValue};
872
873 #[test]
874 fn uniform_dist_is_valid_test() -> Result<()> {
875 assert_eq!(
876 Distribution::new_uniform(Interval::make_zero(&DataType::Int8)?)?,
877 Distribution::Uniform(UniformDistribution {
878 interval: Interval::make_zero(&DataType::Int8)?,
879 })
880 );
881
882 assert!(Distribution::new_uniform(Interval::UNCERTAIN).is_err());
883 Ok(())
884 }
885
886 #[test]
887 fn exponential_dist_is_valid_test() {
888 let exponentials = vec![
890 (
891 Distribution::new_exponential(ScalarValue::Null, ScalarValue::Null, true),
892 false,
893 ),
894 (
895 Distribution::new_exponential(
896 ScalarValue::from(0_f32),
897 ScalarValue::from(1_f32),
898 true,
899 ),
900 false,
901 ),
902 (
903 Distribution::new_exponential(
904 ScalarValue::from(100_f32),
905 ScalarValue::from(1_f32),
906 true,
907 ),
908 true,
909 ),
910 (
911 Distribution::new_exponential(
912 ScalarValue::from(-100_f32),
913 ScalarValue::from(1_f32),
914 true,
915 ),
916 false,
917 ),
918 ];
919 for case in exponentials {
920 assert_eq!(case.0.is_ok(), case.1);
921 }
922 }
923
924 #[test]
925 fn gaussian_dist_is_valid_test() {
926 let gaussians = vec![
928 (
929 Distribution::new_gaussian(ScalarValue::Null, ScalarValue::Null),
930 false,
931 ),
932 (
933 Distribution::new_gaussian(
934 ScalarValue::from(0_f32),
935 ScalarValue::from(0_f32),
936 ),
937 true,
938 ),
939 (
940 Distribution::new_gaussian(
941 ScalarValue::from(0_f32),
942 ScalarValue::from(0.5_f32),
943 ),
944 true,
945 ),
946 (
947 Distribution::new_gaussian(
948 ScalarValue::from(0_f32),
949 ScalarValue::from(-0.5_f32),
950 ),
951 false,
952 ),
953 ];
954 for case in gaussians {
955 assert_eq!(case.0.is_ok(), case.1);
956 }
957 }
958
959 #[test]
960 fn bernoulli_dist_is_valid_test() {
961 let bernoullis = vec![
963 (Distribution::new_bernoulli(ScalarValue::Null), true),
964 (Distribution::new_bernoulli(ScalarValue::from(0.)), true),
965 (Distribution::new_bernoulli(ScalarValue::from(0.25)), true),
966 (Distribution::new_bernoulli(ScalarValue::from(1.)), true),
967 (Distribution::new_bernoulli(ScalarValue::from(11.)), false),
968 (Distribution::new_bernoulli(ScalarValue::from(-11.)), false),
969 (Distribution::new_bernoulli(ScalarValue::from(0_i64)), true),
970 (Distribution::new_bernoulli(ScalarValue::from(1_i64)), true),
971 (
972 Distribution::new_bernoulli(ScalarValue::from(11_i64)),
973 false,
974 ),
975 (
976 Distribution::new_bernoulli(ScalarValue::from(-11_i64)),
977 false,
978 ),
979 ];
980 for case in bernoullis {
981 assert_eq!(case.0.is_ok(), case.1);
982 }
983 }
984
985 #[test]
986 fn generic_dist_is_valid_test() -> Result<()> {
987 let generic_dists = vec![
989 (
991 Distribution::new_generic(
992 ScalarValue::Null,
993 ScalarValue::Null,
994 ScalarValue::Null,
995 Interval::UNCERTAIN,
996 ),
997 false,
998 ),
999 (
1000 Distribution::new_generic(
1001 ScalarValue::Null,
1002 ScalarValue::Null,
1003 ScalarValue::Null,
1004 Interval::make_zero(&DataType::Float32)?,
1005 ),
1006 true,
1007 ),
1008 (
1009 Distribution::new_generic(
1010 ScalarValue::from(0_f32),
1011 ScalarValue::Float32(None),
1012 ScalarValue::Float32(None),
1013 Interval::make_zero(&DataType::Float32)?,
1014 ),
1015 true,
1016 ),
1017 (
1018 Distribution::new_generic(
1019 ScalarValue::Float64(None),
1020 ScalarValue::from(0.),
1021 ScalarValue::Float64(None),
1022 Interval::make_zero(&DataType::Float32)?,
1023 ),
1024 true,
1025 ),
1026 (
1027 Distribution::new_generic(
1028 ScalarValue::from(-10_f32),
1029 ScalarValue::Float32(None),
1030 ScalarValue::Float32(None),
1031 Interval::make_zero(&DataType::Float32)?,
1032 ),
1033 false,
1034 ),
1035 (
1036 Distribution::new_generic(
1037 ScalarValue::Float32(None),
1038 ScalarValue::from(10_f32),
1039 ScalarValue::Float32(None),
1040 Interval::make_zero(&DataType::Float32)?,
1041 ),
1042 false,
1043 ),
1044 (
1045 Distribution::new_generic(
1046 ScalarValue::Null,
1047 ScalarValue::Null,
1048 ScalarValue::Null,
1049 Interval::make_zero(&DataType::Float32)?,
1050 ),
1051 true,
1052 ),
1053 (
1054 Distribution::new_generic(
1055 ScalarValue::from(0),
1056 ScalarValue::from(0),
1057 ScalarValue::Int32(None),
1058 Interval::make_zero(&DataType::Int32)?,
1059 ),
1060 true,
1061 ),
1062 (
1063 Distribution::new_generic(
1064 ScalarValue::from(0_f32),
1065 ScalarValue::from(0_f32),
1066 ScalarValue::Float32(None),
1067 Interval::make_zero(&DataType::Float32)?,
1068 ),
1069 true,
1070 ),
1071 (
1072 Distribution::new_generic(
1073 ScalarValue::from(50.),
1074 ScalarValue::from(50.),
1075 ScalarValue::Float64(None),
1076 Interval::make(Some(0.), Some(100.))?,
1077 ),
1078 true,
1079 ),
1080 (
1081 Distribution::new_generic(
1082 ScalarValue::from(50.),
1083 ScalarValue::from(50.),
1084 ScalarValue::Float64(None),
1085 Interval::make(Some(-100.), Some(0.))?,
1086 ),
1087 false,
1088 ),
1089 (
1090 Distribution::new_generic(
1091 ScalarValue::Float64(None),
1092 ScalarValue::Float64(None),
1093 ScalarValue::from(1.),
1094 Interval::make_zero(&DataType::Float64)?,
1095 ),
1096 true,
1097 ),
1098 (
1099 Distribution::new_generic(
1100 ScalarValue::Float64(None),
1101 ScalarValue::Float64(None),
1102 ScalarValue::from(-1.),
1103 Interval::make_zero(&DataType::Float64)?,
1104 ),
1105 false,
1106 ),
1107 ];
1108 for case in generic_dists {
1109 assert_eq!(case.0.is_ok(), case.1, "{:?}", case.0);
1110 }
1111
1112 Ok(())
1113 }
1114
1115 #[test]
1116 fn mean_extraction_test() -> Result<()> {
1117 let dists = vec![
1119 (
1120 Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1121 ScalarValue::from(0_i64),
1122 ),
1123 (
1124 Distribution::new_uniform(Interval::make_zero(&DataType::Float64)?),
1125 ScalarValue::from(0.),
1126 ),
1127 (
1128 Distribution::new_uniform(Interval::make(Some(1), Some(100))?),
1129 ScalarValue::from(50),
1130 ),
1131 (
1132 Distribution::new_uniform(Interval::make(Some(-100), Some(-1))?),
1133 ScalarValue::from(-50),
1134 ),
1135 (
1136 Distribution::new_uniform(Interval::make(Some(-100), Some(100))?),
1137 ScalarValue::from(0),
1138 ),
1139 (
1140 Distribution::new_exponential(
1141 ScalarValue::from(2.),
1142 ScalarValue::from(0.),
1143 true,
1144 ),
1145 ScalarValue::from(0.5),
1146 ),
1147 (
1148 Distribution::new_exponential(
1149 ScalarValue::from(2.),
1150 ScalarValue::from(1.),
1151 true,
1152 ),
1153 ScalarValue::from(1.5),
1154 ),
1155 (
1156 Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1157 ScalarValue::from(0.),
1158 ),
1159 (
1160 Distribution::new_gaussian(
1161 ScalarValue::from(-2.),
1162 ScalarValue::from(0.5),
1163 ),
1164 ScalarValue::from(-2.),
1165 ),
1166 (
1167 Distribution::new_bernoulli(ScalarValue::from(0.5)),
1168 ScalarValue::from(0.5),
1169 ),
1170 (
1171 Distribution::new_generic(
1172 ScalarValue::from(42.),
1173 ScalarValue::from(42.),
1174 ScalarValue::Float64(None),
1175 Interval::make(Some(25.), Some(50.))?,
1176 ),
1177 ScalarValue::from(42.),
1178 ),
1179 ];
1180
1181 for case in dists {
1182 assert_eq!(case.0?.mean()?, case.1);
1183 }
1184
1185 Ok(())
1186 }
1187
1188 #[test]
1189 fn median_extraction_test() -> Result<()> {
1190 let dists = vec![
1192 (
1193 Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1194 ScalarValue::from(0_i64),
1195 ),
1196 (
1197 Distribution::new_uniform(Interval::make(Some(25.), Some(75.))?),
1198 ScalarValue::from(50.),
1199 ),
1200 (
1201 Distribution::new_exponential(
1202 ScalarValue::from(2_f64.ln()),
1203 ScalarValue::from(0.),
1204 true,
1205 ),
1206 ScalarValue::from(1.),
1207 ),
1208 (
1209 Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1210 ScalarValue::from(2.),
1211 ),
1212 (
1213 Distribution::new_bernoulli(ScalarValue::from(0.25)),
1214 ScalarValue::from(0.),
1215 ),
1216 (
1217 Distribution::new_bernoulli(ScalarValue::from(0.75)),
1218 ScalarValue::from(1.),
1219 ),
1220 (
1221 Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1222 ScalarValue::from(2.),
1223 ),
1224 (
1225 Distribution::new_generic(
1226 ScalarValue::from(12.),
1227 ScalarValue::from(12.),
1228 ScalarValue::Float64(None),
1229 Interval::make(Some(0.), Some(25.))?,
1230 ),
1231 ScalarValue::from(12.),
1232 ),
1233 ];
1234
1235 for case in dists {
1236 assert_eq!(case.0?.median()?, case.1);
1237 }
1238
1239 Ok(())
1240 }
1241
1242 #[test]
1243 fn variance_extraction_test() -> Result<()> {
1244 let dists = vec![
1246 (
1247 Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?),
1248 ScalarValue::from(12.),
1249 ),
1250 (
1251 Distribution::new_exponential(
1252 ScalarValue::from(10.),
1253 ScalarValue::from(0.),
1254 true,
1255 ),
1256 ScalarValue::from(0.01),
1257 ),
1258 (
1259 Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1260 ScalarValue::from(1.),
1261 ),
1262 (
1263 Distribution::new_bernoulli(ScalarValue::from(0.5)),
1264 ScalarValue::from(0.25),
1265 ),
1266 (
1267 Distribution::new_generic(
1268 ScalarValue::Float64(None),
1269 ScalarValue::Float64(None),
1270 ScalarValue::from(0.02),
1271 Interval::make_zero(&DataType::Float64)?,
1272 ),
1273 ScalarValue::from(0.02),
1274 ),
1275 ];
1276
1277 for case in dists {
1278 assert_eq!(case.0?.variance()?, case.1);
1279 }
1280
1281 Ok(())
1282 }
1283
1284 #[test]
1285 fn test_calculate_generic_properties_gauss_gauss() -> Result<()> {
1286 let dist_a =
1287 Distribution::new_gaussian(ScalarValue::from(10.), ScalarValue::from(0.0))?;
1288 let dist_b =
1289 Distribution::new_gaussian(ScalarValue::from(20.), ScalarValue::from(0.0))?;
1290
1291 let test_data = vec![
1292 (
1294 compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1295 ScalarValue::from(30.),
1296 ),
1297 (
1298 compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1299 ScalarValue::from(-10.),
1300 ),
1301 (
1303 compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1304 ScalarValue::from(30.),
1305 ),
1306 (
1307 compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1308 ScalarValue::from(-10.),
1309 ),
1310 ];
1311 for (actual, expected) in test_data {
1312 assert_eq!(actual, expected);
1313 }
1314
1315 Ok(())
1316 }
1317
1318 #[test]
1319 fn test_combine_bernoullis_and_op() -> Result<()> {
1320 let op = Operator::And;
1321 let left = BernoulliDistribution::try_new(ScalarValue::from(0.5))?;
1322 let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1323 let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1324 let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1325
1326 assert_eq!(
1327 combine_bernoullis(&op, &left, &right)?.p_value(),
1328 &ScalarValue::from(0.5 * 0.4)
1329 );
1330 assert_eq!(
1331 combine_bernoullis(&op, &left_null, &right)?.p_value(),
1332 &ScalarValue::Float64(None)
1333 );
1334 assert_eq!(
1335 combine_bernoullis(&op, &left, &right_null)?.p_value(),
1336 &ScalarValue::Float64(None)
1337 );
1338 assert_eq!(
1339 combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1340 &ScalarValue::Null
1341 );
1342
1343 Ok(())
1344 }
1345
1346 #[test]
1347 fn test_combine_bernoullis_or_op() -> Result<()> {
1348 let op = Operator::Or;
1349 let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1350 let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1351 let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1352 let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1353
1354 assert_eq!(
1355 combine_bernoullis(&op, &left, &right)?.p_value(),
1356 &ScalarValue::from(0.6 + 0.4 - (0.6 * 0.4))
1357 );
1358 assert_eq!(
1359 combine_bernoullis(&op, &left_null, &right)?.p_value(),
1360 &ScalarValue::Float64(None)
1361 );
1362 assert_eq!(
1363 combine_bernoullis(&op, &left, &right_null)?.p_value(),
1364 &ScalarValue::Float64(None)
1365 );
1366 assert_eq!(
1367 combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1368 &ScalarValue::Null
1369 );
1370
1371 Ok(())
1372 }
1373
1374 #[test]
1375 fn test_combine_bernoullis_unsupported_ops() -> Result<()> {
1376 let mut operator_set = operator_set();
1377 operator_set.remove(&Operator::And);
1378 operator_set.remove(&Operator::Or);
1379
1380 let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1381 let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1382 for op in operator_set {
1383 assert!(
1384 combine_bernoullis(&op, &left, &right).is_err(),
1385 "Operator {op} should not be supported for Bernoulli distributions"
1386 );
1387 }
1388
1389 Ok(())
1390 }
1391
1392 #[test]
1393 fn test_combine_gaussians_addition() -> Result<()> {
1394 let op = Operator::Plus;
1395 let left = GaussianDistribution::try_new(
1396 ScalarValue::from(3.0),
1397 ScalarValue::from(2.0),
1398 )?;
1399 let right = GaussianDistribution::try_new(
1400 ScalarValue::from(4.0),
1401 ScalarValue::from(1.0),
1402 )?;
1403
1404 let result = combine_gaussians(&op, &left, &right)?.unwrap();
1405
1406 assert_eq!(result.mean(), &ScalarValue::from(7.0)); assert_eq!(result.variance(), &ScalarValue::from(3.0)); Ok(())
1409 }
1410
1411 #[test]
1412 fn test_combine_gaussians_subtraction() -> Result<()> {
1413 let op = Operator::Minus;
1414 let left = GaussianDistribution::try_new(
1415 ScalarValue::from(7.0),
1416 ScalarValue::from(2.0),
1417 )?;
1418 let right = GaussianDistribution::try_new(
1419 ScalarValue::from(4.0),
1420 ScalarValue::from(1.0),
1421 )?;
1422
1423 let result = combine_gaussians(&op, &left, &right)?.unwrap();
1424
1425 assert_eq!(result.mean(), &ScalarValue::from(3.0)); assert_eq!(result.variance(), &ScalarValue::from(3.0)); Ok(())
1429 }
1430
1431 #[test]
1432 fn test_combine_gaussians_unsupported_ops() -> Result<()> {
1433 let mut operator_set = operator_set();
1434 operator_set.remove(&Operator::Plus);
1435 operator_set.remove(&Operator::Minus);
1436
1437 let left = GaussianDistribution::try_new(
1438 ScalarValue::from(7.0),
1439 ScalarValue::from(2.0),
1440 )?;
1441 let right = GaussianDistribution::try_new(
1442 ScalarValue::from(4.0),
1443 ScalarValue::from(1.0),
1444 )?;
1445 for op in operator_set {
1446 assert!(
1447 combine_gaussians(&op, &left, &right)?.is_none(),
1448 "Operator {op} should not be supported for Gaussian distributions"
1449 );
1450 }
1451
1452 Ok(())
1453 }
1454
1455 #[test]
1462 fn test_calculate_generic_properties_uniform_uniform() -> Result<()> {
1463 let dist_a = Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?)?;
1464 let dist_b = Distribution::new_uniform(Interval::make(Some(12.), Some(36.))?)?;
1465
1466 let test_data = vec![
1467 (
1469 compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1470 ScalarValue::from(30.),
1471 ),
1472 (
1473 compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1474 ScalarValue::from(-18.),
1475 ),
1476 (
1477 compute_mean(&Operator::Multiply, &dist_a, &dist_b)?,
1478 ScalarValue::from(144.),
1479 ),
1480 (
1482 compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1483 ScalarValue::from(30.),
1484 ),
1485 (
1486 compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1487 ScalarValue::from(-18.),
1488 ),
1489 (
1491 compute_variance(&Operator::Plus, &dist_a, &dist_b)?,
1492 ScalarValue::from(60.),
1493 ),
1494 (
1495 compute_variance(&Operator::Minus, &dist_a, &dist_b)?,
1496 ScalarValue::from(60.),
1497 ),
1498 (
1499 compute_variance(&Operator::Multiply, &dist_a, &dist_b)?,
1500 ScalarValue::from(9216.),
1501 ),
1502 ];
1503 for (actual, expected) in test_data {
1504 assert_eq!(actual, expected);
1505 }
1506
1507 Ok(())
1508 }
1509
1510 #[test]
1513 fn test_compute_range_where_present() -> Result<()> {
1514 let a = &Interval::make(Some(0.), Some(12.0))?;
1515 let b = &Interval::make(Some(0.), Some(12.0))?;
1516 let mean = ScalarValue::from(6.0);
1517 for (dist_a, dist_b) in [
1518 (
1519 Distribution::new_uniform(a.clone())?,
1520 Distribution::new_uniform(b.clone())?,
1521 ),
1522 (
1523 Distribution::new_generic(
1524 mean.clone(),
1525 mean.clone(),
1526 ScalarValue::Float64(None),
1527 a.clone(),
1528 )?,
1529 Distribution::new_uniform(b.clone())?,
1530 ),
1531 (
1532 Distribution::new_uniform(a.clone())?,
1533 Distribution::new_generic(
1534 mean.clone(),
1535 mean.clone(),
1536 ScalarValue::Float64(None),
1537 b.clone(),
1538 )?,
1539 ),
1540 (
1541 Distribution::new_generic(
1542 mean.clone(),
1543 mean.clone(),
1544 ScalarValue::Float64(None),
1545 a.clone(),
1546 )?,
1547 Distribution::new_generic(
1548 mean.clone(),
1549 mean.clone(),
1550 ScalarValue::Float64(None),
1551 b.clone(),
1552 )?,
1553 ),
1554 ] {
1555 use super::Operator::{
1556 Divide, Eq, Gt, GtEq, Lt, LtEq, Minus, Multiply, NotEq, Plus,
1557 };
1558 for op in [Plus, Minus, Multiply, Divide] {
1559 assert_eq!(
1560 new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?,
1561 apply_operator(&op, a, b)?,
1562 "Failed for {:?} {op} {:?}",
1563 dist_a,
1564 dist_b
1565 );
1566 }
1567 for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] {
1568 assert_eq!(
1569 create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?,
1570 apply_operator(&op, a, b)?,
1571 "Failed for {:?} {op} {:?}",
1572 dist_a,
1573 dist_b
1574 );
1575 }
1576 }
1577
1578 Ok(())
1579 }
1580
1581 fn operator_set() -> HashSet<Operator> {
1582 use super::Operator::*;
1583
1584 let all_ops = vec![
1585 And,
1586 Or,
1587 Eq,
1588 NotEq,
1589 Gt,
1590 GtEq,
1591 Lt,
1592 LtEq,
1593 Plus,
1594 Minus,
1595 Multiply,
1596 Divide,
1597 Modulo,
1598 IsDistinctFrom,
1599 IsNotDistinctFrom,
1600 RegexMatch,
1601 RegexIMatch,
1602 RegexNotMatch,
1603 RegexNotIMatch,
1604 LikeMatch,
1605 ILikeMatch,
1606 NotLikeMatch,
1607 NotILikeMatch,
1608 BitwiseAnd,
1609 BitwiseOr,
1610 BitwiseXor,
1611 BitwiseShiftRight,
1612 BitwiseShiftLeft,
1613 StringConcat,
1614 AtArrow,
1615 ArrowAt,
1616 ];
1617
1618 all_ops.into_iter().collect()
1619 }
1620}