Skip to main content

datafusion_expr_common/
statistics.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Probabilistic distributions for expression-level statistics (unused).
19//!
20//! Note: All public items in this module are **deprecated** as of `54.0.0`.
21//!
22//! See <https://github.com/apache/datafusion/pull/22071> for details.
23
24// The whole module is deprecated; suppress warnings from intra-module uses
25// of the deprecated types so the module continues to compile.
26#![allow(deprecated)]
27
28use std::f64::consts::LN_2;
29
30use crate::interval_arithmetic::{Interval, apply_operator};
31use crate::operator::Operator;
32use crate::type_coercion::binary::binary_numeric_coercion;
33
34use arrow::array::ArrowNativeTypeOp;
35use arrow::datatypes::DataType;
36use datafusion_common::rounding::alter_fp_rounding_mode;
37use datafusion_common::{
38    Result, ScalarValue, assert_eq_or_internal_err, assert_ne_or_internal_err,
39    assert_or_internal_err, internal_err, not_impl_err,
40};
41
42/// This object defines probabilistic distributions that encode uncertain
43/// information about a single, scalar value. Currently, we support five core
44/// statistical distributions. New variants will be added over time.
45///
46/// This object is the lowest-level object in the statistics hierarchy, and it
47/// is the main unit of calculus when evaluating expressions in a statistical
48/// context. Notions like column and table statistics are built on top of this
49/// object and the operations it supports.
50#[deprecated(
51    since = "54.0.0",
52    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
53)]
54#[derive(Clone, Debug, PartialEq)]
55pub enum Distribution {
56    Uniform(UniformDistribution),
57    Exponential(ExponentialDistribution),
58    Gaussian(GaussianDistribution),
59    Bernoulli(BernoulliDistribution),
60    Generic(GenericDistribution),
61}
62
63use Distribution::{Bernoulli, Exponential, Gaussian, Generic, Uniform};
64
65impl Distribution {
66    /// Constructs a new [`Uniform`] distribution from the given [`Interval`].
67    pub fn new_uniform(interval: Interval) -> Result<Self> {
68        UniformDistribution::try_new(interval).map(Uniform)
69    }
70
71    /// Constructs a new [`Exponential`] distribution from the given rate/offset
72    /// pair, and validates the given parameters.
73    pub fn new_exponential(
74        rate: ScalarValue,
75        offset: ScalarValue,
76        positive_tail: bool,
77    ) -> Result<Self> {
78        ExponentialDistribution::try_new(rate, offset, positive_tail).map(Exponential)
79    }
80
81    /// Constructs a new [`Gaussian`] distribution from the given mean/variance
82    /// pair, and validates the given parameters.
83    pub fn new_gaussian(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
84        GaussianDistribution::try_new(mean, variance).map(Gaussian)
85    }
86
87    /// Constructs a new [`Bernoulli`] distribution from the given success
88    /// probability, and validates the given parameters.
89    pub fn new_bernoulli(p: ScalarValue) -> Result<Self> {
90        BernoulliDistribution::try_new(p).map(Bernoulli)
91    }
92
93    /// Constructs a new [`Generic`] distribution from the given mean, median,
94    /// variance, and range values after validating the given parameters.
95    pub fn new_generic(
96        mean: ScalarValue,
97        median: ScalarValue,
98        variance: ScalarValue,
99        range: Interval,
100    ) -> Result<Self> {
101        GenericDistribution::try_new(mean, median, variance, range).map(Generic)
102    }
103
104    /// Constructs a new [`Generic`] distribution from the given range. Other
105    /// parameters (mean, median and variance) are initialized with null values.
106    pub fn new_from_interval(range: Interval) -> Result<Self> {
107        let null = ScalarValue::try_from(range.data_type())?;
108        Distribution::new_generic(null.clone(), null.clone(), null, range)
109    }
110
111    /// Extracts the mean value of this uncertain quantity, depending on its
112    /// distribution:
113    /// - A [`Uniform`] distribution's interval determines its mean value, which
114    ///   is the arithmetic average of the interval endpoints.
115    /// - An [`Exponential`] distribution's mean is calculable by the formula
116    ///   `offset + 1 / λ`, where `λ` is the (non-negative) rate.
117    /// - A [`Gaussian`] distribution contains the mean explicitly.
118    /// - A [`Bernoulli`] distribution's mean is equal to its success probability `p`.
119    /// - A [`Generic`] distribution _may_ have it explicitly, or this information
120    ///   may be absent.
121    pub fn mean(&self) -> Result<ScalarValue> {
122        match &self {
123            Uniform(u) => u.mean(),
124            Exponential(e) => e.mean(),
125            Gaussian(g) => Ok(g.mean().clone()),
126            Bernoulli(b) => Ok(b.mean().clone()),
127            Generic(u) => Ok(u.mean().clone()),
128        }
129    }
130
131    /// Extracts the median value of this uncertain quantity, depending on its
132    /// distribution:
133    /// - A [`Uniform`] distribution's interval determines its median value, which
134    ///   is the arithmetic average of the interval endpoints.
135    /// - An [`Exponential`] distribution's median is calculable by the formula
136    ///   `offset + ln(2) / λ`, where `λ` is the (non-negative) rate.
137    /// - A [`Gaussian`] distribution's median is equal to its mean, which is
138    ///   specified explicitly.
139    /// - A [`Bernoulli`] distribution's median is `1` if `p > 0.5` and `0`
140    ///   otherwise, where `p` is the success probability.
141    /// - A [`Generic`] distribution _may_ have it explicitly, or this information
142    ///   may be absent.
143    pub fn median(&self) -> Result<ScalarValue> {
144        match &self {
145            Uniform(u) => u.median(),
146            Exponential(e) => e.median(),
147            Gaussian(g) => Ok(g.median().clone()),
148            Bernoulli(b) => b.median(),
149            Generic(u) => Ok(u.median().clone()),
150        }
151    }
152
153    /// Extracts the variance value of this uncertain quantity, depending on
154    /// its distribution:
155    /// - A [`Uniform`] distribution's interval determines its variance value, which
156    ///   is calculable by the formula `(upper - lower) ^ 2 / 12`.
157    /// - An [`Exponential`] distribution's variance is calculable by the formula
158    ///   `1 / (λ ^ 2)`, where `λ` is the (non-negative) rate.
159    /// - A [`Gaussian`] distribution's variance is specified explicitly.
160    /// - A [`Bernoulli`] distribution's median is given by the formula `p * (1 - p)`
161    ///   where `p` is the success probability.
162    /// - A [`Generic`] distribution _may_ have it explicitly, or this information
163    ///   may be absent.
164    pub fn variance(&self) -> Result<ScalarValue> {
165        match &self {
166            Uniform(u) => u.variance(),
167            Exponential(e) => e.variance(),
168            Gaussian(g) => Ok(g.variance.clone()),
169            Bernoulli(b) => b.variance(),
170            Generic(u) => Ok(u.variance.clone()),
171        }
172    }
173
174    /// Extracts the range of this uncertain quantity, depending on its
175    /// distribution:
176    /// - A [`Uniform`] distribution's range is simply its interval.
177    /// - An [`Exponential`] distribution's range is `[offset, +∞)`.
178    /// - A [`Gaussian`] distribution's range is unbounded.
179    /// - A [`Bernoulli`] distribution's range is [`Interval::TRUE_OR_FALSE`], if
180    ///   `p` is neither `0` nor `1`. Otherwise, it is [`Interval::FALSE`]
181    ///   and [`Interval::TRUE`], respectively.
182    /// - A [`Generic`] distribution is unbounded by default, but more information
183    ///   may be present.
184    pub fn range(&self) -> Result<Interval> {
185        match &self {
186            Uniform(u) => Ok(u.range().clone()),
187            Exponential(e) => e.range(),
188            Gaussian(g) => g.range(),
189            Bernoulli(b) => Ok(b.range()),
190            Generic(u) => Ok(u.range().clone()),
191        }
192    }
193
194    /// Returns the data type of the statistical parameters comprising this
195    /// distribution.
196    pub fn data_type(&self) -> DataType {
197        match &self {
198            Uniform(u) => u.data_type(),
199            Exponential(e) => e.data_type(),
200            Gaussian(g) => g.data_type(),
201            Bernoulli(b) => b.data_type(),
202            Generic(u) => u.data_type(),
203        }
204    }
205
206    pub fn target_type(args: &[&ScalarValue]) -> Result<DataType> {
207        let mut arg_types = args
208            .iter()
209            .filter(|&&arg| arg != &ScalarValue::Null)
210            .map(|&arg| arg.data_type());
211
212        let Some(dt) = arg_types.next().map_or_else(
213            || Some(DataType::Null),
214            |first| {
215                arg_types
216                    .try_fold(first, |target, arg| binary_numeric_coercion(&target, &arg))
217            },
218        ) else {
219            return internal_err!("Can only evaluate statistics for numeric types");
220        };
221        Ok(dt)
222    }
223}
224
225/// Uniform distribution, represented by its range. If the given range extends
226/// towards infinity, the distribution will be improper -- which is OK. For a
227/// more in-depth discussion, see:
228///
229/// <https://en.wikipedia.org/wiki/Continuous_uniform_distribution>
230/// <https://en.wikipedia.org/wiki/Prior_probability#Improper_priors>
231#[deprecated(
232    since = "54.0.0",
233    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
234)]
235#[derive(Clone, Debug, PartialEq)]
236pub struct UniformDistribution {
237    interval: Interval,
238}
239
240/// Exponential distribution with an optional shift. The probability density
241/// function (PDF) is defined as follows:
242///
243/// For a positive tail (when `positive_tail` is `true`):
244///
245/// `f(x; λ, offset) = λ exp(-λ (x - offset))    for x ≥ offset`
246///
247/// For a negative tail (when `positive_tail` is `false`):
248///
249/// `f(x; λ, offset) = λ exp(-λ (offset - x))    for x ≤ offset`
250///
251///
252/// In both cases, the PDF is `0` outside the specified domain.
253///
254/// For more information, see:
255///
256/// <https://en.wikipedia.org/wiki/Exponential_distribution>
257#[deprecated(
258    since = "54.0.0",
259    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
260)]
261#[derive(Clone, Debug, PartialEq)]
262pub struct ExponentialDistribution {
263    rate: ScalarValue,
264    offset: ScalarValue,
265    /// Indicates whether the exponential distribution has a positive tail;
266    /// i.e. it extends towards positive infinity.
267    positive_tail: bool,
268}
269
270/// Gaussian (normal) distribution, represented by its mean and variance.
271/// For a more in-depth discussion, see:
272///
273/// <https://en.wikipedia.org/wiki/Normal_distribution>
274#[deprecated(
275    since = "54.0.0",
276    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
277)]
278#[derive(Clone, Debug, PartialEq)]
279pub struct GaussianDistribution {
280    mean: ScalarValue,
281    variance: ScalarValue,
282}
283
284/// Bernoulli distribution with success probability `p`. If `p` has a null value,
285/// the success probability is unknown. For a more in-depth discussion, see:
286///
287/// <https://en.wikipedia.org/wiki/Bernoulli_distribution>
288#[deprecated(
289    since = "54.0.0",
290    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
291)]
292#[derive(Clone, Debug, PartialEq)]
293pub struct BernoulliDistribution {
294    p: ScalarValue,
295}
296
297/// A generic distribution whose functional form is not available, which is
298/// approximated via some summary statistics. For a more in-depth discussion, see:
299///
300/// <https://en.wikipedia.org/wiki/Summary_statistics>
301#[deprecated(
302    since = "54.0.0",
303    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
304)]
305#[derive(Clone, Debug, PartialEq)]
306pub struct GenericDistribution {
307    mean: ScalarValue,
308    median: ScalarValue,
309    variance: ScalarValue,
310    range: Interval,
311}
312
313impl UniformDistribution {
314    fn try_new(interval: Interval) -> Result<Self> {
315        assert_ne_or_internal_err!(
316            interval.data_type(),
317            DataType::Boolean,
318            "Construction of a boolean `Uniform` distribution is prohibited, create a `Bernoulli` distribution instead."
319        );
320
321        Ok(Self { interval })
322    }
323
324    pub fn data_type(&self) -> DataType {
325        self.interval.data_type()
326    }
327
328    /// Computes the mean value of this distribution. In case of improper
329    /// distributions (i.e. when the range is unbounded), the function returns
330    /// a `NULL` `ScalarValue`.
331    pub fn mean(&self) -> Result<ScalarValue> {
332        // TODO: Should we ensure that this always returns a real number data type?
333        let dt = self.data_type();
334        let two = ScalarValue::from(2).cast_to(&dt)?;
335        let result = self
336            .interval
337            .lower()
338            .add_checked(self.interval.upper())?
339            .div(two);
340        debug_assert!(
341            !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
342        );
343        result
344    }
345
346    pub fn median(&self) -> Result<ScalarValue> {
347        self.mean()
348    }
349
350    /// Computes the variance value of this distribution. In case of improper
351    /// distributions (i.e. when the range is unbounded), the function returns
352    /// a `NULL` `ScalarValue`.
353    pub fn variance(&self) -> Result<ScalarValue> {
354        // TODO: Should we ensure that this always returns a real number data type?
355        let width = self.interval.width()?;
356        let dt = width.data_type();
357        let twelve = ScalarValue::from(12).cast_to(&dt)?;
358        let result = width.mul_checked(&width)?.div(twelve);
359        debug_assert!(
360            !self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
361        );
362        result
363    }
364
365    pub fn range(&self) -> &Interval {
366        &self.interval
367    }
368}
369
370impl ExponentialDistribution {
371    fn try_new(
372        rate: ScalarValue,
373        offset: ScalarValue,
374        positive_tail: bool,
375    ) -> Result<Self> {
376        let dt = rate.data_type();
377        assert_eq_or_internal_err!(
378            offset.data_type(),
379            dt,
380            "Rate and offset must have the same data type"
381        );
382        assert_or_internal_err!(
383            !offset.is_null(),
384            "Offset of an `ExponentialDistribution` cannot be null"
385        );
386        assert_or_internal_err!(
387            !rate.is_null(),
388            "Rate of an `ExponentialDistribution` cannot be null"
389        );
390        let zero = ScalarValue::new_zero(&dt)?;
391        assert_or_internal_err!(
392            !rate.le(&zero),
393            "Rate of an `ExponentialDistribution` must be positive"
394        );
395        Ok(Self {
396            rate,
397            offset,
398            positive_tail,
399        })
400    }
401
402    pub fn data_type(&self) -> DataType {
403        self.rate.data_type()
404    }
405
406    pub fn rate(&self) -> &ScalarValue {
407        &self.rate
408    }
409
410    pub fn offset(&self) -> &ScalarValue {
411        &self.offset
412    }
413
414    pub fn positive_tail(&self) -> bool {
415        self.positive_tail
416    }
417
418    pub fn mean(&self) -> Result<ScalarValue> {
419        // TODO: Should we ensure that this always returns a real number data type?
420        let one = ScalarValue::new_one(&self.data_type())?;
421        let tail_mean = one.div(&self.rate)?;
422        if self.positive_tail {
423            self.offset.add_checked(tail_mean)
424        } else {
425            self.offset.sub_checked(tail_mean)
426        }
427    }
428
429    pub fn median(&self) -> Result<ScalarValue> {
430        // TODO: Should we ensure that this always returns a real number data type?
431        let ln_two = ScalarValue::from(LN_2).cast_to(&self.data_type())?;
432        let tail_median = ln_two.div(&self.rate)?;
433        if self.positive_tail {
434            self.offset.add_checked(tail_median)
435        } else {
436            self.offset.sub_checked(tail_median)
437        }
438    }
439
440    pub fn variance(&self) -> Result<ScalarValue> {
441        // TODO: Should we ensure that this always returns a real number data type?
442        let one = ScalarValue::new_one(&self.data_type())?;
443        let rate_squared = self.rate.mul_checked(&self.rate)?;
444        one.div(rate_squared)
445    }
446
447    pub fn range(&self) -> Result<Interval> {
448        let end = ScalarValue::try_from(&self.data_type())?;
449        if self.positive_tail {
450            Interval::try_new(self.offset.clone(), end)
451        } else {
452            Interval::try_new(end, self.offset.clone())
453        }
454    }
455}
456
457impl GaussianDistribution {
458    fn try_new(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
459        let dt = mean.data_type();
460        assert_eq_or_internal_err!(
461            variance.data_type(),
462            dt,
463            "Mean and variance must have the same data type"
464        );
465        assert_or_internal_err!(
466            !variance.is_null(),
467            "Variance of a `GaussianDistribution` cannot be null"
468        );
469        let zero = ScalarValue::new_zero(&dt)?;
470        assert_or_internal_err!(
471            !variance.lt(&zero),
472            "Variance of a `GaussianDistribution` must be positive"
473        );
474        Ok(Self { mean, variance })
475    }
476
477    pub fn data_type(&self) -> DataType {
478        self.mean.data_type()
479    }
480
481    pub fn mean(&self) -> &ScalarValue {
482        &self.mean
483    }
484
485    pub fn variance(&self) -> &ScalarValue {
486        &self.variance
487    }
488
489    pub fn median(&self) -> &ScalarValue {
490        self.mean()
491    }
492
493    pub fn range(&self) -> Result<Interval> {
494        Interval::make_unbounded(&self.data_type())
495    }
496}
497
498impl BernoulliDistribution {
499    fn try_new(p: ScalarValue) -> Result<Self> {
500        if p.is_null() {
501            return Ok(Self { p });
502        }
503        let dt = p.data_type();
504        let zero = ScalarValue::new_zero(&dt)?;
505        let one = ScalarValue::new_one(&dt)?;
506        assert_or_internal_err!(
507            p.ge(&zero) && p.le(&one),
508            "Success probability of a `BernoulliDistribution` must be in [0, 1]"
509        );
510        Ok(Self { p })
511    }
512
513    pub fn data_type(&self) -> DataType {
514        self.p.data_type()
515    }
516
517    pub fn p_value(&self) -> &ScalarValue {
518        &self.p
519    }
520
521    pub fn mean(&self) -> &ScalarValue {
522        &self.p
523    }
524
525    /// Computes the median value of this distribution. In case of an unknown
526    /// success probability, the function returns a `NULL` `ScalarValue`.
527    pub fn median(&self) -> Result<ScalarValue> {
528        let dt = self.data_type();
529        if self.p.is_null() {
530            ScalarValue::try_from(&dt)
531        } else {
532            let one = ScalarValue::new_one(&dt)?;
533            if one.sub_checked(&self.p)?.lt(&self.p) {
534                ScalarValue::new_one(&dt)
535            } else {
536                ScalarValue::new_zero(&dt)
537            }
538        }
539    }
540
541    /// Computes the variance value of this distribution. In case of an unknown
542    /// success probability, the function returns a `NULL` `ScalarValue`.
543    pub fn variance(&self) -> Result<ScalarValue> {
544        let dt = self.data_type();
545        let one = ScalarValue::new_one(&dt)?;
546        let result = one.sub_checked(&self.p)?.mul_checked(&self.p);
547        debug_assert!(!self.p.is_null() || result.as_ref().is_ok_and(|r| r.is_null()));
548        result
549    }
550
551    pub fn range(&self) -> Interval {
552        let dt = self.data_type();
553        // Unwraps are safe as the constructor guarantees that the data type
554        // supports zero and one values.
555        if ScalarValue::new_zero(&dt).unwrap().eq(&self.p) {
556            Interval::FALSE
557        } else if ScalarValue::new_one(&dt).unwrap().eq(&self.p) {
558            Interval::TRUE
559        } else {
560            Interval::TRUE_OR_FALSE
561        }
562    }
563}
564
565impl GenericDistribution {
566    fn try_new(
567        mean: ScalarValue,
568        median: ScalarValue,
569        variance: ScalarValue,
570        range: Interval,
571    ) -> Result<Self> {
572        assert_ne_or_internal_err!(
573            range.data_type(),
574            DataType::Boolean,
575            "Construction of a boolean `Generic` distribution is prohibited, create a `Bernoulli` distribution instead."
576        );
577
578        let validate_location = |m: &ScalarValue| -> Result<bool> {
579            // Checks whether the given location estimate is within the range.
580            if m.is_null() {
581                Ok(true)
582            } else {
583                range.contains_value(m)
584            }
585        };
586
587        let locations_valid = validate_location(&mean)? && validate_location(&median)?;
588        let variance_non_negative = if variance.is_null() {
589            true
590        } else {
591            let zero = ScalarValue::new_zero(&variance.data_type())?;
592            !variance.lt(&zero)
593        };
594        assert_or_internal_err!(
595            locations_valid && variance_non_negative,
596            "Tried to construct an invalid `GenericDistribution` instance"
597        );
598
599        Ok(Self {
600            mean,
601            median,
602            variance,
603            range,
604        })
605    }
606
607    pub fn data_type(&self) -> DataType {
608        self.mean.data_type()
609    }
610
611    pub fn mean(&self) -> &ScalarValue {
612        &self.mean
613    }
614
615    pub fn median(&self) -> &ScalarValue {
616        &self.median
617    }
618
619    pub fn variance(&self) -> &ScalarValue {
620        &self.variance
621    }
622
623    pub fn range(&self) -> &Interval {
624        &self.range
625    }
626}
627
628/// This function takes a logical operator and two Bernoulli distributions,
629/// and it returns a new Bernoulli distribution that represents the result of
630/// the operation. Currently, only `AND` and `OR` operations are supported.
631#[deprecated(
632    since = "54.0.0",
633    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
634)]
635pub fn combine_bernoullis(
636    op: &Operator,
637    left: &BernoulliDistribution,
638    right: &BernoulliDistribution,
639) -> Result<BernoulliDistribution> {
640    let left_p = left.p_value();
641    let right_p = right.p_value();
642    match op {
643        Operator::And => match (left_p.is_null(), right_p.is_null()) {
644            (false, false) => {
645                BernoulliDistribution::try_new(left_p.mul_checked(right_p)?)
646            }
647            (false, true) if left_p.eq(&ScalarValue::new_zero(&left_p.data_type())?) => {
648                Ok(left.clone())
649            }
650            (true, false)
651                if right_p.eq(&ScalarValue::new_zero(&right_p.data_type())?) =>
652            {
653                Ok(right.clone())
654            }
655            _ => {
656                let dt = Distribution::target_type(&[left_p, right_p])?;
657                BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
658            }
659        },
660        Operator::Or => match (left_p.is_null(), right_p.is_null()) {
661            (false, false) => {
662                let sum = left_p.add_checked(right_p)?;
663                let product = left_p.mul_checked(right_p)?;
664                let or_success = sum.sub_checked(product)?;
665                BernoulliDistribution::try_new(or_success)
666            }
667            (false, true) if left_p.eq(&ScalarValue::new_one(&left_p.data_type())?) => {
668                Ok(left.clone())
669            }
670            (true, false) if right_p.eq(&ScalarValue::new_one(&right_p.data_type())?) => {
671                Ok(right.clone())
672            }
673            _ => {
674                let dt = Distribution::target_type(&[left_p, right_p])?;
675                BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
676            }
677        },
678        _ => {
679            not_impl_err!("Statistical evaluation only supports AND and OR operators")
680        }
681    }
682}
683
684/// Applies the given operation to the given Gaussian distributions. Currently,
685/// this function handles only addition and subtraction operations. If the
686/// result is not a Gaussian random variable, it returns `None`. For details,
687/// see:
688///
689/// <https://en.wikipedia.org/wiki/Sum_of_normally_distributed_random_variables>
690#[deprecated(
691    since = "54.0.0",
692    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
693)]
694pub fn combine_gaussians(
695    op: &Operator,
696    left: &GaussianDistribution,
697    right: &GaussianDistribution,
698) -> Result<Option<GaussianDistribution>> {
699    match op {
700        Operator::Plus => GaussianDistribution::try_new(
701            left.mean().add_checked(right.mean())?,
702            left.variance().add_checked(right.variance())?,
703        )
704        .map(Some),
705        Operator::Minus => GaussianDistribution::try_new(
706            left.mean().sub_checked(right.mean())?,
707            left.variance().add_checked(right.variance())?,
708        )
709        .map(Some),
710        _ => Ok(None),
711    }
712}
713
714/// Creates a new `Bernoulli` distribution by computing the resulting probability.
715/// Expects `op` to be a comparison operator, with `left` and `right` having
716/// numeric distributions. The resulting distribution has the `Float64` data
717/// type.
718#[deprecated(
719    since = "54.0.0",
720    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
721)]
722pub fn create_bernoulli_from_comparison(
723    op: &Operator,
724    left: &Distribution,
725    right: &Distribution,
726) -> Result<Distribution> {
727    match (left, right) {
728        (Uniform(left), Uniform(right)) => {
729            match op {
730                Operator::Eq | Operator::NotEq => {
731                    let (li, ri) = (left.range(), right.range());
732                    if let Some(intersection) = li.intersect(ri)? {
733                        // If the ranges are not disjoint, calculate the probability
734                        // of equality using cardinalities:
735                        if let (Some(lc), Some(rc), Some(ic)) = (
736                            li.cardinality(),
737                            ri.cardinality(),
738                            intersection.cardinality(),
739                        ) {
740                            // Avoid overflow by widening the type temporarily:
741                            let pairs = ((lc as u128) * (rc as u128)) as f64;
742                            let p = (ic as f64).div_checked(pairs)?;
743                            // Alternative approach that may be more stable:
744                            // let p = (ic as f64)
745                            //     .div_checked(lc as f64)?
746                            //     .div_checked(rc as f64)?;
747
748                            let mut p_value = ScalarValue::from(p);
749                            if op == &Operator::NotEq {
750                                let one = ScalarValue::from(1.0);
751                                p_value = alter_fp_rounding_mode::<false, _>(
752                                    &one,
753                                    &p_value,
754                                    |lhs, rhs| lhs.sub_checked(rhs),
755                                )?;
756                            };
757                            return Distribution::new_bernoulli(p_value);
758                        }
759                    } else if op == &Operator::Eq {
760                        // If the ranges are disjoint, probability of equality is 0.
761                        return Distribution::new_bernoulli(ScalarValue::from(0.0));
762                    } else {
763                        // If the ranges are disjoint, probability of not-equality is 1.
764                        return Distribution::new_bernoulli(ScalarValue::from(1.0));
765                    }
766                }
767                Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => {
768                    // TODO: We can handle inequality operators and calculate a
769                    // `p` value instead of falling back to an unknown Bernoulli
770                    // distribution. Note that the strict and non-strict inequalities
771                    // may require slightly different logic in case of real vs.
772                    // integral data types.
773                }
774                _ => {}
775            }
776        }
777        (Gaussian(_), Gaussian(_)) => {
778            // TODO: We can handle Gaussian comparisons and calculate a `p` value
779            //       instead of falling back to an unknown Bernoulli distribution.
780        }
781        _ => {}
782    }
783    let (li, ri) = (left.range()?, right.range()?);
784    let range_evaluation = apply_operator(op, &li, &ri)?;
785    if range_evaluation.eq(&Interval::FALSE) {
786        Distribution::new_bernoulli(ScalarValue::from(0.0))
787    } else if range_evaluation.eq(&Interval::TRUE) {
788        Distribution::new_bernoulli(ScalarValue::from(1.0))
789    } else if range_evaluation.eq(&Interval::TRUE_OR_FALSE) {
790        Distribution::new_bernoulli(ScalarValue::try_from(&DataType::Float64)?)
791    } else {
792        internal_err!("This function must be called with a comparison operator")
793    }
794}
795
796/// Creates a new [`Generic`] distribution that represents the result of the
797/// given binary operation on two unknown quantities represented by their
798/// [`Distribution`] objects. The function computes the mean, median and
799/// variance if possible.
800#[deprecated(
801    since = "54.0.0",
802    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
803)]
804pub fn new_generic_from_binary_op(
805    op: &Operator,
806    left: &Distribution,
807    right: &Distribution,
808) -> Result<Distribution> {
809    Distribution::new_generic(
810        compute_mean(op, left, right)?,
811        compute_median(op, left, right)?,
812        compute_variance(op, left, right)?,
813        apply_operator(op, &left.range()?, &right.range()?)?,
814    )
815}
816
817/// Computes the mean value for the result of the given binary operation on
818/// two unknown quantities represented by their [`Distribution`] objects.
819#[deprecated(
820    since = "54.0.0",
821    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
822)]
823pub fn compute_mean(
824    op: &Operator,
825    left: &Distribution,
826    right: &Distribution,
827) -> Result<ScalarValue> {
828    let (left_mean, right_mean) = (left.mean()?, right.mean()?);
829
830    match op {
831        Operator::Plus => return left_mean.add_checked(right_mean),
832        Operator::Minus => return left_mean.sub_checked(right_mean),
833        // Note the independence assumption below:
834        Operator::Multiply => return left_mean.mul_checked(right_mean),
835        // TODO: We can calculate the mean for division when we support reciprocals,
836        // or know the distributions of the operands. For details, see:
837        //
838        // <https://en.wikipedia.org/wiki/Algebra_of_random_variables>
839        // <https://stats.stackexchange.com/questions/185683/distribution-of-ratio-between-two-independent-uniform-random-variables>
840        //
841        // Fall back to an unknown mean value for division:
842        Operator::Divide => {}
843        // Fall back to an unknown mean value for other cases:
844        _ => {}
845    }
846    let target_type = Distribution::target_type(&[&left_mean, &right_mean])?;
847    ScalarValue::try_from(target_type)
848}
849
850/// Computes the median value for the result of the given binary operation on
851/// two unknown quantities represented by its [`Distribution`] objects. Currently,
852/// the median is calculable only for addition and subtraction operations on:
853/// - [`Uniform`] and [`Uniform`] distributions, and
854/// - [`Gaussian`] and [`Gaussian`] distributions.
855#[deprecated(
856    since = "54.0.0",
857    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
858)]
859pub fn compute_median(
860    op: &Operator,
861    left: &Distribution,
862    right: &Distribution,
863) -> Result<ScalarValue> {
864    match (left, right) {
865        (Uniform(lu), Uniform(ru)) => {
866            let (left_median, right_median) = (lu.median()?, ru.median()?);
867            // Under the independence assumption, the result is a symmetric
868            // triangular distribution, so we can simply add/subtract the
869            // median values:
870            match op {
871                Operator::Plus => return left_median.add_checked(right_median),
872                Operator::Minus => return left_median.sub_checked(right_median),
873                // Fall back to an unknown median value for other cases:
874                _ => {}
875            }
876        }
877        // Under the independence assumption, the result is another Gaussian
878        // distribution, so we can simply add/subtract the median values:
879        (Gaussian(lg), Gaussian(rg)) => match op {
880            Operator::Plus => return lg.mean().add_checked(rg.mean()),
881            Operator::Minus => return lg.mean().sub_checked(rg.mean()),
882            // Fall back to an unknown median value for other cases:
883            _ => {}
884        },
885        // Fall back to an unknown median value for other cases:
886        _ => {}
887    }
888
889    let (left_median, right_median) = (left.median()?, right.median()?);
890    let target_type = Distribution::target_type(&[&left_median, &right_median])?;
891    ScalarValue::try_from(target_type)
892}
893
894/// Computes the variance value for the result of the given binary operation on
895/// two unknown quantities represented by their [`Distribution`] objects.
896#[deprecated(
897    since = "54.0.0",
898    note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
899)]
900pub fn compute_variance(
901    op: &Operator,
902    left: &Distribution,
903    right: &Distribution,
904) -> Result<ScalarValue> {
905    let (left_variance, right_variance) = (left.variance()?, right.variance()?);
906
907    match op {
908        // Note the independence assumption below:
909        Operator::Plus => return left_variance.add_checked(right_variance),
910        // Note the independence assumption below:
911        Operator::Minus => return left_variance.add_checked(right_variance),
912        // Note the independence assumption below:
913        Operator::Multiply => {
914            // For more details, along with an explanation of the formula below, see:
915            //
916            // <https://en.wikipedia.org/wiki/Distribution_of_the_product_of_two_random_variables>
917            let (left_mean, right_mean) = (left.mean()?, right.mean()?);
918            let left_mean_sq = left_mean.mul_checked(&left_mean)?;
919            let right_mean_sq = right_mean.mul_checked(&right_mean)?;
920            let left_sos = left_variance.add_checked(&left_mean_sq)?;
921            let right_sos = right_variance.add_checked(&right_mean_sq)?;
922            let pos = left_mean_sq.mul_checked(right_mean_sq)?;
923            return left_sos.mul_checked(right_sos)?.sub_checked(pos);
924        }
925        // TODO: We can calculate the variance for division when we support reciprocals,
926        // or know the distributions of the operands. For details, see:
927        //
928        // <https://en.wikipedia.org/wiki/Algebra_of_random_variables>
929        // <https://stats.stackexchange.com/questions/185683/distribution-of-ratio-between-two-independent-uniform-random-variables>
930        //
931        // Fall back to an unknown variance value for division:
932        Operator::Divide => {}
933        // Fall back to an unknown variance value for other cases:
934        _ => {}
935    }
936    let target_type = Distribution::target_type(&[&left_variance, &right_variance])?;
937    ScalarValue::try_from(target_type)
938}
939
940#[cfg(test)]
941mod tests {
942    use super::{
943        BernoulliDistribution, Distribution, GaussianDistribution, UniformDistribution,
944        combine_bernoullis, combine_gaussians, compute_mean, compute_median,
945        compute_variance, create_bernoulli_from_comparison, new_generic_from_binary_op,
946    };
947    use crate::interval_arithmetic::{Interval, apply_operator};
948    use crate::operator::Operator;
949
950    use arrow::datatypes::DataType;
951    use datafusion_common::{HashSet, Result, ScalarValue};
952
953    #[test]
954    fn uniform_dist_is_valid_test() -> Result<()> {
955        assert_eq!(
956            Distribution::new_uniform(Interval::make_zero(&DataType::Int8)?)?,
957            Distribution::Uniform(UniformDistribution {
958                interval: Interval::make_zero(&DataType::Int8)?,
959            })
960        );
961
962        assert!(Distribution::new_uniform(Interval::TRUE_OR_FALSE).is_err());
963        Ok(())
964    }
965
966    #[test]
967    fn exponential_dist_is_valid_test() {
968        // This array collects test cases of the form (distribution, validity).
969        let exponentials = vec![
970            (
971                Distribution::new_exponential(ScalarValue::Null, ScalarValue::Null, true),
972                false,
973            ),
974            (
975                Distribution::new_exponential(
976                    ScalarValue::from(0_f32),
977                    ScalarValue::from(1_f32),
978                    true,
979                ),
980                false,
981            ),
982            (
983                Distribution::new_exponential(
984                    ScalarValue::from(100_f32),
985                    ScalarValue::from(1_f32),
986                    true,
987                ),
988                true,
989            ),
990            (
991                Distribution::new_exponential(
992                    ScalarValue::from(-100_f32),
993                    ScalarValue::from(1_f32),
994                    true,
995                ),
996                false,
997            ),
998        ];
999        for case in exponentials {
1000            assert_eq!(case.0.is_ok(), case.1);
1001        }
1002    }
1003
1004    #[test]
1005    fn gaussian_dist_is_valid_test() {
1006        // This array collects test cases of the form (distribution, validity).
1007        let gaussians = vec![
1008            (
1009                Distribution::new_gaussian(ScalarValue::Null, ScalarValue::Null),
1010                false,
1011            ),
1012            (
1013                Distribution::new_gaussian(
1014                    ScalarValue::from(0_f32),
1015                    ScalarValue::from(0_f32),
1016                ),
1017                true,
1018            ),
1019            (
1020                Distribution::new_gaussian(
1021                    ScalarValue::from(0_f32),
1022                    ScalarValue::from(0.5_f32),
1023                ),
1024                true,
1025            ),
1026            (
1027                Distribution::new_gaussian(
1028                    ScalarValue::from(0_f32),
1029                    ScalarValue::from(-0.5_f32),
1030                ),
1031                false,
1032            ),
1033        ];
1034        for case in gaussians {
1035            assert_eq!(case.0.is_ok(), case.1);
1036        }
1037    }
1038
1039    #[test]
1040    fn bernoulli_dist_is_valid_test() {
1041        // This array collects test cases of the form (distribution, validity).
1042        let bernoullis = vec![
1043            (Distribution::new_bernoulli(ScalarValue::Null), true),
1044            (Distribution::new_bernoulli(ScalarValue::from(0.)), true),
1045            (Distribution::new_bernoulli(ScalarValue::from(0.25)), true),
1046            (Distribution::new_bernoulli(ScalarValue::from(1.)), true),
1047            (Distribution::new_bernoulli(ScalarValue::from(11.)), false),
1048            (Distribution::new_bernoulli(ScalarValue::from(-11.)), false),
1049            (Distribution::new_bernoulli(ScalarValue::from(0_i64)), true),
1050            (Distribution::new_bernoulli(ScalarValue::from(1_i64)), true),
1051            (
1052                Distribution::new_bernoulli(ScalarValue::from(11_i64)),
1053                false,
1054            ),
1055            (
1056                Distribution::new_bernoulli(ScalarValue::from(-11_i64)),
1057                false,
1058            ),
1059        ];
1060        for case in bernoullis {
1061            assert_eq!(case.0.is_ok(), case.1);
1062        }
1063    }
1064
1065    #[test]
1066    fn generic_dist_is_valid_test() -> Result<()> {
1067        // This array collects test cases of the form (distribution, validity).
1068        let generic_dists = vec![
1069            // Using a boolean range to construct a Generic distribution is prohibited.
1070            (
1071                Distribution::new_generic(
1072                    ScalarValue::Null,
1073                    ScalarValue::Null,
1074                    ScalarValue::Null,
1075                    Interval::TRUE_OR_FALSE,
1076                ),
1077                false,
1078            ),
1079            (
1080                Distribution::new_generic(
1081                    ScalarValue::Null,
1082                    ScalarValue::Null,
1083                    ScalarValue::Null,
1084                    Interval::make_zero(&DataType::Float32)?,
1085                ),
1086                true,
1087            ),
1088            (
1089                Distribution::new_generic(
1090                    ScalarValue::from(0_f32),
1091                    ScalarValue::Float32(None),
1092                    ScalarValue::Float32(None),
1093                    Interval::make_zero(&DataType::Float32)?,
1094                ),
1095                true,
1096            ),
1097            (
1098                Distribution::new_generic(
1099                    ScalarValue::Float64(None),
1100                    ScalarValue::from(0.),
1101                    ScalarValue::Float64(None),
1102                    Interval::make_zero(&DataType::Float32)?,
1103                ),
1104                true,
1105            ),
1106            (
1107                Distribution::new_generic(
1108                    ScalarValue::from(-10_f32),
1109                    ScalarValue::Float32(None),
1110                    ScalarValue::Float32(None),
1111                    Interval::make_zero(&DataType::Float32)?,
1112                ),
1113                false,
1114            ),
1115            (
1116                Distribution::new_generic(
1117                    ScalarValue::Float32(None),
1118                    ScalarValue::from(10_f32),
1119                    ScalarValue::Float32(None),
1120                    Interval::make_zero(&DataType::Float32)?,
1121                ),
1122                false,
1123            ),
1124            (
1125                Distribution::new_generic(
1126                    ScalarValue::Null,
1127                    ScalarValue::Null,
1128                    ScalarValue::Null,
1129                    Interval::make_zero(&DataType::Float32)?,
1130                ),
1131                true,
1132            ),
1133            (
1134                Distribution::new_generic(
1135                    ScalarValue::from(0),
1136                    ScalarValue::from(0),
1137                    ScalarValue::Int32(None),
1138                    Interval::make_zero(&DataType::Int32)?,
1139                ),
1140                true,
1141            ),
1142            (
1143                Distribution::new_generic(
1144                    ScalarValue::from(0_f32),
1145                    ScalarValue::from(0_f32),
1146                    ScalarValue::Float32(None),
1147                    Interval::make_zero(&DataType::Float32)?,
1148                ),
1149                true,
1150            ),
1151            (
1152                Distribution::new_generic(
1153                    ScalarValue::from(50.),
1154                    ScalarValue::from(50.),
1155                    ScalarValue::Float64(None),
1156                    Interval::make(Some(0.), Some(100.))?,
1157                ),
1158                true,
1159            ),
1160            (
1161                Distribution::new_generic(
1162                    ScalarValue::from(50.),
1163                    ScalarValue::from(50.),
1164                    ScalarValue::Float64(None),
1165                    Interval::make(Some(-100.), Some(0.))?,
1166                ),
1167                false,
1168            ),
1169            (
1170                Distribution::new_generic(
1171                    ScalarValue::Float64(None),
1172                    ScalarValue::Float64(None),
1173                    ScalarValue::from(1.),
1174                    Interval::make_zero(&DataType::Float64)?,
1175                ),
1176                true,
1177            ),
1178            (
1179                Distribution::new_generic(
1180                    ScalarValue::Float64(None),
1181                    ScalarValue::Float64(None),
1182                    ScalarValue::from(-1.),
1183                    Interval::make_zero(&DataType::Float64)?,
1184                ),
1185                false,
1186            ),
1187        ];
1188        for case in generic_dists {
1189            assert_eq!(case.0.is_ok(), case.1, "{:?}", case.0);
1190        }
1191
1192        Ok(())
1193    }
1194
1195    #[test]
1196    fn mean_extraction_test() -> Result<()> {
1197        // This array collects test cases of the form (distribution, mean value).
1198        let dists = vec![
1199            (
1200                Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1201                ScalarValue::from(0_i64),
1202            ),
1203            (
1204                Distribution::new_uniform(Interval::make_zero(&DataType::Float64)?),
1205                ScalarValue::from(0.),
1206            ),
1207            (
1208                Distribution::new_uniform(Interval::make(Some(1), Some(100))?),
1209                ScalarValue::from(50),
1210            ),
1211            (
1212                Distribution::new_uniform(Interval::make(Some(-100), Some(-1))?),
1213                ScalarValue::from(-50),
1214            ),
1215            (
1216                Distribution::new_uniform(Interval::make(Some(-100), Some(100))?),
1217                ScalarValue::from(0),
1218            ),
1219            (
1220                Distribution::new_exponential(
1221                    ScalarValue::from(2.),
1222                    ScalarValue::from(0.),
1223                    true,
1224                ),
1225                ScalarValue::from(0.5),
1226            ),
1227            (
1228                Distribution::new_exponential(
1229                    ScalarValue::from(2.),
1230                    ScalarValue::from(1.),
1231                    true,
1232                ),
1233                ScalarValue::from(1.5),
1234            ),
1235            (
1236                Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1237                ScalarValue::from(0.),
1238            ),
1239            (
1240                Distribution::new_gaussian(
1241                    ScalarValue::from(-2.),
1242                    ScalarValue::from(0.5),
1243                ),
1244                ScalarValue::from(-2.),
1245            ),
1246            (
1247                Distribution::new_bernoulli(ScalarValue::from(0.5)),
1248                ScalarValue::from(0.5),
1249            ),
1250            (
1251                Distribution::new_generic(
1252                    ScalarValue::from(42.),
1253                    ScalarValue::from(42.),
1254                    ScalarValue::Float64(None),
1255                    Interval::make(Some(25.), Some(50.))?,
1256                ),
1257                ScalarValue::from(42.),
1258            ),
1259        ];
1260
1261        for case in dists {
1262            assert_eq!(case.0?.mean()?, case.1);
1263        }
1264
1265        Ok(())
1266    }
1267
1268    #[test]
1269    fn median_extraction_test() -> Result<()> {
1270        // This array collects test cases of the form (distribution, median value).
1271        let dists = vec![
1272            (
1273                Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
1274                ScalarValue::from(0_i64),
1275            ),
1276            (
1277                Distribution::new_uniform(Interval::make(Some(25.), Some(75.))?),
1278                ScalarValue::from(50.),
1279            ),
1280            (
1281                Distribution::new_exponential(
1282                    ScalarValue::from(2_f64.ln()),
1283                    ScalarValue::from(0.),
1284                    true,
1285                ),
1286                ScalarValue::from(1.),
1287            ),
1288            (
1289                Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1290                ScalarValue::from(2.),
1291            ),
1292            (
1293                Distribution::new_bernoulli(ScalarValue::from(0.25)),
1294                ScalarValue::from(0.),
1295            ),
1296            (
1297                Distribution::new_bernoulli(ScalarValue::from(0.75)),
1298                ScalarValue::from(1.),
1299            ),
1300            (
1301                Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
1302                ScalarValue::from(2.),
1303            ),
1304            (
1305                Distribution::new_generic(
1306                    ScalarValue::from(12.),
1307                    ScalarValue::from(12.),
1308                    ScalarValue::Float64(None),
1309                    Interval::make(Some(0.), Some(25.))?,
1310                ),
1311                ScalarValue::from(12.),
1312            ),
1313        ];
1314
1315        for case in dists {
1316            assert_eq!(case.0?.median()?, case.1);
1317        }
1318
1319        Ok(())
1320    }
1321
1322    #[test]
1323    fn variance_extraction_test() -> Result<()> {
1324        // This array collects test cases of the form (distribution, variance value).
1325        let dists = vec![
1326            (
1327                Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?),
1328                ScalarValue::from(12.),
1329            ),
1330            (
1331                Distribution::new_exponential(
1332                    ScalarValue::from(10.),
1333                    ScalarValue::from(0.),
1334                    true,
1335                ),
1336                ScalarValue::from(0.01),
1337            ),
1338            (
1339                Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
1340                ScalarValue::from(1.),
1341            ),
1342            (
1343                Distribution::new_bernoulli(ScalarValue::from(0.5)),
1344                ScalarValue::from(0.25),
1345            ),
1346            (
1347                Distribution::new_generic(
1348                    ScalarValue::Float64(None),
1349                    ScalarValue::Float64(None),
1350                    ScalarValue::from(0.02),
1351                    Interval::make_zero(&DataType::Float64)?,
1352                ),
1353                ScalarValue::from(0.02),
1354            ),
1355        ];
1356
1357        for case in dists {
1358            assert_eq!(case.0?.variance()?, case.1);
1359        }
1360
1361        Ok(())
1362    }
1363
1364    #[test]
1365    fn test_calculate_generic_properties_gauss_gauss() -> Result<()> {
1366        let dist_a =
1367            Distribution::new_gaussian(ScalarValue::from(10.), ScalarValue::from(0.0))?;
1368        let dist_b =
1369            Distribution::new_gaussian(ScalarValue::from(20.), ScalarValue::from(0.0))?;
1370
1371        let test_data = vec![
1372            // Mean:
1373            (
1374                compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1375                ScalarValue::from(30.),
1376            ),
1377            (
1378                compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1379                ScalarValue::from(-10.),
1380            ),
1381            // Median:
1382            (
1383                compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1384                ScalarValue::from(30.),
1385            ),
1386            (
1387                compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1388                ScalarValue::from(-10.),
1389            ),
1390        ];
1391        for (actual, expected) in test_data {
1392            assert_eq!(actual, expected);
1393        }
1394
1395        Ok(())
1396    }
1397
1398    #[test]
1399    fn test_combine_bernoullis_and_op() -> Result<()> {
1400        let op = Operator::And;
1401        let left = BernoulliDistribution::try_new(ScalarValue::from(0.5))?;
1402        let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1403        let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1404        let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1405
1406        assert_eq!(
1407            combine_bernoullis(&op, &left, &right)?.p_value(),
1408            &ScalarValue::from(0.5 * 0.4)
1409        );
1410        assert_eq!(
1411            combine_bernoullis(&op, &left_null, &right)?.p_value(),
1412            &ScalarValue::Float64(None)
1413        );
1414        assert_eq!(
1415            combine_bernoullis(&op, &left, &right_null)?.p_value(),
1416            &ScalarValue::Float64(None)
1417        );
1418        assert_eq!(
1419            combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1420            &ScalarValue::Null
1421        );
1422
1423        Ok(())
1424    }
1425
1426    #[test]
1427    fn test_combine_bernoullis_or_op() -> Result<()> {
1428        let op = Operator::Or;
1429        let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1430        let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1431        let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1432        let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
1433
1434        assert_eq!(
1435            combine_bernoullis(&op, &left, &right)?.p_value(),
1436            &ScalarValue::from(0.6 + 0.4 - (0.6 * 0.4))
1437        );
1438        assert_eq!(
1439            combine_bernoullis(&op, &left_null, &right)?.p_value(),
1440            &ScalarValue::Float64(None)
1441        );
1442        assert_eq!(
1443            combine_bernoullis(&op, &left, &right_null)?.p_value(),
1444            &ScalarValue::Float64(None)
1445        );
1446        assert_eq!(
1447            combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
1448            &ScalarValue::Null
1449        );
1450
1451        Ok(())
1452    }
1453
1454    #[test]
1455    fn test_combine_bernoullis_unsupported_ops() -> Result<()> {
1456        let mut operator_set = operator_set();
1457        operator_set.remove(&Operator::And);
1458        operator_set.remove(&Operator::Or);
1459
1460        let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
1461        let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
1462        for op in operator_set {
1463            assert!(
1464                combine_bernoullis(&op, &left, &right).is_err(),
1465                "Operator {op} should not be supported for Bernoulli distributions"
1466            );
1467        }
1468
1469        Ok(())
1470    }
1471
1472    #[test]
1473    fn test_combine_gaussians_addition() -> Result<()> {
1474        let op = Operator::Plus;
1475        let left = GaussianDistribution::try_new(
1476            ScalarValue::from(3.0),
1477            ScalarValue::from(2.0),
1478        )?;
1479        let right = GaussianDistribution::try_new(
1480            ScalarValue::from(4.0),
1481            ScalarValue::from(1.0),
1482        )?;
1483
1484        let result = combine_gaussians(&op, &left, &right)?.unwrap();
1485
1486        assert_eq!(result.mean(), &ScalarValue::from(7.0)); // 3.0 + 4.0
1487        assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0
1488        Ok(())
1489    }
1490
1491    #[test]
1492    fn test_combine_gaussians_subtraction() -> Result<()> {
1493        let op = Operator::Minus;
1494        let left = GaussianDistribution::try_new(
1495            ScalarValue::from(7.0),
1496            ScalarValue::from(2.0),
1497        )?;
1498        let right = GaussianDistribution::try_new(
1499            ScalarValue::from(4.0),
1500            ScalarValue::from(1.0),
1501        )?;
1502
1503        let result = combine_gaussians(&op, &left, &right)?.unwrap();
1504
1505        assert_eq!(result.mean(), &ScalarValue::from(3.0)); // 7.0 - 4.0
1506        assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0
1507
1508        Ok(())
1509    }
1510
1511    #[test]
1512    fn test_combine_gaussians_unsupported_ops() -> Result<()> {
1513        let mut operator_set = operator_set();
1514        operator_set.remove(&Operator::Plus);
1515        operator_set.remove(&Operator::Minus);
1516
1517        let left = GaussianDistribution::try_new(
1518            ScalarValue::from(7.0),
1519            ScalarValue::from(2.0),
1520        )?;
1521        let right = GaussianDistribution::try_new(
1522            ScalarValue::from(4.0),
1523            ScalarValue::from(1.0),
1524        )?;
1525        for op in operator_set {
1526            assert!(
1527                combine_gaussians(&op, &left, &right)?.is_none(),
1528                "Operator {op} should not be supported for Gaussian distributions"
1529            );
1530        }
1531
1532        Ok(())
1533    }
1534
1535    // Expected test results were calculated in Wolfram Mathematica, by using:
1536    //
1537    // *METHOD_NAME*[TransformedDistribution[
1538    //  x *op* y,
1539    //  {x ~ *DISTRIBUTION_X*[..], y ~ *DISTRIBUTION_Y*[..]}
1540    // ]]
1541    #[test]
1542    fn test_calculate_generic_properties_uniform_uniform() -> Result<()> {
1543        let dist_a = Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?)?;
1544        let dist_b = Distribution::new_uniform(Interval::make(Some(12.), Some(36.))?)?;
1545
1546        let test_data = vec![
1547            // Mean:
1548            (
1549                compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
1550                ScalarValue::from(30.),
1551            ),
1552            (
1553                compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
1554                ScalarValue::from(-18.),
1555            ),
1556            (
1557                compute_mean(&Operator::Multiply, &dist_a, &dist_b)?,
1558                ScalarValue::from(144.),
1559            ),
1560            // Median:
1561            (
1562                compute_median(&Operator::Plus, &dist_a, &dist_b)?,
1563                ScalarValue::from(30.),
1564            ),
1565            (
1566                compute_median(&Operator::Minus, &dist_a, &dist_b)?,
1567                ScalarValue::from(-18.),
1568            ),
1569            // Variance:
1570            (
1571                compute_variance(&Operator::Plus, &dist_a, &dist_b)?,
1572                ScalarValue::from(60.),
1573            ),
1574            (
1575                compute_variance(&Operator::Minus, &dist_a, &dist_b)?,
1576                ScalarValue::from(60.),
1577            ),
1578            (
1579                compute_variance(&Operator::Multiply, &dist_a, &dist_b)?,
1580                ScalarValue::from(9216.),
1581            ),
1582        ];
1583        for (actual, expected) in test_data {
1584            assert_eq!(actual, expected);
1585        }
1586
1587        Ok(())
1588    }
1589
1590    /// Test for `Uniform`-`Uniform`, `Uniform`-`Generic`, `Generic`-`Uniform`,
1591    /// `Generic`-`Generic` pairs, where range is always present.
1592    #[test]
1593    fn test_compute_range_where_present() -> Result<()> {
1594        let a = &Interval::make(Some(0.), Some(12.0))?;
1595        let b = &Interval::make(Some(0.), Some(12.0))?;
1596        let mean = ScalarValue::from(6.0);
1597        for (dist_a, dist_b) in [
1598            (
1599                Distribution::new_uniform(a.clone())?,
1600                Distribution::new_uniform(b.clone())?,
1601            ),
1602            (
1603                Distribution::new_generic(
1604                    mean.clone(),
1605                    mean.clone(),
1606                    ScalarValue::Float64(None),
1607                    a.clone(),
1608                )?,
1609                Distribution::new_uniform(b.clone())?,
1610            ),
1611            (
1612                Distribution::new_uniform(a.clone())?,
1613                Distribution::new_generic(
1614                    mean.clone(),
1615                    mean.clone(),
1616                    ScalarValue::Float64(None),
1617                    b.clone(),
1618                )?,
1619            ),
1620            (
1621                Distribution::new_generic(
1622                    mean.clone(),
1623                    mean.clone(),
1624                    ScalarValue::Float64(None),
1625                    a.clone(),
1626                )?,
1627                Distribution::new_generic(
1628                    mean.clone(),
1629                    mean.clone(),
1630                    ScalarValue::Float64(None),
1631                    b.clone(),
1632                )?,
1633            ),
1634        ] {
1635            use super::Operator::{
1636                Divide, Eq, Gt, GtEq, Lt, LtEq, Minus, Multiply, NotEq, Plus,
1637            };
1638            for op in [Plus, Minus, Multiply, Divide] {
1639                assert_eq!(
1640                    new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?,
1641                    apply_operator(&op, a, b)?,
1642                    "Failed for {dist_a:?} {op} {dist_b:?}"
1643                );
1644            }
1645            for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] {
1646                assert_eq!(
1647                    create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?,
1648                    apply_operator(&op, a, b)?,
1649                    "Failed for {dist_a:?} {op} {dist_b:?}"
1650                );
1651            }
1652        }
1653
1654        Ok(())
1655    }
1656
1657    fn operator_set() -> HashSet<Operator> {
1658        use super::Operator::*;
1659
1660        let all_ops = vec![
1661            And,
1662            Or,
1663            Eq,
1664            NotEq,
1665            Gt,
1666            GtEq,
1667            Lt,
1668            LtEq,
1669            Plus,
1670            Minus,
1671            Multiply,
1672            Divide,
1673            Modulo,
1674            IsDistinctFrom,
1675            IsNotDistinctFrom,
1676            RegexMatch,
1677            RegexIMatch,
1678            RegexNotMatch,
1679            RegexNotIMatch,
1680            LikeMatch,
1681            ILikeMatch,
1682            NotLikeMatch,
1683            NotILikeMatch,
1684            BitwiseAnd,
1685            BitwiseOr,
1686            BitwiseXor,
1687            BitwiseShiftRight,
1688            BitwiseShiftLeft,
1689            StringConcat,
1690            AtArrow,
1691            ArrowAt,
1692        ];
1693
1694        all_ops.into_iter().collect()
1695    }
1696}