datafusion_statrs/distribution/
triangular.rs1use datafusion::error::DataFusionError;
39use datafusion::execution::FunctionRegistry;
40use datafusion::logical_expr::ScalarUDF;
41use statrs::distribution::Triangular;
42
43use crate::utils::continuous4f::Continuous4F;
44use crate::utils::evaluator4f::{CdfEvaluator4F, LnPdfEvaluator4F, PdfEvaluator4F, SfEvaluator4F};
45
46type Pdf = Continuous4F<PdfEvaluator4F<Triangular>>;
47
48pub fn pdf() -> ScalarUDF {
50 ScalarUDF::from(Pdf::new("triangular_pdf"))
51}
52
53type LnPdf = Continuous4F<LnPdfEvaluator4F<Triangular>>;
54
55pub fn ln_pdf() -> ScalarUDF {
57 ScalarUDF::from(LnPdf::new("triangular_ln_pdf"))
58}
59
60type Cdf = Continuous4F<CdfEvaluator4F<Triangular>>;
61
62pub fn cdf() -> ScalarUDF {
64 ScalarUDF::from(Cdf::new("triangular_cdf"))
65}
66
67type Sf = Continuous4F<SfEvaluator4F<Triangular>>;
68
69pub fn sf() -> ScalarUDF {
71 ScalarUDF::from(Sf::new("triangular_sf"))
72}
73
74pub fn register(registry: &mut dyn FunctionRegistry) -> Result<(), DataFusionError> {
76 crate::utils::register::register(registry, vec![pdf(), ln_pdf(), cdf(), sf()])
77}
78
79#[cfg(test)]
80mod tests {
81 use std::sync::Arc;
82
83 use assert_eq_float::assert_eq_float;
84 use datafusion::{
85 arrow::{
86 array::{Float64Array, RecordBatch},
87 datatypes::{DataType, Field, Schema, SchemaRef},
88 },
89 common::cast::as_float64_array,
90 error::DataFusionError,
91 prelude::{SessionContext, col},
92 };
93 use statrs::distribution::TriangularError;
94
95 use super::*;
96
97 fn get_schema() -> SchemaRef {
98 SchemaRef::new(Schema::new(vec![
99 Field::new("x", DataType::Float64, true),
100 Field::new("min", DataType::Float64, true),
101 Field::new("max", DataType::Float64, true),
102 Field::new("mode", DataType::Float64, true),
103 ]))
104 }
105
106 fn make_records(
107 rows: Vec<(Option<f64>, Option<f64>, Option<f64>, Option<f64>)>,
108 ) -> RecordBatch {
109 let mut xs = Vec::new();
110 let mut mns = Vec::new();
111 let mut mxs = Vec::new();
112 let mut mds = Vec::new();
113 for row in rows {
114 xs.push(row.0);
115 mns.push(row.1);
116 mxs.push(row.2);
117 mds.push(row.3);
118 }
119
120 RecordBatch::try_new(
121 get_schema(),
122 vec![
123 Arc::new(Float64Array::from(xs)),
124 Arc::new(Float64Array::from(mns)),
125 Arc::new(Float64Array::from(mxs)),
126 Arc::new(Float64Array::from(mds)),
127 ],
128 )
129 .unwrap()
130 }
131
132 #[tokio::test]
133 async fn triangular_pdf_success() {
134 let pdf = pdf();
135
136 let recs = make_records(vec![
137 (Some(5.0), Some(3.0), Some(7.0), Some(4.0)),
138 (Some(6.0), Some(3.0), Some(7.0), Some(4.0)),
139 (None, Some(3.0), Some(7.0), Some(4.0)),
140 (Some(6.0), None, Some(7.0), Some(4.0)),
141 ]);
142
143 let ctx = SessionContext::new();
144 ctx.register_batch("tbl", recs).unwrap();
145 let df = ctx.table("tbl").await.unwrap();
146 let res = df
147 .select(vec![
148 (pdf.call(vec![col("x"), col("min"), col("max"), col("mode")])).alias("q"),
149 ])
150 .unwrap()
151 .collect()
152 .await
153 .unwrap();
154
155 assert_eq!(res.len(), 1);
156 assert_eq!(res[0].num_columns(), 1);
157 assert_eq!(res[0].num_rows(), 4);
158 let res_col = as_float64_array(res[0].column(0)).unwrap();
159 assert_eq_float!(res_col.value(0), 0.3333333333333333);
160 assert_eq_float!(res_col.value(1), 0.16666666666666666);
161 assert!(res_col.value(2).is_nan());
162 assert!(res_col.value(3).is_nan());
163 }
164
165 #[tokio::test]
166 async fn triangular_pdf_failure_1() {
167 let pdf = pdf();
168
169 let recs = make_records(vec![(Some(1.0), Some(0.0), Some(1.0), Some(-1.25))]);
170
171 let ctx = SessionContext::new();
172 ctx.register_batch("tbl", recs).unwrap();
173 let df = ctx.table("tbl").await.unwrap();
174 let res = df
175 .select(vec![
176 (pdf.call(vec![col("x"), col("min"), col("max"), col("mode")])).alias("q"),
177 ])
178 .unwrap()
179 .collect()
180 .await;
181 match res {
182 Err(DataFusionError::External(e)) => {
183 let be = e.downcast::<TriangularError>().unwrap();
184 assert_eq!(*be.as_ref(), TriangularError::ModeOutOfRange);
185 }
186 _ => {
187 println!("unexpected result: {:?}", res);
188 assert!(false);
189 }
190 }
191 }
192
193 #[tokio::test]
194 async fn triangular_ln_pdf_success() {
195 let mut ctx = SessionContext::new();
196 register(&mut ctx).unwrap();
197 let res = ctx
198 .sql("SELECT triangular_ln_pdf(3.14, 3.0, 7.0, 6.0)")
199 .await
200 .unwrap()
201 .collect()
202 .await
203 .unwrap();
204 assert_eq!(res.len(), 1);
205 assert_eq!(res[0].num_columns(), 1);
206 assert_eq!(res[0].num_rows(), 1);
207 let res_col = as_float64_array(res[0].column(0)).unwrap();
208 assert_eq_float!(res_col.value(0), -3.757872325600887);
209 }
210
211 #[tokio::test]
212 async fn triangular_cdf_success() {
213 let pdf = cdf();
214
215 let recs = make_records(vec![
216 (Some(5.0), Some(3.0), Some(7.0), Some(4.0)),
217 (Some(6.0), Some(3.0), Some(7.0), Some(4.0)),
218 (None, Some(3.0), Some(7.0), Some(4.0)),
219 (Some(6.0), None, Some(7.0), Some(4.0)),
220 ]);
221
222 let ctx = SessionContext::new();
223 ctx.register_batch("tbl", recs).unwrap();
224 let df = ctx.table("tbl").await.unwrap();
225 let res = df
226 .select(vec![
227 (pdf.call(vec![col("x"), col("min"), col("max"), col("mode")])).alias("q"),
228 ])
229 .unwrap()
230 .collect()
231 .await
232 .unwrap();
233
234 assert_eq!(res.len(), 1);
235 assert_eq!(res[0].num_columns(), 1);
236 assert_eq!(res[0].num_rows(), 4);
237 let res_col = as_float64_array(res[0].column(0)).unwrap();
238 assert_eq_float!(res_col.value(0), 0.6666666666666667);
239 assert_eq_float!(res_col.value(1), 0.9166666666666666);
240 assert!(res_col.value(2).is_nan());
241 assert!(res_col.value(3).is_nan());
242 }
243
244 #[tokio::test]
245 async fn triangular_sf_success() {
246 let pdf = sf();
247
248 let recs = make_records(vec![
249 (Some(5.0), Some(3.0), Some(7.0), Some(4.0)),
250 (Some(6.0), Some(3.0), Some(7.0), Some(4.0)),
251 (None, Some(3.0), Some(7.0), Some(4.0)),
252 (Some(6.0), None, Some(7.0), Some(4.0)),
253 ]);
254
255 let ctx = SessionContext::new();
256 ctx.register_batch("tbl", recs).unwrap();
257 let df = ctx.table("tbl").await.unwrap();
258 let res = df
259 .select(vec![
260 (pdf.call(vec![col("x"), col("min"), col("max"), col("mode")])).alias("q"),
261 ])
262 .unwrap()
263 .collect()
264 .await
265 .unwrap();
266
267 assert_eq!(res.len(), 1);
268 assert_eq!(res[0].num_columns(), 1);
269 assert_eq!(res[0].num_rows(), 4);
270 let res_col = as_float64_array(res[0].column(0)).unwrap();
271 assert_eq_float!(res_col.value(0), 0.3333333333333333);
272 assert_eq_float!(res_col.value(1), 0.08333333333333333);
273 assert!(res_col.value(2).is_nan());
274 assert!(res_col.value(3).is_nan());
275 }
276}