1mod aggregations;
4mod joins;
5mod stats;
6mod transformations;
7
8pub use aggregations::{CubeRollupData, GroupedData, PivotedGroupedData};
9pub use joins::{JoinType, join};
10pub use stats::DataFrameStat;
11pub use transformations::{
12 DataFrameNa, filter, order_by, order_by_exprs, select, select_with_exprs, with_column,
13};
14
15use crate::column::Column;
16use crate::error::EngineError;
17use crate::functions::SortOrder;
18use crate::schema::StructType;
19use crate::session::SparkSession;
20use crate::type_coercion::coerce_for_pyspark_comparison;
21use polars::prelude::{
22 AnyValue, DataFrame as PlDataFrame, DataType, Expr, IntoLazy, LazyFrame, PlSmallStr,
23 PolarsError, Schema, SchemaNamesAndDtypes, UnknownKind, col, lit,
24};
25use serde_json::Value as JsonValue;
26use std::collections::{HashMap, HashSet};
27use std::path::Path;
28use std::sync::Arc;
29
30const DEFAULT_CASE_SENSITIVE: bool = false;
32
33#[allow(clippy::large_enum_variant)]
36pub(crate) enum DataFrameInner {
37 #[allow(dead_code)]
38 Eager(Arc<PlDataFrame>),
39 Lazy(LazyFrame),
40}
41
42pub struct DataFrame {
46 pub(crate) inner: DataFrameInner,
47 pub(crate) case_sensitive: bool,
49 pub(crate) alias: Option<String>,
51}
52
53impl DataFrame {
54 pub fn from_polars(df: PlDataFrame) -> Self {
57 let lf = df.lazy();
58 DataFrame {
59 inner: DataFrameInner::Lazy(lf),
60 case_sensitive: DEFAULT_CASE_SENSITIVE,
61 alias: None,
62 }
63 }
64
65 pub fn from_polars_with_options(df: PlDataFrame, case_sensitive: bool) -> Self {
68 let lf = df.lazy();
69 DataFrame {
70 inner: DataFrameInner::Lazy(lf),
71 case_sensitive,
72 alias: None,
73 }
74 }
75
76 pub fn from_lazy(lf: LazyFrame) -> Self {
78 DataFrame {
79 inner: DataFrameInner::Lazy(lf),
80 case_sensitive: DEFAULT_CASE_SENSITIVE,
81 alias: None,
82 }
83 }
84
85 pub fn from_lazy_with_options(lf: LazyFrame, case_sensitive: bool) -> Self {
87 DataFrame {
88 inner: DataFrameInner::Lazy(lf),
89 case_sensitive,
90 alias: None,
91 }
92 }
93
94 pub(crate) fn with_case_insensitive_column_resolution(self) -> Self {
97 DataFrame {
98 inner: self.inner,
99 case_sensitive: false,
100 alias: self.alias,
101 }
102 }
103
104 pub fn empty() -> Self {
106 DataFrame {
107 inner: DataFrameInner::Lazy(PlDataFrame::empty().lazy()),
108 case_sensitive: DEFAULT_CASE_SENSITIVE,
109 alias: None,
110 }
111 }
112
113 pub(crate) fn lazy_frame(&self) -> LazyFrame {
115 match &self.inner {
116 DataFrameInner::Eager(df) => df.as_ref().clone().lazy(),
117 DataFrameInner::Lazy(lf) => lf.clone(),
118 }
119 }
120
121 pub(crate) fn collect_inner(&self) -> Result<Arc<PlDataFrame>, PolarsError> {
123 match &self.inner {
124 DataFrameInner::Eager(df) => Ok(df.clone()),
125 DataFrameInner::Lazy(lf) => Ok(Arc::new(lf.clone().collect()?)),
126 }
127 }
128
129 pub fn alias(&self, name: &str) -> Self {
132 let lf = self.lazy_frame();
133 DataFrame {
134 inner: DataFrameInner::Lazy(lf),
135 case_sensitive: self.case_sensitive,
136 alias: Some(name.to_string()),
137 }
138 }
139
140 pub fn resolve_expr_column_names(&self, expr: Expr) -> Result<Expr, PolarsError> {
147 let df = self;
148 let mut alias_output_names: HashSet<String> = HashSet::new();
149 let _ = expr.clone().try_map_expr(|e| {
150 if let Expr::Alias(_, name) = &e {
151 alias_output_names.insert(name.as_str().to_string());
152 }
153 Ok(e)
154 })?;
155 expr.try_map_expr(move |e| {
156 if let Expr::Column(name) = &e {
157 let name_str = name.as_str();
158 if alias_output_names.contains(name_str) {
159 return Ok(e);
160 }
161 if name_str.is_empty() {
163 return Ok(e);
164 }
165 if name_str.contains('.') {
167 let parts: Vec<&str> = name_str.split('.').collect();
168 let first = parts[0];
169 let rest = &parts[1..];
170 if rest.is_empty() {
171 return Err(PolarsError::ColumnNotFound(
172 format!("Column '{}': trailing dot not allowed", name_str).into(),
173 ));
174 }
175 let resolved = df.resolve_column_name(first)?;
176 let mut expr = col(PlSmallStr::from(resolved.as_str()));
177 for field in rest {
178 expr = expr.struct_().field_by_name(field);
179 }
180 return Ok(expr);
181 }
182 let resolved = df.resolve_column_name(name_str)?;
183 return Ok(Expr::Column(PlSmallStr::from(resolved.as_str())));
184 }
185 Ok(e)
186 })
187 }
188
189 pub fn coerce_string_numeric_comparisons(&self, expr: Expr) -> Result<Expr, PolarsError> {
197 use polars::prelude::{DataType, LiteralValue, Operator};
198 use std::sync::Arc;
199
200 fn is_numeric_literal(expr: &Expr) -> bool {
201 match expr {
202 Expr::Literal(lv) => {
203 let dt = lv.get_datatype();
204 dt.is_numeric()
205 || matches!(
206 dt,
207 DataType::Unknown(UnknownKind::Int(_))
208 | DataType::Unknown(UnknownKind::Float)
209 )
210 }
211 _ => false,
212 }
213 }
214
215 fn literal_dtype(lv: &LiteralValue) -> DataType {
216 let dt = lv.get_datatype();
217 if matches!(
218 dt,
219 DataType::Unknown(UnknownKind::Int(_)) | DataType::Unknown(UnknownKind::Float)
220 ) {
221 DataType::Float64
222 } else {
223 dt
224 }
225 }
226
227 let expr = {
230 if let Expr::BinaryExpr { left, op, right } = &expr {
231 let is_comparison_op = matches!(
232 op,
233 Operator::Eq
234 | Operator::NotEq
235 | Operator::Lt
236 | Operator::LtEq
237 | Operator::Gt
238 | Operator::GtEq
239 );
240 let left_is_col = matches!(&**left, Expr::Column(_));
241 let right_is_col = matches!(&**right, Expr::Column(_));
242 let left_is_numeric_lit =
243 matches!(&**left, Expr::Literal(_)) && is_numeric_literal(left.as_ref());
244 let right_is_numeric_lit =
245 matches!(&**right, Expr::Literal(_)) && is_numeric_literal(right.as_ref());
246 let left_is_string_lit = matches!(
247 &**left,
248 Expr::Literal(lv) if lv.get_datatype() == DataType::String
249 );
250 let right_is_string_lit = matches!(
251 &**right,
252 Expr::Literal(lv) if lv.get_datatype() == DataType::String
253 );
254 let root_is_col_vs_numeric = is_comparison_op
255 && ((left_is_col && right_is_numeric_lit)
256 || (right_is_col && left_is_numeric_lit));
257 let root_is_col_vs_string = is_comparison_op
258 && ((left_is_col && right_is_string_lit)
259 || (right_is_col && left_is_string_lit));
260 if root_is_col_vs_numeric {
261 let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
262 let lit_ty = match &**right {
263 Expr::Literal(lv) => literal_dtype(lv),
264 _ => DataType::Float64,
265 };
266 coerce_for_pyspark_comparison(
267 (*left).as_ref().clone(),
268 (*right).as_ref().clone(),
269 &DataType::String,
270 &lit_ty,
271 op,
272 )
273 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
274 } else {
275 let lit_ty = match &**left {
276 Expr::Literal(lv) => literal_dtype(lv),
277 _ => DataType::Float64,
278 };
279 coerce_for_pyspark_comparison(
280 (*left).as_ref().clone(),
281 (*right).as_ref().clone(),
282 &lit_ty,
283 &DataType::String,
284 op,
285 )
286 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
287 };
288 Expr::BinaryExpr {
289 left: Arc::new(new_left),
290 op: *op,
291 right: Arc::new(new_right),
292 }
293 } else if root_is_col_vs_string {
294 let col_name = if left_is_col {
295 if let Expr::Column(n) = &**left {
296 n.as_str()
297 } else {
298 unreachable!()
299 }
300 } else if let Expr::Column(n) = &**right {
301 n.as_str()
302 } else {
303 unreachable!()
304 };
305 if let Some(col_dtype) = self.get_column_dtype(col_name) {
306 if matches!(col_dtype, DataType::Date | DataType::Datetime(_, _)) {
307 let (left_ty, right_ty) = if left_is_col {
308 (col_dtype.clone(), DataType::String)
309 } else {
310 (DataType::String, col_dtype.clone())
311 };
312 let (new_left, new_right) = coerce_for_pyspark_comparison(
313 (*left).as_ref().clone(),
314 (*right).as_ref().clone(),
315 &left_ty,
316 &right_ty,
317 op,
318 )
319 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
320 return Ok(Expr::BinaryExpr {
321 left: Arc::new(new_left),
322 op: *op,
323 right: Arc::new(new_right),
324 });
325 }
326 }
327 expr
328 } else if is_comparison_op && left_is_col && right_is_col {
329 let left_name = if let Expr::Column(n) = &**left {
332 n.as_str()
333 } else {
334 unreachable!()
335 };
336 let right_name = if let Expr::Column(n) = &**right {
337 n.as_str()
338 } else {
339 unreachable!()
340 };
341 if let (Some(left_ty), Some(right_ty)) = (
342 self.get_column_dtype(left_name),
343 self.get_column_dtype(right_name),
344 ) {
345 if left_ty != right_ty {
346 if let Ok((new_left, new_right)) = coerce_for_pyspark_comparison(
347 (*left).as_ref().clone(),
348 (*right).as_ref().clone(),
349 &left_ty,
350 &right_ty,
351 op,
352 ) {
353 return Ok(Expr::BinaryExpr {
354 left: Arc::new(new_left),
355 op: *op,
356 right: Arc::new(new_right),
357 });
358 }
359 }
360 }
361 expr
362 } else {
363 expr
364 }
365 } else {
366 expr
367 }
368 };
369
370 expr.try_map_expr(move |e| {
372 if let Expr::BinaryExpr { left, op, right } = e {
373 let is_comparison_op = matches!(
374 op,
375 Operator::Eq
376 | Operator::NotEq
377 | Operator::Lt
378 | Operator::LtEq
379 | Operator::Gt
380 | Operator::GtEq
381 );
382 if !is_comparison_op {
383 return Ok(Expr::BinaryExpr { left, op, right });
384 }
385
386 let left_is_col = matches!(&*left, Expr::Column(_));
387 let right_is_col = matches!(&*right, Expr::Column(_));
388 let left_is_lit = matches!(&*left, Expr::Literal(_));
389 let right_is_lit = matches!(&*right, Expr::Literal(_));
390 let left_is_string_lit =
391 matches!(&*left, Expr::Literal(lv) if lv.get_datatype() == DataType::String);
392 let right_is_string_lit =
393 matches!(&*right, Expr::Literal(lv) if lv.get_datatype() == DataType::String);
394
395 let left_is_numeric_lit = left_is_lit && is_numeric_literal(left.as_ref());
396 let right_is_numeric_lit = right_is_lit && is_numeric_literal(right.as_ref());
397
398 let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
402 let lit_ty = match &*right {
403 Expr::Literal(lv) => literal_dtype(lv),
404 _ => DataType::Float64,
405 };
406 coerce_for_pyspark_comparison(
407 (*left).clone(),
408 (*right).clone(),
409 &DataType::String,
410 &lit_ty,
411 &op,
412 )
413 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
414 } else if right_is_col && left_is_numeric_lit {
415 let lit_ty = match &*left {
416 Expr::Literal(lv) => literal_dtype(lv),
417 _ => DataType::Float64,
418 };
419 coerce_for_pyspark_comparison(
420 (*left).clone(),
421 (*right).clone(),
422 &lit_ty,
423 &DataType::String,
424 &op,
425 )
426 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
427 } else if (left_is_col && right_is_string_lit)
428 || (right_is_col && left_is_string_lit)
429 {
430 let col_name = if left_is_col {
431 if let Expr::Column(n) = &*left {
432 n.as_str()
433 } else {
434 unreachable!()
435 }
436 } else if let Expr::Column(n) = &*right {
437 n.as_str()
438 } else {
439 unreachable!()
440 };
441 if let Some(col_dtype) = self.get_column_dtype(col_name) {
442 if matches!(col_dtype, DataType::Date | DataType::Datetime(_, _)) {
443 let (left_ty, right_ty) = if left_is_col {
444 (col_dtype.clone(), DataType::String)
445 } else {
446 (DataType::String, col_dtype.clone())
447 };
448 let (new_l, new_r) = coerce_for_pyspark_comparison(
449 (*left).clone(),
450 (*right).clone(),
451 &left_ty,
452 &right_ty,
453 &op,
454 )
455 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
456 return Ok(Expr::BinaryExpr {
457 left: Arc::new(new_l),
458 op,
459 right: Arc::new(new_r),
460 });
461 }
462 }
463 return Ok(Expr::BinaryExpr { left, op, right });
464 } else {
465 return Ok(Expr::BinaryExpr { left, op, right });
467 };
468
469 Ok(Expr::BinaryExpr {
470 left: Arc::new(new_left),
471 op,
472 right: Arc::new(new_right),
473 })
474 } else {
475 Ok(e)
476 }
477 })
478 }
479
480 fn schema_or_collect(&self) -> Result<Arc<Schema>, PolarsError> {
482 match &self.inner {
483 DataFrameInner::Eager(df) => Ok(Arc::clone(df.schema())),
484 DataFrameInner::Lazy(lf) => Ok(lf.clone().collect_schema()?),
485 }
486 }
487
488 pub fn resolve_column_name(&self, name: &str) -> Result<String, PolarsError> {
491 let schema = self.schema_or_collect()?;
492 let names: Vec<String> = schema
493 .iter_names_and_dtypes()
494 .map(|(n, _)| n.to_string())
495 .collect();
496 if self.case_sensitive {
497 if names.iter().any(|n| n == name) {
498 return Ok(name.to_string());
499 }
500 } else {
501 let name_lower = name.to_lowercase();
502 for n in &names {
503 if n.to_lowercase() == name_lower {
504 return Ok(n.clone());
505 }
506 }
507 }
508 let available = names.join(", ");
509 Err(PolarsError::ColumnNotFound(
510 format!(
511 "Column '{}' not found. Available columns: [{}]. Check spelling and case sensitivity (spark.sql.caseSensitive).",
512 name,
513 available
514 )
515 .into(),
516 ))
517 }
518
519 pub fn schema(&self) -> Result<StructType, PolarsError> {
521 let s = self.schema_or_collect()?;
522 Ok(StructType::from_polars_schema(&s))
523 }
524
525 pub fn schema_engine(&self) -> Result<StructType, EngineError> {
527 self.schema().map_err(EngineError::from)
528 }
529
530 pub fn get_column_dtype(&self, name: &str) -> Option<DataType> {
532 let resolved = self.resolve_column_name(name).ok()?;
533 self.schema_or_collect()
534 .ok()?
535 .iter_names_and_dtypes()
536 .find(|(n, _)| n.to_string() == resolved)
537 .map(|(_, dt)| dt.clone())
538 }
539
540 pub fn get_column_data_type(&self, name: &str) -> Option<crate::schema::DataType> {
543 let resolved = self.resolve_column_name(name).ok()?;
544 let st = self.schema().ok()?;
545 st.fields()
546 .iter()
547 .find(|f| f.name == resolved)
548 .map(|f| f.data_type.clone())
549 }
550
551 pub fn columns(&self) -> Result<Vec<String>, PolarsError> {
553 let schema = self.schema_or_collect()?;
554 Ok(schema
555 .iter_names_and_dtypes()
556 .map(|(n, _)| n.to_string())
557 .collect())
558 }
559
560 pub fn columns_engine(&self) -> Result<Vec<String>, EngineError> {
562 self.columns().map_err(EngineError::from)
563 }
564
565 pub fn count(&self) -> Result<usize, PolarsError> {
567 Ok(self.collect_inner()?.height())
568 }
569
570 pub fn count_engine(&self) -> Result<usize, EngineError> {
572 self.count().map_err(EngineError::from)
573 }
574
575 pub fn show(&self, n: Option<usize>) -> Result<(), PolarsError> {
577 let n = n.unwrap_or(20);
578 let df = self.collect_inner()?;
579 println!("{}", df.head(Some(n)));
580 Ok(())
581 }
582
583 pub fn collect(&self) -> Result<Arc<PlDataFrame>, PolarsError> {
585 self.collect_inner()
586 }
587
588 pub fn collect_as_json_rows_engine(
590 &self,
591 ) -> Result<Vec<HashMap<String, JsonValue>>, EngineError> {
592 self.collect_as_json_rows().map_err(EngineError::from)
593 }
594
595 pub fn collect_as_json_rows(&self) -> Result<Vec<HashMap<String, JsonValue>>, PolarsError> {
597 let collected = self.collect_inner()?;
598 let names = collected.get_column_names();
599 let nrows = collected.height();
600 let mut rows = Vec::with_capacity(nrows);
601 for i in 0..nrows {
602 let mut row = HashMap::with_capacity(names.len());
603 for (col_idx, name) in names.iter().enumerate() {
604 let s = collected
605 .columns()
606 .get(col_idx)
607 .ok_or_else(|| PolarsError::ComputeError("column index out of range".into()))?;
608 let av = s.get(i)?;
609 let jv = any_value_to_json(&av, s.dtype());
610 row.insert(name.to_string(), jv);
611 }
612 rows.push(row);
613 }
614 Ok(rows)
615 }
616
617 pub fn to_json_rows(&self) -> Result<String, EngineError> {
620 let rows = self.collect_as_json_rows()?;
621 serde_json::to_string(&rows).map_err(Into::into)
622 }
623
624 pub fn select_exprs(&self, exprs: Vec<Expr>) -> Result<DataFrame, PolarsError> {
628 transformations::select_with_exprs(self, exprs, self.case_sensitive)
629 }
630
631 pub fn select(&self, cols: Vec<&str>) -> Result<DataFrame, PolarsError> {
634 let resolved: Vec<String> = cols
635 .iter()
636 .map(|c| self.resolve_column_name(c))
637 .collect::<Result<Vec<_>, _>>()?;
638 let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
639 let mut result = transformations::select(self, refs, self.case_sensitive)?;
640 if !self.case_sensitive {
642 for (requested, res) in cols.iter().zip(resolved.iter()) {
643 if *requested != res.as_str() {
644 result = result.with_column_renamed(res, requested)?;
645 }
646 }
647 }
648 Ok(result)
649 }
650
651 pub fn select_engine(&self, cols: Vec<&str>) -> Result<DataFrame, EngineError> {
653 self.select(cols).map_err(EngineError::from)
654 }
655
656 pub fn filter(&self, condition: Expr) -> Result<DataFrame, PolarsError> {
658 transformations::filter(self, condition, self.case_sensitive)
659 }
660
661 pub fn filter_engine(&self, condition: Expr) -> Result<DataFrame, EngineError> {
663 self.filter(condition).map_err(EngineError::from)
664 }
665
666 pub fn column(&self, name: &str) -> Result<Column, PolarsError> {
669 let resolved = self.resolve_column_name(name)?;
670 Ok(Column::new(resolved))
671 }
672
673 pub fn with_column(&self, column_name: &str, col: &Column) -> Result<DataFrame, PolarsError> {
676 transformations::with_column(self, column_name, col, self.case_sensitive)
677 }
678
679 pub fn with_column_engine(
681 &self,
682 column_name: &str,
683 col: &Column,
684 ) -> Result<DataFrame, EngineError> {
685 self.with_column(column_name, col)
686 .map_err(EngineError::from)
687 }
688
689 pub fn with_column_expr(
691 &self,
692 column_name: &str,
693 expr: Expr,
694 ) -> Result<DataFrame, PolarsError> {
695 let col = Column::from_expr(expr, None);
696 self.with_column(column_name, &col)
697 }
698
699 pub fn group_by(&self, column_names: Vec<&str>) -> Result<GroupedData, PolarsError> {
702 use polars::prelude::*;
703 let resolved: Vec<String> = column_names
704 .iter()
705 .map(|c| self.resolve_column_name(c))
706 .collect::<Result<Vec<_>, _>>()?;
707 let exprs: Vec<Expr> = resolved.iter().map(|name| col(name.as_str())).collect();
708 let lf = self.lazy_frame();
709 let lazy_grouped = lf.clone().group_by(exprs);
710 Ok(GroupedData {
711 lf,
712 lazy_grouped,
713 grouping_cols: resolved,
714 case_sensitive: self.case_sensitive,
715 })
716 }
717
718 pub fn group_by_engine(&self, column_names: Vec<&str>) -> Result<GroupedData, EngineError> {
720 self.group_by(column_names).map_err(EngineError::from)
721 }
722
723 pub fn group_by_exprs(
726 &self,
727 exprs: Vec<Expr>,
728 grouping_col_names: Vec<String>,
729 ) -> Result<GroupedData, PolarsError> {
730 use polars::prelude::*;
731 if exprs.len() != grouping_col_names.len() {
732 return Err(PolarsError::ComputeError(
733 format!(
734 "group_by_exprs: {} exprs but {} names",
735 exprs.len(),
736 grouping_col_names.len()
737 )
738 .into(),
739 ));
740 }
741 let resolved: Vec<Expr> = exprs
742 .into_iter()
743 .map(|e| self.resolve_expr_column_names(e))
744 .collect::<Result<Vec<_>, _>>()?;
745 let lf = self.lazy_frame();
746 let lazy_grouped = lf.clone().group_by(resolved);
747 Ok(GroupedData {
748 lf,
749 lazy_grouped,
750 grouping_cols: grouping_col_names,
751 case_sensitive: self.case_sensitive,
752 })
753 }
754
755 pub fn cube(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
757 let resolved: Vec<String> = column_names
758 .iter()
759 .map(|c| self.resolve_column_name(c))
760 .collect::<Result<Vec<_>, _>>()?;
761 Ok(CubeRollupData {
762 lf: self.lazy_frame(),
763 grouping_cols: resolved,
764 case_sensitive: self.case_sensitive,
765 is_cube: true,
766 })
767 }
768
769 pub fn rollup(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
771 let resolved: Vec<String> = column_names
772 .iter()
773 .map(|c| self.resolve_column_name(c))
774 .collect::<Result<Vec<_>, _>>()?;
775 Ok(CubeRollupData {
776 lf: self.lazy_frame(),
777 grouping_cols: resolved,
778 case_sensitive: self.case_sensitive,
779 is_cube: false,
780 })
781 }
782
783 pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
787 let resolved: Vec<Expr> = aggregations
788 .into_iter()
789 .map(|e| self.resolve_expr_column_names(e))
790 .collect::<Result<Vec<_>, _>>()?;
791 let disambiguated = aggregations::disambiguate_agg_output_names(resolved);
792 let pl_df = self.lazy_frame().select(disambiguated).collect()?;
793 Ok(Self::from_polars_with_options(pl_df, self.case_sensitive))
794 }
795
796 pub fn join(
799 &self,
800 other: &DataFrame,
801 on: Vec<&str>,
802 how: JoinType,
803 ) -> Result<DataFrame, PolarsError> {
804 let resolved: Vec<String> = on
805 .iter()
806 .map(|c| self.resolve_column_name(c))
807 .collect::<Result<Vec<_>, _>>()?;
808 let on_refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
809 join(self, other, on_refs, how, self.case_sensitive)
810 }
811
812 pub fn order_by(
817 &self,
818 column_names: Vec<&str>,
819 ascending: Vec<bool>,
820 ) -> Result<DataFrame, PolarsError> {
821 let resolved: Vec<String> = column_names
822 .iter()
823 .map(|c| self.resolve_column_name(c))
824 .collect::<Result<Vec<_>, _>>()?;
825 let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
826 transformations::order_by(self, refs, ascending, self.case_sensitive)
827 }
828
829 pub fn order_by_exprs(&self, sort_orders: Vec<SortOrder>) -> Result<DataFrame, PolarsError> {
831 transformations::order_by_exprs(self, sort_orders, self.case_sensitive)
832 }
833
834 pub fn union(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
836 transformations::union(self, other, self.case_sensitive)
837 }
838
839 pub fn union_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
841 self.union(other)
842 }
843
844 pub fn union_by_name(
846 &self,
847 other: &DataFrame,
848 allow_missing_columns: bool,
849 ) -> Result<DataFrame, PolarsError> {
850 transformations::union_by_name(self, other, allow_missing_columns, self.case_sensitive)
851 }
852
853 pub fn distinct(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
855 transformations::distinct(self, subset, self.case_sensitive)
856 }
857
858 pub fn drop(&self, columns: Vec<&str>) -> Result<DataFrame, PolarsError> {
860 transformations::drop(self, columns, self.case_sensitive)
861 }
862
863 pub fn dropna(
865 &self,
866 subset: Option<Vec<&str>>,
867 how: &str,
868 thresh: Option<usize>,
869 ) -> Result<DataFrame, PolarsError> {
870 transformations::dropna(self, subset, how, thresh, self.case_sensitive)
871 }
872
873 pub fn fillna(&self, value: Expr, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
875 transformations::fillna(self, value, subset, self.case_sensitive)
876 }
877
878 pub fn limit(&self, n: usize) -> Result<DataFrame, PolarsError> {
880 transformations::limit(self, n, self.case_sensitive)
881 }
882
883 pub fn limit_engine(&self, n: usize) -> Result<DataFrame, EngineError> {
885 self.limit(n).map_err(EngineError::from)
886 }
887
888 pub fn with_column_renamed(
890 &self,
891 old_name: &str,
892 new_name: &str,
893 ) -> Result<DataFrame, PolarsError> {
894 transformations::with_column_renamed(self, old_name, new_name, self.case_sensitive)
895 }
896
897 pub fn replace(
899 &self,
900 column_name: &str,
901 old_value: Expr,
902 new_value: Expr,
903 ) -> Result<DataFrame, PolarsError> {
904 transformations::replace(self, column_name, old_value, new_value, self.case_sensitive)
905 }
906
907 pub fn cross_join(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
909 transformations::cross_join(self, other, self.case_sensitive)
910 }
911
912 pub fn describe(&self) -> Result<DataFrame, PolarsError> {
914 transformations::describe(self, self.case_sensitive)
915 }
916
917 pub fn cache(&self) -> Result<DataFrame, PolarsError> {
919 Ok(self.clone())
920 }
921
922 pub fn persist(&self) -> Result<DataFrame, PolarsError> {
924 Ok(self.clone())
925 }
926
927 pub fn unpersist(&self) -> Result<DataFrame, PolarsError> {
929 Ok(self.clone())
930 }
931
932 pub fn subtract(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
934 transformations::subtract(self, other, self.case_sensitive)
935 }
936
937 pub fn intersect(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
939 transformations::intersect(self, other, self.case_sensitive)
940 }
941
942 pub fn sample(
944 &self,
945 with_replacement: bool,
946 fraction: f64,
947 seed: Option<u64>,
948 ) -> Result<DataFrame, PolarsError> {
949 transformations::sample(self, with_replacement, fraction, seed, self.case_sensitive)
950 }
951
952 pub fn random_split(
954 &self,
955 weights: &[f64],
956 seed: Option<u64>,
957 ) -> Result<Vec<DataFrame>, PolarsError> {
958 transformations::random_split(self, weights, seed, self.case_sensitive)
959 }
960
961 pub fn sample_by(
964 &self,
965 col_name: &str,
966 fractions: &[(Expr, f64)],
967 seed: Option<u64>,
968 ) -> Result<DataFrame, PolarsError> {
969 transformations::sample_by(self, col_name, fractions, seed, self.case_sensitive)
970 }
971
972 pub fn first(&self) -> Result<DataFrame, PolarsError> {
974 transformations::first(self, self.case_sensitive)
975 }
976
977 pub fn head(&self, n: usize) -> Result<DataFrame, PolarsError> {
979 transformations::head(self, n, self.case_sensitive)
980 }
981
982 pub fn take(&self, n: usize) -> Result<DataFrame, PolarsError> {
984 transformations::take(self, n, self.case_sensitive)
985 }
986
987 pub fn tail(&self, n: usize) -> Result<DataFrame, PolarsError> {
989 transformations::tail(self, n, self.case_sensitive)
990 }
991
992 pub fn is_empty(&self) -> bool {
994 transformations::is_empty(self)
995 }
996
997 pub fn to_df(&self, names: Vec<&str>) -> Result<DataFrame, PolarsError> {
999 transformations::to_df(self, &names, self.case_sensitive)
1000 }
1001
1002 pub fn stat(&self) -> DataFrameStat<'_> {
1004 DataFrameStat { df: self }
1005 }
1006
1007 pub fn corr(&self) -> Result<DataFrame, PolarsError> {
1009 self.stat().corr_matrix()
1010 }
1011
1012 pub fn corr_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
1014 self.stat().corr(col1, col2)
1015 }
1016
1017 pub fn cov_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
1019 self.stat().cov(col1, col2)
1020 }
1021
1022 pub fn summary(&self) -> Result<DataFrame, PolarsError> {
1024 self.describe()
1025 }
1026
1027 pub fn to_json(&self) -> Result<Vec<String>, PolarsError> {
1029 transformations::to_json(self)
1030 }
1031
1032 pub fn explain(&self) -> String {
1034 transformations::explain(self)
1035 }
1036
1037 pub fn print_schema(&self) -> Result<String, PolarsError> {
1039 transformations::print_schema(self)
1040 }
1041
1042 pub fn checkpoint(&self) -> Result<DataFrame, PolarsError> {
1044 Ok(self.clone())
1045 }
1046
1047 pub fn local_checkpoint(&self) -> Result<DataFrame, PolarsError> {
1049 Ok(self.clone())
1050 }
1051
1052 pub fn repartition(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
1054 Ok(self.clone())
1055 }
1056
1057 pub fn repartition_by_range(
1059 &self,
1060 _num_partitions: usize,
1061 _cols: Vec<&str>,
1062 ) -> Result<DataFrame, PolarsError> {
1063 Ok(self.clone())
1064 }
1065
1066 pub fn dtypes(&self) -> Result<Vec<(String, String)>, PolarsError> {
1068 let schema = self.schema_or_collect()?;
1069 Ok(schema
1070 .iter_names_and_dtypes()
1071 .map(|(name, dtype)| (name.to_string(), format!("{dtype:?}")))
1072 .collect())
1073 }
1074
1075 pub fn sort_within_partitions(
1077 &self,
1078 _cols: &[crate::functions::SortOrder],
1079 ) -> Result<DataFrame, PolarsError> {
1080 Ok(self.clone())
1081 }
1082
1083 pub fn coalesce(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
1085 Ok(self.clone())
1086 }
1087
1088 pub fn hint(&self, _name: &str, _params: &[i32]) -> Result<DataFrame, PolarsError> {
1090 Ok(self.clone())
1091 }
1092
1093 pub fn is_local(&self) -> bool {
1095 true
1096 }
1097
1098 pub fn input_files(&self) -> Vec<String> {
1100 Vec::new()
1101 }
1102
1103 pub fn same_semantics(&self, _other: &DataFrame) -> bool {
1105 false
1106 }
1107
1108 pub fn semantic_hash(&self) -> u64 {
1110 0
1111 }
1112
1113 pub fn observe(&self, _name: &str, _expr: Expr) -> Result<DataFrame, PolarsError> {
1115 Ok(self.clone())
1116 }
1117
1118 pub fn with_watermark(
1120 &self,
1121 _event_time: &str,
1122 _delay: &str,
1123 ) -> Result<DataFrame, PolarsError> {
1124 Ok(self.clone())
1125 }
1126
1127 pub fn select_expr(&self, exprs: &[String]) -> Result<DataFrame, PolarsError> {
1129 transformations::select_expr(self, exprs, self.case_sensitive)
1130 }
1131
1132 pub fn col_regex(&self, pattern: &str) -> Result<DataFrame, PolarsError> {
1134 transformations::col_regex(self, pattern, self.case_sensitive)
1135 }
1136
1137 pub fn with_columns(&self, exprs: &[(String, Column)]) -> Result<DataFrame, PolarsError> {
1139 transformations::with_columns(self, exprs, self.case_sensitive)
1140 }
1141
1142 pub fn with_columns_renamed(
1144 &self,
1145 renames: &[(String, String)],
1146 ) -> Result<DataFrame, PolarsError> {
1147 transformations::with_columns_renamed(self, renames, self.case_sensitive)
1148 }
1149
1150 pub fn na(&self) -> DataFrameNa<'_> {
1152 DataFrameNa { df: self }
1153 }
1154
1155 pub fn offset(&self, n: usize) -> Result<DataFrame, PolarsError> {
1157 transformations::offset(self, n, self.case_sensitive)
1158 }
1159
1160 pub fn transform<F>(&self, f: F) -> Result<DataFrame, PolarsError>
1162 where
1163 F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
1164 {
1165 transformations::transform(self, f)
1166 }
1167
1168 pub fn freq_items(&self, columns: &[&str], support: f64) -> Result<DataFrame, PolarsError> {
1170 transformations::freq_items(self, columns, support, self.case_sensitive)
1171 }
1172
1173 pub fn approx_quantile(
1175 &self,
1176 column: &str,
1177 probabilities: &[f64],
1178 ) -> Result<DataFrame, PolarsError> {
1179 transformations::approx_quantile(self, column, probabilities, self.case_sensitive)
1180 }
1181
1182 pub fn crosstab(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
1184 transformations::crosstab(self, col1, col2, self.case_sensitive)
1185 }
1186
1187 pub fn melt(&self, id_vars: &[&str], value_vars: &[&str]) -> Result<DataFrame, PolarsError> {
1189 transformations::melt(self, id_vars, value_vars, self.case_sensitive)
1190 }
1191
1192 pub fn unpivot(&self, ids: &[&str], values: &[&str]) -> Result<DataFrame, PolarsError> {
1194 transformations::melt(self, ids, values, self.case_sensitive)
1195 }
1196
1197 pub fn pivot(
1199 &self,
1200 _pivot_col: &str,
1201 _values: Option<Vec<&str>>,
1202 ) -> Result<DataFrame, PolarsError> {
1203 Err(PolarsError::InvalidOperation(
1204 "pivot is not yet implemented; use crosstab(col1, col2) for two-column cross-tabulation."
1205 .into(),
1206 ))
1207 }
1208
1209 pub fn except_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
1211 transformations::except_all(self, other, self.case_sensitive)
1212 }
1213
1214 pub fn intersect_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
1216 transformations::intersect_all(self, other, self.case_sensitive)
1217 }
1218
1219 #[cfg(feature = "delta")]
1222 pub fn write_delta(
1223 &self,
1224 path: impl AsRef<std::path::Path>,
1225 overwrite: bool,
1226 ) -> Result<(), PolarsError> {
1227 crate::delta::write_delta(self.collect_inner()?.as_ref(), path, overwrite)
1228 }
1229
1230 #[cfg(not(feature = "delta"))]
1232 pub fn write_delta(
1233 &self,
1234 _path: impl AsRef<std::path::Path>,
1235 _overwrite: bool,
1236 ) -> Result<(), PolarsError> {
1237 Err(PolarsError::InvalidOperation(
1238 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1239 ))
1240 }
1241
1242 pub fn save_as_delta_table(&self, session: &crate::session::SparkSession, name: &str) {
1244 session.register_table(name, self.clone());
1245 }
1246
1247 pub fn write(&self) -> DataFrameWriter<'_> {
1249 DataFrameWriter {
1250 df: self,
1251 mode: WriteMode::Overwrite,
1252 format: WriteFormat::Parquet,
1253 options: HashMap::new(),
1254 partition_by: Vec::new(),
1255 }
1256 }
1257}
1258
1259#[derive(Clone, Copy, PartialEq, Eq)]
1261pub enum WriteMode {
1262 Overwrite,
1263 Append,
1264}
1265
1266#[derive(Clone, Copy, PartialEq, Eq)]
1268pub enum SaveMode {
1269 ErrorIfExists,
1271 Overwrite,
1273 Append,
1275 Ignore,
1277}
1278
1279#[derive(Clone, Copy)]
1281pub enum WriteFormat {
1282 Parquet,
1283 Csv,
1284 Json,
1285}
1286
1287pub struct DataFrameWriter<'a> {
1289 df: &'a DataFrame,
1290 mode: WriteMode,
1291 format: WriteFormat,
1292 options: HashMap<String, String>,
1293 partition_by: Vec<String>,
1294}
1295
1296impl<'a> DataFrameWriter<'a> {
1297 pub fn mode(mut self, mode: WriteMode) -> Self {
1298 self.mode = mode;
1299 self
1300 }
1301
1302 pub fn format(mut self, format: WriteFormat) -> Self {
1303 self.format = format;
1304 self
1305 }
1306
1307 pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1309 self.options.insert(key.into(), value.into());
1310 self
1311 }
1312
1313 pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
1315 for (k, v) in opts {
1316 self.options.insert(k, v);
1317 }
1318 self
1319 }
1320
1321 pub fn partition_by(mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> Self {
1323 self.partition_by = cols.into_iter().map(|s| s.into()).collect();
1324 self
1325 }
1326
1327 pub fn save_as_table(
1329 &self,
1330 session: &SparkSession,
1331 name: &str,
1332 mode: SaveMode,
1333 ) -> Result<(), PolarsError> {
1334 use polars::prelude::*;
1335 use std::fs;
1336 use std::path::Path;
1337
1338 let warehouse_path = session.warehouse_dir().map(|w| Path::new(w).join(name));
1339 let warehouse_exists = warehouse_path.as_ref().is_some_and(|p| p.is_dir());
1340
1341 fn persist_to_warehouse(
1342 df: &crate::dataframe::DataFrame,
1343 dir: &Path,
1344 ) -> Result<(), PolarsError> {
1345 use std::fs;
1346 fs::create_dir_all(dir).map_err(|e| {
1347 PolarsError::ComputeError(format!("saveAsTable: create dir: {e}").into())
1348 })?;
1349 let file_path = dir.join("data.parquet");
1350 df.write()
1351 .mode(crate::dataframe::WriteMode::Overwrite)
1352 .format(crate::dataframe::WriteFormat::Parquet)
1353 .save(&file_path)
1354 }
1355
1356 let final_df = match mode {
1357 SaveMode::ErrorIfExists => {
1358 if session.saved_table_exists(name) || warehouse_exists {
1359 return Err(PolarsError::InvalidOperation(
1360 format!(
1361 "Table or view '{name}' already exists. SaveMode is ErrorIfExists."
1362 )
1363 .into(),
1364 ));
1365 }
1366 if let Some(ref p) = warehouse_path {
1367 persist_to_warehouse(self.df, p)?;
1368 }
1369 self.df.clone()
1370 }
1371 SaveMode::Overwrite => {
1372 if let Some(ref p) = warehouse_path {
1373 let _ = fs::remove_dir_all(p);
1374 persist_to_warehouse(self.df, p)?;
1375 }
1376 self.df.clone()
1377 }
1378 SaveMode::Append => {
1379 let existing_pl = if let Some(existing) = session.get_saved_table(name) {
1380 existing.collect_inner()?.as_ref().clone()
1381 } else if let (Some(ref p), true) = (warehouse_path.as_ref(), warehouse_exists) {
1382 let data_file = p.join("data.parquet");
1384 let read_path = if data_file.is_file() {
1385 data_file.as_path()
1386 } else {
1387 p.as_ref()
1388 };
1389 let pl_path =
1390 polars::prelude::PlRefPath::try_from_path(read_path).map_err(|e| {
1391 PolarsError::ComputeError(
1392 format!("saveAsTable append: path: {e}").into(),
1393 )
1394 })?;
1395 let lf = LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default()).map_err(
1396 |e| {
1397 PolarsError::ComputeError(
1398 format!("saveAsTable append: read warehouse: {e}").into(),
1399 )
1400 },
1401 )?;
1402 lf.collect().map_err(|e| {
1403 PolarsError::ComputeError(
1404 format!("saveAsTable append: collect: {e}").into(),
1405 )
1406 })?
1407 } else {
1408 session.register_table(name, self.df.clone());
1410 if let Some(ref p) = warehouse_path {
1411 persist_to_warehouse(self.df, p)?;
1412 }
1413 return Ok(());
1414 };
1415 let new_pl = self.df.collect_inner()?.as_ref().clone();
1416 let existing_cols: Vec<&str> = existing_pl
1417 .get_column_names()
1418 .iter()
1419 .map(|s| s.as_str())
1420 .collect();
1421 let new_cols = new_pl.get_column_names();
1422 let missing: Vec<_> = existing_cols
1423 .iter()
1424 .filter(|c| !new_cols.iter().any(|n| n.as_str() == **c))
1425 .collect();
1426 if !missing.is_empty() {
1427 return Err(PolarsError::InvalidOperation(
1428 format!(
1429 "saveAsTable append: new DataFrame missing columns: {:?}",
1430 missing
1431 )
1432 .into(),
1433 ));
1434 }
1435 let new_ordered = new_pl.select(existing_cols.iter().copied())?;
1436 let mut combined = existing_pl;
1437 combined.vstack_mut(&new_ordered)?;
1438 let merged = crate::dataframe::DataFrame::from_polars_with_options(
1439 combined,
1440 self.df.case_sensitive,
1441 );
1442 if let Some(ref p) = warehouse_path {
1443 let _ = fs::remove_dir_all(p);
1444 persist_to_warehouse(&merged, p)?;
1445 }
1446 merged
1447 }
1448 SaveMode::Ignore => {
1449 if session.saved_table_exists(name) || warehouse_exists {
1450 return Ok(());
1451 }
1452 if let Some(ref p) = warehouse_path {
1453 persist_to_warehouse(self.df, p)?;
1454 }
1455 self.df.clone()
1456 }
1457 };
1458 session.register_table(name, final_df);
1459 Ok(())
1460 }
1461
1462 pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1464 DataFrameWriter {
1465 df: self.df,
1466 mode: self.mode,
1467 format: WriteFormat::Parquet,
1468 options: self.options.clone(),
1469 partition_by: self.partition_by.clone(),
1470 }
1471 .save(path)
1472 }
1473
1474 pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1476 DataFrameWriter {
1477 df: self.df,
1478 mode: self.mode,
1479 format: WriteFormat::Csv,
1480 options: self.options.clone(),
1481 partition_by: self.partition_by.clone(),
1482 }
1483 .save(path)
1484 }
1485
1486 pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1488 DataFrameWriter {
1489 df: self.df,
1490 mode: self.mode,
1491 format: WriteFormat::Json,
1492 options: self.options.clone(),
1493 partition_by: self.partition_by.clone(),
1494 }
1495 .save(path)
1496 }
1497
1498 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1501 use polars::prelude::*;
1502 let path = path.as_ref();
1503 let to_write: PlDataFrame = match self.mode {
1504 WriteMode::Overwrite => self.df.collect_inner()?.as_ref().clone(),
1505 WriteMode::Append => {
1506 if self.partition_by.is_empty() {
1507 let existing: Option<PlDataFrame> = if path.exists() && path.is_file() {
1508 match self.format {
1509 WriteFormat::Parquet => polars::prelude::PlRefPath::try_from_path(path)
1510 .ok()
1511 .and_then(|pl_path| {
1512 LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default())
1513 .and_then(|lf| lf.collect())
1514 .ok()
1515 }),
1516 WriteFormat::Csv => polars::prelude::PlRefPath::try_from_path(path)
1517 .ok()
1518 .and_then(|pl_path| {
1519 LazyCsvReader::new(pl_path)
1520 .with_has_header(true)
1521 .finish()
1522 .and_then(|lf| lf.collect())
1523 .ok()
1524 }),
1525 WriteFormat::Json => polars::prelude::PlRefPath::try_from_path(path)
1526 .ok()
1527 .and_then(|pl_path| {
1528 LazyJsonLineReader::new(pl_path)
1529 .finish()
1530 .and_then(|lf| lf.collect())
1531 .ok()
1532 }),
1533 }
1534 } else {
1535 None
1536 };
1537 match existing {
1538 Some(existing) => {
1539 let lfs: [LazyFrame; 2] = [
1540 existing.clone().lazy(),
1541 self.df.collect_inner()?.as_ref().clone().lazy(),
1542 ];
1543 concat(lfs, UnionArgs::default())?.collect()?
1544 }
1545 None => self.df.collect_inner()?.as_ref().clone(),
1546 }
1547 } else {
1548 self.df.collect_inner()?.as_ref().clone()
1549 }
1550 }
1551 };
1552
1553 if !self.partition_by.is_empty() {
1554 return self.save_partitioned(path, &to_write);
1555 }
1556
1557 match self.format {
1558 WriteFormat::Parquet => {
1559 let mut file = std::fs::File::create(path).map_err(|e| {
1560 PolarsError::ComputeError(format!("write parquet create: {e}").into())
1561 })?;
1562 let mut df_mut = to_write;
1563 ParquetWriter::new(&mut file)
1564 .finish(&mut df_mut)
1565 .map_err(|e| PolarsError::ComputeError(format!("write parquet: {e}").into()))?;
1566 }
1567 WriteFormat::Csv => {
1568 let has_header = self
1569 .options
1570 .get("header")
1571 .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
1572 .unwrap_or(true);
1573 let delimiter = self
1574 .options
1575 .get("sep")
1576 .and_then(|s| s.bytes().next())
1577 .unwrap_or(b',');
1578 let mut file = std::fs::File::create(path).map_err(|e| {
1579 PolarsError::ComputeError(format!("write csv create: {e}").into())
1580 })?;
1581 CsvWriter::new(&mut file)
1582 .include_header(has_header)
1583 .with_separator(delimiter)
1584 .finish(&mut to_write.clone())
1585 .map_err(|e| PolarsError::ComputeError(format!("write csv: {e}").into()))?;
1586 }
1587 WriteFormat::Json => {
1588 let mut file = std::fs::File::create(path).map_err(|e| {
1589 PolarsError::ComputeError(format!("write json create: {e}").into())
1590 })?;
1591 JsonWriter::new(&mut file)
1592 .finish(&mut to_write.clone())
1593 .map_err(|e| PolarsError::ComputeError(format!("write json: {e}").into()))?;
1594 }
1595 }
1596 Ok(())
1597 }
1598
1599 fn save_partitioned(&self, path: &Path, to_write: &PlDataFrame) -> Result<(), PolarsError> {
1601 use polars::prelude::*;
1602 let resolved: Vec<String> = self
1603 .partition_by
1604 .iter()
1605 .map(|c| self.df.resolve_column_name(c))
1606 .collect::<Result<Vec<_>, _>>()?;
1607 let all_names = to_write.get_column_names();
1608 let data_cols: Vec<&str> = all_names
1609 .iter()
1610 .filter(|n| !resolved.iter().any(|r| r == n.as_str()))
1611 .map(|n| n.as_str())
1612 .collect();
1613
1614 let unique_keys = to_write
1615 .select(resolved.iter().map(|s| s.as_str()).collect::<Vec<_>>())?
1616 .unique::<Option<&[String]>, String>(
1617 None,
1618 polars::prelude::UniqueKeepStrategy::First,
1619 None,
1620 )?;
1621
1622 if self.mode == WriteMode::Overwrite && path.exists() {
1623 if path.is_dir() {
1624 std::fs::remove_dir_all(path).map_err(|e| {
1625 PolarsError::ComputeError(
1626 format!("write partitioned: remove_dir_all: {e}").into(),
1627 )
1628 })?;
1629 } else {
1630 std::fs::remove_file(path).map_err(|e| {
1631 PolarsError::ComputeError(format!("write partitioned: remove_file: {e}").into())
1632 })?;
1633 }
1634 }
1635 std::fs::create_dir_all(path).map_err(|e| {
1636 PolarsError::ComputeError(format!("write partitioned: create_dir_all: {e}").into())
1637 })?;
1638
1639 let ext = match self.format {
1640 WriteFormat::Parquet => "parquet",
1641 WriteFormat::Csv => "csv",
1642 WriteFormat::Json => "json",
1643 };
1644
1645 for row_idx in 0..unique_keys.height() {
1646 let row = unique_keys
1647 .get(row_idx)
1648 .ok_or_else(|| PolarsError::ComputeError("partition_row: get row".into()))?;
1649 let filter_expr = partition_row_to_filter_expr(&resolved, &row)?;
1650 let subset = to_write.clone().lazy().filter(filter_expr).collect()?;
1651 let subset = subset.select(data_cols.iter().copied())?;
1652 if subset.height() == 0 {
1653 continue;
1654 }
1655
1656 let part_path: std::path::PathBuf = resolved
1657 .iter()
1658 .zip(row.iter())
1659 .map(|(name, av)| format!("{}={}", name, format_partition_value(av)))
1660 .fold(path.to_path_buf(), |p, seg| p.join(seg));
1661 std::fs::create_dir_all(&part_path).map_err(|e| {
1662 PolarsError::ComputeError(
1663 format!("write partitioned: create_dir_all partition: {e}").into(),
1664 )
1665 })?;
1666
1667 let file_idx = if self.mode == WriteMode::Append {
1668 let suffix = format!(".{ext}");
1669 let max_n = std::fs::read_dir(&part_path)
1670 .map(|rd| {
1671 rd.filter_map(Result::ok)
1672 .filter_map(|e| {
1673 e.file_name().to_str().and_then(|s| {
1674 s.strip_prefix("part-")
1675 .and_then(|t| t.strip_suffix(&suffix))
1676 .and_then(|t| t.parse::<u32>().ok())
1677 })
1678 })
1679 .max()
1680 .unwrap_or(0)
1681 })
1682 .unwrap_or(0);
1683 max_n + 1
1684 } else {
1685 0
1686 };
1687 let filename = format!("part-{file_idx:05}.{ext}");
1688 let file_path = part_path.join(&filename);
1689
1690 match self.format {
1691 WriteFormat::Parquet => {
1692 let mut file = std::fs::File::create(&file_path).map_err(|e| {
1693 PolarsError::ComputeError(
1694 format!("write partitioned parquet create: {e}").into(),
1695 )
1696 })?;
1697 let mut df_mut = subset;
1698 ParquetWriter::new(&mut file)
1699 .finish(&mut df_mut)
1700 .map_err(|e| {
1701 PolarsError::ComputeError(
1702 format!("write partitioned parquet: {e}").into(),
1703 )
1704 })?;
1705 }
1706 WriteFormat::Csv => {
1707 let has_header = self
1708 .options
1709 .get("header")
1710 .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
1711 .unwrap_or(true);
1712 let delimiter = self
1713 .options
1714 .get("sep")
1715 .and_then(|s| s.bytes().next())
1716 .unwrap_or(b',');
1717 let mut file = std::fs::File::create(&file_path).map_err(|e| {
1718 PolarsError::ComputeError(
1719 format!("write partitioned csv create: {e}").into(),
1720 )
1721 })?;
1722 CsvWriter::new(&mut file)
1723 .include_header(has_header)
1724 .with_separator(delimiter)
1725 .finish(&mut subset.clone())
1726 .map_err(|e| {
1727 PolarsError::ComputeError(format!("write partitioned csv: {e}").into())
1728 })?;
1729 }
1730 WriteFormat::Json => {
1731 let mut file = std::fs::File::create(&file_path).map_err(|e| {
1732 PolarsError::ComputeError(
1733 format!("write partitioned json create: {e}").into(),
1734 )
1735 })?;
1736 JsonWriter::new(&mut file)
1737 .finish(&mut subset.clone())
1738 .map_err(|e| {
1739 PolarsError::ComputeError(format!("write partitioned json: {e}").into())
1740 })?;
1741 }
1742 }
1743 }
1744 Ok(())
1745 }
1746}
1747
1748impl Clone for DataFrame {
1749 fn clone(&self) -> Self {
1750 DataFrame {
1751 inner: match &self.inner {
1752 DataFrameInner::Eager(df) => DataFrameInner::Eager(df.clone()),
1753 DataFrameInner::Lazy(lf) => DataFrameInner::Lazy(lf.clone()),
1754 },
1755 case_sensitive: self.case_sensitive,
1756 alias: self.alias.clone(),
1757 }
1758 }
1759}
1760
1761fn format_partition_value(av: &AnyValue<'_>) -> String {
1764 let s = match av {
1765 AnyValue::Null => "__HIVE_DEFAULT_PARTITION__".to_string(),
1766 AnyValue::Boolean(b) => b.to_string(),
1767 AnyValue::Int32(i) => i.to_string(),
1768 AnyValue::Int64(i) => i.to_string(),
1769 AnyValue::UInt32(u) => u.to_string(),
1770 AnyValue::UInt64(u) => u.to_string(),
1771 AnyValue::Float32(f) => f.to_string(),
1772 AnyValue::Float64(f) => f.to_string(),
1773 AnyValue::String(s) => s.to_string(),
1774 AnyValue::StringOwned(s) => s.as_str().to_string(),
1775 AnyValue::Date(d) => d.to_string(),
1776 _ => av.to_string(),
1777 };
1778 s.replace([std::path::MAIN_SEPARATOR, '/'], "_")
1780}
1781
1782fn partition_row_to_filter_expr(
1784 col_names: &[String],
1785 row: &[AnyValue<'_>],
1786) -> Result<Expr, PolarsError> {
1787 if col_names.len() != row.len() {
1788 return Err(PolarsError::ComputeError(
1789 format!(
1790 "partition_row_to_filter_expr: {} columns but {} row values",
1791 col_names.len(),
1792 row.len()
1793 )
1794 .into(),
1795 ));
1796 }
1797 let mut pred = None::<Expr>;
1798 for (name, av) in col_names.iter().zip(row.iter()) {
1799 let clause = match av {
1800 AnyValue::Null => col(name.as_str()).is_null(),
1801 AnyValue::Boolean(b) => col(name.as_str()).eq(lit(*b)),
1802 AnyValue::Int32(i) => col(name.as_str()).eq(lit(*i)),
1803 AnyValue::Int64(i) => col(name.as_str()).eq(lit(*i)),
1804 AnyValue::UInt32(u) => col(name.as_str()).eq(lit(*u)),
1805 AnyValue::UInt64(u) => col(name.as_str()).eq(lit(*u)),
1806 AnyValue::Float32(f) => col(name.as_str()).eq(lit(*f)),
1807 AnyValue::Float64(f) => col(name.as_str()).eq(lit(*f)),
1808 AnyValue::String(s) => col(name.as_str()).eq(lit(s.to_string())),
1809 AnyValue::StringOwned(s) => col(name.as_str()).eq(lit(s.clone())),
1810 _ => {
1811 let s = av.to_string();
1813 col(name.as_str()).cast(DataType::String).eq(lit(s))
1814 }
1815 };
1816 pred = Some(match pred {
1817 None => clause,
1818 Some(p) => p.and(clause),
1819 });
1820 }
1821 Ok(pred.unwrap_or_else(|| lit(true)))
1822}
1823
1824fn is_map_format(dtype: &DataType) -> bool {
1826 if let DataType::List(inner) = dtype {
1827 if let DataType::Struct(fields) = inner.as_ref() {
1828 let has_key = fields.iter().any(|f| f.name == "key");
1829 let has_value = fields.iter().any(|f| f.name == "value");
1830 return has_key && has_value;
1831 }
1832 }
1833 false
1834}
1835
1836fn any_value_to_json(av: &AnyValue<'_>, dtype: &DataType) -> JsonValue {
1839 use serde_json::Map;
1840 match av {
1841 AnyValue::Null => JsonValue::Null,
1842 AnyValue::Boolean(b) => JsonValue::Bool(*b),
1843 AnyValue::Int32(i) => JsonValue::Number(serde_json::Number::from(*i)),
1844 AnyValue::Int64(i) => JsonValue::Number(serde_json::Number::from(*i)),
1845 AnyValue::UInt32(u) => JsonValue::Number(serde_json::Number::from(*u)),
1846 AnyValue::UInt64(u) => JsonValue::Number(serde_json::Number::from(*u)),
1847 AnyValue::Float32(f) => serde_json::Number::from_f64(f64::from(*f))
1848 .map(JsonValue::Number)
1849 .unwrap_or(JsonValue::Null),
1850 AnyValue::Float64(f) => serde_json::Number::from_f64(*f)
1851 .map(JsonValue::Number)
1852 .unwrap_or(JsonValue::Null),
1853 AnyValue::String(s) => JsonValue::String(s.to_string()),
1854 AnyValue::StringOwned(s) => JsonValue::String(s.to_string()),
1855 AnyValue::List(s) => {
1856 if is_map_format(dtype) {
1857 let mut obj = Map::new();
1859 for i in 0..s.len() {
1860 if let Ok(elem) = s.get(i) {
1861 let (k, v) = match &elem {
1862 AnyValue::Struct(_, _, fields) => {
1863 let mut k = None;
1864 let mut v = None;
1865 for (fld_av, fld) in elem._iter_struct_av().zip(fields.iter()) {
1866 if fld.name == "key" {
1867 k = fld_av
1868 .get_str()
1869 .map(|s| s.to_string())
1870 .or_else(|| Some(fld_av.to_string()));
1871 } else if fld.name == "value" {
1872 v = Some(any_value_to_json(&fld_av, &fld.dtype));
1873 }
1874 }
1875 (k, v)
1876 }
1877 AnyValue::StructOwned(payload) => {
1878 let (values, fields) = &**payload;
1879 let mut k = None;
1880 let mut v = None;
1881 for (fld_av, fld) in values.iter().zip(fields.iter()) {
1882 if fld.name == "key" {
1883 k = fld_av
1884 .get_str()
1885 .map(|s| s.to_string())
1886 .or_else(|| Some(fld_av.to_string()));
1887 } else if fld.name == "value" {
1888 v = Some(any_value_to_json(fld_av, &fld.dtype));
1889 }
1890 }
1891 (k, v)
1892 }
1893 _ => (None, None),
1894 };
1895 if let (Some(key), Some(val)) = (k, v) {
1896 obj.insert(key, val);
1897 }
1898 }
1899 }
1900 JsonValue::Object(obj)
1901 } else {
1902 let inner_dtype = match dtype {
1903 DataType::List(inner) => inner.as_ref(),
1904 _ => dtype,
1905 };
1906 let arr: Vec<JsonValue> = (0..s.len())
1907 .filter_map(|i| s.get(i).ok())
1908 .map(|a| any_value_to_json(&a, inner_dtype))
1909 .collect();
1910 JsonValue::Array(arr)
1911 }
1912 }
1913 AnyValue::Struct(_, _, fields) => {
1914 let mut obj = Map::new();
1915 for (fld_av, fld) in av._iter_struct_av().zip(fields.iter()) {
1916 obj.insert(fld.name.to_string(), any_value_to_json(&fld_av, &fld.dtype));
1917 }
1918 JsonValue::Object(obj)
1919 }
1920 AnyValue::StructOwned(payload) => {
1921 let (values, fields) = &**payload;
1922 let mut obj = Map::new();
1923 for (fld_av, fld) in values.iter().zip(fields.iter()) {
1924 obj.insert(fld.name.to_string(), any_value_to_json(fld_av, &fld.dtype));
1925 }
1926 JsonValue::Object(obj)
1927 }
1928 _ => JsonValue::Null,
1929 }
1930}
1931
1932#[cfg(test)]
1933mod tests {
1934 use super::*;
1935 use polars::prelude::{NamedFrom, Series};
1936
1937 #[test]
1939 fn coerce_string_numeric_root_in_filter() {
1940 let s = Series::new("str_col".into(), &["123", "456"]);
1941 let pl_df = polars::prelude::DataFrame::new_infer_height(vec![s.into()]).unwrap();
1942 let df = DataFrame::from_polars(pl_df);
1943 let expr = col("str_col").eq(lit(123i64));
1944 let out = df.filter(expr).unwrap();
1945 assert_eq!(out.count().unwrap(), 1);
1946 }
1947
1948 #[test]
1950 fn lazy_schema_columns_resolve_before_collect() {
1951 let spark = SparkSession::builder()
1952 .app_name("lazy_mod_tests")
1953 .get_or_create();
1954 let df = spark
1955 .create_dataframe(
1956 vec![
1957 (1i64, 25i64, "a".to_string()),
1958 (2i64, 30i64, "b".to_string()),
1959 ],
1960 vec!["id", "age", "name"],
1961 )
1962 .unwrap();
1963 assert_eq!(df.columns().unwrap(), vec!["id", "age", "name"]);
1964 assert_eq!(df.resolve_column_name("AGE").unwrap(), "age");
1965 assert!(df.get_column_dtype("id").unwrap().is_integer());
1966 }
1967
1968 #[test]
1970 fn lazy_from_lazy_produces_valid_df() {
1971 let _spark = SparkSession::builder()
1972 .app_name("lazy_mod_tests")
1973 .get_or_create();
1974 let pl_df = polars::prelude::df!("x" => &[1i64, 2, 3]).unwrap();
1975 let df = DataFrame::from_lazy_with_options(pl_df.lazy(), false);
1976 assert_eq!(df.columns().unwrap(), vec!["x"]);
1977 assert_eq!(df.count().unwrap(), 3);
1978 }
1979}