datafusion_statrs/distribution/
negative_binomial.rs

1//! Module containing functions to the Negative Binomial Distribution.
2//! 
3//! Implemented by [`statrs::distribution::Binomial`].
4//! 
5//! The [Negative Binomial Distribution](https://en.wikipedia.org/wiki/Negative_binomial_distribution) has two
6//! parameters:
7//! 
8//! r: 0 < r  
9//! p: 0 ≤ p ≤ 1
10//! 
11//! Usage:
12//! 
13//! `negative_binomial_pmf(x, r, p)`  
14//! `negative_binomial_ln_pmf(x, r, p)`  
15//! `negative_binomial_cdf(x, r, p)`  
16//! `negative_binomial_sf(x, r, p)`
17//! 
18//! with
19//! 
20//!   `x`: 0 ≤ x ≤ n `UInt64`/`BIGINT UNSIGNED`,  
21//!   `r`: 0 < r `Float64`/`DOUBLE`,  
22//!   `p`: [0, 1] `Float64`/`DOUBLE`
23//! 
24//! Examples
25//! ```
26//! #[tokio::main(flavor = "current_thread")]
27//! async fn main() -> std::io::Result<()> {
28//!     let mut ctx = datafusion::prelude::SessionContext::new();
29//!     datafusion_statrs::distribution::negative_binomial::register(&mut ctx)?;
30//!     ctx.sql("SELECT negative_binomial_cdf(CAST(2 AS BIGINT UNSIGNED), 5.0, 0.2)").await?
31//!        .show().await?;
32//!     Ok(())
33//! }
34//! ```
35
36use datafusion::error::DataFusionError;
37use datafusion::execution::FunctionRegistry;
38use datafusion::logical_expr::ScalarUDF;
39use statrs::distribution::NegativeBinomial;
40
41use crate::utils::discrete1u2f::Discrete1U2F;
42use crate::utils::evaluator1u2f::{CdfEvaluator1U2F, LnPmfEvaluator1U2F, PmfEvaluator1U2F, SfEvaluator1U2F};
43
44type Pmf = Discrete1U2F<PmfEvaluator1U2F<NegativeBinomial>>;
45
46/// ScalarUDF for the Negative Binomial PMF
47pub fn pmf() -> ScalarUDF {
48    ScalarUDF::from(Pmf::new("negative_binomial_pmf"))
49}
50
51type LnPmf = Discrete1U2F<LnPmfEvaluator1U2F<NegativeBinomial>>;
52
53/// ScalarUDF for the Negative Binomial log PMF
54pub fn ln_pmf() -> ScalarUDF {
55    ScalarUDF::from(LnPmf::new("negative_binomial_ln_pmf"))
56}
57
58type Cdf = Discrete1U2F<CdfEvaluator1U2F<NegativeBinomial>>;
59
60/// ScalarUDF for the Negative Binomial CDF
61pub fn cdf() -> ScalarUDF {
62    ScalarUDF::from(Cdf::new("negative_binomial_cdf"))
63}
64
65type Sf = Discrete1U2F<SfEvaluator1U2F<NegativeBinomial>>;
66
67/// ScalarUDF for the Negative Binomial SF
68pub fn sf() -> ScalarUDF {
69    ScalarUDF::from(Sf::new("negative_binomial_sf"))
70}
71
72/// Register the functions for the Negative Binomial Distribution
73pub fn register(registry: &mut dyn FunctionRegistry) -> Result<(), DataFusionError> {
74    crate::utils::register::register(registry, vec![pmf(), ln_pmf(), cdf(), sf()])
75}
76
77#[cfg(test)]
78mod tests {
79    use std::sync::Arc;
80
81    use assert_eq_float::assert_eq_float;
82    use datafusion::{
83        arrow::{
84            array::{Float64Array, RecordBatch, UInt64Array},
85            datatypes::{DataType, Field, Schema, SchemaRef},
86        },
87        common::cast::as_float64_array,
88        error::DataFusionError,
89        prelude::{SessionContext, col},
90    };
91    use statrs::distribution::NegativeBinomialError;
92
93    use super::*;
94
95    fn get_schema() -> SchemaRef {
96        SchemaRef::new(Schema::new(vec![
97            Field::new("x", DataType::UInt64, true),
98            Field::new("r", DataType::Float64, true),
99            Field::new("p", DataType::Float64, true),
100        ]))
101    }
102
103    fn make_records(rows: Vec<(Option<u64>, Option<f64>, Option<f64>)>) -> RecordBatch {
104        let mut xs = Vec::new();
105        let mut ss = Vec::new();
106        let mut rs = Vec::new();
107        for row in rows {
108            xs.push(row.0);
109            ss.push(row.1);
110            rs.push(row.2);
111        }
112
113        RecordBatch::try_new(
114            get_schema(),
115            vec![
116                Arc::new(UInt64Array::from(xs)),
117                Arc::new(Float64Array::from(ss)),
118                Arc::new(Float64Array::from(rs)),
119            ],
120        )
121        .unwrap()
122    }
123
124    #[tokio::test]
125    async fn negative_binomial_pmf_success() {
126        let pmf = pmf();
127
128        let recs = make_records(vec![
129            (Some(1), Some(3.0), Some(0.25)),
130            (Some(2), Some(3.0), Some(0.25)),
131            (None, Some(3.0), Some(0.25)),
132            (Some(1), None, Some(0.25)),
133        ]);
134
135        let ctx = SessionContext::new();
136        ctx.register_batch("tbl", recs).unwrap();
137        let df = ctx.table("tbl").await.unwrap();
138        let res = df
139            .select(vec![
140                (pmf.call(vec![col("x"), col("r"), col("p")])).alias("q"),
141            ])
142            .unwrap()
143            .collect()
144            .await
145            .unwrap();
146
147        assert_eq!(res.len(), 1);
148        assert_eq!(res[0].num_columns(), 1);
149        assert_eq!(res[0].num_rows(), 4);
150        let res_col = as_float64_array(res[0].column(0)).unwrap();
151        assert_eq_float!(res_col.value(0), 0.035156249999999827);
152        assert_eq_float!(res_col.value(1), 0.05273437499999992);
153        assert!(res_col.value(2).is_nan());
154        assert!(res_col.value(3).is_nan());
155    }
156
157    #[tokio::test]
158    async fn negative_binomial_pmf_failure_1() {
159        let pmf = pmf();
160
161        let recs = make_records(vec![(Some(1), Some(0.0), Some(1.25))]);
162
163        let ctx = SessionContext::new();
164        ctx.register_batch("tbl", recs).unwrap();
165        let df = ctx.table("tbl").await.unwrap();
166        let res = df
167            .select(vec![
168                (pmf.call(vec![col("x"), col("r"), col("p")])).alias("q"),
169            ])
170            .unwrap()
171            .collect()
172            .await;
173        match res {
174            Err(DataFusionError::External(e)) => {
175                let be = e.downcast::<NegativeBinomialError>().unwrap();
176                assert_eq!(*be.as_ref(), NegativeBinomialError::PInvalid);
177            }
178            _ => {
179                println!("unexpected result: {:?}", res);
180                assert!(false);
181            }
182        }
183    }
184
185    #[tokio::test]
186    async fn negative_binomial_ln_pdf_success() {
187        let mut ctx = SessionContext::new();
188        register(&mut ctx).unwrap();
189        let res = ctx
190            .sql("SELECT negative_binomial_ln_pmf(CAST(2 AS BIGINT UNSIGNED), 8.0, 0.11)")
191            .await
192            .unwrap()
193            .collect()
194            .await
195            .unwrap();
196        assert_eq!(res.len(), 1);
197        assert_eq!(res[0].num_columns(), 1);
198        assert_eq!(res[0].num_rows(), 1);
199        let res_col = as_float64_array(res[0].column(0)).unwrap();
200        assert_eq_float!(res_col.value(0), -14.307747999573525);
201    }
202
203    #[tokio::test]
204    async fn negative_binomial_cdf_success() {
205        let pmf = cdf();
206
207        let recs = make_records(vec![
208            (Some(1), Some(3.0), Some(0.25)),
209            (Some(2), Some(3.0), Some(0.25)),
210            (None, Some(3.0), Some(0.25)),
211            (Some(1), None, Some(0.25)),
212        ]);
213
214        let ctx = SessionContext::new();
215        ctx.register_batch("tbl", recs).unwrap();
216        let df = ctx.table("tbl").await.unwrap();
217        let res = df
218            .select(vec![
219                (pmf.call(vec![col("x"), col("r"), col("p")])).alias("q"),
220            ])
221            .unwrap()
222            .collect()
223            .await
224            .unwrap();
225
226        assert_eq!(res.len(), 1);
227        assert_eq!(res[0].num_columns(), 1);
228        assert_eq!(res[0].num_rows(), 4);
229        let res_col = as_float64_array(res[0].column(0)).unwrap();
230        assert_eq_float!(res_col.value(0), 0.050781250000000056);
231        assert_eq_float!(res_col.value(1), 0.10351562499999896);
232        assert!(res_col.value(2).is_nan());
233        assert!(res_col.value(3).is_nan());
234    }
235
236    #[tokio::test]
237    async fn negative_binomial_sf_success() {
238        let pmf = sf();
239
240        let recs = make_records(vec![
241            (Some(1), Some(3.0), Some(0.25)),
242            (Some(2), Some(3.0), Some(0.25)),
243            (None, Some(3.0), Some(0.25)),
244            (Some(1), None, Some(0.25)),
245        ]);
246
247        let ctx = SessionContext::new();
248        ctx.register_batch("tbl", recs).unwrap();
249        let df = ctx.table("tbl").await.unwrap();
250        let res = df
251            .select(vec![
252                (pmf.call(vec![col("x"), col("r"), col("p")])).alias("q"),
253            ])
254            .unwrap()
255            .collect()
256            .await
257            .unwrap();
258
259        assert_eq!(res.len(), 1);
260        assert_eq!(res[0].num_columns(), 1);
261        assert_eq!(res[0].num_rows(), 4);
262        let res_col = as_float64_array(res[0].column(0)).unwrap();
263        assert_eq_float!(res_col.value(0), 0.94921875);
264        assert_eq_float!(res_col.value(1), 0.896484375000001);
265        assert!(res_col.value(2).is_nan());
266        assert!(res_col.value(3).is_nan());
267    }
268}