polars_plan/dsl/file_scan/
mod.rs

1use std::hash::Hash;
2use std::sync::Mutex;
3
4use polars_core::utils::get_numeric_upcast_supertype_lossless;
5use polars_io::cloud::CloudOptions;
6#[cfg(feature = "csv")]
7use polars_io::csv::read::CsvReadOptions;
8#[cfg(feature = "ipc")]
9use polars_io::ipc::IpcScanOptions;
10#[cfg(feature = "parquet")]
11use polars_io::parquet::metadata::FileMetadataRef;
12#[cfg(feature = "parquet")]
13use polars_io::parquet::read::ParquetOptions;
14use polars_io::{HiveOptions, RowIndex};
15use polars_utils::slice_enum::Slice;
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18use strum_macros::IntoStaticStr;
19
20use super::*;
21
22#[cfg(feature = "python")]
23pub mod python_dataset;
24#[cfg(feature = "python")]
25pub use python_dataset::{DATASET_PROVIDER_VTABLE, PythonDatasetProviderVTable};
26
27bitflags::bitflags! {
28    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
29    pub struct ScanFlags : u32 {
30        const SPECIALIZED_PREDICATE_FILTER = 0x01;
31    }
32}
33
34#[derive(Clone, Debug, IntoStaticStr)]
35#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
36// TODO: Arc<> some of the options and the cloud options.
37pub enum FileScan {
38    #[cfg(feature = "csv")]
39    Csv { options: CsvReadOptions },
40
41    #[cfg(feature = "json")]
42    NDJson { options: NDJsonReadOptions },
43
44    #[cfg(feature = "parquet")]
45    Parquet {
46        options: ParquetOptions,
47        #[cfg_attr(feature = "serde", serde(skip))]
48        metadata: Option<FileMetadataRef>,
49    },
50
51    #[cfg(feature = "ipc")]
52    Ipc {
53        options: IpcScanOptions,
54        #[cfg_attr(feature = "serde", serde(skip))]
55        metadata: Option<Arc<arrow::io::ipc::read::FileMetadata>>,
56    },
57
58    #[cfg(feature = "python")]
59    PythonDataset {
60        dataset_object: Arc<python_dataset::PythonDatasetProvider>,
61
62        #[cfg_attr(feature = "serde", serde(skip, default))]
63        cached_ir: Arc<Mutex<Option<ExpandedDataset>>>,
64    },
65
66    #[cfg_attr(feature = "serde", serde(skip))]
67    Anonymous {
68        options: Arc<AnonymousScanOptions>,
69        function: Arc<dyn AnonymousScan>,
70    },
71}
72
73impl FileScan {
74    pub fn flags(&self) -> ScanFlags {
75        match self {
76            #[cfg(feature = "csv")]
77            Self::Csv { .. } => ScanFlags::empty(),
78            #[cfg(feature = "ipc")]
79            Self::Ipc { .. } => ScanFlags::empty(),
80            #[cfg(feature = "parquet")]
81            Self::Parquet { .. } => ScanFlags::SPECIALIZED_PREDICATE_FILTER,
82            #[cfg(feature = "json")]
83            Self::NDJson { .. } => ScanFlags::empty(),
84            #[allow(unreachable_patterns)]
85            _ => ScanFlags::empty(),
86        }
87    }
88
89    pub(crate) fn sort_projection(&self, _has_row_index: bool) -> bool {
90        match self {
91            #[cfg(feature = "csv")]
92            Self::Csv { .. } => true,
93            #[cfg(feature = "ipc")]
94            Self::Ipc { .. } => _has_row_index,
95            #[cfg(feature = "parquet")]
96            Self::Parquet { .. } => false,
97            #[allow(unreachable_patterns)]
98            _ => false,
99        }
100    }
101
102    pub fn streamable(&self) -> bool {
103        match self {
104            #[cfg(feature = "csv")]
105            Self::Csv { .. } => true,
106            #[cfg(feature = "ipc")]
107            Self::Ipc { .. } => false,
108            #[cfg(feature = "parquet")]
109            Self::Parquet { .. } => true,
110            #[cfg(feature = "json")]
111            Self::NDJson { .. } => false,
112            #[allow(unreachable_patterns)]
113            _ => false,
114        }
115    }
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Hash)]
119#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
120pub enum MissingColumnsPolicy {
121    #[default]
122    Raise,
123    /// Inserts full-NULL columns for the missing ones.
124    Insert,
125}
126
127/// Used by scans.
128#[derive(Debug, Clone, PartialEq, Eq, Hash)]
129#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
130pub struct CastColumnsPolicy {
131    /// Allow casting when target dtype is lossless supertype
132    pub integer_upcast: bool,
133
134    /// Allow Float32 -> Float64
135    pub float_upcast: bool,
136    /// Allow Float64 -> Float32
137    pub float_downcast: bool,
138
139    /// Allow datetime[ns] to be casted to any lower precision. Important for
140    /// being able to read datasets written by spark.
141    pub datetime_nanoseconds_downcast: bool,
142
143    /// Allow casting to change time units.
144    pub datetime_convert_timezone: bool,
145
146    pub missing_struct_fields: MissingColumnsPolicy,
147    pub extra_struct_fields: ExtraColumnsPolicy,
148}
149
150impl CastColumnsPolicy {
151    /// Configuration variant that defaults to raising on mismatch.
152    pub const ERROR_ON_MISMATCH: Self = Self {
153        integer_upcast: false,
154        float_upcast: false,
155        float_downcast: false,
156        datetime_nanoseconds_downcast: false,
157        datetime_convert_timezone: false,
158        missing_struct_fields: MissingColumnsPolicy::Raise,
159        extra_struct_fields: ExtraColumnsPolicy::Raise,
160    };
161}
162
163impl Default for CastColumnsPolicy {
164    fn default() -> Self {
165        Self::ERROR_ON_MISMATCH
166    }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Hash)]
170#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
171pub enum ExtraColumnsPolicy {
172    /// Error if there are extra columns outside the target schema.
173    #[default]
174    Raise,
175    Ignore,
176}
177
178/// Scan arguments shared across different scan types.
179#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
180#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
181pub struct UnifiedScanArgs {
182    /// User-provided schema of the file. Will be inferred during IR conversion
183    /// if None.
184    pub schema: Option<SchemaRef>,
185    pub cloud_options: Option<CloudOptions>,
186    pub hive_options: HiveOptions,
187
188    pub rechunk: bool,
189    pub cache: bool,
190    pub glob: bool,
191
192    pub projection: Option<Arc<[PlSmallStr]>>,
193    pub row_index: Option<RowIndex>,
194    /// Slice applied before predicates
195    pub pre_slice: Option<Slice>,
196
197    pub cast_columns_policy: CastColumnsPolicy,
198    pub missing_columns_policy: MissingColumnsPolicy,
199    pub include_file_paths: Option<PlSmallStr>,
200}
201
202/// Manual impls of Eq/Hash, as some fields are `Arc<T>` where T does not have Eq/Hash. For these
203/// fields we compare the pointer addresses instead.
204mod _file_scan_eq_hash {
205    use std::hash::{Hash, Hasher};
206    use std::sync::Arc;
207
208    use super::FileScan;
209
210    impl PartialEq for FileScan {
211        fn eq(&self, other: &Self) -> bool {
212            FileScanEqHashWrap::from(self) == FileScanEqHashWrap::from(other)
213        }
214    }
215
216    impl Eq for FileScan {}
217
218    impl Hash for FileScan {
219        fn hash<H: Hasher>(&self, state: &mut H) {
220            FileScanEqHashWrap::from(self).hash(state)
221        }
222    }
223
224    /// # Hash / Eq safety
225    /// * All usizes originate from `Arc<>`s, and the lifetime of this enum is bound to that of the
226    ///   input ref.
227    #[derive(PartialEq, Hash)]
228    pub enum FileScanEqHashWrap<'a> {
229        #[cfg(feature = "csv")]
230        Csv {
231            options: &'a polars_io::csv::read::CsvReadOptions,
232        },
233
234        #[cfg(feature = "json")]
235        NDJson {
236            options: &'a crate::prelude::NDJsonReadOptions,
237        },
238
239        #[cfg(feature = "parquet")]
240        Parquet {
241            options: &'a polars_io::prelude::ParquetOptions,
242            metadata: Option<usize>,
243        },
244
245        #[cfg(feature = "ipc")]
246        Ipc {
247            options: &'a polars_io::prelude::IpcScanOptions,
248            metadata: Option<usize>,
249        },
250
251        #[cfg(feature = "python")]
252        PythonDataset {
253            dataset_object: usize,
254            cached_ir: usize,
255        },
256
257        Anonymous {
258            options: &'a crate::dsl::AnonymousScanOptions,
259            function: usize,
260        },
261
262        /// Variant to ensure the lifetime is used regardless of feature gate combination.
263        #[expect(unused)]
264        Phantom(&'a ()),
265    }
266
267    impl<'a> From<&'a FileScan> for FileScanEqHashWrap<'a> {
268        fn from(value: &'a FileScan) -> Self {
269            match value {
270                #[cfg(feature = "csv")]
271                FileScan::Csv { options } => FileScanEqHashWrap::Csv { options },
272
273                #[cfg(feature = "json")]
274                FileScan::NDJson { options } => FileScanEqHashWrap::NDJson { options },
275
276                #[cfg(feature = "parquet")]
277                FileScan::Parquet { options, metadata } => FileScanEqHashWrap::Parquet {
278                    options,
279                    metadata: metadata.as_ref().map(arc_as_ptr),
280                },
281
282                #[cfg(feature = "ipc")]
283                FileScan::Ipc { options, metadata } => FileScanEqHashWrap::Ipc {
284                    options,
285                    metadata: metadata.as_ref().map(arc_as_ptr),
286                },
287
288                #[cfg(feature = "python")]
289                FileScan::PythonDataset {
290                    dataset_object,
291                    cached_ir,
292                } => FileScanEqHashWrap::PythonDataset {
293                    dataset_object: arc_as_ptr(dataset_object),
294                    cached_ir: arc_as_ptr(cached_ir),
295                },
296
297                FileScan::Anonymous { options, function } => FileScanEqHashWrap::Anonymous {
298                    options,
299                    function: arc_as_ptr(function),
300                },
301            }
302        }
303    }
304
305    fn arc_as_ptr<T: ?Sized>(arc: &Arc<T>) -> usize {
306        Arc::as_ptr(arc) as *const () as usize
307    }
308}
309
310impl CastColumnsPolicy {
311    /// # Returns
312    /// * Ok(true): Cast needed to target dtype
313    /// * Ok(false): No casting needed
314    /// * Err(_): Forbidden by configuration, or incompatible types.
315    pub fn should_cast_column(
316        &self,
317        column_name: &str,
318        target_dtype: &DataType,
319        incoming_dtype: &DataType,
320    ) -> PolarsResult<bool> {
321        let mismatch_err = |hint: &str| {
322            let hint_spacing = if hint.is_empty() { "" } else { ", " };
323
324            polars_bail!(
325                SchemaMismatch:
326                "data type mismatch for column {}: incoming: {:?} != target: {:?}{}{}",
327                column_name,
328                incoming_dtype,
329                target_dtype,
330                hint_spacing,
331                hint,
332            )
333        };
334
335        // We intercept the nested types first to prevent an expensive recursive eq - recursion
336        // is instead done manually through this function.
337
338        #[cfg(feature = "dtype-struct")]
339        if let DataType::Struct(target_fields) = target_dtype {
340            let DataType::Struct(incoming_fields) = incoming_dtype else {
341                return mismatch_err("");
342            };
343
344            let incoming_fields_schema = PlHashMap::from_iter(
345                incoming_fields
346                    .iter()
347                    .enumerate()
348                    .map(|(i, fld)| (fld.name.as_str(), (i, &fld.dtype))),
349            );
350
351            let mut should_cast = incoming_fields.len() != target_fields.len();
352
353            for (target_idx, target_field) in target_fields.iter().enumerate() {
354                let Some((incoming_idx, incoming_field_dtype)) =
355                    incoming_fields_schema.get(target_field.name().as_str())
356                else {
357                    match self.missing_struct_fields {
358                        MissingColumnsPolicy::Raise => {
359                            return mismatch_err(&format!(
360                                "encountered missing struct field: {}, \
361                                hint: pass cast_options=pl.ScanCastOptions(missing_struct_fields='insert')",
362                                target_field.name(),
363                            ));
364                        },
365                        MissingColumnsPolicy::Insert => {
366                            should_cast = true;
367                            // Must keep checking the rest of the fields.
368                            continue;
369                        },
370                    };
371                };
372
373                // # Note
374                // We also need to cast if the struct fields are out of order. Currently there is
375                // no API parameter to control this - we always do this by default.
376                should_cast |= *incoming_idx != target_idx;
377
378                should_cast |= self.should_cast_column(
379                    column_name,
380                    &target_field.dtype,
381                    incoming_field_dtype,
382                )?;
383            }
384
385            // Casting is also needed if there are extra fields, check them here.
386
387            // Take and re-use hashmap
388            let mut target_fields_schema = incoming_fields_schema;
389            target_fields_schema.clear();
390
391            target_fields_schema.extend(
392                target_fields
393                    .iter()
394                    .enumerate()
395                    .map(|(i, fld)| (fld.name.as_str(), (i, &fld.dtype))),
396            );
397
398            for fld in incoming_fields {
399                if !target_fields_schema.contains_key(fld.name.as_str()) {
400                    match self.extra_struct_fields {
401                        ExtraColumnsPolicy::Ignore => {
402                            should_cast = true;
403                            break;
404                        },
405                        ExtraColumnsPolicy::Raise => {
406                            return mismatch_err(&format!(
407                                "encountered extra struct field: {}, \
408                                hint: pass cast_options=pl.ScanCastOptions(extra_struct_fields='ignore')",
409                                &fld.name,
410                            ));
411                        },
412                    }
413                }
414            }
415
416            return Ok(should_cast);
417        }
418
419        if let DataType::List(target_inner) = target_dtype {
420            let DataType::List(incoming_inner) = incoming_dtype else {
421                return mismatch_err("");
422            };
423
424            return self.should_cast_column(column_name, target_inner, incoming_inner);
425        }
426
427        #[cfg(feature = "dtype-array")]
428        if let DataType::Array(target_inner, target_width) = target_dtype {
429            let DataType::Array(incoming_inner, incoming_width) = incoming_dtype else {
430                return mismatch_err("");
431            };
432
433            if incoming_width != target_width {
434                return mismatch_err("");
435            }
436
437            return self.should_cast_column(column_name, target_inner, incoming_inner);
438        }
439
440        // Eq here should be cheap as we have intercepted all nested types above.
441
442        debug_assert!(!target_dtype.is_nested());
443
444        if target_dtype == incoming_dtype {
445            return Ok(false);
446        }
447
448        //
449        // After this point the dtypes are mismatching.
450        //
451
452        if target_dtype.is_integer() && incoming_dtype.is_integer() {
453            if !self.integer_upcast {
454                return mismatch_err(
455                    "hint: pass cast_options=pl.ScanCastOptions(integer_cast='upcast')",
456                );
457            }
458
459            return match get_numeric_upcast_supertype_lossless(incoming_dtype, target_dtype) {
460                Some(ref v) if v == target_dtype => Ok(true),
461                _ => mismatch_err("incoming dtype cannot safely cast to target dtype"),
462            };
463        }
464
465        if target_dtype.is_float() && incoming_dtype.is_float() {
466            return match (target_dtype, incoming_dtype) {
467                (DataType::Float64, DataType::Float32) => {
468                    if self.float_upcast {
469                        Ok(true)
470                    } else {
471                        mismatch_err(
472                            "hint: pass cast_options=pl.ScanCastOptions(float_cast='upcast')",
473                        )
474                    }
475                },
476
477                (DataType::Float32, DataType::Float64) => {
478                    if self.float_downcast {
479                        Ok(true)
480                    } else {
481                        mismatch_err(
482                            "hint: pass cast_options=pl.ScanCastOptions(float_cast='downcast')",
483                        )
484                    }
485                },
486
487                _ => unreachable!(),
488            };
489        }
490
491        if let (
492            DataType::Datetime(target_unit, target_zone),
493            DataType::Datetime(incoming_unit, incoming_zone),
494        ) = (target_dtype, incoming_dtype)
495        {
496            // Check timezone
497            if !self.datetime_convert_timezone
498                && !TimeZone::eq_none_as_utc(incoming_zone.as_ref(), target_zone.as_ref())
499            {
500                return mismatch_err(
501                    "hint: pass cast_options=pl.ScanCastOptions(datetime_cast='convert-timezone')",
502                );
503            }
504
505            // Check unit
506            if target_unit != incoming_unit {
507                return if let TimeUnit::Nanoseconds = incoming_unit {
508                    if self.datetime_nanoseconds_downcast {
509                        Ok(true)
510                    } else {
511                        mismatch_err(
512                            "hint: pass cast_options=pl.ScanCastOptions(datetime_cast='nanosecond-downcast')",
513                        )
514                    }
515                } else {
516                    // Currently don't have parameter for controlling arbitrary time unit casting.
517                    mismatch_err("")
518                };
519            }
520
521            // Dtype differs and we are allowed to coerce
522            return Ok(true);
523        }
524
525        mismatch_err("")
526    }
527}