datafusion_expr_common/
statistics.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::f64::consts::LN_2;
19
20use crate::interval_arithmetic::{Interval, apply_operator};
21use crate::operator::Operator;
22use crate::type_coercion::binary::binary_numeric_coercion;
23
24use arrow::array::ArrowNativeTypeOp;
25use arrow::datatypes::DataType;
26use datafusion_common::rounding::alter_fp_rounding_mode;
27use datafusion_common::{
28    Result, ScalarValue, assert_eq_or_internal_err, assert_ne_or_internal_err,
29    assert_or_internal_err, internal_err, not_impl_err,
30};
31
32/// This object defines probabilistic distributions that encode uncertain
33/// information about a single, scalar value. Currently, we support five core
34/// statistical distributions. New variants will be added over time.
35///
36/// This object is the lowest-level object in the statistics hierarchy, and it
37/// is the main unit of calculus when evaluating expressions in a statistical
38/// context. Notions like column and table statistics are built on top of this
39/// object and the operations it supports.
40#[derive(Clone, Debug, PartialEq)]
41pub enum Distribution {
42    Uniform(UniformDistribution),
43    Exponential(ExponentialDistribution),
44    Gaussian(GaussianDistribution),
45    Bernoulli(BernoulliDistribution),
46    Generic(GenericDistribution),
47}
48
49use Distribution::{Bernoulli, Exponential, Gaussian, Generic, Uniform};
50
51impl Distribution {
52    /// Constructs a new [`Uniform`] distribution from the given [`Interval`].
53    pub fn new_uniform(interval: Interval) -> Result<Self> {
54        UniformDistribution::try_new(interval).map(Uniform)
55    }
56
57    /// Constructs a new [`Exponential`] distribution from the given rate/offset
58    /// pair, and validates the given parameters.
59    pub fn new_exponential(
60        rate: ScalarValue,
61        offset: ScalarValue,
62        positive_tail: bool,
63    ) -> Result<Self> {
64        ExponentialDistribution::try_new(rate, offset, positive_tail).map(Exponential)
65    }
66
67    /// Constructs a new [`Gaussian`] distribution from the given mean/variance
68    /// pair, and validates the given parameters.
69    pub fn new_gaussian(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
70        GaussianDistribution::try_new(mean, variance).map(Gaussian)
71    }
72
73    /// Constructs a new [`Bernoulli`] distribution from the given success
74    /// probability, and validates the given parameters.
75    pub fn new_bernoulli(p: ScalarValue) -> Result<Self> {
76        BernoulliDistribution::try_new(p).map(Bernoulli)
77    }
78
79    /// Constructs a new [`Generic`] distribution from the given mean, median,
80    /// variance, and range values after validating the given parameters.
81    pub fn new_generic(
82        mean: ScalarValue,
83        median: ScalarValue,
84        variance: ScalarValue,
85        range: Interval,
86    ) -> Result<Self> {
87        GenericDistribution::try_new(mean, median, variance, range).map(Generic)
88    }
89
90    /// Constructs a new [`Generic`] distribution from the given range. Other
91    /// parameters (mean, median and variance) are initialized with null values.
92    pub fn new_from_interval(range: Interval) -> Result<Self> {
93        let null = ScalarValue::try_from(range.data_type())?;
94        Distribution::new_generic(null.clone(), null.clone(), null, range)
95    }
96
97    /// Extracts the mean value of this uncertain quantity, depending on its
98    /// distribution:
99    /// - A [`Uniform`] distribution's interval determines its mean value, which
100    ///   is the arithmetic average of the interval endpoints.
101    /// - An [`Exponential`] distribution's mean is calculable by the formula
102    ///   `offset + 1 / λ`, where `λ` is the (non-negative) rate.
103    /// - A [`Gaussian`] distribution contains the mean explicitly.
104    /// - A [`Bernoulli`] distribution's mean is equal to its success probability `p`.
105    /// - A [`Generic`] distribution _may_ have it explicitly, or this information
106    ///   may be absent.
107    pub fn mean(&self) -> Result<ScalarValue> {
108        match &self {
109            Uniform(u) => u.mean(),
110            Exponential(e) => e.mean(),
111            Gaussian(g) => Ok(g.mean().clone()),
112            Bernoulli(b) => Ok(b.mean().clone()),
113            Generic(u) => Ok(u.mean().clone()),
114        }
115    }
116
117    /// Extracts the median value of this uncertain quantity, depending on its
118    /// distribution:
119    /// - A [`Uniform`] distribution's interval determines its median value, which
120    ///   is the arithmetic average of the interval endpoints.
121    /// - An [`Exponential`] distribution's median is calculable by the formula
122    ///   `offset + ln(2) / λ`, where `λ` is the (non-negative) rate.
123    /// - A [`Gaussian`] distribution's median is equal to its mean, which is
124    ///   specified explicitly.
125    /// - A [`Bernoulli`] distribution's median is `1` if `p > 0.5` and `0`
126    ///   otherwise, where `p` is the success probability.
127    /// - A [`Generic`] distribution _may_ have it explicitly, or this information
128    ///   may be absent.
129    pub fn median(&self) -> Result<ScalarValue> {
130        match &self {
131            Uniform(u) => u.median(),
132            Exponential(e) => e.median(),
133            Gaussian(g) => Ok(g.median().clone()),
134            Bernoulli(b) => b.median(),
135            Generic(u) => Ok(u.median().clone()),
136        }
137    }
138
139    /// Extracts the variance value of this uncertain quantity, depending on
140    /// its distribution:
141    /// - A [`Uniform`] distribution's interval determines its variance value, which
142    ///   is calculable by the formula `(upper - lower) ^ 2 / 12`.
143    /// - An [`Exponential`] distribution's variance is calculable by the formula
144    ///   `1 / (λ ^ 2)`, where `λ` is the (non-negative) rate.
145    /// - A [`Gaussian`] distribution's variance is specified explicitly.
146    /// - A [`Bernoulli`] distribution's median is given by the formula `p * (1 - p)`
147    ///   where `p` is the success probability.
148    /// - A [`Generic`] distribution _may_ have it explicitly, or this information
149    ///   may be absent.
150    pub fn variance(&self) -> Result<ScalarValue> {
151        match &self {
152            Uniform(u) => u.variance(),
153            Exponential(e) => e.variance(),
154            Gaussian(g) => Ok(g.variance.clone()),
155            Bernoulli(b) => b.variance(),
156            Generic(u) => Ok(u.variance.clone()),
157        }
158    }
159
160    /// Extracts the range of this uncertain quantity, depending on its
161    /// distribution:
162    /// - A [`Uniform`] distribution's range is simply its interval.
163    /// - An [`Exponential`] distribution's range is `[offset, +∞)`.
164    /// - A [`Gaussian`] distribution's range is unbounded.
165    /// - A [`Bernoulli`] distribution's range is [`Interval::TRUE_OR_FALSE`], if
166    ///   `p` is neither `0` nor `1`. Otherwise, it is [`Interval::FALSE`]
167    ///   and [`Interval::TRUE`], respectively.
168    /// - A [`Generic`] distribution is unbounded by default, but more information
169    ///   may be present.
170    pub fn range(&self) -> Result<Interval> {
171        match &self {
172            Uniform(u) => Ok(u.range().clone()),
173            Exponential(e) => e.range(),
174            Gaussian(g) => g.range(),
175            Bernoulli(b) => Ok(b.range()),
176            Generic(u) => Ok(u.range().clone()),
177        }
178    }
179
180    /// Returns the data type of the statistical parameters comprising this
181    /// distribution.
182    pub fn data_type(&self) -> DataType {
183        match &self {
184            Uniform(u) => u.data_type(),
185            Exponential(e) => e.data_type(),
186            Gaussian(g) => g.data_type(),
187            Bernoulli(b) => b.data_type(),
188            Generic(u) => u.data_type(),
189        }
190    }
191
192    pub fn target_type(args: &[&ScalarValue]) -> Result<DataType> {
193        let mut arg_types = args
194            .iter()
195            .filter(|&&arg| arg != &ScalarValue::Null)
196            .map(|&arg| arg.data_type());
197
198        let Some(dt) = arg_types.next().map_or_else(
199            || Some(DataType::Null),
200            |first| {
201                arg_types
202                    .try_fold(first, |target, arg| binary_numeric_coercion(&target, &arg))
203            },
204        ) else {
205            return internal_err!("Can only evaluate statistics for numeric types");
206        };
207        Ok(dt)
208    }
209}
210
211/// Uniform distribution, represented by its range. If the given range extends
212/// towards infinity, the distribution will be improper -- which is OK. For a
213/// more in-depth discussion, see:
214///
215/// <https://en.wikipedia.org/wiki/Continuous_uniform_distribution>
216/// <https://en.wikipedia.org/wiki/Prior_probability#Improper_priors>
217#[derive(Clone, Debug, PartialEq)]
218pub struct UniformDistribution {
219    interval: Interval,
220}
221
222/// Exponential distribution with an optional shift. The probability density
223/// function (PDF) is defined as follows:
224///
225/// For a positive tail (when `positive_tail` is `true`):
226///
227/// `f(x; λ, offset) = λ exp(-λ (x - offset))    for x ≥ offset`
228///
229/// For a negative tail (when `positive_tail` is `false`):
230///
231/// `f(x; λ, offset) = λ exp(-λ (offset - x))    for x ≤ offset`
232///
233///
234/// In both cases, the PDF is `0` outside the specified domain.
235///
236/// For more information, see:
237///
238/// <https://en.wikipedia.org/wiki/Exponential_distribution>
239#[derive(Clone, Debug, PartialEq)]
240pub struct ExponentialDistribution {
241    rate: ScalarValue,
242    offset: ScalarValue,
243    /// Indicates whether the exponential distribution has a positive tail;
244    /// i.e. it extends towards positive infinity.
245    positive_tail: bool,
246}
247
248/// Gaussian (normal) distribution, represented by its mean and variance.
249/// For a more in-depth discussion, see:
250///
251/// <https://en.wikipedia.org/wiki/Normal_distribution>
252#[derive(Clone, Debug, PartialEq)]
253pub struct GaussianDistribution {
254    mean: ScalarValue,
255    variance: ScalarValue,
256}
257
258/// Bernoulli distribution with success probability `p`. If `p` has a null value,
259/// the success probability is unknown. For a more in-depth discussion, see:
260///
261/// <https://en.wikipedia.org/wiki/Bernoulli_distribution>
262#[derive(Clone, Debug, PartialEq)]
263pub struct BernoulliDistribution {
264    p: ScalarValue,
265}
266
267/// A generic distribution whose functional form is not available, which is
268/// approximated via some summary statistics. For a more in-depth discussion, see:
269///
270/// <https://en.wikipedia.org/wiki/Summary_statistics>
271#[derive(Clone, Debug, PartialEq)]
272pub struct GenericDistribution {
273    mean: ScalarValue,
274    median: ScalarValue,
275    variance: ScalarValue,
276    range: Interval,
277}
278
279impl UniformDistribution {
280    fn try_new(interval: Interval) -> Result<Self> {
281        assert_ne_or_internal_err!(
282            interval.data_type(),
283            DataType::Boolean,
284            "Construction of a boolean `Uniform` distribution is prohibited, create a `Bernoulli` distribution instead."
285        );
286
287        Ok(Self { interval })
288    }
289
290    pub fn data_type(&self) -> DataType {
291        self.interval.data_type()
292    }
293
294    /// Computes the mean value of this distribution. In case of improper
295    /// distributions (i.e. when the range is unbounded), the function returns
296    /// a `NULL` `ScalarValue`.
297    pub fn mean(&self) -> Result<ScalarValue> {
298        // TODO: Should we ensure that this always returns a real number data type?
299        let dt = self.data_type();
300        let two = ScalarValue::from(2).cast_to(&dt)?;
301        let result = self
302            .interval
303            .lower()
304            .add_checked(self.interval.upper())?
305            .div(two);
306        debug_assert!(
307            !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
308        );
309        result
310    }
311
312    pub fn median(&self) -> Result<ScalarValue> {
313        self.mean()
314    }
315
316    /// Computes the variance value of this distribution. In case of improper
317    /// distributions (i.e. when the range is unbounded), the function returns
318    /// a `NULL` `ScalarValue`.
319    pub fn variance(&self) -> Result<ScalarValue> {
320        // TODO: Should we ensure that this always returns a real number data type?
321        let width = self.interval.width()?;
322        let dt = width.data_type();
323        let twelve = ScalarValue::from(12).cast_to(&dt)?;
324        let result = width.mul_checked(&width)?.div(twelve);
325        debug_assert!(
326            !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
327        );
328        result
329    }
330
331    pub fn range(&self) -> &Interval {
332        &self.interval
333    }
334}
335
336impl ExponentialDistribution {
337    fn try_new(
338        rate: ScalarValue,
339        offset: ScalarValue,
340        positive_tail: bool,
341    ) -> Result<Self> {
342        let dt = rate.data_type();
343        assert_eq_or_internal_err!(
344            offset.data_type(),
345            dt,
346            "Rate and offset must have the same data type"
347        );
348        assert_or_internal_err!(
349            !offset.is_null(),
350            "Offset of an `ExponentialDistribution` cannot be null"
351        );
352        assert_or_internal_err!(
353            !rate.is_null(),
354            "Rate of an `ExponentialDistribution` cannot be null"
355        );
356        let zero = ScalarValue::new_zero(&dt)?;
357        assert_or_internal_err!(
358            !rate.le(&zero),
359            "Rate of an `ExponentialDistribution` must be positive"
360        );
361        Ok(Self {
362            rate,
363            offset,
364            positive_tail,
365        })
366    }
367
368    pub fn data_type(&self) -> DataType {
369        self.rate.data_type()
370    }
371
372    pub fn rate(&self) -> &ScalarValue {
373        &self.rate
374    }
375
376    pub fn offset(&self) -> &ScalarValue {
377        &self.offset
378    }
379
380    pub fn positive_tail(&self) -> bool {
381        self.positive_tail
382    }
383
384    pub fn mean(&self) -> Result<ScalarValue> {
385        // TODO: Should we ensure that this always returns a real number data type?
386        let one = ScalarValue::new_one(&self.data_type())?;
387        let tail_mean = one.div(&self.rate)?;
388        if self.positive_tail {
389            self.offset.add_checked(tail_mean)
390        } else {
391            self.offset.sub_checked(tail_mean)
392        }
393    }
394
395    pub fn median(&self) -> Result<ScalarValue> {
396        // TODO: Should we ensure that this always returns a real number data type?
397        let ln_two = ScalarValue::from(LN_2).cast_to(&self.data_type())?;
398        let tail_median = ln_two.div(&self.rate)?;
399        if self.positive_tail {
400            self.offset.add_checked(tail_median)
401        } else {
402            self.offset.sub_checked(tail_median)
403        }
404    }
405
406    pub fn variance(&self) -> Result<ScalarValue> {
407        // TODO: Should we ensure that this always returns a real number data type?
408        let one = ScalarValue::new_one(&self.data_type())?;
409        let rate_squared = self.rate.mul_checked(&self.rate)?;
410        one.div(rate_squared)
411    }
412
413    pub fn range(&self) -> Result<Interval> {
414        let end = ScalarValue::try_from(&self.data_type())?;
415        if self.positive_tail {
416            Interval::try_new(self.offset.clone(), end)
417        } else {
418            Interval::try_new(end, self.offset.clone())
419        }
420    }
421}
422
423impl GaussianDistribution {
424    fn try_new(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
425        let dt = mean.data_type();
426        assert_eq_or_internal_err!(
427            variance.data_type(),
428            dt,
429            "Mean and variance must have the same data type"
430        );
431        assert_or_internal_err!(
432            !variance.is_null(),
433            "Variance of a `GaussianDistribution` cannot be null"
434        );
435        let zero = ScalarValue::new_zero(&dt)?;
436        assert_or_internal_err!(
437            !variance.lt(&zero),
438            "Variance of a `GaussianDistribution` must be positive"
439        );
440        Ok(Self { mean, variance })
441    }
442
443    pub fn data_type(&self) -> DataType {
444        self.mean.data_type()
445    }
446
447    pub fn mean(&self) -> &ScalarValue {
448        &self.mean
449    }
450
451    pub fn variance(&self) -> &ScalarValue {
452        &self.variance
453    }
454
455    pub fn median(&self) -> &ScalarValue {
456        self.mean()
457    }
458
459    pub fn range(&self) -> Result<Interval> {
460        Interval::make_unbounded(&self.data_type())
461    }
462}
463
464impl BernoulliDistribution {
465    fn try_new(p: ScalarValue) -> Result<Self> {
466        if p.is_null() {
467            return Ok(Self { p });
468        }
469        let dt = p.data_type();
470        let zero = ScalarValue::new_zero(&dt)?;
471        let one = ScalarValue::new_one(&dt)?;
472        assert_or_internal_err!(
473            p.ge(&zero) && p.le(&one),
474            "Success probability of a `BernoulliDistribution` must be in [0, 1]"
475        );
476        Ok(Self { p })
477    }
478
479    pub fn data_type(&self) -> DataType {
480        self.p.data_type()
481    }
482
483    pub fn p_value(&self) -> &ScalarValue {
484        &self.p
485    }
486
487    pub fn mean(&self) -> &ScalarValue {
488        &self.p
489    }
490
491    /// Computes the median value of this distribution. In case of an unknown
492    /// success probability, the function returns a `NULL` `ScalarValue`.
493    pub fn median(&self) -> Result<ScalarValue> {
494        let dt = self.data_type();
495        if self.p.is_null() {
496            ScalarValue::try_from(&dt)
497        } else {
498            let one = ScalarValue::new_one(&dt)?;
499            if one.sub_checked(&self.p)?.lt(&self.p) {
500                ScalarValue::new_one(&dt)
501            } else {
502                ScalarValue::new_zero(&dt)
503            }
504        }
505    }
506
507    /// Computes the variance value of this distribution. In case of an unknown
508    /// success probability, the function returns a `NULL` `ScalarValue`.
509    pub fn variance(&self) -> Result<ScalarValue> {
510        let dt = self.data_type();
511        let one = ScalarValue::new_one(&dt)?;
512        let result = one.sub_checked(&self.p)?.mul_checked(&self.p);
513        debug_assert!(!self.p.is_null() || result.as_ref().is_ok_and(|r| r.is_null()));
514        result
515    }
516
517    pub fn range(&self) -> Interval {
518        let dt = self.data_type();
519        // Unwraps are safe as the constructor guarantees that the data type
520        // supports zero and one values.
521        if ScalarValue::new_zero(&dt).unwrap().eq(&self.p) {
522            Interval::FALSE
523        } else if ScalarValue::new_one(&dt).unwrap().eq(&self.p) {
524            Interval::TRUE
525        } else {
526            Interval::TRUE_OR_FALSE
527        }
528    }
529}
530
531impl GenericDistribution {
532    fn try_new(
533        mean: ScalarValue,
534        median: ScalarValue,
535        variance: ScalarValue,
536        range: Interval,
537    ) -> Result<Self> {
538        assert_ne_or_internal_err!(
539            range.data_type(),
540            DataType::Boolean,
541            "Construction of a boolean `Generic` distribution is prohibited, create a `Bernoulli` distribution instead."
542        );
543
544        let validate_location = |m: &ScalarValue| -> Result<bool> {
545            // Checks whether the given location estimate is within the range.
546            if m.is_null() {
547                Ok(true)
548            } else {
549                range.contains_value(m)
550            }
551        };
552
553        let locations_valid = validate_location(&mean)? && validate_location(&median)?;
554        let variance_non_negative = if variance.is_null() {
555            true
556        } else {
557            let zero = ScalarValue::new_zero(&variance.data_type())?;
558            !variance.lt(&zero)
559        };
560        assert_or_internal_err!(
561            locations_valid && variance_non_negative,
562            "Tried to construct an invalid `GenericDistribution` instance"
563        );
564
565        Ok(Self {
566            mean,
567            median,
568            variance,
569            range,
570        })
571    }
572
573    pub fn data_type(&self) -> DataType {
574        self.mean.data_type()
575    }
576
577    pub fn mean(&self) -> &ScalarValue {
578        &self.mean
579    }
580
581    pub fn median(&self) -> &ScalarValue {
582        &self.median
583    }
584
585    pub fn variance(&self) -> &ScalarValue {
586        &self.variance
587    }
588
589    pub fn range(&self) -> &Interval {
590        &self.range
591    }
592}
593
594/// This function takes a logical operator and two Bernoulli distributions,
595/// and it returns a new Bernoulli distribution that represents the result of
596/// the operation. Currently, only `AND` and `OR` operations are supported.
597pub fn combine_bernoullis(
598    op: &Operator,
599    left: &BernoulliDistribution,
600    right: &BernoulliDistribution,
601) -> Result<BernoulliDistribution> {
602    let left_p = left.p_value();
603    let right_p = right.p_value();
604    match op {
605        Operator::And => match (left_p.is_null(), right_p.is_null()) {
606            (false, false) => {
607                BernoulliDistribution::try_new(left_p.mul_checked(right_p)?)
608            }
609            (false, true) if left_p.eq(&ScalarValue::new_zero(&left_p.data_type())?) => {
610                Ok(left.clone())
611            }
612            (true, false)
613                if right_p.eq(&ScalarValue::new_zero(&right_p.data_type())?) =>
614            {
615                Ok(right.clone())
616            }
617            _ => {
618                let dt = Distribution::target_type(&[left_p, right_p])?;
619                BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
620            }
621        },
622        Operator::Or => match (left_p.is_null(), right_p.is_null()) {
623            (false, false) => {
624                let sum = left_p.add_checked(right_p)?;
625                let product = left_p.mul_checked(right_p)?;
626                let or_success = sum.sub_checked(product)?;
627                BernoulliDistribution::try_new(or_success)
628            }
629            (false, true) if left_p.eq(&ScalarValue::new_one(&left_p.data_type())?) => {
630                Ok(left.clone())
631            }
632            (true, false) if right_p.eq(&ScalarValue::new_one(&right_p.data_type())?) => {
633                Ok(right.clone())
634            }
635            _ => {
636                let dt = Distribution::target_type(&[left_p, right_p])?;
637                BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
638            }
639        },
640        _ => {
641            not_impl_err!("Statistical evaluation only supports AND and OR operators")
642        }
643    }
644}
645
646/// Applies the given operation to the given Gaussian distributions. Currently,
647/// this function handles only addition and subtraction operations. If the
648/// result is not a Gaussian random variable, it returns `None`. For details,
649/// see:
650///
651/// <https://en.wikipedia.org/wiki/Sum_of_normally_distributed_random_variables>
652pub fn combine_gaussians(
653    op: &Operator,
654    left: &GaussianDistribution,
655    right: &GaussianDistribution,
656) -> Result<Option<GaussianDistribution>> {
657    match op {
658        Operator::Plus => GaussianDistribution::try_new(
659            left.mean().add_checked(right.mean())?,
660            left.variance().add_checked(right.variance())?,
661        )
662        .map(Some),
663        Operator::Minus => GaussianDistribution::try_new(
664            left.mean().sub_checked(right.mean())?,
665            left.variance().add_checked(right.variance())?,
666        )
667        .map(Some),
668        _ => Ok(None),
669    }
670}
671
672/// Creates a new `Bernoulli` distribution by computing the resulting probability.
673/// Expects `op` to be a comparison operator, with `left` and `right` having
674/// numeric distributions. The resulting distribution has the `Float64` data
675/// type.
676pub fn create_bernoulli_from_comparison(
677    op: &Operator,
678    left: &Distribution,
679    right: &Distribution,
680) -> Result<Distribution> {
681    match (left, right) {
682        (Uniform(left), Uniform(right)) => {
683            match op {
684                Operator::Eq | Operator::NotEq => {
685                    let (li, ri) = (left.range(), right.range());
686                    if let Some(intersection) = li.intersect(ri)? {
687                        // If the ranges are not disjoint, calculate the probability
688                        // of equality using cardinalities:
689                        if let (Some(lc), Some(rc), Some(ic)) = (
690                            li.cardinality(),
691                            ri.cardinality(),
692                            intersection.cardinality(),
693                        ) {
694                            // Avoid overflow by widening the type temporarily:
695                            let pairs = ((lc as u128) * (rc as u128)) as f64;
696                            let p = (ic as f64).div_checked(pairs)?;
697                            // Alternative approach that may be more stable:
698                            // let p = (ic as f64)
699                            //     .div_checked(lc as f64)?
700                            //     .div_checked(rc as f64)?;
701
702                            let mut p_value = ScalarValue::from(p);
703                            if op == &Operator::NotEq {
704                                let one = ScalarValue::from(1.0);
705                                p_value = alter_fp_rounding_mode::<false, _>(
706                                    &one,
707                                    &p_value,
708                                    |lhs, rhs| lhs.sub_checked(rhs),
709                                )?;
710                            };
711                            return Distribution::new_bernoulli(p_value);
712                        }
713                    } else if op == &Operator::Eq {
714                        // If the ranges are disjoint, probability of equality is 0.
715                        return Distribution::new_bernoulli(ScalarValue::from(0.0));
716                    } else {
717                        // If the ranges are disjoint, probability of not-equality is 1.
718                        return Distribution::new_bernoulli(ScalarValue::from(1.0));
719                    }
720                }
721                Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => {
722                    // TODO: We can handle inequality operators and calculate a
723                    // `p` value instead of falling back to an unknown Bernoulli
724                    // distribution. Note that the strict and non-strict inequalities
725                    // may require slightly different logic in case of real vs.
726                    // integral data types.
727                }
728                _ => {}
729            }
730        }
731        (Gaussian(_), Gaussian(_)) => {
732            // TODO: We can handle Gaussian comparisons and calculate a `p` value
733            //       instead of falling back to an unknown Bernoulli distribution.
734        }
735        _ => {}
736    }
737    let (li, ri) = (left.range()?, right.range()?);
738    let range_evaluation = apply_operator(op, &li, &ri)?;
739    if range_evaluation.eq(&Interval::FALSE) {
740        Distribution::new_bernoulli(ScalarValue::from(0.0))
741    } else if range_evaluation.eq(&Interval::TRUE) {
742        Distribution::new_bernoulli(ScalarValue::from(1.0))
743    } else if range_evaluation.eq(&Interval::TRUE_OR_FALSE) {
744        Distribution::new_bernoulli(ScalarValue::try_from(&DataType::Float64)?)
745    } else {
746        internal_err!("This function must be called with a comparison operator")
747    }
748}
749
750/// Creates a new [`Generic`] distribution that represents the result of the
751/// given binary operation on two unknown quantities represented by their
752/// [`Distribution`] objects. The function computes the mean, median and
753/// variance if possible.
754pub fn new_generic_from_binary_op(
755    op: &Operator,
756    left: &Distribution,
757    right: &Distribution,
758) -> Result<Distribution> {
759    Distribution::new_generic(
760        compute_mean(op, left, right)?,
761        compute_median(op, left, right)?,
762        compute_variance(op, left, right)?,
763        apply_operator(op, &left.range()?, &right.range()?)?,
764    )
765}
766
767/// Computes the mean value for the result of the given binary operation on
768/// two unknown quantities represented by their [`Distribution`] objects.
769pub fn compute_mean(
770    op: &Operator,
771    left: &Distribution,
772    right: &Distribution,
773) -> Result<ScalarValue> {
774    let (left_mean, right_mean) = (left.mean()?, right.mean()?);
775
776    match op {
777        Operator::Plus => return left_mean.add_checked(right_mean),
778        Operator::Minus => return left_mean.sub_checked(right_mean),
779        // Note the independence assumption below:
780        Operator::Multiply => return left_mean.mul_checked(right_mean),
781        // TODO: We can calculate the mean for division when we support reciprocals,
782        // or know the distributions of the operands. For details, see:
783        //
784        // <https://en.wikipedia.org/wiki/Algebra_of_random_variables>
785        // <https://stats.stackexchange.com/questions/185683/distribution-of-ratio-between-two-independent-uniform-random-variables>
786        //
787        // Fall back to an unknown mean value for division:
788        Operator::Divide => {}
789        // Fall back to an unknown mean value for other cases:
790        _ => {}
791    }
792    let target_type = Distribution::target_type(&[&left_mean, &right_mean])?;
793    ScalarValue::try_from(target_type)
794}
795
796/// Computes the median value for the result of the given binary operation on
797/// two unknown quantities represented by its [`Distribution`] objects. Currently,
798/// the median is calculable only for addition and subtraction operations on:
799/// - [`Uniform`] and [`Uniform`] distributions, and
800/// - [`Gaussian`] and [`Gaussian`] distributions.
801pub fn compute_median(
802    op: &Operator,
803    left: &Distribution,
804    right: &Distribution,
805) -> Result<ScalarValue> {
806    match (left, right) {
807        (Uniform(lu), Uniform(ru)) => {
808            let (left_median, right_median) = (lu.median()?, ru.median()?);
809            // Under the independence assumption, the result is a symmetric
810            // triangular distribution, so we can simply add/subtract the
811            // median values:
812            match op {
813                Operator::Plus => return left_median.add_checked(right_median),
814                Operator::Minus => return left_median.sub_checked(right_median),
815                // Fall back to an unknown median value for other cases:
816                _ => {}
817            }
818        }
819        // Under the independence assumption, the result is another Gaussian
820        // distribution, so we can simply add/subtract the median values:
821        (Gaussian(lg), Gaussian(rg)) => match op {
822            Operator::Plus => return lg.mean().add_checked(rg.mean()),
823            Operator::Minus => return lg.mean().sub_checked(rg.mean()),
824            // Fall back to an unknown median value for other cases:
825            _ => {}
826        },
827        // Fall back to an unknown median value for other cases:
828        _ => {}
829    }
830
831    let (left_median, right_median) = (left.median()?, right.median()?);
832    let target_type = Distribution::target_type(&[&left_median, &right_median])?;
833    ScalarValue::try_from(target_type)
834}
835
836/// Computes the variance value for the result of the given binary operation on
837/// two unknown quantities represented by their [`Distribution`] objects.
838pub fn compute_variance(
839    op: &Operator,
840    left: &Distribution,
841    right: &Distribution,
842) -> Result<ScalarValue> {
843    let (left_variance, right_variance) = (left.variance()?, right.variance()?);
844
845    match op {
846        // Note the independence assumption below:
847        Operator::Plus => return left_variance.add_checked(right_variance),
848        // Note the independence assumption below:
849        Operator::Minus => return left_variance.add_checked(right_variance),
850        // Note the independence assumption below:
851        Operator::Multiply => {
852            // For more details, along with an explanation of the formula below, see:
853            //
854            // <https://en.wikipedia.org/wiki/Distribution_of_the_product_of_two_random_variables>
855            let (left_mean, right_mean) = (left.mean()?, right.mean()?);
856            let left_mean_sq = left_mean.mul_checked(&left_mean)?;
857            let right_mean_sq = right_mean.mul_checked(&right_mean)?;
858            let left_sos = left_variance.add_checked(&left_mean_sq)?;
859            let right_sos = right_variance.add_checked(&right_mean_sq)?;
860            let pos = left_mean_sq.mul_checked(right_mean_sq)?;
861            return left_sos.mul_checked(right_sos)?.sub_checked(pos);
862        }
863        // TODO: We can calculate the variance for division when we support reciprocals,
864        // or know the distributions of the operands. For details, see:
865        //
866        // <https://en.wikipedia.org/wiki/Algebra_of_random_variables>
867        // <https://stats.stackexchange.com/questions/185683/distribution-of-ratio-between-two-independent-uniform-random-variables>
868        //
869        // Fall back to an unknown variance value for division:
870        Operator::Divide => {}
871        // Fall back to an unknown variance value for other cases:
872        _ => {}
873    }
874    let target_type = Distribution::target_type(&[&left_variance, &right_variance])?;
875    ScalarValue::try_from(target_type)
876}
877
878#[cfg(test)]
879mod tests {
880    use super::{
881        BernoulliDistribution, Distribution, GaussianDistribution, UniformDistribution,
882        combine_bernoullis, combine_gaussians, compute_mean, compute_median,
883        compute_variance, create_bernoulli_from_comparison, new_generic_from_binary_op,
884    };
885    use crate::interval_arithmetic::{Interval, apply_operator};
886    use crate::operator::Operator;
887
888    use arrow::datatypes::DataType;
889    use datafusion_common::{HashSet, Result, ScalarValue};
890
891    #[test]
892    fn uniform_dist_is_valid_test() -> Result<()> {
893        assert_eq!(
894            Distribution::new_uniform(Interval::make_zero(&DataType::Int8)?)?,
895            Distribution::Uniform(UniformDistribution {
896                interval: Interval::make_zero(&DataType::Int8)?,
897            })
898        );
899
900        assert!(Distribution::new_uniform(Interval::TRUE_OR_FALSE).is_err());
901        Ok(())
902    }
903
904    #[test]
905    fn exponential_dist_is_valid_test() {
906        // This array collects test cases of the form (distribution, validity).
907        let exponentials = vec![
908            (
909                Distribution::new_exponential(ScalarValue::Null, ScalarValue::Null, true),
910                false,
911            ),
912            (
913                Distribution::new_exponential(
914                    ScalarValue::from(0_f32),
915                    ScalarValue::from(1_f32),
916                    true,
917                ),
918                false,
919            ),
920            (
921                Distribution::new_exponential(
922                    ScalarValue::from(100_f32),
923                    ScalarValue::from(1_f32),
924                    true,
925                ),
926                true,
927            ),
928            (
929                Distribution::new_exponential(
930                    ScalarValue::from(-100_f32),
931                    ScalarValue::from(1_f32),
932                    true,
933                ),
934                false,
935            ),
936        ];
937        for case in exponentials {
938            assert_eq!(case.0.is_ok(), case.1);
939        }
940    }
941
942    #[test]
943    fn gaussian_dist_is_valid_test() {
944        // This array collects test cases of the form (distribution, validity).
945        let gaussians = vec![
946            (
947                Distribution::new_gaussian(ScalarValue::Null, ScalarValue::Null),
948                false,
949            ),
950            (
951                Distribution::new_gaussian(
952                    ScalarValue::from(0_f32),
953                    ScalarValue::from(0_f32),
954                ),
955                true,
956            ),
957            (
958                Distribution::new_gaussian(
959                    ScalarValue::from(0_f32),
960                    ScalarValue::from(0.5_f32),
961                ),
962                true,
963            ),
964            (
965                Distribution::new_gaussian(
966                    ScalarValue::from(0_f32),
967                    ScalarValue::from(-0.5_f32),
968                ),
969                false,
970            ),
971        ];
972        for case in gaussians {
973            assert_eq!(case.0.is_ok(), case.1);
974        }
975    }
976
977    #[test]
978    fn bernoulli_dist_is_valid_test() {
979        // This array collects test cases of the form (distribution, validity).
980        let bernoullis = vec![
981            (Distribution::new_bernoulli(ScalarValue::Null), true),
982            (Distribution::new_bernoulli(ScalarValue::from(0.)), true),
983            (Distribution::new_bernoulli(ScalarValue::from(0.25)), true),
984            (Distribution::new_bernoulli(ScalarValue::from(1.)), true),
985            (Distribution::new_bernoulli(ScalarValue::from(11.)), false),
986            (Distribution::new_bernoulli(ScalarValue::from(-11.)), false),
987            (Distribution::new_bernoulli(ScalarValue::from(0_i64)), true),
988            (Distribution::new_bernoulli(ScalarValue::from(1_i64)), true),
989            (
990                Distribution::new_bernoulli(ScalarValue::from(11_i64)),
991                false,
992            ),
993            (
994                Distribution::new_bernoulli(ScalarValue::from(-11_i64)),
995                false,
996            ),
997        ];
998        for case in bernoullis {
999            assert_eq!(case.0.is_ok(), case.1);
1000        }
1001    }
1002
1003    #[test]
1004    fn generic_dist_is_valid_test() -> Result<()> {
1005        // This array collects test cases of the form (distribution, validity).
1006        let generic_dists = vec![
1007            // Using a boolean range to construct a Generic distribution is prohibited.
1008            (
1009                Distribution::new_generic(
1010                    ScalarValue::Null,
1011                    ScalarValue::Null,
1012                    ScalarValue::Null,
1013                    Interval::TRUE_OR_FALSE,
1014                ),
1015                false,
1016            ),
1017            (
1018                Distribution::new_generic(
1019                    ScalarValue::Null,
1020                    ScalarValue::Null,
1021                    ScalarValue::Null,
1022                    Interval::make_zero(&DataType::Float32)?,
1023                ),
1024                true,
1025            ),
1026            (
1027                Distribution::new_generic(
1028                    ScalarValue::from(0_f32),
1029                    ScalarValue::Float32(None),
1030                    ScalarValue::Float32(None),
1031                    Interval::make_zero(&DataType::Float32)?,
1032                ),
1033                true,
1034            ),
1035            (
1036                Distribution::new_generic(
1037                    ScalarValue::Float64(None),
1038                    ScalarValue::from(0.),
1039                    ScalarValue::Float64(None),
1040                    Interval::make_zero(&DataType::Float32)?,
1041                ),
1042                true,
1043            ),
1044            (
1045                Distribution::new_generic(
1046                    ScalarValue::from(-10_f32),
1047                    ScalarValue::Float32(None),
1048                    ScalarValue::Float32(None),
1049                    Interval::make_zero(&DataType::Float32)?,
1050                ),
1051                false,
1052            ),
1053            (
1054                Distribution::new_generic(
1055                    ScalarValue::Float32(None),
1056                    ScalarValue::from(10_f32),
1057                    ScalarValue::Float32(None),
1058                    Interval::make_zero(&DataType::Float32)?,
1059                ),
1060                false,
1061            ),
1062            (
1063                Distribution::new_generic(
1064                    ScalarValue::Null,
1065                    ScalarValue::Null,
1066                    ScalarValue::Null,
1067                    Interval::make_zero(&DataType::Float32)?,
1068                ),
1069                true,
1070            ),
1071            (
1072                Distribution::new_generic(
1073                    ScalarValue::from(0),
1074                    ScalarValue::from(0),
1075                    ScalarValue::Int32(None),
1076                    Interval::make_zero(&DataType::Int32)?,
1077                ),
1078                true,
1079            ),
1080            (
1081                Distribution::new_generic(
1082                    ScalarValue::from(0_f32),
1083                    ScalarValue::from(0_f32),
1084                    ScalarValue::Float32(None),
1085                    Interval::make_zero(&DataType::Float32)?,
1086                ),
1087                true,
1088            ),
1089            (
1090                Distribution::new_generic(
1091                    ScalarValue::from(50.),
1092                    ScalarValue::from(50.),
1093                    ScalarValue::Float64(None),
1094                    Interval::make(Some(0.), Some(100.))?,
1095                ),
1096                true,
1097            ),
1098            (
1099                Distribution::new_generic(
1100                    ScalarValue::from(50.),
1101                    ScalarValue::from(50.),
1102                    ScalarValue::Float64(None),
1103                    Interval::make(Some(-100.), Some(0.))?,
1104                ),
1105                false,
1106            ),
1107            (
1108                Distribution::new_generic(
1109                    ScalarValue::Float64(None),
1110                    ScalarValue::Float64(None),
1111                    ScalarValue::from(1.),
1112                    Interval::make_zero(&DataType::Float64)?,
1113                ),
1114                true,
1115            ),
1116            (
1117                Distribution::new_generic(
1118                    ScalarValue::Float64(None),
1119                    ScalarValue::Float64(None),
1120                    ScalarValue::from(-1.),
1121                    Interval::make_zero(&DataType::Float64)?,
1122                ),
1123                false,
1124            ),
1125        ];
1126        for case in generic_dists {
1127            assert_eq!(case.0.is_ok(), case.1, "{:?}", case.0);
1128        }
1129
1130        Ok(())
1131    }
1132
1133    #[test]
1134    fn mean_extraction_test() -> Result<()> {
1135        // This array collects test cases of the form (distribution, mean value).
1136        let dists = vec![
1137            (
1138                Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1139                ScalarValue::from(0_i64),
1140            ),
1141            (
1142                Distribution::new_uniform(Interval::make_zero(&DataType::Float64)?),
1143                ScalarValue::from(0.),
1144            ),
1145            (
1146                Distribution::new_uniform(Interval::make(Some(1), Some(100))?),
1147                ScalarValue::from(50),
1148            ),
1149            (
1150                Distribution::new_uniform(Interval::make(Some(-100), Some(-1))?),
1151                ScalarValue::from(-50),
1152            ),
1153            (
1154                Distribution::new_uniform(Interval::make(Some(-100), Some(100))?),
1155                ScalarValue::from(0),
1156            ),
1157            (
1158                Distribution::new_exponential(
1159                    ScalarValue::from(2.),
1160                    ScalarValue::from(0.),
1161                    true,
1162                ),
1163                ScalarValue::from(0.5),
1164            ),
1165            (
1166                Distribution::new_exponential(
1167                    ScalarValue::from(2.),
1168                    ScalarValue::from(1.),
1169                    true,
1170                ),
1171                ScalarValue::from(1.5),
1172            ),
1173            (
1174                Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1175                ScalarValue::from(0.),
1176            ),
1177            (
1178                Distribution::new_gaussian(
1179                    ScalarValue::from(-2.),
1180                    ScalarValue::from(0.5),
1181                ),
1182                ScalarValue::from(-2.),
1183            ),
1184            (
1185                Distribution::new_bernoulli(ScalarValue::from(0.5)),
1186                ScalarValue::from(0.5),
1187            ),
1188            (
1189                Distribution::new_generic(
1190                    ScalarValue::from(42.),
1191                    ScalarValue::from(42.),
1192                    ScalarValue::Float64(None),
1193                    Interval::make(Some(25.), Some(50.))?,
1194                ),
1195                ScalarValue::from(42.),
1196            ),
1197        ];
1198
1199        for case in dists {
1200            assert_eq!(case.0?.mean()?, case.1);
1201        }
1202
1203        Ok(())
1204    }
1205
1206    #[test]
1207    fn median_extraction_test() -> Result<()> {
1208        // This array collects test cases of the form (distribution, median value).
1209        let dists = vec![
1210            (
1211                Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1212                ScalarValue::from(0_i64),
1213            ),
1214            (
1215                Distribution::new_uniform(Interval::make(Some(25.), Some(75.))?),
1216                ScalarValue::from(50.),
1217            ),
1218            (
1219                Distribution::new_exponential(
1220                    ScalarValue::from(2_f64.ln()),
1221                    ScalarValue::from(0.),
1222                    true,
1223                ),
1224                ScalarValue::from(1.),
1225            ),
1226            (
1227                Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1228                ScalarValue::from(2.),
1229            ),
1230            (
1231                Distribution::new_bernoulli(ScalarValue::from(0.25)),
1232                ScalarValue::from(0.),
1233            ),
1234            (
1235                Distribution::new_bernoulli(ScalarValue::from(0.75)),
1236                ScalarValue::from(1.),
1237            ),
1238            (
1239                Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1240                ScalarValue::from(2.),
1241            ),
1242            (
1243                Distribution::new_generic(
1244                    ScalarValue::from(12.),
1245                    ScalarValue::from(12.),
1246                    ScalarValue::Float64(None),
1247                    Interval::make(Some(0.), Some(25.))?,
1248                ),
1249                ScalarValue::from(12.),
1250            ),
1251        ];
1252
1253        for case in dists {
1254            assert_eq!(case.0?.median()?, case.1);
1255        }
1256
1257        Ok(())
1258    }
1259
1260    #[test]
1261    fn variance_extraction_test() -> Result<()> {
1262        // This array collects test cases of the form (distribution, variance value).
1263        let dists = vec![
1264            (
1265                Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?),
1266                ScalarValue::from(12.),
1267            ),
1268            (
1269                Distribution::new_exponential(
1270                    ScalarValue::from(10.),
1271                    ScalarValue::from(0.),
1272                    true,
1273                ),
1274                ScalarValue::from(0.01),
1275            ),
1276            (
1277                Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1278                ScalarValue::from(1.),
1279            ),
1280            (
1281                Distribution::new_bernoulli(ScalarValue::from(0.5)),
1282                ScalarValue::from(0.25),
1283            ),
1284            (
1285                Distribution::new_generic(
1286                    ScalarValue::Float64(None),
1287                    ScalarValue::Float64(None),
1288                    ScalarValue::from(0.02),
1289                    Interval::make_zero(&DataType::Float64)?,
1290                ),
1291                ScalarValue::from(0.02),
1292            ),
1293        ];
1294
1295        for case in dists {
1296            assert_eq!(case.0?.variance()?, case.1);
1297        }
1298
1299        Ok(())
1300    }
1301
1302    #[test]
1303    fn test_calculate_generic_properties_gauss_gauss() -> Result<()> {
1304        let dist_a =
1305            Distribution::new_gaussian(ScalarValue::from(10.), ScalarValue::from(0.0))?;
1306        let dist_b =
1307            Distribution::new_gaussian(ScalarValue::from(20.), ScalarValue::from(0.0))?;
1308
1309        let test_data = vec![
1310            // Mean:
1311            (
1312                compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1313                ScalarValue::from(30.),
1314            ),
1315            (
1316                compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1317                ScalarValue::from(-10.),
1318            ),
1319            // Median:
1320            (
1321                compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1322                ScalarValue::from(30.),
1323            ),
1324            (
1325                compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1326                ScalarValue::from(-10.),
1327            ),
1328        ];
1329        for (actual, expected) in test_data {
1330            assert_eq!(actual, expected);
1331        }
1332
1333        Ok(())
1334    }
1335
1336    #[test]
1337    fn test_combine_bernoullis_and_op() -> Result<()> {
1338        let op = Operator::And;
1339        let left = BernoulliDistribution::try_new(ScalarValue::from(0.5))?;
1340        let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1341        let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1342        let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1343
1344        assert_eq!(
1345            combine_bernoullis(&op, &left, &right)?.p_value(),
1346            &ScalarValue::from(0.5 * 0.4)
1347        );
1348        assert_eq!(
1349            combine_bernoullis(&op, &left_null, &right)?.p_value(),
1350            &ScalarValue::Float64(None)
1351        );
1352        assert_eq!(
1353            combine_bernoullis(&op, &left, &right_null)?.p_value(),
1354            &ScalarValue::Float64(None)
1355        );
1356        assert_eq!(
1357            combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1358            &ScalarValue::Null
1359        );
1360
1361        Ok(())
1362    }
1363
1364    #[test]
1365    fn test_combine_bernoullis_or_op() -> Result<()> {
1366        let op = Operator::Or;
1367        let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1368        let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1369        let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1370        let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1371
1372        assert_eq!(
1373            combine_bernoullis(&op, &left, &right)?.p_value(),
1374            &ScalarValue::from(0.6 + 0.4 - (0.6 * 0.4))
1375        );
1376        assert_eq!(
1377            combine_bernoullis(&op, &left_null, &right)?.p_value(),
1378            &ScalarValue::Float64(None)
1379        );
1380        assert_eq!(
1381            combine_bernoullis(&op, &left, &right_null)?.p_value(),
1382            &ScalarValue::Float64(None)
1383        );
1384        assert_eq!(
1385            combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1386            &ScalarValue::Null
1387        );
1388
1389        Ok(())
1390    }
1391
1392    #[test]
1393    fn test_combine_bernoullis_unsupported_ops() -> Result<()> {
1394        let mut operator_set = operator_set();
1395        operator_set.remove(&Operator::And);
1396        operator_set.remove(&Operator::Or);
1397
1398        let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1399        let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1400        for op in operator_set {
1401            assert!(
1402                combine_bernoullis(&op, &left, &right).is_err(),
1403                "Operator {op} should not be supported for Bernoulli distributions"
1404            );
1405        }
1406
1407        Ok(())
1408    }
1409
1410    #[test]
1411    fn test_combine_gaussians_addition() -> Result<()> {
1412        let op = Operator::Plus;
1413        let left = GaussianDistribution::try_new(
1414            ScalarValue::from(3.0),
1415            ScalarValue::from(2.0),
1416        )?;
1417        let right = GaussianDistribution::try_new(
1418            ScalarValue::from(4.0),
1419            ScalarValue::from(1.0),
1420        )?;
1421
1422        let result = combine_gaussians(&op, &left, &right)?.unwrap();
1423
1424        assert_eq!(result.mean(), &ScalarValue::from(7.0)); // 3.0 + 4.0
1425        assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0
1426        Ok(())
1427    }
1428
1429    #[test]
1430    fn test_combine_gaussians_subtraction() -> Result<()> {
1431        let op = Operator::Minus;
1432        let left = GaussianDistribution::try_new(
1433            ScalarValue::from(7.0),
1434            ScalarValue::from(2.0),
1435        )?;
1436        let right = GaussianDistribution::try_new(
1437            ScalarValue::from(4.0),
1438            ScalarValue::from(1.0),
1439        )?;
1440
1441        let result = combine_gaussians(&op, &left, &right)?.unwrap();
1442
1443        assert_eq!(result.mean(), &ScalarValue::from(3.0)); // 7.0 - 4.0
1444        assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0
1445
1446        Ok(())
1447    }
1448
1449    #[test]
1450    fn test_combine_gaussians_unsupported_ops() -> Result<()> {
1451        let mut operator_set = operator_set();
1452        operator_set.remove(&Operator::Plus);
1453        operator_set.remove(&Operator::Minus);
1454
1455        let left = GaussianDistribution::try_new(
1456            ScalarValue::from(7.0),
1457            ScalarValue::from(2.0),
1458        )?;
1459        let right = GaussianDistribution::try_new(
1460            ScalarValue::from(4.0),
1461            ScalarValue::from(1.0),
1462        )?;
1463        for op in operator_set {
1464            assert!(
1465                combine_gaussians(&op, &left, &right)?.is_none(),
1466                "Operator {op} should not be supported for Gaussian distributions"
1467            );
1468        }
1469
1470        Ok(())
1471    }
1472
1473    // Expected test results were calculated in Wolfram Mathematica, by using:
1474    //
1475    // *METHOD_NAME*[TransformedDistribution[
1476    //  x *op* y,
1477    //  {x ~ *DISTRIBUTION_X*[..], y ~ *DISTRIBUTION_Y*[..]}
1478    // ]]
1479    #[test]
1480    fn test_calculate_generic_properties_uniform_uniform() -> Result<()> {
1481        let dist_a = Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?)?;
1482        let dist_b = Distribution::new_uniform(Interval::make(Some(12.), Some(36.))?)?;
1483
1484        let test_data = vec![
1485            // Mean:
1486            (
1487                compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1488                ScalarValue::from(30.),
1489            ),
1490            (
1491                compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1492                ScalarValue::from(-18.),
1493            ),
1494            (
1495                compute_mean(&Operator::Multiply, &dist_a, &dist_b)?,
1496                ScalarValue::from(144.),
1497            ),
1498            // Median:
1499            (
1500                compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1501                ScalarValue::from(30.),
1502            ),
1503            (
1504                compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1505                ScalarValue::from(-18.),
1506            ),
1507            // Variance:
1508            (
1509                compute_variance(&Operator::Plus, &dist_a, &dist_b)?,
1510                ScalarValue::from(60.),
1511            ),
1512            (
1513                compute_variance(&Operator::Minus, &dist_a, &dist_b)?,
1514                ScalarValue::from(60.),
1515            ),
1516            (
1517                compute_variance(&Operator::Multiply, &dist_a, &dist_b)?,
1518                ScalarValue::from(9216.),
1519            ),
1520        ];
1521        for (actual, expected) in test_data {
1522            assert_eq!(actual, expected);
1523        }
1524
1525        Ok(())
1526    }
1527
1528    /// Test for `Uniform`-`Uniform`, `Uniform`-`Generic`, `Generic`-`Uniform`,
1529    /// `Generic`-`Generic` pairs, where range is always present.
1530    #[test]
1531    fn test_compute_range_where_present() -> Result<()> {
1532        let a = &Interval::make(Some(0.), Some(12.0))?;
1533        let b = &Interval::make(Some(0.), Some(12.0))?;
1534        let mean = ScalarValue::from(6.0);
1535        for (dist_a, dist_b) in [
1536            (
1537                Distribution::new_uniform(a.clone())?,
1538                Distribution::new_uniform(b.clone())?,
1539            ),
1540            (
1541                Distribution::new_generic(
1542                    mean.clone(),
1543                    mean.clone(),
1544                    ScalarValue::Float64(None),
1545                    a.clone(),
1546                )?,
1547                Distribution::new_uniform(b.clone())?,
1548            ),
1549            (
1550                Distribution::new_uniform(a.clone())?,
1551                Distribution::new_generic(
1552                    mean.clone(),
1553                    mean.clone(),
1554                    ScalarValue::Float64(None),
1555                    b.clone(),
1556                )?,
1557            ),
1558            (
1559                Distribution::new_generic(
1560                    mean.clone(),
1561                    mean.clone(),
1562                    ScalarValue::Float64(None),
1563                    a.clone(),
1564                )?,
1565                Distribution::new_generic(
1566                    mean.clone(),
1567                    mean.clone(),
1568                    ScalarValue::Float64(None),
1569                    b.clone(),
1570                )?,
1571            ),
1572        ] {
1573            use super::Operator::{
1574                Divide, Eq, Gt, GtEq, Lt, LtEq, Minus, Multiply, NotEq, Plus,
1575            };
1576            for op in [Plus, Minus, Multiply, Divide] {
1577                assert_eq!(
1578                    new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?,
1579                    apply_operator(&op, a, b)?,
1580                    "Failed for {dist_a:?} {op} {dist_b:?}"
1581                );
1582            }
1583            for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] {
1584                assert_eq!(
1585                    create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?,
1586                    apply_operator(&op, a, b)?,
1587                    "Failed for {dist_a:?} {op} {dist_b:?}"
1588                );
1589            }
1590        }
1591
1592        Ok(())
1593    }
1594
1595    fn operator_set() -> HashSet<Operator> {
1596        use super::Operator::*;
1597
1598        let all_ops = vec![
1599            And,
1600            Or,
1601            Eq,
1602            NotEq,
1603            Gt,
1604            GtEq,
1605            Lt,
1606            LtEq,
1607            Plus,
1608            Minus,
1609            Multiply,
1610            Divide,
1611            Modulo,
1612            IsDistinctFrom,
1613            IsNotDistinctFrom,
1614            RegexMatch,
1615            RegexIMatch,
1616            RegexNotMatch,
1617            RegexNotIMatch,
1618            LikeMatch,
1619            ILikeMatch,
1620            NotLikeMatch,
1621            NotILikeMatch,
1622            BitwiseAnd,
1623            BitwiseOr,
1624            BitwiseXor,
1625            BitwiseShiftRight,
1626            BitwiseShiftLeft,
1627            StringConcat,
1628            AtArrow,
1629            ArrowAt,
1630        ];
1631
1632        all_ops.into_iter().collect()
1633    }
1634}