datafusion_statrs/distribution/
binomial.rs1use datafusion::error::DataFusionError;
37use datafusion::execution::FunctionRegistry;
38use datafusion::logical_expr::ScalarUDF;
39use statrs::distribution::Binomial;
40
41use crate::utils::discrete2u1f::Discrete2U1F;
42use crate::utils::evaluator2u1f::{CdfEvaluator2U1F, LnPmfEvaluator2U1F, PmfEvaluator2U1F, SfEvaluator2U1F};
43
44type Pmf = Discrete2U1F<PmfEvaluator2U1F<Binomial>>;
45
46pub fn pmf() -> ScalarUDF {
48 ScalarUDF::from(Pmf::new("binomial_pmf"))
49}
50
51type LnPmf = Discrete2U1F<LnPmfEvaluator2U1F<Binomial>>;
52
53pub fn ln_pmf() -> ScalarUDF {
55 ScalarUDF::from(LnPmf::new("binomial_ln_pmf"))
56}
57
58type Cdf = Discrete2U1F<CdfEvaluator2U1F<Binomial>>;
59
60pub fn cdf() -> ScalarUDF {
62 ScalarUDF::from(Cdf::new("binomial_cdf"))
63}
64
65type Sf = Discrete2U1F<SfEvaluator2U1F<Binomial>>;
66
67pub fn sf() -> ScalarUDF {
69 ScalarUDF::from(Sf::new("binomial_sf"))
70}
71
72pub fn register(registry: &mut dyn FunctionRegistry) -> Result<(), DataFusionError> {
74 crate::utils::register::register(registry, vec![pmf(), ln_pmf(), cdf(), sf()])
75}
76
77#[cfg(test)]
78mod tests {
79 use std::sync::Arc;
80
81 use assert_eq_float::assert_eq_float;
82 use datafusion::{
83 arrow::{
84 array::{Float64Array, RecordBatch, UInt64Array},
85 datatypes::{DataType, Field, Schema, SchemaRef},
86 },
87 common::cast::as_float64_array,
88 error::DataFusionError,
89 prelude::{SessionContext, col},
90 };
91 use statrs::distribution::BinomialError;
92
93 use super::*;
94
95 fn get_schema() -> SchemaRef {
96 SchemaRef::new(Schema::new(vec![
97 Field::new("x", DataType::UInt64, true),
98 Field::new("n", DataType::UInt64, true),
99 Field::new("p", DataType::Float64, true),
100 ]))
101 }
102
103 fn make_records(rows: Vec<(Option<u64>, Option<u64>, Option<f64>)>) -> RecordBatch {
104 let mut xs = Vec::new();
105 let mut ns = Vec::new();
106 let mut ps = Vec::new();
107 for row in rows {
108 xs.push(row.0);
109 ns.push(row.1);
110 ps.push(row.2);
111 }
112
113 RecordBatch::try_new(
114 get_schema(),
115 vec![
116 Arc::new(UInt64Array::from(xs)),
117 Arc::new(UInt64Array::from(ns)),
118 Arc::new(Float64Array::from(ps)),
119 ],
120 )
121 .unwrap()
122 }
123
124 #[tokio::test]
125 async fn binomial_pmf_success() {
126 let pmf = pmf();
127
128 let recs = make_records(vec![
129 (Some(0), Some(3), Some(0.25)),
130 (Some(1), Some(3), Some(0.25)),
131 (None, Some(3), Some(0.25)),
132 (Some(0), None, Some(0.25)),
133 ]);
134
135 let ctx = SessionContext::new();
136 ctx.register_batch("tbl", recs).unwrap();
137 let df = ctx.table("tbl").await.unwrap();
138 let res = df
139 .select(vec![
140 (pmf.call(vec![col("x"), col("n"), col("p")])).alias("q"),
141 ])
142 .unwrap()
143 .collect()
144 .await
145 .unwrap();
146
147 assert_eq!(res.len(), 1);
148 assert_eq!(res[0].num_columns(), 1);
149 assert_eq!(res[0].num_rows(), 4);
150 let res_col = as_float64_array(res[0].column(0)).unwrap();
151 assert_eq_float!(res_col.value(0), 0.421875);
152 assert_eq_float!(res_col.value(1), 0.421875);
153 assert!(res_col.value(2).is_nan());
154 assert!(res_col.value(3).is_nan());
155 }
156
157 #[tokio::test]
158 async fn binomial_pmf_failure_1() {
159 let pmf = pmf();
160
161 let recs = make_records(vec![(Some(0), Some(3), Some(1.25))]);
162
163 let ctx = SessionContext::new();
164 ctx.register_batch("tbl", recs).unwrap();
165 let df = ctx.table("tbl").await.unwrap();
166 let res = df
167 .select(vec![
168 (pmf.call(vec![col("x"), col("n"), col("p")])).alias("q"),
169 ])
170 .unwrap()
171 .collect()
172 .await;
173 match res {
174 Err(DataFusionError::External(e)) => {
175 let be = e.downcast::<BinomialError>().unwrap();
176 assert_eq!(*be.as_ref(), BinomialError::ProbabilityInvalid);
177 }
178 _ => {
179 println!("unexpected result: {:?}", res);
180 assert!(false);
181 }
182 }
183 }
184
185 #[tokio::test]
186 async fn binomial_ln_pdf_success() {
187 let mut ctx = SessionContext::new();
188 register(&mut ctx).unwrap();
189 let res = ctx
190 .sql("SELECT binomial_ln_pmf(CAST(2 AS BIGINT UNSIGNED), CAST(10 AS BIGINT UNSIGNED), 0.5)")
191 .await
192 .unwrap()
193 .collect()
194 .await
195 .unwrap();
196 assert_eq!(res.len(), 1);
197 assert_eq!(res[0].num_columns(), 1);
198 assert_eq!(res[0].num_rows(), 1);
199 let res_col = as_float64_array(res[0].column(0)).unwrap();
200 assert_eq_float!(res_col.value(0), -3.1248093158291335);
201 }
202
203 #[tokio::test]
204 async fn binomial_cdf_success() {
205 let pmf = cdf();
206
207 let recs = make_records(vec![
208 (Some(0), Some(3), Some(0.25)),
209 (Some(1), Some(3), Some(0.25)),
210 (None, Some(3), Some(0.25)),
211 (Some(0), None, Some(0.25)),
212 ]);
213
214 let ctx = SessionContext::new();
215 ctx.register_batch("tbl", recs).unwrap();
216 let df = ctx.table("tbl").await.unwrap();
217 let res = df
218 .select(vec![
219 (pmf.call(vec![col("x"), col("n"), col("p")])).alias("q"),
220 ])
221 .unwrap()
222 .collect()
223 .await
224 .unwrap();
225
226 assert_eq!(res.len(), 1);
227 assert_eq!(res[0].num_columns(), 1);
228 assert_eq!(res[0].num_rows(), 4);
229 let res_col = as_float64_array(res[0].column(0)).unwrap();
230 assert_eq_float!(res_col.value(0), 0.421875, 3e-15);
231 assert_eq_float!(res_col.value(1), 0.84375);
232 assert!(res_col.value(2).is_nan());
233 assert!(res_col.value(3).is_nan());
234 }
235
236 #[tokio::test]
237 async fn binomial_sf_success() {
238 let pmf = sf();
239
240 let recs = make_records(vec![
241 (Some(0), Some(3), Some(0.25)),
242 (Some(1), Some(3), Some(0.25)),
243 (None, Some(3), Some(0.25)),
244 (Some(0), None, Some(0.25)),
245 ]);
246
247 let ctx = SessionContext::new();
248 ctx.register_batch("tbl", recs).unwrap();
249 let df = ctx.table("tbl").await.unwrap();
250 let res = df
251 .select(vec![
252 (pmf.call(vec![col("x"), col("n"), col("p")])).alias("q"),
253 ])
254 .unwrap()
255 .collect()
256 .await
257 .unwrap();
258
259 assert_eq!(res.len(), 1);
260 assert_eq!(res[0].num_columns(), 1);
261 assert_eq!(res[0].num_rows(), 4);
262 let res_col = as_float64_array(res[0].column(0)).unwrap();
263 assert_eq_float!(res_col.value(0), 0.578125, 3e-15);
264 assert_eq_float!(res_col.value(1), 0.15625, 4e-15);
265 assert!(res_col.value(2).is_nan());
266 assert!(res_col.value(3).is_nan());
267 }
268}