datafusion_statrs/distribution/
normal.rs1use datafusion::error::DataFusionError;
37use datafusion::execution::FunctionRegistry;
38use datafusion::logical_expr::ScalarUDF;
39use statrs::distribution::Normal;
40
41use crate::utils::continuous3f::Continuous3F;
42use crate::utils::evaluator3f::{CdfEvaluator3F, LnPdfEvaluator3F, PdfEvaluator3F, SfEvaluator3F};
43
44type Pdf = Continuous3F<PdfEvaluator3F<Normal>>;
45
46pub fn pdf() -> ScalarUDF {
48 ScalarUDF::from(Pdf::new("normal_pdf"))
49}
50
51type LnPdf = Continuous3F<LnPdfEvaluator3F<Normal>>;
52
53pub fn ln_pdf() -> ScalarUDF {
55 ScalarUDF::from(LnPdf::new("normal_ln_pdf"))
56}
57
58type Cdf = Continuous3F<CdfEvaluator3F<Normal>>;
59
60pub fn cdf() -> ScalarUDF {
62 ScalarUDF::from(Cdf::new("normal_cdf"))
63}
64
65type Sf = Continuous3F<SfEvaluator3F<Normal>>;
66
67pub fn sf() -> ScalarUDF {
69 ScalarUDF::from(Sf::new("normal_sf"))
70}
71
72pub fn register(registry: &mut dyn FunctionRegistry) -> Result<(), DataFusionError> {
74 crate::utils::register::register(registry, vec![pdf(), ln_pdf(), cdf(), sf()])
75}
76
77#[cfg(test)]
78mod tests {
79 use std::sync::Arc;
80
81 use assert_eq_float::assert_eq_float;
82 use datafusion::{
83 arrow::{
84 array::{Float64Array, RecordBatch},
85 datatypes::{DataType, Field, Schema, SchemaRef},
86 },
87 common::cast::as_float64_array,
88 error::DataFusionError,
89 prelude::{SessionContext, col},
90 };
91 use statrs::distribution::NormalError;
92
93 use super::*;
94
95 fn get_schema() -> SchemaRef {
96 SchemaRef::new(Schema::new(vec![
97 Field::new("x", DataType::Float64, true),
98 Field::new("s", DataType::Float64, true),
99 Field::new("r", DataType::Float64, true),
100 ]))
101 }
102
103 fn make_records(rows: Vec<(Option<f64>, Option<f64>, Option<f64>)>) -> RecordBatch {
104 let mut xs = Vec::new();
105 let mut ss = Vec::new();
106 let mut rs = Vec::new();
107 for row in rows {
108 xs.push(row.0);
109 ss.push(row.1);
110 rs.push(row.2);
111 }
112
113 RecordBatch::try_new(
114 get_schema(),
115 vec![
116 Arc::new(Float64Array::from(xs)),
117 Arc::new(Float64Array::from(ss)),
118 Arc::new(Float64Array::from(rs)),
119 ],
120 )
121 .unwrap()
122 }
123
124 #[tokio::test]
125 async fn normal_pdf_success() {
126 let pdf = pdf();
127
128 let recs = make_records(vec![
129 (Some(1.0), Some(3.0), Some(0.25)),
130 (Some(2.), Some(3.0), Some(0.25)),
131 (None, Some(3.0), Some(0.25)),
132 (Some(1.0), None, Some(0.25)),
133 ]);
134
135 let ctx = SessionContext::new();
136 ctx.register_batch("tbl", recs).unwrap();
137 let df = ctx.table("tbl").await.unwrap();
138 let res = df
139 .select(vec![
140 (pdf.call(vec![col("x"), col("s"), col("r")])).alias("q"),
141 ])
142 .unwrap()
143 .collect()
144 .await
145 .unwrap();
146
147 assert_eq!(res.len(), 1);
148 assert_eq!(res[0].num_columns(), 1);
149 assert_eq!(res[0].num_rows(), 4);
150 let res_col = as_float64_array(res[0].column(0)).unwrap();
151 assert_eq_float!(res_col.value(0), 2.0209084334147568e-14);
152 assert_eq_float!(res_col.value(1), 0.0005353209030595414);
153 assert!(res_col.value(2).is_nan());
154 assert!(res_col.value(3).is_nan());
155 }
156
157 #[tokio::test]
158 async fn normal_pdf_failure_1() {
159 let pdf = pdf();
160
161 let recs = make_records(vec![(Some(1.0), Some(0.0), Some(-1.25))]);
162
163 let ctx = SessionContext::new();
164 ctx.register_batch("tbl", recs).unwrap();
165 let df = ctx.table("tbl").await.unwrap();
166 let res = df
167 .select(vec![
168 (pdf.call(vec![col("x"), col("s"), col("r")])).alias("q"),
169 ])
170 .unwrap()
171 .collect()
172 .await;
173 match res {
174 Err(DataFusionError::External(e)) => {
175 let be = e.downcast::<NormalError>().unwrap();
176 assert_eq!(*be.as_ref(), NormalError::StandardDeviationInvalid);
177 }
178 _ => {
179 println!("unexpected result: {:?}", res);
180 assert!(false);
181 }
182 }
183 }
184
185 #[tokio::test]
186 async fn normal_ln_pdf_success() {
187 let mut ctx = SessionContext::new();
188 register(&mut ctx).unwrap();
189 let res = ctx
190 .sql("SELECT normal_ln_pdf(0.2, 5.0, 1.0)")
191 .await
192 .unwrap()
193 .collect()
194 .await
195 .unwrap();
196 assert_eq!(res.len(), 1);
197 assert_eq!(res[0].num_columns(), 1);
198 assert_eq!(res[0].num_rows(), 1);
199 let res_col = as_float64_array(res[0].column(0)).unwrap();
200 assert_eq_float!(res_col.value(0), -12.438938533204672);
201 }
202
203 #[tokio::test]
204 async fn normal_cdf_success() {
205 let pdf = cdf();
206
207 let recs = make_records(vec![
208 (Some(1.0), Some(3.0), Some(0.25)),
209 (Some(2.), Some(3.0), Some(0.25)),
210 (None, Some(3.0), Some(0.25)),
211 (Some(1.0), None, Some(0.25)),
212 ]);
213
214 let ctx = SessionContext::new();
215 ctx.register_batch("tbl", recs).unwrap();
216 let df = ctx.table("tbl").await.unwrap();
217 let res = df
218 .select(vec![
219 (pdf.call(vec![col("x"), col("s"), col("r")])).alias("q"),
220 ])
221 .unwrap()
222 .collect()
223 .await
224 .unwrap();
225
226 assert_eq!(res.len(), 1);
227 assert_eq!(res[0].num_columns(), 1);
228 assert_eq!(res[0].num_rows(), 4);
229 let res_col = as_float64_array(res[0].column(0)).unwrap();
230 assert_eq_float!(res_col.value(0), 6.220960574599358e-16);
231 assert_eq_float!(res_col.value(1), 3.167124183566376e-5);
232 assert!(res_col.value(2).is_nan());
233 assert!(res_col.value(3).is_nan());
234 }
235
236 #[tokio::test]
237 async fn normal_sf_success() {
238 let pdf = sf();
239
240 let recs = make_records(vec![
241 (Some(1.0), Some(3.0), Some(0.25)),
242 (Some(2.), Some(3.0), Some(0.25)),
243 (None, Some(3.0), Some(0.25)),
244 (Some(1.0), None, Some(0.25)),
245 ]);
246
247 let ctx = SessionContext::new();
248 ctx.register_batch("tbl", recs).unwrap();
249 let df = ctx.table("tbl").await.unwrap();
250 let res = df
251 .select(vec![
252 (pdf.call(vec![col("x"), col("s"), col("r")])).alias("q"),
253 ])
254 .unwrap()
255 .collect()
256 .await
257 .unwrap();
258
259 assert_eq!(res.len(), 1);
260 assert_eq!(res[0].num_columns(), 1);
261 assert_eq!(res[0].num_rows(), 4);
262 let res_col = as_float64_array(res[0].column(0)).unwrap();
263 assert_eq_float!(res_col.value(0), 0.9999999999999993);
264 assert_eq_float!(res_col.value(1), 0.9999683287581643);
265 assert!(res_col.value(2).is_nan());
266 assert!(res_col.value(3).is_nan());
267 }
268}