1use std::f64::consts::LN_2;
19
20use crate::interval_arithmetic::{Interval, apply_operator};
21use crate::operator::Operator;
22use crate::type_coercion::binary::binary_numeric_coercion;
23
24use arrow::array::ArrowNativeTypeOp;
25use arrow::datatypes::DataType;
26use datafusion_common::rounding::alter_fp_rounding_mode;
27use datafusion_common::{
28 Result, ScalarValue, assert_eq_or_internal_err, assert_ne_or_internal_err,
29 assert_or_internal_err, internal_err, not_impl_err,
30};
31
32#[derive(Clone, Debug, PartialEq)]
41pub enum Distribution {
42 Uniform(UniformDistribution),
43 Exponential(ExponentialDistribution),
44 Gaussian(GaussianDistribution),
45 Bernoulli(BernoulliDistribution),
46 Generic(GenericDistribution),
47}
48
49use Distribution::{Bernoulli, Exponential, Gaussian, Generic, Uniform};
50
51impl Distribution {
52 pub fn new_uniform(interval: Interval) -> Result<Self> {
54 UniformDistribution::try_new(interval).map(Uniform)
55 }
56
57 pub fn new_exponential(
60 rate: ScalarValue,
61 offset: ScalarValue,
62 positive_tail: bool,
63 ) -> Result<Self> {
64 ExponentialDistribution::try_new(rate, offset, positive_tail).map(Exponential)
65 }
66
67 pub fn new_gaussian(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
70 GaussianDistribution::try_new(mean, variance).map(Gaussian)
71 }
72
73 pub fn new_bernoulli(p: ScalarValue) -> Result<Self> {
76 BernoulliDistribution::try_new(p).map(Bernoulli)
77 }
78
79 pub fn new_generic(
82 mean: ScalarValue,
83 median: ScalarValue,
84 variance: ScalarValue,
85 range: Interval,
86 ) -> Result<Self> {
87 GenericDistribution::try_new(mean, median, variance, range).map(Generic)
88 }
89
90 pub fn new_from_interval(range: Interval) -> Result<Self> {
93 let null = ScalarValue::try_from(range.data_type())?;
94 Distribution::new_generic(null.clone(), null.clone(), null, range)
95 }
96
97 pub fn mean(&self) -> Result<ScalarValue> {
108 match &self {
109 Uniform(u) => u.mean(),
110 Exponential(e) => e.mean(),
111 Gaussian(g) => Ok(g.mean().clone()),
112 Bernoulli(b) => Ok(b.mean().clone()),
113 Generic(u) => Ok(u.mean().clone()),
114 }
115 }
116
117 pub fn median(&self) -> Result<ScalarValue> {
130 match &self {
131 Uniform(u) => u.median(),
132 Exponential(e) => e.median(),
133 Gaussian(g) => Ok(g.median().clone()),
134 Bernoulli(b) => b.median(),
135 Generic(u) => Ok(u.median().clone()),
136 }
137 }
138
139 pub fn variance(&self) -> Result<ScalarValue> {
151 match &self {
152 Uniform(u) => u.variance(),
153 Exponential(e) => e.variance(),
154 Gaussian(g) => Ok(g.variance.clone()),
155 Bernoulli(b) => b.variance(),
156 Generic(u) => Ok(u.variance.clone()),
157 }
158 }
159
160 pub fn range(&self) -> Result<Interval> {
171 match &self {
172 Uniform(u) => Ok(u.range().clone()),
173 Exponential(e) => e.range(),
174 Gaussian(g) => g.range(),
175 Bernoulli(b) => Ok(b.range()),
176 Generic(u) => Ok(u.range().clone()),
177 }
178 }
179
180 pub fn data_type(&self) -> DataType {
183 match &self {
184 Uniform(u) => u.data_type(),
185 Exponential(e) => e.data_type(),
186 Gaussian(g) => g.data_type(),
187 Bernoulli(b) => b.data_type(),
188 Generic(u) => u.data_type(),
189 }
190 }
191
192 pub fn target_type(args: &[&ScalarValue]) -> Result<DataType> {
193 let mut arg_types = args
194 .iter()
195 .filter(|&&arg| arg != &ScalarValue::Null)
196 .map(|&arg| arg.data_type());
197
198 let Some(dt) = arg_types.next().map_or_else(
199 || Some(DataType::Null),
200 |first| {
201 arg_types
202 .try_fold(first, |target, arg| binary_numeric_coercion(&target, &arg))
203 },
204 ) else {
205 return internal_err!("Can only evaluate statistics for numeric types");
206 };
207 Ok(dt)
208 }
209}
210
211#[derive(Clone, Debug, PartialEq)]
218pub struct UniformDistribution {
219 interval: Interval,
220}
221
222#[derive(Clone, Debug, PartialEq)]
240pub struct ExponentialDistribution {
241 rate: ScalarValue,
242 offset: ScalarValue,
243 positive_tail: bool,
246}
247
248#[derive(Clone, Debug, PartialEq)]
253pub struct GaussianDistribution {
254 mean: ScalarValue,
255 variance: ScalarValue,
256}
257
258#[derive(Clone, Debug, PartialEq)]
263pub struct BernoulliDistribution {
264 p: ScalarValue,
265}
266
267#[derive(Clone, Debug, PartialEq)]
272pub struct GenericDistribution {
273 mean: ScalarValue,
274 median: ScalarValue,
275 variance: ScalarValue,
276 range: Interval,
277}
278
279impl UniformDistribution {
280 fn try_new(interval: Interval) -> Result<Self> {
281 assert_ne_or_internal_err!(
282 interval.data_type(),
283 DataType::Boolean,
284 "Construction of a boolean `Uniform` distribution is prohibited, create a `Bernoulli` distribution instead."
285 );
286
287 Ok(Self { interval })
288 }
289
290 pub fn data_type(&self) -> DataType {
291 self.interval.data_type()
292 }
293
294 pub fn mean(&self) -> Result<ScalarValue> {
298 let dt = self.data_type();
300 let two = ScalarValue::from(2).cast_to(&dt)?;
301 let result = self
302 .interval
303 .lower()
304 .add_checked(self.interval.upper())?
305 .div(two);
306 debug_assert!(
307 !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
308 );
309 result
310 }
311
312 pub fn median(&self) -> Result<ScalarValue> {
313 self.mean()
314 }
315
316 pub fn variance(&self) -> Result<ScalarValue> {
320 let width = self.interval.width()?;
322 let dt = width.data_type();
323 let twelve = ScalarValue::from(12).cast_to(&dt)?;
324 let result = width.mul_checked(&width)?.div(twelve);
325 debug_assert!(
326 !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
327 );
328 result
329 }
330
331 pub fn range(&self) -> &Interval {
332 &self.interval
333 }
334}
335
336impl ExponentialDistribution {
337 fn try_new(
338 rate: ScalarValue,
339 offset: ScalarValue,
340 positive_tail: bool,
341 ) -> Result<Self> {
342 let dt = rate.data_type();
343 assert_eq_or_internal_err!(
344 offset.data_type(),
345 dt,
346 "Rate and offset must have the same data type"
347 );
348 assert_or_internal_err!(
349 !offset.is_null(),
350 "Offset of an `ExponentialDistribution` cannot be null"
351 );
352 assert_or_internal_err!(
353 !rate.is_null(),
354 "Rate of an `ExponentialDistribution` cannot be null"
355 );
356 let zero = ScalarValue::new_zero(&dt)?;
357 assert_or_internal_err!(
358 !rate.le(&zero),
359 "Rate of an `ExponentialDistribution` must be positive"
360 );
361 Ok(Self {
362 rate,
363 offset,
364 positive_tail,
365 })
366 }
367
368 pub fn data_type(&self) -> DataType {
369 self.rate.data_type()
370 }
371
372 pub fn rate(&self) -> &ScalarValue {
373 &self.rate
374 }
375
376 pub fn offset(&self) -> &ScalarValue {
377 &self.offset
378 }
379
380 pub fn positive_tail(&self) -> bool {
381 self.positive_tail
382 }
383
384 pub fn mean(&self) -> Result<ScalarValue> {
385 let one = ScalarValue::new_one(&self.data_type())?;
387 let tail_mean = one.div(&self.rate)?;
388 if self.positive_tail {
389 self.offset.add_checked(tail_mean)
390 } else {
391 self.offset.sub_checked(tail_mean)
392 }
393 }
394
395 pub fn median(&self) -> Result<ScalarValue> {
396 let ln_two = ScalarValue::from(LN_2).cast_to(&self.data_type())?;
398 let tail_median = ln_two.div(&self.rate)?;
399 if self.positive_tail {
400 self.offset.add_checked(tail_median)
401 } else {
402 self.offset.sub_checked(tail_median)
403 }
404 }
405
406 pub fn variance(&self) -> Result<ScalarValue> {
407 let one = ScalarValue::new_one(&self.data_type())?;
409 let rate_squared = self.rate.mul_checked(&self.rate)?;
410 one.div(rate_squared)
411 }
412
413 pub fn range(&self) -> Result<Interval> {
414 let end = ScalarValue::try_from(&self.data_type())?;
415 if self.positive_tail {
416 Interval::try_new(self.offset.clone(), end)
417 } else {
418 Interval::try_new(end, self.offset.clone())
419 }
420 }
421}
422
423impl GaussianDistribution {
424 fn try_new(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
425 let dt = mean.data_type();
426 assert_eq_or_internal_err!(
427 variance.data_type(),
428 dt,
429 "Mean and variance must have the same data type"
430 );
431 assert_or_internal_err!(
432 !variance.is_null(),
433 "Variance of a `GaussianDistribution` cannot be null"
434 );
435 let zero = ScalarValue::new_zero(&dt)?;
436 assert_or_internal_err!(
437 !variance.lt(&zero),
438 "Variance of a `GaussianDistribution` must be positive"
439 );
440 Ok(Self { mean, variance })
441 }
442
443 pub fn data_type(&self) -> DataType {
444 self.mean.data_type()
445 }
446
447 pub fn mean(&self) -> &ScalarValue {
448 &self.mean
449 }
450
451 pub fn variance(&self) -> &ScalarValue {
452 &self.variance
453 }
454
455 pub fn median(&self) -> &ScalarValue {
456 self.mean()
457 }
458
459 pub fn range(&self) -> Result<Interval> {
460 Interval::make_unbounded(&self.data_type())
461 }
462}
463
464impl BernoulliDistribution {
465 fn try_new(p: ScalarValue) -> Result<Self> {
466 if p.is_null() {
467 return Ok(Self { p });
468 }
469 let dt = p.data_type();
470 let zero = ScalarValue::new_zero(&dt)?;
471 let one = ScalarValue::new_one(&dt)?;
472 assert_or_internal_err!(
473 p.ge(&zero) && p.le(&one),
474 "Success probability of a `BernoulliDistribution` must be in [0, 1]"
475 );
476 Ok(Self { p })
477 }
478
479 pub fn data_type(&self) -> DataType {
480 self.p.data_type()
481 }
482
483 pub fn p_value(&self) -> &ScalarValue {
484 &self.p
485 }
486
487 pub fn mean(&self) -> &ScalarValue {
488 &self.p
489 }
490
491 pub fn median(&self) -> Result<ScalarValue> {
494 let dt = self.data_type();
495 if self.p.is_null() {
496 ScalarValue::try_from(&dt)
497 } else {
498 let one = ScalarValue::new_one(&dt)?;
499 if one.sub_checked(&self.p)?.lt(&self.p) {
500 ScalarValue::new_one(&dt)
501 } else {
502 ScalarValue::new_zero(&dt)
503 }
504 }
505 }
506
507 pub fn variance(&self) -> Result<ScalarValue> {
510 let dt = self.data_type();
511 let one = ScalarValue::new_one(&dt)?;
512 let result = one.sub_checked(&self.p)?.mul_checked(&self.p);
513 debug_assert!(!self.p.is_null() || result.as_ref().is_ok_and(|r| r.is_null()));
514 result
515 }
516
517 pub fn range(&self) -> Interval {
518 let dt = self.data_type();
519 if ScalarValue::new_zero(&dt).unwrap().eq(&self.p) {
522 Interval::FALSE
523 } else if ScalarValue::new_one(&dt).unwrap().eq(&self.p) {
524 Interval::TRUE
525 } else {
526 Interval::TRUE_OR_FALSE
527 }
528 }
529}
530
531impl GenericDistribution {
532 fn try_new(
533 mean: ScalarValue,
534 median: ScalarValue,
535 variance: ScalarValue,
536 range: Interval,
537 ) -> Result<Self> {
538 assert_ne_or_internal_err!(
539 range.data_type(),
540 DataType::Boolean,
541 "Construction of a boolean `Generic` distribution is prohibited, create a `Bernoulli` distribution instead."
542 );
543
544 let validate_location = |m: &ScalarValue| -> Result<bool> {
545 if m.is_null() {
547 Ok(true)
548 } else {
549 range.contains_value(m)
550 }
551 };
552
553 let locations_valid = validate_location(&mean)? && validate_location(&median)?;
554 let variance_non_negative = if variance.is_null() {
555 true
556 } else {
557 let zero = ScalarValue::new_zero(&variance.data_type())?;
558 !variance.lt(&zero)
559 };
560 assert_or_internal_err!(
561 locations_valid && variance_non_negative,
562 "Tried to construct an invalid `GenericDistribution` instance"
563 );
564
565 Ok(Self {
566 mean,
567 median,
568 variance,
569 range,
570 })
571 }
572
573 pub fn data_type(&self) -> DataType {
574 self.mean.data_type()
575 }
576
577 pub fn mean(&self) -> &ScalarValue {
578 &self.mean
579 }
580
581 pub fn median(&self) -> &ScalarValue {
582 &self.median
583 }
584
585 pub fn variance(&self) -> &ScalarValue {
586 &self.variance
587 }
588
589 pub fn range(&self) -> &Interval {
590 &self.range
591 }
592}
593
594pub fn combine_bernoullis(
598 op: &Operator,
599 left: &BernoulliDistribution,
600 right: &BernoulliDistribution,
601) -> Result<BernoulliDistribution> {
602 let left_p = left.p_value();
603 let right_p = right.p_value();
604 match op {
605 Operator::And => match (left_p.is_null(), right_p.is_null()) {
606 (false, false) => {
607 BernoulliDistribution::try_new(left_p.mul_checked(right_p)?)
608 }
609 (false, true) if left_p.eq(&ScalarValue::new_zero(&left_p.data_type())?) => {
610 Ok(left.clone())
611 }
612 (true, false)
613 if right_p.eq(&ScalarValue::new_zero(&right_p.data_type())?) =>
614 {
615 Ok(right.clone())
616 }
617 _ => {
618 let dt = Distribution::target_type(&[left_p, right_p])?;
619 BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
620 }
621 },
622 Operator::Or => match (left_p.is_null(), right_p.is_null()) {
623 (false, false) => {
624 let sum = left_p.add_checked(right_p)?;
625 let product = left_p.mul_checked(right_p)?;
626 let or_success = sum.sub_checked(product)?;
627 BernoulliDistribution::try_new(or_success)
628 }
629 (false, true) if left_p.eq(&ScalarValue::new_one(&left_p.data_type())?) => {
630 Ok(left.clone())
631 }
632 (true, false) if right_p.eq(&ScalarValue::new_one(&right_p.data_type())?) => {
633 Ok(right.clone())
634 }
635 _ => {
636 let dt = Distribution::target_type(&[left_p, right_p])?;
637 BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
638 }
639 },
640 _ => {
641 not_impl_err!("Statistical evaluation only supports AND and OR operators")
642 }
643 }
644}
645
646pub fn combine_gaussians(
653 op: &Operator,
654 left: &GaussianDistribution,
655 right: &GaussianDistribution,
656) -> Result<Option<GaussianDistribution>> {
657 match op {
658 Operator::Plus => GaussianDistribution::try_new(
659 left.mean().add_checked(right.mean())?,
660 left.variance().add_checked(right.variance())?,
661 )
662 .map(Some),
663 Operator::Minus => GaussianDistribution::try_new(
664 left.mean().sub_checked(right.mean())?,
665 left.variance().add_checked(right.variance())?,
666 )
667 .map(Some),
668 _ => Ok(None),
669 }
670}
671
672pub fn create_bernoulli_from_comparison(
677 op: &Operator,
678 left: &Distribution,
679 right: &Distribution,
680) -> Result<Distribution> {
681 match (left, right) {
682 (Uniform(left), Uniform(right)) => {
683 match op {
684 Operator::Eq | Operator::NotEq => {
685 let (li, ri) = (left.range(), right.range());
686 if let Some(intersection) = li.intersect(ri)? {
687 if let (Some(lc), Some(rc), Some(ic)) = (
690 li.cardinality(),
691 ri.cardinality(),
692 intersection.cardinality(),
693 ) {
694 let pairs = ((lc as u128) * (rc as u128)) as f64;
696 let p = (ic as f64).div_checked(pairs)?;
697 let mut p_value = ScalarValue::from(p);
703 if op == &Operator::NotEq {
704 let one = ScalarValue::from(1.0);
705 p_value = alter_fp_rounding_mode::<false, _>(
706 &one,
707 &p_value,
708 |lhs, rhs| lhs.sub_checked(rhs),
709 )?;
710 };
711 return Distribution::new_bernoulli(p_value);
712 }
713 } else if op == &Operator::Eq {
714 return Distribution::new_bernoulli(ScalarValue::from(0.0));
716 } else {
717 return Distribution::new_bernoulli(ScalarValue::from(1.0));
719 }
720 }
721 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => {
722 }
728 _ => {}
729 }
730 }
731 (Gaussian(_), Gaussian(_)) => {
732 }
735 _ => {}
736 }
737 let (li, ri) = (left.range()?, right.range()?);
738 let range_evaluation = apply_operator(op, &li, &ri)?;
739 if range_evaluation.eq(&Interval::FALSE) {
740 Distribution::new_bernoulli(ScalarValue::from(0.0))
741 } else if range_evaluation.eq(&Interval::TRUE) {
742 Distribution::new_bernoulli(ScalarValue::from(1.0))
743 } else if range_evaluation.eq(&Interval::TRUE_OR_FALSE) {
744 Distribution::new_bernoulli(ScalarValue::try_from(&DataType::Float64)?)
745 } else {
746 internal_err!("This function must be called with a comparison operator")
747 }
748}
749
750pub fn new_generic_from_binary_op(
755 op: &Operator,
756 left: &Distribution,
757 right: &Distribution,
758) -> Result<Distribution> {
759 Distribution::new_generic(
760 compute_mean(op, left, right)?,
761 compute_median(op, left, right)?,
762 compute_variance(op, left, right)?,
763 apply_operator(op, &left.range()?, &right.range()?)?,
764 )
765}
766
767pub fn compute_mean(
770 op: &Operator,
771 left: &Distribution,
772 right: &Distribution,
773) -> Result<ScalarValue> {
774 let (left_mean, right_mean) = (left.mean()?, right.mean()?);
775
776 match op {
777 Operator::Plus => return left_mean.add_checked(right_mean),
778 Operator::Minus => return left_mean.sub_checked(right_mean),
779 Operator::Multiply => return left_mean.mul_checked(right_mean),
781 Operator::Divide => {}
789 _ => {}
791 }
792 let target_type = Distribution::target_type(&[&left_mean, &right_mean])?;
793 ScalarValue::try_from(target_type)
794}
795
796pub fn compute_median(
802 op: &Operator,
803 left: &Distribution,
804 right: &Distribution,
805) -> Result<ScalarValue> {
806 match (left, right) {
807 (Uniform(lu), Uniform(ru)) => {
808 let (left_median, right_median) = (lu.median()?, ru.median()?);
809 match op {
813 Operator::Plus => return left_median.add_checked(right_median),
814 Operator::Minus => return left_median.sub_checked(right_median),
815 _ => {}
817 }
818 }
819 (Gaussian(lg), Gaussian(rg)) => match op {
822 Operator::Plus => return lg.mean().add_checked(rg.mean()),
823 Operator::Minus => return lg.mean().sub_checked(rg.mean()),
824 _ => {}
826 },
827 _ => {}
829 }
830
831 let (left_median, right_median) = (left.median()?, right.median()?);
832 let target_type = Distribution::target_type(&[&left_median, &right_median])?;
833 ScalarValue::try_from(target_type)
834}
835
836pub fn compute_variance(
839 op: &Operator,
840 left: &Distribution,
841 right: &Distribution,
842) -> Result<ScalarValue> {
843 let (left_variance, right_variance) = (left.variance()?, right.variance()?);
844
845 match op {
846 Operator::Plus => return left_variance.add_checked(right_variance),
848 Operator::Minus => return left_variance.add_checked(right_variance),
850 Operator::Multiply => {
852 let (left_mean, right_mean) = (left.mean()?, right.mean()?);
856 let left_mean_sq = left_mean.mul_checked(&left_mean)?;
857 let right_mean_sq = right_mean.mul_checked(&right_mean)?;
858 let left_sos = left_variance.add_checked(&left_mean_sq)?;
859 let right_sos = right_variance.add_checked(&right_mean_sq)?;
860 let pos = left_mean_sq.mul_checked(right_mean_sq)?;
861 return left_sos.mul_checked(right_sos)?.sub_checked(pos);
862 }
863 Operator::Divide => {}
871 _ => {}
873 }
874 let target_type = Distribution::target_type(&[&left_variance, &right_variance])?;
875 ScalarValue::try_from(target_type)
876}
877
878#[cfg(test)]
879mod tests {
880 use super::{
881 BernoulliDistribution, Distribution, GaussianDistribution, UniformDistribution,
882 combine_bernoullis, combine_gaussians, compute_mean, compute_median,
883 compute_variance, create_bernoulli_from_comparison, new_generic_from_binary_op,
884 };
885 use crate::interval_arithmetic::{Interval, apply_operator};
886 use crate::operator::Operator;
887
888 use arrow::datatypes::DataType;
889 use datafusion_common::{HashSet, Result, ScalarValue};
890
891 #[test]
892 fn uniform_dist_is_valid_test() -> Result<()> {
893 assert_eq!(
894 Distribution::new_uniform(Interval::make_zero(&DataType::Int8)?)?,
895 Distribution::Uniform(UniformDistribution {
896 interval: Interval::make_zero(&DataType::Int8)?,
897 })
898 );
899
900 assert!(Distribution::new_uniform(Interval::TRUE_OR_FALSE).is_err());
901 Ok(())
902 }
903
904 #[test]
905 fn exponential_dist_is_valid_test() {
906 let exponentials = vec![
908 (
909 Distribution::new_exponential(ScalarValue::Null, ScalarValue::Null, true),
910 false,
911 ),
912 (
913 Distribution::new_exponential(
914 ScalarValue::from(0_f32),
915 ScalarValue::from(1_f32),
916 true,
917 ),
918 false,
919 ),
920 (
921 Distribution::new_exponential(
922 ScalarValue::from(100_f32),
923 ScalarValue::from(1_f32),
924 true,
925 ),
926 true,
927 ),
928 (
929 Distribution::new_exponential(
930 ScalarValue::from(-100_f32),
931 ScalarValue::from(1_f32),
932 true,
933 ),
934 false,
935 ),
936 ];
937 for case in exponentials {
938 assert_eq!(case.0.is_ok(), case.1);
939 }
940 }
941
942 #[test]
943 fn gaussian_dist_is_valid_test() {
944 let gaussians = vec![
946 (
947 Distribution::new_gaussian(ScalarValue::Null, ScalarValue::Null),
948 false,
949 ),
950 (
951 Distribution::new_gaussian(
952 ScalarValue::from(0_f32),
953 ScalarValue::from(0_f32),
954 ),
955 true,
956 ),
957 (
958 Distribution::new_gaussian(
959 ScalarValue::from(0_f32),
960 ScalarValue::from(0.5_f32),
961 ),
962 true,
963 ),
964 (
965 Distribution::new_gaussian(
966 ScalarValue::from(0_f32),
967 ScalarValue::from(-0.5_f32),
968 ),
969 false,
970 ),
971 ];
972 for case in gaussians {
973 assert_eq!(case.0.is_ok(), case.1);
974 }
975 }
976
977 #[test]
978 fn bernoulli_dist_is_valid_test() {
979 let bernoullis = vec![
981 (Distribution::new_bernoulli(ScalarValue::Null), true),
982 (Distribution::new_bernoulli(ScalarValue::from(0.)), true),
983 (Distribution::new_bernoulli(ScalarValue::from(0.25)), true),
984 (Distribution::new_bernoulli(ScalarValue::from(1.)), true),
985 (Distribution::new_bernoulli(ScalarValue::from(11.)), false),
986 (Distribution::new_bernoulli(ScalarValue::from(-11.)), false),
987 (Distribution::new_bernoulli(ScalarValue::from(0_i64)), true),
988 (Distribution::new_bernoulli(ScalarValue::from(1_i64)), true),
989 (
990 Distribution::new_bernoulli(ScalarValue::from(11_i64)),
991 false,
992 ),
993 (
994 Distribution::new_bernoulli(ScalarValue::from(-11_i64)),
995 false,
996 ),
997 ];
998 for case in bernoullis {
999 assert_eq!(case.0.is_ok(), case.1);
1000 }
1001 }
1002
1003 #[test]
1004 fn generic_dist_is_valid_test() -> Result<()> {
1005 let generic_dists = vec![
1007 (
1009 Distribution::new_generic(
1010 ScalarValue::Null,
1011 ScalarValue::Null,
1012 ScalarValue::Null,
1013 Interval::TRUE_OR_FALSE,
1014 ),
1015 false,
1016 ),
1017 (
1018 Distribution::new_generic(
1019 ScalarValue::Null,
1020 ScalarValue::Null,
1021 ScalarValue::Null,
1022 Interval::make_zero(&DataType::Float32)?,
1023 ),
1024 true,
1025 ),
1026 (
1027 Distribution::new_generic(
1028 ScalarValue::from(0_f32),
1029 ScalarValue::Float32(None),
1030 ScalarValue::Float32(None),
1031 Interval::make_zero(&DataType::Float32)?,
1032 ),
1033 true,
1034 ),
1035 (
1036 Distribution::new_generic(
1037 ScalarValue::Float64(None),
1038 ScalarValue::from(0.),
1039 ScalarValue::Float64(None),
1040 Interval::make_zero(&DataType::Float32)?,
1041 ),
1042 true,
1043 ),
1044 (
1045 Distribution::new_generic(
1046 ScalarValue::from(-10_f32),
1047 ScalarValue::Float32(None),
1048 ScalarValue::Float32(None),
1049 Interval::make_zero(&DataType::Float32)?,
1050 ),
1051 false,
1052 ),
1053 (
1054 Distribution::new_generic(
1055 ScalarValue::Float32(None),
1056 ScalarValue::from(10_f32),
1057 ScalarValue::Float32(None),
1058 Interval::make_zero(&DataType::Float32)?,
1059 ),
1060 false,
1061 ),
1062 (
1063 Distribution::new_generic(
1064 ScalarValue::Null,
1065 ScalarValue::Null,
1066 ScalarValue::Null,
1067 Interval::make_zero(&DataType::Float32)?,
1068 ),
1069 true,
1070 ),
1071 (
1072 Distribution::new_generic(
1073 ScalarValue::from(0),
1074 ScalarValue::from(0),
1075 ScalarValue::Int32(None),
1076 Interval::make_zero(&DataType::Int32)?,
1077 ),
1078 true,
1079 ),
1080 (
1081 Distribution::new_generic(
1082 ScalarValue::from(0_f32),
1083 ScalarValue::from(0_f32),
1084 ScalarValue::Float32(None),
1085 Interval::make_zero(&DataType::Float32)?,
1086 ),
1087 true,
1088 ),
1089 (
1090 Distribution::new_generic(
1091 ScalarValue::from(50.),
1092 ScalarValue::from(50.),
1093 ScalarValue::Float64(None),
1094 Interval::make(Some(0.), Some(100.))?,
1095 ),
1096 true,
1097 ),
1098 (
1099 Distribution::new_generic(
1100 ScalarValue::from(50.),
1101 ScalarValue::from(50.),
1102 ScalarValue::Float64(None),
1103 Interval::make(Some(-100.), Some(0.))?,
1104 ),
1105 false,
1106 ),
1107 (
1108 Distribution::new_generic(
1109 ScalarValue::Float64(None),
1110 ScalarValue::Float64(None),
1111 ScalarValue::from(1.),
1112 Interval::make_zero(&DataType::Float64)?,
1113 ),
1114 true,
1115 ),
1116 (
1117 Distribution::new_generic(
1118 ScalarValue::Float64(None),
1119 ScalarValue::Float64(None),
1120 ScalarValue::from(-1.),
1121 Interval::make_zero(&DataType::Float64)?,
1122 ),
1123 false,
1124 ),
1125 ];
1126 for case in generic_dists {
1127 assert_eq!(case.0.is_ok(), case.1, "{:?}", case.0);
1128 }
1129
1130 Ok(())
1131 }
1132
1133 #[test]
1134 fn mean_extraction_test() -> Result<()> {
1135 let dists = vec![
1137 (
1138 Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1139 ScalarValue::from(0_i64),
1140 ),
1141 (
1142 Distribution::new_uniform(Interval::make_zero(&DataType::Float64)?),
1143 ScalarValue::from(0.),
1144 ),
1145 (
1146 Distribution::new_uniform(Interval::make(Some(1), Some(100))?),
1147 ScalarValue::from(50),
1148 ),
1149 (
1150 Distribution::new_uniform(Interval::make(Some(-100), Some(-1))?),
1151 ScalarValue::from(-50),
1152 ),
1153 (
1154 Distribution::new_uniform(Interval::make(Some(-100), Some(100))?),
1155 ScalarValue::from(0),
1156 ),
1157 (
1158 Distribution::new_exponential(
1159 ScalarValue::from(2.),
1160 ScalarValue::from(0.),
1161 true,
1162 ),
1163 ScalarValue::from(0.5),
1164 ),
1165 (
1166 Distribution::new_exponential(
1167 ScalarValue::from(2.),
1168 ScalarValue::from(1.),
1169 true,
1170 ),
1171 ScalarValue::from(1.5),
1172 ),
1173 (
1174 Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1175 ScalarValue::from(0.),
1176 ),
1177 (
1178 Distribution::new_gaussian(
1179 ScalarValue::from(-2.),
1180 ScalarValue::from(0.5),
1181 ),
1182 ScalarValue::from(-2.),
1183 ),
1184 (
1185 Distribution::new_bernoulli(ScalarValue::from(0.5)),
1186 ScalarValue::from(0.5),
1187 ),
1188 (
1189 Distribution::new_generic(
1190 ScalarValue::from(42.),
1191 ScalarValue::from(42.),
1192 ScalarValue::Float64(None),
1193 Interval::make(Some(25.), Some(50.))?,
1194 ),
1195 ScalarValue::from(42.),
1196 ),
1197 ];
1198
1199 for case in dists {
1200 assert_eq!(case.0?.mean()?, case.1);
1201 }
1202
1203 Ok(())
1204 }
1205
1206 #[test]
1207 fn median_extraction_test() -> Result<()> {
1208 let dists = vec![
1210 (
1211 Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1212 ScalarValue::from(0_i64),
1213 ),
1214 (
1215 Distribution::new_uniform(Interval::make(Some(25.), Some(75.))?),
1216 ScalarValue::from(50.),
1217 ),
1218 (
1219 Distribution::new_exponential(
1220 ScalarValue::from(2_f64.ln()),
1221 ScalarValue::from(0.),
1222 true,
1223 ),
1224 ScalarValue::from(1.),
1225 ),
1226 (
1227 Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1228 ScalarValue::from(2.),
1229 ),
1230 (
1231 Distribution::new_bernoulli(ScalarValue::from(0.25)),
1232 ScalarValue::from(0.),
1233 ),
1234 (
1235 Distribution::new_bernoulli(ScalarValue::from(0.75)),
1236 ScalarValue::from(1.),
1237 ),
1238 (
1239 Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1240 ScalarValue::from(2.),
1241 ),
1242 (
1243 Distribution::new_generic(
1244 ScalarValue::from(12.),
1245 ScalarValue::from(12.),
1246 ScalarValue::Float64(None),
1247 Interval::make(Some(0.), Some(25.))?,
1248 ),
1249 ScalarValue::from(12.),
1250 ),
1251 ];
1252
1253 for case in dists {
1254 assert_eq!(case.0?.median()?, case.1);
1255 }
1256
1257 Ok(())
1258 }
1259
1260 #[test]
1261 fn variance_extraction_test() -> Result<()> {
1262 let dists = vec![
1264 (
1265 Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?),
1266 ScalarValue::from(12.),
1267 ),
1268 (
1269 Distribution::new_exponential(
1270 ScalarValue::from(10.),
1271 ScalarValue::from(0.),
1272 true,
1273 ),
1274 ScalarValue::from(0.01),
1275 ),
1276 (
1277 Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1278 ScalarValue::from(1.),
1279 ),
1280 (
1281 Distribution::new_bernoulli(ScalarValue::from(0.5)),
1282 ScalarValue::from(0.25),
1283 ),
1284 (
1285 Distribution::new_generic(
1286 ScalarValue::Float64(None),
1287 ScalarValue::Float64(None),
1288 ScalarValue::from(0.02),
1289 Interval::make_zero(&DataType::Float64)?,
1290 ),
1291 ScalarValue::from(0.02),
1292 ),
1293 ];
1294
1295 for case in dists {
1296 assert_eq!(case.0?.variance()?, case.1);
1297 }
1298
1299 Ok(())
1300 }
1301
1302 #[test]
1303 fn test_calculate_generic_properties_gauss_gauss() -> Result<()> {
1304 let dist_a =
1305 Distribution::new_gaussian(ScalarValue::from(10.), ScalarValue::from(0.0))?;
1306 let dist_b =
1307 Distribution::new_gaussian(ScalarValue::from(20.), ScalarValue::from(0.0))?;
1308
1309 let test_data = vec![
1310 (
1312 compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1313 ScalarValue::from(30.),
1314 ),
1315 (
1316 compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1317 ScalarValue::from(-10.),
1318 ),
1319 (
1321 compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1322 ScalarValue::from(30.),
1323 ),
1324 (
1325 compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1326 ScalarValue::from(-10.),
1327 ),
1328 ];
1329 for (actual, expected) in test_data {
1330 assert_eq!(actual, expected);
1331 }
1332
1333 Ok(())
1334 }
1335
1336 #[test]
1337 fn test_combine_bernoullis_and_op() -> Result<()> {
1338 let op = Operator::And;
1339 let left = BernoulliDistribution::try_new(ScalarValue::from(0.5))?;
1340 let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1341 let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1342 let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1343
1344 assert_eq!(
1345 combine_bernoullis(&op, &left, &right)?.p_value(),
1346 &ScalarValue::from(0.5 * 0.4)
1347 );
1348 assert_eq!(
1349 combine_bernoullis(&op, &left_null, &right)?.p_value(),
1350 &ScalarValue::Float64(None)
1351 );
1352 assert_eq!(
1353 combine_bernoullis(&op, &left, &right_null)?.p_value(),
1354 &ScalarValue::Float64(None)
1355 );
1356 assert_eq!(
1357 combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1358 &ScalarValue::Null
1359 );
1360
1361 Ok(())
1362 }
1363
1364 #[test]
1365 fn test_combine_bernoullis_or_op() -> Result<()> {
1366 let op = Operator::Or;
1367 let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1368 let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1369 let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1370 let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1371
1372 assert_eq!(
1373 combine_bernoullis(&op, &left, &right)?.p_value(),
1374 &ScalarValue::from(0.6 + 0.4 - (0.6 * 0.4))
1375 );
1376 assert_eq!(
1377 combine_bernoullis(&op, &left_null, &right)?.p_value(),
1378 &ScalarValue::Float64(None)
1379 );
1380 assert_eq!(
1381 combine_bernoullis(&op, &left, &right_null)?.p_value(),
1382 &ScalarValue::Float64(None)
1383 );
1384 assert_eq!(
1385 combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1386 &ScalarValue::Null
1387 );
1388
1389 Ok(())
1390 }
1391
1392 #[test]
1393 fn test_combine_bernoullis_unsupported_ops() -> Result<()> {
1394 let mut operator_set = operator_set();
1395 operator_set.remove(&Operator::And);
1396 operator_set.remove(&Operator::Or);
1397
1398 let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1399 let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1400 for op in operator_set {
1401 assert!(
1402 combine_bernoullis(&op, &left, &right).is_err(),
1403 "Operator {op} should not be supported for Bernoulli distributions"
1404 );
1405 }
1406
1407 Ok(())
1408 }
1409
1410 #[test]
1411 fn test_combine_gaussians_addition() -> Result<()> {
1412 let op = Operator::Plus;
1413 let left = GaussianDistribution::try_new(
1414 ScalarValue::from(3.0),
1415 ScalarValue::from(2.0),
1416 )?;
1417 let right = GaussianDistribution::try_new(
1418 ScalarValue::from(4.0),
1419 ScalarValue::from(1.0),
1420 )?;
1421
1422 let result = combine_gaussians(&op, &left, &right)?.unwrap();
1423
1424 assert_eq!(result.mean(), &ScalarValue::from(7.0)); assert_eq!(result.variance(), &ScalarValue::from(3.0)); Ok(())
1427 }
1428
1429 #[test]
1430 fn test_combine_gaussians_subtraction() -> Result<()> {
1431 let op = Operator::Minus;
1432 let left = GaussianDistribution::try_new(
1433 ScalarValue::from(7.0),
1434 ScalarValue::from(2.0),
1435 )?;
1436 let right = GaussianDistribution::try_new(
1437 ScalarValue::from(4.0),
1438 ScalarValue::from(1.0),
1439 )?;
1440
1441 let result = combine_gaussians(&op, &left, &right)?.unwrap();
1442
1443 assert_eq!(result.mean(), &ScalarValue::from(3.0)); assert_eq!(result.variance(), &ScalarValue::from(3.0)); Ok(())
1447 }
1448
1449 #[test]
1450 fn test_combine_gaussians_unsupported_ops() -> Result<()> {
1451 let mut operator_set = operator_set();
1452 operator_set.remove(&Operator::Plus);
1453 operator_set.remove(&Operator::Minus);
1454
1455 let left = GaussianDistribution::try_new(
1456 ScalarValue::from(7.0),
1457 ScalarValue::from(2.0),
1458 )?;
1459 let right = GaussianDistribution::try_new(
1460 ScalarValue::from(4.0),
1461 ScalarValue::from(1.0),
1462 )?;
1463 for op in operator_set {
1464 assert!(
1465 combine_gaussians(&op, &left, &right)?.is_none(),
1466 "Operator {op} should not be supported for Gaussian distributions"
1467 );
1468 }
1469
1470 Ok(())
1471 }
1472
1473 #[test]
1480 fn test_calculate_generic_properties_uniform_uniform() -> Result<()> {
1481 let dist_a = Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?)?;
1482 let dist_b = Distribution::new_uniform(Interval::make(Some(12.), Some(36.))?)?;
1483
1484 let test_data = vec![
1485 (
1487 compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1488 ScalarValue::from(30.),
1489 ),
1490 (
1491 compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1492 ScalarValue::from(-18.),
1493 ),
1494 (
1495 compute_mean(&Operator::Multiply, &dist_a, &dist_b)?,
1496 ScalarValue::from(144.),
1497 ),
1498 (
1500 compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1501 ScalarValue::from(30.),
1502 ),
1503 (
1504 compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1505 ScalarValue::from(-18.),
1506 ),
1507 (
1509 compute_variance(&Operator::Plus, &dist_a, &dist_b)?,
1510 ScalarValue::from(60.),
1511 ),
1512 (
1513 compute_variance(&Operator::Minus, &dist_a, &dist_b)?,
1514 ScalarValue::from(60.),
1515 ),
1516 (
1517 compute_variance(&Operator::Multiply, &dist_a, &dist_b)?,
1518 ScalarValue::from(9216.),
1519 ),
1520 ];
1521 for (actual, expected) in test_data {
1522 assert_eq!(actual, expected);
1523 }
1524
1525 Ok(())
1526 }
1527
1528 #[test]
1531 fn test_compute_range_where_present() -> Result<()> {
1532 let a = &Interval::make(Some(0.), Some(12.0))?;
1533 let b = &Interval::make(Some(0.), Some(12.0))?;
1534 let mean = ScalarValue::from(6.0);
1535 for (dist_a, dist_b) in [
1536 (
1537 Distribution::new_uniform(a.clone())?,
1538 Distribution::new_uniform(b.clone())?,
1539 ),
1540 (
1541 Distribution::new_generic(
1542 mean.clone(),
1543 mean.clone(),
1544 ScalarValue::Float64(None),
1545 a.clone(),
1546 )?,
1547 Distribution::new_uniform(b.clone())?,
1548 ),
1549 (
1550 Distribution::new_uniform(a.clone())?,
1551 Distribution::new_generic(
1552 mean.clone(),
1553 mean.clone(),
1554 ScalarValue::Float64(None),
1555 b.clone(),
1556 )?,
1557 ),
1558 (
1559 Distribution::new_generic(
1560 mean.clone(),
1561 mean.clone(),
1562 ScalarValue::Float64(None),
1563 a.clone(),
1564 )?,
1565 Distribution::new_generic(
1566 mean.clone(),
1567 mean.clone(),
1568 ScalarValue::Float64(None),
1569 b.clone(),
1570 )?,
1571 ),
1572 ] {
1573 use super::Operator::{
1574 Divide, Eq, Gt, GtEq, Lt, LtEq, Minus, Multiply, NotEq, Plus,
1575 };
1576 for op in [Plus, Minus, Multiply, Divide] {
1577 assert_eq!(
1578 new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?,
1579 apply_operator(&op, a, b)?,
1580 "Failed for {dist_a:?} {op} {dist_b:?}"
1581 );
1582 }
1583 for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] {
1584 assert_eq!(
1585 create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?,
1586 apply_operator(&op, a, b)?,
1587 "Failed for {dist_a:?} {op} {dist_b:?}"
1588 );
1589 }
1590 }
1591
1592 Ok(())
1593 }
1594
1595 fn operator_set() -> HashSet<Operator> {
1596 use super::Operator::*;
1597
1598 let all_ops = vec![
1599 And,
1600 Or,
1601 Eq,
1602 NotEq,
1603 Gt,
1604 GtEq,
1605 Lt,
1606 LtEq,
1607 Plus,
1608 Minus,
1609 Multiply,
1610 Divide,
1611 Modulo,
1612 IsDistinctFrom,
1613 IsNotDistinctFrom,
1614 RegexMatch,
1615 RegexIMatch,
1616 RegexNotMatch,
1617 RegexNotIMatch,
1618 LikeMatch,
1619 ILikeMatch,
1620 NotLikeMatch,
1621 NotILikeMatch,
1622 BitwiseAnd,
1623 BitwiseOr,
1624 BitwiseXor,
1625 BitwiseShiftRight,
1626 BitwiseShiftLeft,
1627 StringConcat,
1628 AtArrow,
1629 ArrowAt,
1630 ];
1631
1632 all_ops.into_iter().collect()
1633 }
1634}