datafusion_statrs/distribution/
triangular.rs

1//! Module containing functions to the Triangular Distribution.
2//! 
3//! Implemented by [`statrs::distribution::Triangular`].
4//! 
5//! The [Triangular Distribution](https://en.wikipedia.org/wiki/Triangular_distribution) has three
6//! parameters:
7//! 
8//! a: a ∈ R (real numbers) 
9//! b: a < b 
10//! c: a ≤ c ≤ b
11//! 
12//! Usage:
13//! 
14//! `triangular_pdf(x, a, b, c)`  
15//! `triangular_ln_pdf(x, a, b, c)`  
16//! `triangular_cdf(x, a, b, c)`  
17//! `triangular_sf(x, a, b, c)`
18//! 
19//! with
20//! 
21//!   `x`: [a, b] `Float64`/`DOUBLE`,  
22//!   `a`: (-∞, +∞) `Float64`/`DOUBLE`,  
23//!   `b`: (a, +∞) `Float64`/`DOUBLE`,  
24//!   `c`: [a, b] `Float64`/`DOUBLE`
25//! 
26//! Examples
27//! ```
28//! #[tokio::main(flavor = "current_thread")]
29//! async fn main() -> std::io::Result<()> {
30//!     let mut ctx = datafusion::prelude::SessionContext::new();
31//!     datafusion_statrs::distribution::triangular::register(&mut ctx)?;
32//!     ctx.sql("SELECT triangular_pdf(1.1, 1.0, 3.0, 2.5)").await?
33//!        .show().await?;
34//!     Ok(())
35//! }
36//! ```
37
38use datafusion::error::DataFusionError;
39use datafusion::execution::FunctionRegistry;
40use datafusion::logical_expr::ScalarUDF;
41use statrs::distribution::Triangular;
42
43use crate::utils::continuous4f::Continuous4F;
44use crate::utils::evaluator4f::{CdfEvaluator4F, LnPdfEvaluator4F, PdfEvaluator4F, SfEvaluator4F};
45
46type Pdf = Continuous4F<PdfEvaluator4F<Triangular>>;
47
48/// ScalarUDF for the Triangular PDF
49pub fn pdf() -> ScalarUDF {
50    ScalarUDF::from(Pdf::new("triangular_pdf"))
51}
52
53type LnPdf = Continuous4F<LnPdfEvaluator4F<Triangular>>;
54
55/// ScalarUDF for the Triangular log PDF
56pub fn ln_pdf() -> ScalarUDF {
57    ScalarUDF::from(LnPdf::new("triangular_ln_pdf"))
58}
59
60type Cdf = Continuous4F<CdfEvaluator4F<Triangular>>;
61
62/// ScalarUDF for the Triangular PDF
63pub fn cdf() -> ScalarUDF {
64    ScalarUDF::from(Cdf::new("triangular_cdf"))
65}
66
67type Sf = Continuous4F<SfEvaluator4F<Triangular>>;
68
69/// ScalarUDF for the Triangular PDF
70pub fn sf() -> ScalarUDF {
71    ScalarUDF::from(Sf::new("triangular_sf"))
72}
73
74/// Register the functions for the Triangular Distribution
75pub fn register(registry: &mut dyn FunctionRegistry) -> Result<(), DataFusionError> {
76    crate::utils::register::register(registry, vec![pdf(), ln_pdf(), cdf(), sf()])
77}
78
79#[cfg(test)]
80mod tests {
81    use std::sync::Arc;
82
83    use assert_eq_float::assert_eq_float;
84    use datafusion::{
85        arrow::{
86            array::{Float64Array, RecordBatch},
87            datatypes::{DataType, Field, Schema, SchemaRef},
88        },
89        common::cast::as_float64_array,
90        error::DataFusionError,
91        prelude::{SessionContext, col},
92    };
93    use statrs::distribution::TriangularError;
94
95    use super::*;
96
97    fn get_schema() -> SchemaRef {
98        SchemaRef::new(Schema::new(vec![
99            Field::new("x", DataType::Float64, true),
100            Field::new("min", DataType::Float64, true),
101            Field::new("max", DataType::Float64, true),
102            Field::new("mode", DataType::Float64, true),
103        ]))
104    }
105
106    fn make_records(
107        rows: Vec<(Option<f64>, Option<f64>, Option<f64>, Option<f64>)>,
108    ) -> RecordBatch {
109        let mut xs = Vec::new();
110        let mut mns = Vec::new();
111        let mut mxs = Vec::new();
112        let mut mds = Vec::new();
113        for row in rows {
114            xs.push(row.0);
115            mns.push(row.1);
116            mxs.push(row.2);
117            mds.push(row.3);
118        }
119
120        RecordBatch::try_new(
121            get_schema(),
122            vec![
123                Arc::new(Float64Array::from(xs)),
124                Arc::new(Float64Array::from(mns)),
125                Arc::new(Float64Array::from(mxs)),
126                Arc::new(Float64Array::from(mds)),
127            ],
128        )
129        .unwrap()
130    }
131
132    #[tokio::test]
133    async fn triangular_pdf_success() {
134        let pdf = pdf();
135
136        let recs = make_records(vec![
137            (Some(5.0), Some(3.0), Some(7.0), Some(4.0)),
138            (Some(6.0), Some(3.0), Some(7.0), Some(4.0)),
139            (None, Some(3.0), Some(7.0), Some(4.0)),
140            (Some(6.0), None, Some(7.0), Some(4.0)),
141        ]);
142
143        let ctx = SessionContext::new();
144        ctx.register_batch("tbl", recs).unwrap();
145        let df = ctx.table("tbl").await.unwrap();
146        let res = df
147            .select(vec![
148                (pdf.call(vec![col("x"), col("min"), col("max"), col("mode")])).alias("q"),
149            ])
150            .unwrap()
151            .collect()
152            .await
153            .unwrap();
154
155        assert_eq!(res.len(), 1);
156        assert_eq!(res[0].num_columns(), 1);
157        assert_eq!(res[0].num_rows(), 4);
158        let res_col = as_float64_array(res[0].column(0)).unwrap();
159        assert_eq_float!(res_col.value(0), 0.3333333333333333);
160        assert_eq_float!(res_col.value(1), 0.16666666666666666);
161        assert!(res_col.value(2).is_nan());
162        assert!(res_col.value(3).is_nan());
163    }
164
165    #[tokio::test]
166    async fn triangular_pdf_failure_1() {
167        let pdf = pdf();
168
169        let recs = make_records(vec![(Some(1.0), Some(0.0), Some(1.0), Some(-1.25))]);
170
171        let ctx = SessionContext::new();
172        ctx.register_batch("tbl", recs).unwrap();
173        let df = ctx.table("tbl").await.unwrap();
174        let res = df
175            .select(vec![
176                (pdf.call(vec![col("x"), col("min"), col("max"), col("mode")])).alias("q"),
177            ])
178            .unwrap()
179            .collect()
180            .await;
181        match res {
182            Err(DataFusionError::External(e)) => {
183                let be = e.downcast::<TriangularError>().unwrap();
184                assert_eq!(*be.as_ref(), TriangularError::ModeOutOfRange);
185            }
186            _ => {
187                println!("unexpected result: {:?}", res);
188                assert!(false);
189            }
190        }
191    }
192
193    #[tokio::test]
194    async fn triangular_ln_pdf_success() {
195        let mut ctx = SessionContext::new();
196        register(&mut ctx).unwrap();
197        let res = ctx
198            .sql("SELECT triangular_ln_pdf(3.14, 3.0, 7.0, 6.0)")
199            .await
200            .unwrap()
201            .collect()
202            .await
203            .unwrap();
204        assert_eq!(res.len(), 1);
205        assert_eq!(res[0].num_columns(), 1);
206        assert_eq!(res[0].num_rows(), 1);
207        let res_col = as_float64_array(res[0].column(0)).unwrap();
208        assert_eq_float!(res_col.value(0), -3.757872325600887);
209    }
210
211    #[tokio::test]
212    async fn triangular_cdf_success() {
213        let pdf = cdf();
214
215        let recs = make_records(vec![
216            (Some(5.0), Some(3.0), Some(7.0), Some(4.0)),
217            (Some(6.0), Some(3.0), Some(7.0), Some(4.0)),
218            (None, Some(3.0), Some(7.0), Some(4.0)),
219            (Some(6.0), None, Some(7.0), Some(4.0)),
220        ]);
221
222        let ctx = SessionContext::new();
223        ctx.register_batch("tbl", recs).unwrap();
224        let df = ctx.table("tbl").await.unwrap();
225        let res = df
226            .select(vec![
227                (pdf.call(vec![col("x"), col("min"), col("max"), col("mode")])).alias("q"),
228            ])
229            .unwrap()
230            .collect()
231            .await
232            .unwrap();
233
234        assert_eq!(res.len(), 1);
235        assert_eq!(res[0].num_columns(), 1);
236        assert_eq!(res[0].num_rows(), 4);
237        let res_col = as_float64_array(res[0].column(0)).unwrap();
238        assert_eq_float!(res_col.value(0), 0.6666666666666667);
239        assert_eq_float!(res_col.value(1), 0.9166666666666666);
240        assert!(res_col.value(2).is_nan());
241        assert!(res_col.value(3).is_nan());
242    }
243
244    #[tokio::test]
245    async fn triangular_sf_success() {
246        let pdf = sf();
247
248        let recs = make_records(vec![
249            (Some(5.0), Some(3.0), Some(7.0), Some(4.0)),
250            (Some(6.0), Some(3.0), Some(7.0), Some(4.0)),
251            (None, Some(3.0), Some(7.0), Some(4.0)),
252            (Some(6.0), None, Some(7.0), Some(4.0)),
253        ]);
254
255        let ctx = SessionContext::new();
256        ctx.register_batch("tbl", recs).unwrap();
257        let df = ctx.table("tbl").await.unwrap();
258        let res = df
259            .select(vec![
260                (pdf.call(vec![col("x"), col("min"), col("max"), col("mode")])).alias("q"),
261            ])
262            .unwrap()
263            .collect()
264            .await
265            .unwrap();
266
267        assert_eq!(res.len(), 1);
268        assert_eq!(res[0].num_columns(), 1);
269        assert_eq!(res[0].num_rows(), 4);
270        let res_col = as_float64_array(res[0].column(0)).unwrap();
271        assert_eq_float!(res_col.value(0), 0.3333333333333333);
272        assert_eq_float!(res_col.value(1), 0.08333333333333333);
273        assert!(res_col.value(2).is_nan());
274        assert!(res_col.value(3).is_nan());
275    }
276}