1mod aggregations;
4pub mod joins;
5mod stats;
6mod transformations;
7
8pub(crate) use aggregations::disambiguate_agg_output_names;
9pub use aggregations::{CubeRollupData, GroupedData, PivotedGroupedData};
10pub use joins::{
11 JoinOptions, JoinType, expr_contains_only_join_key_equalities, join,
12 try_extract_join_eq_columns, try_extract_join_eq_columns_all,
13};
14pub use stats::DataFrameStat;
15pub(crate) use transformations::literal_value_to_serde_value;
16pub use transformations::{
17 DataFrameNa, SelectItem, filter, order_by, order_by_exprs, select, select_items,
18 select_with_exprs, with_column,
19};
20
21use crate::column::Column;
22use crate::error::{EngineError, polars_to_core_error};
23use crate::functions::SortOrder;
24use crate::schema::{StructType, StructTypePolarsExt};
25use crate::session::SparkSession;
26use crate::type_coercion::{coerce_for_pyspark_comparison, is_numeric_public};
27use polars::datatypes::TimeUnit;
28use polars::prelude::{
29 AnyValue, DataFrame as PlDataFrame, DataType, Expr, Field, IntoLazy, LazyFrame, NULL,
30 PlSmallStr, PolarsError, Schema, SchemaNamesAndDtypes, UnknownKind, col, lit,
31};
32use serde_json::Value as JsonValue;
33use std::collections::{HashMap, HashSet};
34use std::path::Path;
35use std::sync::Arc;
36
37const DEFAULT_CASE_SENSITIVE: bool = false;
39
40fn pyspark_type_name(dtype: &DataType) -> String {
42 use polars::datatypes::DataType as PlDataType;
43 match dtype {
44 PlDataType::Int32 => "IntegerType".to_string(),
45 PlDataType::Int64 => "LongType".to_string(),
46 PlDataType::String => "StringType".to_string(),
47 PlDataType::Float32 | PlDataType::Float64 => "DoubleType".to_string(),
48 PlDataType::Boolean => "BooleanType".to_string(),
49 PlDataType::Date => "DateType".to_string(),
50 PlDataType::Datetime(_, _) => "TimestampType".to_string(),
51 PlDataType::List(inner) => format!("ArrayType({})", pyspark_type_name(inner)),
52 PlDataType::Struct(fields) => {
53 let parts: Vec<String> = fields
54 .iter()
55 .map(|f| format!("{}: {}", f.name(), pyspark_type_name(f.dtype())))
56 .collect();
57 format!("StructType([{}])", parts.join(", "))
58 }
59 _ => format!("{dtype:?}"),
60 }
61}
62
63#[allow(clippy::large_enum_variant)]
66pub(crate) enum DataFrameInner {
67 #[allow(dead_code)]
68 Eager(Arc<PlDataFrame>),
69 Lazy(LazyFrame),
70}
71
72pub struct DataFrame {
76 pub(crate) inner: DataFrameInner,
77 pub(crate) case_sensitive: bool,
79 pub(crate) alias: Option<String>,
81 pub(crate) ambiguous_columns: Option<HashSet<String>>,
84}
85
86#[derive(Clone)]
89pub enum GroupBySpec {
90 Name(String),
91 Column(Box<Column>),
92}
93
94impl DataFrame {
95 pub fn from_polars(df: PlDataFrame) -> Self {
98 let lf = df.lazy();
99 DataFrame {
100 inner: DataFrameInner::Lazy(lf),
101 case_sensitive: DEFAULT_CASE_SENSITIVE,
102 alias: None,
103 ambiguous_columns: None,
104 }
105 }
106
107 pub fn from_polars_with_options(df: PlDataFrame, case_sensitive: bool) -> Self {
110 let lf = df.lazy();
111 DataFrame {
112 inner: DataFrameInner::Lazy(lf),
113 case_sensitive,
114 alias: None,
115 ambiguous_columns: None,
116 }
117 }
118
119 pub(crate) fn from_eager_with_options(df: PlDataFrame, case_sensitive: bool) -> Self {
123 DataFrame {
124 inner: DataFrameInner::Eager(Arc::new(df)),
125 case_sensitive,
126 alias: None,
127 ambiguous_columns: None,
128 }
129 }
130
131 pub fn from_lazy(lf: LazyFrame) -> Self {
133 DataFrame {
134 inner: DataFrameInner::Lazy(lf),
135 case_sensitive: DEFAULT_CASE_SENSITIVE,
136 alias: None,
137 ambiguous_columns: None,
138 }
139 }
140
141 pub fn from_lazy_with_options(lf: LazyFrame, case_sensitive: bool) -> Self {
143 DataFrame {
144 inner: DataFrameInner::Lazy(lf),
145 case_sensitive,
146 alias: None,
147 ambiguous_columns: None,
148 }
149 }
150
151 pub(crate) fn from_lazy_with_options_and_ambiguous(
154 lf: LazyFrame,
155 case_sensitive: bool,
156 ambiguous_columns: Option<HashSet<String>>,
157 ) -> Self {
158 DataFrame {
159 inner: DataFrameInner::Lazy(lf),
160 case_sensitive,
161 alias: None,
162 ambiguous_columns,
163 }
164 }
165
166 pub(crate) fn with_case_insensitive_column_resolution(self) -> Self {
169 DataFrame {
170 inner: self.inner,
171 case_sensitive: false,
172 alias: self.alias,
173 ambiguous_columns: self.ambiguous_columns,
174 }
175 }
176
177 pub fn empty() -> Self {
179 DataFrame {
180 inner: DataFrameInner::Lazy(PlDataFrame::empty().lazy()),
181 case_sensitive: DEFAULT_CASE_SENSITIVE,
182 alias: None,
183 ambiguous_columns: None,
184 }
185 }
186
187 pub(crate) fn is_eager(&self) -> bool {
189 matches!(&self.inner, DataFrameInner::Eager(_))
190 }
191
192 pub(crate) fn lazy_frame(&self) -> LazyFrame {
194 match &self.inner {
195 DataFrameInner::Eager(df) => df.as_ref().clone().lazy(),
196 DataFrameInner::Lazy(lf) => lf.clone(),
197 }
198 }
199
200 pub(crate) fn collect_inner(&self) -> Result<Arc<PlDataFrame>, PolarsError> {
202 match &self.inner {
203 DataFrameInner::Eager(df) => Ok(df.clone()),
204 DataFrameInner::Lazy(lf) => Ok(Arc::new(lf.clone().collect()?)),
205 }
206 }
207
208 pub fn alias(&self, name: &str) -> Self {
211 let lf = self.lazy_frame();
212 DataFrame {
213 inner: DataFrameInner::Lazy(lf),
214 case_sensitive: self.case_sensitive,
215 alias: Some(name.to_string()),
216 ambiguous_columns: self.ambiguous_columns.clone(),
217 }
218 }
219
220 pub fn get_alias(&self) -> Option<String> {
222 self.alias.clone()
223 }
224
225 pub fn resolve_expr_column_names(&self, expr: Expr) -> Result<Expr, PolarsError> {
232 let df = self;
233 let mut alias_output_names: HashSet<String> = HashSet::new();
234 let _ = expr.clone().try_map_expr(|e| {
235 if let Expr::Alias(_, name) = &e {
236 alias_output_names.insert(name.as_str().to_string());
237 }
238 Ok(e)
239 })?;
240 if let Ok(out_name) = polars_plan::utils::expr_output_name(&expr) {
243 let out_str = out_name.as_str();
244 let matches_schema = self
245 .columns()
246 .map(|cols| cols.iter().any(|c| c.eq_ignore_ascii_case(out_str)))
247 .unwrap_or(false);
248 if !matches_schema {
249 alias_output_names.insert(out_str.to_string());
250 }
251 }
252 expr.try_map_expr(move |e| {
253 if let Expr::Alias(inner, name) = &e {
255 let new_inner = df.resolve_expr_column_names(inner.as_ref().clone())?;
256 return Ok(Expr::Alias(Arc::new(new_inner), name.clone()));
257 }
258 if let Expr::Cast {
260 ref expr,
261 ref dtype,
262 ref options,
263 } = e
264 {
265 let resolved_inner = df.resolve_expr_column_names(expr.as_ref().clone())?;
266 return Ok(Expr::Cast {
267 expr: Arc::new(resolved_inner),
268 dtype: dtype.clone(),
269 options: *options,
270 });
271 }
272 if let Expr::Column(name) = &e {
273 let name_str = name.as_str();
274 if !name_str.contains('.') && alias_output_names.contains(name_str) {
278 let matches_schema = df
279 .columns()
280 .map(|cols| cols.iter().any(|c| c.eq_ignore_ascii_case(name_str)))
281 .unwrap_or(false);
282 if !matches_schema {
283 return Ok(e);
284 }
285 }
286 if name_str.is_empty() {
288 return Ok(e);
289 }
290 if name_str.contains('.') {
292 let parts: Vec<&str> = name_str.split('.').collect();
293 let first = parts[0];
294 let rest = &parts[1..];
295 if rest.is_empty() {
296 return Err(PolarsError::ColumnNotFound(
297 format!(
298 "cannot resolve: Column '{}': trailing dot not allowed",
299 name_str
300 )
301 .into(),
302 ));
303 }
304 match df.resolve_column_name(first) {
306 Ok(resolved) => {
307 let mut expr = col(PlSmallStr::from(resolved.as_str()));
308 let mut current_dtype =
309 df.get_column_dtype(resolved.as_str()).ok_or_else(|| {
310 PolarsError::ColumnNotFound(
311 format!("cannot resolve: column '{}' not found", resolved)
312 .into(),
313 )
314 })?;
315 let mut context_name = resolved.to_string();
316 for field in rest {
317 let (resolved_field, field_dtype) = match df
318 .resolve_struct_field_from_type(
319 ¤t_dtype,
320 field,
321 &context_name,
322 ) {
323 Ok(t) => t,
324 Err(_) => {
325 return Ok(lit(NULL).alias(PlSmallStr::from(name_str)));
327 }
328 };
329 expr = expr.struct_().field_by_name(&resolved_field);
330 context_name = format!("{}.{}", context_name, resolved_field);
331 current_dtype = field_dtype;
332 }
333 return Ok(expr.alias(PlSmallStr::from(name_str)));
334 }
335 Err(_) => {
336 if let Ok(suffix_resolved) = df.resolve_column_name(name_str) {
338 return Ok(col(PlSmallStr::from(suffix_resolved.as_str()))
339 .alias(PlSmallStr::from(name_str)));
340 }
341 return Err(PolarsError::ColumnNotFound(
342 format!("cannot resolve: column '{}' not found", first).into(),
343 ));
344 }
345 }
346 }
347 let resolved = df.resolve_column_name(name_str)?;
348 return Ok(Expr::Column(PlSmallStr::from(resolved.as_str())));
349 }
350 if let Expr::Function {
352 input,
353 function:
354 polars::prelude::FunctionExpr::StructExpr(
355 polars::prelude::StructFunction::FieldByName(name),
356 ),
357 } = &e
358 {
359 if input.len() == 1 {
360 let input_expr = input[0].clone();
361 if let Some(input_dt) = df.get_expr_output_dtype(&input_expr) {
362 match df.resolve_struct_field_from_type(&input_dt, name.as_str(), "struct")
363 {
364 Ok((resolved_name, _)) => {
365 return Ok(input_expr.struct_().field_by_name(&resolved_name));
366 }
367 Err(_) => {
368 return Ok(lit(NULL));
370 }
371 }
372 }
373 return Ok(input_expr.struct_().field_by_name(name));
376 }
377 }
378 if let Expr::Function { input, function } = &e {
380 let resolved_inputs: Result<Vec<Expr>, _> = input
381 .iter()
382 .map(|arg| df.resolve_expr_column_names(arg.clone()))
383 .collect();
384 if let Ok(resolved) = resolved_inputs {
385 return Ok(Expr::Function {
386 input: resolved,
387 function: function.clone(),
388 });
389 }
390 }
391 if let Expr::Over {
393 function,
394 partition_by,
395 order_by,
396 mapping,
397 } = &e
398 {
399 let resolved_function = df.resolve_expr_column_names(function.as_ref().clone())?;
400 let resolved_partition_by: Result<Vec<Expr>, _> = partition_by
401 .iter()
402 .map(|p| df.resolve_expr_column_names(p.clone()))
403 .collect();
404 let resolved_partition_by = resolved_partition_by?;
405 let resolved_order_by = order_by.as_ref().map(|(ob, opts)| {
406 df.resolve_expr_column_names(ob.as_ref().clone())
407 .map(|r| (Arc::new(r), *opts))
408 });
409 let resolved_order_by = match resolved_order_by {
410 Some(Ok((r, opts))) => Some((r, opts)),
411 Some(Err(e)) => return Err(e),
412 None => None,
413 };
414 return Ok(Expr::Over {
415 function: Arc::new(resolved_function),
416 partition_by: resolved_partition_by,
417 order_by: resolved_order_by,
418 mapping: *mapping,
419 });
420 }
421 Ok(e)
422 })
423 }
424
425 pub fn coerce_string_numeric_comparisons(&self, expr: Expr) -> Result<Expr, PolarsError> {
433 use polars::prelude::{DataType, LiteralValue, Operator};
434 use std::sync::Arc;
435
436 fn is_numeric_literal(expr: &Expr) -> bool {
437 match expr {
438 Expr::Literal(lv) => {
439 let dt = lv.get_datatype();
440 dt.is_numeric()
441 || matches!(
442 dt,
443 DataType::Unknown(UnknownKind::Int(_))
444 | DataType::Unknown(UnknownKind::Float)
445 )
446 }
447 _ => false,
448 }
449 }
450
451 fn literal_dtype(lv: &LiteralValue) -> DataType {
452 let dt = lv.get_datatype();
453 if matches!(
454 dt,
455 DataType::Unknown(UnknownKind::Int(_)) | DataType::Unknown(UnknownKind::Float)
456 ) {
457 DataType::Float64
458 } else {
459 dt
460 }
461 }
462
463 let (expr_to_coerce, alias_after) = match &expr {
467 Expr::Alias(inner, name) => (inner.as_ref().clone(), Some(name.clone())),
468 _ => (expr.clone(), None),
469 };
470 let expr_to_coerce = match &expr_to_coerce {
473 Expr::BinaryExpr { left, op, right } if matches!(op, Operator::And | Operator::Or) => {
474 let left_c = self.coerce_string_numeric_comparisons((**left).clone())?;
475 let right_c = self.coerce_string_numeric_comparisons((**right).clone())?;
476 Expr::BinaryExpr {
477 left: Arc::new(left_c),
478 op: *op,
479 right: Arc::new(right_c),
480 }
481 }
482 _ => expr_to_coerce,
483 };
484 fn wrap_expr_with_alias(
485 expr: Expr,
486 alias_name: Option<&polars::prelude::PlSmallStr>,
487 ) -> Expr {
488 match alias_name {
489 Some(name) => Expr::Alias(Arc::new(expr), name.clone()),
490 None => expr,
491 }
492 }
493 let expr = {
494 if let Expr::BinaryExpr { left, op, right } = &expr_to_coerce {
495 let left_inner: &Expr = match left.as_ref() {
497 Expr::Alias(inner, _) => inner.as_ref(),
498 _ => left,
499 };
500 let right_inner: &Expr = match right.as_ref() {
501 Expr::Alias(inner, _) => inner.as_ref(),
502 _ => right,
503 };
504 let is_comparison_op = matches!(
505 op,
506 Operator::Eq
507 | Operator::NotEq
508 | Operator::Lt
509 | Operator::LtEq
510 | Operator::Gt
511 | Operator::GtEq
512 );
513 let left_is_col = matches!(left_inner, Expr::Column(_));
514 let right_is_col = matches!(right_inner, Expr::Column(_));
515 let left_is_numeric_lit =
516 matches!(left_inner, Expr::Literal(_)) && is_numeric_literal(left_inner);
517 let right_is_numeric_lit =
518 matches!(right_inner, Expr::Literal(_)) && is_numeric_literal(right_inner);
519 let left_is_string_lit = matches!(
520 left_inner,
521 Expr::Literal(lv) if lv.get_datatype() == DataType::String
522 );
523 let right_is_string_lit = matches!(
524 right_inner,
525 Expr::Literal(lv) if lv.get_datatype() == DataType::String
526 );
527 let root_is_col_vs_numeric = is_comparison_op
528 && ((left_is_col && right_is_numeric_lit)
529 || (right_is_col && left_is_numeric_lit));
530 let root_is_col_vs_string = is_comparison_op
531 && ((left_is_col && right_is_string_lit)
532 || (right_is_col && left_is_string_lit));
533 if root_is_col_vs_numeric {
534 let col_name = if left_is_col {
536 if let Expr::Column(n) = left_inner {
537 n.as_str()
538 } else {
539 unreachable!()
540 }
541 } else if let Expr::Column(n) = right_inner {
542 n.as_str()
543 } else {
544 unreachable!()
545 };
546 let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
548 let col_ty = self.get_column_dtype(col_name);
549 let lit_ty = match right_inner {
550 Expr::Literal(lv) => literal_dtype(lv),
551 _ => DataType::Float64,
552 };
553 let left_ty = col_ty.filter(is_numeric_public).unwrap_or(DataType::String);
554 coerce_for_pyspark_comparison(
555 left_inner.clone(),
556 right_inner.clone(),
557 &left_ty,
558 &lit_ty,
559 op,
560 )
561 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
562 } else {
563 let col_ty = self.get_column_dtype(col_name);
564 let lit_ty = match left_inner {
565 Expr::Literal(lv) => literal_dtype(lv),
566 _ => DataType::Float64,
567 };
568 let right_ty = col_ty.filter(is_numeric_public).unwrap_or(DataType::String);
569 coerce_for_pyspark_comparison(
570 left_inner.clone(),
571 right_inner.clone(),
572 &lit_ty,
573 &right_ty,
574 op,
575 )
576 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
577 };
578 Expr::BinaryExpr {
579 left: Arc::new(new_left),
580 op: *op,
581 right: Arc::new(new_right),
582 }
583 } else if root_is_col_vs_string {
584 let col_name = if left_is_col {
585 if let Expr::Column(n) = left_inner {
586 n.as_str()
587 } else {
588 unreachable!()
589 }
590 } else if let Expr::Column(n) = right_inner {
591 n.as_str()
592 } else {
593 unreachable!()
594 };
595 if let Some(col_dtype) = self.get_column_dtype(col_name) {
596 if matches!(col_dtype, DataType::Date | DataType::Datetime(_, _)) {
597 let (left_ty, right_ty) = if left_is_col {
598 (col_dtype.clone(), DataType::String)
599 } else {
600 (DataType::String, col_dtype.clone())
601 };
602 let (new_left, new_right) = coerce_for_pyspark_comparison(
603 left_inner.clone(),
604 right_inner.clone(),
605 &left_ty,
606 &right_ty,
607 op,
608 )
609 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
610 let e = Expr::BinaryExpr {
611 left: Arc::new(new_left),
612 op: *op,
613 right: Arc::new(new_right),
614 };
615 return Ok(wrap_expr_with_alias(e, alias_after.as_ref()));
616 }
617 if is_numeric_public(&col_dtype) {
619 let (left_ty, right_ty) = if left_is_col {
620 (col_dtype.clone(), DataType::String)
621 } else {
622 (DataType::String, col_dtype.clone())
623 };
624 let (new_left, new_right) = coerce_for_pyspark_comparison(
625 left_inner.clone(),
626 right_inner.clone(),
627 &left_ty,
628 &right_ty,
629 op,
630 )
631 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
632 let e = Expr::BinaryExpr {
633 left: Arc::new(new_left),
634 op: *op,
635 right: Arc::new(new_right),
636 };
637 return Ok(wrap_expr_with_alias(e, alias_after.as_ref()));
638 }
639 }
640 expr_to_coerce.clone()
641 } else if is_comparison_op && left_is_col && right_is_col {
642 let left_name = if let Expr::Column(n) = left_inner {
645 n.as_str()
646 } else {
647 unreachable!()
648 };
649 let right_name = if let Expr::Column(n) = right_inner {
650 n.as_str()
651 } else {
652 unreachable!()
653 };
654 if let (Some(left_ty), Some(right_ty)) = (
655 self.get_column_dtype(left_name),
656 self.get_column_dtype(right_name),
657 ) {
658 if left_ty != right_ty {
659 if let Ok((new_left, new_right)) = coerce_for_pyspark_comparison(
660 left_inner.clone(),
661 right_inner.clone(),
662 &left_ty,
663 &right_ty,
664 op,
665 ) {
666 let e = Expr::BinaryExpr {
667 left: Arc::new(new_left),
668 op: *op,
669 right: Arc::new(new_right),
670 };
671 return Ok(wrap_expr_with_alias(e, alias_after.as_ref()));
672 }
673 }
674 }
675 expr_to_coerce.clone()
676 } else {
677 expr_to_coerce.clone()
678 }
679 } else {
680 expr_to_coerce.clone()
681 }
682 };
683 let expr = wrap_expr_with_alias(expr, alias_after.as_ref());
684
685 let get_col_dtype = |name: &str| self.get_column_dtype(name);
687 let expr = expr.try_map_expr(move |e| {
688 if let Expr::BinaryExpr { left, op, right } = e {
689 let is_comparison_op = matches!(
690 op,
691 Operator::Eq
692 | Operator::NotEq
693 | Operator::Lt
694 | Operator::LtEq
695 | Operator::Gt
696 | Operator::GtEq
697 );
698 if !is_comparison_op {
699 return Ok(Expr::BinaryExpr { left, op, right });
700 }
701
702 let left_is_col = matches!(&*left, Expr::Column(_));
703 let right_is_col = matches!(&*right, Expr::Column(_));
704 let left_is_lit = matches!(&*left, Expr::Literal(_));
705 let right_is_lit = matches!(&*right, Expr::Literal(_));
706 let left_is_string_lit =
707 matches!(&*left, Expr::Literal(lv) if lv.get_datatype() == DataType::String);
708 let right_is_string_lit =
709 matches!(&*right, Expr::Literal(lv) if lv.get_datatype() == DataType::String);
710
711 let left_is_numeric_lit = left_is_lit && is_numeric_literal(left.as_ref());
712 let right_is_numeric_lit = right_is_lit && is_numeric_literal(right.as_ref());
713
714 let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
716 let col_ty = if let Expr::Column(n) = &*left {
717 get_col_dtype(n.as_str())
718 } else {
719 None
720 };
721 let lit_ty = match &*right {
722 Expr::Literal(lv) => literal_dtype(lv),
723 _ => DataType::Float64,
724 };
725 let left_ty = col_ty.filter(is_numeric_public).unwrap_or(DataType::String);
726 coerce_for_pyspark_comparison(
727 (*left).clone(),
728 (*right).clone(),
729 &left_ty,
730 &lit_ty,
731 &op,
732 )
733 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
734 } else if right_is_col && left_is_numeric_lit {
735 let col_ty = if let Expr::Column(n) = &*right {
736 get_col_dtype(n.as_str())
737 } else {
738 None
739 };
740 let lit_ty = match &*left {
741 Expr::Literal(lv) => literal_dtype(lv),
742 _ => DataType::Float64,
743 };
744 let right_ty = col_ty.filter(is_numeric_public).unwrap_or(DataType::String);
745 coerce_for_pyspark_comparison(
746 (*left).clone(),
747 (*right).clone(),
748 &lit_ty,
749 &right_ty,
750 &op,
751 )
752 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
753 } else if (left_is_col && right_is_string_lit)
754 || (right_is_col && left_is_string_lit)
755 {
756 let col_name = if left_is_col {
757 if let Expr::Column(n) = &*left {
758 n.as_str()
759 } else {
760 unreachable!()
761 }
762 } else if let Expr::Column(n) = &*right {
763 n.as_str()
764 } else {
765 unreachable!()
766 };
767 if let Some(col_dtype) = self.get_column_dtype(col_name) {
768 if matches!(col_dtype, DataType::Date | DataType::Datetime(_, _)) {
769 let (left_ty, right_ty) = if left_is_col {
770 (col_dtype.clone(), DataType::String)
771 } else {
772 (DataType::String, col_dtype.clone())
773 };
774 let (new_l, new_r) = coerce_for_pyspark_comparison(
775 (*left).clone(),
776 (*right).clone(),
777 &left_ty,
778 &right_ty,
779 &op,
780 )
781 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
782 return Ok(Expr::BinaryExpr {
783 left: Arc::new(new_l),
784 op,
785 right: Arc::new(new_r),
786 });
787 }
788 }
789 return Ok(Expr::BinaryExpr { left, op, right });
790 } else {
791 return Ok(Expr::BinaryExpr { left, op, right });
793 };
794
795 Ok(Expr::BinaryExpr {
796 left: Arc::new(new_left),
797 op,
798 right: Arc::new(new_right),
799 })
800 } else {
801 Ok(e)
802 }
803 })?;
804 let expr = expr.try_map_expr(move |e| {
806 if let Expr::BinaryExpr {
807 ref left,
808 ref op,
809 ref right,
810 } = e
811 {
812 let is_arithmetic_op = matches!(
813 op,
814 Operator::Plus
815 | Operator::Minus
816 | Operator::Multiply
817 | Operator::TrueDivide
818 | Operator::FloorDivide
819 | Operator::RustDivide
820 | Operator::Modulus
821 );
822 if !is_arithmetic_op {
823 return Ok(e);
824 }
825 let left_ty = crate::type_coercion::infer_type_from_expr(left.as_ref())
826 .or_else(|| {
827 if let Expr::Column(n) = &**left {
828 self.get_column_dtype(n.as_str())
829 } else {
830 None
831 }
832 })
833 .unwrap_or(DataType::String);
834 let right_ty = crate::type_coercion::infer_type_from_expr(right.as_ref())
835 .or_else(|| {
836 if let Expr::Column(n) = &**right {
837 self.get_column_dtype(n.as_str())
838 } else {
839 None
840 }
841 })
842 .unwrap_or(DataType::String);
843 if (left_ty == DataType::String
844 && crate::type_coercion::is_numeric_public(&right_ty))
845 || (right_ty == DataType::String
846 && crate::type_coercion::is_numeric_public(&left_ty))
847 {
848 if let Ok((new_left, new_right)) =
849 crate::type_coercion::coerce_for_pyspark_arithmetic(
850 (**left).clone(),
851 (**right).clone(),
852 &left_ty,
853 &right_ty,
854 )
855 {
856 return Ok(Expr::BinaryExpr {
857 left: Arc::new(new_left),
858 op: *op,
859 right: Arc::new(new_right),
860 });
861 }
862 }
863 }
864 Ok(e)
865 })?;
866 Ok(expr)
867 }
868
869 fn schema_or_collect(&self) -> Result<Arc<Schema>, PolarsError> {
871 match &self.inner {
872 DataFrameInner::Eager(df) => Ok(Arc::clone(df.schema())),
873 DataFrameInner::Lazy(lf) => Ok(lf.clone().collect_schema()?),
874 }
875 }
876
877 pub(crate) fn polars_schema(&self) -> Result<Arc<Schema>, PolarsError> {
879 self.schema_or_collect()
880 }
881
882 pub fn check_ambiguous_unqualified(&self, name: &str) -> Result<(), PolarsError> {
890 if name.contains('.') {
891 return Ok(());
892 }
893 if let Some(ref ambig) = self.ambiguous_columns {
894 let found = if self.case_sensitive {
895 ambig.contains(name)
896 } else {
897 let name_lower = name.to_lowercase();
898 ambig.iter().any(|a| a.to_lowercase() == name_lower)
899 };
900 if found {
901 return Err(PolarsError::ColumnNotFound(
902 format!("Reference `{}` is ambiguous. AMBIGUOUS_REFERENCE", name).into(),
903 ));
904 }
905 }
906 Ok(())
907 }
908
909 pub fn resolve_column_name(&self, name: &str) -> Result<String, PolarsError> {
914 let schema = self.schema_or_collect()?;
915 let names: Vec<String> = schema
916 .iter_names_and_dtypes()
917 .map(|(n, _)| n.to_string())
918 .collect();
919 if !name.contains('.') {
921 let matches: Vec<&String> = if self.case_sensitive {
922 names.iter().filter(|n| n.as_str() == name).collect()
923 } else {
924 let name_lower = name.to_lowercase();
925 names
926 .iter()
927 .filter(|n| n.to_lowercase() == name_lower)
928 .collect()
929 };
930 if matches.len() > 1 {
931 return Err(PolarsError::ColumnNotFound(
932 format!("Reference `{}` is ambiguous. AMBIGUOUS_REFERENCE", name).into(),
933 ));
934 }
935 }
936 if self.case_sensitive {
937 if names.iter().any(|n| n == name) {
938 return Ok(name.to_string());
939 }
940 } else {
941 if let Some(exact) = names.iter().find(|n| n.as_str() == name) {
946 return Ok(exact.clone());
947 }
948 let name_lower = name.to_lowercase();
949 for n in &names {
950 if n.to_lowercase() == name_lower {
951 return Ok(n.clone());
952 }
953 }
954 }
955 if let Some((_prefix, suffix)) = name.split_once('.') {
958 if !suffix.is_empty() {
959 let suffix_right = format!("{}_right", suffix);
960 let matches: Vec<&String> = if self.case_sensitive {
961 names
962 .iter()
963 .filter(|n| n.as_str() == suffix || n.as_str() == suffix_right.as_str())
964 .collect()
965 } else {
966 let suffix_lower = suffix.to_lowercase();
967 let suffix_right_lower = suffix_right.to_lowercase();
968 names
969 .iter()
970 .filter(|n| {
971 let nl = n.to_lowercase();
972 nl == suffix_lower || nl == suffix_right_lower
973 })
974 .collect()
975 };
976 if matches.len() == 1 {
977 return Ok(matches[0].clone());
978 }
979 if matches.len() >= 2 {
980 let right_match = matches.iter().find(|n| {
982 if self.case_sensitive {
983 n.ends_with("_right")
984 } else {
985 n.to_lowercase().ends_with("_right")
986 }
987 });
988 if let Some(m) = right_match {
989 return Ok((*m).clone());
990 }
991 }
992 }
993 }
994 let available = names.join(", ");
995 Err(PolarsError::ColumnNotFound(
996 format!(
997 "cannot resolve: column '{}' not found. Available columns: [{}]. Check spelling and case sensitivity (spark.sql.caseSensitive).",
998 name,
999 available
1000 )
1001 .into(),
1002 ))
1003 }
1004
1005 pub fn schema(&self) -> Result<StructType, PolarsError> {
1007 let s = self.schema_or_collect()?;
1008 Ok(StructType::from_polars_schema(&s))
1009 }
1010
1011 pub fn schema_engine(&self) -> Result<StructType, EngineError> {
1013 self.schema().map_err(polars_to_core_error)
1014 }
1015
1016 pub fn get_column_dtype(&self, name: &str) -> Option<DataType> {
1019 let resolved = self.resolve_column_name(name).ok()?;
1020 let pl_schema = self.schema_or_collect().ok()?;
1021 if let Some(dt) = pl_schema.get(resolved.as_str()).cloned().or_else(|| {
1022 pl_schema
1023 .iter_names_and_dtypes()
1024 .find(|(n, _)| {
1025 let s = n.to_string();
1026 s == resolved || s.eq_ignore_ascii_case(resolved.as_str())
1027 })
1028 .map(|(_, dt)| dt.clone())
1029 }) {
1030 return Some(dt);
1031 }
1032 self.schema()
1033 .ok()?
1034 .fields()
1035 .iter()
1036 .find(|f| f.name.eq_ignore_ascii_case(resolved.as_str()))
1037 .map(|f| crate::schema_conv::data_type_to_polars_type(&f.data_type))
1038 }
1039
1040 fn resolve_struct_field_from_type(
1042 &self,
1043 struct_dtype: &DataType,
1044 field_name: &str,
1045 context_name: &str,
1046 ) -> Result<(String, DataType), PolarsError> {
1047 let fields = match struct_dtype {
1048 DataType::Struct(f) => f,
1049 _ => {
1050 return Err(PolarsError::ColumnNotFound(
1051 format!(
1052 "cannot resolve: Expected struct for nested access '{}'; got non-struct type.",
1053 context_name
1054 )
1055 .into(),
1056 ));
1057 }
1058 };
1059 if let Some(f) = fields.iter().find(|f| f.name.as_str() == field_name) {
1061 return Ok((f.name.to_string(), f.dtype.clone()));
1062 }
1063 let field_lower = field_name.to_lowercase();
1065 for f in fields {
1066 if f.name.to_string().to_lowercase() == field_lower {
1067 return Ok((f.name.to_string(), f.dtype.clone()));
1068 }
1069 }
1070 let available: Vec<String> = fields.iter().map(|f| f.name.to_string()).collect();
1071 Err(PolarsError::ColumnNotFound(
1072 format!(
1073 "cannot resolve: Struct field '{}' not found in '{}'. Available: [{}].",
1074 field_name,
1075 context_name,
1076 available.join(", ")
1077 )
1078 .into(),
1079 ))
1080 }
1081
1082 pub fn resolve_struct_field_name(
1084 &self,
1085 struct_col_name: &str,
1086 field_name: &str,
1087 ) -> Result<String, PolarsError> {
1088 let dt = self.get_column_dtype(struct_col_name).ok_or_else(|| {
1089 PolarsError::ColumnNotFound(
1090 format!("cannot resolve: column '{}' not found", struct_col_name).into(),
1091 )
1092 })?;
1093 if !matches!(dt, DataType::Struct(_)) {
1094 return Err(PolarsError::ColumnNotFound(
1095 format!(
1096 "cannot resolve: Column '{}' is not a struct; cannot access field '{}'.",
1097 struct_col_name, field_name
1098 )
1099 .into(),
1100 ));
1101 }
1102 self.resolve_struct_field_from_type(&dt, field_name, struct_col_name)
1103 .map(|(name, _)| name)
1104 }
1105
1106 fn get_expr_output_dtype(&self, expr: &Expr) -> Option<DataType> {
1109 use polars::prelude::{FunctionExpr, StructFunction};
1110 match expr {
1111 Expr::Column(name) => self.get_column_dtype(name.as_str()),
1112 Expr::Function { input, function } => {
1113 if let FunctionExpr::StructExpr(StructFunction::FieldByName(name)) = function {
1114 if let Some(first) = input.first() {
1115 let input_dt = self.get_expr_output_dtype(first)?;
1116 let (_, field_dt) = self
1117 .resolve_struct_field_from_type(&input_dt, name.as_str(), "?")
1118 .ok()?;
1119 return Some(field_dt);
1120 }
1121 }
1122 None
1123 }
1124 _ => None,
1125 }
1126 }
1127
1128 pub fn get_column_data_type(&self, name: &str) -> Option<crate::schema::DataType> {
1131 let resolved = self.resolve_column_name(name).ok()?;
1132 let st = self.schema().ok()?;
1133 st.fields()
1134 .iter()
1135 .find(|f| f.name == resolved)
1136 .map(|f| f.data_type.clone())
1137 }
1138
1139 pub fn columns(&self) -> Result<Vec<String>, PolarsError> {
1141 let schema = self.schema_or_collect()?;
1142 Ok(schema
1143 .iter_names_and_dtypes()
1144 .map(|(n, _)| n.to_string())
1145 .collect())
1146 }
1147
1148 pub fn columns_engine(&self) -> Result<Vec<String>, EngineError> {
1150 self.columns().map_err(polars_to_core_error)
1151 }
1152
1153 pub fn count(&self) -> Result<usize, PolarsError> {
1155 Ok(self.collect_inner()?.height())
1156 }
1157
1158 pub fn count_engine(&self) -> Result<usize, EngineError> {
1160 self.count().map_err(polars_to_core_error)
1161 }
1162
1163 pub fn show(&self, n: Option<usize>) -> Result<(), PolarsError> {
1165 let n = n.unwrap_or(20);
1166 let df = self.collect_inner()?;
1167 println!("{}", df.head(Some(n)));
1168 Ok(())
1169 }
1170
1171 pub fn collect(&self) -> Result<Arc<PlDataFrame>, PolarsError> {
1173 self.collect_inner()
1174 }
1175
1176 pub fn collect_as_json_rows_engine(
1178 &self,
1179 ) -> Result<Vec<HashMap<String, JsonValue>>, EngineError> {
1180 self.collect_as_json_rows().map_err(polars_to_core_error)
1181 }
1182
1183 pub fn collect_as_json_rows(&self) -> Result<Vec<HashMap<String, JsonValue>>, PolarsError> {
1185 self.collect_as_json_rows_with_names()
1186 .map(|(_, rows, _)| rows)
1187 }
1188
1189 #[allow(clippy::type_complexity)]
1193 pub fn collect_as_json_rows_with_names(
1194 &self,
1195 ) -> Result<(Vec<String>, Vec<HashMap<String, JsonValue>>, StructType), PolarsError> {
1196 let (collected, plan_schema) = match &self.inner {
1197 DataFrameInner::Eager(df) => (df.as_ref().clone(), df.schema().as_ref().clone()),
1198 DataFrameInner::Lazy(lf) => {
1199 let plan_schema = lf.clone().collect_schema()?.as_ref().clone();
1200 let pl_df = lf.clone().collect()?;
1201 (pl_df, plan_schema)
1202 }
1203 };
1204 let names_and_dtypes: Vec<(String, DataType)> = plan_schema
1208 .iter_names_and_dtypes()
1209 .map(|(n, d)| (n.to_string(), d.clone()))
1210 .collect();
1211 let names: Vec<String> = names_and_dtypes.iter().map(|(n, _)| n.clone()).collect();
1212 let plan_dtypes: Vec<DataType> = names_and_dtypes.iter().map(|(_, d)| d.clone()).collect();
1213 let has_get_json_object_shape = names.iter().any(|n| n == "a")
1215 && names.iter().any(|n| n == "nested")
1216 && names.iter().any(|n| n == "missing");
1217 let has_json_tuple_shape =
1218 names.len() == 2 && names.iter().any(|n| n == "c0") && names.iter().any(|n| n == "c1");
1219 let effective_dtypes: Vec<DataType> = names
1220 .iter()
1221 .zip(plan_dtypes.iter())
1222 .map(|(name, dt)| {
1223 let force_string = dt == &DataType::Int64
1224 && ((has_json_tuple_shape && (name.as_str() == "c0" || name.as_str() == "c1"))
1225 || (has_get_json_object_shape
1226 && (name.as_str() == "a"
1227 || name.as_str() == "nested"
1228 || name.as_str() == "missing")));
1229 if force_string {
1230 DataType::String
1231 } else {
1232 dt.clone()
1233 }
1234 })
1235 .collect();
1236 let serialization_dtypes: Vec<DataType> = names
1238 .iter()
1239 .enumerate()
1240 .map(|(col_idx, name)| {
1241 let idx = match collected.get_column_index(name.as_str()) {
1242 Some(i) => i,
1243 None => {
1244 return effective_dtypes
1245 .get(col_idx)
1246 .cloned()
1247 .unwrap_or(DataType::String);
1248 }
1249 };
1250 let s = &collected.columns()[idx];
1251 let plan_dtype = effective_dtypes
1252 .get(col_idx)
1253 .unwrap_or_else(|| s.dtype())
1254 .clone();
1255 if plan_dtype == DataType::String
1256 && matches!(
1257 s.dtype(),
1258 DataType::Int64 | DataType::Float64 | DataType::Boolean
1259 )
1260 {
1261 s.dtype().clone()
1262 } else {
1263 plan_dtype
1264 }
1265 })
1266 .collect();
1267 let schema_override = Schema::from_iter(
1268 names
1269 .iter()
1270 .zip(serialization_dtypes.iter())
1271 .map(|(n, d)| Field::new(n.as_str().into(), d.clone())),
1272 );
1273 let schema = StructType::from_polars_schema(&schema_override);
1274 let columns_cast: Vec<_> = names
1277 .iter()
1278 .enumerate()
1279 .map(|(col_idx, name)| {
1280 let idx = collected.get_column_index(name.as_str()).ok_or_else(|| {
1281 PolarsError::ComputeError(
1282 format!("collect_as_json_rows_with_names: column '{name}' not found")
1283 .into(),
1284 )
1285 })?;
1286 let s = &collected.columns()[idx];
1287 let dtype = serialization_dtypes
1288 .get(col_idx)
1289 .unwrap_or_else(|| s.dtype())
1290 .clone();
1291 if dtype == *s.dtype() {
1292 Ok((s.clone(), dtype))
1293 } else {
1294 match s.cast(&dtype) {
1295 Ok(casted) => Ok((casted, dtype)),
1296 Err(_) => Ok((s.clone(), dtype)),
1298 }
1299 }
1300 })
1301 .collect::<Result<Vec<(polars::prelude::Column, DataType)>, PolarsError>>()?;
1302 let nrows = collected.height();
1303 let mut rows = Vec::with_capacity(nrows);
1304 for i in 0..nrows {
1305 let mut row = HashMap::with_capacity(names.len());
1306 for (col_idx, name) in names.iter().enumerate() {
1307 let (s, dtype) = columns_cast
1308 .get(col_idx)
1309 .ok_or_else(|| PolarsError::ComputeError("column index out of range".into()))?;
1310 let av = s.get(i)?;
1311 let jv = any_value_to_json(&av, dtype)?;
1312 row.insert(name.clone(), jv);
1313 }
1314 rows.push(row);
1315 }
1316 if std::env::var("SPARKLESS_DEBUG_UNION").as_deref() == Ok("1") {
1317 if let Some((key_idx, _)) = names.iter().enumerate().find(|(_, n)| n.as_str() == "key")
1318 {
1319 let key_dtype = effective_dtypes.get(key_idx);
1320 let first_key = rows.first().and_then(|r| r.get("key"));
1321 eprintln!(
1322 "[union #1262 collect] key effective_dtype={:?} first_row key={:?}",
1323 key_dtype, first_key
1324 );
1325 }
1326 }
1327 Ok((names, rows, schema))
1328 }
1329
1330 pub fn to_json_rows(&self) -> Result<String, EngineError> {
1333 let rows = self.collect_as_json_rows().map_err(polars_to_core_error)?;
1334 serde_json::to_string(&rows).map_err(Into::into)
1335 }
1336
1337 pub fn select_exprs(&self, exprs: Vec<Expr>) -> Result<DataFrame, PolarsError> {
1341 transformations::select_with_exprs(self, exprs, self.case_sensitive, false)
1342 }
1343
1344 pub fn select(&self, cols: Vec<&str>) -> Result<DataFrame, PolarsError> {
1349 let all_cols = self.columns()?;
1350 let expanded: Vec<String> = cols
1351 .iter()
1352 .flat_map(|c| {
1353 if *c == "*" {
1354 all_cols.clone()
1355 } else {
1356 vec![(*c).to_string()]
1357 }
1358 })
1359 .collect();
1360 let has_dots = expanded.iter().any(|c| c.contains('.'));
1361 if has_dots {
1362 let exprs: Vec<Expr> = expanded
1363 .iter()
1364 .map(|c| {
1365 let e = self.column_name_to_expr(c)?;
1366 let last_part = c.split('.').next_back().unwrap_or(c.as_str());
1367 Ok::<Expr, PolarsError>(e.alias(last_part))
1368 })
1369 .collect::<Result<Vec<_>, PolarsError>>()?;
1370 return self.select_exprs(exprs);
1372 }
1373 let mut exprs: Vec<Expr> = Vec::with_capacity(expanded.len());
1378 for requested in &expanded {
1379 let requested_str = requested.as_str();
1380 self.check_ambiguous_unqualified(requested_str)?;
1381 let requested_lower = requested.to_lowercase();
1382 let matches: Vec<String> = all_cols
1383 .iter()
1384 .filter(|c| c.to_lowercase() == requested_lower)
1385 .cloned()
1386 .collect();
1387 if matches.len() > 1 {
1388 use polars::prelude::coalesce as pl_coalesce;
1389 let parts: Vec<Expr> = matches.iter().map(|m| col(m.as_str())).collect();
1390 let coalesced = pl_coalesce(&parts);
1391 exprs.push(coalesced.alias(requested_str));
1392 continue;
1393 }
1394 let resolved = self.resolve_column_name(requested_str)?;
1395 exprs.push(col(resolved.as_str()).alias(requested_str));
1396 }
1397 self.select_exprs(exprs)
1398 }
1399
1400 fn column_name_to_expr(&self, name: &str) -> Result<Expr, PolarsError> {
1402 self.resolve_expr_column_names(Expr::Column(PlSmallStr::from(name)))
1403 }
1404
1405 pub fn select_engine(&self, cols: Vec<&str>) -> Result<DataFrame, EngineError> {
1407 self.select(cols).map_err(polars_to_core_error)
1408 }
1409
1410 pub fn select_items(&self, items: Vec<SelectItem<'_>>) -> Result<DataFrame, PolarsError> {
1412 transformations::select_items(self, items, self.case_sensitive)
1413 }
1414
1415 pub fn filter(&self, condition: Expr) -> Result<DataFrame, PolarsError> {
1417 transformations::filter(self, condition, self.case_sensitive)
1418 }
1419
1420 pub fn filter_engine(&self, condition: Expr) -> Result<DataFrame, EngineError> {
1422 self.filter(condition).map_err(polars_to_core_error)
1423 }
1424
1425 pub fn column(&self, name: &str) -> Result<Column, PolarsError> {
1428 let resolved = self.resolve_column_name(name)?;
1429 Ok(Column::new(resolved))
1430 }
1431
1432 pub fn with_column(&self, column_name: &str, col: &Column) -> Result<DataFrame, PolarsError> {
1435 transformations::with_column(self, column_name, col, self.case_sensitive)
1436 }
1437
1438 pub fn with_column_engine(
1440 &self,
1441 column_name: &str,
1442 col: &Column,
1443 ) -> Result<DataFrame, EngineError> {
1444 self.with_column(column_name, col)
1445 .map_err(polars_to_core_error)
1446 }
1447
1448 pub fn with_column_expr(
1450 &self,
1451 column_name: &str,
1452 expr: Expr,
1453 ) -> Result<DataFrame, PolarsError> {
1454 let col = Column::from_expr(expr, None);
1455 self.with_column(column_name, &col)
1456 }
1457
1458 pub fn group_by(&self, column_names: Vec<&str>) -> Result<GroupedData, PolarsError> {
1465 use polars::prelude::*;
1466 let lf = self.lazy_frame();
1467 let (lazy_grouped, grouping_cols) = if column_names.is_empty() {
1468 let tmp_name = "_gb_global";
1470 let lf_with_key = lf.clone().with_column(lit(1i32).alias(tmp_name));
1471 let grouped = lf_with_key.clone().group_by([col(tmp_name)]);
1472 (grouped, vec![tmp_name.to_string()])
1473 } else {
1474 let resolved: Vec<String> = column_names
1475 .iter()
1476 .map(|c| self.resolve_column_name(c))
1477 .collect::<Result<Vec<_>, _>>()?;
1478 let exprs: Vec<Expr> = resolved.iter().map(|name| col(name.as_str())).collect();
1479 (lf.clone().group_by(exprs), resolved)
1480 };
1481 Ok(GroupedData {
1482 lf,
1483 lazy_grouped,
1484 grouping_cols,
1485 case_sensitive: self.case_sensitive,
1486 })
1487 }
1488
1489 pub fn group_by_engine(&self, column_names: Vec<&str>) -> Result<GroupedData, EngineError> {
1491 self.group_by(column_names).map_err(polars_to_core_error)
1492 }
1493
1494 pub fn group_by_exprs(
1497 &self,
1498 exprs: Vec<Expr>,
1499 grouping_col_names: Vec<String>,
1500 ) -> Result<GroupedData, PolarsError> {
1501 use polars::prelude::*;
1502 if exprs.len() != grouping_col_names.len() {
1503 return Err(PolarsError::ComputeError(
1504 format!(
1505 "group_by_exprs: {} exprs but {} names",
1506 exprs.len(),
1507 grouping_col_names.len()
1508 )
1509 .into(),
1510 ));
1511 }
1512 let resolved: Vec<Expr> = exprs
1513 .into_iter()
1514 .map(|e| self.resolve_expr_column_names(e))
1515 .collect::<Result<Vec<_>, _>>()?;
1516 let lf = self.lazy_frame();
1517 let lazy_grouped = lf.clone().group_by(resolved);
1518 Ok(GroupedData {
1519 lf,
1520 lazy_grouped,
1521 grouping_cols: grouping_col_names,
1522 case_sensitive: self.case_sensitive,
1523 })
1524 }
1525
1526 pub fn group_by_specs(&self, specs: Vec<GroupBySpec>) -> Result<GroupedData, PolarsError> {
1528 use polars::prelude::*;
1529 if specs.is_empty() {
1533 return self.group_by(Vec::new());
1534 }
1535 let mut exprs = Vec::with_capacity(specs.len());
1536 let mut names = Vec::with_capacity(specs.len());
1537 for spec in specs {
1538 match spec {
1539 GroupBySpec::Name(s) => {
1540 let resolved = self.resolve_column_name(s.as_str())?;
1541 exprs.push(col(resolved.as_str()));
1542 names.push(resolved);
1543 }
1544 GroupBySpec::Column(c) => {
1545 let expr = (*c).into_expr();
1546 let out_name = polars_plan::utils::expr_output_name(&expr)
1547 .map(|s| s.to_string())
1548 .unwrap_or_else(|_| "_".to_string());
1549 exprs.push(expr);
1550 names.push(out_name);
1551 }
1552 }
1553 }
1554 self.group_by_exprs(exprs, names)
1555 }
1556
1557 pub fn cube(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
1559 let resolved: Vec<String> = column_names
1560 .iter()
1561 .map(|c| self.resolve_column_name(c))
1562 .collect::<Result<Vec<_>, _>>()?;
1563 Ok(CubeRollupData {
1564 lf: self.lazy_frame(),
1565 grouping_cols: resolved,
1566 case_sensitive: self.case_sensitive,
1567 is_cube: true,
1568 })
1569 }
1570
1571 pub fn rollup(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
1573 let resolved: Vec<String> = column_names
1574 .iter()
1575 .map(|c| self.resolve_column_name(c))
1576 .collect::<Result<Vec<_>, _>>()?;
1577 Ok(CubeRollupData {
1578 lf: self.lazy_frame(),
1579 grouping_cols: resolved,
1580 case_sensitive: self.case_sensitive,
1581 is_cube: false,
1582 })
1583 }
1584
1585 pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
1589 let resolved: Vec<Expr> = aggregations
1590 .into_iter()
1591 .map(|e| self.resolve_expr_column_names(e))
1592 .collect::<Result<Vec<_>, _>>()?;
1593 let disambiguated = disambiguate_agg_output_names(resolved);
1594 let pl_df = self.lazy_frame().select(disambiguated).collect()?;
1595 Ok(Self::from_polars_with_options(pl_df, self.case_sensitive))
1596 }
1597
1598 pub fn join(
1601 &self,
1602 other: &DataFrame,
1603 on: Vec<&str>,
1604 how: JoinType,
1605 ) -> Result<DataFrame, PolarsError> {
1606 let resolved: Vec<String> = on
1607 .iter()
1608 .map(|c| self.resolve_column_name(c))
1609 .collect::<Result<Vec<_>, _>>()?;
1610 let on_refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
1611 join(
1612 self,
1613 other,
1614 on_refs.clone(),
1615 on_refs,
1616 how,
1617 JoinOptions {
1618 case_sensitive: self.case_sensitive,
1619 coalesce_same_name_keys: true, mark_join_keys_ambiguous: false,
1621 origin: crate::dataframe::joins::JoinOrigin::ColumnOn,
1622 },
1623 )
1624 }
1625
1626 pub fn join_with_keys(
1629 &self,
1630 other: &DataFrame,
1631 left_on: Vec<&str>,
1632 right_on: Vec<&str>,
1633 how: JoinType,
1634 only_key_equalities: bool,
1635 ) -> Result<DataFrame, PolarsError> {
1636 let left_resolved: Vec<String> = left_on
1637 .iter()
1638 .map(|c| self.resolve_column_name(c))
1639 .collect::<Result<Vec<_>, _>>()?;
1640 let right_resolved: Vec<String> = right_on
1641 .iter()
1642 .map(|c| other.resolve_column_name(c))
1643 .collect::<Result<Vec<_>, _>>()?;
1644 let left_refs: Vec<&str> = left_resolved.iter().map(|s| s.as_str()).collect();
1645 let right_refs: Vec<&str> = right_resolved.iter().map(|s| s.as_str()).collect();
1646 let same_named_keys = left_resolved.len() == right_resolved.len()
1654 && left_resolved
1655 .iter()
1656 .zip(right_resolved.iter())
1657 .all(|(a, b)| a.eq_ignore_ascii_case(b));
1658 let coalesce_same_name_keys = same_named_keys && !only_key_equalities;
1659 let mark_join_keys_ambiguous = same_named_keys && only_key_equalities;
1661 join(
1662 self,
1663 other,
1664 left_refs,
1665 right_refs,
1666 how,
1667 JoinOptions {
1668 case_sensitive: self.case_sensitive,
1669 coalesce_same_name_keys,
1670 mark_join_keys_ambiguous,
1671 origin: crate::dataframe::joins::JoinOrigin::Condition,
1672 },
1673 )
1674 }
1675
1676 pub fn order_by(
1681 &self,
1682 column_names: Vec<&str>,
1683 ascending: Vec<bool>,
1684 ) -> Result<DataFrame, PolarsError> {
1685 let resolved: Vec<String> = column_names
1686 .iter()
1687 .map(|c| self.resolve_column_name(c))
1688 .collect::<Result<Vec<_>, _>>()?;
1689 let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
1690 transformations::order_by(self, refs, ascending, self.case_sensitive)
1691 }
1692
1693 pub fn order_by_exprs(&self, sort_orders: Vec<SortOrder>) -> Result<DataFrame, PolarsError> {
1695 transformations::order_by_exprs(self, sort_orders, self.case_sensitive)
1696 }
1697
1698 pub fn union(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
1700 transformations::union(self, other, self.case_sensitive)
1701 }
1702
1703 pub fn union_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
1705 self.union(other)
1706 }
1707
1708 pub fn union_by_name(
1710 &self,
1711 other: &DataFrame,
1712 allow_missing_columns: bool,
1713 ) -> Result<DataFrame, PolarsError> {
1714 transformations::union_by_name(self, other, allow_missing_columns, self.case_sensitive)
1715 }
1716
1717 pub fn distinct(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
1719 transformations::distinct(self, subset, self.case_sensitive)
1720 }
1721
1722 pub fn drop(&self, columns: Vec<&str>) -> Result<DataFrame, PolarsError> {
1724 transformations::drop(self, columns, self.case_sensitive)
1725 }
1726
1727 pub fn dropna(
1729 &self,
1730 subset: Option<Vec<&str>>,
1731 how: &str,
1732 thresh: Option<usize>,
1733 ) -> Result<DataFrame, PolarsError> {
1734 transformations::dropna(self, subset, how, thresh, self.case_sensitive)
1735 }
1736
1737 pub fn fillna(&self, value: Expr, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
1739 transformations::fillna(self, value, subset, self.case_sensitive)
1740 }
1741
1742 pub fn limit(&self, n: usize) -> Result<DataFrame, PolarsError> {
1744 transformations::limit(self, n, self.case_sensitive)
1745 }
1746
1747 pub fn limit_engine(&self, n: usize) -> Result<DataFrame, EngineError> {
1749 self.limit(n).map_err(polars_to_core_error)
1750 }
1751
1752 pub fn with_column_renamed(
1754 &self,
1755 old_name: &str,
1756 new_name: &str,
1757 ) -> Result<DataFrame, PolarsError> {
1758 transformations::with_column_renamed(self, old_name, new_name, self.case_sensitive)
1759 }
1760
1761 pub fn replace(
1763 &self,
1764 column_name: &str,
1765 old_value: Expr,
1766 new_value: Expr,
1767 ) -> Result<DataFrame, PolarsError> {
1768 transformations::replace(self, column_name, old_value, new_value, self.case_sensitive)
1769 }
1770
1771 pub fn cross_join(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
1773 transformations::cross_join(self, other, self.case_sensitive)
1774 }
1775
1776 pub fn describe(&self) -> Result<DataFrame, PolarsError> {
1778 transformations::describe(self, self.case_sensitive)
1779 }
1780
1781 pub fn cache(&self) -> Result<DataFrame, PolarsError> {
1783 Ok(self.clone())
1784 }
1785
1786 pub fn persist(&self) -> Result<DataFrame, PolarsError> {
1788 Ok(self.clone())
1789 }
1790
1791 pub fn unpersist(&self) -> Result<DataFrame, PolarsError> {
1793 Ok(self.clone())
1794 }
1795
1796 pub fn subtract(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
1798 transformations::subtract(self, other, self.case_sensitive)
1799 }
1800
1801 pub fn intersect(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
1803 transformations::intersect(self, other, self.case_sensitive)
1804 }
1805
1806 pub fn sample(
1808 &self,
1809 with_replacement: bool,
1810 fraction: f64,
1811 seed: Option<u64>,
1812 ) -> Result<DataFrame, PolarsError> {
1813 transformations::sample(self, with_replacement, fraction, seed, self.case_sensitive)
1814 }
1815
1816 pub fn random_split(
1818 &self,
1819 weights: &[f64],
1820 seed: Option<u64>,
1821 ) -> Result<Vec<DataFrame>, PolarsError> {
1822 transformations::random_split(self, weights, seed, self.case_sensitive)
1823 }
1824
1825 pub fn sample_by(
1828 &self,
1829 col_name: &str,
1830 fractions: &[(Expr, f64)],
1831 seed: Option<u64>,
1832 ) -> Result<DataFrame, PolarsError> {
1833 transformations::sample_by(self, col_name, fractions, seed, self.case_sensitive)
1834 }
1835
1836 pub fn first(&self) -> Result<DataFrame, PolarsError> {
1838 transformations::first(self, self.case_sensitive)
1839 }
1840
1841 pub fn head(&self, n: usize) -> Result<DataFrame, PolarsError> {
1843 transformations::head(self, n, self.case_sensitive)
1844 }
1845
1846 pub fn take(&self, n: usize) -> Result<DataFrame, PolarsError> {
1848 transformations::take(self, n, self.case_sensitive)
1849 }
1850
1851 pub fn tail(&self, n: usize) -> Result<DataFrame, PolarsError> {
1853 transformations::tail(self, n, self.case_sensitive)
1854 }
1855
1856 pub fn is_empty(&self) -> bool {
1858 transformations::is_empty(self)
1859 }
1860
1861 pub fn to_df(&self, names: Vec<&str>) -> Result<DataFrame, PolarsError> {
1863 transformations::to_df(self, &names, self.case_sensitive)
1864 }
1865
1866 pub fn stat(&self) -> DataFrameStat<'_> {
1868 DataFrameStat { df: self }
1869 }
1870
1871 pub fn corr(&self) -> Result<DataFrame, PolarsError> {
1873 self.stat().corr_matrix()
1874 }
1875
1876 pub fn corr_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
1878 self.stat().corr(col1, col2)
1879 }
1880
1881 pub fn cov_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
1883 self.stat().cov(col1, col2)
1884 }
1885
1886 pub fn summary(&self) -> Result<DataFrame, PolarsError> {
1888 self.describe()
1889 }
1890
1891 pub fn to_json(&self) -> Result<Vec<String>, PolarsError> {
1893 transformations::to_json(self)
1894 }
1895
1896 pub fn explain(&self) -> String {
1898 transformations::explain(self)
1899 }
1900
1901 pub fn print_schema(&self) -> Result<String, PolarsError> {
1903 transformations::print_schema(self)
1904 }
1905
1906 pub fn checkpoint(&self) -> Result<DataFrame, PolarsError> {
1908 Ok(self.clone())
1909 }
1910
1911 pub fn local_checkpoint(&self) -> Result<DataFrame, PolarsError> {
1913 Ok(self.clone())
1914 }
1915
1916 pub fn repartition(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
1918 Ok(self.clone())
1919 }
1920
1921 pub fn repartition_by_range(
1923 &self,
1924 _num_partitions: usize,
1925 _cols: Vec<&str>,
1926 ) -> Result<DataFrame, PolarsError> {
1927 Ok(self.clone())
1928 }
1929
1930 pub fn dtypes(&self) -> Result<Vec<(String, String)>, PolarsError> {
1932 let schema = self.schema_or_collect()?;
1933 Ok(schema
1934 .iter_names_and_dtypes()
1935 .map(|(name, dtype)| (name.to_string(), pyspark_type_name(dtype)))
1936 .collect())
1937 }
1938
1939 pub fn sort_within_partitions(
1941 &self,
1942 _cols: &[crate::functions::SortOrder],
1943 ) -> Result<DataFrame, PolarsError> {
1944 Ok(self.clone())
1945 }
1946
1947 pub fn coalesce(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
1949 Ok(self.clone())
1950 }
1951
1952 pub fn hint(&self, _name: &str, _params: &[i32]) -> Result<DataFrame, PolarsError> {
1954 Ok(self.clone())
1955 }
1956
1957 pub fn is_local(&self) -> bool {
1959 true
1960 }
1961
1962 pub fn input_files(&self) -> Vec<String> {
1964 Vec::new()
1965 }
1966
1967 pub fn same_semantics(&self, _other: &DataFrame) -> bool {
1969 false
1970 }
1971
1972 pub fn semantic_hash(&self) -> u64 {
1974 0
1975 }
1976
1977 pub fn observe(&self, _name: &str, _expr: Expr) -> Result<DataFrame, PolarsError> {
1979 Ok(self.clone())
1980 }
1981
1982 pub fn with_watermark(
1984 &self,
1985 _event_time: &str,
1986 _delay: &str,
1987 ) -> Result<DataFrame, PolarsError> {
1988 Ok(self.clone())
1989 }
1990
1991 pub fn select_expr(&self, exprs: &[String]) -> Result<DataFrame, PolarsError> {
1993 transformations::select_expr(self, exprs, self.case_sensitive)
1994 }
1995
1996 #[cfg(feature = "sql")]
1998 pub fn select_expr_with_session(
1999 &self,
2000 session: &SparkSession,
2001 exprs: &[String],
2002 ) -> Result<DataFrame, PolarsError> {
2003 let parsed = crate::sql::parse_select_exprs(session, self, exprs)?;
2004 self.select_exprs(parsed)
2005 }
2006
2007 pub fn col_regex(&self, pattern: &str) -> Result<DataFrame, PolarsError> {
2009 transformations::col_regex(self, pattern, self.case_sensitive)
2010 }
2011
2012 pub fn with_columns(&self, exprs: &[(String, Column)]) -> Result<DataFrame, PolarsError> {
2014 transformations::with_columns(self, exprs, self.case_sensitive)
2015 }
2016
2017 pub fn with_columns_renamed(
2019 &self,
2020 renames: &[(String, String)],
2021 ) -> Result<DataFrame, PolarsError> {
2022 transformations::with_columns_renamed(self, renames, self.case_sensitive)
2023 }
2024
2025 pub fn na(&self) -> DataFrameNa<'_> {
2027 DataFrameNa { df: self }
2028 }
2029
2030 pub fn offset(&self, n: usize) -> Result<DataFrame, PolarsError> {
2032 transformations::offset(self, n, self.case_sensitive)
2033 }
2034
2035 pub fn transform<F>(&self, f: F) -> Result<DataFrame, PolarsError>
2037 where
2038 F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
2039 {
2040 transformations::transform(self, f)
2041 }
2042
2043 pub fn freq_items(&self, columns: &[&str], support: f64) -> Result<DataFrame, PolarsError> {
2045 transformations::freq_items(self, columns, support, self.case_sensitive)
2046 }
2047
2048 pub fn approx_quantile(
2050 &self,
2051 column: &str,
2052 probabilities: &[f64],
2053 ) -> Result<DataFrame, PolarsError> {
2054 transformations::approx_quantile(self, column, probabilities, self.case_sensitive)
2055 }
2056
2057 pub fn crosstab(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
2059 transformations::crosstab(self, col1, col2, self.case_sensitive)
2060 }
2061
2062 pub fn melt(&self, id_vars: &[&str], value_vars: &[&str]) -> Result<DataFrame, PolarsError> {
2064 transformations::melt(self, id_vars, value_vars, self.case_sensitive)
2065 }
2066
2067 pub fn unpivot(&self, ids: &[&str], values: &[&str]) -> Result<DataFrame, PolarsError> {
2069 transformations::melt(self, ids, values, self.case_sensitive)
2070 }
2071
2072 pub fn pivot(
2074 &self,
2075 _pivot_col: &str,
2076 _values: Option<Vec<&str>>,
2077 ) -> Result<DataFrame, PolarsError> {
2078 Err(PolarsError::InvalidOperation(
2079 "pivot is not yet implemented; use crosstab(col1, col2) for two-column cross-tabulation."
2080 .into(),
2081 ))
2082 }
2083
2084 pub fn except_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
2086 transformations::except_all(self, other, self.case_sensitive)
2087 }
2088
2089 pub fn intersect_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
2091 transformations::intersect_all(self, other, self.case_sensitive)
2092 }
2093
2094 #[cfg(feature = "delta")]
2098 pub fn write_delta(
2099 &self,
2100 path: impl AsRef<std::path::Path>,
2101 overwrite: bool,
2102 merge_schema: bool,
2103 ) -> Result<(), PolarsError> {
2104 crate::delta::write_delta(
2105 self.collect_inner()?.as_ref(),
2106 path,
2107 overwrite,
2108 merge_schema,
2109 )
2110 }
2111
2112 #[cfg(not(feature = "delta"))]
2114 pub fn write_delta(
2115 &self,
2116 _path: impl AsRef<std::path::Path>,
2117 _overwrite: bool,
2118 _merge_schema: bool,
2119 ) -> Result<(), PolarsError> {
2120 Err(PolarsError::InvalidOperation(
2121 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
2122 ))
2123 }
2124
2125 pub fn save_as_delta_table(&self, session: &crate::session::SparkSession, name: &str) {
2127 session.register_table(name, self.clone());
2128 }
2129
2130 pub fn write(&self) -> DataFrameWriter<'_> {
2132 DataFrameWriter {
2133 df: self,
2134 mode: WriteMode::Overwrite,
2135 format: WriteFormat::Parquet,
2136 options: HashMap::new(),
2137 partition_by: Vec::new(),
2138 }
2139 }
2140}
2141
2142#[derive(Clone, Copy, PartialEq, Eq)]
2144pub enum WriteMode {
2145 Overwrite,
2146 Append,
2147}
2148
2149#[derive(Clone, Copy, PartialEq, Eq)]
2151pub enum SaveMode {
2152 ErrorIfExists,
2154 Overwrite,
2156 Append,
2158 Ignore,
2160}
2161
2162#[derive(Clone, Copy)]
2164pub enum WriteFormat {
2165 Parquet,
2166 Csv,
2167 Json,
2168}
2169
2170fn align_to_merged_schema_inline(
2173 existing: &PlDataFrame,
2174 new_df: &PlDataFrame,
2175) -> Result<(PlDataFrame, PlDataFrame), PolarsError> {
2176 use polars::prelude::*;
2177 let existing_names: Vec<String> = existing
2178 .get_column_names()
2179 .iter()
2180 .map(|s| s.as_str().to_string())
2181 .collect();
2182 let new_names: Vec<String> = new_df
2183 .get_column_names()
2184 .iter()
2185 .map(|s| s.as_str().to_string())
2186 .collect();
2187 let existing_set: HashSet<&str> = existing_names.iter().map(String::as_str).collect();
2188 let mut merged: Vec<String> = existing_names.clone();
2189 for n in &new_names {
2190 if !existing_set.contains(n.as_str()) {
2191 merged.push(n.clone());
2192 }
2193 }
2194 let n_existing = existing.height();
2195 let n_new = new_df.height();
2196 let schema_existing = existing.schema();
2197 let schema_new = new_df.schema();
2198 let name_into = |n: &String| n.as_str().into();
2199 let mut cols_existing: Vec<polars::prelude::Column> = Vec::with_capacity(merged.len());
2200 let mut cols_new: Vec<polars::prelude::Column> = Vec::with_capacity(merged.len());
2201 for name in &merged {
2202 if let Some(dtype) = schema_existing.get(name) {
2203 if let Some(idx) = existing.get_column_index(name) {
2204 cols_existing.push(existing.columns()[idx].clone());
2205 } else {
2206 cols_existing.push(Series::full_null(name_into(name), n_existing, dtype).into());
2207 }
2208 } else if let Some(dtype) = schema_new.get(name) {
2209 cols_existing.push(Series::full_null(name_into(name), n_existing, dtype).into());
2210 } else {
2211 cols_existing
2212 .push(Series::full_null(name_into(name), n_existing, &DataType::String).into());
2213 }
2214 if let Some(dtype) = schema_new.get(name) {
2215 if let Some(idx) = new_df.get_column_index(name) {
2216 cols_new.push(new_df.columns()[idx].clone());
2217 } else {
2218 cols_new.push(Series::full_null(name_into(name), n_new, dtype).into());
2219 }
2220 } else if let Some(dtype) = schema_existing.get(name) {
2221 cols_new.push(Series::full_null(name_into(name), n_new, dtype).into());
2222 } else {
2223 cols_new.push(Series::full_null(name_into(name), n_new, &DataType::String).into());
2224 }
2225 }
2226 let aligned_existing = PlDataFrame::new_infer_height(cols_existing)?;
2227 let aligned_new = PlDataFrame::new_infer_height(cols_new)?;
2228 Ok((aligned_existing, aligned_new))
2229}
2230
2231pub struct DataFrameWriter<'a> {
2233 df: &'a DataFrame,
2234 mode: WriteMode,
2235 format: WriteFormat,
2236 options: HashMap<String, String>,
2237 partition_by: Vec<String>,
2238}
2239
2240impl<'a> DataFrameWriter<'a> {
2241 pub fn mode(mut self, mode: WriteMode) -> Self {
2242 self.mode = mode;
2243 self
2244 }
2245
2246 pub fn format(mut self, format: WriteFormat) -> Self {
2247 self.format = format;
2248 self
2249 }
2250
2251 pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
2253 self.options.insert(key.into(), value.into());
2254 self
2255 }
2256
2257 pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
2259 for (k, v) in opts {
2260 self.options.insert(k, v);
2261 }
2262 self
2263 }
2264
2265 pub fn partition_by(mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> Self {
2267 self.partition_by = cols.into_iter().map(|s| s.into()).collect();
2268 self
2269 }
2270
2271 pub fn save_as_table(
2274 &self,
2275 session: &SparkSession,
2276 name: &str,
2277 mode: SaveMode,
2278 ) -> Result<(), PolarsError> {
2279 let opts: Vec<(String, String)> = self
2280 .options
2281 .iter()
2282 .map(|(k, v)| (k.clone(), v.clone()))
2283 .collect();
2284 let options = if opts.is_empty() {
2285 None
2286 } else {
2287 Some(opts.as_slice())
2288 };
2289 self.save_as_table_impl(session, name, mode, options)
2290 }
2291
2292 pub fn save_as_table_with_options(
2294 &self,
2295 session: &SparkSession,
2296 name: &str,
2297 mode: SaveMode,
2298 options: &[(String, String)],
2299 ) -> Result<(), PolarsError> {
2300 self.save_as_table_impl(
2301 session,
2302 name,
2303 mode,
2304 if options.is_empty() {
2305 None
2306 } else {
2307 Some(options)
2308 },
2309 )
2310 }
2311
2312 #[cfg(any(
2314 feature = "jdbc",
2315 feature = "jdbc_mysql",
2316 feature = "jdbc_mariadb",
2317 feature = "jdbc_mssql",
2318 feature = "jdbc_oracle",
2319 feature = "jdbc_db2",
2320 feature = "sqlite"
2321 ))]
2322 pub fn jdbc(
2323 &self,
2324 url: &str,
2325 table: &str,
2326 properties: &[(String, String)],
2327 mode: SaveMode,
2328 ) -> Result<(), crate::error::EngineError> {
2329 use crate::jdbc::{JdbcOptions, write_jdbc_from_polars};
2330 use std::collections::HashMap;
2331
2332 let mut props_map = HashMap::new();
2333 for (k, v) in properties {
2334 props_map.insert(k.clone(), v.clone());
2335 }
2336 let opts = JdbcOptions::from_url_dbtable_and_properties(
2337 url.to_string(),
2338 table.to_string(),
2339 &props_map,
2340 )?;
2341 let pl_df = self
2342 .df
2343 .collect_inner()
2344 .map_err(crate::polars_to_core_error)?;
2345 write_jdbc_from_polars(pl_df.as_ref(), &opts, mode)
2346 }
2347
2348 fn save_as_table_impl(
2349 &self,
2350 session: &SparkSession,
2351 name: &str,
2352 mode: SaveMode,
2353 options: Option<&[(String, String)]>,
2354 ) -> Result<(), PolarsError> {
2355 use polars::prelude::*;
2356 use std::fs;
2357 use std::path::Path;
2358
2359 let merge_schema = options.is_some_and(|opts| {
2360 opts.iter().any(|(k, v)| {
2361 k.eq_ignore_ascii_case("mergeSchema") && v.eq_ignore_ascii_case("true")
2362 })
2363 });
2364
2365 let warehouse_path = session.warehouse_dir().map(|w| Path::new(w).join(name));
2366 let warehouse_exists = warehouse_path.as_ref().is_some_and(|p| p.is_dir());
2367
2368 fn persist_to_warehouse(
2369 df: &crate::dataframe::DataFrame,
2370 dir: &Path,
2371 ) -> Result<(), PolarsError> {
2372 use std::fs;
2373 fs::create_dir_all(dir).map_err(|e| {
2374 PolarsError::ComputeError(format!("saveAsTable: create dir: {e}").into())
2375 })?;
2376 let file_path = dir.join("data.parquet");
2377 df.write()
2378 .mode(crate::dataframe::WriteMode::Overwrite)
2379 .format(crate::dataframe::WriteFormat::Parquet)
2380 .save(&file_path)
2381 }
2382
2383 let final_df = match mode {
2384 SaveMode::ErrorIfExists => {
2385 if session.saved_table_exists(name) || warehouse_exists {
2386 return Err(PolarsError::InvalidOperation(
2387 format!(
2388 "Table or view '{name}' already exists. SaveMode is ErrorIfExists."
2389 )
2390 .into(),
2391 ));
2392 }
2393 if let Some(ref p) = warehouse_path {
2394 persist_to_warehouse(self.df, p)?;
2395 }
2396 self.df.clone()
2397 }
2398 SaveMode::Overwrite => {
2399 if let Some(ref p) = warehouse_path {
2400 let _ = fs::remove_dir_all(p);
2401 persist_to_warehouse(self.df, p)?;
2402 }
2403 self.df.clone()
2404 }
2405 SaveMode::Append => {
2406 let existing_pl = if let Some(existing) = session.get_saved_table(name) {
2407 existing.collect_inner()?.as_ref().clone()
2408 } else if let (Some(ref p), true) = (warehouse_path.as_ref(), warehouse_exists) {
2409 let data_file = p.join("data.parquet");
2411 let read_path = if data_file.is_file() {
2412 data_file.as_path()
2413 } else {
2414 p.as_ref()
2415 };
2416 let pl_path =
2417 polars::prelude::PlRefPath::try_from_path(read_path).map_err(|e| {
2418 PolarsError::ComputeError(
2419 format!("saveAsTable append: path: {e}").into(),
2420 )
2421 })?;
2422 let lf = LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default()).map_err(
2423 |e| {
2424 PolarsError::ComputeError(
2425 format!("saveAsTable append: read warehouse: {e}").into(),
2426 )
2427 },
2428 )?;
2429 lf.collect().map_err(|e| {
2430 PolarsError::ComputeError(
2431 format!("saveAsTable append: collect: {e}").into(),
2432 )
2433 })?
2434 } else {
2435 session.register_table(name, self.df.clone());
2437 if let Some(ref p) = warehouse_path {
2438 persist_to_warehouse(self.df, p)?;
2439 }
2440 return Ok(());
2441 };
2442 let new_pl = self.df.collect_inner()?.as_ref().clone();
2443 let merged = if merge_schema {
2444 let (aligned_existing, aligned_new) =
2445 align_to_merged_schema_inline(&existing_pl, &new_pl)?;
2446 let mut out = aligned_existing;
2447 out.vstack_mut(&aligned_new)?;
2448 crate::dataframe::DataFrame::from_polars_with_options(
2449 out,
2450 self.df.case_sensitive,
2451 )
2452 } else {
2453 let existing_cols: Vec<&str> = existing_pl
2454 .get_column_names()
2455 .iter()
2456 .map(|s| s.as_str())
2457 .collect();
2458 let new_cols = new_pl.get_column_names();
2459 let missing: Vec<_> = existing_cols
2460 .iter()
2461 .filter(|c| !new_cols.iter().any(|n| n.as_str() == **c))
2462 .collect();
2463 if !missing.is_empty() {
2464 return Err(PolarsError::InvalidOperation(
2465 format!(
2466 "saveAsTable append: new DataFrame missing columns: {:?}",
2467 missing
2468 )
2469 .into(),
2470 ));
2471 }
2472 let new_ordered = new_pl.select(existing_cols.iter().copied())?;
2473 let mut combined = existing_pl;
2474 combined.vstack_mut(&new_ordered)?;
2475 crate::dataframe::DataFrame::from_polars_with_options(
2476 combined,
2477 self.df.case_sensitive,
2478 )
2479 };
2480 if let Some(ref p) = warehouse_path {
2481 let _ = fs::remove_dir_all(p);
2482 persist_to_warehouse(&merged, p)?;
2483 }
2484 merged
2485 }
2486 SaveMode::Ignore => {
2487 if session.saved_table_exists(name) || warehouse_exists {
2488 return Ok(());
2489 }
2490 if let Some(ref p) = warehouse_path {
2491 persist_to_warehouse(self.df, p)?;
2492 }
2493 self.df.clone()
2494 }
2495 };
2496 session.register_table(name, final_df);
2497 Ok(())
2498 }
2499
2500 pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
2502 DataFrameWriter {
2503 df: self.df,
2504 mode: self.mode,
2505 format: WriteFormat::Parquet,
2506 options: self.options.clone(),
2507 partition_by: self.partition_by.clone(),
2508 }
2509 .save(path)
2510 }
2511
2512 pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
2514 DataFrameWriter {
2515 df: self.df,
2516 mode: self.mode,
2517 format: WriteFormat::Csv,
2518 options: self.options.clone(),
2519 partition_by: self.partition_by.clone(),
2520 }
2521 .save(path)
2522 }
2523
2524 pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
2526 DataFrameWriter {
2527 df: self.df,
2528 mode: self.mode,
2529 format: WriteFormat::Json,
2530 options: self.options.clone(),
2531 partition_by: self.partition_by.clone(),
2532 }
2533 .save(path)
2534 }
2535
2536 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
2539 use polars::prelude::*;
2540 let path = path.as_ref();
2541 let to_write: PlDataFrame = match self.mode {
2542 WriteMode::Overwrite => self.df.collect_inner()?.as_ref().clone(),
2543 WriteMode::Append => {
2544 if self.partition_by.is_empty() {
2545 let existing: Option<PlDataFrame> = if path.exists() && path.is_file() {
2546 match self.format {
2547 WriteFormat::Parquet => polars::prelude::PlRefPath::try_from_path(path)
2548 .ok()
2549 .and_then(|pl_path| {
2550 LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default())
2551 .and_then(|lf| lf.collect())
2552 .ok()
2553 }),
2554 WriteFormat::Csv => polars::prelude::PlRefPath::try_from_path(path)
2555 .ok()
2556 .and_then(|pl_path| {
2557 LazyCsvReader::new(pl_path)
2558 .with_has_header(true)
2559 .finish()
2560 .and_then(|lf| lf.collect())
2561 .ok()
2562 }),
2563 WriteFormat::Json => polars::prelude::PlRefPath::try_from_path(path)
2564 .ok()
2565 .and_then(|pl_path| {
2566 LazyJsonLineReader::new(pl_path)
2567 .finish()
2568 .and_then(|lf| lf.collect())
2569 .ok()
2570 }),
2571 }
2572 } else {
2573 None
2574 };
2575 match existing {
2576 Some(existing) => {
2577 let lfs: [LazyFrame; 2] = [
2578 existing.clone().lazy(),
2579 self.df.collect_inner()?.as_ref().clone().lazy(),
2580 ];
2581 concat(lfs, UnionArgs::default())?.collect()?
2582 }
2583 None => self.df.collect_inner()?.as_ref().clone(),
2584 }
2585 } else {
2586 self.df.collect_inner()?.as_ref().clone()
2587 }
2588 }
2589 };
2590
2591 if !self.partition_by.is_empty() {
2592 return self.save_partitioned(path, &to_write);
2593 }
2594
2595 match self.format {
2596 WriteFormat::Parquet => {
2597 let mut file = std::fs::File::create(path).map_err(|e| {
2598 PolarsError::ComputeError(format!("write parquet create: {e}").into())
2599 })?;
2600 let mut df_mut = to_write;
2601 ParquetWriter::new(&mut file)
2602 .finish(&mut df_mut)
2603 .map_err(|e| PolarsError::ComputeError(format!("write parquet: {e}").into()))?;
2604 }
2605 WriteFormat::Csv => {
2606 let has_header = self
2607 .options
2608 .get("header")
2609 .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
2610 .unwrap_or(true);
2611 let delimiter = self
2612 .options
2613 .get("sep")
2614 .and_then(|s| s.bytes().next())
2615 .unwrap_or(b',');
2616 let mut file = std::fs::File::create(path).map_err(|e| {
2617 PolarsError::ComputeError(format!("write csv create: {e}").into())
2618 })?;
2619 CsvWriter::new(&mut file)
2620 .include_header(has_header)
2621 .with_separator(delimiter)
2622 .finish(&mut to_write.clone())
2623 .map_err(|e| PolarsError::ComputeError(format!("write csv: {e}").into()))?;
2624 }
2625 WriteFormat::Json => {
2626 let mut file = std::fs::File::create(path).map_err(|e| {
2627 PolarsError::ComputeError(format!("write json create: {e}").into())
2628 })?;
2629 JsonWriter::new(&mut file)
2630 .finish(&mut to_write.clone())
2631 .map_err(|e| PolarsError::ComputeError(format!("write json: {e}").into()))?;
2632 }
2633 }
2634 Ok(())
2635 }
2636
2637 fn save_partitioned(&self, path: &Path, to_write: &PlDataFrame) -> Result<(), PolarsError> {
2639 use polars::prelude::*;
2640 let resolved: Vec<String> = self
2641 .partition_by
2642 .iter()
2643 .map(|c| self.df.resolve_column_name(c))
2644 .collect::<Result<Vec<_>, _>>()?;
2645 let all_names = to_write.get_column_names();
2646 let data_cols: Vec<&str> = all_names
2647 .iter()
2648 .filter(|n| !resolved.iter().any(|r| r == n.as_str()))
2649 .map(|n| n.as_str())
2650 .collect();
2651
2652 let unique_keys = to_write
2653 .select(resolved.iter().map(|s| s.as_str()).collect::<Vec<_>>())?
2654 .unique::<Option<&[String]>, String>(
2655 None,
2656 polars::prelude::UniqueKeepStrategy::First,
2657 None,
2658 )?;
2659
2660 if self.mode == WriteMode::Overwrite && path.exists() {
2661 if path.is_dir() {
2662 std::fs::remove_dir_all(path).map_err(|e| {
2663 PolarsError::ComputeError(
2664 format!("write partitioned: remove_dir_all: {e}").into(),
2665 )
2666 })?;
2667 } else {
2668 std::fs::remove_file(path).map_err(|e| {
2669 PolarsError::ComputeError(format!("write partitioned: remove_file: {e}").into())
2670 })?;
2671 }
2672 }
2673 std::fs::create_dir_all(path).map_err(|e| {
2674 PolarsError::ComputeError(format!("write partitioned: create_dir_all: {e}").into())
2675 })?;
2676
2677 let ext = match self.format {
2678 WriteFormat::Parquet => "parquet",
2679 WriteFormat::Csv => "csv",
2680 WriteFormat::Json => "json",
2681 };
2682
2683 for row_idx in 0..unique_keys.height() {
2684 let row = unique_keys
2685 .get(row_idx)
2686 .ok_or_else(|| PolarsError::ComputeError("partition_row: get row".into()))?;
2687 let filter_expr = partition_row_to_filter_expr(&resolved, &row)?;
2688 let subset = to_write.clone().lazy().filter(filter_expr).collect()?;
2689 let subset = subset.select(data_cols.iter().copied())?;
2690 if subset.height() == 0 {
2691 continue;
2692 }
2693
2694 let part_path: std::path::PathBuf = resolved
2695 .iter()
2696 .zip(row.iter())
2697 .map(|(name, av)| format!("{}={}", name, format_partition_value(av)))
2698 .fold(path.to_path_buf(), |p, seg| p.join(seg));
2699 std::fs::create_dir_all(&part_path).map_err(|e| {
2700 PolarsError::ComputeError(
2701 format!("write partitioned: create_dir_all partition: {e}").into(),
2702 )
2703 })?;
2704
2705 let file_idx = if self.mode == WriteMode::Append {
2706 let suffix = format!(".{ext}");
2707 let max_n = std::fs::read_dir(&part_path)
2708 .map(|rd| {
2709 rd.filter_map(Result::ok)
2710 .filter_map(|e| {
2711 e.file_name().to_str().and_then(|s| {
2712 s.strip_prefix("part-")
2713 .and_then(|t| t.strip_suffix(&suffix))
2714 .and_then(|t| t.parse::<u32>().ok())
2715 })
2716 })
2717 .max()
2718 .unwrap_or(0)
2719 })
2720 .unwrap_or(0);
2721 max_n + 1
2722 } else {
2723 0
2724 };
2725 let filename = format!("part-{file_idx:05}.{ext}");
2726 let file_path = part_path.join(&filename);
2727
2728 match self.format {
2729 WriteFormat::Parquet => {
2730 let mut file = std::fs::File::create(&file_path).map_err(|e| {
2731 PolarsError::ComputeError(
2732 format!("write partitioned parquet create: {e}").into(),
2733 )
2734 })?;
2735 let mut df_mut = subset;
2736 ParquetWriter::new(&mut file)
2737 .finish(&mut df_mut)
2738 .map_err(|e| {
2739 PolarsError::ComputeError(
2740 format!("write partitioned parquet: {e}").into(),
2741 )
2742 })?;
2743 }
2744 WriteFormat::Csv => {
2745 let has_header = self
2746 .options
2747 .get("header")
2748 .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
2749 .unwrap_or(true);
2750 let delimiter = self
2751 .options
2752 .get("sep")
2753 .and_then(|s| s.bytes().next())
2754 .unwrap_or(b',');
2755 let mut file = std::fs::File::create(&file_path).map_err(|e| {
2756 PolarsError::ComputeError(
2757 format!("write partitioned csv create: {e}").into(),
2758 )
2759 })?;
2760 CsvWriter::new(&mut file)
2761 .include_header(has_header)
2762 .with_separator(delimiter)
2763 .finish(&mut subset.clone())
2764 .map_err(|e| {
2765 PolarsError::ComputeError(format!("write partitioned csv: {e}").into())
2766 })?;
2767 }
2768 WriteFormat::Json => {
2769 let mut file = std::fs::File::create(&file_path).map_err(|e| {
2770 PolarsError::ComputeError(
2771 format!("write partitioned json create: {e}").into(),
2772 )
2773 })?;
2774 JsonWriter::new(&mut file)
2775 .finish(&mut subset.clone())
2776 .map_err(|e| {
2777 PolarsError::ComputeError(format!("write partitioned json: {e}").into())
2778 })?;
2779 }
2780 }
2781 }
2782 Ok(())
2783 }
2784}
2785
2786impl Clone for DataFrame {
2787 fn clone(&self) -> Self {
2788 DataFrame {
2789 inner: match &self.inner {
2790 DataFrameInner::Eager(df) => DataFrameInner::Eager(df.clone()),
2791 DataFrameInner::Lazy(lf) => DataFrameInner::Lazy(lf.clone()),
2792 },
2793 case_sensitive: self.case_sensitive,
2794 alias: self.alias.clone(),
2795 ambiguous_columns: self.ambiguous_columns.clone(),
2796 }
2797 }
2798}
2799
2800fn format_partition_value(av: &AnyValue<'_>) -> String {
2803 let s = match av {
2804 AnyValue::Null => "__HIVE_DEFAULT_PARTITION__".to_string(),
2805 AnyValue::Boolean(b) => b.to_string(),
2806 AnyValue::Int32(i) => i.to_string(),
2807 AnyValue::Int64(i) => i.to_string(),
2808 AnyValue::UInt32(u) => u.to_string(),
2809 AnyValue::UInt64(u) => u.to_string(),
2810 AnyValue::Float32(f) => f.to_string(),
2811 AnyValue::Float64(f) => f.to_string(),
2812 AnyValue::String(s) => s.to_string(),
2813 AnyValue::StringOwned(s) => s.as_str().to_string(),
2814 AnyValue::Date(d) => d.to_string(),
2815 _ => av.to_string(),
2816 };
2817 s.replace([std::path::MAIN_SEPARATOR, '/'], "_")
2819}
2820
2821fn partition_row_to_filter_expr(
2823 col_names: &[String],
2824 row: &[AnyValue<'_>],
2825) -> Result<Expr, PolarsError> {
2826 if col_names.len() != row.len() {
2827 return Err(PolarsError::ComputeError(
2828 format!(
2829 "partition_row_to_filter_expr: {} columns but {} row values",
2830 col_names.len(),
2831 row.len()
2832 )
2833 .into(),
2834 ));
2835 }
2836 let mut pred = None::<Expr>;
2837 for (name, av) in col_names.iter().zip(row.iter()) {
2838 let clause = match av {
2839 AnyValue::Null => col(name.as_str()).is_null(),
2840 AnyValue::Boolean(b) => col(name.as_str()).eq(lit(*b)),
2841 AnyValue::Int32(i) => col(name.as_str()).eq(lit(*i)),
2842 AnyValue::Int64(i) => col(name.as_str()).eq(lit(*i)),
2843 AnyValue::UInt32(u) => col(name.as_str()).eq(lit(*u)),
2844 AnyValue::UInt64(u) => col(name.as_str()).eq(lit(*u)),
2845 AnyValue::Float32(f) => col(name.as_str()).eq(lit(*f)),
2846 AnyValue::Float64(f) => col(name.as_str()).eq(lit(*f)),
2847 AnyValue::String(s) => col(name.as_str()).eq(lit(s.to_string())),
2848 AnyValue::StringOwned(s) => col(name.as_str()).eq(lit(s.clone())),
2849 _ => {
2850 let s = av.to_string();
2852 col(name.as_str()).cast(DataType::String).eq(lit(s))
2853 }
2854 };
2855 pred = Some(match pred {
2856 None => clause,
2857 Some(p) => p.and(clause),
2858 });
2859 }
2860 Ok(pred.unwrap_or_else(|| lit(true)))
2861}
2862
2863fn is_map_format(dtype: &DataType) -> bool {
2865 if let DataType::List(inner) = dtype {
2866 if let DataType::Struct(fields) = inner.as_ref() {
2867 let has_key = fields.iter().any(|f| f.name == "key");
2868 let has_value = fields.iter().any(|f| f.name == "value");
2869 return has_key && has_value;
2870 }
2871 }
2872 false
2873}
2874
2875fn map_value_string_to_json(s: &str) -> JsonValue {
2879 JsonValue::String(s.to_string())
2880}
2881
2882fn float_to_json_number(f: f64) -> JsonValue {
2885 const EPSILON: f64 = 1e-6;
2886 if f.is_finite() {
2887 let r = f.round();
2888 if (f - r).abs() < EPSILON {
2889 if let Some(n) = serde_json::Number::from_f64(r) {
2890 return JsonValue::Number(n);
2891 }
2892 }
2893 }
2894 serde_json::Number::from_f64(f)
2895 .map(JsonValue::Number)
2896 .unwrap_or(JsonValue::Null)
2897}
2898
2899fn date_days_to_json(days: i32) -> JsonValue {
2901 let epoch = robin_sparkless_core::date_utils::epoch_naive_date();
2902 epoch
2903 .checked_add_signed(chrono::TimeDelta::days(days as i64))
2904 .map(|d| JsonValue::String(d.format("%Y-%m-%d").to_string()))
2905 .unwrap_or(JsonValue::Null)
2906}
2907
2908fn datetime_anyvalue_to_json_iso(val: i64, unit: &TimeUnit) -> JsonValue {
2910 let micros = match unit {
2911 TimeUnit::Nanoseconds => val.checked_div(1000),
2912 TimeUnit::Microseconds => Some(val),
2913 TimeUnit::Milliseconds => val.checked_mul(1000),
2914 };
2915 micros
2916 .and_then(chrono::DateTime::from_timestamp_micros)
2917 .map(|dt| JsonValue::String(dt.format("%Y-%m-%dT%H:%M:%S%.6f").to_string()))
2918 .unwrap_or(JsonValue::Null)
2919}
2920
2921fn struct_string_to_json_object(s: &str, fields: &[Field]) -> Option<JsonValue> {
2924 use serde_json::Map;
2925 if fields.is_empty() {
2926 return None;
2927 }
2928 let trimmed = s.trim();
2929 let inner = trimmed
2930 .strip_prefix('{')
2931 .and_then(|t| t.strip_suffix('}'))
2932 .map(|t| t.trim())
2933 .unwrap_or(trimmed);
2934 let mut obj = Map::new();
2935 if fields.len() == 1 {
2936 let f = &fields[0];
2937 let val = match &f.dtype {
2938 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => inner
2939 .parse::<i64>()
2940 .ok()
2941 .map(serde_json::Number::from)
2942 .map(JsonValue::Number),
2943 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => inner
2944 .parse::<u64>()
2945 .ok()
2946 .map(serde_json::Number::from)
2947 .map(JsonValue::Number),
2948 DataType::Float32 | DataType::Float64 => inner
2949 .parse::<f64>()
2950 .ok()
2951 .filter(|f| f.is_finite())
2952 .and_then(|f| serde_json::Number::from_f64(f).map(JsonValue::Number)),
2953 DataType::String => Some(JsonValue::String(
2954 inner
2955 .strip_prefix('"')
2956 .and_then(|t| t.strip_suffix('"'))
2957 .unwrap_or(inner)
2958 .to_string(),
2959 )),
2960 DataType::Boolean => {
2961 if inner.eq_ignore_ascii_case("true") {
2962 Some(JsonValue::Bool(true))
2963 } else if inner.eq_ignore_ascii_case("false") {
2964 Some(JsonValue::Bool(false))
2965 } else {
2966 None
2967 }
2968 }
2969 _ => None,
2970 }?;
2971 obj.insert(f.name.to_string(), val);
2972 return Some(JsonValue::Object(obj));
2973 }
2974 let parts: Vec<&str> = inner.splitn(fields.len(), ", ").map(|p| p.trim()).collect();
2976 for (i, f) in fields.iter().enumerate() {
2977 let part = parts.get(i).unwrap_or(&"").trim();
2978 let part_unescaped = part
2979 .strip_prefix('"')
2980 .and_then(|t| t.strip_suffix('"'))
2981 .unwrap_or(part);
2982 let val = if part.is_empty() || (part_unescaped.is_empty() && part != "\"\"") {
2983 JsonValue::Null
2984 } else {
2985 match &f.dtype {
2986 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => part
2987 .parse::<i64>()
2988 .ok()
2989 .map(serde_json::Number::from)
2990 .map(JsonValue::Number)
2991 .unwrap_or(JsonValue::Null),
2992 DataType::Float32 | DataType::Float64 => part
2993 .parse::<f64>()
2994 .ok()
2995 .filter(|x| x.is_finite())
2996 .and_then(serde_json::Number::from_f64)
2997 .map(JsonValue::Number)
2998 .unwrap_or(JsonValue::Null),
2999 DataType::String => JsonValue::String(part_unescaped.to_string()),
3000 DataType::Boolean => {
3001 if part.eq_ignore_ascii_case("true") {
3002 JsonValue::Bool(true)
3003 } else if part.eq_ignore_ascii_case("false") {
3004 JsonValue::Bool(false)
3005 } else {
3006 JsonValue::Null
3007 }
3008 }
3009 _ => JsonValue::Null,
3010 }
3011 };
3012 obj.insert(f.name.to_string(), val);
3013 }
3014 Some(JsonValue::Object(obj))
3015}
3016
3017fn any_value_to_json(av: &AnyValue<'_>, dtype: &DataType) -> Result<JsonValue, PolarsError> {
3023 use serde_json::Map;
3024 if matches!(dtype, DataType::String) {
3026 if let Some(s) = av.get_str() {
3027 return Ok(JsonValue::String(s.to_string()));
3028 }
3029 if matches!(
3030 av,
3031 AnyValue::Int8(_)
3032 | AnyValue::Int16(_)
3033 | AnyValue::Int32(_)
3034 | AnyValue::Int64(_)
3035 | AnyValue::UInt8(_)
3036 | AnyValue::UInt16(_)
3037 | AnyValue::UInt32(_)
3038 | AnyValue::UInt64(_)
3039 | AnyValue::Float32(_)
3040 | AnyValue::Float64(_)
3041 | AnyValue::Boolean(_)
3042 ) {
3043 return Ok(JsonValue::String(av.to_string()));
3044 }
3045 }
3046 Ok(match av {
3047 AnyValue::Null => JsonValue::Null,
3048 AnyValue::Boolean(b) => JsonValue::Bool(*b),
3049 AnyValue::Int32(i) if matches!(dtype, DataType::Date) => date_days_to_json(*i),
3051 AnyValue::Int64(i) if matches!(dtype, DataType::Datetime(_, _)) => match dtype {
3052 DataType::Datetime(unit, _) => datetime_anyvalue_to_json_iso(*i, unit),
3053 _ => datetime_anyvalue_to_json_iso(*i, &TimeUnit::Microseconds),
3054 },
3055 AnyValue::Int32(i) if matches!(dtype, DataType::Float32 | DataType::Float64) => {
3057 float_to_json_number(*i as f64)
3058 }
3059 AnyValue::Int64(i) if matches!(dtype, DataType::Float32 | DataType::Float64) => {
3060 float_to_json_number(*i as f64)
3061 }
3062 AnyValue::Int8(i) => JsonValue::Number(serde_json::Number::from(*i)),
3063 AnyValue::Int16(i) => JsonValue::Number(serde_json::Number::from(*i)),
3064 AnyValue::Int32(i) => JsonValue::Number(serde_json::Number::from(*i)),
3065 AnyValue::Int64(i) => JsonValue::Number(serde_json::Number::from(*i)),
3066 AnyValue::UInt8(u) => JsonValue::Number(serde_json::Number::from(*u)),
3067 AnyValue::UInt16(u) => JsonValue::Number(serde_json::Number::from(*u)),
3068 AnyValue::UInt32(u) => JsonValue::Number(serde_json::Number::from(*u)),
3069 AnyValue::UInt64(u) => JsonValue::Number(serde_json::Number::from(*u)),
3070 AnyValue::Float32(f) => float_to_json_number(f64::from(*f)),
3071 AnyValue::Float64(f) => float_to_json_number(*f),
3072 AnyValue::String(s) => {
3073 if matches!(dtype, DataType::List(_)) {
3075 if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
3076 if parsed.is_array() {
3077 parsed
3078 } else {
3079 JsonValue::String(s.to_string())
3080 }
3081 } else {
3082 JsonValue::String(s.to_string())
3083 }
3084 } else if let DataType::Struct(fields) = dtype {
3085 if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
3086 if parsed.is_object() {
3087 parsed
3088 } else if let Some(obj) = struct_string_to_json_object(s, fields) {
3089 obj
3090 } else {
3091 JsonValue::String(s.to_string())
3092 }
3093 } else if let Some(obj) = struct_string_to_json_object(s, fields) {
3094 obj
3095 } else {
3096 JsonValue::String(s.to_string())
3097 }
3098 } else {
3099 JsonValue::String(s.to_string())
3100 }
3101 }
3102 AnyValue::StringOwned(s) => {
3103 let s_ref = s.as_ref();
3104 if matches!(dtype, DataType::List(_)) {
3105 if let Ok(parsed) = serde_json::from_str::<JsonValue>(s_ref) {
3106 if parsed.is_array() {
3107 parsed
3108 } else {
3109 JsonValue::String(s_ref.to_string())
3110 }
3111 } else {
3112 JsonValue::String(s_ref.to_string())
3113 }
3114 } else if let DataType::Struct(fields) = dtype {
3115 if let Ok(parsed) = serde_json::from_str::<JsonValue>(s_ref) {
3116 if parsed.is_object() {
3117 parsed
3118 } else if let Some(obj) = struct_string_to_json_object(s_ref, fields) {
3119 obj
3120 } else {
3121 JsonValue::String(s_ref.to_string())
3122 }
3123 } else if let Some(obj) = struct_string_to_json_object(s_ref, fields) {
3124 obj
3125 } else {
3126 JsonValue::String(s_ref.to_string())
3127 }
3128 } else {
3129 JsonValue::String(s_ref.to_string())
3130 }
3131 }
3132 AnyValue::List(s) => {
3133 if is_map_format(dtype) {
3134 let mut entries: Vec<(String, JsonValue)> = Vec::new();
3138 let mut has_string_value = false;
3139 let mut has_numeric_or_bool_value = false;
3140 for i in 0..s.len() {
3141 if let Ok(elem) = s.get(i) {
3142 let (k, v) = match &elem {
3143 AnyValue::Struct(_, _, fields) => {
3144 let mut k = None;
3145 let mut v = None;
3146 for (fld_av, fld) in elem._iter_struct_av().zip(fields.iter()) {
3147 if fld.name == "key" {
3148 if matches!(fld_av, AnyValue::Null) {
3149 return Err(PolarsError::ComputeError(
3150 "Cannot create map with null key (PySpark: NULL_MAP_KEY)".into(),
3151 ));
3152 }
3153 k = fld_av
3154 .get_str()
3155 .map(|s| s.to_string())
3156 .or_else(|| Some(fld_av.to_string()));
3157 } else if fld.name == "value" {
3158 v = Some(if matches!(fld.dtype, DataType::String) {
3159 if let Some(s) = fld_av.get_str() {
3160 map_value_string_to_json(s)
3161 } else {
3162 any_value_to_json(&fld_av, &fld.dtype)?
3163 }
3164 } else {
3165 any_value_to_json(&fld_av, &fld.dtype)?
3166 });
3167 }
3168 }
3169 (k, v)
3170 }
3171 AnyValue::StructOwned(payload) => {
3172 let (values, fields) = &**payload;
3173 let mut k = None;
3174 let mut v = None;
3175 for (fld_av, fld) in values.iter().zip(fields.iter()) {
3176 if fld.name == "key" {
3177 if matches!(fld_av, AnyValue::Null) {
3178 return Err(PolarsError::ComputeError(
3179 "Cannot create map with null key (PySpark: NULL_MAP_KEY)".into(),
3180 ));
3181 }
3182 k = fld_av
3183 .get_str()
3184 .map(|s| s.to_string())
3185 .or_else(|| Some(fld_av.to_string()));
3186 } else if fld.name == "value" {
3187 v = Some(if matches!(fld.dtype, DataType::String) {
3188 if let Some(s) = fld_av.get_str() {
3189 map_value_string_to_json(s)
3190 } else {
3191 any_value_to_json(fld_av, &fld.dtype)?
3192 }
3193 } else {
3194 any_value_to_json(fld_av, &fld.dtype)?
3195 });
3196 }
3197 }
3198 (k, v)
3199 }
3200 _ => (None, None),
3201 };
3202 if let (Some(key), Some(val)) = (k, v) {
3203 if matches!(val, JsonValue::String(_)) {
3204 has_string_value = true;
3205 } else if matches!(val, JsonValue::Number(_) | JsonValue::Bool(_)) {
3206 has_numeric_or_bool_value = true;
3207 }
3208 entries.push((key, val));
3209 }
3210 }
3211 }
3212 if has_string_value && has_numeric_or_bool_value {
3213 for (_, v) in entries.iter_mut() {
3214 match v {
3215 JsonValue::Number(n) => {
3216 *v = JsonValue::String(n.to_string());
3217 }
3218 JsonValue::Bool(b) => {
3219 *v = JsonValue::String(b.to_string());
3220 }
3221 _ => {}
3222 }
3223 }
3224 }
3225 let mut obj = Map::new();
3226 for (key, val) in entries {
3227 obj.insert(key, val);
3228 }
3229 JsonValue::Object(obj)
3230 } else {
3231 let inner_dtype = match dtype {
3232 DataType::List(inner) => inner.as_ref(),
3233 _ => dtype,
3234 };
3235 let arr: Vec<JsonValue> = (0..s.len())
3236 .filter_map(|i| s.get(i).ok())
3237 .map(|a| any_value_to_json(&a, inner_dtype))
3238 .collect::<Result<Vec<_>, _>>()?;
3239 JsonValue::Array(arr)
3240 }
3241 }
3242 AnyValue::Struct(_, _, fields) => {
3243 let mut vals: Vec<JsonValue> = Vec::with_capacity(fields.len());
3244 for (fld_av, fld) in av._iter_struct_av().zip(fields.iter()) {
3245 vals.push(any_value_to_json(&fld_av, &fld.dtype)?);
3246 }
3247 if vals.iter().all(|v| matches!(v, JsonValue::Null)) {
3248 JsonValue::Null
3249 } else {
3250 let mut obj = Map::new();
3251 for (fld, v) in fields.iter().zip(vals) {
3252 obj.insert(fld.name.to_string(), v);
3253 }
3254 JsonValue::Object(obj)
3255 }
3256 }
3257 AnyValue::StructOwned(payload) => {
3258 let (values, fields) = &**payload;
3259 let vals: Vec<JsonValue> = values
3260 .iter()
3261 .zip(fields.iter())
3262 .map(|(fld_av, fld)| any_value_to_json(fld_av, &fld.dtype))
3263 .collect::<Result<Vec<_>, _>>()?;
3264 if vals.iter().all(|v| matches!(v, JsonValue::Null)) {
3265 JsonValue::Null
3266 } else {
3267 let mut obj = Map::new();
3268 for (fld, v) in fields.iter().zip(vals) {
3269 obj.insert(fld.name.to_string(), v);
3270 }
3271 JsonValue::Object(obj)
3272 }
3273 }
3274 AnyValue::Date(days) => date_days_to_json(*days),
3275 AnyValue::Datetime(val, unit, _) => datetime_anyvalue_to_json_iso(*val, unit),
3276 AnyValue::DatetimeOwned(val, unit, _) => datetime_anyvalue_to_json_iso(*val, unit),
3277 _ => JsonValue::Null,
3278 })
3279}
3280
3281pub fn broadcast(df: &DataFrame) -> DataFrame {
3283 df.clone()
3284}
3285
3286#[cfg(test)]
3287mod tests {
3288 use super::*;
3289 use polars::prelude::{NamedFrom, Series};
3290
3291 #[test]
3293 fn coerce_string_numeric_root_in_filter() {
3294 let s = Series::new("str_col".into(), &["123", "456"]);
3295 let pl_df = polars::prelude::DataFrame::new_infer_height(vec![s.into()]).unwrap();
3296 let df = DataFrame::from_polars(pl_df);
3297 let expr = col("str_col").eq(lit(123i64));
3298 let out = df.filter(expr).unwrap();
3299 assert_eq!(out.count().unwrap(), 1);
3300 }
3301
3302 #[test]
3304 fn coerce_numeric_column_eq_string_literal() {
3305 let s = Series::new("value".into(), &[100i64, 200i64, 50i64]);
3306 let pl_df = polars::prelude::DataFrame::new_infer_height(vec![s.into()]).unwrap();
3307 let df = DataFrame::from_polars(pl_df);
3308 let expr = col("value").eq(lit("100"));
3309 let out = df.filter(expr).unwrap();
3310 assert_eq!(out.count().unwrap(), 1);
3311 let rows = out.collect_as_json_rows().unwrap();
3312 assert_eq!(rows[0].get("value").and_then(|v| v.as_i64()), Some(100));
3313 }
3314
3315 #[test]
3317 fn filter_with_string_contains_predicate() {
3318 use crate::functions::col;
3319 use serde_json::json;
3320
3321 let spark = crate::session::SparkSession::builder()
3322 .app_name("filter_contains_test")
3323 .get_or_create();
3324 let schema = vec![
3325 ("id".to_string(), "bigint".to_string()),
3326 ("name".to_string(), "string".to_string()),
3327 ];
3328 let rows = vec![
3329 vec![json!(1), json!("alice")],
3330 vec![json!(2), json!("bob")],
3331 vec![json!(3), json!("charlie")],
3332 ];
3333 let df = spark
3334 .create_dataframe_from_rows(rows, schema, false, false)
3335 .unwrap();
3336 let cond: polars::prelude::Expr = col("name").contains("lic").into_expr();
3337 let filtered = df.filter(cond).unwrap();
3338 assert_eq!(
3339 filtered.count().unwrap(),
3340 1,
3341 "filter(name.contains(\"lic\")) should return one row (alice)"
3342 );
3343 }
3344
3345 #[test]
3347 fn lazy_schema_columns_resolve_before_collect() {
3348 let spark = SparkSession::builder()
3349 .app_name("lazy_mod_tests")
3350 .get_or_create();
3351 let df = spark
3352 .create_dataframe(
3353 vec![
3354 (1i64, 25i64, "a".to_string()),
3355 (2i64, 30i64, "b".to_string()),
3356 ],
3357 vec!["id", "age", "name"],
3358 )
3359 .unwrap();
3360 assert_eq!(df.columns().unwrap(), vec!["id", "age", "name"]);
3361 assert_eq!(df.resolve_column_name("AGE").unwrap(), "age");
3362 assert!(df.get_column_dtype("id").unwrap().is_integer());
3363 }
3364
3365 #[test]
3367 fn lazy_from_lazy_produces_valid_df() {
3368 let _spark = SparkSession::builder()
3369 .app_name("lazy_mod_tests")
3370 .get_or_create();
3371 let pl_df = polars::prelude::df!("x" => &[1i64, 2, 3]).unwrap();
3372 let df = DataFrame::from_lazy_with_options(pl_df.lazy(), false);
3373 assert_eq!(df.columns().unwrap(), vec!["x"]);
3374 assert_eq!(df.count().unwrap(), 3);
3375 }
3376
3377 #[test]
3379 fn collect_preserves_null_as_json_null() {
3380 use serde_json::Value as JsonValue;
3381
3382 let _spark = SparkSession::builder()
3383 .app_name("collect_null_test")
3384 .get_or_create();
3385 let s_id = Series::new("id".into(), &[1i64, 2i64, 3i64]);
3386 let s_val = Series::new("value".into(), vec![Some(10i64), None, Some(30i64)]);
3387 let pl_df =
3388 polars::prelude::DataFrame::new_infer_height(vec![s_id.into(), s_val.into()]).unwrap();
3389 let df = DataFrame::from_polars(pl_df);
3390 let rows = df.collect_as_json_rows().unwrap();
3391 assert_eq!(rows.len(), 3);
3392 assert_eq!(rows[0].get("value").and_then(|v| v.as_i64()), Some(10));
3393 assert!(rows[1].contains_key("value"));
3394 assert!(matches!(rows[1].get("value"), Some(JsonValue::Null)));
3395 assert_eq!(rows[2].get("value").and_then(|v| v.as_i64()), Some(30));
3396 }
3397
3398 #[test]
3400 fn pivot_raises_not_implemented() {
3401 let spark = SparkSession::builder()
3402 .app_name("pivot_stub_test")
3403 .get_or_create();
3404 let df = spark
3405 .create_dataframe(
3406 vec![
3407 (1i64, 25i64, "a".to_string()),
3408 (2i64, 30i64, "b".to_string()),
3409 ],
3410 vec!["id", "age", "name"],
3411 )
3412 .unwrap();
3413 let err = match df.pivot("name", None) {
3414 Ok(_) => panic!("pivot should not be implemented"),
3415 Err(e) => e,
3416 };
3417 let msg = err.to_string();
3418 assert!(
3419 msg.contains("pivot is not yet implemented") && msg.contains("crosstab"),
3420 "pivot stub should mention crosstab: {}",
3421 msg
3422 );
3423 }
3424
3425 #[test]
3427 fn collect_rounds_float_near_integer() {
3428 let _spark = SparkSession::builder()
3429 .app_name("float_round_test")
3430 .get_or_create();
3431 let s = Series::new("result".into(), &[7.999_999_999_999_998_f64, 8.0]);
3432 let pl_df = polars::prelude::DataFrame::new_infer_height(vec![s.into()]).unwrap();
3433 let df = DataFrame::from_polars(pl_df);
3434 let rows = df.collect_as_json_rows().unwrap();
3435 assert_eq!(rows.len(), 2);
3436 assert_eq!(rows[0].get("result").and_then(|v| v.as_f64()), Some(8.0));
3437 assert_eq!(rows[1].get("result").and_then(|v| v.as_f64()), Some(8.0));
3438 }
3439
3440 #[test]
3442 fn select_nested_struct_field_outer_inner_leaf() {
3443 use serde_json::json;
3444
3445 let spark = SparkSession::builder()
3446 .app_name("nested_struct_test")
3447 .get_or_create();
3448 let schema = vec![(
3449 "outer".to_string(),
3450 "struct<inner:struct<leaf:int>>".to_string(),
3451 )];
3452 let rows = vec![vec![json!({"inner": {"leaf": 7}})]];
3453 let df = spark
3454 .create_dataframe_from_rows(rows, schema, false, false)
3455 .unwrap();
3456 let out = df.select(vec!["outer.inner.leaf"]).unwrap();
3457 let out_rows = out.collect_as_json_rows().unwrap();
3458 assert_eq!(out_rows.len(), 1);
3459 assert_eq!(
3460 out_rows[0].get("leaf").and_then(|v| v.as_i64()),
3461 Some(7),
3462 "nested struct field outer.inner.leaf should resolve to 7"
3463 );
3464 }
3465}