1mod aggregations;
4mod joins;
5mod stats;
6mod transformations;
7
8pub use aggregations::{CubeRollupData, GroupedData, PivotedGroupedData};
9pub use joins::{JoinType, join};
10pub use stats::DataFrameStat;
11pub use transformations::{
12 DataFrameNa, SelectItem, filter, order_by, order_by_exprs, select, select_items,
13 select_with_exprs, with_column,
14};
15
16use crate::column::Column;
17use crate::error::{EngineError, polars_to_core_error};
18use crate::functions::SortOrder;
19use crate::schema::{StructType, StructTypePolarsExt};
20use crate::session::SparkSession;
21use crate::type_coercion::coerce_for_pyspark_comparison;
22use polars::prelude::{
23 AnyValue, DataFrame as PlDataFrame, DataType, Expr, IntoLazy, LazyFrame, PlSmallStr,
24 PolarsError, Schema, SchemaNamesAndDtypes, UnknownKind, col, lit,
25};
26use serde_json::Value as JsonValue;
27use std::collections::{HashMap, HashSet};
28use std::path::Path;
29use std::sync::Arc;
30
31const DEFAULT_CASE_SENSITIVE: bool = false;
33
34#[allow(clippy::large_enum_variant)]
37pub(crate) enum DataFrameInner {
38 #[allow(dead_code)]
39 Eager(Arc<PlDataFrame>),
40 Lazy(LazyFrame),
41}
42
43pub struct DataFrame {
47 pub(crate) inner: DataFrameInner,
48 pub(crate) case_sensitive: bool,
50 pub(crate) alias: Option<String>,
52}
53
54impl DataFrame {
55 pub fn from_polars(df: PlDataFrame) -> Self {
58 let lf = df.lazy();
59 DataFrame {
60 inner: DataFrameInner::Lazy(lf),
61 case_sensitive: DEFAULT_CASE_SENSITIVE,
62 alias: None,
63 }
64 }
65
66 pub fn from_polars_with_options(df: PlDataFrame, case_sensitive: bool) -> Self {
69 let lf = df.lazy();
70 DataFrame {
71 inner: DataFrameInner::Lazy(lf),
72 case_sensitive,
73 alias: None,
74 }
75 }
76
77 pub fn from_lazy(lf: LazyFrame) -> Self {
79 DataFrame {
80 inner: DataFrameInner::Lazy(lf),
81 case_sensitive: DEFAULT_CASE_SENSITIVE,
82 alias: None,
83 }
84 }
85
86 pub fn from_lazy_with_options(lf: LazyFrame, case_sensitive: bool) -> Self {
88 DataFrame {
89 inner: DataFrameInner::Lazy(lf),
90 case_sensitive,
91 alias: None,
92 }
93 }
94
95 pub(crate) fn with_case_insensitive_column_resolution(self) -> Self {
98 DataFrame {
99 inner: self.inner,
100 case_sensitive: false,
101 alias: self.alias,
102 }
103 }
104
105 pub fn empty() -> Self {
107 DataFrame {
108 inner: DataFrameInner::Lazy(PlDataFrame::empty().lazy()),
109 case_sensitive: DEFAULT_CASE_SENSITIVE,
110 alias: None,
111 }
112 }
113
114 pub(crate) fn lazy_frame(&self) -> LazyFrame {
116 match &self.inner {
117 DataFrameInner::Eager(df) => df.as_ref().clone().lazy(),
118 DataFrameInner::Lazy(lf) => lf.clone(),
119 }
120 }
121
122 pub(crate) fn collect_inner(&self) -> Result<Arc<PlDataFrame>, PolarsError> {
124 match &self.inner {
125 DataFrameInner::Eager(df) => Ok(df.clone()),
126 DataFrameInner::Lazy(lf) => Ok(Arc::new(lf.clone().collect()?)),
127 }
128 }
129
130 pub fn alias(&self, name: &str) -> Self {
133 let lf = self.lazy_frame();
134 DataFrame {
135 inner: DataFrameInner::Lazy(lf),
136 case_sensitive: self.case_sensitive,
137 alias: Some(name.to_string()),
138 }
139 }
140
141 pub fn resolve_expr_column_names(&self, expr: Expr) -> Result<Expr, PolarsError> {
148 let df = self;
149 let mut alias_output_names: HashSet<String> = HashSet::new();
150 let _ = expr.clone().try_map_expr(|e| {
151 if let Expr::Alias(_, name) = &e {
152 alias_output_names.insert(name.as_str().to_string());
153 }
154 Ok(e)
155 })?;
156 expr.try_map_expr(move |e| {
157 if let Expr::Column(name) = &e {
158 let name_str = name.as_str();
159 if alias_output_names.contains(name_str) {
160 return Ok(e);
161 }
162 if name_str.is_empty() {
164 return Ok(e);
165 }
166 if name_str.contains('.') {
168 let parts: Vec<&str> = name_str.split('.').collect();
169 let first = parts[0];
170 let rest = &parts[1..];
171 if rest.is_empty() {
172 return Err(PolarsError::ColumnNotFound(
173 format!("Column '{}': trailing dot not allowed", name_str).into(),
174 ));
175 }
176 let resolved = df.resolve_column_name(first)?;
177 let mut expr = col(PlSmallStr::from(resolved.as_str()));
178 for field in rest {
179 expr = expr.struct_().field_by_name(field);
180 }
181 return Ok(expr);
182 }
183 let resolved = df.resolve_column_name(name_str)?;
184 return Ok(Expr::Column(PlSmallStr::from(resolved.as_str())));
185 }
186 Ok(e)
187 })
188 }
189
190 pub fn coerce_string_numeric_comparisons(&self, expr: Expr) -> Result<Expr, PolarsError> {
198 use polars::prelude::{DataType, LiteralValue, Operator};
199 use std::sync::Arc;
200
201 fn is_numeric_literal(expr: &Expr) -> bool {
202 match expr {
203 Expr::Literal(lv) => {
204 let dt = lv.get_datatype();
205 dt.is_numeric()
206 || matches!(
207 dt,
208 DataType::Unknown(UnknownKind::Int(_))
209 | DataType::Unknown(UnknownKind::Float)
210 )
211 }
212 _ => false,
213 }
214 }
215
216 fn literal_dtype(lv: &LiteralValue) -> DataType {
217 let dt = lv.get_datatype();
218 if matches!(
219 dt,
220 DataType::Unknown(UnknownKind::Int(_)) | DataType::Unknown(UnknownKind::Float)
221 ) {
222 DataType::Float64
223 } else {
224 dt
225 }
226 }
227
228 let expr = {
231 if let Expr::BinaryExpr { left, op, right } = &expr {
232 let is_comparison_op = matches!(
233 op,
234 Operator::Eq
235 | Operator::NotEq
236 | Operator::Lt
237 | Operator::LtEq
238 | Operator::Gt
239 | Operator::GtEq
240 );
241 let left_is_col = matches!(&**left, Expr::Column(_));
242 let right_is_col = matches!(&**right, Expr::Column(_));
243 let left_is_numeric_lit =
244 matches!(&**left, Expr::Literal(_)) && is_numeric_literal(left.as_ref());
245 let right_is_numeric_lit =
246 matches!(&**right, Expr::Literal(_)) && is_numeric_literal(right.as_ref());
247 let left_is_string_lit = matches!(
248 &**left,
249 Expr::Literal(lv) if lv.get_datatype() == DataType::String
250 );
251 let right_is_string_lit = matches!(
252 &**right,
253 Expr::Literal(lv) if lv.get_datatype() == DataType::String
254 );
255 let root_is_col_vs_numeric = is_comparison_op
256 && ((left_is_col && right_is_numeric_lit)
257 || (right_is_col && left_is_numeric_lit));
258 let root_is_col_vs_string = is_comparison_op
259 && ((left_is_col && right_is_string_lit)
260 || (right_is_col && left_is_string_lit));
261 if root_is_col_vs_numeric {
262 let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
263 let lit_ty = match &**right {
264 Expr::Literal(lv) => literal_dtype(lv),
265 _ => DataType::Float64,
266 };
267 coerce_for_pyspark_comparison(
268 (*left).as_ref().clone(),
269 (*right).as_ref().clone(),
270 &DataType::String,
271 &lit_ty,
272 op,
273 )
274 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
275 } else {
276 let lit_ty = match &**left {
277 Expr::Literal(lv) => literal_dtype(lv),
278 _ => DataType::Float64,
279 };
280 coerce_for_pyspark_comparison(
281 (*left).as_ref().clone(),
282 (*right).as_ref().clone(),
283 &lit_ty,
284 &DataType::String,
285 op,
286 )
287 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
288 };
289 Expr::BinaryExpr {
290 left: Arc::new(new_left),
291 op: *op,
292 right: Arc::new(new_right),
293 }
294 } else if root_is_col_vs_string {
295 let col_name = if left_is_col {
296 if let Expr::Column(n) = &**left {
297 n.as_str()
298 } else {
299 unreachable!()
300 }
301 } else if let Expr::Column(n) = &**right {
302 n.as_str()
303 } else {
304 unreachable!()
305 };
306 if let Some(col_dtype) = self.get_column_dtype(col_name) {
307 if matches!(col_dtype, DataType::Date | DataType::Datetime(_, _)) {
308 let (left_ty, right_ty) = if left_is_col {
309 (col_dtype.clone(), DataType::String)
310 } else {
311 (DataType::String, col_dtype.clone())
312 };
313 let (new_left, new_right) = coerce_for_pyspark_comparison(
314 (*left).as_ref().clone(),
315 (*right).as_ref().clone(),
316 &left_ty,
317 &right_ty,
318 op,
319 )
320 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
321 return Ok(Expr::BinaryExpr {
322 left: Arc::new(new_left),
323 op: *op,
324 right: Arc::new(new_right),
325 });
326 }
327 }
328 expr
329 } else if is_comparison_op && left_is_col && right_is_col {
330 let left_name = if let Expr::Column(n) = &**left {
333 n.as_str()
334 } else {
335 unreachable!()
336 };
337 let right_name = if let Expr::Column(n) = &**right {
338 n.as_str()
339 } else {
340 unreachable!()
341 };
342 if let (Some(left_ty), Some(right_ty)) = (
343 self.get_column_dtype(left_name),
344 self.get_column_dtype(right_name),
345 ) {
346 if left_ty != right_ty {
347 if let Ok((new_left, new_right)) = coerce_for_pyspark_comparison(
348 (*left).as_ref().clone(),
349 (*right).as_ref().clone(),
350 &left_ty,
351 &right_ty,
352 op,
353 ) {
354 return Ok(Expr::BinaryExpr {
355 left: Arc::new(new_left),
356 op: *op,
357 right: Arc::new(new_right),
358 });
359 }
360 }
361 }
362 expr
363 } else {
364 expr
365 }
366 } else {
367 expr
368 }
369 };
370
371 expr.try_map_expr(move |e| {
373 if let Expr::BinaryExpr { left, op, right } = e {
374 let is_comparison_op = matches!(
375 op,
376 Operator::Eq
377 | Operator::NotEq
378 | Operator::Lt
379 | Operator::LtEq
380 | Operator::Gt
381 | Operator::GtEq
382 );
383 if !is_comparison_op {
384 return Ok(Expr::BinaryExpr { left, op, right });
385 }
386
387 let left_is_col = matches!(&*left, Expr::Column(_));
388 let right_is_col = matches!(&*right, Expr::Column(_));
389 let left_is_lit = matches!(&*left, Expr::Literal(_));
390 let right_is_lit = matches!(&*right, Expr::Literal(_));
391 let left_is_string_lit =
392 matches!(&*left, Expr::Literal(lv) if lv.get_datatype() == DataType::String);
393 let right_is_string_lit =
394 matches!(&*right, Expr::Literal(lv) if lv.get_datatype() == DataType::String);
395
396 let left_is_numeric_lit = left_is_lit && is_numeric_literal(left.as_ref());
397 let right_is_numeric_lit = right_is_lit && is_numeric_literal(right.as_ref());
398
399 let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
403 let lit_ty = match &*right {
404 Expr::Literal(lv) => literal_dtype(lv),
405 _ => DataType::Float64,
406 };
407 coerce_for_pyspark_comparison(
408 (*left).clone(),
409 (*right).clone(),
410 &DataType::String,
411 &lit_ty,
412 &op,
413 )
414 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
415 } else if right_is_col && left_is_numeric_lit {
416 let lit_ty = match &*left {
417 Expr::Literal(lv) => literal_dtype(lv),
418 _ => DataType::Float64,
419 };
420 coerce_for_pyspark_comparison(
421 (*left).clone(),
422 (*right).clone(),
423 &lit_ty,
424 &DataType::String,
425 &op,
426 )
427 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
428 } else if (left_is_col && right_is_string_lit)
429 || (right_is_col && left_is_string_lit)
430 {
431 let col_name = if left_is_col {
432 if let Expr::Column(n) = &*left {
433 n.as_str()
434 } else {
435 unreachable!()
436 }
437 } else if let Expr::Column(n) = &*right {
438 n.as_str()
439 } else {
440 unreachable!()
441 };
442 if let Some(col_dtype) = self.get_column_dtype(col_name) {
443 if matches!(col_dtype, DataType::Date | DataType::Datetime(_, _)) {
444 let (left_ty, right_ty) = if left_is_col {
445 (col_dtype.clone(), DataType::String)
446 } else {
447 (DataType::String, col_dtype.clone())
448 };
449 let (new_l, new_r) = coerce_for_pyspark_comparison(
450 (*left).clone(),
451 (*right).clone(),
452 &left_ty,
453 &right_ty,
454 &op,
455 )
456 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
457 return Ok(Expr::BinaryExpr {
458 left: Arc::new(new_l),
459 op,
460 right: Arc::new(new_r),
461 });
462 }
463 }
464 return Ok(Expr::BinaryExpr { left, op, right });
465 } else {
466 return Ok(Expr::BinaryExpr { left, op, right });
468 };
469
470 Ok(Expr::BinaryExpr {
471 left: Arc::new(new_left),
472 op,
473 right: Arc::new(new_right),
474 })
475 } else {
476 Ok(e)
477 }
478 })
479 }
480
481 fn schema_or_collect(&self) -> Result<Arc<Schema>, PolarsError> {
483 match &self.inner {
484 DataFrameInner::Eager(df) => Ok(Arc::clone(df.schema())),
485 DataFrameInner::Lazy(lf) => Ok(lf.clone().collect_schema()?),
486 }
487 }
488
489 pub fn resolve_column_name(&self, name: &str) -> Result<String, PolarsError> {
492 let schema = self.schema_or_collect()?;
493 let names: Vec<String> = schema
494 .iter_names_and_dtypes()
495 .map(|(n, _)| n.to_string())
496 .collect();
497 if self.case_sensitive {
498 if names.iter().any(|n| n == name) {
499 return Ok(name.to_string());
500 }
501 } else {
502 let name_lower = name.to_lowercase();
503 for n in &names {
504 if n.to_lowercase() == name_lower {
505 return Ok(n.clone());
506 }
507 }
508 }
509 let available = names.join(", ");
510 Err(PolarsError::ColumnNotFound(
511 format!(
512 "Column '{}' not found. Available columns: [{}]. Check spelling and case sensitivity (spark.sql.caseSensitive).",
513 name,
514 available
515 )
516 .into(),
517 ))
518 }
519
520 pub fn schema(&self) -> Result<StructType, PolarsError> {
522 let s = self.schema_or_collect()?;
523 Ok(StructType::from_polars_schema(&s))
524 }
525
526 pub fn schema_engine(&self) -> Result<StructType, EngineError> {
528 self.schema().map_err(polars_to_core_error)
529 }
530
531 pub fn get_column_dtype(&self, name: &str) -> Option<DataType> {
533 let resolved = self.resolve_column_name(name).ok()?;
534 self.schema_or_collect()
535 .ok()?
536 .iter_names_and_dtypes()
537 .find(|(n, _)| n.to_string() == resolved)
538 .map(|(_, dt)| dt.clone())
539 }
540
541 pub fn get_column_data_type(&self, name: &str) -> Option<crate::schema::DataType> {
544 let resolved = self.resolve_column_name(name).ok()?;
545 let st = self.schema().ok()?;
546 st.fields()
547 .iter()
548 .find(|f| f.name == resolved)
549 .map(|f| f.data_type.clone())
550 }
551
552 pub fn columns(&self) -> Result<Vec<String>, PolarsError> {
554 let schema = self.schema_or_collect()?;
555 Ok(schema
556 .iter_names_and_dtypes()
557 .map(|(n, _)| n.to_string())
558 .collect())
559 }
560
561 pub fn columns_engine(&self) -> Result<Vec<String>, EngineError> {
563 self.columns().map_err(polars_to_core_error)
564 }
565
566 pub fn count(&self) -> Result<usize, PolarsError> {
568 Ok(self.collect_inner()?.height())
569 }
570
571 pub fn count_engine(&self) -> Result<usize, EngineError> {
573 self.count().map_err(polars_to_core_error)
574 }
575
576 pub fn show(&self, n: Option<usize>) -> Result<(), PolarsError> {
578 let n = n.unwrap_or(20);
579 let df = self.collect_inner()?;
580 println!("{}", df.head(Some(n)));
581 Ok(())
582 }
583
584 pub fn collect(&self) -> Result<Arc<PlDataFrame>, PolarsError> {
586 self.collect_inner()
587 }
588
589 pub fn collect_as_json_rows_engine(
591 &self,
592 ) -> Result<Vec<HashMap<String, JsonValue>>, EngineError> {
593 self.collect_as_json_rows().map_err(polars_to_core_error)
594 }
595
596 pub fn collect_as_json_rows(&self) -> Result<Vec<HashMap<String, JsonValue>>, PolarsError> {
598 let collected = self.collect_inner()?;
599 let names = collected.get_column_names();
600 let nrows = collected.height();
601 let mut rows = Vec::with_capacity(nrows);
602 for i in 0..nrows {
603 let mut row = HashMap::with_capacity(names.len());
604 for (col_idx, name) in names.iter().enumerate() {
605 let s = collected
606 .columns()
607 .get(col_idx)
608 .ok_or_else(|| PolarsError::ComputeError("column index out of range".into()))?;
609 let av = s.get(i)?;
610 let jv = any_value_to_json(&av, s.dtype());
611 row.insert(name.to_string(), jv);
612 }
613 rows.push(row);
614 }
615 Ok(rows)
616 }
617
618 pub fn to_json_rows(&self) -> Result<String, EngineError> {
621 let rows = self.collect_as_json_rows().map_err(polars_to_core_error)?;
622 serde_json::to_string(&rows).map_err(Into::into)
623 }
624
625 pub fn select_exprs(&self, exprs: Vec<Expr>) -> Result<DataFrame, PolarsError> {
629 transformations::select_with_exprs(self, exprs, self.case_sensitive)
630 }
631
632 pub fn select(&self, cols: Vec<&str>) -> Result<DataFrame, PolarsError> {
635 let resolved: Vec<String> = cols
636 .iter()
637 .map(|c| self.resolve_column_name(c))
638 .collect::<Result<Vec<_>, _>>()?;
639 let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
640 let mut result = transformations::select(self, refs, self.case_sensitive)?;
641 if !self.case_sensitive {
643 for (requested, res) in cols.iter().zip(resolved.iter()) {
644 if *requested != res.as_str() {
645 result = result.with_column_renamed(res, requested)?;
646 }
647 }
648 }
649 Ok(result)
650 }
651
652 pub fn select_engine(&self, cols: Vec<&str>) -> Result<DataFrame, EngineError> {
654 self.select(cols).map_err(polars_to_core_error)
655 }
656
657 pub fn select_items(&self, items: Vec<SelectItem<'_>>) -> Result<DataFrame, PolarsError> {
659 transformations::select_items(self, items, self.case_sensitive)
660 }
661
662 pub fn filter(&self, condition: Expr) -> Result<DataFrame, PolarsError> {
664 transformations::filter(self, condition, self.case_sensitive)
665 }
666
667 pub fn filter_engine(&self, condition: Expr) -> Result<DataFrame, EngineError> {
669 self.filter(condition).map_err(polars_to_core_error)
670 }
671
672 pub fn column(&self, name: &str) -> Result<Column, PolarsError> {
675 let resolved = self.resolve_column_name(name)?;
676 Ok(Column::new(resolved))
677 }
678
679 pub fn with_column(&self, column_name: &str, col: &Column) -> Result<DataFrame, PolarsError> {
682 transformations::with_column(self, column_name, col, self.case_sensitive)
683 }
684
685 pub fn with_column_engine(
687 &self,
688 column_name: &str,
689 col: &Column,
690 ) -> Result<DataFrame, EngineError> {
691 self.with_column(column_name, col)
692 .map_err(polars_to_core_error)
693 }
694
695 pub fn with_column_expr(
697 &self,
698 column_name: &str,
699 expr: Expr,
700 ) -> Result<DataFrame, PolarsError> {
701 let col = Column::from_expr(expr, None);
702 self.with_column(column_name, &col)
703 }
704
705 pub fn group_by(&self, column_names: Vec<&str>) -> Result<GroupedData, PolarsError> {
708 use polars::prelude::*;
709 let resolved: Vec<String> = column_names
710 .iter()
711 .map(|c| self.resolve_column_name(c))
712 .collect::<Result<Vec<_>, _>>()?;
713 let exprs: Vec<Expr> = resolved.iter().map(|name| col(name.as_str())).collect();
714 let lf = self.lazy_frame();
715 let lazy_grouped = lf.clone().group_by(exprs);
716 Ok(GroupedData {
717 lf,
718 lazy_grouped,
719 grouping_cols: resolved,
720 case_sensitive: self.case_sensitive,
721 })
722 }
723
724 pub fn group_by_engine(&self, column_names: Vec<&str>) -> Result<GroupedData, EngineError> {
726 self.group_by(column_names).map_err(polars_to_core_error)
727 }
728
729 pub fn group_by_exprs(
732 &self,
733 exprs: Vec<Expr>,
734 grouping_col_names: Vec<String>,
735 ) -> Result<GroupedData, PolarsError> {
736 use polars::prelude::*;
737 if exprs.len() != grouping_col_names.len() {
738 return Err(PolarsError::ComputeError(
739 format!(
740 "group_by_exprs: {} exprs but {} names",
741 exprs.len(),
742 grouping_col_names.len()
743 )
744 .into(),
745 ));
746 }
747 let resolved: Vec<Expr> = exprs
748 .into_iter()
749 .map(|e| self.resolve_expr_column_names(e))
750 .collect::<Result<Vec<_>, _>>()?;
751 let lf = self.lazy_frame();
752 let lazy_grouped = lf.clone().group_by(resolved);
753 Ok(GroupedData {
754 lf,
755 lazy_grouped,
756 grouping_cols: grouping_col_names,
757 case_sensitive: self.case_sensitive,
758 })
759 }
760
761 pub fn cube(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
763 let resolved: Vec<String> = column_names
764 .iter()
765 .map(|c| self.resolve_column_name(c))
766 .collect::<Result<Vec<_>, _>>()?;
767 Ok(CubeRollupData {
768 lf: self.lazy_frame(),
769 grouping_cols: resolved,
770 case_sensitive: self.case_sensitive,
771 is_cube: true,
772 })
773 }
774
775 pub fn rollup(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
777 let resolved: Vec<String> = column_names
778 .iter()
779 .map(|c| self.resolve_column_name(c))
780 .collect::<Result<Vec<_>, _>>()?;
781 Ok(CubeRollupData {
782 lf: self.lazy_frame(),
783 grouping_cols: resolved,
784 case_sensitive: self.case_sensitive,
785 is_cube: false,
786 })
787 }
788
789 pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
793 let resolved: Vec<Expr> = aggregations
794 .into_iter()
795 .map(|e| self.resolve_expr_column_names(e))
796 .collect::<Result<Vec<_>, _>>()?;
797 let disambiguated = aggregations::disambiguate_agg_output_names(resolved);
798 let pl_df = self.lazy_frame().select(disambiguated).collect()?;
799 Ok(Self::from_polars_with_options(pl_df, self.case_sensitive))
800 }
801
802 pub fn join(
805 &self,
806 other: &DataFrame,
807 on: Vec<&str>,
808 how: JoinType,
809 ) -> Result<DataFrame, PolarsError> {
810 let resolved: Vec<String> = on
811 .iter()
812 .map(|c| self.resolve_column_name(c))
813 .collect::<Result<Vec<_>, _>>()?;
814 let on_refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
815 join(self, other, on_refs, how, self.case_sensitive)
816 }
817
818 pub fn order_by(
823 &self,
824 column_names: Vec<&str>,
825 ascending: Vec<bool>,
826 ) -> Result<DataFrame, PolarsError> {
827 let resolved: Vec<String> = column_names
828 .iter()
829 .map(|c| self.resolve_column_name(c))
830 .collect::<Result<Vec<_>, _>>()?;
831 let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
832 transformations::order_by(self, refs, ascending, self.case_sensitive)
833 }
834
835 pub fn order_by_exprs(&self, sort_orders: Vec<SortOrder>) -> Result<DataFrame, PolarsError> {
837 transformations::order_by_exprs(self, sort_orders, self.case_sensitive)
838 }
839
840 pub fn union(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
842 transformations::union(self, other, self.case_sensitive)
843 }
844
845 pub fn union_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
847 self.union(other)
848 }
849
850 pub fn union_by_name(
852 &self,
853 other: &DataFrame,
854 allow_missing_columns: bool,
855 ) -> Result<DataFrame, PolarsError> {
856 transformations::union_by_name(self, other, allow_missing_columns, self.case_sensitive)
857 }
858
859 pub fn distinct(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
861 transformations::distinct(self, subset, self.case_sensitive)
862 }
863
864 pub fn drop(&self, columns: Vec<&str>) -> Result<DataFrame, PolarsError> {
866 transformations::drop(self, columns, self.case_sensitive)
867 }
868
869 pub fn dropna(
871 &self,
872 subset: Option<Vec<&str>>,
873 how: &str,
874 thresh: Option<usize>,
875 ) -> Result<DataFrame, PolarsError> {
876 transformations::dropna(self, subset, how, thresh, self.case_sensitive)
877 }
878
879 pub fn fillna(&self, value: Expr, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
881 transformations::fillna(self, value, subset, self.case_sensitive)
882 }
883
884 pub fn limit(&self, n: usize) -> Result<DataFrame, PolarsError> {
886 transformations::limit(self, n, self.case_sensitive)
887 }
888
889 pub fn limit_engine(&self, n: usize) -> Result<DataFrame, EngineError> {
891 self.limit(n).map_err(polars_to_core_error)
892 }
893
894 pub fn with_column_renamed(
896 &self,
897 old_name: &str,
898 new_name: &str,
899 ) -> Result<DataFrame, PolarsError> {
900 transformations::with_column_renamed(self, old_name, new_name, self.case_sensitive)
901 }
902
903 pub fn replace(
905 &self,
906 column_name: &str,
907 old_value: Expr,
908 new_value: Expr,
909 ) -> Result<DataFrame, PolarsError> {
910 transformations::replace(self, column_name, old_value, new_value, self.case_sensitive)
911 }
912
913 pub fn cross_join(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
915 transformations::cross_join(self, other, self.case_sensitive)
916 }
917
918 pub fn describe(&self) -> Result<DataFrame, PolarsError> {
920 transformations::describe(self, self.case_sensitive)
921 }
922
923 pub fn cache(&self) -> Result<DataFrame, PolarsError> {
925 Ok(self.clone())
926 }
927
928 pub fn persist(&self) -> Result<DataFrame, PolarsError> {
930 Ok(self.clone())
931 }
932
933 pub fn unpersist(&self) -> Result<DataFrame, PolarsError> {
935 Ok(self.clone())
936 }
937
938 pub fn subtract(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
940 transformations::subtract(self, other, self.case_sensitive)
941 }
942
943 pub fn intersect(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
945 transformations::intersect(self, other, self.case_sensitive)
946 }
947
948 pub fn sample(
950 &self,
951 with_replacement: bool,
952 fraction: f64,
953 seed: Option<u64>,
954 ) -> Result<DataFrame, PolarsError> {
955 transformations::sample(self, with_replacement, fraction, seed, self.case_sensitive)
956 }
957
958 pub fn random_split(
960 &self,
961 weights: &[f64],
962 seed: Option<u64>,
963 ) -> Result<Vec<DataFrame>, PolarsError> {
964 transformations::random_split(self, weights, seed, self.case_sensitive)
965 }
966
967 pub fn sample_by(
970 &self,
971 col_name: &str,
972 fractions: &[(Expr, f64)],
973 seed: Option<u64>,
974 ) -> Result<DataFrame, PolarsError> {
975 transformations::sample_by(self, col_name, fractions, seed, self.case_sensitive)
976 }
977
978 pub fn first(&self) -> Result<DataFrame, PolarsError> {
980 transformations::first(self, self.case_sensitive)
981 }
982
983 pub fn head(&self, n: usize) -> Result<DataFrame, PolarsError> {
985 transformations::head(self, n, self.case_sensitive)
986 }
987
988 pub fn take(&self, n: usize) -> Result<DataFrame, PolarsError> {
990 transformations::take(self, n, self.case_sensitive)
991 }
992
993 pub fn tail(&self, n: usize) -> Result<DataFrame, PolarsError> {
995 transformations::tail(self, n, self.case_sensitive)
996 }
997
998 pub fn is_empty(&self) -> bool {
1000 transformations::is_empty(self)
1001 }
1002
1003 pub fn to_df(&self, names: Vec<&str>) -> Result<DataFrame, PolarsError> {
1005 transformations::to_df(self, &names, self.case_sensitive)
1006 }
1007
1008 pub fn stat(&self) -> DataFrameStat<'_> {
1010 DataFrameStat { df: self }
1011 }
1012
1013 pub fn corr(&self) -> Result<DataFrame, PolarsError> {
1015 self.stat().corr_matrix()
1016 }
1017
1018 pub fn corr_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
1020 self.stat().corr(col1, col2)
1021 }
1022
1023 pub fn cov_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
1025 self.stat().cov(col1, col2)
1026 }
1027
1028 pub fn summary(&self) -> Result<DataFrame, PolarsError> {
1030 self.describe()
1031 }
1032
1033 pub fn to_json(&self) -> Result<Vec<String>, PolarsError> {
1035 transformations::to_json(self)
1036 }
1037
1038 pub fn explain(&self) -> String {
1040 transformations::explain(self)
1041 }
1042
1043 pub fn print_schema(&self) -> Result<String, PolarsError> {
1045 transformations::print_schema(self)
1046 }
1047
1048 pub fn checkpoint(&self) -> Result<DataFrame, PolarsError> {
1050 Ok(self.clone())
1051 }
1052
1053 pub fn local_checkpoint(&self) -> Result<DataFrame, PolarsError> {
1055 Ok(self.clone())
1056 }
1057
1058 pub fn repartition(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
1060 Ok(self.clone())
1061 }
1062
1063 pub fn repartition_by_range(
1065 &self,
1066 _num_partitions: usize,
1067 _cols: Vec<&str>,
1068 ) -> Result<DataFrame, PolarsError> {
1069 Ok(self.clone())
1070 }
1071
1072 pub fn dtypes(&self) -> Result<Vec<(String, String)>, PolarsError> {
1074 let schema = self.schema_or_collect()?;
1075 Ok(schema
1076 .iter_names_and_dtypes()
1077 .map(|(name, dtype)| (name.to_string(), format!("{dtype:?}")))
1078 .collect())
1079 }
1080
1081 pub fn sort_within_partitions(
1083 &self,
1084 _cols: &[crate::functions::SortOrder],
1085 ) -> Result<DataFrame, PolarsError> {
1086 Ok(self.clone())
1087 }
1088
1089 pub fn coalesce(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
1091 Ok(self.clone())
1092 }
1093
1094 pub fn hint(&self, _name: &str, _params: &[i32]) -> Result<DataFrame, PolarsError> {
1096 Ok(self.clone())
1097 }
1098
1099 pub fn is_local(&self) -> bool {
1101 true
1102 }
1103
1104 pub fn input_files(&self) -> Vec<String> {
1106 Vec::new()
1107 }
1108
1109 pub fn same_semantics(&self, _other: &DataFrame) -> bool {
1111 false
1112 }
1113
1114 pub fn semantic_hash(&self) -> u64 {
1116 0
1117 }
1118
1119 pub fn observe(&self, _name: &str, _expr: Expr) -> Result<DataFrame, PolarsError> {
1121 Ok(self.clone())
1122 }
1123
1124 pub fn with_watermark(
1126 &self,
1127 _event_time: &str,
1128 _delay: &str,
1129 ) -> Result<DataFrame, PolarsError> {
1130 Ok(self.clone())
1131 }
1132
1133 pub fn select_expr(&self, exprs: &[String]) -> Result<DataFrame, PolarsError> {
1135 transformations::select_expr(self, exprs, self.case_sensitive)
1136 }
1137
1138 pub fn col_regex(&self, pattern: &str) -> Result<DataFrame, PolarsError> {
1140 transformations::col_regex(self, pattern, self.case_sensitive)
1141 }
1142
1143 pub fn with_columns(&self, exprs: &[(String, Column)]) -> Result<DataFrame, PolarsError> {
1145 transformations::with_columns(self, exprs, self.case_sensitive)
1146 }
1147
1148 pub fn with_columns_renamed(
1150 &self,
1151 renames: &[(String, String)],
1152 ) -> Result<DataFrame, PolarsError> {
1153 transformations::with_columns_renamed(self, renames, self.case_sensitive)
1154 }
1155
1156 pub fn na(&self) -> DataFrameNa<'_> {
1158 DataFrameNa { df: self }
1159 }
1160
1161 pub fn offset(&self, n: usize) -> Result<DataFrame, PolarsError> {
1163 transformations::offset(self, n, self.case_sensitive)
1164 }
1165
1166 pub fn transform<F>(&self, f: F) -> Result<DataFrame, PolarsError>
1168 where
1169 F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
1170 {
1171 transformations::transform(self, f)
1172 }
1173
1174 pub fn freq_items(&self, columns: &[&str], support: f64) -> Result<DataFrame, PolarsError> {
1176 transformations::freq_items(self, columns, support, self.case_sensitive)
1177 }
1178
1179 pub fn approx_quantile(
1181 &self,
1182 column: &str,
1183 probabilities: &[f64],
1184 ) -> Result<DataFrame, PolarsError> {
1185 transformations::approx_quantile(self, column, probabilities, self.case_sensitive)
1186 }
1187
1188 pub fn crosstab(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
1190 transformations::crosstab(self, col1, col2, self.case_sensitive)
1191 }
1192
1193 pub fn melt(&self, id_vars: &[&str], value_vars: &[&str]) -> Result<DataFrame, PolarsError> {
1195 transformations::melt(self, id_vars, value_vars, self.case_sensitive)
1196 }
1197
1198 pub fn unpivot(&self, ids: &[&str], values: &[&str]) -> Result<DataFrame, PolarsError> {
1200 transformations::melt(self, ids, values, self.case_sensitive)
1201 }
1202
1203 pub fn pivot(
1205 &self,
1206 _pivot_col: &str,
1207 _values: Option<Vec<&str>>,
1208 ) -> Result<DataFrame, PolarsError> {
1209 Err(PolarsError::InvalidOperation(
1210 "pivot is not yet implemented; use crosstab(col1, col2) for two-column cross-tabulation."
1211 .into(),
1212 ))
1213 }
1214
1215 pub fn except_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
1217 transformations::except_all(self, other, self.case_sensitive)
1218 }
1219
1220 pub fn intersect_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
1222 transformations::intersect_all(self, other, self.case_sensitive)
1223 }
1224
1225 #[cfg(feature = "delta")]
1228 pub fn write_delta(
1229 &self,
1230 path: impl AsRef<std::path::Path>,
1231 overwrite: bool,
1232 ) -> Result<(), PolarsError> {
1233 crate::delta::write_delta(self.collect_inner()?.as_ref(), path, overwrite)
1234 }
1235
1236 #[cfg(not(feature = "delta"))]
1238 pub fn write_delta(
1239 &self,
1240 _path: impl AsRef<std::path::Path>,
1241 _overwrite: bool,
1242 ) -> Result<(), PolarsError> {
1243 Err(PolarsError::InvalidOperation(
1244 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1245 ))
1246 }
1247
1248 pub fn save_as_delta_table(&self, session: &crate::session::SparkSession, name: &str) {
1250 session.register_table(name, self.clone());
1251 }
1252
1253 pub fn write(&self) -> DataFrameWriter<'_> {
1255 DataFrameWriter {
1256 df: self,
1257 mode: WriteMode::Overwrite,
1258 format: WriteFormat::Parquet,
1259 options: HashMap::new(),
1260 partition_by: Vec::new(),
1261 }
1262 }
1263}
1264
1265#[derive(Clone, Copy, PartialEq, Eq)]
1267pub enum WriteMode {
1268 Overwrite,
1269 Append,
1270}
1271
1272#[derive(Clone, Copy, PartialEq, Eq)]
1274pub enum SaveMode {
1275 ErrorIfExists,
1277 Overwrite,
1279 Append,
1281 Ignore,
1283}
1284
1285#[derive(Clone, Copy)]
1287pub enum WriteFormat {
1288 Parquet,
1289 Csv,
1290 Json,
1291}
1292
1293pub struct DataFrameWriter<'a> {
1295 df: &'a DataFrame,
1296 mode: WriteMode,
1297 format: WriteFormat,
1298 options: HashMap<String, String>,
1299 partition_by: Vec<String>,
1300}
1301
1302impl<'a> DataFrameWriter<'a> {
1303 pub fn mode(mut self, mode: WriteMode) -> Self {
1304 self.mode = mode;
1305 self
1306 }
1307
1308 pub fn format(mut self, format: WriteFormat) -> Self {
1309 self.format = format;
1310 self
1311 }
1312
1313 pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1315 self.options.insert(key.into(), value.into());
1316 self
1317 }
1318
1319 pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
1321 for (k, v) in opts {
1322 self.options.insert(k, v);
1323 }
1324 self
1325 }
1326
1327 pub fn partition_by(mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> Self {
1329 self.partition_by = cols.into_iter().map(|s| s.into()).collect();
1330 self
1331 }
1332
1333 pub fn save_as_table(
1335 &self,
1336 session: &SparkSession,
1337 name: &str,
1338 mode: SaveMode,
1339 ) -> Result<(), PolarsError> {
1340 use polars::prelude::*;
1341 use std::fs;
1342 use std::path::Path;
1343
1344 let warehouse_path = session.warehouse_dir().map(|w| Path::new(w).join(name));
1345 let warehouse_exists = warehouse_path.as_ref().is_some_and(|p| p.is_dir());
1346
1347 fn persist_to_warehouse(
1348 df: &crate::dataframe::DataFrame,
1349 dir: &Path,
1350 ) -> Result<(), PolarsError> {
1351 use std::fs;
1352 fs::create_dir_all(dir).map_err(|e| {
1353 PolarsError::ComputeError(format!("saveAsTable: create dir: {e}").into())
1354 })?;
1355 let file_path = dir.join("data.parquet");
1356 df.write()
1357 .mode(crate::dataframe::WriteMode::Overwrite)
1358 .format(crate::dataframe::WriteFormat::Parquet)
1359 .save(&file_path)
1360 }
1361
1362 let final_df = match mode {
1363 SaveMode::ErrorIfExists => {
1364 if session.saved_table_exists(name) || warehouse_exists {
1365 return Err(PolarsError::InvalidOperation(
1366 format!(
1367 "Table or view '{name}' already exists. SaveMode is ErrorIfExists."
1368 )
1369 .into(),
1370 ));
1371 }
1372 if let Some(ref p) = warehouse_path {
1373 persist_to_warehouse(self.df, p)?;
1374 }
1375 self.df.clone()
1376 }
1377 SaveMode::Overwrite => {
1378 if let Some(ref p) = warehouse_path {
1379 let _ = fs::remove_dir_all(p);
1380 persist_to_warehouse(self.df, p)?;
1381 }
1382 self.df.clone()
1383 }
1384 SaveMode::Append => {
1385 let existing_pl = if let Some(existing) = session.get_saved_table(name) {
1386 existing.collect_inner()?.as_ref().clone()
1387 } else if let (Some(ref p), true) = (warehouse_path.as_ref(), warehouse_exists) {
1388 let data_file = p.join("data.parquet");
1390 let read_path = if data_file.is_file() {
1391 data_file.as_path()
1392 } else {
1393 p.as_ref()
1394 };
1395 let pl_path =
1396 polars::prelude::PlRefPath::try_from_path(read_path).map_err(|e| {
1397 PolarsError::ComputeError(
1398 format!("saveAsTable append: path: {e}").into(),
1399 )
1400 })?;
1401 let lf = LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default()).map_err(
1402 |e| {
1403 PolarsError::ComputeError(
1404 format!("saveAsTable append: read warehouse: {e}").into(),
1405 )
1406 },
1407 )?;
1408 lf.collect().map_err(|e| {
1409 PolarsError::ComputeError(
1410 format!("saveAsTable append: collect: {e}").into(),
1411 )
1412 })?
1413 } else {
1414 session.register_table(name, self.df.clone());
1416 if let Some(ref p) = warehouse_path {
1417 persist_to_warehouse(self.df, p)?;
1418 }
1419 return Ok(());
1420 };
1421 let new_pl = self.df.collect_inner()?.as_ref().clone();
1422 let existing_cols: Vec<&str> = existing_pl
1423 .get_column_names()
1424 .iter()
1425 .map(|s| s.as_str())
1426 .collect();
1427 let new_cols = new_pl.get_column_names();
1428 let missing: Vec<_> = existing_cols
1429 .iter()
1430 .filter(|c| !new_cols.iter().any(|n| n.as_str() == **c))
1431 .collect();
1432 if !missing.is_empty() {
1433 return Err(PolarsError::InvalidOperation(
1434 format!(
1435 "saveAsTable append: new DataFrame missing columns: {:?}",
1436 missing
1437 )
1438 .into(),
1439 ));
1440 }
1441 let new_ordered = new_pl.select(existing_cols.iter().copied())?;
1442 let mut combined = existing_pl;
1443 combined.vstack_mut(&new_ordered)?;
1444 let merged = crate::dataframe::DataFrame::from_polars_with_options(
1445 combined,
1446 self.df.case_sensitive,
1447 );
1448 if let Some(ref p) = warehouse_path {
1449 let _ = fs::remove_dir_all(p);
1450 persist_to_warehouse(&merged, p)?;
1451 }
1452 merged
1453 }
1454 SaveMode::Ignore => {
1455 if session.saved_table_exists(name) || warehouse_exists {
1456 return Ok(());
1457 }
1458 if let Some(ref p) = warehouse_path {
1459 persist_to_warehouse(self.df, p)?;
1460 }
1461 self.df.clone()
1462 }
1463 };
1464 session.register_table(name, final_df);
1465 Ok(())
1466 }
1467
1468 pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1470 DataFrameWriter {
1471 df: self.df,
1472 mode: self.mode,
1473 format: WriteFormat::Parquet,
1474 options: self.options.clone(),
1475 partition_by: self.partition_by.clone(),
1476 }
1477 .save(path)
1478 }
1479
1480 pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1482 DataFrameWriter {
1483 df: self.df,
1484 mode: self.mode,
1485 format: WriteFormat::Csv,
1486 options: self.options.clone(),
1487 partition_by: self.partition_by.clone(),
1488 }
1489 .save(path)
1490 }
1491
1492 pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1494 DataFrameWriter {
1495 df: self.df,
1496 mode: self.mode,
1497 format: WriteFormat::Json,
1498 options: self.options.clone(),
1499 partition_by: self.partition_by.clone(),
1500 }
1501 .save(path)
1502 }
1503
1504 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1507 use polars::prelude::*;
1508 let path = path.as_ref();
1509 let to_write: PlDataFrame = match self.mode {
1510 WriteMode::Overwrite => self.df.collect_inner()?.as_ref().clone(),
1511 WriteMode::Append => {
1512 if self.partition_by.is_empty() {
1513 let existing: Option<PlDataFrame> = if path.exists() && path.is_file() {
1514 match self.format {
1515 WriteFormat::Parquet => polars::prelude::PlRefPath::try_from_path(path)
1516 .ok()
1517 .and_then(|pl_path| {
1518 LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default())
1519 .and_then(|lf| lf.collect())
1520 .ok()
1521 }),
1522 WriteFormat::Csv => polars::prelude::PlRefPath::try_from_path(path)
1523 .ok()
1524 .and_then(|pl_path| {
1525 LazyCsvReader::new(pl_path)
1526 .with_has_header(true)
1527 .finish()
1528 .and_then(|lf| lf.collect())
1529 .ok()
1530 }),
1531 WriteFormat::Json => polars::prelude::PlRefPath::try_from_path(path)
1532 .ok()
1533 .and_then(|pl_path| {
1534 LazyJsonLineReader::new(pl_path)
1535 .finish()
1536 .and_then(|lf| lf.collect())
1537 .ok()
1538 }),
1539 }
1540 } else {
1541 None
1542 };
1543 match existing {
1544 Some(existing) => {
1545 let lfs: [LazyFrame; 2] = [
1546 existing.clone().lazy(),
1547 self.df.collect_inner()?.as_ref().clone().lazy(),
1548 ];
1549 concat(lfs, UnionArgs::default())?.collect()?
1550 }
1551 None => self.df.collect_inner()?.as_ref().clone(),
1552 }
1553 } else {
1554 self.df.collect_inner()?.as_ref().clone()
1555 }
1556 }
1557 };
1558
1559 if !self.partition_by.is_empty() {
1560 return self.save_partitioned(path, &to_write);
1561 }
1562
1563 match self.format {
1564 WriteFormat::Parquet => {
1565 let mut file = std::fs::File::create(path).map_err(|e| {
1566 PolarsError::ComputeError(format!("write parquet create: {e}").into())
1567 })?;
1568 let mut df_mut = to_write;
1569 ParquetWriter::new(&mut file)
1570 .finish(&mut df_mut)
1571 .map_err(|e| PolarsError::ComputeError(format!("write parquet: {e}").into()))?;
1572 }
1573 WriteFormat::Csv => {
1574 let has_header = self
1575 .options
1576 .get("header")
1577 .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
1578 .unwrap_or(true);
1579 let delimiter = self
1580 .options
1581 .get("sep")
1582 .and_then(|s| s.bytes().next())
1583 .unwrap_or(b',');
1584 let mut file = std::fs::File::create(path).map_err(|e| {
1585 PolarsError::ComputeError(format!("write csv create: {e}").into())
1586 })?;
1587 CsvWriter::new(&mut file)
1588 .include_header(has_header)
1589 .with_separator(delimiter)
1590 .finish(&mut to_write.clone())
1591 .map_err(|e| PolarsError::ComputeError(format!("write csv: {e}").into()))?;
1592 }
1593 WriteFormat::Json => {
1594 let mut file = std::fs::File::create(path).map_err(|e| {
1595 PolarsError::ComputeError(format!("write json create: {e}").into())
1596 })?;
1597 JsonWriter::new(&mut file)
1598 .finish(&mut to_write.clone())
1599 .map_err(|e| PolarsError::ComputeError(format!("write json: {e}").into()))?;
1600 }
1601 }
1602 Ok(())
1603 }
1604
1605 fn save_partitioned(&self, path: &Path, to_write: &PlDataFrame) -> Result<(), PolarsError> {
1607 use polars::prelude::*;
1608 let resolved: Vec<String> = self
1609 .partition_by
1610 .iter()
1611 .map(|c| self.df.resolve_column_name(c))
1612 .collect::<Result<Vec<_>, _>>()?;
1613 let all_names = to_write.get_column_names();
1614 let data_cols: Vec<&str> = all_names
1615 .iter()
1616 .filter(|n| !resolved.iter().any(|r| r == n.as_str()))
1617 .map(|n| n.as_str())
1618 .collect();
1619
1620 let unique_keys = to_write
1621 .select(resolved.iter().map(|s| s.as_str()).collect::<Vec<_>>())?
1622 .unique::<Option<&[String]>, String>(
1623 None,
1624 polars::prelude::UniqueKeepStrategy::First,
1625 None,
1626 )?;
1627
1628 if self.mode == WriteMode::Overwrite && path.exists() {
1629 if path.is_dir() {
1630 std::fs::remove_dir_all(path).map_err(|e| {
1631 PolarsError::ComputeError(
1632 format!("write partitioned: remove_dir_all: {e}").into(),
1633 )
1634 })?;
1635 } else {
1636 std::fs::remove_file(path).map_err(|e| {
1637 PolarsError::ComputeError(format!("write partitioned: remove_file: {e}").into())
1638 })?;
1639 }
1640 }
1641 std::fs::create_dir_all(path).map_err(|e| {
1642 PolarsError::ComputeError(format!("write partitioned: create_dir_all: {e}").into())
1643 })?;
1644
1645 let ext = match self.format {
1646 WriteFormat::Parquet => "parquet",
1647 WriteFormat::Csv => "csv",
1648 WriteFormat::Json => "json",
1649 };
1650
1651 for row_idx in 0..unique_keys.height() {
1652 let row = unique_keys
1653 .get(row_idx)
1654 .ok_or_else(|| PolarsError::ComputeError("partition_row: get row".into()))?;
1655 let filter_expr = partition_row_to_filter_expr(&resolved, &row)?;
1656 let subset = to_write.clone().lazy().filter(filter_expr).collect()?;
1657 let subset = subset.select(data_cols.iter().copied())?;
1658 if subset.height() == 0 {
1659 continue;
1660 }
1661
1662 let part_path: std::path::PathBuf = resolved
1663 .iter()
1664 .zip(row.iter())
1665 .map(|(name, av)| format!("{}={}", name, format_partition_value(av)))
1666 .fold(path.to_path_buf(), |p, seg| p.join(seg));
1667 std::fs::create_dir_all(&part_path).map_err(|e| {
1668 PolarsError::ComputeError(
1669 format!("write partitioned: create_dir_all partition: {e}").into(),
1670 )
1671 })?;
1672
1673 let file_idx = if self.mode == WriteMode::Append {
1674 let suffix = format!(".{ext}");
1675 let max_n = std::fs::read_dir(&part_path)
1676 .map(|rd| {
1677 rd.filter_map(Result::ok)
1678 .filter_map(|e| {
1679 e.file_name().to_str().and_then(|s| {
1680 s.strip_prefix("part-")
1681 .and_then(|t| t.strip_suffix(&suffix))
1682 .and_then(|t| t.parse::<u32>().ok())
1683 })
1684 })
1685 .max()
1686 .unwrap_or(0)
1687 })
1688 .unwrap_or(0);
1689 max_n + 1
1690 } else {
1691 0
1692 };
1693 let filename = format!("part-{file_idx:05}.{ext}");
1694 let file_path = part_path.join(&filename);
1695
1696 match self.format {
1697 WriteFormat::Parquet => {
1698 let mut file = std::fs::File::create(&file_path).map_err(|e| {
1699 PolarsError::ComputeError(
1700 format!("write partitioned parquet create: {e}").into(),
1701 )
1702 })?;
1703 let mut df_mut = subset;
1704 ParquetWriter::new(&mut file)
1705 .finish(&mut df_mut)
1706 .map_err(|e| {
1707 PolarsError::ComputeError(
1708 format!("write partitioned parquet: {e}").into(),
1709 )
1710 })?;
1711 }
1712 WriteFormat::Csv => {
1713 let has_header = self
1714 .options
1715 .get("header")
1716 .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
1717 .unwrap_or(true);
1718 let delimiter = self
1719 .options
1720 .get("sep")
1721 .and_then(|s| s.bytes().next())
1722 .unwrap_or(b',');
1723 let mut file = std::fs::File::create(&file_path).map_err(|e| {
1724 PolarsError::ComputeError(
1725 format!("write partitioned csv create: {e}").into(),
1726 )
1727 })?;
1728 CsvWriter::new(&mut file)
1729 .include_header(has_header)
1730 .with_separator(delimiter)
1731 .finish(&mut subset.clone())
1732 .map_err(|e| {
1733 PolarsError::ComputeError(format!("write partitioned csv: {e}").into())
1734 })?;
1735 }
1736 WriteFormat::Json => {
1737 let mut file = std::fs::File::create(&file_path).map_err(|e| {
1738 PolarsError::ComputeError(
1739 format!("write partitioned json create: {e}").into(),
1740 )
1741 })?;
1742 JsonWriter::new(&mut file)
1743 .finish(&mut subset.clone())
1744 .map_err(|e| {
1745 PolarsError::ComputeError(format!("write partitioned json: {e}").into())
1746 })?;
1747 }
1748 }
1749 }
1750 Ok(())
1751 }
1752}
1753
1754impl Clone for DataFrame {
1755 fn clone(&self) -> Self {
1756 DataFrame {
1757 inner: match &self.inner {
1758 DataFrameInner::Eager(df) => DataFrameInner::Eager(df.clone()),
1759 DataFrameInner::Lazy(lf) => DataFrameInner::Lazy(lf.clone()),
1760 },
1761 case_sensitive: self.case_sensitive,
1762 alias: self.alias.clone(),
1763 }
1764 }
1765}
1766
1767fn format_partition_value(av: &AnyValue<'_>) -> String {
1770 let s = match av {
1771 AnyValue::Null => "__HIVE_DEFAULT_PARTITION__".to_string(),
1772 AnyValue::Boolean(b) => b.to_string(),
1773 AnyValue::Int32(i) => i.to_string(),
1774 AnyValue::Int64(i) => i.to_string(),
1775 AnyValue::UInt32(u) => u.to_string(),
1776 AnyValue::UInt64(u) => u.to_string(),
1777 AnyValue::Float32(f) => f.to_string(),
1778 AnyValue::Float64(f) => f.to_string(),
1779 AnyValue::String(s) => s.to_string(),
1780 AnyValue::StringOwned(s) => s.as_str().to_string(),
1781 AnyValue::Date(d) => d.to_string(),
1782 _ => av.to_string(),
1783 };
1784 s.replace([std::path::MAIN_SEPARATOR, '/'], "_")
1786}
1787
1788fn partition_row_to_filter_expr(
1790 col_names: &[String],
1791 row: &[AnyValue<'_>],
1792) -> Result<Expr, PolarsError> {
1793 if col_names.len() != row.len() {
1794 return Err(PolarsError::ComputeError(
1795 format!(
1796 "partition_row_to_filter_expr: {} columns but {} row values",
1797 col_names.len(),
1798 row.len()
1799 )
1800 .into(),
1801 ));
1802 }
1803 let mut pred = None::<Expr>;
1804 for (name, av) in col_names.iter().zip(row.iter()) {
1805 let clause = match av {
1806 AnyValue::Null => col(name.as_str()).is_null(),
1807 AnyValue::Boolean(b) => col(name.as_str()).eq(lit(*b)),
1808 AnyValue::Int32(i) => col(name.as_str()).eq(lit(*i)),
1809 AnyValue::Int64(i) => col(name.as_str()).eq(lit(*i)),
1810 AnyValue::UInt32(u) => col(name.as_str()).eq(lit(*u)),
1811 AnyValue::UInt64(u) => col(name.as_str()).eq(lit(*u)),
1812 AnyValue::Float32(f) => col(name.as_str()).eq(lit(*f)),
1813 AnyValue::Float64(f) => col(name.as_str()).eq(lit(*f)),
1814 AnyValue::String(s) => col(name.as_str()).eq(lit(s.to_string())),
1815 AnyValue::StringOwned(s) => col(name.as_str()).eq(lit(s.clone())),
1816 _ => {
1817 let s = av.to_string();
1819 col(name.as_str()).cast(DataType::String).eq(lit(s))
1820 }
1821 };
1822 pred = Some(match pred {
1823 None => clause,
1824 Some(p) => p.and(clause),
1825 });
1826 }
1827 Ok(pred.unwrap_or_else(|| lit(true)))
1828}
1829
1830fn is_map_format(dtype: &DataType) -> bool {
1832 if let DataType::List(inner) = dtype {
1833 if let DataType::Struct(fields) = inner.as_ref() {
1834 let has_key = fields.iter().any(|f| f.name == "key");
1835 let has_value = fields.iter().any(|f| f.name == "value");
1836 return has_key && has_value;
1837 }
1838 }
1839 false
1840}
1841
1842fn any_value_to_json(av: &AnyValue<'_>, dtype: &DataType) -> JsonValue {
1845 use serde_json::Map;
1846 match av {
1847 AnyValue::Null => JsonValue::Null,
1848 AnyValue::Boolean(b) => JsonValue::Bool(*b),
1849 AnyValue::Int32(i) => JsonValue::Number(serde_json::Number::from(*i)),
1850 AnyValue::Int64(i) => JsonValue::Number(serde_json::Number::from(*i)),
1851 AnyValue::UInt32(u) => JsonValue::Number(serde_json::Number::from(*u)),
1852 AnyValue::UInt64(u) => JsonValue::Number(serde_json::Number::from(*u)),
1853 AnyValue::Float32(f) => serde_json::Number::from_f64(f64::from(*f))
1854 .map(JsonValue::Number)
1855 .unwrap_or(JsonValue::Null),
1856 AnyValue::Float64(f) => serde_json::Number::from_f64(*f)
1857 .map(JsonValue::Number)
1858 .unwrap_or(JsonValue::Null),
1859 AnyValue::String(s) => JsonValue::String(s.to_string()),
1860 AnyValue::StringOwned(s) => JsonValue::String(s.to_string()),
1861 AnyValue::List(s) => {
1862 if is_map_format(dtype) {
1863 let mut obj = Map::new();
1865 for i in 0..s.len() {
1866 if let Ok(elem) = s.get(i) {
1867 let (k, v) = match &elem {
1868 AnyValue::Struct(_, _, fields) => {
1869 let mut k = None;
1870 let mut v = None;
1871 for (fld_av, fld) in elem._iter_struct_av().zip(fields.iter()) {
1872 if fld.name == "key" {
1873 k = fld_av
1874 .get_str()
1875 .map(|s| s.to_string())
1876 .or_else(|| Some(fld_av.to_string()));
1877 } else if fld.name == "value" {
1878 v = Some(any_value_to_json(&fld_av, &fld.dtype));
1879 }
1880 }
1881 (k, v)
1882 }
1883 AnyValue::StructOwned(payload) => {
1884 let (values, fields) = &**payload;
1885 let mut k = None;
1886 let mut v = None;
1887 for (fld_av, fld) in values.iter().zip(fields.iter()) {
1888 if fld.name == "key" {
1889 k = fld_av
1890 .get_str()
1891 .map(|s| s.to_string())
1892 .or_else(|| Some(fld_av.to_string()));
1893 } else if fld.name == "value" {
1894 v = Some(any_value_to_json(fld_av, &fld.dtype));
1895 }
1896 }
1897 (k, v)
1898 }
1899 _ => (None, None),
1900 };
1901 if let (Some(key), Some(val)) = (k, v) {
1902 obj.insert(key, val);
1903 }
1904 }
1905 }
1906 JsonValue::Object(obj)
1907 } else {
1908 let inner_dtype = match dtype {
1909 DataType::List(inner) => inner.as_ref(),
1910 _ => dtype,
1911 };
1912 let arr: Vec<JsonValue> = (0..s.len())
1913 .filter_map(|i| s.get(i).ok())
1914 .map(|a| any_value_to_json(&a, inner_dtype))
1915 .collect();
1916 JsonValue::Array(arr)
1917 }
1918 }
1919 AnyValue::Struct(_, _, fields) => {
1920 let mut obj = Map::new();
1921 for (fld_av, fld) in av._iter_struct_av().zip(fields.iter()) {
1922 obj.insert(fld.name.to_string(), any_value_to_json(&fld_av, &fld.dtype));
1923 }
1924 JsonValue::Object(obj)
1925 }
1926 AnyValue::StructOwned(payload) => {
1927 let (values, fields) = &**payload;
1928 let mut obj = Map::new();
1929 for (fld_av, fld) in values.iter().zip(fields.iter()) {
1930 obj.insert(fld.name.to_string(), any_value_to_json(fld_av, &fld.dtype));
1931 }
1932 JsonValue::Object(obj)
1933 }
1934 _ => JsonValue::Null,
1935 }
1936}
1937
1938pub fn broadcast(df: &DataFrame) -> DataFrame {
1940 df.clone()
1941}
1942
1943#[cfg(test)]
1944mod tests {
1945 use super::*;
1946 use polars::prelude::{NamedFrom, Series};
1947
1948 #[test]
1950 fn coerce_string_numeric_root_in_filter() {
1951 let s = Series::new("str_col".into(), &["123", "456"]);
1952 let pl_df = polars::prelude::DataFrame::new_infer_height(vec![s.into()]).unwrap();
1953 let df = DataFrame::from_polars(pl_df);
1954 let expr = col("str_col").eq(lit(123i64));
1955 let out = df.filter(expr).unwrap();
1956 assert_eq!(out.count().unwrap(), 1);
1957 }
1958
1959 #[test]
1961 fn lazy_schema_columns_resolve_before_collect() {
1962 let spark = SparkSession::builder()
1963 .app_name("lazy_mod_tests")
1964 .get_or_create();
1965 let df = spark
1966 .create_dataframe(
1967 vec![
1968 (1i64, 25i64, "a".to_string()),
1969 (2i64, 30i64, "b".to_string()),
1970 ],
1971 vec!["id", "age", "name"],
1972 )
1973 .unwrap();
1974 assert_eq!(df.columns().unwrap(), vec!["id", "age", "name"]);
1975 assert_eq!(df.resolve_column_name("AGE").unwrap(), "age");
1976 assert!(df.get_column_dtype("id").unwrap().is_integer());
1977 }
1978
1979 #[test]
1981 fn lazy_from_lazy_produces_valid_df() {
1982 let _spark = SparkSession::builder()
1983 .app_name("lazy_mod_tests")
1984 .get_or_create();
1985 let pl_df = polars::prelude::df!("x" => &[1i64, 2, 3]).unwrap();
1986 let df = DataFrame::from_lazy_with_options(pl_df.lazy(), false);
1987 assert_eq!(df.columns().unwrap(), vec!["x"]);
1988 assert_eq!(df.count().unwrap(), 3);
1989 }
1990}