polars_plan/dsl/function_expr/
cat.rs

1use super::*;
2use crate::map;
3
4#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
5#[derive(Clone, PartialEq, Debug, Eq, Hash)]
6pub enum CategoricalFunction {
7    GetCategories,
8    #[cfg(feature = "strings")]
9    LenBytes,
10    #[cfg(feature = "strings")]
11    LenChars,
12    #[cfg(feature = "strings")]
13    StartsWith(String),
14    #[cfg(feature = "strings")]
15    EndsWith(String),
16    #[cfg(feature = "strings")]
17    Slice(i64, Option<usize>),
18}
19
20impl CategoricalFunction {
21    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
22        use CategoricalFunction::*;
23        match self {
24            GetCategories => mapper.with_dtype(DataType::String),
25            #[cfg(feature = "strings")]
26            LenBytes => mapper.with_dtype(DataType::UInt32),
27            #[cfg(feature = "strings")]
28            LenChars => mapper.with_dtype(DataType::UInt32),
29            #[cfg(feature = "strings")]
30            StartsWith(_) => mapper.with_dtype(DataType::Boolean),
31            #[cfg(feature = "strings")]
32            EndsWith(_) => mapper.with_dtype(DataType::Boolean),
33            #[cfg(feature = "strings")]
34            Slice(_, _) => mapper.with_dtype(DataType::String),
35        }
36    }
37
38    pub fn function_options(&self) -> FunctionOptions {
39        use CategoricalFunction as C;
40        match self {
41            C::GetCategories => FunctionOptions::groupwise(),
42            #[cfg(feature = "strings")]
43            C::LenBytes | C::LenChars | C::StartsWith(_) | C::EndsWith(_) | C::Slice(_, _) => {
44                FunctionOptions::elementwise()
45            },
46        }
47    }
48}
49
50impl Display for CategoricalFunction {
51    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52        use CategoricalFunction::*;
53        let s = match self {
54            GetCategories => "get_categories",
55            #[cfg(feature = "strings")]
56            LenBytes => "len_bytes",
57            #[cfg(feature = "strings")]
58            LenChars => "len_chars",
59            #[cfg(feature = "strings")]
60            StartsWith(_) => "starts_with",
61            #[cfg(feature = "strings")]
62            EndsWith(_) => "ends_with",
63            #[cfg(feature = "strings")]
64            Slice(_, _) => "slice",
65        };
66        write!(f, "cat.{s}")
67    }
68}
69
70impl From<CategoricalFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
71    fn from(func: CategoricalFunction) -> Self {
72        use CategoricalFunction::*;
73        match func {
74            GetCategories => map!(get_categories),
75            #[cfg(feature = "strings")]
76            LenBytes => map!(len_bytes),
77            #[cfg(feature = "strings")]
78            LenChars => map!(len_chars),
79            #[cfg(feature = "strings")]
80            StartsWith(prefix) => map!(starts_with, prefix.as_str()),
81            #[cfg(feature = "strings")]
82            EndsWith(suffix) => map!(ends_with, suffix.as_str()),
83            #[cfg(feature = "strings")]
84            Slice(offset, length) => map!(slice, offset, length),
85        }
86    }
87}
88
89impl From<CategoricalFunction> for FunctionExpr {
90    fn from(func: CategoricalFunction) -> Self {
91        FunctionExpr::Categorical(func)
92    }
93}
94
95fn get_categories(s: &Column) -> PolarsResult<Column> {
96    // categorical check
97    let ca = s.categorical()?;
98    let rev_map = ca.get_rev_map();
99    let arr = rev_map.get_categories().clone().boxed();
100    Series::try_from((ca.name().clone(), arr)).map(Column::from)
101}
102
103// Determine mapping between categories and underlying physical. For local, this is just 0..n.
104// For global, this is the global indexes.
105fn _get_cat_phys_map(ca: &CategoricalChunked) -> (StringChunked, Series) {
106    let (categories, phys) = match &**ca.get_rev_map() {
107        RevMapping::Local(c, _) => (c, ca.physical().cast(&IDX_DTYPE).unwrap()),
108        RevMapping::Global(physical_map, c, _) => {
109            // Map physical to its local representation for use with take() later.
110            let phys = ca
111                .physical()
112                .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap()));
113            let out = phys.cast(&IDX_DTYPE).unwrap();
114            (c, out)
115        },
116    };
117    let categories = StringChunked::with_chunk(ca.name().clone(), categories.clone());
118    (categories, phys)
119}
120
121/// Fast path: apply a string function to the categories of a categorical column and broadcast the
122/// result back to the array.
123// fn apply_to_cats<F, T>(ca: &CategoricalChunked, mut op: F) -> PolarsResult<Column>
124fn apply_to_cats<F, T>(c: &Column, mut op: F) -> PolarsResult<Column>
125where
126    F: FnMut(&StringChunked) -> ChunkedArray<T>,
127    ChunkedArray<T>: IntoSeries,
128    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
129{
130    let ca = c.categorical()?;
131    let (categories, phys) = _get_cat_phys_map(ca);
132    let result = op(&categories);
133    // SAFETY: physical idx array is valid.
134    let out = unsafe { result.take_unchecked(phys.idx().unwrap()) };
135    Ok(out.into_column())
136}
137
138/// Fast path: apply a binary function to the categories of a categorical column and broadcast the
139/// result back to the array.
140fn apply_to_cats_binary<F, T>(c: &Column, mut op: F) -> PolarsResult<Column>
141where
142    F: FnMut(&BinaryChunked) -> ChunkedArray<T>,
143    ChunkedArray<T>: IntoSeries,
144    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
145{
146    let ca = c.categorical()?;
147    let (categories, phys) = _get_cat_phys_map(ca);
148    let result = op(&categories.as_binary());
149    // SAFETY: physical idx array is valid.
150    let out = unsafe { result.take_unchecked(phys.idx().unwrap()) };
151    Ok(out.into_column())
152}
153
154#[cfg(feature = "strings")]
155fn len_bytes(c: &Column) -> PolarsResult<Column> {
156    apply_to_cats(c, |s| s.str_len_bytes())
157}
158
159#[cfg(feature = "strings")]
160fn len_chars(c: &Column) -> PolarsResult<Column> {
161    apply_to_cats(c, |s| s.str_len_chars())
162}
163
164#[cfg(feature = "strings")]
165fn starts_with(c: &Column, prefix: &str) -> PolarsResult<Column> {
166    apply_to_cats_binary(c, |s| s.starts_with(prefix.as_bytes()))
167}
168
169#[cfg(feature = "strings")]
170fn ends_with(c: &Column, suffix: &str) -> PolarsResult<Column> {
171    apply_to_cats_binary(c, |s| s.ends_with(suffix.as_bytes()))
172}
173
174#[cfg(feature = "strings")]
175fn slice(c: &Column, offset: i64, length: Option<usize>) -> PolarsResult<Column> {
176    let length = length.unwrap_or(usize::MAX) as u64;
177    let ca = c.categorical()?;
178    let (categories, phys) = _get_cat_phys_map(ca);
179
180    let result = unsafe {
181        categories.apply_views(|view, val| {
182            let (start, end) = substring_ternary_offsets_value(val, offset, length);
183            update_view(view, start, end, val)
184        })
185    };
186    // SAFETY: physical idx array is valid.
187    let out = unsafe { result.take_unchecked(phys.idx().unwrap()) };
188    Ok(out.into_column())
189}