datafusion_statrs/distribution/
negative_binomial.rs1use datafusion::error::DataFusionError;
37use datafusion::execution::FunctionRegistry;
38use datafusion::logical_expr::ScalarUDF;
39use statrs::distribution::NegativeBinomial;
40
41use crate::utils::discrete1u2f::Discrete1U2F;
42use crate::utils::evaluator1u2f::{CdfEvaluator1U2F, LnPmfEvaluator1U2F, PmfEvaluator1U2F, SfEvaluator1U2F};
43
44type Pmf = Discrete1U2F<PmfEvaluator1U2F<NegativeBinomial>>;
45
46pub fn pmf() -> ScalarUDF {
48 ScalarUDF::from(Pmf::new("negative_binomial_pmf"))
49}
50
51type LnPmf = Discrete1U2F<LnPmfEvaluator1U2F<NegativeBinomial>>;
52
53pub fn ln_pmf() -> ScalarUDF {
55 ScalarUDF::from(LnPmf::new("negative_binomial_ln_pmf"))
56}
57
58type Cdf = Discrete1U2F<CdfEvaluator1U2F<NegativeBinomial>>;
59
60pub fn cdf() -> ScalarUDF {
62 ScalarUDF::from(Cdf::new("negative_binomial_cdf"))
63}
64
65type Sf = Discrete1U2F<SfEvaluator1U2F<NegativeBinomial>>;
66
67pub fn sf() -> ScalarUDF {
69 ScalarUDF::from(Sf::new("negative_binomial_sf"))
70}
71
72pub fn register(registry: &mut dyn FunctionRegistry) -> Result<(), DataFusionError> {
74 crate::utils::register::register(registry, vec![pmf(), ln_pmf(), cdf(), sf()])
75}
76
77#[cfg(test)]
78mod tests {
79 use std::sync::Arc;
80
81 use assert_eq_float::assert_eq_float;
82 use datafusion::{
83 arrow::{
84 array::{Float64Array, RecordBatch, UInt64Array},
85 datatypes::{DataType, Field, Schema, SchemaRef},
86 },
87 common::cast::as_float64_array,
88 error::DataFusionError,
89 prelude::{SessionContext, col},
90 };
91 use statrs::distribution::NegativeBinomialError;
92
93 use super::*;
94
95 fn get_schema() -> SchemaRef {
96 SchemaRef::new(Schema::new(vec![
97 Field::new("x", DataType::UInt64, true),
98 Field::new("r", DataType::Float64, true),
99 Field::new("p", DataType::Float64, true),
100 ]))
101 }
102
103 fn make_records(rows: Vec<(Option<u64>, Option<f64>, Option<f64>)>) -> RecordBatch {
104 let mut xs = Vec::new();
105 let mut ss = Vec::new();
106 let mut rs = Vec::new();
107 for row in rows {
108 xs.push(row.0);
109 ss.push(row.1);
110 rs.push(row.2);
111 }
112
113 RecordBatch::try_new(
114 get_schema(),
115 vec![
116 Arc::new(UInt64Array::from(xs)),
117 Arc::new(Float64Array::from(ss)),
118 Arc::new(Float64Array::from(rs)),
119 ],
120 )
121 .unwrap()
122 }
123
124 #[tokio::test]
125 async fn negative_binomial_pmf_success() {
126 let pmf = pmf();
127
128 let recs = make_records(vec![
129 (Some(1), Some(3.0), Some(0.25)),
130 (Some(2), Some(3.0), Some(0.25)),
131 (None, Some(3.0), Some(0.25)),
132 (Some(1), None, Some(0.25)),
133 ]);
134
135 let ctx = SessionContext::new();
136 ctx.register_batch("tbl", recs).unwrap();
137 let df = ctx.table("tbl").await.unwrap();
138 let res = df
139 .select(vec![
140 (pmf.call(vec![col("x"), col("r"), col("p")])).alias("q"),
141 ])
142 .unwrap()
143 .collect()
144 .await
145 .unwrap();
146
147 assert_eq!(res.len(), 1);
148 assert_eq!(res[0].num_columns(), 1);
149 assert_eq!(res[0].num_rows(), 4);
150 let res_col = as_float64_array(res[0].column(0)).unwrap();
151 assert_eq_float!(res_col.value(0), 0.035156249999999827);
152 assert_eq_float!(res_col.value(1), 0.05273437499999992);
153 assert!(res_col.value(2).is_nan());
154 assert!(res_col.value(3).is_nan());
155 }
156
157 #[tokio::test]
158 async fn negative_binomial_pmf_failure_1() {
159 let pmf = pmf();
160
161 let recs = make_records(vec![(Some(1), Some(0.0), Some(1.25))]);
162
163 let ctx = SessionContext::new();
164 ctx.register_batch("tbl", recs).unwrap();
165 let df = ctx.table("tbl").await.unwrap();
166 let res = df
167 .select(vec![
168 (pmf.call(vec![col("x"), col("r"), col("p")])).alias("q"),
169 ])
170 .unwrap()
171 .collect()
172 .await;
173 match res {
174 Err(DataFusionError::External(e)) => {
175 let be = e.downcast::<NegativeBinomialError>().unwrap();
176 assert_eq!(*be.as_ref(), NegativeBinomialError::PInvalid);
177 }
178 _ => {
179 println!("unexpected result: {:?}", res);
180 assert!(false);
181 }
182 }
183 }
184
185 #[tokio::test]
186 async fn negative_binomial_ln_pdf_success() {
187 let mut ctx = SessionContext::new();
188 register(&mut ctx).unwrap();
189 let res = ctx
190 .sql("SELECT negative_binomial_ln_pmf(CAST(2 AS BIGINT UNSIGNED), 8.0, 0.11)")
191 .await
192 .unwrap()
193 .collect()
194 .await
195 .unwrap();
196 assert_eq!(res.len(), 1);
197 assert_eq!(res[0].num_columns(), 1);
198 assert_eq!(res[0].num_rows(), 1);
199 let res_col = as_float64_array(res[0].column(0)).unwrap();
200 assert_eq_float!(res_col.value(0), -14.307747999573525);
201 }
202
203 #[tokio::test]
204 async fn negative_binomial_cdf_success() {
205 let pmf = cdf();
206
207 let recs = make_records(vec![
208 (Some(1), Some(3.0), Some(0.25)),
209 (Some(2), Some(3.0), Some(0.25)),
210 (None, Some(3.0), Some(0.25)),
211 (Some(1), None, Some(0.25)),
212 ]);
213
214 let ctx = SessionContext::new();
215 ctx.register_batch("tbl", recs).unwrap();
216 let df = ctx.table("tbl").await.unwrap();
217 let res = df
218 .select(vec![
219 (pmf.call(vec![col("x"), col("r"), col("p")])).alias("q"),
220 ])
221 .unwrap()
222 .collect()
223 .await
224 .unwrap();
225
226 assert_eq!(res.len(), 1);
227 assert_eq!(res[0].num_columns(), 1);
228 assert_eq!(res[0].num_rows(), 4);
229 let res_col = as_float64_array(res[0].column(0)).unwrap();
230 assert_eq_float!(res_col.value(0), 0.050781250000000056);
231 assert_eq_float!(res_col.value(1), 0.10351562499999896);
232 assert!(res_col.value(2).is_nan());
233 assert!(res_col.value(3).is_nan());
234 }
235
236 #[tokio::test]
237 async fn negative_binomial_sf_success() {
238 let pmf = sf();
239
240 let recs = make_records(vec![
241 (Some(1), Some(3.0), Some(0.25)),
242 (Some(2), Some(3.0), Some(0.25)),
243 (None, Some(3.0), Some(0.25)),
244 (Some(1), None, Some(0.25)),
245 ]);
246
247 let ctx = SessionContext::new();
248 ctx.register_batch("tbl", recs).unwrap();
249 let df = ctx.table("tbl").await.unwrap();
250 let res = df
251 .select(vec![
252 (pmf.call(vec![col("x"), col("r"), col("p")])).alias("q"),
253 ])
254 .unwrap()
255 .collect()
256 .await
257 .unwrap();
258
259 assert_eq!(res.len(), 1);
260 assert_eq!(res[0].num_columns(), 1);
261 assert_eq!(res[0].num_rows(), 4);
262 let res_col = as_float64_array(res[0].column(0)).unwrap();
263 assert_eq_float!(res_col.value(0), 0.94921875);
264 assert_eq_float!(res_col.value(1), 0.896484375000001);
265 assert!(res_col.value(2).is_nan());
266 assert!(res_col.value(3).is_nan());
267 }
268}