polars_plan/plans/
options.rs

1use bitflags::bitflags;
2use polars_core::prelude::*;
3use polars_core::utils::SuperTypeOptions;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use crate::plans::PlSmallStr;
8
9#[derive(Clone, Debug, Eq, PartialEq, Hash)]
10#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
11pub struct DistinctOptionsIR {
12    /// Subset of columns that will be taken into account.
13    pub subset: Option<Arc<[PlSmallStr]>>,
14    /// This will maintain the order of the input.
15    /// Note that this is more expensive.
16    /// `maintain_order` is not supported in the streaming
17    /// engine.
18    pub maintain_order: bool,
19    /// Which rows to keep.
20    pub keep_strategy: UniqueKeepStrategy,
21    /// Take only a slice of the result
22    pub slice: Option<(i64, usize)>,
23}
24
25// a boolean that can only be set to `false` safely
26#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
29pub struct UnsafeBool(bool);
30impl Default for UnsafeBool {
31    fn default() -> Self {
32        UnsafeBool(true)
33    }
34}
35
36#[cfg(feature = "dsl-schema")]
37impl schemars::JsonSchema for FunctionFlags {
38    fn schema_name() -> String {
39        "FunctionFlags".to_owned()
40    }
41
42    fn schema_id() -> std::borrow::Cow<'static, str> {
43        std::borrow::Cow::Borrowed(concat!(module_path!(), "::", "FunctionFlags"))
44    }
45
46    fn json_schema(_generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
47        use serde_json::{Map, Value};
48
49        let name_to_bits: Map<String, Value> = Self::all()
50            .iter_names()
51            .map(|(name, flag)| (name.to_owned(), flag.bits().into()))
52            .collect();
53
54        schemars::schema::Schema::Object(schemars::schema::SchemaObject {
55            instance_type: Some(schemars::schema::InstanceType::String.into()),
56            format: Some("bitflags".to_owned()),
57            extensions: schemars::Map::from_iter([
58                // Add a map of flag names and bit patterns to detect schema changes
59                ("bitflags".to_owned(), Value::Object(name_to_bits)),
60            ]),
61            ..Default::default()
62        })
63    }
64}
65
66bitflags!(
67        #[repr(transparent)]
68        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
69        #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
70        pub struct FunctionFlags: u16 {
71            /// The physical expression may rename the output of this function.
72            /// If set to `false` the physical engine will ensure the left input
73            /// expression is the output name.
74            const ALLOW_RENAME = 1 << 0;
75            /// if set, then the `Series` passed to the function in the group_by operation
76            /// will ensure the name is set. This is an extra heap allocation per group.
77            const PASS_NAME_TO_APPLY = 1 << 1;
78            /// There can be two ways of expanding wildcards:
79            ///
80            /// Say the schema is 'a', 'b' and there is a function `f`. In this case, `f('*')` can expand
81            /// to:
82            /// 1. `f('a', 'b')`
83            /// 2. `f('a'), f('b')`
84            ///
85            /// Setting this to true, will lead to behavior 1.
86            ///
87            /// This also accounts for regex expansion.
88            const INPUT_WILDCARD_EXPANSION = 1 << 2;
89            /// Automatically explode on unit length if it ran as final aggregation.
90            ///
91            /// this is the case for aggregations like sum, min, covariance etc.
92            /// We need to know this because we cannot see the difference between
93            /// the following functions based on the output type and number of elements:
94            ///
95            /// x: {1, 2, 3}
96            ///
97            /// head_1(x) -> {1}
98            /// sum(x) -> {4}
99            ///
100            /// mutually exclusive with `RETURNS_SCALAR`
101            const RETURNS_SCALAR = 1 << 3;
102            /// This can happen with UDF's that use Polars within the UDF.
103            /// This can lead to recursively entering the engine and sometimes deadlocks.
104            /// This flag must be set to handle that.
105            const OPTIONAL_RE_ENTRANT = 1 << 4;
106            /// Whether this function allows no inputs.
107            const ALLOW_EMPTY_INPUTS = 1 << 5;
108
109            /// Given a function f and a column of values [v1, ..., vn]
110            /// f is row-separable i.f.f.
111            /// f([v1, ..., vn]) = concat(f(v1, ... vm), f(vm+1, ..., vn))
112            const ROW_SEPARABLE = 1 << 6;
113            /// Given a function f and a column of values [v1, ..., vn]
114            /// f is length preserving i.f.f. len(f([v1, ..., vn])) = n
115            ///
116            /// mutually exclusive with `RETURNS_SCALAR`
117            const LENGTH_PRESERVING = 1 << 7;
118            /// NULLs on the first input are propagated to the output.
119            const PRESERVES_NULL_FIRST_INPUT = 1 << 8;
120            /// NULLs on any input are propagated to the output.
121            const PRESERVES_NULL_ALL_INPUTS = 1 << 9;
122        }
123);
124
125impl FunctionFlags {
126    pub fn set_elementwise(&mut self) {
127        *self |= Self::ROW_SEPARABLE | Self::LENGTH_PRESERVING;
128    }
129
130    pub fn is_elementwise(self) -> bool {
131        self.contains(Self::ROW_SEPARABLE | Self::LENGTH_PRESERVING)
132    }
133
134    pub fn is_row_separable(self) -> bool {
135        self.contains(Self::ROW_SEPARABLE)
136    }
137
138    pub fn is_length_preserving(self) -> bool {
139        self.contains(Self::LENGTH_PRESERVING)
140    }
141
142    pub fn returns_scalar(self) -> bool {
143        self.contains(Self::RETURNS_SCALAR)
144    }
145}
146
147impl Default for FunctionFlags {
148    fn default() -> Self {
149        Self::from_bits_truncate(0)
150    }
151}
152
153#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
154pub enum CastingRules {
155    /// Whether information may be lost during cast. E.g. a float to int is considered lossy,
156    /// whereas int to int is considered lossless.
157    /// Overflowing is not considered in this flag, that's handled in `strict` casting
158    FirstArgLossless,
159    Supertype(SuperTypeOptions),
160}
161
162impl CastingRules {
163    pub fn cast_to_supertypes() -> CastingRules {
164        Self::Supertype(Default::default())
165    }
166}
167
168#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
169#[cfg_attr(any(feature = "serde"), derive(Serialize, Deserialize))]
170#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
171pub struct FunctionOptions {
172    // Validate the output of a `map`.
173    // this should always be true or we could OOB
174    pub check_lengths: UnsafeBool,
175    pub flags: FunctionFlags,
176
177    /// Options used when deciding how to cast the arguments of the function.
178    #[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
179    pub cast_options: Option<CastingRules>,
180}
181
182impl FunctionOptions {
183    #[cfg(feature = "fused")]
184    pub(crate) unsafe fn no_check_lengths(&mut self) {
185        self.check_lengths = UnsafeBool(false);
186    }
187    pub fn check_lengths(&self) -> bool {
188        self.check_lengths.0
189    }
190
191    pub fn set_elementwise(&mut self) {
192        self.flags.set_elementwise();
193    }
194
195    pub fn is_elementwise(&self) -> bool {
196        self.flags.is_elementwise()
197    }
198
199    pub fn is_length_preserving(&self) -> bool {
200        self.flags.contains(FunctionFlags::LENGTH_PRESERVING)
201    }
202
203    pub fn is_row_separable(&self) -> bool {
204        self.flags.is_row_separable()
205    }
206
207    pub fn returns_scalar(&self) -> bool {
208        self.flags.returns_scalar()
209    }
210
211    pub fn elementwise() -> FunctionOptions {
212        FunctionOptions {
213            ..Default::default()
214        }
215        .with_flags(|f| f | FunctionFlags::ROW_SEPARABLE | FunctionFlags::LENGTH_PRESERVING)
216    }
217
218    pub fn elementwise_with_infer() -> FunctionOptions {
219        Self::length_preserving()
220    }
221
222    pub fn row_separable() -> FunctionOptions {
223        FunctionOptions {
224            ..Default::default()
225        }
226        .with_flags(|f| f | FunctionFlags::ROW_SEPARABLE)
227    }
228
229    pub fn length_preserving() -> FunctionOptions {
230        FunctionOptions {
231            ..Default::default()
232        }
233        .with_flags(|f| f | FunctionFlags::LENGTH_PRESERVING)
234    }
235
236    pub fn groupwise() -> FunctionOptions {
237        FunctionOptions {
238            ..Default::default()
239        }
240    }
241
242    pub fn aggregation() -> FunctionOptions {
243        let mut options = Self::groupwise();
244        options.flags |= FunctionFlags::RETURNS_SCALAR;
245        options
246    }
247
248    pub fn with_supertyping(self, supertype_options: SuperTypeOptions) -> FunctionOptions {
249        self.with_casting_rules(CastingRules::Supertype(supertype_options))
250    }
251
252    pub fn with_casting_rules(mut self, casting_rules: CastingRules) -> FunctionOptions {
253        self.cast_options = Some(casting_rules);
254        self
255    }
256
257    pub fn with_flags(mut self, f: impl Fn(FunctionFlags) -> FunctionFlags) -> FunctionOptions {
258        self.flags = f(self.flags);
259        self
260    }
261}
262
263impl Default for FunctionOptions {
264    fn default() -> Self {
265        FunctionOptions {
266            check_lengths: UnsafeBool(true),
267            cast_options: Default::default(),
268            flags: Default::default(),
269        }
270    }
271}
272
273#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
274#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
275#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
276pub struct ProjectionOptions {
277    pub run_parallel: bool,
278    pub duplicate_check: bool,
279    // Should length-1 Series be broadcast to the length of the dataframe.
280    // Only used by CSE optimizer
281    pub should_broadcast: bool,
282}
283
284impl Default for ProjectionOptions {
285    fn default() -> Self {
286        Self {
287            run_parallel: true,
288            duplicate_check: true,
289            should_broadcast: true,
290        }
291    }
292}
293
294impl ProjectionOptions {
295    /// Conservatively merge the options of two [`ProjectionOptions`]
296    pub fn merge_options(&self, other: &Self) -> Self {
297        Self {
298            run_parallel: self.run_parallel & other.run_parallel,
299            duplicate_check: self.duplicate_check & other.duplicate_check,
300            should_broadcast: self.should_broadcast | other.should_broadcast,
301        }
302    }
303}