polars_mem_engine/scan_predicate/
functions.rs

1use std::cell::LazyCell;
2use std::sync::Arc;
3
4use polars_core::config;
5use polars_core::error::PolarsResult;
6use polars_core::prelude::{IDX_DTYPE, IdxCa, InitHashMaps, PlHashMap, PlIndexMap, PlIndexSet};
7use polars_core::schema::Schema;
8use polars_error::polars_warn;
9use polars_expr::{ExpressionConversionState, create_physical_expr};
10use polars_io::predicates::ScanIOPredicate;
11use polars_plan::dsl::default_values::{
12    DefaultFieldValues, IcebergIdentityTransformedPartitionFields,
13};
14use polars_plan::dsl::deletion::DeletionFilesList;
15use polars_plan::dsl::{FileScanIR, Operator, ScanSources, TableStatistics, UnifiedScanArgs};
16use polars_plan::plans::expr_ir::{ExprIR, OutputName};
17use polars_plan::plans::hive::HivePartitionsDf;
18use polars_plan::plans::predicates::{aexpr_to_column_predicates, aexpr_to_skip_batch_predicate};
19use polars_plan::plans::{AExpr, Context, ExprIRDisplay, FileInfo, IR, MintermIter};
20use polars_plan::utils::aexpr_to_leaf_names_iter;
21use polars_utils::arena::{Arena, Node};
22use polars_utils::pl_str::PlSmallStr;
23use polars_utils::{IdxSize, format_pl_smallstr};
24
25use crate::scan_predicate::skip_files_mask::SkipFilesMask;
26use crate::scan_predicate::{PhysicalColumnPredicates, ScanPredicate};
27
28pub fn create_scan_predicate(
29    predicate: &ExprIR,
30    expr_arena: &mut Arena<AExpr>,
31    schema: &Arc<Schema>,
32    hive_schema: Option<&Schema>,
33    state: &mut ExpressionConversionState,
34    create_skip_batch_predicate: bool,
35    create_column_predicates: bool,
36) -> PolarsResult<ScanPredicate> {
37    let mut predicate = predicate.clone();
38
39    let mut hive_predicate = None;
40    let mut hive_predicate_is_full_predicate = false;
41
42    #[expect(clippy::never_loop)]
43    loop {
44        let Some(hive_schema) = hive_schema else {
45            break;
46        };
47
48        let mut hive_predicate_parts = vec![];
49        let mut non_hive_predicate_parts = vec![];
50
51        for predicate_part in MintermIter::new(predicate.node(), expr_arena) {
52            if aexpr_to_leaf_names_iter(predicate_part, expr_arena)
53                .all(|name| hive_schema.contains(&name))
54            {
55                hive_predicate_parts.push(predicate_part)
56            } else {
57                non_hive_predicate_parts.push(predicate_part)
58            }
59        }
60
61        if hive_predicate_parts.is_empty() {
62            break;
63        }
64
65        if non_hive_predicate_parts.is_empty() {
66            hive_predicate_is_full_predicate = true;
67            break;
68        }
69
70        {
71            let mut iter = hive_predicate_parts.into_iter();
72            let mut node = iter.next().unwrap();
73
74            for next_node in iter {
75                node = expr_arena.add(AExpr::BinaryExpr {
76                    left: node,
77                    op: Operator::And,
78                    right: next_node,
79                });
80            }
81
82            hive_predicate = Some(create_physical_expr(
83                &ExprIR::from_node(node, expr_arena),
84                Context::Default,
85                expr_arena,
86                schema,
87                state,
88            )?)
89        }
90
91        {
92            let mut iter = non_hive_predicate_parts.into_iter();
93            let mut node = iter.next().unwrap();
94
95            for next_node in iter {
96                node = expr_arena.add(AExpr::BinaryExpr {
97                    left: node,
98                    op: Operator::And,
99                    right: next_node,
100                });
101            }
102
103            predicate = ExprIR::from_node(node, expr_arena);
104        }
105
106        break;
107    }
108
109    let phys_predicate =
110        create_physical_expr(&predicate, Context::Default, expr_arena, schema, state)?;
111
112    if hive_predicate_is_full_predicate {
113        hive_predicate = Some(phys_predicate.clone());
114    }
115
116    let live_columns = Arc::new(PlIndexSet::from_iter(aexpr_to_leaf_names_iter(
117        predicate.node(),
118        expr_arena,
119    )));
120
121    let mut skip_batch_predicate = None;
122
123    if create_skip_batch_predicate {
124        if let Some(node) = aexpr_to_skip_batch_predicate(predicate.node(), expr_arena, schema) {
125            let expr = ExprIR::new(node, predicate.output_name_inner().clone());
126
127            if std::env::var("POLARS_OUTPUT_SKIP_BATCH_PRED").as_deref() == Ok("1") {
128                eprintln!("predicate: {}", predicate.display(expr_arena));
129                eprintln!("skip_batch_predicate: {}", expr.display(expr_arena));
130            }
131
132            let mut skip_batch_schema = Schema::with_capacity(1 + live_columns.len());
133
134            skip_batch_schema.insert(PlSmallStr::from_static("len"), IDX_DTYPE);
135            for (col, dtype) in schema.iter() {
136                if !live_columns.contains(col) {
137                    continue;
138                }
139
140                skip_batch_schema.insert(format_pl_smallstr!("{col}_min"), dtype.clone());
141                skip_batch_schema.insert(format_pl_smallstr!("{col}_max"), dtype.clone());
142                skip_batch_schema.insert(format_pl_smallstr!("{col}_nc"), IDX_DTYPE);
143            }
144
145            skip_batch_predicate = Some(create_physical_expr(
146                &expr,
147                Context::Default,
148                expr_arena,
149                &Arc::new(skip_batch_schema),
150                state,
151            )?);
152        }
153    }
154
155    let column_predicates = if create_column_predicates {
156        let column_predicates = aexpr_to_column_predicates(predicate.node(), expr_arena, schema);
157        if std::env::var("POLARS_OUTPUT_COLUMN_PREDS").as_deref() == Ok("1") {
158            eprintln!("column_predicates: {{");
159            eprintln!("  [");
160            for (pred, spec) in column_predicates.predicates.values() {
161                eprintln!(
162                    "    {} ({spec:?}),",
163                    ExprIRDisplay::display_node(*pred, expr_arena)
164                );
165            }
166            eprintln!("  ],");
167            eprintln!(
168                "  is_sumwise_complete: {}",
169                column_predicates.is_sumwise_complete
170            );
171            eprintln!("}}");
172        }
173        PhysicalColumnPredicates {
174            predicates: column_predicates
175                .predicates
176                .into_iter()
177                .map(|(n, (p, s))| {
178                    PolarsResult::Ok((
179                        n,
180                        (
181                            create_physical_expr(
182                                &ExprIR::new(p, OutputName::Alias(PlSmallStr::EMPTY)),
183                                Context::Default,
184                                expr_arena,
185                                schema,
186                                state,
187                            )?,
188                            s,
189                        ),
190                    ))
191                })
192                .collect::<PolarsResult<PlHashMap<_, _>>>()?,
193            is_sumwise_complete: column_predicates.is_sumwise_complete,
194        }
195    } else {
196        PhysicalColumnPredicates {
197            predicates: PlHashMap::default(),
198            is_sumwise_complete: false,
199        }
200    };
201
202    PolarsResult::Ok(ScanPredicate {
203        predicate: phys_predicate,
204        live_columns,
205        skip_batch_predicate,
206        column_predicates,
207        hive_predicate,
208        hive_predicate_is_full_predicate,
209    })
210}
211
212/// # Returns
213/// (skip_files_mask, predicate)
214pub fn initialize_scan_predicate<'a>(
215    predicate: Option<&'a ScanIOPredicate>,
216    hive_parts: Option<&HivePartitionsDf>,
217    table_statsitics: Option<&TableStatistics>,
218    verbose: bool,
219) -> PolarsResult<(Option<SkipFilesMask>, Option<&'a ScanIOPredicate>)> {
220    #[expect(clippy::never_loop)]
221    loop {
222        let Some(predicate) = predicate else {
223            break;
224        };
225
226        let expected_mask_len: usize;
227
228        let (skip_files_mask, send_predicate_to_readers) = if let Some(hive_parts) = hive_parts
229            && let Some(hive_predicate) = &predicate.hive_predicate
230        {
231            if verbose {
232                eprintln!(
233                    "initialize_scan_predicate: Source filter mask initialization via hive partitions"
234                );
235            }
236
237            expected_mask_len = hive_parts.df().height();
238
239            let inclusion_mask = hive_predicate
240                .evaluate_io(hive_parts.df())?
241                .bool()?
242                .rechunk()
243                .into_owned()
244                .downcast_into_iter()
245                .next()
246                .unwrap()
247                .values()
248                .clone();
249
250            (
251                SkipFilesMask::Inclusion(inclusion_mask),
252                !predicate.hive_predicate_is_full_predicate,
253            )
254        } else if let Some(table_statsitics) = table_statsitics
255            && let Some(skip_batch_predicate) = &predicate.skip_batch_predicate
256        {
257            if verbose {
258                eprintln!(
259                    "initialize_scan_predicate: Source filter mask initialization via table statistics"
260                );
261            }
262
263            expected_mask_len = table_statsitics.0.height();
264
265            let exclusion_mask = skip_batch_predicate.evaluate_with_stat_df(&table_statsitics.0)?;
266
267            (SkipFilesMask::Exclusion(exclusion_mask), true)
268        } else {
269            break;
270        };
271
272        if skip_files_mask.len() != expected_mask_len {
273            let msg = format!(
274                "WARNING: \
275                initialize_scan_predicate: \
276                filter mask length mismatch (length: {}, expected: {}). Files \
277                will not be skipped. This is a bug; please open an issue with \
278                a reproducible example if possible.",
279                skip_files_mask.len(),
280                expected_mask_len
281            );
282            polars_warn!(msg);
283            return Ok((None, Some(predicate)));
284        }
285
286        if verbose {
287            eprintln!(
288                "initialize_scan_predicate: Predicate pushdown allows skipping {} / {} files",
289                skip_files_mask.num_skipped_files(),
290                skip_files_mask.len()
291            );
292        }
293
294        return Ok((
295            Some(skip_files_mask),
296            send_predicate_to_readers.then_some(predicate),
297        ));
298    }
299
300    Ok((None, predicate))
301}
302
303/// Filters the list of files in an `IR::Scan` based on the contained predicate. This is possible
304/// if the predicate has components that refer to only the hive parts and there is no e.g.
305/// row index / slice.
306///
307/// This also applies the projection onto the hive parts.
308///
309/// # Panics
310/// Panics if `scan_ir_node` is not `IR::Scan`.
311pub fn apply_scan_predicate_to_scan_ir(
312    scan_ir_node: Node,
313    ir_arena: &mut Arena<IR>,
314    expr_arena: &mut Arena<AExpr>,
315) -> PolarsResult<()> {
316    let scan_ir_schema = IR::schema(ir_arena.get(scan_ir_node), ir_arena).into_owned();
317    let scan_ir = ir_arena.get_mut(scan_ir_node);
318
319    let IR::Scan {
320        sources,
321        hive_parts,
322        predicate,
323        predicate_file_skip_applied,
324        unified_scan_args,
325        file_info,
326        ..
327    } = scan_ir
328    else {
329        unreachable!()
330    };
331
332    if let Some(hive_parts) = hive_parts.as_mut() {
333        *hive_parts = hive_parts.filter_columns(&scan_ir_schema);
334    }
335
336    if unified_scan_args.has_row_index_or_slice() || predicate_file_skip_applied.is_some() {
337        return Ok(());
338    }
339
340    let Some(predicate) = predicate else {
341        return Ok(());
342    };
343
344    match sources {
345        // Files cannot be `gather()`ed.
346        ScanSources::Files(_) => return Ok(()),
347        ScanSources::Paths(_) | ScanSources::Buffers(_) => {},
348    }
349
350    let verbose = config::verbose();
351
352    let scan_predicate = create_scan_predicate(
353        predicate,
354        expr_arena,
355        &scan_ir_schema,
356        hive_parts.as_ref().map(|hp| hp.df().schema().as_ref()),
357        &mut ExpressionConversionState::new(true),
358        true,  // create_skip_batch_predicate
359        false, // create_column_predicates
360    )?
361    .to_io(None, file_info.schema.clone());
362
363    let (skip_files_mask, predicate_to_readers) = initialize_scan_predicate(
364        Some(&scan_predicate),
365        hive_parts.as_ref(),
366        unified_scan_args.table_statistics.as_ref(),
367        verbose,
368    )?;
369
370    if let Some(skip_files_mask) = skip_files_mask {
371        assert_eq!(skip_files_mask.len(), sources.len());
372
373        if verbose {
374            let s = if sources.len() == 1 { "" } else { "s" };
375            eprintln!(
376                "apply_scan_predicate_to_scan_ir: remove {} / {} file{s}",
377                skip_files_mask.num_skipped_files(),
378                sources.len()
379            );
380        }
381
382        let is_fully_applied = predicate_to_readers.is_none();
383        *predicate_file_skip_applied = Some(is_fully_applied);
384
385        if skip_files_mask.num_skipped_files() > 0 {
386            filter_scan_ir(scan_ir, skip_files_mask.non_skipped_files_idx_iter())
387        }
388    }
389
390    Ok(())
391}
392
393/// Filters the paths for a scan IR. This also involves performing selections on
394/// e.g. hive partitions, deletion files.
395///
396/// Note: `selected_path_indices` should be cheaply cloneable.
397///
398/// # Panics
399/// Panics if `scan_ir` is not `IR::Scan`.
400pub fn filter_scan_ir<I>(scan_ir: &mut IR, selected_path_indices: I)
401where
402    I: Iterator<Item = usize> + Clone,
403{
404    let IR::Scan {
405        sources,
406        file_info:
407            FileInfo {
408                schema: _,
409                reader_schema,
410                row_estimation,
411            },
412        hive_parts,
413        predicate: _,
414        predicate_file_skip_applied: _,
415        output_schema: _,
416        scan_type,
417        unified_scan_args,
418    } = scan_ir
419    else {
420        panic!("{:?}", scan_ir);
421    };
422
423    let size_hint = selected_path_indices.size_hint();
424
425    if size_hint.0 == sources.len()
426        && size_hint.1 == Some(sources.len())
427        && selected_path_indices
428            .clone()
429            .enumerate()
430            .all(|(i, x)| i == x)
431    {
432        return;
433    }
434
435    let UnifiedScanArgs {
436        schema: _,
437        cloud_options: _,
438        hive_options: _,
439        rechunk: _,
440        cache: _,
441        glob: _,
442        hidden_file_prefix: _,
443        projection: _,
444        column_mapping: _,
445        default_values,
446        // Ensure these are None.
447        row_index: None,
448        pre_slice: None,
449        cast_columns_policy: _,
450        missing_columns_policy: _,
451        extra_columns_policy: _,
452        include_file_paths: _,
453        table_statistics,
454        deletion_files,
455        row_count,
456    } = unified_scan_args.as_mut()
457    else {
458        panic!("{unified_scan_args:?}")
459    };
460
461    *row_count = None;
462
463    if selected_path_indices.clone().next() != Some(0) {
464        *reader_schema = None;
465
466        // Ensure the metadata is unset, otherwise it may incorrectly be used at
467        // scan. This is especially important for Parquet as it requires the
468        // correct `is_nullable` in the arrow field.
469        match scan_type.as_mut() {
470            #[cfg(feature = "parquet")]
471            FileScanIR::Parquet {
472                options: _,
473                metadata,
474            } => *metadata = None,
475
476            #[cfg(feature = "ipc")]
477            FileScanIR::Ipc {
478                options: _,
479                metadata,
480            } => *metadata = None,
481
482            #[cfg(feature = "csv")]
483            FileScanIR::Csv { options: _ } => {},
484
485            #[cfg(feature = "json")]
486            FileScanIR::NDJson { options: _ } => {},
487
488            #[cfg(feature = "python")]
489            FileScanIR::PythonDataset {
490                dataset_object: _,
491                cached_ir,
492            } => *cached_ir.lock().unwrap() = None,
493
494            #[cfg(feature = "scan_lines")]
495            FileScanIR::Lines { name: _ } => {},
496
497            FileScanIR::Anonymous {
498                options: _,
499                function: _,
500            } => {},
501        }
502    }
503
504    let selected_path_indices_idxsize = LazyCell::new(|| {
505        selected_path_indices
506            .clone()
507            .map(|i| IdxSize::try_from(i).unwrap())
508            .collect::<Vec<_>>()
509    });
510
511    *deletion_files = deletion_files.as_ref().and_then(|x| match x {
512        DeletionFilesList::IcebergPositionDelete(deletions) => {
513            let mut out = None;
514
515            for (out_idx, source_idx) in selected_path_indices.clone().enumerate() {
516                if let Some(v) = deletions.get(&source_idx) {
517                    out.get_or_insert_with(|| {
518                        PlIndexMap::with_capacity(selected_path_indices.size_hint().0 - out_idx)
519                    })
520                    .insert(out_idx, v.clone());
521                }
522            }
523
524            out.map(|x| DeletionFilesList::IcebergPositionDelete(Arc::new(x)))
525        },
526    });
527
528    *table_statistics = table_statistics.as_ref().map(|x| {
529        let df_height = IdxSize::try_from(x.0.height()).unwrap();
530
531        assert!(selected_path_indices_idxsize.iter().all(|x| *x < df_height));
532
533        TableStatistics(Arc::new(unsafe {
534            x.0.take_slice_unchecked(&selected_path_indices_idxsize)
535        }))
536    });
537
538    let original_sources_len = sources.len();
539    *sources = sources.gather(selected_path_indices.clone()).unwrap();
540    *row_estimation = (
541        None,
542        row_estimation
543            .1
544            .div_ceil(original_sources_len)
545            .saturating_mul(sources.len()),
546    );
547
548    *hive_parts = hive_parts.as_ref().map(|hp| {
549        let df = hp.df();
550        let df_height = IdxSize::try_from(df.height()).unwrap();
551
552        assert!(selected_path_indices_idxsize.iter().all(|x| *x < df_height));
553
554        // Safety: Asserted all < df.height() above.
555        unsafe { df.take_slice_unchecked(&selected_path_indices_idxsize) }.into()
556    });
557
558    *default_values = default_values.as_ref().map(|x| match x {
559        DefaultFieldValues::Iceberg(v) => {
560            let mut out = PlIndexMap::with_capacity(v.len());
561            let mut gather_indices = PlHashMap::with_capacity(v.len());
562
563            for (k, v) in v.iter() {
564                out.insert(
565                    *k,
566                    v.as_ref().map_err(Clone::clone).map(|partition_values| {
567                        if !gather_indices.contains_key(&partition_values.len()) {
568                            gather_indices.insert(
569                                partition_values.len(),
570                                selected_path_indices
571                                    .clone()
572                                    .map(|i| {
573                                        (i < partition_values.len())
574                                            .then(|| IdxSize::try_from(i).unwrap())
575                                    })
576                                    .collect::<IdxCa>(),
577                            );
578                        }
579
580                        unsafe {
581                            partition_values.take_unchecked(
582                                gather_indices.get(&partition_values.len()).unwrap(),
583                            )
584                        }
585                    }),
586                );
587            }
588
589            DefaultFieldValues::Iceberg(Arc::new(IcebergIdentityTransformedPartitionFields(out)))
590        },
591    });
592}