polars_plan/dsl/
builder_dsl.rs

1use std::sync::Arc;
2
3use polars_core::prelude::*;
4#[cfg(feature = "csv")]
5use polars_io::csv::read::CsvReadOptions;
6#[cfg(feature = "ipc")]
7use polars_io::ipc::IpcScanOptions;
8#[cfg(feature = "parquet")]
9use polars_io::parquet::read::ParquetOptions;
10
11#[cfg(feature = "python")]
12use crate::dsl::python_dsl::PythonFunction;
13use crate::prelude::*;
14
15pub struct DslBuilder(pub DslPlan);
16
17impl From<DslPlan> for DslBuilder {
18    fn from(lp: DslPlan) -> Self {
19        DslBuilder(lp)
20    }
21}
22
23impl DslBuilder {
24    pub fn anonymous_scan(
25        function: Arc<dyn AnonymousScan>,
26        options: AnonymousScanOptions,
27        unified_scan_args: UnifiedScanArgs,
28    ) -> PolarsResult<Self> {
29        let schema = unified_scan_args.schema.clone().ok_or_else(|| {
30            polars_err!(
31                ComputeError:
32                "anonymous scan requires schema to be specified in unified_scan_args"
33            )
34        })?;
35
36        Ok(DslPlan::Scan {
37            sources: ScanSources::default(),
38            file_info: Some(FileInfo {
39                schema: schema.clone(),
40                reader_schema: Some(either::Either::Right(schema)),
41                ..Default::default()
42            }),
43            unified_scan_args: Box::new(unified_scan_args),
44            scan_type: Box::new(FileScan::Anonymous {
45                function,
46                options: Arc::new(options),
47            }),
48            cached_ir: Default::default(),
49        }
50        .into())
51    }
52
53    #[cfg(feature = "parquet")]
54    #[allow(clippy::too_many_arguments)]
55    pub fn scan_parquet(
56        sources: ScanSources,
57        options: ParquetOptions,
58        unified_scan_args: UnifiedScanArgs,
59    ) -> PolarsResult<Self> {
60        Ok(DslPlan::Scan {
61            sources,
62            file_info: None,
63            unified_scan_args: Box::new(unified_scan_args),
64            scan_type: Box::new(FileScan::Parquet {
65                options,
66                metadata: None,
67            }),
68            cached_ir: Default::default(),
69        }
70        .into())
71    }
72
73    #[cfg(feature = "ipc")]
74    #[allow(clippy::too_many_arguments)]
75    pub fn scan_ipc(
76        sources: ScanSources,
77        options: IpcScanOptions,
78        unified_scan_args: UnifiedScanArgs,
79    ) -> PolarsResult<Self> {
80        Ok(DslPlan::Scan {
81            sources,
82            file_info: None,
83            unified_scan_args: Box::new(unified_scan_args),
84            scan_type: Box::new(FileScan::Ipc {
85                options,
86                metadata: None,
87            }),
88            cached_ir: Default::default(),
89        }
90        .into())
91    }
92
93    #[allow(clippy::too_many_arguments)]
94    #[cfg(feature = "csv")]
95    pub fn scan_csv(
96        sources: ScanSources,
97        options: CsvReadOptions,
98        unified_scan_args: UnifiedScanArgs,
99    ) -> PolarsResult<Self> {
100        Ok(DslPlan::Scan {
101            sources,
102            file_info: None,
103            unified_scan_args: Box::new(unified_scan_args),
104            scan_type: Box::new(FileScan::Csv { options }),
105            cached_ir: Default::default(),
106        }
107        .into())
108    }
109
110    #[cfg(feature = "python")]
111    pub fn scan_python_dataset(
112        dataset_object: polars_utils::python_function::PythonObject,
113    ) -> DslBuilder {
114        use super::python_dataset::PythonDatasetProvider;
115
116        DslPlan::Scan {
117            sources: ScanSources::default(),
118            file_info: None,
119            unified_scan_args: Default::default(),
120            scan_type: Box::new(FileScan::PythonDataset {
121                dataset_object: Arc::new(PythonDatasetProvider::new(dataset_object)),
122                cached_ir: Default::default(),
123            }),
124            cached_ir: Default::default(),
125        }
126        .into()
127    }
128
129    pub fn cache(self) -> Self {
130        let input = Arc::new(self.0);
131        let id = input.as_ref() as *const DslPlan as usize;
132        DslPlan::Cache { input, id }.into()
133    }
134
135    pub fn drop(self, to_drop: Vec<Selector>, strict: bool) -> Self {
136        self.map_private(DslFunction::Drop(DropFunction { to_drop, strict }))
137    }
138
139    pub fn project(self, exprs: Vec<Expr>, options: ProjectionOptions) -> Self {
140        DslPlan::Select {
141            expr: exprs,
142            input: Arc::new(self.0),
143            options,
144        }
145        .into()
146    }
147
148    pub fn fill_null(self, fill_value: Expr) -> Self {
149        self.project(
150            vec![all().fill_null(fill_value)],
151            ProjectionOptions {
152                duplicate_check: false,
153                ..Default::default()
154            },
155        )
156    }
157
158    pub fn drop_nans(self, subset: Option<Vec<Expr>>) -> Self {
159        if let Some(subset) = subset {
160            self.filter(
161                all_horizontal(
162                    subset
163                        .into_iter()
164                        .map(|v| v.is_not_nan())
165                        .collect::<Vec<_>>(),
166                )
167                .unwrap(),
168            )
169        } else {
170            self.filter(
171                // TODO: when Decimal supports NaN, include here
172                all_horizontal([dtype_cols([DataType::Float32, DataType::Float64]).is_not_nan()])
173                    .unwrap(),
174            )
175        }
176    }
177
178    pub fn drop_nulls(self, subset: Option<Vec<Expr>>) -> Self {
179        if let Some(subset) = subset {
180            self.filter(
181                all_horizontal(
182                    subset
183                        .into_iter()
184                        .map(|v| v.is_not_null())
185                        .collect::<Vec<_>>(),
186                )
187                .unwrap(),
188            )
189        } else {
190            self.filter(all_horizontal([all().is_not_null()]).unwrap())
191        }
192    }
193
194    pub fn fill_nan(self, fill_value: Expr) -> Self {
195        self.map_private(DslFunction::FillNan(fill_value))
196    }
197
198    pub fn with_columns(self, exprs: Vec<Expr>, options: ProjectionOptions) -> Self {
199        if exprs.is_empty() {
200            return self;
201        }
202
203        DslPlan::HStack {
204            input: Arc::new(self.0),
205            exprs,
206            options,
207        }
208        .into()
209    }
210
211    pub fn with_context(self, contexts: Vec<DslPlan>) -> Self {
212        DslPlan::ExtContext {
213            input: Arc::new(self.0),
214            contexts,
215        }
216        .into()
217    }
218
219    /// Apply a filter
220    pub fn filter(self, predicate: Expr) -> Self {
221        DslPlan::Filter {
222            predicate,
223            input: Arc::new(self.0),
224        }
225        .into()
226    }
227
228    pub fn group_by<E: AsRef<[Expr]>>(
229        self,
230        keys: Vec<Expr>,
231        aggs: E,
232        apply: Option<(Arc<dyn DataFrameUdf>, SchemaRef)>,
233        maintain_order: bool,
234        #[cfg(feature = "dynamic_group_by")] dynamic_options: Option<DynamicGroupOptions>,
235        #[cfg(feature = "dynamic_group_by")] rolling_options: Option<RollingGroupOptions>,
236    ) -> Self {
237        let aggs = aggs.as_ref().to_vec();
238        let options = GroupbyOptions {
239            #[cfg(feature = "dynamic_group_by")]
240            dynamic: dynamic_options,
241            #[cfg(feature = "dynamic_group_by")]
242            rolling: rolling_options,
243            slice: None,
244        };
245
246        DslPlan::GroupBy {
247            input: Arc::new(self.0),
248            keys,
249            aggs,
250            apply,
251            maintain_order,
252            options: Arc::new(options),
253        }
254        .into()
255    }
256
257    pub fn build(self) -> DslPlan {
258        self.0
259    }
260
261    pub fn from_existing_df(df: DataFrame) -> Self {
262        let schema = df.schema().clone();
263        DslPlan::DataFrameScan {
264            df: Arc::new(df),
265            schema,
266        }
267        .into()
268    }
269
270    pub fn sort(self, by_column: Vec<Expr>, sort_options: SortMultipleOptions) -> Self {
271        DslPlan::Sort {
272            input: Arc::new(self.0),
273            by_column,
274            slice: None,
275            sort_options,
276        }
277        .into()
278    }
279
280    pub fn explode(self, columns: Vec<Selector>, allow_empty: bool) -> Self {
281        DslPlan::MapFunction {
282            input: Arc::new(self.0),
283            function: DslFunction::Explode {
284                columns,
285                allow_empty,
286            },
287        }
288        .into()
289    }
290
291    #[cfg(feature = "pivot")]
292    pub fn unpivot(self, args: UnpivotArgsDSL) -> Self {
293        DslPlan::MapFunction {
294            input: Arc::new(self.0),
295            function: DslFunction::Unpivot { args },
296        }
297        .into()
298    }
299
300    pub fn row_index(self, name: PlSmallStr, offset: Option<IdxSize>) -> Self {
301        DslPlan::MapFunction {
302            input: Arc::new(self.0),
303            function: DslFunction::RowIndex { name, offset },
304        }
305        .into()
306    }
307
308    pub fn distinct(self, options: DistinctOptionsDSL) -> Self {
309        DslPlan::Distinct {
310            input: Arc::new(self.0),
311            options,
312        }
313        .into()
314    }
315
316    pub fn slice(self, offset: i64, len: IdxSize) -> Self {
317        DslPlan::Slice {
318            input: Arc::new(self.0),
319            offset,
320            len,
321        }
322        .into()
323    }
324
325    pub fn join(
326        self,
327        other: DslPlan,
328        left_on: Vec<Expr>,
329        right_on: Vec<Expr>,
330        options: Arc<JoinOptions>,
331    ) -> Self {
332        DslPlan::Join {
333            input_left: Arc::new(self.0),
334            input_right: Arc::new(other),
335            left_on,
336            right_on,
337            predicates: Default::default(),
338            options,
339        }
340        .into()
341    }
342    pub fn map_private(self, function: DslFunction) -> Self {
343        DslPlan::MapFunction {
344            input: Arc::new(self.0),
345            function,
346        }
347        .into()
348    }
349
350    #[cfg(feature = "python")]
351    pub fn map_python(
352        self,
353        function: PythonFunction,
354        optimizations: AllowedOptimizations,
355        schema: Option<SchemaRef>,
356        validate_output: bool,
357    ) -> Self {
358        DslPlan::MapFunction {
359            input: Arc::new(self.0),
360            function: DslFunction::OpaquePython(OpaquePythonUdf {
361                function,
362                schema,
363                predicate_pd: optimizations.contains(OptFlags::PREDICATE_PUSHDOWN),
364                projection_pd: optimizations.contains(OptFlags::PROJECTION_PUSHDOWN),
365                streamable: optimizations.contains(OptFlags::STREAMING),
366                validate_output,
367            }),
368        }
369        .into()
370    }
371
372    pub fn map<F>(
373        self,
374        function: F,
375        optimizations: AllowedOptimizations,
376        schema: Option<Arc<dyn UdfSchema>>,
377        name: PlSmallStr,
378    ) -> Self
379    where
380        F: DataFrameUdf + 'static,
381    {
382        let function = Arc::new(function);
383
384        DslPlan::MapFunction {
385            input: Arc::new(self.0),
386            function: DslFunction::FunctionIR(FunctionIR::Opaque {
387                function,
388                schema,
389                predicate_pd: optimizations.contains(OptFlags::PREDICATE_PUSHDOWN),
390                projection_pd: optimizations.contains(OptFlags::PROJECTION_PUSHDOWN),
391                streamable: optimizations.contains(OptFlags::STREAMING),
392                fmt_str: name,
393            }),
394        }
395        .into()
396    }
397}