datafusion_statrs/distribution/
binomial.rs

1//! Module containing functions to the Binomial Distribution.
2//! 
3//! Implemented by [`statrs::distribution::Binomial`].
4//! 
5//! The [Binomial Distribution](https://en.wikipedia.org/wiki/Binomial_distribution) has two
6//! parameters:
7//! 
8//! n: n ∈ N (natural numbers)  
9//! p: 0 ≤ p ≤ 1
10//! 
11//! Usage:
12//! 
13//! `binomial_pmf(x, n, p)`  
14//! `binomial_ln_pmf(x, n, p)`  
15//! `binomial_cdf(x, n, p)`  
16//! `binomial_sf(x, n, p)`
17//! 
18//! with
19//! 
20//!   `x`: 0 ≤ x ≤ n `UInt64`/`BIGINT UNSIGNED`,  
21//!   `n`: 0 ≤ n `UInt64`/`BIGINT UNSIGNED`,  
22//!   `p`: [0, 1] `Float64`/`DOUBLE`
23//! 
24//! Examples
25//! ```
26//! #[tokio::main(flavor = "current_thread")]
27//! async fn main() -> std::io::Result<()> {
28//!     let mut ctx = datafusion::prelude::SessionContext::new();
29//!     datafusion_statrs::distribution::binomial::register(&mut ctx)?;
30//!     ctx.sql("SELECT binomial_cdf(CAST(2 AS BIGINT UNSIGNED), CAST(5 AS BIGINT UNSIGNED), 0.2)").await?
31//!        .show().await?;
32//!     Ok(())
33//! }
34//! ```
35
36use datafusion::error::DataFusionError;
37use datafusion::execution::FunctionRegistry;
38use datafusion::logical_expr::ScalarUDF;
39use statrs::distribution::Binomial;
40
41use crate::utils::discrete2u1f::Discrete2U1F;
42use crate::utils::evaluator2u1f::{CdfEvaluator2U1F, LnPmfEvaluator2U1F, PmfEvaluator2U1F, SfEvaluator2U1F};
43
44type Pmf = Discrete2U1F<PmfEvaluator2U1F<Binomial>>;
45
46/// ScalarUDF for the Binomial Distribution PMF
47pub fn pmf() -> ScalarUDF {
48    ScalarUDF::from(Pmf::new("binomial_pmf"))
49}
50
51type LnPmf = Discrete2U1F<LnPmfEvaluator2U1F<Binomial>>;
52
53/// ScalarUDF for the Binomial Distribution PMF
54pub fn ln_pmf() -> ScalarUDF {
55    ScalarUDF::from(LnPmf::new("binomial_ln_pmf"))
56}
57
58type Cdf = Discrete2U1F<CdfEvaluator2U1F<Binomial>>;
59
60/// ScalarUDF for the Binomial Distribution CDF
61pub fn cdf() -> ScalarUDF {
62    ScalarUDF::from(Cdf::new("binomial_cdf"))
63}
64
65type Sf = Discrete2U1F<SfEvaluator2U1F<Binomial>>;
66
67/// ScalarUDF for the Binomial Distribution SF
68pub fn sf() -> ScalarUDF {
69    ScalarUDF::from(Sf::new("binomial_sf"))
70}
71
72/// Register the functions for the Binomial Distribution
73pub fn register(registry: &mut dyn FunctionRegistry) -> Result<(), DataFusionError> {
74    crate::utils::register::register(registry, vec![pmf(), ln_pmf(), cdf(), sf()])
75}
76
77#[cfg(test)]
78mod tests {
79    use std::sync::Arc;
80
81    use assert_eq_float::assert_eq_float;
82    use datafusion::{
83        arrow::{
84            array::{Float64Array, RecordBatch, UInt64Array},
85            datatypes::{DataType, Field, Schema, SchemaRef},
86        },
87        common::cast::as_float64_array,
88        error::DataFusionError,
89        prelude::{SessionContext, col},
90    };
91    use statrs::distribution::BinomialError;
92
93    use super::*;
94
95    fn get_schema() -> SchemaRef {
96        SchemaRef::new(Schema::new(vec![
97            Field::new("x", DataType::UInt64, true),
98            Field::new("n", DataType::UInt64, true),
99            Field::new("p", DataType::Float64, true),
100        ]))
101    }
102
103    fn make_records(rows: Vec<(Option<u64>, Option<u64>, Option<f64>)>) -> RecordBatch {
104        let mut xs = Vec::new();
105        let mut ns = Vec::new();
106        let mut ps = Vec::new();
107        for row in rows {
108            xs.push(row.0);
109            ns.push(row.1);
110            ps.push(row.2);
111        }
112
113        RecordBatch::try_new(
114            get_schema(),
115            vec![
116                Arc::new(UInt64Array::from(xs)),
117                Arc::new(UInt64Array::from(ns)),
118                Arc::new(Float64Array::from(ps)),
119            ],
120        )
121        .unwrap()
122    }
123
124    #[tokio::test]
125    async fn binomial_pmf_success() {
126        let pmf = pmf();
127
128        let recs = make_records(vec![
129            (Some(0), Some(3), Some(0.25)),
130            (Some(1), Some(3), Some(0.25)),
131            (None, Some(3), Some(0.25)),
132            (Some(0), None, Some(0.25)),
133        ]);
134
135        let ctx = SessionContext::new();
136        ctx.register_batch("tbl", recs).unwrap();
137        let df = ctx.table("tbl").await.unwrap();
138        let res = df
139            .select(vec![
140                (pmf.call(vec![col("x"), col("n"), col("p")])).alias("q"),
141            ])
142            .unwrap()
143            .collect()
144            .await
145            .unwrap();
146
147        assert_eq!(res.len(), 1);
148        assert_eq!(res[0].num_columns(), 1);
149        assert_eq!(res[0].num_rows(), 4);
150        let res_col = as_float64_array(res[0].column(0)).unwrap();
151        assert_eq_float!(res_col.value(0), 0.421875);
152        assert_eq_float!(res_col.value(1), 0.421875);
153        assert!(res_col.value(2).is_nan());
154        assert!(res_col.value(3).is_nan());
155    }
156
157    #[tokio::test]
158    async fn binomial_pmf_failure_1() {
159        let pmf = pmf();
160
161        let recs = make_records(vec![(Some(0), Some(3), Some(1.25))]);
162
163        let ctx = SessionContext::new();
164        ctx.register_batch("tbl", recs).unwrap();
165        let df = ctx.table("tbl").await.unwrap();
166        let res = df
167            .select(vec![
168                (pmf.call(vec![col("x"), col("n"), col("p")])).alias("q"),
169            ])
170            .unwrap()
171            .collect()
172            .await;
173        match res {
174            Err(DataFusionError::External(e)) => {
175                let be = e.downcast::<BinomialError>().unwrap();
176                assert_eq!(*be.as_ref(), BinomialError::ProbabilityInvalid);
177            }
178            _ => {
179                println!("unexpected result: {:?}", res);
180                assert!(false);
181            }
182        }
183    }
184
185    #[tokio::test]
186    async fn binomial_ln_pdf_success() {
187        let mut ctx = SessionContext::new();
188        register(&mut ctx).unwrap();
189        let res = ctx
190            .sql("SELECT binomial_ln_pmf(CAST(2 AS BIGINT UNSIGNED), CAST(10 AS BIGINT UNSIGNED), 0.5)")
191            .await
192            .unwrap()
193            .collect()
194            .await
195            .unwrap();
196        assert_eq!(res.len(), 1);
197        assert_eq!(res[0].num_columns(), 1);
198        assert_eq!(res[0].num_rows(), 1);
199        let res_col = as_float64_array(res[0].column(0)).unwrap();
200        assert_eq_float!(res_col.value(0), -3.1248093158291335);
201    }
202
203    #[tokio::test]
204    async fn binomial_cdf_success() {
205        let pmf = cdf();
206
207        let recs = make_records(vec![
208            (Some(0), Some(3), Some(0.25)),
209            (Some(1), Some(3), Some(0.25)),
210            (None, Some(3), Some(0.25)),
211            (Some(0), None, Some(0.25)),
212        ]);
213
214        let ctx = SessionContext::new();
215        ctx.register_batch("tbl", recs).unwrap();
216        let df = ctx.table("tbl").await.unwrap();
217        let res = df
218            .select(vec![
219                (pmf.call(vec![col("x"), col("n"), col("p")])).alias("q"),
220            ])
221            .unwrap()
222            .collect()
223            .await
224            .unwrap();
225
226        assert_eq!(res.len(), 1);
227        assert_eq!(res[0].num_columns(), 1);
228        assert_eq!(res[0].num_rows(), 4);
229        let res_col = as_float64_array(res[0].column(0)).unwrap();
230        assert_eq_float!(res_col.value(0), 0.421875, 3e-15);
231        assert_eq_float!(res_col.value(1), 0.84375);
232        assert!(res_col.value(2).is_nan());
233        assert!(res_col.value(3).is_nan());
234    }
235
236    #[tokio::test]
237    async fn binomial_sf_success() {
238        let pmf = sf();
239
240        let recs = make_records(vec![
241            (Some(0), Some(3), Some(0.25)),
242            (Some(1), Some(3), Some(0.25)),
243            (None, Some(3), Some(0.25)),
244            (Some(0), None, Some(0.25)),
245        ]);
246
247        let ctx = SessionContext::new();
248        ctx.register_batch("tbl", recs).unwrap();
249        let df = ctx.table("tbl").await.unwrap();
250        let res = df
251            .select(vec![
252                (pmf.call(vec![col("x"), col("n"), col("p")])).alias("q"),
253            ])
254            .unwrap()
255            .collect()
256            .await
257            .unwrap();
258
259        assert_eq!(res.len(), 1);
260        assert_eq!(res[0].num_columns(), 1);
261        assert_eq!(res[0].num_rows(), 4);
262        let res_col = as_float64_array(res[0].column(0)).unwrap();
263        assert_eq_float!(res_col.value(0), 0.578125, 3e-15);
264        assert_eq_float!(res_col.value(1), 0.15625, 4e-15);
265        assert!(res_col.value(2).is_nan());
266        assert!(res_col.value(3).is_nan());
267    }
268}