polars_plan/plans/aexpr/function_expr/
binary.rs

1use super::*;
2use crate::{map, map_as_slice};
3
4#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
5#[derive(Clone, PartialEq, Debug, Eq, Hash)]
6pub enum IRBinaryFunction {
7    Contains,
8    StartsWith,
9    EndsWith,
10    #[cfg(feature = "binary_encoding")]
11    HexDecode(bool),
12    #[cfg(feature = "binary_encoding")]
13    HexEncode,
14    #[cfg(feature = "binary_encoding")]
15    Base64Decode(bool),
16    #[cfg(feature = "binary_encoding")]
17    Base64Encode,
18    Size,
19    #[cfg(feature = "binary_encoding")]
20    FromBuffer(DataType, bool),
21}
22
23impl IRBinaryFunction {
24    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
25        use IRBinaryFunction::*;
26        match self {
27            Contains => mapper.with_dtype(DataType::Boolean),
28            EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean),
29            #[cfg(feature = "binary_encoding")]
30            HexDecode(_) | Base64Decode(_) => mapper.with_same_dtype(),
31            #[cfg(feature = "binary_encoding")]
32            HexEncode | Base64Encode => mapper.with_dtype(DataType::String),
33            Size => mapper.with_dtype(DataType::UInt32),
34            #[cfg(feature = "binary_encoding")]
35            FromBuffer(dtype, _) => mapper.with_dtype(dtype.clone()),
36        }
37    }
38
39    pub fn function_options(&self) -> FunctionOptions {
40        use IRBinaryFunction as B;
41        match self {
42            B::Contains | B::StartsWith | B::EndsWith => {
43                FunctionOptions::elementwise().with_supertyping(Default::default())
44            },
45            B::Size => FunctionOptions::elementwise(),
46            #[cfg(feature = "binary_encoding")]
47            B::HexDecode(_)
48            | B::HexEncode
49            | B::Base64Decode(_)
50            | B::Base64Encode
51            | B::FromBuffer(_, _) => FunctionOptions::elementwise(),
52        }
53    }
54}
55
56impl Display for IRBinaryFunction {
57    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
58        use IRBinaryFunction::*;
59        let s = match self {
60            Contains => "contains",
61            StartsWith => "starts_with",
62            EndsWith => "ends_with",
63            #[cfg(feature = "binary_encoding")]
64            HexDecode(_) => "hex_decode",
65            #[cfg(feature = "binary_encoding")]
66            HexEncode => "hex_encode",
67            #[cfg(feature = "binary_encoding")]
68            Base64Decode(_) => "base64_decode",
69            #[cfg(feature = "binary_encoding")]
70            Base64Encode => "base64_encode",
71            Size => "size_bytes",
72            #[cfg(feature = "binary_encoding")]
73            FromBuffer(_, _) => "from_buffer",
74        };
75        write!(f, "bin.{s}")
76    }
77}
78
79impl From<IRBinaryFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
80    fn from(func: IRBinaryFunction) -> Self {
81        use IRBinaryFunction::*;
82        match func {
83            Contains => {
84                map_as_slice!(contains)
85            },
86            EndsWith => {
87                map_as_slice!(ends_with)
88            },
89            StartsWith => {
90                map_as_slice!(starts_with)
91            },
92            #[cfg(feature = "binary_encoding")]
93            HexDecode(strict) => map!(hex_decode, strict),
94            #[cfg(feature = "binary_encoding")]
95            HexEncode => map!(hex_encode),
96            #[cfg(feature = "binary_encoding")]
97            Base64Decode(strict) => map!(base64_decode, strict),
98            #[cfg(feature = "binary_encoding")]
99            Base64Encode => map!(base64_encode),
100            Size => map!(size_bytes),
101            #[cfg(feature = "binary_encoding")]
102            FromBuffer(dtype, is_little_endian) => map!(from_buffer, &dtype, is_little_endian),
103        }
104    }
105}
106
107pub(super) fn contains(s: &[Column]) -> PolarsResult<Column> {
108    let ca = s[0].binary()?;
109    let lit = s[1].binary()?;
110    Ok(ca
111        .contains_chunked(lit)?
112        .with_name(ca.name().clone())
113        .into_column())
114}
115
116pub(super) fn ends_with(s: &[Column]) -> PolarsResult<Column> {
117    let ca = s[0].binary()?;
118    let suffix = s[1].binary()?;
119
120    Ok(ca
121        .ends_with_chunked(suffix)?
122        .with_name(ca.name().clone())
123        .into_column())
124}
125
126pub(super) fn starts_with(s: &[Column]) -> PolarsResult<Column> {
127    let ca = s[0].binary()?;
128    let prefix = s[1].binary()?;
129
130    Ok(ca
131        .starts_with_chunked(prefix)?
132        .with_name(ca.name().clone())
133        .into_column())
134}
135
136pub(super) fn size_bytes(s: &Column) -> PolarsResult<Column> {
137    let ca = s.binary()?;
138    Ok(ca.size_bytes().into_column())
139}
140
141#[cfg(feature = "binary_encoding")]
142pub(super) fn hex_decode(s: &Column, strict: bool) -> PolarsResult<Column> {
143    let ca = s.binary()?;
144    ca.hex_decode(strict).map(|ok| ok.into_column())
145}
146
147#[cfg(feature = "binary_encoding")]
148pub(super) fn hex_encode(s: &Column) -> PolarsResult<Column> {
149    let ca = s.binary()?;
150    Ok(ca.hex_encode().into())
151}
152
153#[cfg(feature = "binary_encoding")]
154pub(super) fn base64_decode(s: &Column, strict: bool) -> PolarsResult<Column> {
155    let ca = s.binary()?;
156    ca.base64_decode(strict).map(|ok| ok.into_column())
157}
158
159#[cfg(feature = "binary_encoding")]
160pub(super) fn base64_encode(s: &Column) -> PolarsResult<Column> {
161    let ca = s.binary()?;
162    Ok(ca.base64_encode().into())
163}
164
165#[cfg(feature = "binary_encoding")]
166pub(super) fn from_buffer(
167    s: &Column,
168    dtype: &DataType,
169    is_little_endian: bool,
170) -> PolarsResult<Column> {
171    let ca = s.binary()?;
172    ca.from_buffer(dtype, is_little_endian)
173        .map(|val| val.into())
174}
175
176impl From<IRBinaryFunction> for IRFunctionExpr {
177    fn from(b: IRBinaryFunction) -> Self {
178        IRFunctionExpr::BinaryExpr(b)
179    }
180}