polars_plan/dsl/function_expr/
binary.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3
4use super::*;
5use crate::{map, map_as_slice};
6
7#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8#[derive(Clone, PartialEq, Debug, Eq, Hash)]
9pub enum BinaryFunction {
10    Contains,
11    StartsWith,
12    EndsWith,
13    #[cfg(feature = "binary_encoding")]
14    HexDecode(bool),
15    #[cfg(feature = "binary_encoding")]
16    HexEncode,
17    #[cfg(feature = "binary_encoding")]
18    Base64Decode(bool),
19    #[cfg(feature = "binary_encoding")]
20    Base64Encode,
21    Size,
22    #[cfg(feature = "binary_encoding")]
23    FromBuffer(DataType, bool),
24}
25
26impl BinaryFunction {
27    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
28        use BinaryFunction::*;
29        match self {
30            Contains => mapper.with_dtype(DataType::Boolean),
31            EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean),
32            #[cfg(feature = "binary_encoding")]
33            HexDecode(_) | Base64Decode(_) => mapper.with_same_dtype(),
34            #[cfg(feature = "binary_encoding")]
35            HexEncode | Base64Encode => mapper.with_dtype(DataType::String),
36            Size => mapper.with_dtype(DataType::UInt32),
37            #[cfg(feature = "binary_encoding")]
38            FromBuffer(dtype, _) => mapper.with_dtype(dtype.clone()),
39        }
40    }
41
42    pub fn function_options(&self) -> FunctionOptions {
43        use BinaryFunction as B;
44        match self {
45            B::Contains | B::StartsWith | B::EndsWith => {
46                FunctionOptions::elementwise().with_supertyping(Default::default())
47            },
48            B::Size => FunctionOptions::elementwise(),
49            #[cfg(feature = "binary_encoding")]
50            B::HexDecode(_)
51            | B::HexEncode
52            | B::Base64Decode(_)
53            | B::Base64Encode
54            | B::FromBuffer(_, _) => FunctionOptions::elementwise(),
55        }
56    }
57}
58
59impl Display for BinaryFunction {
60    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
61        use BinaryFunction::*;
62        let s = match self {
63            Contains => "contains",
64            StartsWith => "starts_with",
65            EndsWith => "ends_with",
66            #[cfg(feature = "binary_encoding")]
67            HexDecode(_) => "hex_decode",
68            #[cfg(feature = "binary_encoding")]
69            HexEncode => "hex_encode",
70            #[cfg(feature = "binary_encoding")]
71            Base64Decode(_) => "base64_decode",
72            #[cfg(feature = "binary_encoding")]
73            Base64Encode => "base64_encode",
74            Size => "size_bytes",
75            #[cfg(feature = "binary_encoding")]
76            FromBuffer(_, _) => "from_buffer",
77        };
78        write!(f, "bin.{s}")
79    }
80}
81
82impl From<BinaryFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
83    fn from(func: BinaryFunction) -> Self {
84        use BinaryFunction::*;
85        match func {
86            Contains => {
87                map_as_slice!(contains)
88            },
89            EndsWith => {
90                map_as_slice!(ends_with)
91            },
92            StartsWith => {
93                map_as_slice!(starts_with)
94            },
95            #[cfg(feature = "binary_encoding")]
96            HexDecode(strict) => map!(hex_decode, strict),
97            #[cfg(feature = "binary_encoding")]
98            HexEncode => map!(hex_encode),
99            #[cfg(feature = "binary_encoding")]
100            Base64Decode(strict) => map!(base64_decode, strict),
101            #[cfg(feature = "binary_encoding")]
102            Base64Encode => map!(base64_encode),
103            Size => map!(size_bytes),
104            #[cfg(feature = "binary_encoding")]
105            FromBuffer(dtype, is_little_endian) => map!(from_buffer, &dtype, is_little_endian),
106        }
107    }
108}
109
110pub(super) fn contains(s: &[Column]) -> PolarsResult<Column> {
111    let ca = s[0].binary()?;
112    let lit = s[1].binary()?;
113    Ok(ca
114        .contains_chunked(lit)?
115        .with_name(ca.name().clone())
116        .into_column())
117}
118
119pub(super) fn ends_with(s: &[Column]) -> PolarsResult<Column> {
120    let ca = s[0].binary()?;
121    let suffix = s[1].binary()?;
122
123    Ok(ca
124        .ends_with_chunked(suffix)?
125        .with_name(ca.name().clone())
126        .into_column())
127}
128
129pub(super) fn starts_with(s: &[Column]) -> PolarsResult<Column> {
130    let ca = s[0].binary()?;
131    let prefix = s[1].binary()?;
132
133    Ok(ca
134        .starts_with_chunked(prefix)?
135        .with_name(ca.name().clone())
136        .into_column())
137}
138
139pub(super) fn size_bytes(s: &Column) -> PolarsResult<Column> {
140    let ca = s.binary()?;
141    Ok(ca.size_bytes().into_column())
142}
143
144#[cfg(feature = "binary_encoding")]
145pub(super) fn hex_decode(s: &Column, strict: bool) -> PolarsResult<Column> {
146    let ca = s.binary()?;
147    ca.hex_decode(strict).map(|ok| ok.into_column())
148}
149
150#[cfg(feature = "binary_encoding")]
151pub(super) fn hex_encode(s: &Column) -> PolarsResult<Column> {
152    let ca = s.binary()?;
153    Ok(ca.hex_encode().into())
154}
155
156#[cfg(feature = "binary_encoding")]
157pub(super) fn base64_decode(s: &Column, strict: bool) -> PolarsResult<Column> {
158    let ca = s.binary()?;
159    ca.base64_decode(strict).map(|ok| ok.into_column())
160}
161
162#[cfg(feature = "binary_encoding")]
163pub(super) fn base64_encode(s: &Column) -> PolarsResult<Column> {
164    let ca = s.binary()?;
165    Ok(ca.base64_encode().into())
166}
167
168#[cfg(feature = "binary_encoding")]
169pub(super) fn from_buffer(
170    s: &Column,
171    dtype: &DataType,
172    is_little_endian: bool,
173) -> PolarsResult<Column> {
174    let ca = s.binary()?;
175    ca.from_buffer(dtype, is_little_endian)
176        .map(|val| val.into())
177}
178
179impl From<BinaryFunction> for FunctionExpr {
180    fn from(b: BinaryFunction) -> Self {
181        FunctionExpr::BinaryExpr(b)
182    }
183}