datafusion_statrs/distribution/
normal.rs

1//! Module containing functions to the Normal Distribution.
2//! 
3//! Implemented by [`statrs::distribution::Normal`].
4//! 
5//! The [Normal Distribution](https://en.wikipedia.org/wiki/Normal_distribution) has two
6//! parameters:
7//! 
8//! μ: μ ∈ R (real numbers) 
9//! σ: 0 < σ (rate)
10//! 
11//! Usage:
12//! 
13//! `normal_pdf(x, μ, σ)`  
14//! `normal_ln_pdf(x, μ, σ)`  
15//! `normal_cdf(x, μ, σ)`  
16//! `normal_sf(x, μ, σ)`
17//! 
18//! with
19//! 
20//!   `x`: (-∞, +∞) `Float64`/`DOUBLE`,  
21//!   `μ`: (-∞, +∞) `Float64`/`DOUBLE`,  
22//!   `σ`: (0, +∞) `Float64`/`DOUBLE`
23//! 
24//! Examples
25//! ```
26//! #[tokio::main(flavor = "current_thread")]
27//! async fn main() -> std::io::Result<()> {
28//!     let mut ctx = datafusion::prelude::SessionContext::new();
29//!     datafusion_statrs::distribution::normal::register(&mut ctx)?;
30//!     ctx.sql("SELECT normal_pdf(1.1, 9.0, 1.0)").await?
31//!        .show().await?;
32//!     Ok(())
33//! }
34//! ```
35
36use datafusion::error::DataFusionError;
37use datafusion::execution::FunctionRegistry;
38use datafusion::logical_expr::ScalarUDF;
39use statrs::distribution::Normal;
40
41use crate::utils::continuous3f::Continuous3F;
42use crate::utils::evaluator3f::{CdfEvaluator3F, LnPdfEvaluator3F, PdfEvaluator3F, SfEvaluator3F};
43
44type Pdf = Continuous3F<PdfEvaluator3F<Normal>>;
45
46/// ScalarUDF for the Normal PDF
47pub fn pdf() -> ScalarUDF {
48    ScalarUDF::from(Pdf::new("normal_pdf"))
49}
50
51type LnPdf = Continuous3F<LnPdfEvaluator3F<Normal>>;
52
53/// ScalarUDF for the Normal log PDF
54pub fn ln_pdf() -> ScalarUDF {
55    ScalarUDF::from(LnPdf::new("normal_ln_pdf"))
56}
57
58type Cdf = Continuous3F<CdfEvaluator3F<Normal>>;
59
60/// ScalarUDF for the Normal CDF
61pub fn cdf() -> ScalarUDF {
62    ScalarUDF::from(Cdf::new("normal_cdf"))
63}
64
65type Sf = Continuous3F<SfEvaluator3F<Normal>>;
66
67/// ScalarUDF for the Normal SF
68pub fn sf() -> ScalarUDF {
69    ScalarUDF::from(Sf::new("normal_sf"))
70}
71
72/// Register the functions for the Normal Distribution
73pub fn register(registry: &mut dyn FunctionRegistry) -> Result<(), DataFusionError> {
74    crate::utils::register::register(registry, vec![pdf(), ln_pdf(), cdf(), sf()])
75}
76
77#[cfg(test)]
78mod tests {
79    use std::sync::Arc;
80
81    use assert_eq_float::assert_eq_float;
82    use datafusion::{
83        arrow::{
84            array::{Float64Array, RecordBatch},
85            datatypes::{DataType, Field, Schema, SchemaRef},
86        },
87        common::cast::as_float64_array,
88        error::DataFusionError,
89        prelude::{SessionContext, col},
90    };
91    use statrs::distribution::NormalError;
92
93    use super::*;
94
95    fn get_schema() -> SchemaRef {
96        SchemaRef::new(Schema::new(vec![
97            Field::new("x", DataType::Float64, true),
98            Field::new("s", DataType::Float64, true),
99            Field::new("r", DataType::Float64, true),
100        ]))
101    }
102
103    fn make_records(rows: Vec<(Option<f64>, Option<f64>, Option<f64>)>) -> RecordBatch {
104        let mut xs = Vec::new();
105        let mut ss = Vec::new();
106        let mut rs = Vec::new();
107        for row in rows {
108            xs.push(row.0);
109            ss.push(row.1);
110            rs.push(row.2);
111        }
112
113        RecordBatch::try_new(
114            get_schema(),
115            vec![
116                Arc::new(Float64Array::from(xs)),
117                Arc::new(Float64Array::from(ss)),
118                Arc::new(Float64Array::from(rs)),
119            ],
120        )
121        .unwrap()
122    }
123
124    #[tokio::test]
125    async fn normal_pdf_success() {
126        let pdf = pdf();
127
128        let recs = make_records(vec![
129            (Some(1.0), Some(3.0), Some(0.25)),
130            (Some(2.), Some(3.0), Some(0.25)),
131            (None, Some(3.0), Some(0.25)),
132            (Some(1.0), None, Some(0.25)),
133        ]);
134
135        let ctx = SessionContext::new();
136        ctx.register_batch("tbl", recs).unwrap();
137        let df = ctx.table("tbl").await.unwrap();
138        let res = df
139            .select(vec![
140                (pdf.call(vec![col("x"), col("s"), col("r")])).alias("q"),
141            ])
142            .unwrap()
143            .collect()
144            .await
145            .unwrap();
146
147        assert_eq!(res.len(), 1);
148        assert_eq!(res[0].num_columns(), 1);
149        assert_eq!(res[0].num_rows(), 4);
150        let res_col = as_float64_array(res[0].column(0)).unwrap();
151        assert_eq_float!(res_col.value(0), 2.0209084334147568e-14);
152        assert_eq_float!(res_col.value(1), 0.0005353209030595414);
153        assert!(res_col.value(2).is_nan());
154        assert!(res_col.value(3).is_nan());
155    }
156
157    #[tokio::test]
158    async fn normal_pdf_failure_1() {
159        let pdf = pdf();
160
161        let recs = make_records(vec![(Some(1.0), Some(0.0), Some(-1.25))]);
162
163        let ctx = SessionContext::new();
164        ctx.register_batch("tbl", recs).unwrap();
165        let df = ctx.table("tbl").await.unwrap();
166        let res = df
167            .select(vec![
168                (pdf.call(vec![col("x"), col("s"), col("r")])).alias("q"),
169            ])
170            .unwrap()
171            .collect()
172            .await;
173        match res {
174            Err(DataFusionError::External(e)) => {
175                let be = e.downcast::<NormalError>().unwrap();
176                assert_eq!(*be.as_ref(), NormalError::StandardDeviationInvalid);
177            }
178            _ => {
179                println!("unexpected result: {:?}", res);
180                assert!(false);
181            }
182        }
183    }
184
185    #[tokio::test]
186    async fn normal_ln_pdf_success() {
187        let mut ctx = SessionContext::new();
188        register(&mut ctx).unwrap();
189        let res = ctx
190            .sql("SELECT normal_ln_pdf(0.2, 5.0, 1.0)")
191            .await
192            .unwrap()
193            .collect()
194            .await
195            .unwrap();
196        assert_eq!(res.len(), 1);
197        assert_eq!(res[0].num_columns(), 1);
198        assert_eq!(res[0].num_rows(), 1);
199        let res_col = as_float64_array(res[0].column(0)).unwrap();
200        assert_eq_float!(res_col.value(0), -12.438938533204672);
201    }
202
203    #[tokio::test]
204    async fn normal_cdf_success() {
205        let pdf = cdf();
206
207        let recs = make_records(vec![
208            (Some(1.0), Some(3.0), Some(0.25)),
209            (Some(2.), Some(3.0), Some(0.25)),
210            (None, Some(3.0), Some(0.25)),
211            (Some(1.0), None, Some(0.25)),
212        ]);
213
214        let ctx = SessionContext::new();
215        ctx.register_batch("tbl", recs).unwrap();
216        let df = ctx.table("tbl").await.unwrap();
217        let res = df
218            .select(vec![
219                (pdf.call(vec![col("x"), col("s"), col("r")])).alias("q"),
220            ])
221            .unwrap()
222            .collect()
223            .await
224            .unwrap();
225
226        assert_eq!(res.len(), 1);
227        assert_eq!(res[0].num_columns(), 1);
228        assert_eq!(res[0].num_rows(), 4);
229        let res_col = as_float64_array(res[0].column(0)).unwrap();
230        assert_eq_float!(res_col.value(0), 6.220960574599358e-16);
231        assert_eq_float!(res_col.value(1), 3.167124183566376e-5);
232        assert!(res_col.value(2).is_nan());
233        assert!(res_col.value(3).is_nan());
234    }
235
236    #[tokio::test]
237    async fn normal_sf_success() {
238        let pdf = sf();
239
240        let recs = make_records(vec![
241            (Some(1.0), Some(3.0), Some(0.25)),
242            (Some(2.), Some(3.0), Some(0.25)),
243            (None, Some(3.0), Some(0.25)),
244            (Some(1.0), None, Some(0.25)),
245        ]);
246
247        let ctx = SessionContext::new();
248        ctx.register_batch("tbl", recs).unwrap();
249        let df = ctx.table("tbl").await.unwrap();
250        let res = df
251            .select(vec![
252                (pdf.call(vec![col("x"), col("s"), col("r")])).alias("q"),
253            ])
254            .unwrap()
255            .collect()
256            .await
257            .unwrap();
258
259        assert_eq!(res.len(), 1);
260        assert_eq!(res[0].num_columns(), 1);
261        assert_eq!(res[0].num_rows(), 4);
262        let res_col = as_float64_array(res[0].column(0)).unwrap();
263        assert_eq_float!(res_col.value(0), 0.9999999999999993);
264        assert_eq_float!(res_col.value(1), 0.9999683287581643);
265        assert!(res_col.value(2).is_nan());
266        assert!(res_col.value(3).is_nan());
267    }
268}