1use std::ops::Deref;
2use std::sync::Mutex;
3
4use arrow::datatypes::ArrowSchemaRef;
5use either::Either;
6use polars_core::prelude::*;
7use polars_utils::format_pl_smallstr;
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::prelude::*;
12
13impl DslPlan {
14 pub fn compute_schema(&self) -> PolarsResult<SchemaRef> {
19 let mut lp_arena = Default::default();
20 let mut expr_arena = Default::default();
21 let node = to_alp(
22 self.clone(),
23 &mut expr_arena,
24 &mut lp_arena,
25 &mut OptFlags::schema_only(),
26 )?;
27
28 Ok(lp_arena.get(node).schema(&lp_arena).into_owned())
29 }
30}
31
32#[derive(Clone, Debug)]
33#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
34pub struct FileInfo {
35 pub schema: SchemaRef,
41 pub reader_schema: Option<Either<ArrowSchemaRef, SchemaRef>>,
44 pub row_estimation: (Option<usize>, usize),
47}
48
49impl Default for FileInfo {
51 fn default() -> Self {
52 FileInfo {
53 schema: Default::default(),
54 reader_schema: None,
55 row_estimation: (None, usize::MAX),
56 }
57 }
58}
59
60impl FileInfo {
61 pub fn new(
63 schema: SchemaRef,
64 reader_schema: Option<Either<ArrowSchemaRef, SchemaRef>>,
65 row_estimation: (Option<usize>, usize),
66 ) -> Self {
67 Self {
68 schema: schema.clone(),
69 reader_schema,
70 row_estimation,
71 }
72 }
73
74 pub fn update_schema_with_hive_schema(&mut self, hive_schema: SchemaRef) {
76 let schema = Arc::make_mut(&mut self.schema);
77
78 for field in hive_schema.iter_fields() {
79 if let Some(existing) = schema.get_mut(&field.name) {
80 *existing = field.dtype().clone();
81 } else {
82 schema
83 .insert_at_index(schema.len(), field.name, field.dtype.clone())
84 .unwrap();
85 }
86 }
87 }
88}
89
90#[cfg(feature = "streaming")]
91fn estimate_sizes(
92 known_size: Option<usize>,
93 estimated_size: usize,
94 filter_count: usize,
95) -> (Option<usize>, usize) {
96 match (known_size, filter_count) {
97 (Some(known_size), 0) => (Some(known_size), estimated_size),
98 (None, 0) => (None, estimated_size),
99 (_, _) => (
100 None,
101 (estimated_size as f32 * 0.9f32.powf(filter_count as f32)) as usize,
102 ),
103 }
104}
105
106#[cfg(feature = "streaming")]
107pub fn set_estimated_row_counts(
108 root: Node,
109 lp_arena: &mut Arena<IR>,
110 expr_arena: &Arena<AExpr>,
111 mut _filter_count: usize,
112 scratch: &mut Vec<Node>,
113) -> (Option<usize>, usize, usize) {
114 use IR::*;
115
116 fn apply_slice(out: &mut (Option<usize>, usize, usize), slice: Option<(i64, usize)>) {
117 if let Some((_, len)) = slice {
118 out.0 = out.0.map(|known_size| std::cmp::min(len, known_size));
119 out.1 = std::cmp::min(len, out.1);
120 }
121 }
122
123 match lp_arena.get(root) {
124 Filter { predicate, input } => {
125 _filter_count += expr_arena
126 .iter(predicate.node())
127 .filter(|(_, ae)| matches!(ae, AExpr::BinaryExpr { .. }))
128 .count()
129 + 1;
130 set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count, scratch)
131 },
132 Slice { input, len, .. } => {
133 let len = *len as usize;
134 let mut out =
135 set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count, scratch);
136 apply_slice(&mut out, Some((0, len)));
137 out
138 },
139 Union { .. } => {
140 if let Union {
141 inputs,
142 mut options,
143 } = lp_arena.take(root)
144 {
145 let mut sum_output = (None, 0usize);
146 for input in &inputs {
147 let mut out =
148 set_estimated_row_counts(*input, lp_arena, expr_arena, 0, scratch);
149 if let Some((_offset, len)) = options.slice {
150 apply_slice(&mut out, Some((0, len)))
151 }
152 let out = estimate_sizes(out.0, out.1, out.2);
154 sum_output.1 = sum_output.1.saturating_add(out.1);
155 }
156 options.rows = sum_output;
157 lp_arena.replace(root, Union { inputs, options });
158 (sum_output.0, sum_output.1, 0)
159 } else {
160 unreachable!()
161 }
162 },
163 Join { .. } => {
164 if let Join {
165 input_left,
166 input_right,
167 mut options,
168 schema,
169 left_on,
170 right_on,
171 } = lp_arena.take(root)
172 {
173 let mut_options = Arc::make_mut(&mut options);
174 let (known_size, estimated_size, filter_count_left) =
175 set_estimated_row_counts(input_left, lp_arena, expr_arena, 0, scratch);
176 mut_options.rows_left =
177 estimate_sizes(known_size, estimated_size, filter_count_left);
178 let (known_size, estimated_size, filter_count_right) =
179 set_estimated_row_counts(input_right, lp_arena, expr_arena, 0, scratch);
180 mut_options.rows_right =
181 estimate_sizes(known_size, estimated_size, filter_count_right);
182
183 let mut out = match options.args.how {
184 JoinType::Left => {
185 let (known_size, estimated_size) = options.rows_left;
186 (known_size, estimated_size, filter_count_left)
187 },
188 JoinType::Cross | JoinType::Full => {
189 let (known_size_left, estimated_size_left) = options.rows_left;
190 let (known_size_right, estimated_size_right) = options.rows_right;
191 match (known_size_left, known_size_right) {
192 (Some(l), Some(r)) => {
193 (Some(l * r), estimated_size_left, estimated_size_right)
194 },
195 _ => (None, estimated_size_left * estimated_size_right, 0),
196 }
197 },
198 _ => {
199 let (known_size_left, estimated_size_left) = options.rows_left;
200 let (known_size_right, estimated_size_right) = options.rows_right;
201 if estimated_size_left > estimated_size_right {
202 (known_size_left, estimated_size_left, 0)
203 } else {
204 (known_size_right, estimated_size_right, 0)
205 }
206 },
207 };
208 apply_slice(&mut out, options.args.slice);
209 lp_arena.replace(
210 root,
211 Join {
212 input_left,
213 input_right,
214 options,
215 schema,
216 left_on,
217 right_on,
218 },
219 );
220 out
221 } else {
222 unreachable!()
223 }
224 },
225 DataFrameScan { df, .. } => {
226 let len = df.height();
227 (Some(len), len, _filter_count)
228 },
229 Scan { file_info, .. } => {
230 let (known_size, estimated_size) = file_info.row_estimation;
231 (known_size, estimated_size, _filter_count)
232 },
233 #[cfg(feature = "python")]
234 PythonScan { .. } => {
235 (None, usize::MAX, _filter_count)
237 },
238 lp => {
239 lp.copy_inputs(scratch);
240 let mut sum_output = (None, 0, 0);
241 while let Some(input) = scratch.pop() {
242 let out =
243 set_estimated_row_counts(input, lp_arena, expr_arena, _filter_count, scratch);
244 sum_output.1 += out.1;
245 sum_output.2 += out.2;
246 sum_output.0 = match sum_output.0 {
247 None => out.0,
248 p => p,
249 };
250 }
251 sum_output
252 },
253 }
254}
255
256pub(crate) fn det_join_schema(
257 schema_left: &SchemaRef,
258 schema_right: &SchemaRef,
259 left_on: &[ExprIR],
260 right_on: &[ExprIR],
261 options: &JoinOptions,
262 expr_arena: &Arena<AExpr>,
263) -> PolarsResult<SchemaRef> {
264 match &options.args.how {
265 #[cfg(feature = "semi_anti_join")]
268 JoinType::Semi | JoinType::Anti => Ok(schema_left.clone()),
269 JoinType::Right if options.args.should_coalesce() => {
278 let mut join_on_left: PlHashSet<_> = PlHashSet::with_capacity(left_on.len());
280 for e in left_on {
281 let field = e.field(schema_left, Context::Default, expr_arena)?;
282 join_on_left.insert(field.name);
283 }
284
285 let mut join_on_right: PlHashSet<_> = PlHashSet::with_capacity(right_on.len());
286 for e in right_on {
287 let field = e.field(schema_right, Context::Default, expr_arena)?;
288 join_on_right.insert(field.name);
289 }
290
291 let mut suffixed = None;
293
294 let new_schema = Schema::with_capacity(schema_left.len() + schema_right.len())
295 .hstack(schema_left.iter().filter_map(|(name, dtype)| {
297 if join_on_left.contains(name) {
298 return None;
299 }
300
301 Some((name.clone(), dtype.clone()))
302 }))?
303 .hstack(schema_right.iter().map(|(name, dtype)| {
305 suffixed = None;
306
307 let in_left_schema = schema_left.contains(name.as_str());
308 let is_coalesced = join_on_left.contains(name.as_str());
309
310 if in_left_schema && !is_coalesced {
311 suffixed = Some(format_pl_smallstr!("{}{}", name, options.args.suffix()));
312 (suffixed.clone().unwrap(), dtype.clone())
313 } else {
314 (name.clone(), dtype.clone())
315 }
316 }))
317 .map_err(|e| {
318 if let Some(column) = suffixed {
319 join_suffix_duplicate_help_msg(&column)
320 } else {
321 e
322 }
323 })?;
324
325 Ok(Arc::new(new_schema))
326 },
327 _how => {
328 let mut new_schema = Schema::with_capacity(schema_left.len() + schema_right.len())
329 .hstack(schema_left.iter_fields())?;
330
331 let is_coalesced = options.args.should_coalesce();
332
333 let mut _asof_pre_added_rhs_keys: PlHashSet<PlSmallStr> = PlHashSet::new();
334
335 #[cfg(feature = "asof_join")]
340 if matches!(_how, JoinType::AsOf(_)) {
341 for (left_on, right_on) in left_on.iter().zip(right_on) {
342 let field_left = left_on.field(schema_left, Context::Default, expr_arena)?;
343 let field_right = right_on.field(schema_right, Context::Default, expr_arena)?;
344
345 if is_coalesced && field_left.name != field_right.name {
346 _asof_pre_added_rhs_keys.insert(field_right.name.clone());
347
348 if schema_left.contains(&field_right.name) {
349 new_schema.with_column(
350 _join_suffix_name(&field_right.name, options.args.suffix()),
351 field_right.dtype,
352 );
353 } else {
354 new_schema.with_column(field_right.name, field_right.dtype);
355 }
356 }
357 }
358 }
359
360 let mut join_on_right: PlHashSet<_> = PlHashSet::with_capacity(right_on.len());
361 for e in right_on {
362 let field = e.field(schema_right, Context::Default, expr_arena)?;
363 join_on_right.insert(field.name);
364 }
365
366 for (name, dtype) in schema_right.iter() {
367 #[cfg(feature = "asof_join")]
368 {
369 if let JoinType::AsOf(asof_options) = &options.args.how {
370 if _asof_pre_added_rhs_keys.contains(name) {
372 continue;
373 }
374
375 if asof_options
377 .right_by
378 .as_deref()
379 .is_some_and(|x| x.contains(name))
380 {
381 continue;
383 }
384 }
385 }
386
387 if join_on_right.contains(name.as_str()) && is_coalesced {
388 continue;
390 }
391
392 let mut suffixed = None;
394
395 let (name, dtype) = if schema_left.contains(name) {
396 suffixed = Some(format_pl_smallstr!("{}{}", name, options.args.suffix()));
397 (suffixed.clone().unwrap(), dtype.clone())
398 } else {
399 (name.clone(), dtype.clone())
400 };
401
402 new_schema.try_insert(name, dtype).map_err(|e| {
403 if let Some(column) = suffixed {
404 join_suffix_duplicate_help_msg(&column)
405 } else {
406 e
407 }
408 })?;
409 }
410
411 Ok(Arc::new(new_schema))
412 },
413 }
414}
415
416fn join_suffix_duplicate_help_msg(column_name: &str) -> PolarsError {
417 polars_err!(
418 Duplicate:
419 "\
420column with name '{}' already exists
421
422You may want to try:
423- renaming the column prior to joining
424- using the `suffix` parameter to specify a suffix different to the default one ('_right')",
425 column_name
426 )
427}
428
429#[derive(Default)]
432pub struct CachedSchema(Mutex<Option<SchemaRef>>);
433
434impl AsRef<Mutex<Option<SchemaRef>>> for CachedSchema {
435 fn as_ref(&self) -> &Mutex<Option<SchemaRef>> {
436 &self.0
437 }
438}
439
440impl Deref for CachedSchema {
441 type Target = Mutex<Option<SchemaRef>>;
442
443 fn deref(&self) -> &Self::Target {
444 &self.0
445 }
446}
447
448impl Clone for CachedSchema {
449 fn clone(&self) -> Self {
450 let inner = self.0.lock().unwrap();
451 Self(Mutex::new(inner.clone()))
452 }
453}
454
455impl CachedSchema {
456 pub fn get(&self) -> Option<SchemaRef> {
457 self.0.lock().unwrap().clone()
458 }
459}