polars_plan/dsl/options/
mod.rs

1use std::hash::Hash;
2#[cfg(feature = "json")]
3use std::num::NonZeroUsize;
4use std::str::FromStr;
5use std::sync::Arc;
6
7mod sink;
8
9use polars_core::error::PolarsResult;
10use polars_core::prelude::*;
11#[cfg(feature = "csv")]
12use polars_io::csv::write::CsvWriterOptions;
13#[cfg(feature = "ipc")]
14use polars_io::ipc::IpcWriterOptions;
15#[cfg(feature = "json")]
16use polars_io::json::JsonWriterOptions;
17#[cfg(feature = "parquet")]
18use polars_io::parquet::write::ParquetWriteOptions;
19#[cfg(feature = "iejoin")]
20use polars_ops::frame::IEJoinOptions;
21use polars_ops::frame::{CrossJoinFilter, CrossJoinOptions, JoinTypeOptions};
22use polars_ops::prelude::{JoinArgs, JoinType};
23#[cfg(feature = "dynamic_group_by")]
24use polars_time::DynamicGroupOptions;
25#[cfg(feature = "dynamic_group_by")]
26use polars_time::RollingGroupOptions;
27use polars_utils::IdxSize;
28use polars_utils::pl_str::PlSmallStr;
29#[cfg(feature = "serde")]
30use serde::{Deserialize, Serialize};
31pub use sink::*;
32use strum_macros::IntoStaticStr;
33
34use super::ExprIR;
35use crate::dsl::Selector;
36
37#[derive(Copy, Clone, PartialEq, Debug, Eq, Hash)]
38#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39pub struct RollingCovOptions {
40    pub window_size: IdxSize,
41    pub min_periods: IdxSize,
42    pub ddof: u8,
43}
44
45#[derive(Clone, PartialEq, Debug, Eq, Hash)]
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
47pub struct StrptimeOptions {
48    /// Formatting string
49    pub format: Option<PlSmallStr>,
50    /// If set then polars will return an error if any date parsing fails
51    pub strict: bool,
52    /// If polars may parse matches that not contain the whole string
53    /// e.g. "foo-2021-01-01-bar" could match "2021-01-01"
54    pub exact: bool,
55    /// use a cache of unique, converted dates to apply the datetime conversion.
56    pub cache: bool,
57}
58
59impl Default for StrptimeOptions {
60    fn default() -> Self {
61        StrptimeOptions {
62            format: None,
63            strict: true,
64            exact: true,
65            cache: true,
66        }
67    }
68}
69
70#[derive(Clone, PartialEq, Eq, IntoStaticStr, Debug)]
71#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
72#[strum(serialize_all = "snake_case")]
73pub enum JoinTypeOptionsIR {
74    #[cfg(feature = "iejoin")]
75    IEJoin(IEJoinOptions),
76    #[cfg_attr(all(feature = "serde", not(feature = "ir_serde")), serde(skip))]
77    // Fused cross join and filter (only in in-memory engine)
78    Cross { predicate: ExprIR },
79}
80
81impl Hash for JoinTypeOptionsIR {
82    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
83        use JoinTypeOptionsIR::*;
84        match self {
85            #[cfg(feature = "iejoin")]
86            IEJoin(opt) => opt.hash(state),
87            Cross { predicate } => predicate.node().hash(state),
88        }
89    }
90}
91
92impl JoinTypeOptionsIR {
93    pub fn compile<C: FnOnce(&ExprIR) -> PolarsResult<Arc<dyn CrossJoinFilter>>>(
94        self,
95        plan: C,
96    ) -> PolarsResult<JoinTypeOptions> {
97        use JoinTypeOptionsIR::*;
98        match self {
99            Cross { predicate } => {
100                let predicate = plan(&predicate)?;
101
102                Ok(JoinTypeOptions::Cross(CrossJoinOptions { predicate }))
103            },
104            #[cfg(feature = "iejoin")]
105            IEJoin(opt) => Ok(JoinTypeOptions::IEJoin(opt)),
106        }
107    }
108}
109
110#[derive(Clone, Debug, PartialEq, Eq, Hash)]
111#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
112pub struct JoinOptions {
113    pub allow_parallel: bool,
114    pub force_parallel: bool,
115    pub args: JoinArgs,
116    pub options: Option<JoinTypeOptionsIR>,
117    /// Proxy of the number of rows in both sides of the joins
118    /// Holds `(Option<known_size>, estimated_size)`
119    pub rows_left: (Option<usize>, usize),
120    pub rows_right: (Option<usize>, usize),
121}
122
123impl Default for JoinOptions {
124    fn default() -> Self {
125        JoinOptions {
126            allow_parallel: true,
127            force_parallel: false,
128            // Todo!: make default
129            args: JoinArgs::new(JoinType::Left),
130            options: Default::default(),
131            rows_left: (None, usize::MAX),
132            rows_right: (None, usize::MAX),
133        }
134    }
135}
136
137#[derive(Clone, Debug, PartialEq, Eq, Hash)]
138#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
139pub enum WindowType {
140    /// Explode the aggregated list and just do a hstack instead of a join
141    /// this requires the groups to be sorted to make any sense
142    Over(WindowMapping),
143    #[cfg(feature = "dynamic_group_by")]
144    Rolling(RollingGroupOptions),
145}
146
147impl From<WindowMapping> for WindowType {
148    fn from(value: WindowMapping) -> Self {
149        Self::Over(value)
150    }
151}
152
153impl Default for WindowType {
154    fn default() -> Self {
155        Self::Over(WindowMapping::default())
156    }
157}
158
159#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)]
160#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
161#[strum(serialize_all = "snake_case")]
162pub enum WindowMapping {
163    /// Map the group values to the position
164    #[default]
165    GroupsToRows,
166    /// Explode the aggregated list and just do a hstack instead of a join
167    /// this requires the groups to be sorted to make any sense
168    Explode,
169    /// Join the groups as 'List<group_dtype>' to the row positions.
170    /// warning: this can be memory intensive
171    Join,
172}
173
174#[derive(Clone, Debug, PartialEq, Eq, Hash)]
175#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
176pub enum NestedType {
177    #[cfg(feature = "dtype-array")]
178    Array,
179    // List,
180}
181
182#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
183#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
184pub struct UnpivotArgsDSL {
185    pub on: Vec<Selector>,
186    pub index: Vec<Selector>,
187    pub variable_name: Option<PlSmallStr>,
188    pub value_name: Option<PlSmallStr>,
189}
190
191#[derive(Clone, Debug, Copy, Eq, PartialEq, Hash)]
192#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
193pub enum Engine {
194    Auto,
195    OldStreaming,
196    Streaming,
197    InMemory,
198    Gpu,
199}
200
201impl FromStr for Engine {
202    type Err = String;
203
204    fn from_str(s: &str) -> Result<Self, Self::Err> {
205        match s {
206            // "cpu" for backwards compatibility
207            "auto" => Ok(Engine::Auto),
208            "cpu" | "in-memory" => Ok(Engine::InMemory),
209            "streaming" => Ok(Engine::Streaming),
210            "old-streaming" => Ok(Engine::OldStreaming),
211            "gpu" => Ok(Engine::Gpu),
212            v => Err(format!(
213                "`engine` must be one of {{'auto', 'in-memory', 'streaming', 'old-streaming', 'gpu'}}, got {v}",
214            )),
215        }
216    }
217}
218
219impl Engine {
220    pub fn into_static_str(self) -> &'static str {
221        match self {
222            Self::Auto => "auto",
223            Self::OldStreaming => "old-streaming",
224            Self::Streaming => "streaming",
225            Self::InMemory => "in-memory",
226            Self::Gpu => "gpu",
227        }
228    }
229}
230
231#[derive(Clone, Debug, Copy, Default, Eq, PartialEq, Hash)]
232#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
233pub struct UnionOptions {
234    pub slice: Option<(i64, usize)>,
235    // known row_output, estimated row output
236    pub rows: (Option<usize>, usize),
237    pub parallel: bool,
238    pub from_partitioned_ds: bool,
239    pub flattened_by_opt: bool,
240    pub rechunk: bool,
241    pub maintain_order: bool,
242}
243
244#[derive(Clone, Debug, Copy, Default, Eq, PartialEq, Hash)]
245#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
246pub struct HConcatOptions {
247    pub parallel: bool,
248}
249
250#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)]
251#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
252pub struct GroupbyOptions {
253    #[cfg(feature = "dynamic_group_by")]
254    pub dynamic: Option<DynamicGroupOptions>,
255    #[cfg(feature = "dynamic_group_by")]
256    pub rolling: Option<RollingGroupOptions>,
257    /// Take only a slice of the result
258    pub slice: Option<(i64, usize)>,
259}
260
261impl GroupbyOptions {
262    pub(crate) fn is_rolling(&self) -> bool {
263        #[cfg(feature = "dynamic_group_by")]
264        {
265            self.rolling.is_some()
266        }
267        #[cfg(not(feature = "dynamic_group_by"))]
268        {
269            false
270        }
271    }
272
273    pub(crate) fn is_dynamic(&self) -> bool {
274        #[cfg(feature = "dynamic_group_by")]
275        {
276            self.dynamic.is_some()
277        }
278        #[cfg(not(feature = "dynamic_group_by"))]
279        {
280            false
281        }
282    }
283}
284
285#[derive(Clone, Debug, Eq, PartialEq, Default, Hash)]
286#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
287pub struct DistinctOptionsDSL {
288    /// Subset of columns that will be taken into account.
289    pub subset: Option<Vec<Selector>>,
290    /// This will maintain the order of the input.
291    /// Note that this is more expensive.
292    /// `maintain_order` is not supported in the streaming
293    /// engine.
294    pub maintain_order: bool,
295    /// Which rows to keep.
296    pub keep_strategy: UniqueKeepStrategy,
297}
298
299#[derive(Clone, Copy, PartialEq, Eq, Debug)]
300pub struct LogicalPlanUdfOptions {
301    ///  allow predicate pushdown optimizations
302    pub predicate_pd: bool,
303    ///  allow projection pushdown optimizations
304    pub projection_pd: bool,
305    // used for formatting
306    pub fmt_str: &'static str,
307}
308
309#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)]
310#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
311pub struct AnonymousScanOptions {
312    pub skip_rows: Option<usize>,
313    pub fmt_str: &'static str,
314}
315
316#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
317#[derive(Clone, Debug, PartialEq, Eq, Hash)]
318pub enum FileType {
319    #[cfg(feature = "parquet")]
320    Parquet(ParquetWriteOptions),
321    #[cfg(feature = "ipc")]
322    Ipc(IpcWriterOptions),
323    #[cfg(feature = "csv")]
324    Csv(CsvWriterOptions),
325    #[cfg(feature = "json")]
326    Json(JsonWriterOptions),
327}
328
329impl FileType {
330    pub fn extension(&self) -> &'static str {
331        match self {
332            #[cfg(feature = "parquet")]
333            Self::Parquet(_) => "parquet",
334            #[cfg(feature = "ipc")]
335            Self::Ipc(_) => "ipc",
336            #[cfg(feature = "csv")]
337            Self::Csv(_) => "csv",
338            #[cfg(feature = "json")]
339            Self::Json(_) => "jsonl",
340
341            #[allow(unreachable_patterns)]
342            _ => unreachable!("enable file type features"),
343        }
344    }
345}
346
347//
348// Arguments given to `concat`. Differs from `UnionOptions` as the latter is IR state.
349#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
350#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
351pub struct UnionArgs {
352    pub parallel: bool,
353    pub rechunk: bool,
354    pub to_supertypes: bool,
355    pub diagonal: bool,
356    // If it is a union from a scan over multiple files.
357    pub from_partitioned_ds: bool,
358    pub maintain_order: bool,
359}
360
361impl Default for UnionArgs {
362    fn default() -> Self {
363        Self {
364            parallel: true,
365            rechunk: false,
366            to_supertypes: false,
367            diagonal: false,
368            from_partitioned_ds: false,
369            maintain_order: true,
370        }
371    }
372}
373
374impl From<UnionArgs> for UnionOptions {
375    fn from(args: UnionArgs) -> Self {
376        UnionOptions {
377            slice: None,
378            parallel: args.parallel,
379            rows: (None, 0),
380            from_partitioned_ds: args.from_partitioned_ds,
381            flattened_by_opt: false,
382            rechunk: args.rechunk,
383            maintain_order: args.maintain_order,
384        }
385    }
386}
387
388#[derive(Clone, Debug, PartialEq, Eq, Hash)]
389#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
390#[cfg(feature = "json")]
391pub struct NDJsonReadOptions {
392    pub n_threads: Option<usize>,
393    pub infer_schema_length: Option<NonZeroUsize>,
394    pub chunk_size: NonZeroUsize,
395    pub low_memory: bool,
396    pub ignore_errors: bool,
397    pub schema: Option<SchemaRef>,
398    pub schema_overwrite: Option<SchemaRef>,
399}