polars_core/chunked_array/ops/aggregate/
mod.rs

1//! Implementations of the ChunkAgg trait.
2mod quantile;
3mod var;
4
5use arrow::types::NativeType;
6use num_traits::{Float, One, ToPrimitive, Zero};
7use polars_compute::float_sum;
8use polars_compute::min_max::MinMaxKernel;
9use polars_compute::rolling::QuantileMethod;
10use polars_compute::sum::{WrappingSum, wrapping_sum_arr};
11use polars_utils::min_max::MinMax;
12use polars_utils::sync::SyncPtr;
13pub use quantile::*;
14pub use var::*;
15
16use super::float_sorted_arg_max::{
17    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
18};
19use crate::chunked_array::ChunkedArray;
20use crate::datatypes::{BooleanChunked, PolarsNumericType};
21use crate::prelude::*;
22use crate::series::IsSorted;
23
24/// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations.
25pub trait ChunkAggSeries {
26    /// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1.
27    fn sum_reduce(&self) -> Scalar {
28        unimplemented!()
29    }
30    /// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1.
31    fn max_reduce(&self) -> Scalar {
32        unimplemented!()
33    }
34    /// Get the min of the [`ChunkedArray`] as a new [`Series`] of length 1.
35    fn min_reduce(&self) -> Scalar {
36        unimplemented!()
37    }
38    /// Get the product of the [`ChunkedArray`] as a new [`Series`] of length 1.
39    fn prod_reduce(&self) -> Scalar {
40        unimplemented!()
41    }
42}
43
44fn sum<T>(array: &PrimitiveArray<T>) -> T
45where
46    T: NumericNative + NativeType + WrappingSum,
47{
48    if array.null_count() == array.len() {
49        return T::default();
50    }
51
52    if T::is_float() {
53        unsafe {
54            if T::is_f32() {
55                let f32_arr =
56                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f32>>(array);
57                let sum = float_sum::sum_arr_as_f32(f32_arr);
58                std::mem::transmute_copy::<f32, T>(&sum)
59            } else if T::is_f64() {
60                let f64_arr =
61                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f64>>(array);
62                let sum = float_sum::sum_arr_as_f64(f64_arr);
63                std::mem::transmute_copy::<f64, T>(&sum)
64            } else {
65                unreachable!("only supported float types are f32 and f64");
66            }
67        }
68    } else {
69        wrapping_sum_arr(array)
70    }
71}
72
73impl<T> ChunkAgg<T::Native> for ChunkedArray<T>
74where
75    T: PolarsNumericType,
76    T::Native: WrappingSum,
77    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
78{
79    fn sum(&self) -> Option<T::Native> {
80        Some(
81            self.downcast_iter()
82                .map(sum)
83                .fold(T::Native::zero(), |acc, v| acc + v),
84        )
85    }
86
87    fn _sum_as_f64(&self) -> f64 {
88        self.downcast_iter().map(float_sum::sum_arr_as_f64).sum()
89    }
90
91    fn min(&self) -> Option<T::Native> {
92        if self.null_count() == self.len() {
93            return None;
94        }
95
96        // There is at least one non-null value.
97
98        match self.is_sorted_flag() {
99            IsSorted::Ascending => {
100                let idx = self.first_non_null().unwrap();
101                unsafe { self.get_unchecked(idx) }
102            },
103            IsSorted::Descending => {
104                let idx = self.last_non_null().unwrap();
105                unsafe { self.get_unchecked(idx) }
106            },
107            IsSorted::Not => self
108                .downcast_iter()
109                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
110                .reduce(MinMax::min_ignore_nan),
111        }
112    }
113
114    fn max(&self) -> Option<T::Native> {
115        if self.null_count() == self.len() {
116            return None;
117        }
118        // There is at least one non-null value.
119
120        match self.is_sorted_flag() {
121            IsSorted::Ascending => {
122                let idx = if T::get_static_dtype().is_float() {
123                    float_arg_max_sorted_ascending(self)
124                } else {
125                    self.last_non_null().unwrap()
126                };
127
128                unsafe { self.get_unchecked(idx) }
129            },
130            IsSorted::Descending => {
131                let idx = if T::get_static_dtype().is_float() {
132                    float_arg_max_sorted_descending(self)
133                } else {
134                    self.first_non_null().unwrap()
135                };
136
137                unsafe { self.get_unchecked(idx) }
138            },
139            IsSorted::Not => self
140                .downcast_iter()
141                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
142                .reduce(MinMax::max_ignore_nan),
143        }
144    }
145
146    fn min_max(&self) -> Option<(T::Native, T::Native)> {
147        if self.null_count() == self.len() {
148            return None;
149        }
150        // There is at least one non-null value.
151
152        match self.is_sorted_flag() {
153            IsSorted::Ascending => {
154                let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) };
155                let max = {
156                    let idx = if T::get_static_dtype().is_float() {
157                        float_arg_max_sorted_ascending(self)
158                    } else {
159                        self.last_non_null().unwrap()
160                    };
161
162                    unsafe { self.get_unchecked(idx) }
163                };
164                min.zip(max)
165            },
166            IsSorted::Descending => {
167                let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) };
168                let max = {
169                    let idx = if T::get_static_dtype().is_float() {
170                        float_arg_max_sorted_descending(self)
171                    } else {
172                        self.first_non_null().unwrap()
173                    };
174
175                    unsafe { self.get_unchecked(idx) }
176                };
177
178                min.zip(max)
179            },
180            IsSorted::Not => self
181                .downcast_iter()
182                .filter_map(MinMaxKernel::min_max_ignore_nan_kernel)
183                .reduce(|(min1, max1), (min2, max2)| {
184                    (
185                        MinMax::min_ignore_nan(min1, min2),
186                        MinMax::max_ignore_nan(max1, max2),
187                    )
188                }),
189        }
190    }
191
192    fn mean(&self) -> Option<f64> {
193        let count = self.len() - self.null_count();
194        if count == 0 {
195            return None;
196        }
197        Some(self._sum_as_f64() / count as f64)
198    }
199}
200
201/// Booleans are cast to 1 or 0.
202impl BooleanChunked {
203    pub fn sum(&self) -> Option<IdxSize> {
204        Some(if self.is_empty() {
205            0
206        } else {
207            self.downcast_iter()
208                .map(|arr| match arr.validity() {
209                    Some(validity) => {
210                        (arr.len() - (validity & arr.values()).unset_bits()) as IdxSize
211                    },
212                    None => (arr.len() - arr.values().unset_bits()) as IdxSize,
213                })
214                .sum()
215        })
216    }
217
218    pub fn min(&self) -> Option<bool> {
219        let nc = self.null_count();
220        let len = self.len();
221        if self.is_empty() || nc == len {
222            return None;
223        }
224        if nc == 0 {
225            if self.all() { Some(true) } else { Some(false) }
226        } else {
227            // we can unwrap as we already checked empty and all null above
228            if (self.sum().unwrap() + nc as IdxSize) == len as IdxSize {
229                Some(true)
230            } else {
231                Some(false)
232            }
233        }
234    }
235
236    pub fn max(&self) -> Option<bool> {
237        if self.is_empty() || self.null_count() == self.len() {
238            return None;
239        }
240        if self.any() { Some(true) } else { Some(false) }
241    }
242    pub fn mean(&self) -> Option<f64> {
243        if self.is_empty() || self.null_count() == self.len() {
244            return None;
245        }
246        self.sum()
247            .map(|sum| sum as f64 / (self.len() - self.null_count()) as f64)
248    }
249}
250
251// Needs the same trait bounds as the implementation of ChunkedArray<T> of dyn Series.
252impl<T> ChunkAggSeries for ChunkedArray<T>
253where
254    T: PolarsNumericType,
255    T::Native: WrappingSum,
256    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
257{
258    fn sum_reduce(&self) -> Scalar {
259        let v: Option<T::Native> = self.sum();
260        Scalar::new(T::get_static_dtype(), v.into())
261    }
262
263    fn max_reduce(&self) -> Scalar {
264        let v = ChunkAgg::max(self);
265        Scalar::new(T::get_static_dtype(), v.into())
266    }
267
268    fn min_reduce(&self) -> Scalar {
269        let v = ChunkAgg::min(self);
270        Scalar::new(T::get_static_dtype(), v.into())
271    }
272
273    fn prod_reduce(&self) -> Scalar {
274        let mut prod = T::Native::one();
275
276        for arr in self.downcast_iter() {
277            for v in arr.into_iter().flatten() {
278                prod = prod * *v
279            }
280        }
281        Scalar::new(T::get_static_dtype(), prod.into())
282    }
283}
284
285impl<T> VarAggSeries for ChunkedArray<T>
286where
287    T: PolarsIntegerType,
288    ChunkedArray<T>: ChunkVar,
289{
290    fn var_reduce(&self, ddof: u8) -> Scalar {
291        let v = self.var(ddof);
292        Scalar::new(DataType::Float64, v.into())
293    }
294
295    fn std_reduce(&self, ddof: u8) -> Scalar {
296        let v = self.std(ddof);
297        Scalar::new(DataType::Float64, v.into())
298    }
299}
300
301impl VarAggSeries for Float32Chunked {
302    fn var_reduce(&self, ddof: u8) -> Scalar {
303        let v = self.var(ddof).map(|v| v as f32);
304        Scalar::new(DataType::Float32, v.into())
305    }
306
307    fn std_reduce(&self, ddof: u8) -> Scalar {
308        let v = self.std(ddof).map(|v| v as f32);
309        Scalar::new(DataType::Float32, v.into())
310    }
311}
312
313impl VarAggSeries for Float64Chunked {
314    fn var_reduce(&self, ddof: u8) -> Scalar {
315        let v = self.var(ddof);
316        Scalar::new(DataType::Float64, v.into())
317    }
318
319    fn std_reduce(&self, ddof: u8) -> Scalar {
320        let v = self.std(ddof);
321        Scalar::new(DataType::Float64, v.into())
322    }
323}
324
325impl<T> QuantileAggSeries for ChunkedArray<T>
326where
327    T: PolarsIntegerType,
328    T::Native: Ord + WrappingSum,
329{
330    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
331        let v = self.quantile(quantile, method)?;
332        Ok(Scalar::new(DataType::Float64, v.into()))
333    }
334
335    fn median_reduce(&self) -> Scalar {
336        let v = self.median();
337        Scalar::new(DataType::Float64, v.into())
338    }
339}
340
341impl QuantileAggSeries for Float32Chunked {
342    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
343        let v = self.quantile(quantile, method)?;
344        Ok(Scalar::new(DataType::Float32, v.into()))
345    }
346
347    fn median_reduce(&self) -> Scalar {
348        let v = self.median();
349        Scalar::new(DataType::Float32, v.into())
350    }
351}
352
353impl QuantileAggSeries for Float64Chunked {
354    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
355        let v = self.quantile(quantile, method)?;
356        Ok(Scalar::new(DataType::Float64, v.into()))
357    }
358
359    fn median_reduce(&self) -> Scalar {
360        let v = self.median();
361        Scalar::new(DataType::Float64, v.into())
362    }
363}
364
365impl ChunkAggSeries for BooleanChunked {
366    fn sum_reduce(&self) -> Scalar {
367        let v = self.sum();
368        Scalar::new(IDX_DTYPE, v.into())
369    }
370    fn max_reduce(&self) -> Scalar {
371        let v = self.max();
372        Scalar::new(DataType::Boolean, v.into())
373    }
374    fn min_reduce(&self) -> Scalar {
375        let v = self.min();
376        Scalar::new(DataType::Boolean, v.into())
377    }
378}
379
380impl StringChunked {
381    pub(crate) fn max_str(&self) -> Option<&str> {
382        if self.is_empty() {
383            return None;
384        }
385        match self.is_sorted_flag() {
386            IsSorted::Ascending => {
387                self.last_non_null().and_then(|idx| {
388                    // SAFETY: last_non_null returns in bound index
389                    unsafe { self.get_unchecked(idx) }
390                })
391            },
392            IsSorted::Descending => {
393                self.first_non_null().and_then(|idx| {
394                    // SAFETY: first_non_null returns in bound index
395                    unsafe { self.get_unchecked(idx) }
396                })
397            },
398            IsSorted::Not => self
399                .downcast_iter()
400                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
401                .reduce(MinMax::max_ignore_nan),
402        }
403    }
404    pub(crate) fn min_str(&self) -> Option<&str> {
405        if self.is_empty() {
406            return None;
407        }
408        match self.is_sorted_flag() {
409            IsSorted::Ascending => {
410                self.first_non_null().and_then(|idx| {
411                    // SAFETY: first_non_null returns in bound index
412                    unsafe { self.get_unchecked(idx) }
413                })
414            },
415            IsSorted::Descending => {
416                self.last_non_null().and_then(|idx| {
417                    // SAFETY: last_non_null returns in bound index
418                    unsafe { self.get_unchecked(idx) }
419                })
420            },
421            IsSorted::Not => self
422                .downcast_iter()
423                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
424                .reduce(MinMax::min_ignore_nan),
425        }
426    }
427}
428
429impl ChunkAggSeries for StringChunked {
430    fn max_reduce(&self) -> Scalar {
431        let av: AnyValue = self.max_str().into();
432        Scalar::new(DataType::String, av.into_static())
433    }
434    fn min_reduce(&self) -> Scalar {
435        let av: AnyValue = self.min_str().into();
436        Scalar::new(DataType::String, av.into_static())
437    }
438}
439
440#[cfg(feature = "dtype-categorical")]
441impl CategoricalChunked {
442    fn min_categorical(&self) -> Option<u32> {
443        if self.is_empty() || self.null_count() == self.len() {
444            return None;
445        }
446        if self.uses_lexical_ordering() {
447            let rev_map = self.get_rev_map();
448            // Fast path where all categories are used
449            let c = if self._can_fast_unique() {
450                rev_map.get_categories().min_ignore_nan_kernel()
451            } else {
452                // SAFETY:
453                // Indices are in bounds
454                self.physical()
455                    .iter()
456                    .flat_map(|opt_el: Option<u32>| {
457                        opt_el.map(|el| unsafe { rev_map.get_unchecked(el) })
458                    })
459                    .min()
460            };
461            rev_map.find(c.unwrap())
462        } else {
463            self.physical().min()
464        }
465    }
466
467    fn max_categorical(&self) -> Option<u32> {
468        if self.is_empty() || self.null_count() == self.len() {
469            return None;
470        }
471        if self.uses_lexical_ordering() {
472            let rev_map = self.get_rev_map();
473            // Fast path where all categories are used
474            let c = if self._can_fast_unique() {
475                rev_map.get_categories().max_ignore_nan_kernel()
476            } else {
477                // SAFETY:
478                // Indices are in bounds
479                self.physical()
480                    .iter()
481                    .flat_map(|opt_el: Option<u32>| {
482                        opt_el.map(|el| unsafe { rev_map.get_unchecked(el) })
483                    })
484                    .max()
485            };
486            rev_map.find(c.unwrap())
487        } else {
488            self.physical().max()
489        }
490    }
491}
492
493#[cfg(feature = "dtype-categorical")]
494impl ChunkAggSeries for CategoricalChunked {
495    fn min_reduce(&self) -> Scalar {
496        match self.dtype() {
497            DataType::Enum(r, _) => match self.physical().min() {
498                None => Scalar::new(self.dtype().clone(), AnyValue::Null),
499                Some(v) => {
500                    let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else {
501                        unreachable!()
502                    };
503                    Scalar::new(
504                        self.dtype().clone(),
505                        AnyValue::EnumOwned(
506                            v,
507                            r.as_ref().unwrap().clone(),
508                            SyncPtr::from_const(arr as *const _),
509                        ),
510                    )
511                },
512            },
513            DataType::Categorical(r, _) => match self.min_categorical() {
514                None => Scalar::new(self.dtype().clone(), AnyValue::Null),
515                Some(v) => {
516                    let r = r.as_ref().unwrap();
517                    let arr = match &**r {
518                        RevMapping::Local(arr, _) => arr,
519                        RevMapping::Global(_, arr, _) => arr,
520                    };
521                    Scalar::new(
522                        self.dtype().clone(),
523                        AnyValue::CategoricalOwned(
524                            v,
525                            r.clone(),
526                            SyncPtr::from_const(arr as *const _),
527                        ),
528                    )
529                },
530            },
531            _ => unreachable!(),
532        }
533    }
534    fn max_reduce(&self) -> Scalar {
535        match self.dtype() {
536            DataType::Enum(r, _) => match self.physical().max() {
537                None => Scalar::new(self.dtype().clone(), AnyValue::Null),
538                Some(v) => {
539                    let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else {
540                        unreachable!()
541                    };
542                    Scalar::new(
543                        self.dtype().clone(),
544                        AnyValue::EnumOwned(
545                            v,
546                            r.as_ref().unwrap().clone(),
547                            SyncPtr::from_const(arr as *const _),
548                        ),
549                    )
550                },
551            },
552            DataType::Categorical(r, _) => match self.max_categorical() {
553                None => Scalar::new(self.dtype().clone(), AnyValue::Null),
554                Some(v) => {
555                    let r = r.as_ref().unwrap();
556                    let arr = match &**r {
557                        RevMapping::Local(arr, _) => arr,
558                        RevMapping::Global(_, arr, _) => arr,
559                    };
560                    Scalar::new(
561                        self.dtype().clone(),
562                        AnyValue::CategoricalOwned(
563                            v,
564                            r.clone(),
565                            SyncPtr::from_const(arr as *const _),
566                        ),
567                    )
568                },
569            },
570            _ => unreachable!(),
571        }
572    }
573}
574
575impl BinaryChunked {
576    pub fn max_binary(&self) -> Option<&[u8]> {
577        if self.is_empty() {
578            return None;
579        }
580        match self.is_sorted_flag() {
581            IsSorted::Ascending => {
582                self.last_non_null().and_then(|idx| {
583                    // SAFETY: last_non_null returns in bound index.
584                    unsafe { self.get_unchecked(idx) }
585                })
586            },
587            IsSorted::Descending => {
588                self.first_non_null().and_then(|idx| {
589                    // SAFETY: first_non_null returns in bound index.
590                    unsafe { self.get_unchecked(idx) }
591                })
592            },
593            IsSorted::Not => self
594                .downcast_iter()
595                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
596                .reduce(MinMax::max_ignore_nan),
597        }
598    }
599
600    pub fn min_binary(&self) -> Option<&[u8]> {
601        if self.is_empty() {
602            return None;
603        }
604        match self.is_sorted_flag() {
605            IsSorted::Ascending => {
606                self.first_non_null().and_then(|idx| {
607                    // SAFETY: first_non_null returns in bound index.
608                    unsafe { self.get_unchecked(idx) }
609                })
610            },
611            IsSorted::Descending => {
612                self.last_non_null().and_then(|idx| {
613                    // SAFETY: last_non_null returns in bound index.
614                    unsafe { self.get_unchecked(idx) }
615                })
616            },
617            IsSorted::Not => self
618                .downcast_iter()
619                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
620                .reduce(MinMax::min_ignore_nan),
621        }
622    }
623}
624
625impl ChunkAggSeries for BinaryChunked {
626    fn sum_reduce(&self) -> Scalar {
627        unimplemented!()
628    }
629    fn max_reduce(&self) -> Scalar {
630        let av: AnyValue = self.max_binary().into();
631        Scalar::new(self.dtype().clone(), av.into_static())
632    }
633    fn min_reduce(&self) -> Scalar {
634        let av: AnyValue = self.min_binary().into();
635        Scalar::new(self.dtype().clone(), av.into_static())
636    }
637}
638
639#[cfg(feature = "object")]
640impl<T: PolarsObject> ChunkAggSeries for ObjectChunked<T> {}
641
642#[cfg(test)]
643mod test {
644    use polars_compute::rolling::QuantileMethod;
645
646    use crate::prelude::*;
647
648    #[test]
649    #[cfg(not(miri))]
650    fn test_var() {
651        // Validated with numpy. Note that numpy uses ddof as an argument which
652        // influences results. The default ddof=0, we chose ddof=1, which is
653        // standard in statistics.
654        let ca1 = Int32Chunked::new(PlSmallStr::EMPTY, &[5, 8, 9, 5, 0]);
655        let ca2 = Int32Chunked::new(
656            PlSmallStr::EMPTY,
657            &[
658                Some(5),
659                None,
660                Some(8),
661                Some(9),
662                None,
663                Some(5),
664                Some(0),
665                None,
666            ],
667        );
668        for ca in &[ca1, ca2] {
669            let out = ca.var(1);
670            assert_eq!(out, Some(12.3));
671            let out = ca.std(1).unwrap();
672            assert!((3.5071355833500366 - out).abs() < 0.000000001);
673        }
674    }
675
676    #[test]
677    fn test_agg_float() {
678        let ca1 = Float32Chunked::new(PlSmallStr::from_static("a"), &[1.0, f32::NAN]);
679        let ca2 = Float32Chunked::new(PlSmallStr::from_static("b"), &[f32::NAN, 1.0]);
680        assert_eq!(ca1.min(), ca2.min());
681        let ca1 = Float64Chunked::new(PlSmallStr::from_static("a"), &[1.0, f64::NAN]);
682        let ca2 = Float64Chunked::from_slice(PlSmallStr::from_static("b"), &[f64::NAN, 1.0]);
683        assert_eq!(ca1.min(), ca2.min());
684        println!("{:?}", (ca1.min(), ca2.min()))
685    }
686
687    #[test]
688    fn test_median() {
689        let ca = UInt32Chunked::new(
690            PlSmallStr::from_static("a"),
691            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
692        );
693        assert_eq!(ca.median(), Some(3.0));
694        let ca = UInt32Chunked::new(
695            PlSmallStr::from_static("a"),
696            &[
697                None,
698                Some(7),
699                Some(6),
700                Some(2),
701                Some(1),
702                None,
703                Some(3),
704                Some(5),
705                None,
706                Some(4),
707            ],
708        );
709        assert_eq!(ca.median(), Some(4.0));
710
711        let ca = Float32Chunked::from_slice(
712            PlSmallStr::EMPTY,
713            &[
714                0.166189, 0.166559, 0.168517, 0.169393, 0.175272, 0.233167, 0.238787, 0.266562,
715                0.26903, 0.285792, 0.292801, 0.293429, 0.301706, 0.308534, 0.331489, 0.346095,
716                0.367644, 0.369939, 0.372074, 0.41014, 0.415789, 0.421781, 0.427725, 0.465363,
717                0.500208, 2.621727, 2.803311, 3.868526,
718            ],
719        );
720        assert!((ca.median().unwrap() - 0.3200115).abs() < 0.0001)
721    }
722
723    #[test]
724    fn test_mean() {
725        let ca = Float32Chunked::new(PlSmallStr::EMPTY, &[Some(1.0), Some(2.0), None]);
726        assert_eq!(ca.mean().unwrap(), 1.5);
727        assert_eq!(
728            ca.into_series()
729                .mean_reduce()
730                .value()
731                .extract::<f32>()
732                .unwrap(),
733            1.5
734        );
735        // all null values case
736        let ca = Float32Chunked::full_null(PlSmallStr::EMPTY, 3);
737        assert_eq!(ca.mean(), None);
738        assert_eq!(
739            ca.into_series().mean_reduce().value().extract::<f32>(),
740            None
741        );
742    }
743
744    #[test]
745    fn test_quantile_all_null() {
746        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
747        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
748        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
749        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
750
751        let methods = vec![
752            QuantileMethod::Nearest,
753            QuantileMethod::Lower,
754            QuantileMethod::Higher,
755            QuantileMethod::Midpoint,
756            QuantileMethod::Linear,
757            QuantileMethod::Equiprobable,
758        ];
759
760        for method in methods {
761            assert_eq!(test_f32.quantile(0.9, method).unwrap(), None);
762            assert_eq!(test_i32.quantile(0.9, method).unwrap(), None);
763            assert_eq!(test_f64.quantile(0.9, method).unwrap(), None);
764            assert_eq!(test_i64.quantile(0.9, method).unwrap(), None);
765        }
766    }
767
768    #[test]
769    fn test_quantile_single_value() {
770        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
771        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
772        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
773        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
774
775        let methods = vec![
776            QuantileMethod::Nearest,
777            QuantileMethod::Lower,
778            QuantileMethod::Higher,
779            QuantileMethod::Midpoint,
780            QuantileMethod::Linear,
781            QuantileMethod::Equiprobable,
782        ];
783
784        for method in methods {
785            assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0));
786            assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0));
787            assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0));
788            assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0));
789        }
790    }
791
792    #[test]
793    fn test_quantile_min_max() {
794        let test_f32 = Float32Chunked::from_slice_options(
795            PlSmallStr::EMPTY,
796            &[None, Some(1f32), Some(5f32), Some(1f32)],
797        );
798        let test_i32 = Int32Chunked::from_slice_options(
799            PlSmallStr::EMPTY,
800            &[None, Some(1i32), Some(5i32), Some(1i32)],
801        );
802        let test_f64 = Float64Chunked::from_slice_options(
803            PlSmallStr::EMPTY,
804            &[None, Some(1f64), Some(5f64), Some(1f64)],
805        );
806        let test_i64 = Int64Chunked::from_slice_options(
807            PlSmallStr::EMPTY,
808            &[None, Some(1i64), Some(5i64), Some(1i64)],
809        );
810
811        let methods = vec![
812            QuantileMethod::Nearest,
813            QuantileMethod::Lower,
814            QuantileMethod::Higher,
815            QuantileMethod::Midpoint,
816            QuantileMethod::Linear,
817            QuantileMethod::Equiprobable,
818        ];
819
820        for method in methods {
821            assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min());
822            assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max());
823
824            assert_eq!(
825                test_i32.quantile(0.0, method).unwrap().unwrap(),
826                test_i32.min().unwrap() as f64
827            );
828            assert_eq!(
829                test_i32.quantile(1.0, method).unwrap().unwrap(),
830                test_i32.max().unwrap() as f64
831            );
832
833            assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min());
834            assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max());
835            assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median());
836
837            assert_eq!(
838                test_i64.quantile(0.0, method).unwrap().unwrap(),
839                test_i64.min().unwrap() as f64
840            );
841            assert_eq!(
842                test_i64.quantile(1.0, method).unwrap().unwrap(),
843                test_i64.max().unwrap() as f64
844            );
845        }
846    }
847
848    #[test]
849    fn test_quantile() {
850        let ca = UInt32Chunked::new(
851            PlSmallStr::from_static("a"),
852            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
853        );
854
855        assert_eq!(
856            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
857            Some(1.0)
858        );
859        assert_eq!(
860            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
861            Some(5.0)
862        );
863        assert_eq!(
864            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
865            Some(3.0)
866        );
867
868        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
869        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0));
870        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0));
871
872        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
873        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0));
874        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0));
875
876        assert_eq!(
877            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
878            Some(1.5)
879        );
880        assert_eq!(
881            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
882            Some(4.5)
883        );
884        assert_eq!(
885            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
886            Some(3.5)
887        );
888
889        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4));
890        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6));
891        assert!(
892            (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001
893        );
894
895        assert_eq!(
896            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
897            Some(1.0)
898        );
899        assert_eq!(
900            ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(),
901            Some(2.0)
902        );
903        assert_eq!(
904            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
905            Some(3.0)
906        );
907
908        let ca = UInt32Chunked::new(
909            PlSmallStr::from_static("a"),
910            &[
911                None,
912                Some(7),
913                Some(6),
914                Some(2),
915                Some(1),
916                None,
917                Some(3),
918                Some(5),
919                None,
920                Some(4),
921            ],
922        );
923
924        assert_eq!(
925            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
926            Some(2.0)
927        );
928        assert_eq!(
929            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
930            Some(6.0)
931        );
932        assert_eq!(
933            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
934            Some(5.0)
935        );
936
937        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
938        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0));
939        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0));
940
941        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
942        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0));
943        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0));
944
945        assert_eq!(
946            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
947            Some(1.5)
948        );
949        assert_eq!(
950            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
951            Some(6.5)
952        );
953        assert_eq!(
954            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
955            Some(4.5)
956        );
957
958        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6));
959        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4));
960        assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6));
961
962        assert_eq!(
963            ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(),
964            Some(1.0)
965        );
966        assert_eq!(
967            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
968            Some(2.0)
969        );
970        assert_eq!(
971            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
972            Some(5.0)
973        );
974    }
975}