Skip to main content

polars_core/chunked_array/ops/sort/
categorical.rs

1use num_traits::Zero;
2
3use super::*;
4
5impl<T: PolarsCategoricalType> CategoricalChunked<T> {
6    #[must_use]
7    pub fn sort_with(&self, options: SortOptions) -> CategoricalChunked<T> {
8        if !self.uses_lexical_ordering() {
9            let cats = self.physical().sort_with(options);
10            // SAFETY: we only reordered the indexes so we are still in bounds.
11            return unsafe {
12                CategoricalChunked::<T>::from_cats_and_dtype_unchecked(cats, self.dtype().clone())
13            };
14        }
15
16        let mut vals = self
17            .physical()
18            .into_iter()
19            .zip(self.iter_str())
20            .collect_trusted::<Vec<_>>();
21
22        sort_unstable_by_branch(vals.as_mut_slice(), options, |a, b| a.1.cmp(&b.1));
23
24        let mut cats = Vec::with_capacity(self.len());
25        let mut validity =
26            (self.null_count() > 0).then(|| BitmapBuilder::with_capacity(self.len()));
27
28        if self.null_count() > 0 && !options.nulls_last {
29            cats.resize(self.null_count(), T::Native::zero());
30            if let Some(validity) = &mut validity {
31                validity.extend_constant(self.null_count(), false);
32            }
33        }
34
35        let valid_slice = if options.descending {
36            &vals[..self.len() - self.null_count()]
37        } else {
38            &vals[self.null_count()..]
39        };
40        cats.extend(valid_slice.iter().map(|(idx, _v)| idx.unwrap()));
41        if let Some(validity) = &mut validity {
42            validity.extend_constant(self.len() - self.null_count(), true);
43        }
44
45        if self.null_count() > 0 && options.nulls_last {
46            cats.resize(self.len(), T::Native::zero());
47            if let Some(validity) = &mut validity {
48                validity.extend_constant(self.null_count(), false);
49            }
50        }
51
52        let arr = PrimitiveArray::from_vec(cats).with_validity(validity.map(|v| v.freeze()));
53        let cats = ChunkedArray::with_chunk(self.name().clone(), arr);
54
55        // SAFETY: we only reordered the indexes so we are still in bounds.
56        unsafe {
57            CategoricalChunked::<T>::from_cats_and_dtype_unchecked(cats, self.dtype().clone())
58        }
59    }
60
61    /// Returned a sorted `ChunkedArray`.
62    #[must_use]
63    pub fn sort(&self, descending: bool) -> CategoricalChunked<T> {
64        self.sort_with(SortOptions {
65            nulls_last: false,
66            descending,
67            multithreaded: true,
68            maintain_order: false,
69            limit: None,
70        })
71    }
72
73    /// Retrieve the indexes needed to sort this array.
74    pub fn arg_sort(&self, options: SortOptions) -> IdxCa {
75        if self.uses_lexical_ordering() {
76            let iters = [self.iter_str()];
77            arg_sort::arg_sort(
78                self.name().clone(),
79                iters,
80                options,
81                self.physical().null_count(),
82                self.len(),
83                IsSorted::Not,
84                false,
85            )
86        } else {
87            self.physical().arg_sort(options)
88        }
89    }
90
91    /// Retrieve the indices needed to sort this and the other arrays.
92    pub(crate) fn arg_sort_multiple(
93        &self,
94        by: &[Column],
95        options: &SortMultipleOptions,
96    ) -> PolarsResult<IdxCa> {
97        if self.uses_lexical_ordering() {
98            args_validate(self.physical(), by, &options.descending, "descending")?;
99            args_validate(self.physical(), by, &options.nulls_last, "nulls_last")?;
100            let mut count: IdxSize = 0;
101
102            // we use bytes to save a monomorphisized str impl
103            // as bytes already is used for binary and string sorting
104            let vals: Vec<_> = self
105                .iter_str()
106                .map(|v| {
107                    let i = count;
108                    count += 1;
109                    (i, v.map(|v| v.as_bytes()))
110                })
111                .collect_trusted();
112
113            arg_sort_multiple_impl(vals, by, options)
114        } else {
115            self.physical().arg_sort_multiple(by, options)
116        }
117    }
118}
119
120#[cfg(test)]
121mod test {
122    use crate::prelude::*;
123
124    fn assert_order(ca: &Categorical8Chunked, cmp: &[&str]) {
125        let s = ca.cast(&DataType::String).unwrap();
126        let ca = s.str().unwrap();
127        assert_eq!(ca.into_no_null_iter().collect::<Vec<_>>(), cmp);
128    }
129
130    #[test]
131    fn test_cat_lexical_sort() -> PolarsResult<()> {
132        let init = &["c", "b", "a", "d"];
133
134        let cats = Categories::new(
135            PlSmallStr::EMPTY,
136            PlSmallStr::EMPTY,
137            CategoricalPhysical::U8,
138        );
139        let s = Series::new(PlSmallStr::EMPTY, init).cast(&DataType::from_categories(cats))?;
140        let ca = s.cat8()?;
141
142        let out = ca.sort(false);
143        assert_order(&out, &["a", "b", "c", "d"]);
144
145        let out = ca.arg_sort(SortOptions {
146            descending: false,
147            ..Default::default()
148        });
149        assert_eq!(out.into_no_null_iter().collect::<Vec<_>>(), &[2, 1, 0, 3]);
150
151        Ok(())
152    }
153
154    #[test]
155    fn test_cat_lexical_sort_multiple() -> PolarsResult<()> {
156        let init = &["c", "b", "a", "a"];
157
158        let cats = Categories::new(
159            PlSmallStr::EMPTY,
160            PlSmallStr::EMPTY,
161            CategoricalPhysical::U8,
162        );
163        let series = Series::new(PlSmallStr::EMPTY, init).cast(&DataType::from_categories(cats))?;
164
165        let df = df![
166            "cat" => &series,
167            "vals" => [1, 1, 2, 2]
168        ]?;
169
170        let out = df.sort(
171            ["cat", "vals"],
172            SortMultipleOptions::default().with_order_descending_multi([false, false]),
173        )?;
174        let out = out.column("cat")?;
175        let cat = out.as_materialized_series().cat8()?;
176        assert_order(cat, &["a", "a", "b", "c"]);
177
178        let out = df.sort(
179            ["vals", "cat"],
180            SortMultipleOptions::default().with_order_descending_multi([false, false]),
181        )?;
182        let out = out.column("cat")?;
183        let cat = out.as_materialized_series().cat8()?;
184        assert_order(cat, &["b", "c", "a", "a"]);
185
186        Ok(())
187    }
188}