polars_plan/plans/
options.rs

1use bitflags::bitflags;
2use polars_core::prelude::*;
3use polars_core::utils::SuperTypeOptions;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use crate::plans::PlSmallStr;
8
9#[derive(Clone, Debug, Eq, PartialEq, Hash)]
10#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
11pub struct DistinctOptionsIR {
12    /// Subset of columns that will be taken into account.
13    pub subset: Option<Arc<[PlSmallStr]>>,
14    /// This will maintain the order of the input.
15    /// Note that this is more expensive.
16    /// `maintain_order` is not supported in the streaming
17    /// engine.
18    pub maintain_order: bool,
19    /// Which rows to keep.
20    pub keep_strategy: UniqueKeepStrategy,
21    /// Take only a slice of the result
22    pub slice: Option<(i64, usize)>,
23}
24
25// a boolean that can only be set to `false` safely
26#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28pub struct UnsafeBool(bool);
29impl Default for UnsafeBool {
30    fn default() -> Self {
31        UnsafeBool(true)
32    }
33}
34
35bitflags!(
36        #[repr(transparent)]
37        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38        #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39        pub struct FunctionFlags: u16 {
40            /// Raise if use in group by
41            const ALLOW_GROUP_AWARE = 1 << 0;
42            /// The physical expression may rename the output of this function.
43            /// If set to `false` the physical engine will ensure the left input
44            /// expression is the output name.
45            const ALLOW_RENAME = 1 << 2;
46            /// if set, then the `Series` passed to the function in the group_by operation
47            /// will ensure the name is set. This is an extra heap allocation per group.
48            const PASS_NAME_TO_APPLY = 1 << 3;
49            /// There can be two ways of expanding wildcards:
50            ///
51            /// Say the schema is 'a', 'b' and there is a function `f`. In this case, `f('*')` can expand
52            /// to:
53            /// 1. `f('a', 'b')`
54            /// 2. `f('a'), f('b')`
55            ///
56            /// Setting this to true, will lead to behavior 1.
57            ///
58            /// This also accounts for regex expansion.
59            const INPUT_WILDCARD_EXPANSION = 1 << 4;
60            /// Automatically explode on unit length if it ran as final aggregation.
61            ///
62            /// this is the case for aggregations like sum, min, covariance etc.
63            /// We need to know this because we cannot see the difference between
64            /// the following functions based on the output type and number of elements:
65            ///
66            /// x: {1, 2, 3}
67            ///
68            /// head_1(x) -> {1}
69            /// sum(x) -> {4}
70            ///
71            /// mutually exclusive with `RETURNS_SCALAR`
72            const RETURNS_SCALAR = 1 << 5;
73            /// This can happen with UDF's that use Polars within the UDF.
74            /// This can lead to recursively entering the engine and sometimes deadlocks.
75            /// This flag must be set to handle that.
76            const OPTIONAL_RE_ENTRANT = 1 << 6;
77            /// Whether this function allows no inputs.
78            const ALLOW_EMPTY_INPUTS = 1 << 7;
79
80            /// Given a function f and a column of values [v1, ..., vn]
81            /// f is row-separable i.f.f.
82            /// f([v1, ..., vn]) = concat(f(v1, ... vm), f(vm+1, ..., vn))
83            const ROW_SEPARABLE = 1 << 8;
84            /// Given a function f and a column of values [v1, ..., vn]
85            /// f is length preserving i.f.f. len(f([v1, ..., vn])) = n
86            ///
87            /// mutually exclusive with `RETURNS_SCALAR`
88            const LENGTH_PRESERVING = 1 << 9;
89            /// Aggregate the values of the expression into a list before applying the function.
90            const APPLY_LIST = 1 << 10;
91        }
92);
93
94impl FunctionFlags {
95    pub fn set_elementwise(&mut self) {
96        *self |= Self::ROW_SEPARABLE | Self::LENGTH_PRESERVING;
97    }
98
99    pub fn is_elementwise(self) -> bool {
100        self.contains(Self::ROW_SEPARABLE | Self::LENGTH_PRESERVING)
101    }
102
103    pub fn returns_scalar(self) -> bool {
104        self.contains(Self::RETURNS_SCALAR)
105    }
106}
107
108impl Default for FunctionFlags {
109    fn default() -> Self {
110        Self::from_bits_truncate(0) | Self::ALLOW_GROUP_AWARE
111    }
112}
113
114#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
115pub enum CastingRules {
116    /// Whether information may be lost during cast. E.g. a float to int is considered lossy,
117    /// whereas int to int is considered lossless.
118    /// Overflowing is not considered in this flag, that's handled in `strict` casting
119    FirstArgLossless,
120    Supertype(SuperTypeOptions),
121}
122
123impl CastingRules {
124    pub fn cast_to_supertypes() -> CastingRules {
125        Self::Supertype(Default::default())
126    }
127}
128
129#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
130#[cfg_attr(any(feature = "serde"), derive(Serialize, Deserialize))]
131pub struct FunctionOptions {
132    // Validate the output of a `map`.
133    // this should always be true or we could OOB
134    pub check_lengths: UnsafeBool,
135    pub flags: FunctionFlags,
136
137    // used for formatting, (only for anonymous functions)
138    #[cfg_attr(feature = "serde", serde(skip))]
139    pub fmt_str: &'static str,
140    /// Options used when deciding how to cast the arguments of the function.
141    #[cfg_attr(feature = "serde", serde(skip))]
142    pub cast_options: Option<CastingRules>,
143}
144
145impl FunctionOptions {
146    #[cfg(feature = "fused")]
147    pub(crate) unsafe fn no_check_lengths(&mut self) {
148        self.check_lengths = UnsafeBool(false);
149    }
150    pub fn check_lengths(&self) -> bool {
151        self.check_lengths.0
152    }
153
154    pub fn set_elementwise(&mut self) {
155        self.flags.set_elementwise();
156    }
157
158    pub fn is_elementwise(&self) -> bool {
159        self.flags.is_elementwise()
160    }
161
162    pub fn is_length_preserving(&self) -> bool {
163        self.flags.contains(FunctionFlags::LENGTH_PRESERVING)
164    }
165
166    pub fn returns_scalar(&self) -> bool {
167        self.flags.returns_scalar()
168    }
169
170    pub fn elementwise() -> FunctionOptions {
171        FunctionOptions {
172            ..Default::default()
173        }
174        .with_flags(|f| f | FunctionFlags::ROW_SEPARABLE | FunctionFlags::LENGTH_PRESERVING)
175    }
176
177    pub fn elementwise_with_infer() -> FunctionOptions {
178        Self::length_preserving()
179    }
180
181    pub fn row_separable() -> FunctionOptions {
182        FunctionOptions {
183            ..Default::default()
184        }
185        .with_flags(|f| f | FunctionFlags::ROW_SEPARABLE)
186    }
187
188    pub fn length_preserving() -> FunctionOptions {
189        FunctionOptions {
190            ..Default::default()
191        }
192        .with_flags(|f| f | FunctionFlags::LENGTH_PRESERVING)
193    }
194
195    pub fn groupwise() -> FunctionOptions {
196        FunctionOptions {
197            ..Default::default()
198        }
199    }
200
201    pub fn aggregation() -> FunctionOptions {
202        let mut options = Self::groupwise();
203        options.flags |= FunctionFlags::RETURNS_SCALAR;
204        options
205    }
206
207    pub fn with_supertyping(self, supertype_options: SuperTypeOptions) -> FunctionOptions {
208        self.with_casting_rules(CastingRules::Supertype(supertype_options))
209    }
210
211    pub fn with_casting_rules(mut self, casting_rules: CastingRules) -> FunctionOptions {
212        self.cast_options = Some(casting_rules);
213        self
214    }
215
216    pub fn with_flags(mut self, f: impl Fn(FunctionFlags) -> FunctionFlags) -> FunctionOptions {
217        self.flags = f(self.flags);
218        self
219    }
220
221    pub fn with_fmt_str(mut self, fmt_str: &'static str) -> FunctionOptions {
222        self.fmt_str = fmt_str;
223        self
224    }
225}
226
227impl Default for FunctionOptions {
228    fn default() -> Self {
229        FunctionOptions {
230            check_lengths: UnsafeBool(true),
231            fmt_str: Default::default(),
232            cast_options: Default::default(),
233            flags: Default::default(),
234        }
235    }
236}
237
238#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
239#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
240pub struct ProjectionOptions {
241    pub run_parallel: bool,
242    pub duplicate_check: bool,
243    // Should length-1 Series be broadcast to the length of the dataframe.
244    // Only used by CSE optimizer
245    pub should_broadcast: bool,
246}
247
248impl Default for ProjectionOptions {
249    fn default() -> Self {
250        Self {
251            run_parallel: true,
252            duplicate_check: true,
253            should_broadcast: true,
254        }
255    }
256}
257
258impl ProjectionOptions {
259    /// Conservatively merge the options of two [`ProjectionOptions`]
260    pub fn merge_options(&self, other: &Self) -> Self {
261        Self {
262            run_parallel: self.run_parallel & other.run_parallel,
263            duplicate_check: self.duplicate_check & other.duplicate_check,
264            should_broadcast: self.should_broadcast | other.should_broadcast,
265        }
266    }
267}