polars_plan/plans/
schema.rs

1use std::ops::Deref;
2use std::sync::Mutex;
3
4use arrow::datatypes::ArrowSchemaRef;
5use either::Either;
6use polars_core::prelude::*;
7use polars_utils::format_pl_smallstr;
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::prelude::*;
12
13impl DslPlan {
14    // Warning! This should not be used on the DSL internally.
15    // All schema resolving should be done during conversion to [`IR`].
16
17    /// Compute the schema. This requires conversion to [`IR`] and type-resolving.
18    pub fn compute_schema(&self) -> PolarsResult<SchemaRef> {
19        let mut lp_arena = Default::default();
20        let mut expr_arena = Default::default();
21        let node = to_alp(
22            self.clone(),
23            &mut expr_arena,
24            &mut lp_arena,
25            &mut OptFlags::schema_only(),
26        )?;
27
28        Ok(lp_arena.get(node).schema(&lp_arena).into_owned())
29    }
30}
31
32#[derive(Clone, Debug)]
33#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
34pub struct FileInfo {
35    /// Schema of the physical file.
36    ///
37    /// Notes:
38    /// - Does not include logical columns like `include_file_path` and row index.
39    /// - Always includes all hive columns.
40    pub schema: SchemaRef,
41    /// Stores the schema used for the reader, as the main schema can contain
42    /// extra hive columns.
43    pub reader_schema: Option<Either<ArrowSchemaRef, SchemaRef>>,
44    /// - known size
45    /// - estimated size (set to usize::max if unknown).
46    pub row_estimation: (Option<usize>, usize),
47}
48
49// Manual default because `row_estimation.1` needs to be `usize::MAX`.
50impl Default for FileInfo {
51    fn default() -> Self {
52        FileInfo {
53            schema: Default::default(),
54            reader_schema: None,
55            row_estimation: (None, usize::MAX),
56        }
57    }
58}
59
60impl FileInfo {
61    /// Constructs a new [`FileInfo`].
62    pub fn new(
63        schema: SchemaRef,
64        reader_schema: Option<Either<ArrowSchemaRef, SchemaRef>>,
65        row_estimation: (Option<usize>, usize),
66    ) -> Self {
67        Self {
68            schema: schema.clone(),
69            reader_schema,
70            row_estimation,
71        }
72    }
73
74    /// Merge the [`Schema`] of a [`HivePartitions`] with the schema of this [`FileInfo`].
75    pub fn update_schema_with_hive_schema(&mut self, hive_schema: SchemaRef) {
76        let schema = Arc::make_mut(&mut self.schema);
77
78        for field in hive_schema.iter_fields() {
79            if let Some(existing) = schema.get_mut(&field.name) {
80                *existing = field.dtype().clone();
81            } else {
82                schema
83                    .insert_at_index(schema.len(), field.name, field.dtype.clone())
84                    .unwrap();
85            }
86        }
87    }
88}
89
90#[cfg(feature = "streaming")]
91fn estimate_sizes(
92    known_size: Option<usize>,
93    estimated_size: usize,
94    filter_count: usize,
95) -> (Option<usize>, usize) {
96    match (known_size, filter_count) {
97        (Some(known_size), 0) => (Some(known_size), estimated_size),
98        (None, 0) => (None, estimated_size),
99        (_, _) => (
100            None,
101            (estimated_size as f32 * 0.9f32.powf(filter_count as f32)) as usize,
102        ),
103    }
104}
105
106#[cfg(feature = "streaming")]
107pub fn set_estimated_row_counts(
108    root: Node,
109    lp_arena: &mut Arena<IR>,
110    expr_arena: &Arena<AExpr>,
111    mut _filter_count: usize,
112    scratch: &mut Vec<Node>,
113) -> (Option<usize>, usize, usize) {
114    use IR::*;
115
116    fn apply_slice(out: &mut (Option<usize>, usize, usize), slice: Option<(i64, usize)>) {
117        if let Some((_, len)) = slice {
118            out.0 = out.0.map(|known_size| std::cmp::min(len, known_size));
119            out.1 = std::cmp::min(len, out.1);
120        }
121    }
122
123    match lp_arena.get(root) {
124        Filter { predicate, input } => {
125            _filter_count += expr_arena
126                .iter(predicate.node())
127                .filter(|(_, ae)| matches!(ae, AExpr::BinaryExpr { .. }))
128                .count()
129                + 1;
130            set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count, scratch)
131        },
132        Slice { input, len, .. } => {
133            let len = *len as usize;
134            let mut out =
135                set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count, scratch);
136            apply_slice(&mut out, Some((0, len)));
137            out
138        },
139        Union { .. } => {
140            if let Union {
141                inputs,
142                mut options,
143            } = lp_arena.take(root)
144            {
145                let mut sum_output = (None, 0usize);
146                for input in &inputs {
147                    let mut out =
148                        set_estimated_row_counts(*input, lp_arena, expr_arena, 0, scratch);
149                    if let Some((_offset, len)) = options.slice {
150                        apply_slice(&mut out, Some((0, len)))
151                    }
152                    // todo! deal with known as well
153                    let out = estimate_sizes(out.0, out.1, out.2);
154                    sum_output.1 = sum_output.1.saturating_add(out.1);
155                }
156                options.rows = sum_output;
157                lp_arena.replace(root, Union { inputs, options });
158                (sum_output.0, sum_output.1, 0)
159            } else {
160                unreachable!()
161            }
162        },
163        Join { .. } => {
164            if let Join {
165                input_left,
166                input_right,
167                mut options,
168                schema,
169                left_on,
170                right_on,
171            } = lp_arena.take(root)
172            {
173                let mut_options = Arc::make_mut(&mut options);
174                let (known_size, estimated_size, filter_count_left) =
175                    set_estimated_row_counts(input_left, lp_arena, expr_arena, 0, scratch);
176                mut_options.rows_left =
177                    estimate_sizes(known_size, estimated_size, filter_count_left);
178                let (known_size, estimated_size, filter_count_right) =
179                    set_estimated_row_counts(input_right, lp_arena, expr_arena, 0, scratch);
180                mut_options.rows_right =
181                    estimate_sizes(known_size, estimated_size, filter_count_right);
182
183                let mut out = match options.args.how {
184                    JoinType::Left => {
185                        let (known_size, estimated_size) = options.rows_left;
186                        (known_size, estimated_size, filter_count_left)
187                    },
188                    JoinType::Cross | JoinType::Full => {
189                        let (known_size_left, estimated_size_left) = options.rows_left;
190                        let (known_size_right, estimated_size_right) = options.rows_right;
191                        match (known_size_left, known_size_right) {
192                            (Some(l), Some(r)) => {
193                                (Some(l * r), estimated_size_left, estimated_size_right)
194                            },
195                            _ => (None, estimated_size_left * estimated_size_right, 0),
196                        }
197                    },
198                    _ => {
199                        let (known_size_left, estimated_size_left) = options.rows_left;
200                        let (known_size_right, estimated_size_right) = options.rows_right;
201                        if estimated_size_left > estimated_size_right {
202                            (known_size_left, estimated_size_left, 0)
203                        } else {
204                            (known_size_right, estimated_size_right, 0)
205                        }
206                    },
207                };
208                apply_slice(&mut out, options.args.slice);
209                lp_arena.replace(
210                    root,
211                    Join {
212                        input_left,
213                        input_right,
214                        options,
215                        schema,
216                        left_on,
217                        right_on,
218                    },
219                );
220                out
221            } else {
222                unreachable!()
223            }
224        },
225        DataFrameScan { df, .. } => {
226            let len = df.height();
227            (Some(len), len, _filter_count)
228        },
229        Scan { file_info, .. } => {
230            let (known_size, estimated_size) = file_info.row_estimation;
231            (known_size, estimated_size, _filter_count)
232        },
233        #[cfg(feature = "python")]
234        PythonScan { .. } => {
235            // TODO! get row estimation.
236            (None, usize::MAX, _filter_count)
237        },
238        lp => {
239            lp.copy_inputs(scratch);
240            let mut sum_output = (None, 0, 0);
241            while let Some(input) = scratch.pop() {
242                let out =
243                    set_estimated_row_counts(input, lp_arena, expr_arena, _filter_count, scratch);
244                sum_output.1 += out.1;
245                sum_output.2 += out.2;
246                sum_output.0 = match sum_output.0 {
247                    None => out.0,
248                    p => p,
249                };
250            }
251            sum_output
252        },
253    }
254}
255
256pub(crate) fn det_join_schema(
257    schema_left: &SchemaRef,
258    schema_right: &SchemaRef,
259    left_on: &[ExprIR],
260    right_on: &[ExprIR],
261    options: &JoinOptions,
262    expr_arena: &Arena<AExpr>,
263) -> PolarsResult<SchemaRef> {
264    match &options.args.how {
265        // semi and anti joins are just filtering operations
266        // the schema will never change.
267        #[cfg(feature = "semi_anti_join")]
268        JoinType::Semi | JoinType::Anti => Ok(schema_left.clone()),
269        // Right-join with coalesce enabled will coalesce LHS columns into RHS columns (i.e. LHS columns
270        // are removed). This is the opposite of what a left join does so it has its own codepath.
271        //
272        // E.g. df(cols=[A, B]).right_join(df(cols=[A, B]), on=A, coalesce=True)
273        //
274        // will result in
275        //
276        // df(cols=[B, A, B_right])
277        JoinType::Right if options.args.should_coalesce() => {
278            // Get join names.
279            let mut join_on_left: PlHashSet<_> = PlHashSet::with_capacity(left_on.len());
280            for e in left_on {
281                let field = e.field(schema_left, Context::Default, expr_arena)?;
282                join_on_left.insert(field.name);
283            }
284
285            let mut join_on_right: PlHashSet<_> = PlHashSet::with_capacity(right_on.len());
286            for e in right_on {
287                let field = e.field(schema_right, Context::Default, expr_arena)?;
288                join_on_right.insert(field.name);
289            }
290
291            // For the error message
292            let mut suffixed = None;
293
294            let new_schema = Schema::with_capacity(schema_left.len() + schema_right.len())
295                // Columns from left, excluding those used as join keys
296                .hstack(schema_left.iter().filter_map(|(name, dtype)| {
297                    if join_on_left.contains(name) {
298                        return None;
299                    }
300
301                    Some((name.clone(), dtype.clone()))
302                }))?
303                // Columns from right
304                .hstack(schema_right.iter().map(|(name, dtype)| {
305                    suffixed = None;
306
307                    let in_left_schema = schema_left.contains(name.as_str());
308                    let is_coalesced = join_on_left.contains(name.as_str());
309
310                    if in_left_schema && !is_coalesced {
311                        suffixed = Some(format_pl_smallstr!("{}{}", name, options.args.suffix()));
312                        (suffixed.clone().unwrap(), dtype.clone())
313                    } else {
314                        (name.clone(), dtype.clone())
315                    }
316                }))
317                .map_err(|e| {
318                    if let Some(column) = suffixed {
319                        join_suffix_duplicate_help_msg(&column)
320                    } else {
321                        e
322                    }
323                })?;
324
325            Ok(Arc::new(new_schema))
326        },
327        _how => {
328            let mut new_schema = Schema::with_capacity(schema_left.len() + schema_right.len())
329                .hstack(schema_left.iter_fields())?;
330
331            let is_coalesced = options.args.should_coalesce();
332
333            let mut _asof_pre_added_rhs_keys: PlHashSet<PlSmallStr> = PlHashSet::new();
334
335            // Handles coalescing of asof-joins.
336            // Asof joins are not equi-joins
337            // so the columns that are joined on, may have different
338            // values so if the right has a different name, it is added to the schema
339            #[cfg(feature = "asof_join")]
340            if matches!(_how, JoinType::AsOf(_)) {
341                for (left_on, right_on) in left_on.iter().zip(right_on) {
342                    let field_left = left_on.field(schema_left, Context::Default, expr_arena)?;
343                    let field_right = right_on.field(schema_right, Context::Default, expr_arena)?;
344
345                    if is_coalesced && field_left.name != field_right.name {
346                        _asof_pre_added_rhs_keys.insert(field_right.name.clone());
347
348                        if schema_left.contains(&field_right.name) {
349                            new_schema.with_column(
350                                _join_suffix_name(&field_right.name, options.args.suffix()),
351                                field_right.dtype,
352                            );
353                        } else {
354                            new_schema.with_column(field_right.name, field_right.dtype);
355                        }
356                    }
357                }
358            }
359
360            let mut join_on_right: PlHashSet<_> = PlHashSet::with_capacity(right_on.len());
361            for e in right_on {
362                let field = e.field(schema_right, Context::Default, expr_arena)?;
363                join_on_right.insert(field.name);
364            }
365
366            for (name, dtype) in schema_right.iter() {
367                #[cfg(feature = "asof_join")]
368                {
369                    if let JoinType::AsOf(asof_options) = &options.args.how {
370                        // Asof adds keys earlier
371                        if _asof_pre_added_rhs_keys.contains(name) {
372                            continue;
373                        }
374
375                        // Asof join by columns are coalesced
376                        if asof_options
377                            .right_by
378                            .as_deref()
379                            .is_some_and(|x| x.contains(name))
380                        {
381                            // Do not add suffix. The column of the left table will be used
382                            continue;
383                        }
384                    }
385                }
386
387                if join_on_right.contains(name.as_str()) && is_coalesced {
388                    // Column will be coalesced into an already added LHS column.
389                    continue;
390                }
391
392                // For the error message.
393                let mut suffixed = None;
394
395                let (name, dtype) = if schema_left.contains(name) {
396                    suffixed = Some(format_pl_smallstr!("{}{}", name, options.args.suffix()));
397                    (suffixed.clone().unwrap(), dtype.clone())
398                } else {
399                    (name.clone(), dtype.clone())
400                };
401
402                new_schema.try_insert(name, dtype).map_err(|e| {
403                    if let Some(column) = suffixed {
404                        join_suffix_duplicate_help_msg(&column)
405                    } else {
406                        e
407                    }
408                })?;
409            }
410
411            Ok(Arc::new(new_schema))
412        },
413    }
414}
415
416fn join_suffix_duplicate_help_msg(column_name: &str) -> PolarsError {
417    polars_err!(
418        Duplicate:
419        "\
420column with name '{}' already exists
421
422You may want to try:
423- renaming the column prior to joining
424- using the `suffix` parameter to specify a suffix different to the default one ('_right')",
425        column_name
426    )
427}
428
429// We don't use an `Arc<Mutex>` because caches should live in different query plans.
430// For that reason we have a specialized deep clone.
431#[derive(Default)]
432pub struct CachedSchema(Mutex<Option<SchemaRef>>);
433
434impl AsRef<Mutex<Option<SchemaRef>>> for CachedSchema {
435    fn as_ref(&self) -> &Mutex<Option<SchemaRef>> {
436        &self.0
437    }
438}
439
440impl Deref for CachedSchema {
441    type Target = Mutex<Option<SchemaRef>>;
442
443    fn deref(&self) -> &Self::Target {
444        &self.0
445    }
446}
447
448impl Clone for CachedSchema {
449    fn clone(&self) -> Self {
450        let inner = self.0.lock().unwrap();
451        Self(Mutex::new(inner.clone()))
452    }
453}
454
455impl CachedSchema {
456    pub fn get(&self) -> Option<SchemaRef> {
457        self.0.lock().unwrap().clone()
458    }
459}