1#![allow(deprecated)]
27
28use std::f64::consts::LN_2;
29
30use crate::interval_arithmetic::{Interval, apply_operator};
31use crate::operator::Operator;
32use crate::type_coercion::binary::binary_numeric_coercion;
33
34use arrow::array::ArrowNativeTypeOp;
35use arrow::datatypes::DataType;
36use datafusion_common::rounding::alter_fp_rounding_mode;
37use datafusion_common::{
38 Result, ScalarValue, assert_eq_or_internal_err, assert_ne_or_internal_err,
39 assert_or_internal_err, internal_err, not_impl_err,
40};
41
42#[deprecated(
51 since = "54.0.0",
52 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
53)]
54#[derive(Clone, Debug, PartialEq)]
55pub enum Distribution {
56 Uniform(UniformDistribution),
57 Exponential(ExponentialDistribution),
58 Gaussian(GaussianDistribution),
59 Bernoulli(BernoulliDistribution),
60 Generic(GenericDistribution),
61}
62
63use Distribution::{Bernoulli, Exponential, Gaussian, Generic, Uniform};
64
65impl Distribution {
66 pub fn new_uniform(interval: Interval) -> Result<Self> {
68 UniformDistribution::try_new(interval).map(Uniform)
69 }
70
71 pub fn new_exponential(
74 rate: ScalarValue,
75 offset: ScalarValue,
76 positive_tail: bool,
77 ) -> Result<Self> {
78 ExponentialDistribution::try_new(rate, offset, positive_tail).map(Exponential)
79 }
80
81 pub fn new_gaussian(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
84 GaussianDistribution::try_new(mean, variance).map(Gaussian)
85 }
86
87 pub fn new_bernoulli(p: ScalarValue) -> Result<Self> {
90 BernoulliDistribution::try_new(p).map(Bernoulli)
91 }
92
93 pub fn new_generic(
96 mean: ScalarValue,
97 median: ScalarValue,
98 variance: ScalarValue,
99 range: Interval,
100 ) -> Result<Self> {
101 GenericDistribution::try_new(mean, median, variance, range).map(Generic)
102 }
103
104 pub fn new_from_interval(range: Interval) -> Result<Self> {
107 let null = ScalarValue::try_from(range.data_type())?;
108 Distribution::new_generic(null.clone(), null.clone(), null, range)
109 }
110
111 pub fn mean(&self) -> Result<ScalarValue> {
122 match &self {
123 Uniform(u) => u.mean(),
124 Exponential(e) => e.mean(),
125 Gaussian(g) => Ok(g.mean().clone()),
126 Bernoulli(b) => Ok(b.mean().clone()),
127 Generic(u) => Ok(u.mean().clone()),
128 }
129 }
130
131 pub fn median(&self) -> Result<ScalarValue> {
144 match &self {
145 Uniform(u) => u.median(),
146 Exponential(e) => e.median(),
147 Gaussian(g) => Ok(g.median().clone()),
148 Bernoulli(b) => b.median(),
149 Generic(u) => Ok(u.median().clone()),
150 }
151 }
152
153 pub fn variance(&self) -> Result<ScalarValue> {
165 match &self {
166 Uniform(u) => u.variance(),
167 Exponential(e) => e.variance(),
168 Gaussian(g) => Ok(g.variance.clone()),
169 Bernoulli(b) => b.variance(),
170 Generic(u) => Ok(u.variance.clone()),
171 }
172 }
173
174 pub fn range(&self) -> Result<Interval> {
185 match &self {
186 Uniform(u) => Ok(u.range().clone()),
187 Exponential(e) => e.range(),
188 Gaussian(g) => g.range(),
189 Bernoulli(b) => Ok(b.range()),
190 Generic(u) => Ok(u.range().clone()),
191 }
192 }
193
194 pub fn data_type(&self) -> DataType {
197 match &self {
198 Uniform(u) => u.data_type(),
199 Exponential(e) => e.data_type(),
200 Gaussian(g) => g.data_type(),
201 Bernoulli(b) => b.data_type(),
202 Generic(u) => u.data_type(),
203 }
204 }
205
206 pub fn target_type(args: &[&ScalarValue]) -> Result<DataType> {
207 let mut arg_types = args
208 .iter()
209 .filter(|&&arg| arg != &ScalarValue::Null)
210 .map(|&arg| arg.data_type());
211
212 let Some(dt) = arg_types.next().map_or_else(
213 || Some(DataType::Null),
214 |first| {
215 arg_types
216 .try_fold(first, |target, arg| binary_numeric_coercion(&target, &arg))
217 },
218 ) else {
219 return internal_err!("Can only evaluate statistics for numeric types");
220 };
221 Ok(dt)
222 }
223}
224
225#[deprecated(
232 since = "54.0.0",
233 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
234)]
235#[derive(Clone, Debug, PartialEq)]
236pub struct UniformDistribution {
237 interval: Interval,
238}
239
240#[deprecated(
258 since = "54.0.0",
259 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
260)]
261#[derive(Clone, Debug, PartialEq)]
262pub struct ExponentialDistribution {
263 rate: ScalarValue,
264 offset: ScalarValue,
265 positive_tail: bool,
268}
269
270#[deprecated(
275 since = "54.0.0",
276 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
277)]
278#[derive(Clone, Debug, PartialEq)]
279pub struct GaussianDistribution {
280 mean: ScalarValue,
281 variance: ScalarValue,
282}
283
284#[deprecated(
289 since = "54.0.0",
290 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
291)]
292#[derive(Clone, Debug, PartialEq)]
293pub struct BernoulliDistribution {
294 p: ScalarValue,
295}
296
297#[deprecated(
302 since = "54.0.0",
303 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
304)]
305#[derive(Clone, Debug, PartialEq)]
306pub struct GenericDistribution {
307 mean: ScalarValue,
308 median: ScalarValue,
309 variance: ScalarValue,
310 range: Interval,
311}
312
313impl UniformDistribution {
314 fn try_new(interval: Interval) -> Result<Self> {
315 assert_ne_or_internal_err!(
316 interval.data_type(),
317 DataType::Boolean,
318 "Construction of a boolean `Uniform` distribution is prohibited, create a `Bernoulli` distribution instead."
319 );
320
321 Ok(Self { interval })
322 }
323
324 pub fn data_type(&self) -> DataType {
325 self.interval.data_type()
326 }
327
328 pub fn mean(&self) -> Result<ScalarValue> {
332 let dt = self.data_type();
334 let two = ScalarValue::from(2).cast_to(&dt)?;
335 let result = self
336 .interval
337 .lower()
338 .add_checked(self.interval.upper())?
339 .div(two);
340 debug_assert!(
341 !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
342 );
343 result
344 }
345
346 pub fn median(&self) -> Result<ScalarValue> {
347 self.mean()
348 }
349
350 pub fn variance(&self) -> Result<ScalarValue> {
354 let width = self.interval.width()?;
356 let dt = width.data_type();
357 let twelve = ScalarValue::from(12).cast_to(&dt)?;
358 let result = width.mul_checked(&width)?.div(twelve);
359 debug_assert!(
360 !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
361 );
362 result
363 }
364
365 pub fn range(&self) -> &Interval {
366 &self.interval
367 }
368}
369
370impl ExponentialDistribution {
371 fn try_new(
372 rate: ScalarValue,
373 offset: ScalarValue,
374 positive_tail: bool,
375 ) -> Result<Self> {
376 let dt = rate.data_type();
377 assert_eq_or_internal_err!(
378 offset.data_type(),
379 dt,
380 "Rate and offset must have the same data type"
381 );
382 assert_or_internal_err!(
383 !offset.is_null(),
384 "Offset of an `ExponentialDistribution` cannot be null"
385 );
386 assert_or_internal_err!(
387 !rate.is_null(),
388 "Rate of an `ExponentialDistribution` cannot be null"
389 );
390 let zero = ScalarValue::new_zero(&dt)?;
391 assert_or_internal_err!(
392 !rate.le(&zero),
393 "Rate of an `ExponentialDistribution` must be positive"
394 );
395 Ok(Self {
396 rate,
397 offset,
398 positive_tail,
399 })
400 }
401
402 pub fn data_type(&self) -> DataType {
403 self.rate.data_type()
404 }
405
406 pub fn rate(&self) -> &ScalarValue {
407 &self.rate
408 }
409
410 pub fn offset(&self) -> &ScalarValue {
411 &self.offset
412 }
413
414 pub fn positive_tail(&self) -> bool {
415 self.positive_tail
416 }
417
418 pub fn mean(&self) -> Result<ScalarValue> {
419 let one = ScalarValue::new_one(&self.data_type())?;
421 let tail_mean = one.div(&self.rate)?;
422 if self.positive_tail {
423 self.offset.add_checked(tail_mean)
424 } else {
425 self.offset.sub_checked(tail_mean)
426 }
427 }
428
429 pub fn median(&self) -> Result<ScalarValue> {
430 let ln_two = ScalarValue::from(LN_2).cast_to(&self.data_type())?;
432 let tail_median = ln_two.div(&self.rate)?;
433 if self.positive_tail {
434 self.offset.add_checked(tail_median)
435 } else {
436 self.offset.sub_checked(tail_median)
437 }
438 }
439
440 pub fn variance(&self) -> Result<ScalarValue> {
441 let one = ScalarValue::new_one(&self.data_type())?;
443 let rate_squared = self.rate.mul_checked(&self.rate)?;
444 one.div(rate_squared)
445 }
446
447 pub fn range(&self) -> Result<Interval> {
448 let end = ScalarValue::try_from(&self.data_type())?;
449 if self.positive_tail {
450 Interval::try_new(self.offset.clone(), end)
451 } else {
452 Interval::try_new(end, self.offset.clone())
453 }
454 }
455}
456
457impl GaussianDistribution {
458 fn try_new(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
459 let dt = mean.data_type();
460 assert_eq_or_internal_err!(
461 variance.data_type(),
462 dt,
463 "Mean and variance must have the same data type"
464 );
465 assert_or_internal_err!(
466 !variance.is_null(),
467 "Variance of a `GaussianDistribution` cannot be null"
468 );
469 let zero = ScalarValue::new_zero(&dt)?;
470 assert_or_internal_err!(
471 !variance.lt(&zero),
472 "Variance of a `GaussianDistribution` must be positive"
473 );
474 Ok(Self { mean, variance })
475 }
476
477 pub fn data_type(&self) -> DataType {
478 self.mean.data_type()
479 }
480
481 pub fn mean(&self) -> &ScalarValue {
482 &self.mean
483 }
484
485 pub fn variance(&self) -> &ScalarValue {
486 &self.variance
487 }
488
489 pub fn median(&self) -> &ScalarValue {
490 self.mean()
491 }
492
493 pub fn range(&self) -> Result<Interval> {
494 Interval::make_unbounded(&self.data_type())
495 }
496}
497
498impl BernoulliDistribution {
499 fn try_new(p: ScalarValue) -> Result<Self> {
500 if p.is_null() {
501 return Ok(Self { p });
502 }
503 let dt = p.data_type();
504 let zero = ScalarValue::new_zero(&dt)?;
505 let one = ScalarValue::new_one(&dt)?;
506 assert_or_internal_err!(
507 p.ge(&zero) && p.le(&one),
508 "Success probability of a `BernoulliDistribution` must be in [0, 1]"
509 );
510 Ok(Self { p })
511 }
512
513 pub fn data_type(&self) -> DataType {
514 self.p.data_type()
515 }
516
517 pub fn p_value(&self) -> &ScalarValue {
518 &self.p
519 }
520
521 pub fn mean(&self) -> &ScalarValue {
522 &self.p
523 }
524
525 pub fn median(&self) -> Result<ScalarValue> {
528 let dt = self.data_type();
529 if self.p.is_null() {
530 ScalarValue::try_from(&dt)
531 } else {
532 let one = ScalarValue::new_one(&dt)?;
533 if one.sub_checked(&self.p)?.lt(&self.p) {
534 ScalarValue::new_one(&dt)
535 } else {
536 ScalarValue::new_zero(&dt)
537 }
538 }
539 }
540
541 pub fn variance(&self) -> Result<ScalarValue> {
544 let dt = self.data_type();
545 let one = ScalarValue::new_one(&dt)?;
546 let result = one.sub_checked(&self.p)?.mul_checked(&self.p);
547 debug_assert!(!self.p.is_null() || result.as_ref().is_ok_and(|r| r.is_null()));
548 result
549 }
550
551 pub fn range(&self) -> Interval {
552 let dt = self.data_type();
553 if ScalarValue::new_zero(&dt).unwrap().eq(&self.p) {
556 Interval::FALSE
557 } else if ScalarValue::new_one(&dt).unwrap().eq(&self.p) {
558 Interval::TRUE
559 } else {
560 Interval::TRUE_OR_FALSE
561 }
562 }
563}
564
565impl GenericDistribution {
566 fn try_new(
567 mean: ScalarValue,
568 median: ScalarValue,
569 variance: ScalarValue,
570 range: Interval,
571 ) -> Result<Self> {
572 assert_ne_or_internal_err!(
573 range.data_type(),
574 DataType::Boolean,
575 "Construction of a boolean `Generic` distribution is prohibited, create a `Bernoulli` distribution instead."
576 );
577
578 let validate_location = |m: &ScalarValue| -> Result<bool> {
579 if m.is_null() {
581 Ok(true)
582 } else {
583 range.contains_value(m)
584 }
585 };
586
587 let locations_valid = validate_location(&mean)? && validate_location(&median)?;
588 let variance_non_negative = if variance.is_null() {
589 true
590 } else {
591 let zero = ScalarValue::new_zero(&variance.data_type())?;
592 !variance.lt(&zero)
593 };
594 assert_or_internal_err!(
595 locations_valid && variance_non_negative,
596 "Tried to construct an invalid `GenericDistribution` instance"
597 );
598
599 Ok(Self {
600 mean,
601 median,
602 variance,
603 range,
604 })
605 }
606
607 pub fn data_type(&self) -> DataType {
608 self.mean.data_type()
609 }
610
611 pub fn mean(&self) -> &ScalarValue {
612 &self.mean
613 }
614
615 pub fn median(&self) -> &ScalarValue {
616 &self.median
617 }
618
619 pub fn variance(&self) -> &ScalarValue {
620 &self.variance
621 }
622
623 pub fn range(&self) -> &Interval {
624 &self.range
625 }
626}
627
628#[deprecated(
632 since = "54.0.0",
633 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
634)]
635pub fn combine_bernoullis(
636 op: &Operator,
637 left: &BernoulliDistribution,
638 right: &BernoulliDistribution,
639) -> Result<BernoulliDistribution> {
640 let left_p = left.p_value();
641 let right_p = right.p_value();
642 match op {
643 Operator::And => match (left_p.is_null(), right_p.is_null()) {
644 (false, false) => {
645 BernoulliDistribution::try_new(left_p.mul_checked(right_p)?)
646 }
647 (false, true) if left_p.eq(&ScalarValue::new_zero(&left_p.data_type())?) => {
648 Ok(left.clone())
649 }
650 (true, false)
651 if right_p.eq(&ScalarValue::new_zero(&right_p.data_type())?) =>
652 {
653 Ok(right.clone())
654 }
655 _ => {
656 let dt = Distribution::target_type(&[left_p, right_p])?;
657 BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
658 }
659 },
660 Operator::Or => match (left_p.is_null(), right_p.is_null()) {
661 (false, false) => {
662 let sum = left_p.add_checked(right_p)?;
663 let product = left_p.mul_checked(right_p)?;
664 let or_success = sum.sub_checked(product)?;
665 BernoulliDistribution::try_new(or_success)
666 }
667 (false, true) if left_p.eq(&ScalarValue::new_one(&left_p.data_type())?) => {
668 Ok(left.clone())
669 }
670 (true, false) if right_p.eq(&ScalarValue::new_one(&right_p.data_type())?) => {
671 Ok(right.clone())
672 }
673 _ => {
674 let dt = Distribution::target_type(&[left_p, right_p])?;
675 BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
676 }
677 },
678 _ => {
679 not_impl_err!("Statistical evaluation only supports AND and OR operators")
680 }
681 }
682}
683
684#[deprecated(
691 since = "54.0.0",
692 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
693)]
694pub fn combine_gaussians(
695 op: &Operator,
696 left: &GaussianDistribution,
697 right: &GaussianDistribution,
698) -> Result<Option<GaussianDistribution>> {
699 match op {
700 Operator::Plus => GaussianDistribution::try_new(
701 left.mean().add_checked(right.mean())?,
702 left.variance().add_checked(right.variance())?,
703 )
704 .map(Some),
705 Operator::Minus => GaussianDistribution::try_new(
706 left.mean().sub_checked(right.mean())?,
707 left.variance().add_checked(right.variance())?,
708 )
709 .map(Some),
710 _ => Ok(None),
711 }
712}
713
714#[deprecated(
719 since = "54.0.0",
720 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
721)]
722pub fn create_bernoulli_from_comparison(
723 op: &Operator,
724 left: &Distribution,
725 right: &Distribution,
726) -> Result<Distribution> {
727 match (left, right) {
728 (Uniform(left), Uniform(right)) => {
729 match op {
730 Operator::Eq | Operator::NotEq => {
731 let (li, ri) = (left.range(), right.range());
732 if let Some(intersection) = li.intersect(ri)? {
733 if let (Some(lc), Some(rc), Some(ic)) = (
736 li.cardinality(),
737 ri.cardinality(),
738 intersection.cardinality(),
739 ) {
740 let pairs = ((lc as u128) * (rc as u128)) as f64;
742 let p = (ic as f64).div_checked(pairs)?;
743 let mut p_value = ScalarValue::from(p);
749 if op == &Operator::NotEq {
750 let one = ScalarValue::from(1.0);
751 p_value = alter_fp_rounding_mode::<false, _>(
752 &one,
753 &p_value,
754 |lhs, rhs| lhs.sub_checked(rhs),
755 )?;
756 };
757 return Distribution::new_bernoulli(p_value);
758 }
759 } else if op == &Operator::Eq {
760 return Distribution::new_bernoulli(ScalarValue::from(0.0));
762 } else {
763 return Distribution::new_bernoulli(ScalarValue::from(1.0));
765 }
766 }
767 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => {
768 }
774 _ => {}
775 }
776 }
777 (Gaussian(_), Gaussian(_)) => {
778 }
781 _ => {}
782 }
783 let (li, ri) = (left.range()?, right.range()?);
784 let range_evaluation = apply_operator(op, &li, &ri)?;
785 if range_evaluation.eq(&Interval::FALSE) {
786 Distribution::new_bernoulli(ScalarValue::from(0.0))
787 } else if range_evaluation.eq(&Interval::TRUE) {
788 Distribution::new_bernoulli(ScalarValue::from(1.0))
789 } else if range_evaluation.eq(&Interval::TRUE_OR_FALSE) {
790 Distribution::new_bernoulli(ScalarValue::try_from(&DataType::Float64)?)
791 } else {
792 internal_err!("This function must be called with a comparison operator")
793 }
794}
795
796#[deprecated(
801 since = "54.0.0",
802 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
803)]
804pub fn new_generic_from_binary_op(
805 op: &Operator,
806 left: &Distribution,
807 right: &Distribution,
808) -> Result<Distribution> {
809 Distribution::new_generic(
810 compute_mean(op, left, right)?,
811 compute_median(op, left, right)?,
812 compute_variance(op, left, right)?,
813 apply_operator(op, &left.range()?, &right.range()?)?,
814 )
815}
816
817#[deprecated(
820 since = "54.0.0",
821 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
822)]
823pub fn compute_mean(
824 op: &Operator,
825 left: &Distribution,
826 right: &Distribution,
827) -> Result<ScalarValue> {
828 let (left_mean, right_mean) = (left.mean()?, right.mean()?);
829
830 match op {
831 Operator::Plus => return left_mean.add_checked(right_mean),
832 Operator::Minus => return left_mean.sub_checked(right_mean),
833 Operator::Multiply => return left_mean.mul_checked(right_mean),
835 Operator::Divide => {}
843 _ => {}
845 }
846 let target_type = Distribution::target_type(&[&left_mean, &right_mean])?;
847 ScalarValue::try_from(target_type)
848}
849
850#[deprecated(
856 since = "54.0.0",
857 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
858)]
859pub fn compute_median(
860 op: &Operator,
861 left: &Distribution,
862 right: &Distribution,
863) -> Result<ScalarValue> {
864 match (left, right) {
865 (Uniform(lu), Uniform(ru)) => {
866 let (left_median, right_median) = (lu.median()?, ru.median()?);
867 match op {
871 Operator::Plus => return left_median.add_checked(right_median),
872 Operator::Minus => return left_median.sub_checked(right_median),
873 _ => {}
875 }
876 }
877 (Gaussian(lg), Gaussian(rg)) => match op {
880 Operator::Plus => return lg.mean().add_checked(rg.mean()),
881 Operator::Minus => return lg.mean().sub_checked(rg.mean()),
882 _ => {}
884 },
885 _ => {}
887 }
888
889 let (left_median, right_median) = (left.median()?, right.median()?);
890 let target_type = Distribution::target_type(&[&left_median, &right_median])?;
891 ScalarValue::try_from(target_type)
892}
893
894#[deprecated(
897 since = "54.0.0",
898 note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
899)]
900pub fn compute_variance(
901 op: &Operator,
902 left: &Distribution,
903 right: &Distribution,
904) -> Result<ScalarValue> {
905 let (left_variance, right_variance) = (left.variance()?, right.variance()?);
906
907 match op {
908 Operator::Plus => return left_variance.add_checked(right_variance),
910 Operator::Minus => return left_variance.add_checked(right_variance),
912 Operator::Multiply => {
914 let (left_mean, right_mean) = (left.mean()?, right.mean()?);
918 let left_mean_sq = left_mean.mul_checked(&left_mean)?;
919 let right_mean_sq = right_mean.mul_checked(&right_mean)?;
920 let left_sos = left_variance.add_checked(&left_mean_sq)?;
921 let right_sos = right_variance.add_checked(&right_mean_sq)?;
922 let pos = left_mean_sq.mul_checked(right_mean_sq)?;
923 return left_sos.mul_checked(right_sos)?.sub_checked(pos);
924 }
925 Operator::Divide => {}
933 _ => {}
935 }
936 let target_type = Distribution::target_type(&[&left_variance, &right_variance])?;
937 ScalarValue::try_from(target_type)
938}
939
940#[cfg(test)]
941mod tests {
942 use super::{
943 BernoulliDistribution, Distribution, GaussianDistribution, UniformDistribution,
944 combine_bernoullis, combine_gaussians, compute_mean, compute_median,
945 compute_variance, create_bernoulli_from_comparison, new_generic_from_binary_op,
946 };
947 use crate::interval_arithmetic::{Interval, apply_operator};
948 use crate::operator::Operator;
949
950 use arrow::datatypes::DataType;
951 use datafusion_common::{HashSet, Result, ScalarValue};
952
953 #[test]
954 fn uniform_dist_is_valid_test() -> Result<()> {
955 assert_eq!(
956 Distribution::new_uniform(Interval::make_zero(&DataType::Int8)?)?,
957 Distribution::Uniform(UniformDistribution {
958 interval: Interval::make_zero(&DataType::Int8)?,
959 })
960 );
961
962 assert!(Distribution::new_uniform(Interval::TRUE_OR_FALSE).is_err());
963 Ok(())
964 }
965
966 #[test]
967 fn exponential_dist_is_valid_test() {
968 let exponentials = vec![
970 (
971 Distribution::new_exponential(ScalarValue::Null, ScalarValue::Null, true),
972 false,
973 ),
974 (
975 Distribution::new_exponential(
976 ScalarValue::from(0_f32),
977 ScalarValue::from(1_f32),
978 true,
979 ),
980 false,
981 ),
982 (
983 Distribution::new_exponential(
984 ScalarValue::from(100_f32),
985 ScalarValue::from(1_f32),
986 true,
987 ),
988 true,
989 ),
990 (
991 Distribution::new_exponential(
992 ScalarValue::from(-100_f32),
993 ScalarValue::from(1_f32),
994 true,
995 ),
996 false,
997 ),
998 ];
999 for case in exponentials {
1000 assert_eq!(case.0.is_ok(), case.1);
1001 }
1002 }
1003
1004 #[test]
1005 fn gaussian_dist_is_valid_test() {
1006 let gaussians = vec![
1008 (
1009 Distribution::new_gaussian(ScalarValue::Null, ScalarValue::Null),
1010 false,
1011 ),
1012 (
1013 Distribution::new_gaussian(
1014 ScalarValue::from(0_f32),
1015 ScalarValue::from(0_f32),
1016 ),
1017 true,
1018 ),
1019 (
1020 Distribution::new_gaussian(
1021 ScalarValue::from(0_f32),
1022 ScalarValue::from(0.5_f32),
1023 ),
1024 true,
1025 ),
1026 (
1027 Distribution::new_gaussian(
1028 ScalarValue::from(0_f32),
1029 ScalarValue::from(-0.5_f32),
1030 ),
1031 false,
1032 ),
1033 ];
1034 for case in gaussians {
1035 assert_eq!(case.0.is_ok(), case.1);
1036 }
1037 }
1038
1039 #[test]
1040 fn bernoulli_dist_is_valid_test() {
1041 let bernoullis = vec![
1043 (Distribution::new_bernoulli(ScalarValue::Null), true),
1044 (Distribution::new_bernoulli(ScalarValue::from(0.)), true),
1045 (Distribution::new_bernoulli(ScalarValue::from(0.25)), true),
1046 (Distribution::new_bernoulli(ScalarValue::from(1.)), true),
1047 (Distribution::new_bernoulli(ScalarValue::from(11.)), false),
1048 (Distribution::new_bernoulli(ScalarValue::from(-11.)), false),
1049 (Distribution::new_bernoulli(ScalarValue::from(0_i64)), true),
1050 (Distribution::new_bernoulli(ScalarValue::from(1_i64)), true),
1051 (
1052 Distribution::new_bernoulli(ScalarValue::from(11_i64)),
1053 false,
1054 ),
1055 (
1056 Distribution::new_bernoulli(ScalarValue::from(-11_i64)),
1057 false,
1058 ),
1059 ];
1060 for case in bernoullis {
1061 assert_eq!(case.0.is_ok(), case.1);
1062 }
1063 }
1064
1065 #[test]
1066 fn generic_dist_is_valid_test() -> Result<()> {
1067 let generic_dists = vec![
1069 (
1071 Distribution::new_generic(
1072 ScalarValue::Null,
1073 ScalarValue::Null,
1074 ScalarValue::Null,
1075 Interval::TRUE_OR_FALSE,
1076 ),
1077 false,
1078 ),
1079 (
1080 Distribution::new_generic(
1081 ScalarValue::Null,
1082 ScalarValue::Null,
1083 ScalarValue::Null,
1084 Interval::make_zero(&DataType::Float32)?,
1085 ),
1086 true,
1087 ),
1088 (
1089 Distribution::new_generic(
1090 ScalarValue::from(0_f32),
1091 ScalarValue::Float32(None),
1092 ScalarValue::Float32(None),
1093 Interval::make_zero(&DataType::Float32)?,
1094 ),
1095 true,
1096 ),
1097 (
1098 Distribution::new_generic(
1099 ScalarValue::Float64(None),
1100 ScalarValue::from(0.),
1101 ScalarValue::Float64(None),
1102 Interval::make_zero(&DataType::Float32)?,
1103 ),
1104 true,
1105 ),
1106 (
1107 Distribution::new_generic(
1108 ScalarValue::from(-10_f32),
1109 ScalarValue::Float32(None),
1110 ScalarValue::Float32(None),
1111 Interval::make_zero(&DataType::Float32)?,
1112 ),
1113 false,
1114 ),
1115 (
1116 Distribution::new_generic(
1117 ScalarValue::Float32(None),
1118 ScalarValue::from(10_f32),
1119 ScalarValue::Float32(None),
1120 Interval::make_zero(&DataType::Float32)?,
1121 ),
1122 false,
1123 ),
1124 (
1125 Distribution::new_generic(
1126 ScalarValue::Null,
1127 ScalarValue::Null,
1128 ScalarValue::Null,
1129 Interval::make_zero(&DataType::Float32)?,
1130 ),
1131 true,
1132 ),
1133 (
1134 Distribution::new_generic(
1135 ScalarValue::from(0),
1136 ScalarValue::from(0),
1137 ScalarValue::Int32(None),
1138 Interval::make_zero(&DataType::Int32)?,
1139 ),
1140 true,
1141 ),
1142 (
1143 Distribution::new_generic(
1144 ScalarValue::from(0_f32),
1145 ScalarValue::from(0_f32),
1146 ScalarValue::Float32(None),
1147 Interval::make_zero(&DataType::Float32)?,
1148 ),
1149 true,
1150 ),
1151 (
1152 Distribution::new_generic(
1153 ScalarValue::from(50.),
1154 ScalarValue::from(50.),
1155 ScalarValue::Float64(None),
1156 Interval::make(Some(0.), Some(100.))?,
1157 ),
1158 true,
1159 ),
1160 (
1161 Distribution::new_generic(
1162 ScalarValue::from(50.),
1163 ScalarValue::from(50.),
1164 ScalarValue::Float64(None),
1165 Interval::make(Some(-100.), Some(0.))?,
1166 ),
1167 false,
1168 ),
1169 (
1170 Distribution::new_generic(
1171 ScalarValue::Float64(None),
1172 ScalarValue::Float64(None),
1173 ScalarValue::from(1.),
1174 Interval::make_zero(&DataType::Float64)?,
1175 ),
1176 true,
1177 ),
1178 (
1179 Distribution::new_generic(
1180 ScalarValue::Float64(None),
1181 ScalarValue::Float64(None),
1182 ScalarValue::from(-1.),
1183 Interval::make_zero(&DataType::Float64)?,
1184 ),
1185 false,
1186 ),
1187 ];
1188 for case in generic_dists {
1189 assert_eq!(case.0.is_ok(), case.1, "{:?}", case.0);
1190 }
1191
1192 Ok(())
1193 }
1194
1195 #[test]
1196 fn mean_extraction_test() -> Result<()> {
1197 let dists = vec![
1199 (
1200 Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1201 ScalarValue::from(0_i64),
1202 ),
1203 (
1204 Distribution::new_uniform(Interval::make_zero(&DataType::Float64)?),
1205 ScalarValue::from(0.),
1206 ),
1207 (
1208 Distribution::new_uniform(Interval::make(Some(1), Some(100))?),
1209 ScalarValue::from(50),
1210 ),
1211 (
1212 Distribution::new_uniform(Interval::make(Some(-100), Some(-1))?),
1213 ScalarValue::from(-50),
1214 ),
1215 (
1216 Distribution::new_uniform(Interval::make(Some(-100), Some(100))?),
1217 ScalarValue::from(0),
1218 ),
1219 (
1220 Distribution::new_exponential(
1221 ScalarValue::from(2.),
1222 ScalarValue::from(0.),
1223 true,
1224 ),
1225 ScalarValue::from(0.5),
1226 ),
1227 (
1228 Distribution::new_exponential(
1229 ScalarValue::from(2.),
1230 ScalarValue::from(1.),
1231 true,
1232 ),
1233 ScalarValue::from(1.5),
1234 ),
1235 (
1236 Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1237 ScalarValue::from(0.),
1238 ),
1239 (
1240 Distribution::new_gaussian(
1241 ScalarValue::from(-2.),
1242 ScalarValue::from(0.5),
1243 ),
1244 ScalarValue::from(-2.),
1245 ),
1246 (
1247 Distribution::new_bernoulli(ScalarValue::from(0.5)),
1248 ScalarValue::from(0.5),
1249 ),
1250 (
1251 Distribution::new_generic(
1252 ScalarValue::from(42.),
1253 ScalarValue::from(42.),
1254 ScalarValue::Float64(None),
1255 Interval::make(Some(25.), Some(50.))?,
1256 ),
1257 ScalarValue::from(42.),
1258 ),
1259 ];
1260
1261 for case in dists {
1262 assert_eq!(case.0?.mean()?, case.1);
1263 }
1264
1265 Ok(())
1266 }
1267
1268 #[test]
1269 fn median_extraction_test() -> Result<()> {
1270 let dists = vec![
1272 (
1273 Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1274 ScalarValue::from(0_i64),
1275 ),
1276 (
1277 Distribution::new_uniform(Interval::make(Some(25.), Some(75.))?),
1278 ScalarValue::from(50.),
1279 ),
1280 (
1281 Distribution::new_exponential(
1282 ScalarValue::from(2_f64.ln()),
1283 ScalarValue::from(0.),
1284 true,
1285 ),
1286 ScalarValue::from(1.),
1287 ),
1288 (
1289 Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1290 ScalarValue::from(2.),
1291 ),
1292 (
1293 Distribution::new_bernoulli(ScalarValue::from(0.25)),
1294 ScalarValue::from(0.),
1295 ),
1296 (
1297 Distribution::new_bernoulli(ScalarValue::from(0.75)),
1298 ScalarValue::from(1.),
1299 ),
1300 (
1301 Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1302 ScalarValue::from(2.),
1303 ),
1304 (
1305 Distribution::new_generic(
1306 ScalarValue::from(12.),
1307 ScalarValue::from(12.),
1308 ScalarValue::Float64(None),
1309 Interval::make(Some(0.), Some(25.))?,
1310 ),
1311 ScalarValue::from(12.),
1312 ),
1313 ];
1314
1315 for case in dists {
1316 assert_eq!(case.0?.median()?, case.1);
1317 }
1318
1319 Ok(())
1320 }
1321
1322 #[test]
1323 fn variance_extraction_test() -> Result<()> {
1324 let dists = vec![
1326 (
1327 Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?),
1328 ScalarValue::from(12.),
1329 ),
1330 (
1331 Distribution::new_exponential(
1332 ScalarValue::from(10.),
1333 ScalarValue::from(0.),
1334 true,
1335 ),
1336 ScalarValue::from(0.01),
1337 ),
1338 (
1339 Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1340 ScalarValue::from(1.),
1341 ),
1342 (
1343 Distribution::new_bernoulli(ScalarValue::from(0.5)),
1344 ScalarValue::from(0.25),
1345 ),
1346 (
1347 Distribution::new_generic(
1348 ScalarValue::Float64(None),
1349 ScalarValue::Float64(None),
1350 ScalarValue::from(0.02),
1351 Interval::make_zero(&DataType::Float64)?,
1352 ),
1353 ScalarValue::from(0.02),
1354 ),
1355 ];
1356
1357 for case in dists {
1358 assert_eq!(case.0?.variance()?, case.1);
1359 }
1360
1361 Ok(())
1362 }
1363
1364 #[test]
1365 fn test_calculate_generic_properties_gauss_gauss() -> Result<()> {
1366 let dist_a =
1367 Distribution::new_gaussian(ScalarValue::from(10.), ScalarValue::from(0.0))?;
1368 let dist_b =
1369 Distribution::new_gaussian(ScalarValue::from(20.), ScalarValue::from(0.0))?;
1370
1371 let test_data = vec![
1372 (
1374 compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1375 ScalarValue::from(30.),
1376 ),
1377 (
1378 compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1379 ScalarValue::from(-10.),
1380 ),
1381 (
1383 compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1384 ScalarValue::from(30.),
1385 ),
1386 (
1387 compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1388 ScalarValue::from(-10.),
1389 ),
1390 ];
1391 for (actual, expected) in test_data {
1392 assert_eq!(actual, expected);
1393 }
1394
1395 Ok(())
1396 }
1397
1398 #[test]
1399 fn test_combine_bernoullis_and_op() -> Result<()> {
1400 let op = Operator::And;
1401 let left = BernoulliDistribution::try_new(ScalarValue::from(0.5))?;
1402 let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1403 let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1404 let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1405
1406 assert_eq!(
1407 combine_bernoullis(&op, &left, &right)?.p_value(),
1408 &ScalarValue::from(0.5 * 0.4)
1409 );
1410 assert_eq!(
1411 combine_bernoullis(&op, &left_null, &right)?.p_value(),
1412 &ScalarValue::Float64(None)
1413 );
1414 assert_eq!(
1415 combine_bernoullis(&op, &left, &right_null)?.p_value(),
1416 &ScalarValue::Float64(None)
1417 );
1418 assert_eq!(
1419 combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1420 &ScalarValue::Null
1421 );
1422
1423 Ok(())
1424 }
1425
1426 #[test]
1427 fn test_combine_bernoullis_or_op() -> Result<()> {
1428 let op = Operator::Or;
1429 let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1430 let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1431 let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1432 let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1433
1434 assert_eq!(
1435 combine_bernoullis(&op, &left, &right)?.p_value(),
1436 &ScalarValue::from(0.6 + 0.4 - (0.6 * 0.4))
1437 );
1438 assert_eq!(
1439 combine_bernoullis(&op, &left_null, &right)?.p_value(),
1440 &ScalarValue::Float64(None)
1441 );
1442 assert_eq!(
1443 combine_bernoullis(&op, &left, &right_null)?.p_value(),
1444 &ScalarValue::Float64(None)
1445 );
1446 assert_eq!(
1447 combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1448 &ScalarValue::Null
1449 );
1450
1451 Ok(())
1452 }
1453
1454 #[test]
1455 fn test_combine_bernoullis_unsupported_ops() -> Result<()> {
1456 let mut operator_set = operator_set();
1457 operator_set.remove(&Operator::And);
1458 operator_set.remove(&Operator::Or);
1459
1460 let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1461 let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1462 for op in operator_set {
1463 assert!(
1464 combine_bernoullis(&op, &left, &right).is_err(),
1465 "Operator {op} should not be supported for Bernoulli distributions"
1466 );
1467 }
1468
1469 Ok(())
1470 }
1471
1472 #[test]
1473 fn test_combine_gaussians_addition() -> Result<()> {
1474 let op = Operator::Plus;
1475 let left = GaussianDistribution::try_new(
1476 ScalarValue::from(3.0),
1477 ScalarValue::from(2.0),
1478 )?;
1479 let right = GaussianDistribution::try_new(
1480 ScalarValue::from(4.0),
1481 ScalarValue::from(1.0),
1482 )?;
1483
1484 let result = combine_gaussians(&op, &left, &right)?.unwrap();
1485
1486 assert_eq!(result.mean(), &ScalarValue::from(7.0)); assert_eq!(result.variance(), &ScalarValue::from(3.0)); Ok(())
1489 }
1490
1491 #[test]
1492 fn test_combine_gaussians_subtraction() -> Result<()> {
1493 let op = Operator::Minus;
1494 let left = GaussianDistribution::try_new(
1495 ScalarValue::from(7.0),
1496 ScalarValue::from(2.0),
1497 )?;
1498 let right = GaussianDistribution::try_new(
1499 ScalarValue::from(4.0),
1500 ScalarValue::from(1.0),
1501 )?;
1502
1503 let result = combine_gaussians(&op, &left, &right)?.unwrap();
1504
1505 assert_eq!(result.mean(), &ScalarValue::from(3.0)); assert_eq!(result.variance(), &ScalarValue::from(3.0)); Ok(())
1509 }
1510
1511 #[test]
1512 fn test_combine_gaussians_unsupported_ops() -> Result<()> {
1513 let mut operator_set = operator_set();
1514 operator_set.remove(&Operator::Plus);
1515 operator_set.remove(&Operator::Minus);
1516
1517 let left = GaussianDistribution::try_new(
1518 ScalarValue::from(7.0),
1519 ScalarValue::from(2.0),
1520 )?;
1521 let right = GaussianDistribution::try_new(
1522 ScalarValue::from(4.0),
1523 ScalarValue::from(1.0),
1524 )?;
1525 for op in operator_set {
1526 assert!(
1527 combine_gaussians(&op, &left, &right)?.is_none(),
1528 "Operator {op} should not be supported for Gaussian distributions"
1529 );
1530 }
1531
1532 Ok(())
1533 }
1534
1535 #[test]
1542 fn test_calculate_generic_properties_uniform_uniform() -> Result<()> {
1543 let dist_a = Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?)?;
1544 let dist_b = Distribution::new_uniform(Interval::make(Some(12.), Some(36.))?)?;
1545
1546 let test_data = vec![
1547 (
1549 compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1550 ScalarValue::from(30.),
1551 ),
1552 (
1553 compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1554 ScalarValue::from(-18.),
1555 ),
1556 (
1557 compute_mean(&Operator::Multiply, &dist_a, &dist_b)?,
1558 ScalarValue::from(144.),
1559 ),
1560 (
1562 compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1563 ScalarValue::from(30.),
1564 ),
1565 (
1566 compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1567 ScalarValue::from(-18.),
1568 ),
1569 (
1571 compute_variance(&Operator::Plus, &dist_a, &dist_b)?,
1572 ScalarValue::from(60.),
1573 ),
1574 (
1575 compute_variance(&Operator::Minus, &dist_a, &dist_b)?,
1576 ScalarValue::from(60.),
1577 ),
1578 (
1579 compute_variance(&Operator::Multiply, &dist_a, &dist_b)?,
1580 ScalarValue::from(9216.),
1581 ),
1582 ];
1583 for (actual, expected) in test_data {
1584 assert_eq!(actual, expected);
1585 }
1586
1587 Ok(())
1588 }
1589
1590 #[test]
1593 fn test_compute_range_where_present() -> Result<()> {
1594 let a = &Interval::make(Some(0.), Some(12.0))?;
1595 let b = &Interval::make(Some(0.), Some(12.0))?;
1596 let mean = ScalarValue::from(6.0);
1597 for (dist_a, dist_b) in [
1598 (
1599 Distribution::new_uniform(a.clone())?,
1600 Distribution::new_uniform(b.clone())?,
1601 ),
1602 (
1603 Distribution::new_generic(
1604 mean.clone(),
1605 mean.clone(),
1606 ScalarValue::Float64(None),
1607 a.clone(),
1608 )?,
1609 Distribution::new_uniform(b.clone())?,
1610 ),
1611 (
1612 Distribution::new_uniform(a.clone())?,
1613 Distribution::new_generic(
1614 mean.clone(),
1615 mean.clone(),
1616 ScalarValue::Float64(None),
1617 b.clone(),
1618 )?,
1619 ),
1620 (
1621 Distribution::new_generic(
1622 mean.clone(),
1623 mean.clone(),
1624 ScalarValue::Float64(None),
1625 a.clone(),
1626 )?,
1627 Distribution::new_generic(
1628 mean.clone(),
1629 mean.clone(),
1630 ScalarValue::Float64(None),
1631 b.clone(),
1632 )?,
1633 ),
1634 ] {
1635 use super::Operator::{
1636 Divide, Eq, Gt, GtEq, Lt, LtEq, Minus, Multiply, NotEq, Plus,
1637 };
1638 for op in [Plus, Minus, Multiply, Divide] {
1639 assert_eq!(
1640 new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?,
1641 apply_operator(&op, a, b)?,
1642 "Failed for {dist_a:?} {op} {dist_b:?}"
1643 );
1644 }
1645 for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] {
1646 assert_eq!(
1647 create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?,
1648 apply_operator(&op, a, b)?,
1649 "Failed for {dist_a:?} {op} {dist_b:?}"
1650 );
1651 }
1652 }
1653
1654 Ok(())
1655 }
1656
1657 fn operator_set() -> HashSet<Operator> {
1658 use super::Operator::*;
1659
1660 let all_ops = vec![
1661 And,
1662 Or,
1663 Eq,
1664 NotEq,
1665 Gt,
1666 GtEq,
1667 Lt,
1668 LtEq,
1669 Plus,
1670 Minus,
1671 Multiply,
1672 Divide,
1673 Modulo,
1674 IsDistinctFrom,
1675 IsNotDistinctFrom,
1676 RegexMatch,
1677 RegexIMatch,
1678 RegexNotMatch,
1679 RegexNotIMatch,
1680 LikeMatch,
1681 ILikeMatch,
1682 NotLikeMatch,
1683 NotILikeMatch,
1684 BitwiseAnd,
1685 BitwiseOr,
1686 BitwiseXor,
1687 BitwiseShiftRight,
1688 BitwiseShiftLeft,
1689 StringConcat,
1690 AtArrow,
1691 ArrowAt,
1692 ];
1693
1694 all_ops.into_iter().collect()
1695 }
1696}