datafusion_expr_common/
statistics.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::f64::consts::LN_2;
19
20use crate::interval_arithmetic::{apply_operator, Interval};
21use crate::operator::Operator;
22use crate::type_coercion::binary::binary_numeric_coercion;
23
24use arrow::array::ArrowNativeTypeOp;
25use arrow::datatypes::DataType;
26use datafusion_common::rounding::alter_fp_rounding_mode;
27use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
28
29/// This object defines probabilistic distributions that encode uncertain
30/// information about a single, scalar value. Currently, we support five core
31/// statistical distributions. New variants will be added over time.
32///
33/// This object is the lowest-level object in the statistics hierarchy, and it
34/// is the main unit of calculus when evaluating expressions in a statistical
35/// context. Notions like column and table statistics are built on top of this
36/// object and the operations it supports.
37#[derive(Clone, Debug, PartialEq)]
38pub enum Distribution {
39    Uniform(UniformDistribution),
40    Exponential(ExponentialDistribution),
41    Gaussian(GaussianDistribution),
42    Bernoulli(BernoulliDistribution),
43    Generic(GenericDistribution),
44}
45
46use Distribution::{Bernoulli, Exponential, Gaussian, Generic, Uniform};
47
48impl Distribution {
49    /// Constructs a new [`Uniform`] distribution from the given [`Interval`].
50    pub fn new_uniform(interval: Interval) -> Result<Self> {
51        UniformDistribution::try_new(interval).map(Uniform)
52    }
53
54    /// Constructs a new [`Exponential`] distribution from the given rate/offset
55    /// pair, and validates the given parameters.
56    pub fn new_exponential(
57        rate: ScalarValue,
58        offset: ScalarValue,
59        positive_tail: bool,
60    ) -> Result<Self> {
61        ExponentialDistribution::try_new(rate, offset, positive_tail).map(Exponential)
62    }
63
64    /// Constructs a new [`Gaussian`] distribution from the given mean/variance
65    /// pair, and validates the given parameters.
66    pub fn new_gaussian(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
67        GaussianDistribution::try_new(mean, variance).map(Gaussian)
68    }
69
70    /// Constructs a new [`Bernoulli`] distribution from the given success
71    /// probability, and validates the given parameters.
72    pub fn new_bernoulli(p: ScalarValue) -> Result<Self> {
73        BernoulliDistribution::try_new(p).map(Bernoulli)
74    }
75
76    /// Constructs a new [`Generic`] distribution from the given mean, median,
77    /// variance, and range values after validating the given parameters.
78    pub fn new_generic(
79        mean: ScalarValue,
80        median: ScalarValue,
81        variance: ScalarValue,
82        range: Interval,
83    ) -> Result<Self> {
84        GenericDistribution::try_new(mean, median, variance, range).map(Generic)
85    }
86
87    /// Constructs a new [`Generic`] distribution from the given range. Other
88    /// parameters (mean, median and variance) are initialized with null values.
89    pub fn new_from_interval(range: Interval) -> Result<Self> {
90        let null = ScalarValue::try_from(range.data_type())?;
91        Distribution::new_generic(null.clone(), null.clone(), null, range)
92    }
93
94    /// Extracts the mean value of this uncertain quantity, depending on its
95    /// distribution:
96    /// - A [`Uniform`] distribution's interval determines its mean value, which
97    ///   is the arithmetic average of the interval endpoints.
98    /// - An [`Exponential`] distribution's mean is calculable by the formula
99    ///   `offset + 1 / λ`, where `λ` is the (non-negative) rate.
100    /// - A [`Gaussian`] distribution contains the mean explicitly.
101    /// - A [`Bernoulli`] distribution's mean is equal to its success probability `p`.
102    /// - A [`Generic`] distribution _may_ have it explicitly, or this information
103    ///   may be absent.
104    pub fn mean(&self) -> Result<ScalarValue> {
105        match &self {
106            Uniform(u) => u.mean(),
107            Exponential(e) => e.mean(),
108            Gaussian(g) => Ok(g.mean().clone()),
109            Bernoulli(b) => Ok(b.mean().clone()),
110            Generic(u) => Ok(u.mean().clone()),
111        }
112    }
113
114    /// Extracts the median value of this uncertain quantity, depending on its
115    /// distribution:
116    /// - A [`Uniform`] distribution's interval determines its median value, which
117    ///   is the arithmetic average of the interval endpoints.
118    /// - An [`Exponential`] distribution's median is calculable by the formula
119    ///   `offset + ln(2) / λ`, where `λ` is the (non-negative) rate.
120    /// - A [`Gaussian`] distribution's median is equal to its mean, which is
121    ///   specified explicitly.
122    /// - A [`Bernoulli`] distribution's median is `1` if `p > 0.5` and `0`
123    ///   otherwise, where `p` is the success probability.
124    /// - A [`Generic`] distribution _may_ have it explicitly, or this information
125    ///   may be absent.
126    pub fn median(&self) -> Result<ScalarValue> {
127        match &self {
128            Uniform(u) => u.median(),
129            Exponential(e) => e.median(),
130            Gaussian(g) => Ok(g.median().clone()),
131            Bernoulli(b) => b.median(),
132            Generic(u) => Ok(u.median().clone()),
133        }
134    }
135
136    /// Extracts the variance value of this uncertain quantity, depending on
137    /// its distribution:
138    /// - A [`Uniform`] distribution's interval determines its variance value, which
139    ///   is calculable by the formula `(upper - lower) ^ 2 / 12`.
140    /// - An [`Exponential`] distribution's variance is calculable by the formula
141    ///   `1 / (λ ^ 2)`, where `λ` is the (non-negative) rate.
142    /// - A [`Gaussian`] distribution's variance is specified explicitly.
143    /// - A [`Bernoulli`] distribution's median is given by the formula `p * (1 - p)`
144    ///   where `p` is the success probability.
145    /// - A [`Generic`] distribution _may_ have it explicitly, or this information
146    ///   may be absent.
147    pub fn variance(&self) -> Result<ScalarValue> {
148        match &self {
149            Uniform(u) => u.variance(),
150            Exponential(e) => e.variance(),
151            Gaussian(g) => Ok(g.variance.clone()),
152            Bernoulli(b) => b.variance(),
153            Generic(u) => Ok(u.variance.clone()),
154        }
155    }
156
157    /// Extracts the range of this uncertain quantity, depending on its
158    /// distribution:
159    /// - A [`Uniform`] distribution's range is simply its interval.
160    /// - An [`Exponential`] distribution's range is `[offset, +∞)`.
161    /// - A [`Gaussian`] distribution's range is unbounded.
162    /// - A [`Bernoulli`] distribution's range is [`Interval::UNCERTAIN`], if
163    ///   `p` is neither `0` nor `1`. Otherwise, it is [`Interval::CERTAINLY_FALSE`]
164    ///   and [`Interval::CERTAINLY_TRUE`], respectively.
165    /// - A [`Generic`] distribution is unbounded by default, but more information
166    ///   may be present.
167    pub fn range(&self) -> Result<Interval> {
168        match &self {
169            Uniform(u) => Ok(u.range().clone()),
170            Exponential(e) => e.range(),
171            Gaussian(g) => g.range(),
172            Bernoulli(b) => Ok(b.range()),
173            Generic(u) => Ok(u.range().clone()),
174        }
175    }
176
177    /// Returns the data type of the statistical parameters comprising this
178    /// distribution.
179    pub fn data_type(&self) -> DataType {
180        match &self {
181            Uniform(u) => u.data_type(),
182            Exponential(e) => e.data_type(),
183            Gaussian(g) => g.data_type(),
184            Bernoulli(b) => b.data_type(),
185            Generic(u) => u.data_type(),
186        }
187    }
188
189    pub fn target_type(args: &[&ScalarValue]) -> Result<DataType> {
190        let mut arg_types = args
191            .iter()
192            .filter(|&&arg| (arg != &ScalarValue::Null))
193            .map(|&arg| arg.data_type());
194
195        let Some(dt) = arg_types.next().map_or_else(
196            || Some(DataType::Null),
197            |first| {
198                arg_types
199                    .try_fold(first, |target, arg| binary_numeric_coercion(&target, &arg))
200            },
201        ) else {
202            return internal_err!("Can only evaluate statistics for numeric types");
203        };
204        Ok(dt)
205    }
206}
207
208/// Uniform distribution, represented by its range. If the given range extends
209/// towards infinity, the distribution will be improper -- which is OK. For a
210/// more in-depth discussion, see:
211///
212/// <https://en.wikipedia.org/wiki/Continuous_uniform_distribution>
213/// <https://en.wikipedia.org/wiki/Prior_probability#Improper_priors>
214#[derive(Clone, Debug, PartialEq)]
215pub struct UniformDistribution {
216    interval: Interval,
217}
218
219/// Exponential distribution with an optional shift. The probability density
220/// function (PDF) is defined as follows:
221///
222/// For a positive tail (when `positive_tail` is `true`):
223///
224/// `f(x; λ, offset) = λ exp(-λ (x - offset))    for x ≥ offset`
225///
226/// For a negative tail (when `positive_tail` is `false`):
227///
228/// `f(x; λ, offset) = λ exp(-λ (offset - x))    for x ≤ offset`
229///
230///
231/// In both cases, the PDF is `0` outside the specified domain.
232///
233/// For more information, see:
234///
235/// <https://en.wikipedia.org/wiki/Exponential_distribution>
236#[derive(Clone, Debug, PartialEq)]
237pub struct ExponentialDistribution {
238    rate: ScalarValue,
239    offset: ScalarValue,
240    /// Indicates whether the exponential distribution has a positive tail;
241    /// i.e. it extends towards positive infinity.
242    positive_tail: bool,
243}
244
245/// Gaussian (normal) distribution, represented by its mean and variance.
246/// For a more in-depth discussion, see:
247///
248/// <https://en.wikipedia.org/wiki/Normal_distribution>
249#[derive(Clone, Debug, PartialEq)]
250pub struct GaussianDistribution {
251    mean: ScalarValue,
252    variance: ScalarValue,
253}
254
255/// Bernoulli distribution with success probability `p`. If `p` has a null value,
256/// the success probability is unknown. For a more in-depth discussion, see:
257///
258/// <https://en.wikipedia.org/wiki/Bernoulli_distribution>
259#[derive(Clone, Debug, PartialEq)]
260pub struct BernoulliDistribution {
261    p: ScalarValue,
262}
263
264/// A generic distribution whose functional form is not available, which is
265/// approximated via some summary statistics. For a more in-depth discussion, see:
266///
267/// <https://en.wikipedia.org/wiki/Summary_statistics>
268#[derive(Clone, Debug, PartialEq)]
269pub struct GenericDistribution {
270    mean: ScalarValue,
271    median: ScalarValue,
272    variance: ScalarValue,
273    range: Interval,
274}
275
276impl UniformDistribution {
277    fn try_new(interval: Interval) -> Result<Self> {
278        if interval.data_type().eq(&DataType::Boolean) {
279            return internal_err!(
280                "Construction of a boolean `Uniform` distribution is prohibited, create a `Bernoulli` distribution instead."
281            );
282        }
283
284        Ok(Self { interval })
285    }
286
287    pub fn data_type(&self) -> DataType {
288        self.interval.data_type()
289    }
290
291    /// Computes the mean value of this distribution. In case of improper
292    /// distributions (i.e. when the range is unbounded), the function returns
293    /// a `NULL` `ScalarValue`.
294    pub fn mean(&self) -> Result<ScalarValue> {
295        // TODO: Should we ensure that this always returns a real number data type?
296        let dt = self.data_type();
297        let two = ScalarValue::from(2).cast_to(&dt)?;
298        let result = self
299            .interval
300            .lower()
301            .add_checked(self.interval.upper())?
302            .div(two);
303        debug_assert!(
304            !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
305        );
306        result
307    }
308
309    pub fn median(&self) -> Result<ScalarValue> {
310        self.mean()
311    }
312
313    /// Computes the variance value of this distribution. In case of improper
314    /// distributions (i.e. when the range is unbounded), the function returns
315    /// a `NULL` `ScalarValue`.
316    pub fn variance(&self) -> Result<ScalarValue> {
317        // TODO: Should we ensure that this always returns a real number data type?
318        let width = self.interval.width()?;
319        let dt = width.data_type();
320        let twelve = ScalarValue::from(12).cast_to(&dt)?;
321        let result = width.mul_checked(&width)?.div(twelve);
322        debug_assert!(
323            !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
324        );
325        result
326    }
327
328    pub fn range(&self) -> &Interval {
329        &self.interval
330    }
331}
332
333impl ExponentialDistribution {
334    fn try_new(
335        rate: ScalarValue,
336        offset: ScalarValue,
337        positive_tail: bool,
338    ) -> Result<Self> {
339        let dt = rate.data_type();
340        if offset.data_type() != dt {
341            internal_err!("Rate and offset must have the same data type")
342        } else if offset.is_null() {
343            internal_err!("Offset of an `ExponentialDistribution` cannot be null")
344        } else if rate.is_null() {
345            internal_err!("Rate of an `ExponentialDistribution` cannot be null")
346        } else if rate.le(&ScalarValue::new_zero(&dt)?) {
347            internal_err!("Rate of an `ExponentialDistribution` must be positive")
348        } else {
349            Ok(Self {
350                rate,
351                offset,
352                positive_tail,
353            })
354        }
355    }
356
357    pub fn data_type(&self) -> DataType {
358        self.rate.data_type()
359    }
360
361    pub fn rate(&self) -> &ScalarValue {
362        &self.rate
363    }
364
365    pub fn offset(&self) -> &ScalarValue {
366        &self.offset
367    }
368
369    pub fn positive_tail(&self) -> bool {
370        self.positive_tail
371    }
372
373    pub fn mean(&self) -> Result<ScalarValue> {
374        // TODO: Should we ensure that this always returns a real number data type?
375        let one = ScalarValue::new_one(&self.data_type())?;
376        let tail_mean = one.div(&self.rate)?;
377        if self.positive_tail {
378            self.offset.add_checked(tail_mean)
379        } else {
380            self.offset.sub_checked(tail_mean)
381        }
382    }
383
384    pub fn median(&self) -> Result<ScalarValue> {
385        // TODO: Should we ensure that this always returns a real number data type?
386        let ln_two = ScalarValue::from(LN_2).cast_to(&self.data_type())?;
387        let tail_median = ln_two.div(&self.rate)?;
388        if self.positive_tail {
389            self.offset.add_checked(tail_median)
390        } else {
391            self.offset.sub_checked(tail_median)
392        }
393    }
394
395    pub fn variance(&self) -> Result<ScalarValue> {
396        // TODO: Should we ensure that this always returns a real number data type?
397        let one = ScalarValue::new_one(&self.data_type())?;
398        let rate_squared = self.rate.mul_checked(&self.rate)?;
399        one.div(rate_squared)
400    }
401
402    pub fn range(&self) -> Result<Interval> {
403        let end = ScalarValue::try_from(&self.data_type())?;
404        if self.positive_tail {
405            Interval::try_new(self.offset.clone(), end)
406        } else {
407            Interval::try_new(end, self.offset.clone())
408        }
409    }
410}
411
412impl GaussianDistribution {
413    fn try_new(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
414        let dt = mean.data_type();
415        if variance.data_type() != dt {
416            internal_err!("Mean and variance must have the same data type")
417        } else if variance.is_null() {
418            internal_err!("Variance of a `GaussianDistribution` cannot be null")
419        } else if variance.lt(&ScalarValue::new_zero(&dt)?) {
420            internal_err!("Variance of a `GaussianDistribution` must be positive")
421        } else {
422            Ok(Self { mean, variance })
423        }
424    }
425
426    pub fn data_type(&self) -> DataType {
427        self.mean.data_type()
428    }
429
430    pub fn mean(&self) -> &ScalarValue {
431        &self.mean
432    }
433
434    pub fn variance(&self) -> &ScalarValue {
435        &self.variance
436    }
437
438    pub fn median(&self) -> &ScalarValue {
439        self.mean()
440    }
441
442    pub fn range(&self) -> Result<Interval> {
443        Interval::make_unbounded(&self.data_type())
444    }
445}
446
447impl BernoulliDistribution {
448    fn try_new(p: ScalarValue) -> Result<Self> {
449        if p.is_null() {
450            Ok(Self { p })
451        } else {
452            let dt = p.data_type();
453            let zero = ScalarValue::new_zero(&dt)?;
454            let one = ScalarValue::new_one(&dt)?;
455            if p.ge(&zero) && p.le(&one) {
456                Ok(Self { p })
457            } else {
458                internal_err!(
459                    "Success probability of a `BernoulliDistribution` must be in [0, 1]"
460                )
461            }
462        }
463    }
464
465    pub fn data_type(&self) -> DataType {
466        self.p.data_type()
467    }
468
469    pub fn p_value(&self) -> &ScalarValue {
470        &self.p
471    }
472
473    pub fn mean(&self) -> &ScalarValue {
474        &self.p
475    }
476
477    /// Computes the median value of this distribution. In case of an unknown
478    /// success probability, the function returns a `NULL` `ScalarValue`.
479    pub fn median(&self) -> Result<ScalarValue> {
480        let dt = self.data_type();
481        if self.p.is_null() {
482            ScalarValue::try_from(&dt)
483        } else {
484            let one = ScalarValue::new_one(&dt)?;
485            if one.sub_checked(&self.p)?.lt(&self.p) {
486                ScalarValue::new_one(&dt)
487            } else {
488                ScalarValue::new_zero(&dt)
489            }
490        }
491    }
492
493    /// Computes the variance value of this distribution. In case of an unknown
494    /// success probability, the function returns a `NULL` `ScalarValue`.
495    pub fn variance(&self) -> Result<ScalarValue> {
496        let dt = self.data_type();
497        let one = ScalarValue::new_one(&dt)?;
498        let result = one.sub_checked(&self.p)?.mul_checked(&self.p);
499        debug_assert!(!self.p.is_null() || result.as_ref().is_ok_and(|r| r.is_null()));
500        result
501    }
502
503    pub fn range(&self) -> Interval {
504        let dt = self.data_type();
505        // Unwraps are safe as the constructor guarantees that the data type
506        // supports zero and one values.
507        if ScalarValue::new_zero(&dt).unwrap().eq(&self.p) {
508            Interval::CERTAINLY_FALSE
509        } else if ScalarValue::new_one(&dt).unwrap().eq(&self.p) {
510            Interval::CERTAINLY_TRUE
511        } else {
512            Interval::UNCERTAIN
513        }
514    }
515}
516
517impl GenericDistribution {
518    fn try_new(
519        mean: ScalarValue,
520        median: ScalarValue,
521        variance: ScalarValue,
522        range: Interval,
523    ) -> Result<Self> {
524        if range.data_type().eq(&DataType::Boolean) {
525            return internal_err!(
526                "Construction of a boolean `Generic` distribution is prohibited, create a `Bernoulli` distribution instead."
527            );
528        }
529
530        let validate_location = |m: &ScalarValue| -> Result<bool> {
531            // Checks whether the given location estimate is within the range.
532            if m.is_null() {
533                Ok(true)
534            } else {
535                range.contains_value(m)
536            }
537        };
538
539        if !validate_location(&mean)?
540            || !validate_location(&median)?
541            || (!variance.is_null()
542                && variance.lt(&ScalarValue::new_zero(&variance.data_type())?))
543        {
544            internal_err!("Tried to construct an invalid `GenericDistribution` instance")
545        } else {
546            Ok(Self {
547                mean,
548                median,
549                variance,
550                range,
551            })
552        }
553    }
554
555    pub fn data_type(&self) -> DataType {
556        self.mean.data_type()
557    }
558
559    pub fn mean(&self) -> &ScalarValue {
560        &self.mean
561    }
562
563    pub fn median(&self) -> &ScalarValue {
564        &self.median
565    }
566
567    pub fn variance(&self) -> &ScalarValue {
568        &self.variance
569    }
570
571    pub fn range(&self) -> &Interval {
572        &self.range
573    }
574}
575
576/// This function takes a logical operator and two Bernoulli distributions,
577/// and it returns a new Bernoulli distribution that represents the result of
578/// the operation. Currently, only `AND` and `OR` operations are supported.
579pub fn combine_bernoullis(
580    op: &Operator,
581    left: &BernoulliDistribution,
582    right: &BernoulliDistribution,
583) -> Result<BernoulliDistribution> {
584    let left_p = left.p_value();
585    let right_p = right.p_value();
586    match op {
587        Operator::And => match (left_p.is_null(), right_p.is_null()) {
588            (false, false) => {
589                BernoulliDistribution::try_new(left_p.mul_checked(right_p)?)
590            }
591            (false, true) if left_p.eq(&ScalarValue::new_zero(&left_p.data_type())?) => {
592                Ok(left.clone())
593            }
594            (true, false)
595                if right_p.eq(&ScalarValue::new_zero(&right_p.data_type())?) =>
596            {
597                Ok(right.clone())
598            }
599            _ => {
600                let dt = Distribution::target_type(&[left_p, right_p])?;
601                BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
602            }
603        },
604        Operator::Or => match (left_p.is_null(), right_p.is_null()) {
605            (false, false) => {
606                let sum = left_p.add_checked(right_p)?;
607                let product = left_p.mul_checked(right_p)?;
608                let or_success = sum.sub_checked(product)?;
609                BernoulliDistribution::try_new(or_success)
610            }
611            (false, true) if left_p.eq(&ScalarValue::new_one(&left_p.data_type())?) => {
612                Ok(left.clone())
613            }
614            (true, false) if right_p.eq(&ScalarValue::new_one(&right_p.data_type())?) => {
615                Ok(right.clone())
616            }
617            _ => {
618                let dt = Distribution::target_type(&[left_p, right_p])?;
619                BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
620            }
621        },
622        _ => {
623            not_impl_err!("Statistical evaluation only supports AND and OR operators")
624        }
625    }
626}
627
628/// Applies the given operation to the given Gaussian distributions. Currently,
629/// this function handles only addition and subtraction operations. If the
630/// result is not a Gaussian random variable, it returns `None`. For details,
631/// see:
632///
633/// <https://en.wikipedia.org/wiki/Sum_of_normally_distributed_random_variables>
634pub fn combine_gaussians(
635    op: &Operator,
636    left: &GaussianDistribution,
637    right: &GaussianDistribution,
638) -> Result<Option<GaussianDistribution>> {
639    match op {
640        Operator::Plus => GaussianDistribution::try_new(
641            left.mean().add_checked(right.mean())?,
642            left.variance().add_checked(right.variance())?,
643        )
644        .map(Some),
645        Operator::Minus => GaussianDistribution::try_new(
646            left.mean().sub_checked(right.mean())?,
647            left.variance().add_checked(right.variance())?,
648        )
649        .map(Some),
650        _ => Ok(None),
651    }
652}
653
654/// Creates a new `Bernoulli` distribution by computing the resulting probability.
655/// Expects `op` to be a comparison operator, with `left` and `right` having
656/// numeric distributions. The resulting distribution has the `Float64` data
657/// type.
658pub fn create_bernoulli_from_comparison(
659    op: &Operator,
660    left: &Distribution,
661    right: &Distribution,
662) -> Result<Distribution> {
663    match (left, right) {
664        (Uniform(left), Uniform(right)) => {
665            match op {
666                Operator::Eq | Operator::NotEq => {
667                    let (li, ri) = (left.range(), right.range());
668                    if let Some(intersection) = li.intersect(ri)? {
669                        // If the ranges are not disjoint, calculate the probability
670                        // of equality using cardinalities:
671                        if let (Some(lc), Some(rc), Some(ic)) = (
672                            li.cardinality(),
673                            ri.cardinality(),
674                            intersection.cardinality(),
675                        ) {
676                            // Avoid overflow by widening the type temporarily:
677                            let pairs = ((lc as u128) * (rc as u128)) as f64;
678                            let p = (ic as f64).div_checked(pairs)?;
679                            // Alternative approach that may be more stable:
680                            // let p = (ic as f64)
681                            //     .div_checked(lc as f64)?
682                            //     .div_checked(rc as f64)?;
683
684                            let mut p_value = ScalarValue::from(p);
685                            if op == &Operator::NotEq {
686                                let one = ScalarValue::from(1.0);
687                                p_value = alter_fp_rounding_mode::<false, _>(
688                                    &one,
689                                    &p_value,
690                                    |lhs, rhs| lhs.sub_checked(rhs),
691                                )?;
692                            };
693                            return Distribution::new_bernoulli(p_value);
694                        }
695                    } else if op == &Operator::Eq {
696                        // If the ranges are disjoint, probability of equality is 0.
697                        return Distribution::new_bernoulli(ScalarValue::from(0.0));
698                    } else {
699                        // If the ranges are disjoint, probability of not-equality is 1.
700                        return Distribution::new_bernoulli(ScalarValue::from(1.0));
701                    }
702                }
703                Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => {
704                    // TODO: We can handle inequality operators and calculate a
705                    // `p` value instead of falling back to an unknown Bernoulli
706                    // distribution. Note that the strict and non-strict inequalities
707                    // may require slightly different logic in case of real vs.
708                    // integral data types.
709                }
710                _ => {}
711            }
712        }
713        (Gaussian(_), Gaussian(_)) => {
714            // TODO: We can handle Gaussian comparisons and calculate a `p` value
715            //       instead of falling back to an unknown Bernoulli distribution.
716        }
717        _ => {}
718    }
719    let (li, ri) = (left.range()?, right.range()?);
720    let range_evaluation = apply_operator(op, &li, &ri)?;
721    if range_evaluation.eq(&Interval::CERTAINLY_FALSE) {
722        Distribution::new_bernoulli(ScalarValue::from(0.0))
723    } else if range_evaluation.eq(&Interval::CERTAINLY_TRUE) {
724        Distribution::new_bernoulli(ScalarValue::from(1.0))
725    } else if range_evaluation.eq(&Interval::UNCERTAIN) {
726        Distribution::new_bernoulli(ScalarValue::try_from(&DataType::Float64)?)
727    } else {
728        internal_err!("This function must be called with a comparison operator")
729    }
730}
731
732/// Creates a new [`Generic`] distribution that represents the result of the
733/// given binary operation on two unknown quantities represented by their
734/// [`Distribution`] objects. The function computes the mean, median and
735/// variance if possible.
736pub fn new_generic_from_binary_op(
737    op: &Operator,
738    left: &Distribution,
739    right: &Distribution,
740) -> Result<Distribution> {
741    Distribution::new_generic(
742        compute_mean(op, left, right)?,
743        compute_median(op, left, right)?,
744        compute_variance(op, left, right)?,
745        apply_operator(op, &left.range()?, &right.range()?)?,
746    )
747}
748
749/// Computes the mean value for the result of the given binary operation on
750/// two unknown quantities represented by their [`Distribution`] objects.
751pub fn compute_mean(
752    op: &Operator,
753    left: &Distribution,
754    right: &Distribution,
755) -> Result<ScalarValue> {
756    let (left_mean, right_mean) = (left.mean()?, right.mean()?);
757
758    match op {
759        Operator::Plus => return left_mean.add_checked(right_mean),
760        Operator::Minus => return left_mean.sub_checked(right_mean),
761        // Note the independence assumption below:
762        Operator::Multiply => return left_mean.mul_checked(right_mean),
763        // TODO: We can calculate the mean for division when we support reciprocals,
764        // or know the distributions of the operands. For details, see:
765        //
766        // <https://en.wikipedia.org/wiki/Algebra_of_random_variables>
767        // <https://stats.stackexchange.com/questions/185683/distribution-of-ratio-between-two-independent-uniform-random-variables>
768        //
769        // Fall back to an unknown mean value for division:
770        Operator::Divide => {}
771        // Fall back to an unknown mean value for other cases:
772        _ => {}
773    }
774    let target_type = Distribution::target_type(&[&left_mean, &right_mean])?;
775    ScalarValue::try_from(target_type)
776}
777
778/// Computes the median value for the result of the given binary operation on
779/// two unknown quantities represented by its [`Distribution`] objects. Currently,
780/// the median is calculable only for addition and subtraction operations on:
781/// - [`Uniform`] and [`Uniform`] distributions, and
782/// - [`Gaussian`] and [`Gaussian`] distributions.
783pub fn compute_median(
784    op: &Operator,
785    left: &Distribution,
786    right: &Distribution,
787) -> Result<ScalarValue> {
788    match (left, right) {
789        (Uniform(lu), Uniform(ru)) => {
790            let (left_median, right_median) = (lu.median()?, ru.median()?);
791            // Under the independence assumption, the result is a symmetric
792            // triangular distribution, so we can simply add/subtract the
793            // median values:
794            match op {
795                Operator::Plus => return left_median.add_checked(right_median),
796                Operator::Minus => return left_median.sub_checked(right_median),
797                // Fall back to an unknown median value for other cases:
798                _ => {}
799            }
800        }
801        // Under the independence assumption, the result is another Gaussian
802        // distribution, so we can simply add/subtract the median values:
803        (Gaussian(lg), Gaussian(rg)) => match op {
804            Operator::Plus => return lg.mean().add_checked(rg.mean()),
805            Operator::Minus => return lg.mean().sub_checked(rg.mean()),
806            // Fall back to an unknown median value for other cases:
807            _ => {}
808        },
809        // Fall back to an unknown median value for other cases:
810        _ => {}
811    }
812
813    let (left_median, right_median) = (left.median()?, right.median()?);
814    let target_type = Distribution::target_type(&[&left_median, &right_median])?;
815    ScalarValue::try_from(target_type)
816}
817
818/// Computes the variance value for the result of the given binary operation on
819/// two unknown quantities represented by their [`Distribution`] objects.
820pub fn compute_variance(
821    op: &Operator,
822    left: &Distribution,
823    right: &Distribution,
824) -> Result<ScalarValue> {
825    let (left_variance, right_variance) = (left.variance()?, right.variance()?);
826
827    match op {
828        // Note the independence assumption below:
829        Operator::Plus => return left_variance.add_checked(right_variance),
830        // Note the independence assumption below:
831        Operator::Minus => return left_variance.add_checked(right_variance),
832        // Note the independence assumption below:
833        Operator::Multiply => {
834            // For more details, along with an explanation of the formula below, see:
835            //
836            // <https://en.wikipedia.org/wiki/Distribution_of_the_product_of_two_random_variables>
837            let (left_mean, right_mean) = (left.mean()?, right.mean()?);
838            let left_mean_sq = left_mean.mul_checked(&left_mean)?;
839            let right_mean_sq = right_mean.mul_checked(&right_mean)?;
840            let left_sos = left_variance.add_checked(&left_mean_sq)?;
841            let right_sos = right_variance.add_checked(&right_mean_sq)?;
842            let pos = left_mean_sq.mul_checked(right_mean_sq)?;
843            return left_sos.mul_checked(right_sos)?.sub_checked(pos);
844        }
845        // TODO: We can calculate the variance for division when we support reciprocals,
846        // or know the distributions of the operands. For details, see:
847        //
848        // <https://en.wikipedia.org/wiki/Algebra_of_random_variables>
849        // <https://stats.stackexchange.com/questions/185683/distribution-of-ratio-between-two-independent-uniform-random-variables>
850        //
851        // Fall back to an unknown variance value for division:
852        Operator::Divide => {}
853        // Fall back to an unknown variance value for other cases:
854        _ => {}
855    }
856    let target_type = Distribution::target_type(&[&left_variance, &right_variance])?;
857    ScalarValue::try_from(target_type)
858}
859
860#[cfg(test)]
861mod tests {
862    use super::{
863        combine_bernoullis, combine_gaussians, compute_mean, compute_median,
864        compute_variance, create_bernoulli_from_comparison, new_generic_from_binary_op,
865        BernoulliDistribution, Distribution, GaussianDistribution, UniformDistribution,
866    };
867    use crate::interval_arithmetic::{apply_operator, Interval};
868    use crate::operator::Operator;
869
870    use arrow::datatypes::DataType;
871    use datafusion_common::{HashSet, Result, ScalarValue};
872
873    #[test]
874    fn uniform_dist_is_valid_test() -> Result<()> {
875        assert_eq!(
876            Distribution::new_uniform(Interval::make_zero(&DataType::Int8)?)?,
877            Distribution::Uniform(UniformDistribution {
878                interval: Interval::make_zero(&DataType::Int8)?,
879            })
880        );
881
882        assert!(Distribution::new_uniform(Interval::UNCERTAIN).is_err());
883        Ok(())
884    }
885
886    #[test]
887    fn exponential_dist_is_valid_test() {
888        // This array collects test cases of the form (distribution, validity).
889        let exponentials = vec![
890            (
891                Distribution::new_exponential(ScalarValue::Null, ScalarValue::Null, true),
892                false,
893            ),
894            (
895                Distribution::new_exponential(
896                    ScalarValue::from(0_f32),
897                    ScalarValue::from(1_f32),
898                    true,
899                ),
900                false,
901            ),
902            (
903                Distribution::new_exponential(
904                    ScalarValue::from(100_f32),
905                    ScalarValue::from(1_f32),
906                    true,
907                ),
908                true,
909            ),
910            (
911                Distribution::new_exponential(
912                    ScalarValue::from(-100_f32),
913                    ScalarValue::from(1_f32),
914                    true,
915                ),
916                false,
917            ),
918        ];
919        for case in exponentials {
920            assert_eq!(case.0.is_ok(), case.1);
921        }
922    }
923
924    #[test]
925    fn gaussian_dist_is_valid_test() {
926        // This array collects test cases of the form (distribution, validity).
927        let gaussians = vec![
928            (
929                Distribution::new_gaussian(ScalarValue::Null, ScalarValue::Null),
930                false,
931            ),
932            (
933                Distribution::new_gaussian(
934                    ScalarValue::from(0_f32),
935                    ScalarValue::from(0_f32),
936                ),
937                true,
938            ),
939            (
940                Distribution::new_gaussian(
941                    ScalarValue::from(0_f32),
942                    ScalarValue::from(0.5_f32),
943                ),
944                true,
945            ),
946            (
947                Distribution::new_gaussian(
948                    ScalarValue::from(0_f32),
949                    ScalarValue::from(-0.5_f32),
950                ),
951                false,
952            ),
953        ];
954        for case in gaussians {
955            assert_eq!(case.0.is_ok(), case.1);
956        }
957    }
958
959    #[test]
960    fn bernoulli_dist_is_valid_test() {
961        // This array collects test cases of the form (distribution, validity).
962        let bernoullis = vec![
963            (Distribution::new_bernoulli(ScalarValue::Null), true),
964            (Distribution::new_bernoulli(ScalarValue::from(0.)), true),
965            (Distribution::new_bernoulli(ScalarValue::from(0.25)), true),
966            (Distribution::new_bernoulli(ScalarValue::from(1.)), true),
967            (Distribution::new_bernoulli(ScalarValue::from(11.)), false),
968            (Distribution::new_bernoulli(ScalarValue::from(-11.)), false),
969            (Distribution::new_bernoulli(ScalarValue::from(0_i64)), true),
970            (Distribution::new_bernoulli(ScalarValue::from(1_i64)), true),
971            (
972                Distribution::new_bernoulli(ScalarValue::from(11_i64)),
973                false,
974            ),
975            (
976                Distribution::new_bernoulli(ScalarValue::from(-11_i64)),
977                false,
978            ),
979        ];
980        for case in bernoullis {
981            assert_eq!(case.0.is_ok(), case.1);
982        }
983    }
984
985    #[test]
986    fn generic_dist_is_valid_test() -> Result<()> {
987        // This array collects test cases of the form (distribution, validity).
988        let generic_dists = vec![
989            // Using a boolean range to construct a Generic distribution is prohibited.
990            (
991                Distribution::new_generic(
992                    ScalarValue::Null,
993                    ScalarValue::Null,
994                    ScalarValue::Null,
995                    Interval::UNCERTAIN,
996                ),
997                false,
998            ),
999            (
1000                Distribution::new_generic(
1001                    ScalarValue::Null,
1002                    ScalarValue::Null,
1003                    ScalarValue::Null,
1004                    Interval::make_zero(&DataType::Float32)?,
1005                ),
1006                true,
1007            ),
1008            (
1009                Distribution::new_generic(
1010                    ScalarValue::from(0_f32),
1011                    ScalarValue::Float32(None),
1012                    ScalarValue::Float32(None),
1013                    Interval::make_zero(&DataType::Float32)?,
1014                ),
1015                true,
1016            ),
1017            (
1018                Distribution::new_generic(
1019                    ScalarValue::Float64(None),
1020                    ScalarValue::from(0.),
1021                    ScalarValue::Float64(None),
1022                    Interval::make_zero(&DataType::Float32)?,
1023                ),
1024                true,
1025            ),
1026            (
1027                Distribution::new_generic(
1028                    ScalarValue::from(-10_f32),
1029                    ScalarValue::Float32(None),
1030                    ScalarValue::Float32(None),
1031                    Interval::make_zero(&DataType::Float32)?,
1032                ),
1033                false,
1034            ),
1035            (
1036                Distribution::new_generic(
1037                    ScalarValue::Float32(None),
1038                    ScalarValue::from(10_f32),
1039                    ScalarValue::Float32(None),
1040                    Interval::make_zero(&DataType::Float32)?,
1041                ),
1042                false,
1043            ),
1044            (
1045                Distribution::new_generic(
1046                    ScalarValue::Null,
1047                    ScalarValue::Null,
1048                    ScalarValue::Null,
1049                    Interval::make_zero(&DataType::Float32)?,
1050                ),
1051                true,
1052            ),
1053            (
1054                Distribution::new_generic(
1055                    ScalarValue::from(0),
1056                    ScalarValue::from(0),
1057                    ScalarValue::Int32(None),
1058                    Interval::make_zero(&DataType::Int32)?,
1059                ),
1060                true,
1061            ),
1062            (
1063                Distribution::new_generic(
1064                    ScalarValue::from(0_f32),
1065                    ScalarValue::from(0_f32),
1066                    ScalarValue::Float32(None),
1067                    Interval::make_zero(&DataType::Float32)?,
1068                ),
1069                true,
1070            ),
1071            (
1072                Distribution::new_generic(
1073                    ScalarValue::from(50.),
1074                    ScalarValue::from(50.),
1075                    ScalarValue::Float64(None),
1076                    Interval::make(Some(0.), Some(100.))?,
1077                ),
1078                true,
1079            ),
1080            (
1081                Distribution::new_generic(
1082                    ScalarValue::from(50.),
1083                    ScalarValue::from(50.),
1084                    ScalarValue::Float64(None),
1085                    Interval::make(Some(-100.), Some(0.))?,
1086                ),
1087                false,
1088            ),
1089            (
1090                Distribution::new_generic(
1091                    ScalarValue::Float64(None),
1092                    ScalarValue::Float64(None),
1093                    ScalarValue::from(1.),
1094                    Interval::make_zero(&DataType::Float64)?,
1095                ),
1096                true,
1097            ),
1098            (
1099                Distribution::new_generic(
1100                    ScalarValue::Float64(None),
1101                    ScalarValue::Float64(None),
1102                    ScalarValue::from(-1.),
1103                    Interval::make_zero(&DataType::Float64)?,
1104                ),
1105                false,
1106            ),
1107        ];
1108        for case in generic_dists {
1109            assert_eq!(case.0.is_ok(), case.1, "{:?}", case.0);
1110        }
1111
1112        Ok(())
1113    }
1114
1115    #[test]
1116    fn mean_extraction_test() -> Result<()> {
1117        // This array collects test cases of the form (distribution, mean value).
1118        let dists = vec![
1119            (
1120                Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1121                ScalarValue::from(0_i64),
1122            ),
1123            (
1124                Distribution::new_uniform(Interval::make_zero(&DataType::Float64)?),
1125                ScalarValue::from(0.),
1126            ),
1127            (
1128                Distribution::new_uniform(Interval::make(Some(1), Some(100))?),
1129                ScalarValue::from(50),
1130            ),
1131            (
1132                Distribution::new_uniform(Interval::make(Some(-100), Some(-1))?),
1133                ScalarValue::from(-50),
1134            ),
1135            (
1136                Distribution::new_uniform(Interval::make(Some(-100), Some(100))?),
1137                ScalarValue::from(0),
1138            ),
1139            (
1140                Distribution::new_exponential(
1141                    ScalarValue::from(2.),
1142                    ScalarValue::from(0.),
1143                    true,
1144                ),
1145                ScalarValue::from(0.5),
1146            ),
1147            (
1148                Distribution::new_exponential(
1149                    ScalarValue::from(2.),
1150                    ScalarValue::from(1.),
1151                    true,
1152                ),
1153                ScalarValue::from(1.5),
1154            ),
1155            (
1156                Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1157                ScalarValue::from(0.),
1158            ),
1159            (
1160                Distribution::new_gaussian(
1161                    ScalarValue::from(-2.),
1162                    ScalarValue::from(0.5),
1163                ),
1164                ScalarValue::from(-2.),
1165            ),
1166            (
1167                Distribution::new_bernoulli(ScalarValue::from(0.5)),
1168                ScalarValue::from(0.5),
1169            ),
1170            (
1171                Distribution::new_generic(
1172                    ScalarValue::from(42.),
1173                    ScalarValue::from(42.),
1174                    ScalarValue::Float64(None),
1175                    Interval::make(Some(25.), Some(50.))?,
1176                ),
1177                ScalarValue::from(42.),
1178            ),
1179        ];
1180
1181        for case in dists {
1182            assert_eq!(case.0?.mean()?, case.1);
1183        }
1184
1185        Ok(())
1186    }
1187
1188    #[test]
1189    fn median_extraction_test() -> Result<()> {
1190        // This array collects test cases of the form (distribution, median value).
1191        let dists = vec![
1192            (
1193                Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1194                ScalarValue::from(0_i64),
1195            ),
1196            (
1197                Distribution::new_uniform(Interval::make(Some(25.), Some(75.))?),
1198                ScalarValue::from(50.),
1199            ),
1200            (
1201                Distribution::new_exponential(
1202                    ScalarValue::from(2_f64.ln()),
1203                    ScalarValue::from(0.),
1204                    true,
1205                ),
1206                ScalarValue::from(1.),
1207            ),
1208            (
1209                Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1210                ScalarValue::from(2.),
1211            ),
1212            (
1213                Distribution::new_bernoulli(ScalarValue::from(0.25)),
1214                ScalarValue::from(0.),
1215            ),
1216            (
1217                Distribution::new_bernoulli(ScalarValue::from(0.75)),
1218                ScalarValue::from(1.),
1219            ),
1220            (
1221                Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1222                ScalarValue::from(2.),
1223            ),
1224            (
1225                Distribution::new_generic(
1226                    ScalarValue::from(12.),
1227                    ScalarValue::from(12.),
1228                    ScalarValue::Float64(None),
1229                    Interval::make(Some(0.), Some(25.))?,
1230                ),
1231                ScalarValue::from(12.),
1232            ),
1233        ];
1234
1235        for case in dists {
1236            assert_eq!(case.0?.median()?, case.1);
1237        }
1238
1239        Ok(())
1240    }
1241
1242    #[test]
1243    fn variance_extraction_test() -> Result<()> {
1244        // This array collects test cases of the form (distribution, variance value).
1245        let dists = vec![
1246            (
1247                Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?),
1248                ScalarValue::from(12.),
1249            ),
1250            (
1251                Distribution::new_exponential(
1252                    ScalarValue::from(10.),
1253                    ScalarValue::from(0.),
1254                    true,
1255                ),
1256                ScalarValue::from(0.01),
1257            ),
1258            (
1259                Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1260                ScalarValue::from(1.),
1261            ),
1262            (
1263                Distribution::new_bernoulli(ScalarValue::from(0.5)),
1264                ScalarValue::from(0.25),
1265            ),
1266            (
1267                Distribution::new_generic(
1268                    ScalarValue::Float64(None),
1269                    ScalarValue::Float64(None),
1270                    ScalarValue::from(0.02),
1271                    Interval::make_zero(&DataType::Float64)?,
1272                ),
1273                ScalarValue::from(0.02),
1274            ),
1275        ];
1276
1277        for case in dists {
1278            assert_eq!(case.0?.variance()?, case.1);
1279        }
1280
1281        Ok(())
1282    }
1283
1284    #[test]
1285    fn test_calculate_generic_properties_gauss_gauss() -> Result<()> {
1286        let dist_a =
1287            Distribution::new_gaussian(ScalarValue::from(10.), ScalarValue::from(0.0))?;
1288        let dist_b =
1289            Distribution::new_gaussian(ScalarValue::from(20.), ScalarValue::from(0.0))?;
1290
1291        let test_data = vec![
1292            // Mean:
1293            (
1294                compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1295                ScalarValue::from(30.),
1296            ),
1297            (
1298                compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1299                ScalarValue::from(-10.),
1300            ),
1301            // Median:
1302            (
1303                compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1304                ScalarValue::from(30.),
1305            ),
1306            (
1307                compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1308                ScalarValue::from(-10.),
1309            ),
1310        ];
1311        for (actual, expected) in test_data {
1312            assert_eq!(actual, expected);
1313        }
1314
1315        Ok(())
1316    }
1317
1318    #[test]
1319    fn test_combine_bernoullis_and_op() -> Result<()> {
1320        let op = Operator::And;
1321        let left = BernoulliDistribution::try_new(ScalarValue::from(0.5))?;
1322        let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1323        let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1324        let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1325
1326        assert_eq!(
1327            combine_bernoullis(&op, &left, &right)?.p_value(),
1328            &ScalarValue::from(0.5 * 0.4)
1329        );
1330        assert_eq!(
1331            combine_bernoullis(&op, &left_null, &right)?.p_value(),
1332            &ScalarValue::Float64(None)
1333        );
1334        assert_eq!(
1335            combine_bernoullis(&op, &left, &right_null)?.p_value(),
1336            &ScalarValue::Float64(None)
1337        );
1338        assert_eq!(
1339            combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1340            &ScalarValue::Null
1341        );
1342
1343        Ok(())
1344    }
1345
1346    #[test]
1347    fn test_combine_bernoullis_or_op() -> Result<()> {
1348        let op = Operator::Or;
1349        let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1350        let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1351        let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1352        let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1353
1354        assert_eq!(
1355            combine_bernoullis(&op, &left, &right)?.p_value(),
1356            &ScalarValue::from(0.6 + 0.4 - (0.6 * 0.4))
1357        );
1358        assert_eq!(
1359            combine_bernoullis(&op, &left_null, &right)?.p_value(),
1360            &ScalarValue::Float64(None)
1361        );
1362        assert_eq!(
1363            combine_bernoullis(&op, &left, &right_null)?.p_value(),
1364            &ScalarValue::Float64(None)
1365        );
1366        assert_eq!(
1367            combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1368            &ScalarValue::Null
1369        );
1370
1371        Ok(())
1372    }
1373
1374    #[test]
1375    fn test_combine_bernoullis_unsupported_ops() -> Result<()> {
1376        let mut operator_set = operator_set();
1377        operator_set.remove(&Operator::And);
1378        operator_set.remove(&Operator::Or);
1379
1380        let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1381        let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1382        for op in operator_set {
1383            assert!(
1384                combine_bernoullis(&op, &left, &right).is_err(),
1385                "Operator {op} should not be supported for Bernoulli distributions"
1386            );
1387        }
1388
1389        Ok(())
1390    }
1391
1392    #[test]
1393    fn test_combine_gaussians_addition() -> Result<()> {
1394        let op = Operator::Plus;
1395        let left = GaussianDistribution::try_new(
1396            ScalarValue::from(3.0),
1397            ScalarValue::from(2.0),
1398        )?;
1399        let right = GaussianDistribution::try_new(
1400            ScalarValue::from(4.0),
1401            ScalarValue::from(1.0),
1402        )?;
1403
1404        let result = combine_gaussians(&op, &left, &right)?.unwrap();
1405
1406        assert_eq!(result.mean(), &ScalarValue::from(7.0)); // 3.0 + 4.0
1407        assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0
1408        Ok(())
1409    }
1410
1411    #[test]
1412    fn test_combine_gaussians_subtraction() -> Result<()> {
1413        let op = Operator::Minus;
1414        let left = GaussianDistribution::try_new(
1415            ScalarValue::from(7.0),
1416            ScalarValue::from(2.0),
1417        )?;
1418        let right = GaussianDistribution::try_new(
1419            ScalarValue::from(4.0),
1420            ScalarValue::from(1.0),
1421        )?;
1422
1423        let result = combine_gaussians(&op, &left, &right)?.unwrap();
1424
1425        assert_eq!(result.mean(), &ScalarValue::from(3.0)); // 7.0 - 4.0
1426        assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0
1427
1428        Ok(())
1429    }
1430
1431    #[test]
1432    fn test_combine_gaussians_unsupported_ops() -> Result<()> {
1433        let mut operator_set = operator_set();
1434        operator_set.remove(&Operator::Plus);
1435        operator_set.remove(&Operator::Minus);
1436
1437        let left = GaussianDistribution::try_new(
1438            ScalarValue::from(7.0),
1439            ScalarValue::from(2.0),
1440        )?;
1441        let right = GaussianDistribution::try_new(
1442            ScalarValue::from(4.0),
1443            ScalarValue::from(1.0),
1444        )?;
1445        for op in operator_set {
1446            assert!(
1447                combine_gaussians(&op, &left, &right)?.is_none(),
1448                "Operator {op} should not be supported for Gaussian distributions"
1449            );
1450        }
1451
1452        Ok(())
1453    }
1454
1455    // Expected test results were calculated in Wolfram Mathematica, by using:
1456    //
1457    // *METHOD_NAME*[TransformedDistribution[
1458    //  x *op* y,
1459    //  {x ~ *DISTRIBUTION_X*[..], y ~ *DISTRIBUTION_Y*[..]}
1460    // ]]
1461    #[test]
1462    fn test_calculate_generic_properties_uniform_uniform() -> Result<()> {
1463        let dist_a = Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?)?;
1464        let dist_b = Distribution::new_uniform(Interval::make(Some(12.), Some(36.))?)?;
1465
1466        let test_data = vec![
1467            // Mean:
1468            (
1469                compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1470                ScalarValue::from(30.),
1471            ),
1472            (
1473                compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1474                ScalarValue::from(-18.),
1475            ),
1476            (
1477                compute_mean(&Operator::Multiply, &dist_a, &dist_b)?,
1478                ScalarValue::from(144.),
1479            ),
1480            // Median:
1481            (
1482                compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1483                ScalarValue::from(30.),
1484            ),
1485            (
1486                compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1487                ScalarValue::from(-18.),
1488            ),
1489            // Variance:
1490            (
1491                compute_variance(&Operator::Plus, &dist_a, &dist_b)?,
1492                ScalarValue::from(60.),
1493            ),
1494            (
1495                compute_variance(&Operator::Minus, &dist_a, &dist_b)?,
1496                ScalarValue::from(60.),
1497            ),
1498            (
1499                compute_variance(&Operator::Multiply, &dist_a, &dist_b)?,
1500                ScalarValue::from(9216.),
1501            ),
1502        ];
1503        for (actual, expected) in test_data {
1504            assert_eq!(actual, expected);
1505        }
1506
1507        Ok(())
1508    }
1509
1510    /// Test for `Uniform`-`Uniform`, `Uniform`-`Generic`, `Generic`-`Uniform`,
1511    /// `Generic`-`Generic` pairs, where range is always present.
1512    #[test]
1513    fn test_compute_range_where_present() -> Result<()> {
1514        let a = &Interval::make(Some(0.), Some(12.0))?;
1515        let b = &Interval::make(Some(0.), Some(12.0))?;
1516        let mean = ScalarValue::from(6.0);
1517        for (dist_a, dist_b) in [
1518            (
1519                Distribution::new_uniform(a.clone())?,
1520                Distribution::new_uniform(b.clone())?,
1521            ),
1522            (
1523                Distribution::new_generic(
1524                    mean.clone(),
1525                    mean.clone(),
1526                    ScalarValue::Float64(None),
1527                    a.clone(),
1528                )?,
1529                Distribution::new_uniform(b.clone())?,
1530            ),
1531            (
1532                Distribution::new_uniform(a.clone())?,
1533                Distribution::new_generic(
1534                    mean.clone(),
1535                    mean.clone(),
1536                    ScalarValue::Float64(None),
1537                    b.clone(),
1538                )?,
1539            ),
1540            (
1541                Distribution::new_generic(
1542                    mean.clone(),
1543                    mean.clone(),
1544                    ScalarValue::Float64(None),
1545                    a.clone(),
1546                )?,
1547                Distribution::new_generic(
1548                    mean.clone(),
1549                    mean.clone(),
1550                    ScalarValue::Float64(None),
1551                    b.clone(),
1552                )?,
1553            ),
1554        ] {
1555            use super::Operator::{
1556                Divide, Eq, Gt, GtEq, Lt, LtEq, Minus, Multiply, NotEq, Plus,
1557            };
1558            for op in [Plus, Minus, Multiply, Divide] {
1559                assert_eq!(
1560                    new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?,
1561                    apply_operator(&op, a, b)?,
1562                    "Failed for {:?} {op} {:?}",
1563                    dist_a,
1564                    dist_b
1565                );
1566            }
1567            for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] {
1568                assert_eq!(
1569                    create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?,
1570                    apply_operator(&op, a, b)?,
1571                    "Failed for {:?} {op} {:?}",
1572                    dist_a,
1573                    dist_b
1574                );
1575            }
1576        }
1577
1578        Ok(())
1579    }
1580
1581    fn operator_set() -> HashSet<Operator> {
1582        use super::Operator::*;
1583
1584        let all_ops = vec![
1585            And,
1586            Or,
1587            Eq,
1588            NotEq,
1589            Gt,
1590            GtEq,
1591            Lt,
1592            LtEq,
1593            Plus,
1594            Minus,
1595            Multiply,
1596            Divide,
1597            Modulo,
1598            IsDistinctFrom,
1599            IsNotDistinctFrom,
1600            RegexMatch,
1601            RegexIMatch,
1602            RegexNotMatch,
1603            RegexNotIMatch,
1604            LikeMatch,
1605            ILikeMatch,
1606            NotLikeMatch,
1607            NotILikeMatch,
1608            BitwiseAnd,
1609            BitwiseOr,
1610            BitwiseXor,
1611            BitwiseShiftRight,
1612            BitwiseShiftLeft,
1613            StringConcat,
1614            AtArrow,
1615            ArrowAt,
1616        ];
1617
1618        all_ops.into_iter().collect()
1619    }
1620}