polars_core/chunked_array/ops/aggregate/
mod.rs

1//! Implementations of the ChunkAgg trait.
2mod quantile;
3mod var;
4
5use arrow::types::NativeType;
6use num_traits::{Float, One, ToPrimitive, Zero};
7use polars_compute::float_sum;
8use polars_compute::min_max::MinMaxKernel;
9use polars_compute::sum::{wrapping_sum_arr, WrappingSum};
10use polars_utils::min_max::MinMax;
11use polars_utils::sync::SyncPtr;
12pub use quantile::*;
13pub use var::*;
14
15use super::float_sorted_arg_max::{
16    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
17};
18use crate::chunked_array::ChunkedArray;
19use crate::datatypes::{BooleanChunked, PolarsNumericType};
20use crate::prelude::*;
21use crate::series::IsSorted;
22
23/// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations.
24pub trait ChunkAggSeries {
25    /// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1.
26    fn sum_reduce(&self) -> Scalar {
27        unimplemented!()
28    }
29    /// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1.
30    fn max_reduce(&self) -> Scalar {
31        unimplemented!()
32    }
33    /// Get the min of the [`ChunkedArray`] as a new [`Series`] of length 1.
34    fn min_reduce(&self) -> Scalar {
35        unimplemented!()
36    }
37    /// Get the product of the [`ChunkedArray`] as a new [`Series`] of length 1.
38    fn prod_reduce(&self) -> Scalar {
39        unimplemented!()
40    }
41}
42
43fn sum<T>(array: &PrimitiveArray<T>) -> T
44where
45    T: NumericNative + NativeType + WrappingSum,
46{
47    if array.null_count() == array.len() {
48        return T::default();
49    }
50
51    if T::is_float() {
52        unsafe {
53            if T::is_f32() {
54                let f32_arr =
55                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f32>>(array);
56                let sum = float_sum::sum_arr_as_f32(f32_arr);
57                std::mem::transmute_copy::<f32, T>(&sum)
58            } else if T::is_f64() {
59                let f64_arr =
60                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f64>>(array);
61                let sum = float_sum::sum_arr_as_f64(f64_arr);
62                std::mem::transmute_copy::<f64, T>(&sum)
63            } else {
64                unreachable!("only supported float types are f32 and f64");
65            }
66        }
67    } else {
68        wrapping_sum_arr(array)
69    }
70}
71
72impl<T> ChunkAgg<T::Native> for ChunkedArray<T>
73where
74    T: PolarsNumericType,
75    T::Native: WrappingSum,
76    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
77{
78    fn sum(&self) -> Option<T::Native> {
79        Some(
80            self.downcast_iter()
81                .map(sum)
82                .fold(T::Native::zero(), |acc, v| acc + v),
83        )
84    }
85
86    fn _sum_as_f64(&self) -> f64 {
87        self.downcast_iter().map(float_sum::sum_arr_as_f64).sum()
88    }
89
90    fn min(&self) -> Option<T::Native> {
91        if self.null_count() == self.len() {
92            return None;
93        }
94
95        // There is at least one non-null value.
96
97        let result = match self.is_sorted_flag() {
98            IsSorted::Ascending => {
99                let idx = self.first_non_null().unwrap();
100                unsafe { self.get_unchecked(idx) }
101            },
102            IsSorted::Descending => {
103                let idx = self.last_non_null().unwrap();
104                unsafe { self.get_unchecked(idx) }
105            },
106            IsSorted::Not => self
107                .downcast_iter()
108                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
109                .reduce(MinMax::min_ignore_nan),
110        };
111
112        result
113    }
114
115    fn max(&self) -> Option<T::Native> {
116        if self.null_count() == self.len() {
117            return None;
118        }
119        // There is at least one non-null value.
120
121        let result = match self.is_sorted_flag() {
122            IsSorted::Ascending => {
123                let idx = if T::get_dtype().is_float() {
124                    float_arg_max_sorted_ascending(self)
125                } else {
126                    self.last_non_null().unwrap()
127                };
128
129                unsafe { self.get_unchecked(idx) }
130            },
131            IsSorted::Descending => {
132                let idx = if T::get_dtype().is_float() {
133                    float_arg_max_sorted_descending(self)
134                } else {
135                    self.first_non_null().unwrap()
136                };
137
138                unsafe { self.get_unchecked(idx) }
139            },
140            IsSorted::Not => self
141                .downcast_iter()
142                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
143                .reduce(MinMax::max_ignore_nan),
144        };
145
146        result
147    }
148
149    fn min_max(&self) -> Option<(T::Native, T::Native)> {
150        if self.null_count() == self.len() {
151            return None;
152        }
153        // There is at least one non-null value.
154
155        let result = match self.is_sorted_flag() {
156            IsSorted::Ascending => {
157                let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) };
158                let max = {
159                    let idx = if T::get_dtype().is_float() {
160                        float_arg_max_sorted_ascending(self)
161                    } else {
162                        self.last_non_null().unwrap()
163                    };
164
165                    unsafe { self.get_unchecked(idx) }
166                };
167                min.zip(max)
168            },
169            IsSorted::Descending => {
170                let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) };
171                let max = {
172                    let idx = if T::get_dtype().is_float() {
173                        float_arg_max_sorted_descending(self)
174                    } else {
175                        self.first_non_null().unwrap()
176                    };
177
178                    unsafe { self.get_unchecked(idx) }
179                };
180
181                min.zip(max)
182            },
183            IsSorted::Not => self
184                .downcast_iter()
185                .filter_map(MinMaxKernel::min_max_ignore_nan_kernel)
186                .reduce(|(min1, max1), (min2, max2)| {
187                    (
188                        MinMax::min_ignore_nan(min1, min2),
189                        MinMax::max_ignore_nan(max1, max2),
190                    )
191                }),
192        };
193
194        result
195    }
196
197    fn mean(&self) -> Option<f64> {
198        let count = self.len() - self.null_count();
199        if count == 0 {
200            return None;
201        }
202        Some(self._sum_as_f64() / count as f64)
203    }
204}
205
206/// Booleans are cast to 1 or 0.
207impl BooleanChunked {
208    pub fn sum(&self) -> Option<IdxSize> {
209        Some(if self.is_empty() {
210            0
211        } else {
212            self.downcast_iter()
213                .map(|arr| match arr.validity() {
214                    Some(validity) => {
215                        (arr.len() - (validity & arr.values()).unset_bits()) as IdxSize
216                    },
217                    None => (arr.len() - arr.values().unset_bits()) as IdxSize,
218                })
219                .sum()
220        })
221    }
222
223    pub fn min(&self) -> Option<bool> {
224        let nc = self.null_count();
225        let len = self.len();
226        if self.is_empty() || nc == len {
227            return None;
228        }
229        if nc == 0 {
230            if self.all() {
231                Some(true)
232            } else {
233                Some(false)
234            }
235        } else {
236            // we can unwrap as we already checked empty and all null above
237            if (self.sum().unwrap() + nc as IdxSize) == len as IdxSize {
238                Some(true)
239            } else {
240                Some(false)
241            }
242        }
243    }
244
245    pub fn max(&self) -> Option<bool> {
246        if self.is_empty() || self.null_count() == self.len() {
247            return None;
248        }
249        if self.any() {
250            Some(true)
251        } else {
252            Some(false)
253        }
254    }
255    pub fn mean(&self) -> Option<f64> {
256        if self.is_empty() || self.null_count() == self.len() {
257            return None;
258        }
259        self.sum()
260            .map(|sum| sum as f64 / (self.len() - self.null_count()) as f64)
261    }
262}
263
264// Needs the same trait bounds as the implementation of ChunkedArray<T> of dyn Series.
265impl<T> ChunkAggSeries for ChunkedArray<T>
266where
267    T: PolarsNumericType,
268    T::Native: WrappingSum,
269    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
270    ChunkedArray<T>: IntoSeries,
271{
272    fn sum_reduce(&self) -> Scalar {
273        let v: Option<T::Native> = self.sum();
274        Scalar::new(T::get_dtype(), v.into())
275    }
276
277    fn max_reduce(&self) -> Scalar {
278        let v = ChunkAgg::max(self);
279        Scalar::new(T::get_dtype(), v.into())
280    }
281
282    fn min_reduce(&self) -> Scalar {
283        let v = ChunkAgg::min(self);
284        Scalar::new(T::get_dtype(), v.into())
285    }
286
287    fn prod_reduce(&self) -> Scalar {
288        let mut prod = T::Native::one();
289
290        for arr in self.downcast_iter() {
291            for v in arr.into_iter().flatten() {
292                prod = prod * *v
293            }
294        }
295        Scalar::new(T::get_dtype(), prod.into())
296    }
297}
298
299impl<T> VarAggSeries for ChunkedArray<T>
300where
301    T: PolarsIntegerType,
302    ChunkedArray<T>: ChunkVar,
303{
304    fn var_reduce(&self, ddof: u8) -> Scalar {
305        let v = self.var(ddof);
306        Scalar::new(DataType::Float64, v.into())
307    }
308
309    fn std_reduce(&self, ddof: u8) -> Scalar {
310        let v = self.std(ddof);
311        Scalar::new(DataType::Float64, v.into())
312    }
313}
314
315impl VarAggSeries for Float32Chunked {
316    fn var_reduce(&self, ddof: u8) -> Scalar {
317        let v = self.var(ddof).map(|v| v as f32);
318        Scalar::new(DataType::Float32, v.into())
319    }
320
321    fn std_reduce(&self, ddof: u8) -> Scalar {
322        let v = self.std(ddof).map(|v| v as f32);
323        Scalar::new(DataType::Float32, v.into())
324    }
325}
326
327impl VarAggSeries for Float64Chunked {
328    fn var_reduce(&self, ddof: u8) -> Scalar {
329        let v = self.var(ddof);
330        Scalar::new(DataType::Float64, v.into())
331    }
332
333    fn std_reduce(&self, ddof: u8) -> Scalar {
334        let v = self.std(ddof);
335        Scalar::new(DataType::Float64, v.into())
336    }
337}
338
339impl<T> QuantileAggSeries for ChunkedArray<T>
340where
341    T: PolarsIntegerType,
342    T::Native: Ord + WrappingSum,
343{
344    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
345        let v = self.quantile(quantile, method)?;
346        Ok(Scalar::new(DataType::Float64, v.into()))
347    }
348
349    fn median_reduce(&self) -> Scalar {
350        let v = self.median();
351        Scalar::new(DataType::Float64, v.into())
352    }
353}
354
355impl QuantileAggSeries for Float32Chunked {
356    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
357        let v = self.quantile(quantile, method)?;
358        Ok(Scalar::new(DataType::Float32, v.into()))
359    }
360
361    fn median_reduce(&self) -> Scalar {
362        let v = self.median();
363        Scalar::new(DataType::Float32, v.into())
364    }
365}
366
367impl QuantileAggSeries for Float64Chunked {
368    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
369        let v = self.quantile(quantile, method)?;
370        Ok(Scalar::new(DataType::Float64, v.into()))
371    }
372
373    fn median_reduce(&self) -> Scalar {
374        let v = self.median();
375        Scalar::new(DataType::Float64, v.into())
376    }
377}
378
379impl ChunkAggSeries for BooleanChunked {
380    fn sum_reduce(&self) -> Scalar {
381        let v = self.sum();
382        Scalar::new(IDX_DTYPE, v.into())
383    }
384    fn max_reduce(&self) -> Scalar {
385        let v = self.max();
386        Scalar::new(DataType::Boolean, v.into())
387    }
388    fn min_reduce(&self) -> Scalar {
389        let v = self.min();
390        Scalar::new(DataType::Boolean, v.into())
391    }
392}
393
394impl StringChunked {
395    pub(crate) fn max_str(&self) -> Option<&str> {
396        if self.is_empty() {
397            return None;
398        }
399        match self.is_sorted_flag() {
400            IsSorted::Ascending => {
401                self.last_non_null().and_then(|idx| {
402                    // SAFETY: last_non_null returns in bound index
403                    unsafe { self.get_unchecked(idx) }
404                })
405            },
406            IsSorted::Descending => {
407                self.first_non_null().and_then(|idx| {
408                    // SAFETY: first_non_null returns in bound index
409                    unsafe { self.get_unchecked(idx) }
410                })
411            },
412            IsSorted::Not => self
413                .downcast_iter()
414                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
415                .reduce(MinMax::max_ignore_nan),
416        }
417    }
418    pub(crate) fn min_str(&self) -> Option<&str> {
419        if self.is_empty() {
420            return None;
421        }
422        match self.is_sorted_flag() {
423            IsSorted::Ascending => {
424                self.first_non_null().and_then(|idx| {
425                    // SAFETY: first_non_null returns in bound index
426                    unsafe { self.get_unchecked(idx) }
427                })
428            },
429            IsSorted::Descending => {
430                self.last_non_null().and_then(|idx| {
431                    // SAFETY: last_non_null returns in bound index
432                    unsafe { self.get_unchecked(idx) }
433                })
434            },
435            IsSorted::Not => self
436                .downcast_iter()
437                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
438                .reduce(MinMax::min_ignore_nan),
439        }
440    }
441}
442
443impl ChunkAggSeries for StringChunked {
444    fn max_reduce(&self) -> Scalar {
445        let av: AnyValue = self.max_str().into();
446        Scalar::new(DataType::String, av.into_static())
447    }
448    fn min_reduce(&self) -> Scalar {
449        let av: AnyValue = self.min_str().into();
450        Scalar::new(DataType::String, av.into_static())
451    }
452}
453
454#[cfg(feature = "dtype-categorical")]
455impl CategoricalChunked {
456    fn min_categorical(&self) -> Option<&str> {
457        if self.is_empty() || self.null_count() == self.len() {
458            return None;
459        }
460        if self.uses_lexical_ordering() {
461            // Fast path where all categories are used
462            if self._can_fast_unique() {
463                self.get_rev_map().get_categories().min_ignore_nan_kernel()
464            } else {
465                let rev_map = self.get_rev_map();
466                // SAFETY:
467                // Indices are in bounds
468                self.physical()
469                    .iter()
470                    .flat_map(|opt_el: Option<u32>| {
471                        opt_el.map(|el| unsafe { rev_map.get_unchecked(el) })
472                    })
473                    .min()
474            }
475        } else {
476            // SAFETY:
477            // Indices are in bounds
478            self.physical()
479                .min()
480                .map(|el| unsafe { self.get_rev_map().get_unchecked(el) })
481        }
482    }
483
484    fn max_categorical(&self) -> Option<&str> {
485        if self.is_empty() || self.null_count() == self.len() {
486            return None;
487        }
488        if self.uses_lexical_ordering() {
489            // Fast path where all categories are used
490            if self._can_fast_unique() {
491                self.get_rev_map().get_categories().max_ignore_nan_kernel()
492            } else {
493                let rev_map = self.get_rev_map();
494                // SAFETY:
495                // Indices are in bounds
496                self.physical()
497                    .iter()
498                    .flat_map(|opt_el: Option<u32>| {
499                        opt_el.map(|el| unsafe { rev_map.get_unchecked(el) })
500                    })
501                    .max()
502            }
503        } else {
504            // SAFETY:
505            // Indices are in bounds
506            self.physical()
507                .max()
508                .map(|el| unsafe { self.get_rev_map().get_unchecked(el) })
509        }
510    }
511}
512
513#[cfg(feature = "dtype-categorical")]
514impl ChunkAggSeries for CategoricalChunked {
515    fn min_reduce(&self) -> Scalar {
516        match self.dtype() {
517            DataType::Enum(r, _) => match self.physical().min() {
518                None => Scalar::new(self.dtype().clone(), AnyValue::Null),
519                Some(v) => {
520                    let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else {
521                        unreachable!()
522                    };
523                    Scalar::new(
524                        self.dtype().clone(),
525                        AnyValue::EnumOwned(
526                            v,
527                            r.as_ref().unwrap().clone(),
528                            SyncPtr::from_const(arr as *const _),
529                        ),
530                    )
531                },
532            },
533            DataType::Categorical(_, _) => {
534                let av: AnyValue = self.min_categorical().into();
535                Scalar::new(DataType::String, av.into_static())
536            },
537            _ => unreachable!(),
538        }
539    }
540    fn max_reduce(&self) -> Scalar {
541        match self.dtype() {
542            DataType::Enum(r, _) => match self.physical().max() {
543                None => Scalar::new(self.dtype().clone(), AnyValue::Null),
544                Some(v) => {
545                    let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else {
546                        unreachable!()
547                    };
548                    Scalar::new(
549                        self.dtype().clone(),
550                        AnyValue::EnumOwned(
551                            v,
552                            r.as_ref().unwrap().clone(),
553                            SyncPtr::from_const(arr as *const _),
554                        ),
555                    )
556                },
557            },
558            DataType::Categorical(_, _) => {
559                let av: AnyValue = self.max_categorical().into();
560                Scalar::new(DataType::String, av.into_static())
561            },
562            _ => unreachable!(),
563        }
564    }
565}
566
567impl BinaryChunked {
568    pub fn max_binary(&self) -> Option<&[u8]> {
569        if self.is_empty() {
570            return None;
571        }
572        match self.is_sorted_flag() {
573            IsSorted::Ascending => {
574                self.last_non_null().and_then(|idx| {
575                    // SAFETY: last_non_null returns in bound index.
576                    unsafe { self.get_unchecked(idx) }
577                })
578            },
579            IsSorted::Descending => {
580                self.first_non_null().and_then(|idx| {
581                    // SAFETY: first_non_null returns in bound index.
582                    unsafe { self.get_unchecked(idx) }
583                })
584            },
585            IsSorted::Not => self
586                .downcast_iter()
587                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
588                .reduce(MinMax::max_ignore_nan),
589        }
590    }
591
592    pub fn min_binary(&self) -> Option<&[u8]> {
593        if self.is_empty() {
594            return None;
595        }
596        match self.is_sorted_flag() {
597            IsSorted::Ascending => {
598                self.first_non_null().and_then(|idx| {
599                    // SAFETY: first_non_null returns in bound index.
600                    unsafe { self.get_unchecked(idx) }
601                })
602            },
603            IsSorted::Descending => {
604                self.last_non_null().and_then(|idx| {
605                    // SAFETY: last_non_null returns in bound index.
606                    unsafe { self.get_unchecked(idx) }
607                })
608            },
609            IsSorted::Not => self
610                .downcast_iter()
611                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
612                .reduce(MinMax::min_ignore_nan),
613        }
614    }
615}
616
617impl ChunkAggSeries for BinaryChunked {
618    fn sum_reduce(&self) -> Scalar {
619        unimplemented!()
620    }
621    fn max_reduce(&self) -> Scalar {
622        let av: AnyValue = self.max_binary().into();
623        Scalar::new(self.dtype().clone(), av.into_static())
624    }
625    fn min_reduce(&self) -> Scalar {
626        let av: AnyValue = self.min_binary().into();
627        Scalar::new(self.dtype().clone(), av.into_static())
628    }
629}
630
631#[cfg(feature = "object")]
632impl<T: PolarsObject> ChunkAggSeries for ObjectChunked<T> {}
633
634#[cfg(test)]
635mod test {
636    use crate::prelude::*;
637
638    #[test]
639    fn test_var() {
640        // Validated with numpy. Note that numpy uses ddof as an argument which
641        // influences results. The default ddof=0, we chose ddof=1, which is
642        // standard in statistics.
643        let ca1 = Int32Chunked::new(PlSmallStr::EMPTY, &[5, 8, 9, 5, 0]);
644        let ca2 = Int32Chunked::new(
645            PlSmallStr::EMPTY,
646            &[
647                Some(5),
648                None,
649                Some(8),
650                Some(9),
651                None,
652                Some(5),
653                Some(0),
654                None,
655            ],
656        );
657        for ca in &[ca1, ca2] {
658            let out = ca.var(1);
659            assert_eq!(out, Some(12.3));
660            let out = ca.std(1).unwrap();
661            assert!((3.5071355833500366 - out).abs() < 0.000000001);
662        }
663    }
664
665    #[test]
666    fn test_agg_float() {
667        let ca1 = Float32Chunked::new(PlSmallStr::from_static("a"), &[1.0, f32::NAN]);
668        let ca2 = Float32Chunked::new(PlSmallStr::from_static("b"), &[f32::NAN, 1.0]);
669        assert_eq!(ca1.min(), ca2.min());
670        let ca1 = Float64Chunked::new(PlSmallStr::from_static("a"), &[1.0, f64::NAN]);
671        let ca2 = Float64Chunked::from_slice(PlSmallStr::from_static("b"), &[f64::NAN, 1.0]);
672        assert_eq!(ca1.min(), ca2.min());
673        println!("{:?}", (ca1.min(), ca2.min()))
674    }
675
676    #[test]
677    fn test_median() {
678        let ca = UInt32Chunked::new(
679            PlSmallStr::from_static("a"),
680            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
681        );
682        assert_eq!(ca.median(), Some(3.0));
683        let ca = UInt32Chunked::new(
684            PlSmallStr::from_static("a"),
685            &[
686                None,
687                Some(7),
688                Some(6),
689                Some(2),
690                Some(1),
691                None,
692                Some(3),
693                Some(5),
694                None,
695                Some(4),
696            ],
697        );
698        assert_eq!(ca.median(), Some(4.0));
699
700        let ca = Float32Chunked::from_slice(
701            PlSmallStr::EMPTY,
702            &[
703                0.166189, 0.166559, 0.168517, 0.169393, 0.175272, 0.233167, 0.238787, 0.266562,
704                0.26903, 0.285792, 0.292801, 0.293429, 0.301706, 0.308534, 0.331489, 0.346095,
705                0.367644, 0.369939, 0.372074, 0.41014, 0.415789, 0.421781, 0.427725, 0.465363,
706                0.500208, 2.621727, 2.803311, 3.868526,
707            ],
708        );
709        assert!((ca.median().unwrap() - 0.3200115).abs() < 0.0001)
710    }
711
712    #[test]
713    fn test_mean() {
714        let ca = Float32Chunked::new(PlSmallStr::EMPTY, &[Some(1.0), Some(2.0), None]);
715        assert_eq!(ca.mean().unwrap(), 1.5);
716        assert_eq!(
717            ca.into_series()
718                .mean_reduce()
719                .value()
720                .extract::<f32>()
721                .unwrap(),
722            1.5
723        );
724        // all null values case
725        let ca = Float32Chunked::full_null(PlSmallStr::EMPTY, 3);
726        assert_eq!(ca.mean(), None);
727        assert_eq!(
728            ca.into_series().mean_reduce().value().extract::<f32>(),
729            None
730        );
731    }
732
733    #[test]
734    fn test_quantile_all_null() {
735        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
736        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
737        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
738        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
739
740        let methods = vec![
741            QuantileMethod::Nearest,
742            QuantileMethod::Lower,
743            QuantileMethod::Higher,
744            QuantileMethod::Midpoint,
745            QuantileMethod::Linear,
746            QuantileMethod::Equiprobable,
747        ];
748
749        for method in methods {
750            assert_eq!(test_f32.quantile(0.9, method).unwrap(), None);
751            assert_eq!(test_i32.quantile(0.9, method).unwrap(), None);
752            assert_eq!(test_f64.quantile(0.9, method).unwrap(), None);
753            assert_eq!(test_i64.quantile(0.9, method).unwrap(), None);
754        }
755    }
756
757    #[test]
758    fn test_quantile_single_value() {
759        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
760        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
761        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
762        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
763
764        let methods = vec![
765            QuantileMethod::Nearest,
766            QuantileMethod::Lower,
767            QuantileMethod::Higher,
768            QuantileMethod::Midpoint,
769            QuantileMethod::Linear,
770            QuantileMethod::Equiprobable,
771        ];
772
773        for method in methods {
774            assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0));
775            assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0));
776            assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0));
777            assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0));
778        }
779    }
780
781    #[test]
782    fn test_quantile_min_max() {
783        let test_f32 = Float32Chunked::from_slice_options(
784            PlSmallStr::EMPTY,
785            &[None, Some(1f32), Some(5f32), Some(1f32)],
786        );
787        let test_i32 = Int32Chunked::from_slice_options(
788            PlSmallStr::EMPTY,
789            &[None, Some(1i32), Some(5i32), Some(1i32)],
790        );
791        let test_f64 = Float64Chunked::from_slice_options(
792            PlSmallStr::EMPTY,
793            &[None, Some(1f64), Some(5f64), Some(1f64)],
794        );
795        let test_i64 = Int64Chunked::from_slice_options(
796            PlSmallStr::EMPTY,
797            &[None, Some(1i64), Some(5i64), Some(1i64)],
798        );
799
800        let methods = vec![
801            QuantileMethod::Nearest,
802            QuantileMethod::Lower,
803            QuantileMethod::Higher,
804            QuantileMethod::Midpoint,
805            QuantileMethod::Linear,
806            QuantileMethod::Equiprobable,
807        ];
808
809        for method in methods {
810            assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min());
811            assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max());
812
813            assert_eq!(
814                test_i32.quantile(0.0, method).unwrap().unwrap(),
815                test_i32.min().unwrap() as f64
816            );
817            assert_eq!(
818                test_i32.quantile(1.0, method).unwrap().unwrap(),
819                test_i32.max().unwrap() as f64
820            );
821
822            assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min());
823            assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max());
824            assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median());
825
826            assert_eq!(
827                test_i64.quantile(0.0, method).unwrap().unwrap(),
828                test_i64.min().unwrap() as f64
829            );
830            assert_eq!(
831                test_i64.quantile(1.0, method).unwrap().unwrap(),
832                test_i64.max().unwrap() as f64
833            );
834        }
835    }
836
837    #[test]
838    fn test_quantile() {
839        let ca = UInt32Chunked::new(
840            PlSmallStr::from_static("a"),
841            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
842        );
843
844        assert_eq!(
845            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
846            Some(1.0)
847        );
848        assert_eq!(
849            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
850            Some(5.0)
851        );
852        assert_eq!(
853            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
854            Some(3.0)
855        );
856
857        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
858        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0));
859        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0));
860
861        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
862        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0));
863        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0));
864
865        assert_eq!(
866            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
867            Some(1.5)
868        );
869        assert_eq!(
870            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
871            Some(4.5)
872        );
873        assert_eq!(
874            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
875            Some(3.5)
876        );
877
878        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4));
879        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6));
880        assert!(
881            (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001
882        );
883
884        assert_eq!(
885            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
886            Some(1.0)
887        );
888        assert_eq!(
889            ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(),
890            Some(2.0)
891        );
892        assert_eq!(
893            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
894            Some(3.0)
895        );
896
897        let ca = UInt32Chunked::new(
898            PlSmallStr::from_static("a"),
899            &[
900                None,
901                Some(7),
902                Some(6),
903                Some(2),
904                Some(1),
905                None,
906                Some(3),
907                Some(5),
908                None,
909                Some(4),
910            ],
911        );
912
913        assert_eq!(
914            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
915            Some(2.0)
916        );
917        assert_eq!(
918            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
919            Some(6.0)
920        );
921        assert_eq!(
922            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
923            Some(5.0)
924        );
925
926        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
927        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0));
928        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0));
929
930        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
931        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0));
932        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0));
933
934        assert_eq!(
935            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
936            Some(1.5)
937        );
938        assert_eq!(
939            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
940            Some(6.5)
941        );
942        assert_eq!(
943            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
944            Some(4.5)
945        );
946
947        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6));
948        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4));
949        assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6));
950
951        assert_eq!(
952            ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(),
953            Some(1.0)
954        );
955        assert_eq!(
956            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
957            Some(2.0)
958        );
959        assert_eq!(
960            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
961            Some(5.0)
962        );
963    }
964}