1use std::hash::Hash;
2use std::sync::Mutex;
3
4use polars_core::utils::get_numeric_upcast_supertype_lossless;
5use polars_io::cloud::CloudOptions;
6#[cfg(feature = "csv")]
7use polars_io::csv::read::CsvReadOptions;
8#[cfg(feature = "ipc")]
9use polars_io::ipc::IpcScanOptions;
10#[cfg(feature = "parquet")]
11use polars_io::parquet::metadata::FileMetadataRef;
12#[cfg(feature = "parquet")]
13use polars_io::parquet::read::ParquetOptions;
14use polars_io::{HiveOptions, RowIndex};
15use polars_utils::slice_enum::Slice;
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18use strum_macros::IntoStaticStr;
19
20use super::*;
21
22#[cfg(feature = "python")]
23pub mod python_dataset;
24#[cfg(feature = "python")]
25pub use python_dataset::{DATASET_PROVIDER_VTABLE, PythonDatasetProviderVTable};
26
27bitflags::bitflags! {
28 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
29 pub struct ScanFlags : u32 {
30 const SPECIALIZED_PREDICATE_FILTER = 0x01;
31 }
32}
33
34#[derive(Clone, Debug, IntoStaticStr)]
35#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
36pub enum FileScan {
38 #[cfg(feature = "csv")]
39 Csv { options: CsvReadOptions },
40
41 #[cfg(feature = "json")]
42 NDJson { options: NDJsonReadOptions },
43
44 #[cfg(feature = "parquet")]
45 Parquet {
46 options: ParquetOptions,
47 #[cfg_attr(feature = "serde", serde(skip))]
48 metadata: Option<FileMetadataRef>,
49 },
50
51 #[cfg(feature = "ipc")]
52 Ipc {
53 options: IpcScanOptions,
54 #[cfg_attr(feature = "serde", serde(skip))]
55 metadata: Option<Arc<arrow::io::ipc::read::FileMetadata>>,
56 },
57
58 #[cfg(feature = "python")]
59 PythonDataset {
60 dataset_object: Arc<python_dataset::PythonDatasetProvider>,
61
62 #[cfg_attr(feature = "serde", serde(skip, default))]
63 cached_ir: Arc<Mutex<Option<ExpandedDataset>>>,
64 },
65
66 #[cfg_attr(feature = "serde", serde(skip))]
67 Anonymous {
68 options: Arc<AnonymousScanOptions>,
69 function: Arc<dyn AnonymousScan>,
70 },
71}
72
73impl FileScan {
74 pub fn flags(&self) -> ScanFlags {
75 match self {
76 #[cfg(feature = "csv")]
77 Self::Csv { .. } => ScanFlags::empty(),
78 #[cfg(feature = "ipc")]
79 Self::Ipc { .. } => ScanFlags::empty(),
80 #[cfg(feature = "parquet")]
81 Self::Parquet { .. } => ScanFlags::SPECIALIZED_PREDICATE_FILTER,
82 #[cfg(feature = "json")]
83 Self::NDJson { .. } => ScanFlags::empty(),
84 #[allow(unreachable_patterns)]
85 _ => ScanFlags::empty(),
86 }
87 }
88
89 pub(crate) fn sort_projection(&self, _has_row_index: bool) -> bool {
90 match self {
91 #[cfg(feature = "csv")]
92 Self::Csv { .. } => true,
93 #[cfg(feature = "ipc")]
94 Self::Ipc { .. } => _has_row_index,
95 #[cfg(feature = "parquet")]
96 Self::Parquet { .. } => false,
97 #[allow(unreachable_patterns)]
98 _ => false,
99 }
100 }
101
102 pub fn streamable(&self) -> bool {
103 match self {
104 #[cfg(feature = "csv")]
105 Self::Csv { .. } => true,
106 #[cfg(feature = "ipc")]
107 Self::Ipc { .. } => false,
108 #[cfg(feature = "parquet")]
109 Self::Parquet { .. } => true,
110 #[cfg(feature = "json")]
111 Self::NDJson { .. } => false,
112 #[allow(unreachable_patterns)]
113 _ => false,
114 }
115 }
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Hash)]
119#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
120pub enum MissingColumnsPolicy {
121 #[default]
122 Raise,
123 Insert,
125}
126
127#[derive(Debug, Clone, PartialEq, Eq, Hash)]
129#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
130pub struct CastColumnsPolicy {
131 pub integer_upcast: bool,
133
134 pub float_upcast: bool,
136 pub float_downcast: bool,
138
139 pub datetime_nanoseconds_downcast: bool,
142
143 pub datetime_convert_timezone: bool,
145
146 pub missing_struct_fields: MissingColumnsPolicy,
147 pub extra_struct_fields: ExtraColumnsPolicy,
148}
149
150impl CastColumnsPolicy {
151 pub const ERROR_ON_MISMATCH: Self = Self {
153 integer_upcast: false,
154 float_upcast: false,
155 float_downcast: false,
156 datetime_nanoseconds_downcast: false,
157 datetime_convert_timezone: false,
158 missing_struct_fields: MissingColumnsPolicy::Raise,
159 extra_struct_fields: ExtraColumnsPolicy::Raise,
160 };
161}
162
163impl Default for CastColumnsPolicy {
164 fn default() -> Self {
165 Self::ERROR_ON_MISMATCH
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Hash)]
170#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
171pub enum ExtraColumnsPolicy {
172 #[default]
174 Raise,
175 Ignore,
176}
177
178#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
180#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
181pub struct UnifiedScanArgs {
182 pub schema: Option<SchemaRef>,
185 pub cloud_options: Option<CloudOptions>,
186 pub hive_options: HiveOptions,
187
188 pub rechunk: bool,
189 pub cache: bool,
190 pub glob: bool,
191
192 pub projection: Option<Arc<[PlSmallStr]>>,
193 pub row_index: Option<RowIndex>,
194 pub pre_slice: Option<Slice>,
196
197 pub cast_columns_policy: CastColumnsPolicy,
198 pub missing_columns_policy: MissingColumnsPolicy,
199 pub include_file_paths: Option<PlSmallStr>,
200}
201
202mod _file_scan_eq_hash {
205 use std::hash::{Hash, Hasher};
206 use std::sync::Arc;
207
208 use super::FileScan;
209
210 impl PartialEq for FileScan {
211 fn eq(&self, other: &Self) -> bool {
212 FileScanEqHashWrap::from(self) == FileScanEqHashWrap::from(other)
213 }
214 }
215
216 impl Eq for FileScan {}
217
218 impl Hash for FileScan {
219 fn hash<H: Hasher>(&self, state: &mut H) {
220 FileScanEqHashWrap::from(self).hash(state)
221 }
222 }
223
224 #[derive(PartialEq, Hash)]
228 pub enum FileScanEqHashWrap<'a> {
229 #[cfg(feature = "csv")]
230 Csv {
231 options: &'a polars_io::csv::read::CsvReadOptions,
232 },
233
234 #[cfg(feature = "json")]
235 NDJson {
236 options: &'a crate::prelude::NDJsonReadOptions,
237 },
238
239 #[cfg(feature = "parquet")]
240 Parquet {
241 options: &'a polars_io::prelude::ParquetOptions,
242 metadata: Option<usize>,
243 },
244
245 #[cfg(feature = "ipc")]
246 Ipc {
247 options: &'a polars_io::prelude::IpcScanOptions,
248 metadata: Option<usize>,
249 },
250
251 #[cfg(feature = "python")]
252 PythonDataset {
253 dataset_object: usize,
254 cached_ir: usize,
255 },
256
257 Anonymous {
258 options: &'a crate::dsl::AnonymousScanOptions,
259 function: usize,
260 },
261
262 #[expect(unused)]
264 Phantom(&'a ()),
265 }
266
267 impl<'a> From<&'a FileScan> for FileScanEqHashWrap<'a> {
268 fn from(value: &'a FileScan) -> Self {
269 match value {
270 #[cfg(feature = "csv")]
271 FileScan::Csv { options } => FileScanEqHashWrap::Csv { options },
272
273 #[cfg(feature = "json")]
274 FileScan::NDJson { options } => FileScanEqHashWrap::NDJson { options },
275
276 #[cfg(feature = "parquet")]
277 FileScan::Parquet { options, metadata } => FileScanEqHashWrap::Parquet {
278 options,
279 metadata: metadata.as_ref().map(arc_as_ptr),
280 },
281
282 #[cfg(feature = "ipc")]
283 FileScan::Ipc { options, metadata } => FileScanEqHashWrap::Ipc {
284 options,
285 metadata: metadata.as_ref().map(arc_as_ptr),
286 },
287
288 #[cfg(feature = "python")]
289 FileScan::PythonDataset {
290 dataset_object,
291 cached_ir,
292 } => FileScanEqHashWrap::PythonDataset {
293 dataset_object: arc_as_ptr(dataset_object),
294 cached_ir: arc_as_ptr(cached_ir),
295 },
296
297 FileScan::Anonymous { options, function } => FileScanEqHashWrap::Anonymous {
298 options,
299 function: arc_as_ptr(function),
300 },
301 }
302 }
303 }
304
305 fn arc_as_ptr<T: ?Sized>(arc: &Arc<T>) -> usize {
306 Arc::as_ptr(arc) as *const () as usize
307 }
308}
309
310impl CastColumnsPolicy {
311 pub fn should_cast_column(
316 &self,
317 column_name: &str,
318 target_dtype: &DataType,
319 incoming_dtype: &DataType,
320 ) -> PolarsResult<bool> {
321 let mismatch_err = |hint: &str| {
322 let hint_spacing = if hint.is_empty() { "" } else { ", " };
323
324 polars_bail!(
325 SchemaMismatch:
326 "data type mismatch for column {}: incoming: {:?} != target: {:?}{}{}",
327 column_name,
328 incoming_dtype,
329 target_dtype,
330 hint_spacing,
331 hint,
332 )
333 };
334
335 #[cfg(feature = "dtype-struct")]
339 if let DataType::Struct(target_fields) = target_dtype {
340 let DataType::Struct(incoming_fields) = incoming_dtype else {
341 return mismatch_err("");
342 };
343
344 let incoming_fields_schema = PlHashMap::from_iter(
345 incoming_fields
346 .iter()
347 .enumerate()
348 .map(|(i, fld)| (fld.name.as_str(), (i, &fld.dtype))),
349 );
350
351 let mut should_cast = incoming_fields.len() != target_fields.len();
352
353 for (target_idx, target_field) in target_fields.iter().enumerate() {
354 let Some((incoming_idx, incoming_field_dtype)) =
355 incoming_fields_schema.get(target_field.name().as_str())
356 else {
357 match self.missing_struct_fields {
358 MissingColumnsPolicy::Raise => {
359 return mismatch_err(&format!(
360 "encountered missing struct field: {}, \
361 hint: pass cast_options=pl.ScanCastOptions(missing_struct_fields='insert')",
362 target_field.name(),
363 ));
364 },
365 MissingColumnsPolicy::Insert => {
366 should_cast = true;
367 continue;
369 },
370 };
371 };
372
373 should_cast |= *incoming_idx != target_idx;
377
378 should_cast |= self.should_cast_column(
379 column_name,
380 &target_field.dtype,
381 incoming_field_dtype,
382 )?;
383 }
384
385 let mut target_fields_schema = incoming_fields_schema;
389 target_fields_schema.clear();
390
391 target_fields_schema.extend(
392 target_fields
393 .iter()
394 .enumerate()
395 .map(|(i, fld)| (fld.name.as_str(), (i, &fld.dtype))),
396 );
397
398 for fld in incoming_fields {
399 if !target_fields_schema.contains_key(fld.name.as_str()) {
400 match self.extra_struct_fields {
401 ExtraColumnsPolicy::Ignore => {
402 should_cast = true;
403 break;
404 },
405 ExtraColumnsPolicy::Raise => {
406 return mismatch_err(&format!(
407 "encountered extra struct field: {}, \
408 hint: pass cast_options=pl.ScanCastOptions(extra_struct_fields='ignore')",
409 &fld.name,
410 ));
411 },
412 }
413 }
414 }
415
416 return Ok(should_cast);
417 }
418
419 if let DataType::List(target_inner) = target_dtype {
420 let DataType::List(incoming_inner) = incoming_dtype else {
421 return mismatch_err("");
422 };
423
424 return self.should_cast_column(column_name, target_inner, incoming_inner);
425 }
426
427 #[cfg(feature = "dtype-array")]
428 if let DataType::Array(target_inner, target_width) = target_dtype {
429 let DataType::Array(incoming_inner, incoming_width) = incoming_dtype else {
430 return mismatch_err("");
431 };
432
433 if incoming_width != target_width {
434 return mismatch_err("");
435 }
436
437 return self.should_cast_column(column_name, target_inner, incoming_inner);
438 }
439
440 debug_assert!(!target_dtype.is_nested());
443
444 if target_dtype == incoming_dtype {
445 return Ok(false);
446 }
447
448 if target_dtype.is_integer() && incoming_dtype.is_integer() {
453 if !self.integer_upcast {
454 return mismatch_err(
455 "hint: pass cast_options=pl.ScanCastOptions(integer_cast='upcast')",
456 );
457 }
458
459 return match get_numeric_upcast_supertype_lossless(incoming_dtype, target_dtype) {
460 Some(ref v) if v == target_dtype => Ok(true),
461 _ => mismatch_err("incoming dtype cannot safely cast to target dtype"),
462 };
463 }
464
465 if target_dtype.is_float() && incoming_dtype.is_float() {
466 return match (target_dtype, incoming_dtype) {
467 (DataType::Float64, DataType::Float32) => {
468 if self.float_upcast {
469 Ok(true)
470 } else {
471 mismatch_err(
472 "hint: pass cast_options=pl.ScanCastOptions(float_cast='upcast')",
473 )
474 }
475 },
476
477 (DataType::Float32, DataType::Float64) => {
478 if self.float_downcast {
479 Ok(true)
480 } else {
481 mismatch_err(
482 "hint: pass cast_options=pl.ScanCastOptions(float_cast='downcast')",
483 )
484 }
485 },
486
487 _ => unreachable!(),
488 };
489 }
490
491 if let (
492 DataType::Datetime(target_unit, target_zone),
493 DataType::Datetime(incoming_unit, incoming_zone),
494 ) = (target_dtype, incoming_dtype)
495 {
496 if !self.datetime_convert_timezone
498 && !TimeZone::eq_none_as_utc(incoming_zone.as_ref(), target_zone.as_ref())
499 {
500 return mismatch_err(
501 "hint: pass cast_options=pl.ScanCastOptions(datetime_cast='convert-timezone')",
502 );
503 }
504
505 if target_unit != incoming_unit {
507 return if let TimeUnit::Nanoseconds = incoming_unit {
508 if self.datetime_nanoseconds_downcast {
509 Ok(true)
510 } else {
511 mismatch_err(
512 "hint: pass cast_options=pl.ScanCastOptions(datetime_cast='nanosecond-downcast')",
513 )
514 }
515 } else {
516 mismatch_err("")
518 };
519 }
520
521 return Ok(true);
523 }
524
525 mismatch_err("")
526 }
527}