1mod aggregations;
4mod joins;
5mod stats;
6mod transformations;
7
8pub use aggregations::{CubeRollupData, GroupedData};
9pub use joins::{join, JoinType};
10pub use stats::DataFrameStat;
11pub use transformations::{
12 filter, order_by, order_by_exprs, select, select_with_exprs, with_column, DataFrameNa,
13};
14
15use crate::column::Column;
16use crate::functions::SortOrder;
17use crate::schema::StructType;
18use crate::session::SparkSession;
19use crate::type_coercion::coerce_for_pyspark_comparison;
20use polars::prelude::{
21 col, lit, AnyValue, DataFrame as PlDataFrame, DataType, Expr, PlSmallStr, PolarsError,
22 SchemaNamesAndDtypes,
23};
24use serde_json::Value as JsonValue;
25use std::collections::{HashMap, HashSet};
26use std::path::Path;
27use std::sync::Arc;
28
29const DEFAULT_CASE_SENSITIVE: bool = false;
31
32pub struct DataFrame {
35 pub(crate) df: Arc<PlDataFrame>,
36 pub(crate) case_sensitive: bool,
38}
39
40impl DataFrame {
41 pub fn from_polars(df: PlDataFrame) -> Self {
43 DataFrame {
44 df: Arc::new(df),
45 case_sensitive: DEFAULT_CASE_SENSITIVE,
46 }
47 }
48
49 pub fn from_polars_with_options(df: PlDataFrame, case_sensitive: bool) -> Self {
52 DataFrame {
53 df: Arc::new(df),
54 case_sensitive,
55 }
56 }
57
58 pub fn empty() -> Self {
60 DataFrame {
61 df: Arc::new(PlDataFrame::empty()),
62 case_sensitive: DEFAULT_CASE_SENSITIVE,
63 }
64 }
65
66 pub fn resolve_expr_column_names(&self, expr: Expr) -> Result<Expr, PolarsError> {
73 let df = self;
74 let mut alias_output_names: HashSet<String> = HashSet::new();
75 let _ = expr.clone().try_map_expr(|e| {
76 if let Expr::Alias(_, name) = &e {
77 alias_output_names.insert(name.as_str().to_string());
78 }
79 Ok(e)
80 })?;
81 expr.try_map_expr(move |e| {
82 if let Expr::Column(name) = &e {
83 let name_str = name.as_str();
84 if alias_output_names.contains(name_str) {
85 return Ok(e);
86 }
87 let resolved = df.resolve_column_name(name_str)?;
88 return Ok(Expr::Column(PlSmallStr::from(resolved.as_str())));
89 }
90 Ok(e)
91 })
92 }
93
94 pub fn coerce_string_numeric_comparisons(&self, expr: Expr) -> Result<Expr, PolarsError> {
102 use polars::prelude::{DataType, LiteralValue, Operator};
103 use std::sync::Arc;
104
105 fn is_numeric_literal(expr: &Expr) -> bool {
106 matches!(
107 expr,
108 Expr::Literal(
109 LiteralValue::Int32(_)
110 | LiteralValue::Int64(_)
111 | LiteralValue::UInt32(_)
112 | LiteralValue::UInt64(_)
113 | LiteralValue::Float32(_)
114 | LiteralValue::Float64(_)
115 | LiteralValue::Int(_) | LiteralValue::Float(_) )
118 )
119 }
120
121 fn literal_dtype(lv: &LiteralValue) -> DataType {
122 match lv {
123 LiteralValue::Int32(_) => DataType::Int32,
124 LiteralValue::Int64(_) => DataType::Int64,
125 LiteralValue::UInt32(_) => DataType::UInt32,
126 LiteralValue::UInt64(_) => DataType::UInt64,
127 LiteralValue::Float32(_) => DataType::Float32,
128 LiteralValue::Float64(_) => DataType::Float64,
129 LiteralValue::Int(_) | LiteralValue::Float(_) => DataType::Float64,
130 _ => DataType::Float64,
131 }
132 }
133
134 let expr = {
137 if let Expr::BinaryExpr { left, op, right } = &expr {
138 let is_comparison_op = matches!(
139 op,
140 Operator::Eq
141 | Operator::NotEq
142 | Operator::Lt
143 | Operator::LtEq
144 | Operator::Gt
145 | Operator::GtEq
146 );
147 let left_is_col = matches!(&**left, Expr::Column(_));
148 let right_is_col = matches!(&**right, Expr::Column(_));
149 let left_is_numeric_lit =
150 matches!(&**left, Expr::Literal(_)) && is_numeric_literal(left.as_ref());
151 let right_is_numeric_lit =
152 matches!(&**right, Expr::Literal(_)) && is_numeric_literal(right.as_ref());
153 let root_is_col_vs_numeric = is_comparison_op
154 && ((left_is_col && right_is_numeric_lit)
155 || (right_is_col && left_is_numeric_lit));
156 if root_is_col_vs_numeric {
157 let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
158 let lit_ty = match &**right {
159 Expr::Literal(lv) => literal_dtype(lv),
160 _ => DataType::Float64,
161 };
162 coerce_for_pyspark_comparison(
163 (*left).as_ref().clone(),
164 (*right).as_ref().clone(),
165 &DataType::String,
166 &lit_ty,
167 op,
168 )
169 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
170 } else {
171 let lit_ty = match &**left {
172 Expr::Literal(lv) => literal_dtype(lv),
173 _ => DataType::Float64,
174 };
175 coerce_for_pyspark_comparison(
176 (*left).as_ref().clone(),
177 (*right).as_ref().clone(),
178 &lit_ty,
179 &DataType::String,
180 op,
181 )
182 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
183 };
184 Expr::BinaryExpr {
185 left: Arc::new(new_left),
186 op: *op,
187 right: Arc::new(new_right),
188 }
189 } else {
190 expr
191 }
192 } else {
193 expr
194 }
195 };
196
197 expr.try_map_expr(move |e| {
199 if let Expr::BinaryExpr { left, op, right } = e {
200 let is_comparison_op = matches!(
201 op,
202 Operator::Eq
203 | Operator::NotEq
204 | Operator::Lt
205 | Operator::LtEq
206 | Operator::Gt
207 | Operator::GtEq
208 );
209 if !is_comparison_op {
210 return Ok(Expr::BinaryExpr { left, op, right });
211 }
212
213 let left_is_col = matches!(&*left, Expr::Column(_));
214 let right_is_col = matches!(&*right, Expr::Column(_));
215 let left_is_lit = matches!(&*left, Expr::Literal(_));
216 let right_is_lit = matches!(&*right, Expr::Literal(_));
217
218 let left_is_numeric_lit = left_is_lit && is_numeric_literal(left.as_ref());
219 let right_is_numeric_lit = right_is_lit && is_numeric_literal(right.as_ref());
220
221 let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
225 let lit_ty = match &*right {
226 Expr::Literal(lv) => literal_dtype(lv),
227 _ => DataType::Float64,
228 };
229 coerce_for_pyspark_comparison(
230 (*left).clone(),
231 (*right).clone(),
232 &DataType::String,
233 &lit_ty,
234 &op,
235 )
236 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
237 } else if right_is_col && left_is_numeric_lit {
238 let lit_ty = match &*left {
239 Expr::Literal(lv) => literal_dtype(lv),
240 _ => DataType::Float64,
241 };
242 coerce_for_pyspark_comparison(
243 (*left).clone(),
244 (*right).clone(),
245 &lit_ty,
246 &DataType::String,
247 &op,
248 )
249 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
250 } else {
251 return Ok(Expr::BinaryExpr { left, op, right });
253 };
254
255 Ok(Expr::BinaryExpr {
256 left: Arc::new(new_left),
257 op,
258 right: Arc::new(new_right),
259 })
260 } else {
261 Ok(e)
262 }
263 })
264 }
265
266 pub fn resolve_column_name(&self, name: &str) -> Result<String, PolarsError> {
269 let names = self.df.get_column_names();
270 if self.case_sensitive {
271 if names.iter().any(|n| *n == name) {
272 return Ok(name.to_string());
273 }
274 } else {
275 let name_lower = name.to_lowercase();
276 for n in names {
277 if n.to_lowercase() == name_lower {
278 return Ok(n.to_string());
279 }
280 }
281 }
282 let available: Vec<String> = self
283 .df
284 .get_column_names()
285 .iter()
286 .map(|s| s.to_string())
287 .collect();
288 Err(PolarsError::ColumnNotFound(
289 format!(
290 "Column '{}' not found. Available columns: [{}]. Check spelling and case sensitivity (spark.sql.caseSensitive).",
291 name,
292 available.join(", ")
293 )
294 .into(),
295 ))
296 }
297
298 pub fn schema(&self) -> Result<StructType, PolarsError> {
300 Ok(StructType::from_polars_schema(&self.df.schema()))
301 }
302
303 pub fn columns(&self) -> Result<Vec<String>, PolarsError> {
305 Ok(self
306 .df
307 .get_column_names()
308 .iter()
309 .map(|s| s.to_string())
310 .collect())
311 }
312
313 pub fn count(&self) -> Result<usize, PolarsError> {
315 Ok(self.df.height())
316 }
317
318 pub fn show(&self, n: Option<usize>) -> Result<(), PolarsError> {
320 let n = n.unwrap_or(20);
321 println!("{}", self.df.head(Some(n)));
322 Ok(())
323 }
324
325 pub fn collect(&self) -> Result<Arc<PlDataFrame>, PolarsError> {
327 Ok(self.df.clone())
328 }
329
330 pub fn collect_as_json_rows(&self) -> Result<Vec<HashMap<String, JsonValue>>, PolarsError> {
332 let df = self.df.as_ref();
333 let names = df.get_column_names();
334 let nrows = df.height();
335 let mut rows = Vec::with_capacity(nrows);
336 for i in 0..nrows {
337 let mut row = HashMap::with_capacity(names.len());
338 for (col_idx, name) in names.iter().enumerate() {
339 let s = df
340 .get_columns()
341 .get(col_idx)
342 .ok_or_else(|| PolarsError::ComputeError("column index out of range".into()))?;
343 let av = s.get(i)?;
344 let jv = any_value_to_json(av);
345 row.insert(name.to_string(), jv);
346 }
347 rows.push(row);
348 }
349 Ok(rows)
350 }
351
352 pub fn select_exprs(&self, exprs: Vec<Expr>) -> Result<DataFrame, PolarsError> {
356 transformations::select_with_exprs(self, exprs, self.case_sensitive)
357 }
358
359 pub fn select(&self, cols: Vec<&str>) -> Result<DataFrame, PolarsError> {
362 let resolved: Vec<String> = cols
363 .iter()
364 .map(|c| self.resolve_column_name(c))
365 .collect::<Result<Vec<_>, _>>()?;
366 let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
367 let mut result = transformations::select(self, refs, self.case_sensitive)?;
368 if !self.case_sensitive {
370 for (requested, res) in cols.iter().zip(resolved.iter()) {
371 if *requested != res.as_str() {
372 result = result.with_column_renamed(res, requested)?;
373 }
374 }
375 }
376 Ok(result)
377 }
378
379 pub fn filter(&self, condition: Expr) -> Result<DataFrame, PolarsError> {
381 transformations::filter(self, condition, self.case_sensitive)
382 }
383
384 pub fn column(&self, name: &str) -> Result<Column, PolarsError> {
387 let resolved = self.resolve_column_name(name)?;
388 Ok(Column::new(resolved))
389 }
390
391 pub fn with_column(&self, column_name: &str, col: &Column) -> Result<DataFrame, PolarsError> {
394 transformations::with_column(self, column_name, col, self.case_sensitive)
395 }
396
397 pub fn with_column_expr(
399 &self,
400 column_name: &str,
401 expr: Expr,
402 ) -> Result<DataFrame, PolarsError> {
403 let col = Column::from_expr(expr, None);
404 self.with_column(column_name, &col)
405 }
406
407 pub fn group_by(&self, column_names: Vec<&str>) -> Result<GroupedData, PolarsError> {
410 use polars::prelude::*;
411 let resolved: Vec<String> = column_names
412 .iter()
413 .map(|c| self.resolve_column_name(c))
414 .collect::<Result<Vec<_>, _>>()?;
415 let exprs: Vec<Expr> = resolved.iter().map(|name| col(name.as_str())).collect();
416 let pl_df = self.df.as_ref().clone();
417 let lazy_grouped = pl_df.clone().lazy().group_by(exprs);
418 Ok(GroupedData {
419 df: pl_df,
420 lazy_grouped,
421 grouping_cols: resolved,
422 case_sensitive: self.case_sensitive,
423 })
424 }
425
426 pub fn cube(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
428 let resolved: Vec<String> = column_names
429 .iter()
430 .map(|c| self.resolve_column_name(c))
431 .collect::<Result<Vec<_>, _>>()?;
432 Ok(CubeRollupData {
433 df: self.df.as_ref().clone(),
434 grouping_cols: resolved,
435 case_sensitive: self.case_sensitive,
436 is_cube: true,
437 })
438 }
439
440 pub fn rollup(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
442 let resolved: Vec<String> = column_names
443 .iter()
444 .map(|c| self.resolve_column_name(c))
445 .collect::<Result<Vec<_>, _>>()?;
446 Ok(CubeRollupData {
447 df: self.df.as_ref().clone(),
448 grouping_cols: resolved,
449 case_sensitive: self.case_sensitive,
450 is_cube: false,
451 })
452 }
453
454 pub fn join(
457 &self,
458 other: &DataFrame,
459 on: Vec<&str>,
460 how: JoinType,
461 ) -> Result<DataFrame, PolarsError> {
462 let resolved: Vec<String> = on
463 .iter()
464 .map(|c| self.resolve_column_name(c))
465 .collect::<Result<Vec<_>, _>>()?;
466 let on_refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
467 join(self, other, on_refs, how, self.case_sensitive)
468 }
469
470 pub fn order_by(
473 &self,
474 column_names: Vec<&str>,
475 ascending: Vec<bool>,
476 ) -> Result<DataFrame, PolarsError> {
477 let resolved: Vec<String> = column_names
478 .iter()
479 .map(|c| self.resolve_column_name(c))
480 .collect::<Result<Vec<_>, _>>()?;
481 let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
482 transformations::order_by(self, refs, ascending, self.case_sensitive)
483 }
484
485 pub fn order_by_exprs(&self, sort_orders: Vec<SortOrder>) -> Result<DataFrame, PolarsError> {
487 transformations::order_by_exprs(self, sort_orders, self.case_sensitive)
488 }
489
490 pub fn union(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
492 transformations::union(self, other, self.case_sensitive)
493 }
494
495 pub fn union_by_name(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
497 transformations::union_by_name(self, other, self.case_sensitive)
498 }
499
500 pub fn distinct(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
502 transformations::distinct(self, subset, self.case_sensitive)
503 }
504
505 pub fn drop(&self, columns: Vec<&str>) -> Result<DataFrame, PolarsError> {
507 transformations::drop(self, columns, self.case_sensitive)
508 }
509
510 pub fn dropna(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
512 transformations::dropna(self, subset, self.case_sensitive)
513 }
514
515 pub fn fillna(&self, value: Expr) -> Result<DataFrame, PolarsError> {
517 transformations::fillna(self, value, self.case_sensitive)
518 }
519
520 pub fn limit(&self, n: usize) -> Result<DataFrame, PolarsError> {
522 transformations::limit(self, n, self.case_sensitive)
523 }
524
525 pub fn with_column_renamed(
527 &self,
528 old_name: &str,
529 new_name: &str,
530 ) -> Result<DataFrame, PolarsError> {
531 transformations::with_column_renamed(self, old_name, new_name, self.case_sensitive)
532 }
533
534 pub fn replace(
536 &self,
537 column_name: &str,
538 old_value: Expr,
539 new_value: Expr,
540 ) -> Result<DataFrame, PolarsError> {
541 transformations::replace(self, column_name, old_value, new_value, self.case_sensitive)
542 }
543
544 pub fn cross_join(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
546 transformations::cross_join(self, other, self.case_sensitive)
547 }
548
549 pub fn describe(&self) -> Result<DataFrame, PolarsError> {
551 transformations::describe(self, self.case_sensitive)
552 }
553
554 pub fn cache(&self) -> Result<DataFrame, PolarsError> {
556 Ok(self.clone())
557 }
558
559 pub fn persist(&self) -> Result<DataFrame, PolarsError> {
561 Ok(self.clone())
562 }
563
564 pub fn unpersist(&self) -> Result<DataFrame, PolarsError> {
566 Ok(self.clone())
567 }
568
569 pub fn subtract(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
571 transformations::subtract(self, other, self.case_sensitive)
572 }
573
574 pub fn intersect(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
576 transformations::intersect(self, other, self.case_sensitive)
577 }
578
579 pub fn sample(
581 &self,
582 with_replacement: bool,
583 fraction: f64,
584 seed: Option<u64>,
585 ) -> Result<DataFrame, PolarsError> {
586 transformations::sample(self, with_replacement, fraction, seed, self.case_sensitive)
587 }
588
589 pub fn random_split(
591 &self,
592 weights: &[f64],
593 seed: Option<u64>,
594 ) -> Result<Vec<DataFrame>, PolarsError> {
595 transformations::random_split(self, weights, seed, self.case_sensitive)
596 }
597
598 pub fn sample_by(
601 &self,
602 col_name: &str,
603 fractions: &[(Expr, f64)],
604 seed: Option<u64>,
605 ) -> Result<DataFrame, PolarsError> {
606 transformations::sample_by(self, col_name, fractions, seed, self.case_sensitive)
607 }
608
609 pub fn first(&self) -> Result<DataFrame, PolarsError> {
611 transformations::first(self, self.case_sensitive)
612 }
613
614 pub fn head(&self, n: usize) -> Result<DataFrame, PolarsError> {
616 transformations::head(self, n, self.case_sensitive)
617 }
618
619 pub fn take(&self, n: usize) -> Result<DataFrame, PolarsError> {
621 transformations::take(self, n, self.case_sensitive)
622 }
623
624 pub fn tail(&self, n: usize) -> Result<DataFrame, PolarsError> {
626 transformations::tail(self, n, self.case_sensitive)
627 }
628
629 pub fn is_empty(&self) -> bool {
631 transformations::is_empty(self)
632 }
633
634 pub fn to_df(&self, names: Vec<&str>) -> Result<DataFrame, PolarsError> {
636 transformations::to_df(self, &names, self.case_sensitive)
637 }
638
639 pub fn stat(&self) -> DataFrameStat<'_> {
641 DataFrameStat { df: self }
642 }
643
644 pub fn corr(&self) -> Result<DataFrame, PolarsError> {
646 self.stat().corr_matrix()
647 }
648
649 pub fn corr_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
651 self.stat().corr(col1, col2)
652 }
653
654 pub fn cov_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
656 self.stat().cov(col1, col2)
657 }
658
659 pub fn summary(&self) -> Result<DataFrame, PolarsError> {
661 self.describe()
662 }
663
664 pub fn to_json(&self) -> Result<Vec<String>, PolarsError> {
666 transformations::to_json(self)
667 }
668
669 pub fn explain(&self) -> String {
671 transformations::explain(self)
672 }
673
674 pub fn print_schema(&self) -> Result<String, PolarsError> {
676 transformations::print_schema(self)
677 }
678
679 pub fn checkpoint(&self) -> Result<DataFrame, PolarsError> {
681 Ok(self.clone())
682 }
683
684 pub fn local_checkpoint(&self) -> Result<DataFrame, PolarsError> {
686 Ok(self.clone())
687 }
688
689 pub fn repartition(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
691 Ok(self.clone())
692 }
693
694 pub fn repartition_by_range(
696 &self,
697 _num_partitions: usize,
698 _cols: Vec<&str>,
699 ) -> Result<DataFrame, PolarsError> {
700 Ok(self.clone())
701 }
702
703 pub fn dtypes(&self) -> Result<Vec<(String, String)>, PolarsError> {
705 let schema = self.df.schema();
706 Ok(schema
707 .iter_names_and_dtypes()
708 .map(|(name, dtype)| (name.to_string(), format!("{dtype:?}")))
709 .collect())
710 }
711
712 pub fn sort_within_partitions(
714 &self,
715 _cols: &[crate::functions::SortOrder],
716 ) -> Result<DataFrame, PolarsError> {
717 Ok(self.clone())
718 }
719
720 pub fn coalesce(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
722 Ok(self.clone())
723 }
724
725 pub fn hint(&self, _name: &str, _params: &[i32]) -> Result<DataFrame, PolarsError> {
727 Ok(self.clone())
728 }
729
730 pub fn is_local(&self) -> bool {
732 true
733 }
734
735 pub fn input_files(&self) -> Vec<String> {
737 Vec::new()
738 }
739
740 pub fn same_semantics(&self, _other: &DataFrame) -> bool {
742 false
743 }
744
745 pub fn semantic_hash(&self) -> u64 {
747 0
748 }
749
750 pub fn observe(&self, _name: &str, _expr: Expr) -> Result<DataFrame, PolarsError> {
752 Ok(self.clone())
753 }
754
755 pub fn with_watermark(
757 &self,
758 _event_time: &str,
759 _delay: &str,
760 ) -> Result<DataFrame, PolarsError> {
761 Ok(self.clone())
762 }
763
764 pub fn select_expr(&self, exprs: &[String]) -> Result<DataFrame, PolarsError> {
766 transformations::select_expr(self, exprs, self.case_sensitive)
767 }
768
769 pub fn col_regex(&self, pattern: &str) -> Result<DataFrame, PolarsError> {
771 transformations::col_regex(self, pattern, self.case_sensitive)
772 }
773
774 pub fn with_columns(&self, exprs: &[(String, Column)]) -> Result<DataFrame, PolarsError> {
776 transformations::with_columns(self, exprs, self.case_sensitive)
777 }
778
779 pub fn with_columns_renamed(
781 &self,
782 renames: &[(String, String)],
783 ) -> Result<DataFrame, PolarsError> {
784 transformations::with_columns_renamed(self, renames, self.case_sensitive)
785 }
786
787 pub fn na(&self) -> DataFrameNa<'_> {
789 DataFrameNa { df: self }
790 }
791
792 pub fn offset(&self, n: usize) -> Result<DataFrame, PolarsError> {
794 transformations::offset(self, n, self.case_sensitive)
795 }
796
797 pub fn transform<F>(&self, f: F) -> Result<DataFrame, PolarsError>
799 where
800 F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
801 {
802 transformations::transform(self, f)
803 }
804
805 pub fn freq_items(&self, columns: &[&str], support: f64) -> Result<DataFrame, PolarsError> {
807 transformations::freq_items(self, columns, support, self.case_sensitive)
808 }
809
810 pub fn approx_quantile(
812 &self,
813 column: &str,
814 probabilities: &[f64],
815 ) -> Result<DataFrame, PolarsError> {
816 transformations::approx_quantile(self, column, probabilities, self.case_sensitive)
817 }
818
819 pub fn crosstab(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
821 transformations::crosstab(self, col1, col2, self.case_sensitive)
822 }
823
824 pub fn melt(&self, id_vars: &[&str], value_vars: &[&str]) -> Result<DataFrame, PolarsError> {
826 transformations::melt(self, id_vars, value_vars, self.case_sensitive)
827 }
828
829 pub fn pivot(
831 &self,
832 _pivot_col: &str,
833 _values: Option<Vec<&str>>,
834 ) -> Result<DataFrame, PolarsError> {
835 Err(PolarsError::InvalidOperation(
836 "pivot is not yet implemented; use crosstab(col1, col2) for two-column cross-tabulation."
837 .into(),
838 ))
839 }
840
841 pub fn except_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
843 transformations::except_all(self, other, self.case_sensitive)
844 }
845
846 pub fn intersect_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
848 transformations::intersect_all(self, other, self.case_sensitive)
849 }
850
851 #[cfg(feature = "delta")]
854 pub fn write_delta(
855 &self,
856 path: impl AsRef<std::path::Path>,
857 overwrite: bool,
858 ) -> Result<(), PolarsError> {
859 crate::delta::write_delta(self.df.as_ref(), path, overwrite)
860 }
861
862 #[cfg(not(feature = "delta"))]
864 pub fn write_delta(
865 &self,
866 _path: impl AsRef<std::path::Path>,
867 _overwrite: bool,
868 ) -> Result<(), PolarsError> {
869 Err(PolarsError::InvalidOperation(
870 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
871 ))
872 }
873
874 pub fn save_as_delta_table(&self, session: &crate::session::SparkSession, name: &str) {
876 session.register_table(name, self.clone());
877 }
878
879 pub fn write(&self) -> DataFrameWriter<'_> {
881 DataFrameWriter {
882 df: self,
883 mode: WriteMode::Overwrite,
884 format: WriteFormat::Parquet,
885 options: HashMap::new(),
886 partition_by: Vec::new(),
887 }
888 }
889}
890
891#[derive(Clone, Copy, PartialEq, Eq)]
893pub enum WriteMode {
894 Overwrite,
895 Append,
896}
897
898#[derive(Clone, Copy, PartialEq, Eq)]
900pub enum SaveMode {
901 ErrorIfExists,
903 Overwrite,
905 Append,
907 Ignore,
909}
910
911#[derive(Clone, Copy)]
913pub enum WriteFormat {
914 Parquet,
915 Csv,
916 Json,
917}
918
919pub struct DataFrameWriter<'a> {
921 df: &'a DataFrame,
922 mode: WriteMode,
923 format: WriteFormat,
924 options: HashMap<String, String>,
925 partition_by: Vec<String>,
926}
927
928impl<'a> DataFrameWriter<'a> {
929 pub fn mode(mut self, mode: WriteMode) -> Self {
930 self.mode = mode;
931 self
932 }
933
934 pub fn format(mut self, format: WriteFormat) -> Self {
935 self.format = format;
936 self
937 }
938
939 pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
941 self.options.insert(key.into(), value.into());
942 self
943 }
944
945 pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
947 for (k, v) in opts {
948 self.options.insert(k, v);
949 }
950 self
951 }
952
953 pub fn partition_by(mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> Self {
955 self.partition_by = cols.into_iter().map(|s| s.into()).collect();
956 self
957 }
958
959 pub fn save_as_table(
961 &self,
962 session: &SparkSession,
963 name: &str,
964 mode: SaveMode,
965 ) -> Result<(), PolarsError> {
966 use polars::prelude::*;
967 use std::fs;
968 use std::path::Path;
969
970 let warehouse_path = session.warehouse_dir().map(|w| Path::new(w).join(name));
971 let warehouse_exists = warehouse_path.as_ref().is_some_and(|p| p.is_dir());
972
973 fn persist_to_warehouse(
974 df: &crate::dataframe::DataFrame,
975 dir: &Path,
976 ) -> Result<(), PolarsError> {
977 use std::fs;
978 fs::create_dir_all(dir).map_err(|e| {
979 PolarsError::ComputeError(format!("saveAsTable: create dir: {e}").into())
980 })?;
981 let file_path = dir.join("data.parquet");
982 df.write()
983 .mode(crate::dataframe::WriteMode::Overwrite)
984 .format(crate::dataframe::WriteFormat::Parquet)
985 .save(&file_path)
986 }
987
988 let final_df = match mode {
989 SaveMode::ErrorIfExists => {
990 if session.saved_table_exists(name) || warehouse_exists {
991 return Err(PolarsError::InvalidOperation(
992 format!(
993 "Table or view '{name}' already exists. SaveMode is ErrorIfExists."
994 )
995 .into(),
996 ));
997 }
998 if let Some(ref p) = warehouse_path {
999 persist_to_warehouse(self.df, p)?;
1000 }
1001 self.df.clone()
1002 }
1003 SaveMode::Overwrite => {
1004 if let Some(ref p) = warehouse_path {
1005 let _ = fs::remove_dir_all(p);
1006 persist_to_warehouse(self.df, p)?;
1007 }
1008 self.df.clone()
1009 }
1010 SaveMode::Append => {
1011 let existing_pl = if let Some(existing) = session.get_saved_table(name) {
1012 existing.df.as_ref().clone()
1013 } else if let (Some(ref p), true) = (warehouse_path.as_ref(), warehouse_exists) {
1014 let data_file = p.join("data.parquet");
1016 let read_path = if data_file.is_file() {
1017 data_file.as_path()
1018 } else {
1019 p.as_ref()
1020 };
1021 let lf = LazyFrame::scan_parquet(read_path, ScanArgsParquet::default())
1022 .map_err(|e| {
1023 PolarsError::ComputeError(
1024 format!("saveAsTable append: read warehouse: {e}").into(),
1025 )
1026 })?;
1027 lf.collect().map_err(|e| {
1028 PolarsError::ComputeError(
1029 format!("saveAsTable append: collect: {e}").into(),
1030 )
1031 })?
1032 } else {
1033 session.register_table(name, self.df.clone());
1035 if let Some(ref p) = warehouse_path {
1036 persist_to_warehouse(self.df, p)?;
1037 }
1038 return Ok(());
1039 };
1040 let new_pl = self.df.df.as_ref().clone();
1041 let existing_cols: Vec<&str> = existing_pl
1042 .get_column_names()
1043 .iter()
1044 .map(|s| s.as_str())
1045 .collect();
1046 let new_cols = new_pl.get_column_names();
1047 let missing: Vec<_> = existing_cols
1048 .iter()
1049 .filter(|c| !new_cols.iter().any(|n| n.as_str() == **c))
1050 .collect();
1051 if !missing.is_empty() {
1052 return Err(PolarsError::InvalidOperation(
1053 format!(
1054 "saveAsTable append: new DataFrame missing columns: {:?}",
1055 missing
1056 )
1057 .into(),
1058 ));
1059 }
1060 let new_ordered = new_pl.select(existing_cols.iter().copied())?;
1061 let mut combined = existing_pl;
1062 combined.vstack_mut(&new_ordered)?;
1063 let merged = crate::dataframe::DataFrame::from_polars_with_options(
1064 combined,
1065 self.df.case_sensitive,
1066 );
1067 if let Some(ref p) = warehouse_path {
1068 let _ = fs::remove_dir_all(p);
1069 persist_to_warehouse(&merged, p)?;
1070 }
1071 merged
1072 }
1073 SaveMode::Ignore => {
1074 if session.saved_table_exists(name) || warehouse_exists {
1075 return Ok(());
1076 }
1077 if let Some(ref p) = warehouse_path {
1078 persist_to_warehouse(self.df, p)?;
1079 }
1080 self.df.clone()
1081 }
1082 };
1083 session.register_table(name, final_df);
1084 Ok(())
1085 }
1086
1087 pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1089 DataFrameWriter {
1090 df: self.df,
1091 mode: self.mode,
1092 format: WriteFormat::Parquet,
1093 options: self.options.clone(),
1094 partition_by: self.partition_by.clone(),
1095 }
1096 .save(path)
1097 }
1098
1099 pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1101 DataFrameWriter {
1102 df: self.df,
1103 mode: self.mode,
1104 format: WriteFormat::Csv,
1105 options: self.options.clone(),
1106 partition_by: self.partition_by.clone(),
1107 }
1108 .save(path)
1109 }
1110
1111 pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1113 DataFrameWriter {
1114 df: self.df,
1115 mode: self.mode,
1116 format: WriteFormat::Json,
1117 options: self.options.clone(),
1118 partition_by: self.partition_by.clone(),
1119 }
1120 .save(path)
1121 }
1122
1123 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1126 use polars::prelude::*;
1127 let path = path.as_ref();
1128 let to_write: PlDataFrame = match self.mode {
1129 WriteMode::Overwrite => self.df.df.as_ref().clone(),
1130 WriteMode::Append => {
1131 if self.partition_by.is_empty() {
1132 let existing: Option<PlDataFrame> = if path.exists() && path.is_file() {
1133 match self.format {
1134 WriteFormat::Parquet => {
1135 LazyFrame::scan_parquet(path, ScanArgsParquet::default())
1136 .and_then(|lf| lf.collect())
1137 .ok()
1138 }
1139 WriteFormat::Csv => LazyCsvReader::new(path)
1140 .with_has_header(true)
1141 .finish()
1142 .and_then(|lf| lf.collect())
1143 .ok(),
1144 WriteFormat::Json => LazyJsonLineReader::new(path)
1145 .finish()
1146 .and_then(|lf| lf.collect())
1147 .ok(),
1148 }
1149 } else {
1150 None
1151 };
1152 match existing {
1153 Some(existing) => {
1154 let lfs: [LazyFrame; 2] =
1155 [existing.lazy(), self.df.df.as_ref().clone().lazy()];
1156 concat(lfs, UnionArgs::default())?.collect()?
1157 }
1158 None => self.df.df.as_ref().clone(),
1159 }
1160 } else {
1161 self.df.df.as_ref().clone()
1162 }
1163 }
1164 };
1165
1166 if !self.partition_by.is_empty() {
1167 return self.save_partitioned(path, &to_write);
1168 }
1169
1170 match self.format {
1171 WriteFormat::Parquet => {
1172 let mut file = std::fs::File::create(path).map_err(|e| {
1173 PolarsError::ComputeError(format!("write parquet create: {e}").into())
1174 })?;
1175 let mut df_mut = to_write;
1176 ParquetWriter::new(&mut file)
1177 .finish(&mut df_mut)
1178 .map_err(|e| PolarsError::ComputeError(format!("write parquet: {e}").into()))?;
1179 }
1180 WriteFormat::Csv => {
1181 let has_header = self
1182 .options
1183 .get("header")
1184 .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
1185 .unwrap_or(true);
1186 let delimiter = self
1187 .options
1188 .get("sep")
1189 .and_then(|s| s.bytes().next())
1190 .unwrap_or(b',');
1191 let mut file = std::fs::File::create(path).map_err(|e| {
1192 PolarsError::ComputeError(format!("write csv create: {e}").into())
1193 })?;
1194 CsvWriter::new(&mut file)
1195 .include_header(has_header)
1196 .with_separator(delimiter)
1197 .finish(&mut to_write.clone())
1198 .map_err(|e| PolarsError::ComputeError(format!("write csv: {e}").into()))?;
1199 }
1200 WriteFormat::Json => {
1201 let mut file = std::fs::File::create(path).map_err(|e| {
1202 PolarsError::ComputeError(format!("write json create: {e}").into())
1203 })?;
1204 JsonWriter::new(&mut file)
1205 .finish(&mut to_write.clone())
1206 .map_err(|e| PolarsError::ComputeError(format!("write json: {e}").into()))?;
1207 }
1208 }
1209 Ok(())
1210 }
1211
1212 fn save_partitioned(&self, path: &Path, to_write: &PlDataFrame) -> Result<(), PolarsError> {
1214 use polars::prelude::*;
1215 let resolved: Vec<String> = self
1216 .partition_by
1217 .iter()
1218 .map(|c| self.df.resolve_column_name(c))
1219 .collect::<Result<Vec<_>, _>>()?;
1220 let all_names = to_write.get_column_names();
1221 let data_cols: Vec<&str> = all_names
1222 .iter()
1223 .filter(|n| !resolved.iter().any(|r| r == n.as_str()))
1224 .map(|n| n.as_str())
1225 .collect();
1226
1227 let unique_keys = to_write
1228 .select(resolved.iter().map(|s| s.as_str()).collect::<Vec<_>>())?
1229 .unique::<Option<&[String]>, String>(
1230 None,
1231 polars::prelude::UniqueKeepStrategy::First,
1232 None,
1233 )?;
1234
1235 if self.mode == WriteMode::Overwrite && path.exists() {
1236 if path.is_dir() {
1237 std::fs::remove_dir_all(path).map_err(|e| {
1238 PolarsError::ComputeError(
1239 format!("write partitioned: remove_dir_all: {e}").into(),
1240 )
1241 })?;
1242 } else {
1243 std::fs::remove_file(path).map_err(|e| {
1244 PolarsError::ComputeError(format!("write partitioned: remove_file: {e}").into())
1245 })?;
1246 }
1247 }
1248 std::fs::create_dir_all(path).map_err(|e| {
1249 PolarsError::ComputeError(format!("write partitioned: create_dir_all: {e}").into())
1250 })?;
1251
1252 let ext = match self.format {
1253 WriteFormat::Parquet => "parquet",
1254 WriteFormat::Csv => "csv",
1255 WriteFormat::Json => "json",
1256 };
1257
1258 for row_idx in 0..unique_keys.height() {
1259 let row = unique_keys
1260 .get(row_idx)
1261 .ok_or_else(|| PolarsError::ComputeError("partition_row: get row".into()))?;
1262 let filter_expr = partition_row_to_filter_expr(&resolved, &row)?;
1263 let subset = to_write.clone().lazy().filter(filter_expr).collect()?;
1264 let subset = subset.select(data_cols.iter().copied())?;
1265 if subset.height() == 0 {
1266 continue;
1267 }
1268
1269 let part_path: std::path::PathBuf = resolved
1270 .iter()
1271 .zip(row.iter())
1272 .map(|(name, av)| format!("{}={}", name, format_partition_value(av)))
1273 .fold(path.to_path_buf(), |p, seg| p.join(seg));
1274 std::fs::create_dir_all(&part_path).map_err(|e| {
1275 PolarsError::ComputeError(
1276 format!("write partitioned: create_dir_all partition: {e}").into(),
1277 )
1278 })?;
1279
1280 let file_idx = if self.mode == WriteMode::Append {
1281 let suffix = format!(".{ext}");
1282 let max_n = std::fs::read_dir(&part_path)
1283 .map(|rd| {
1284 rd.filter_map(Result::ok)
1285 .filter_map(|e| {
1286 e.file_name().to_str().and_then(|s| {
1287 s.strip_prefix("part-")
1288 .and_then(|t| t.strip_suffix(&suffix))
1289 .and_then(|t| t.parse::<u32>().ok())
1290 })
1291 })
1292 .max()
1293 .unwrap_or(0)
1294 })
1295 .unwrap_or(0);
1296 max_n + 1
1297 } else {
1298 0
1299 };
1300 let filename = format!("part-{file_idx:05}.{ext}");
1301 let file_path = part_path.join(&filename);
1302
1303 match self.format {
1304 WriteFormat::Parquet => {
1305 let mut file = std::fs::File::create(&file_path).map_err(|e| {
1306 PolarsError::ComputeError(
1307 format!("write partitioned parquet create: {e}").into(),
1308 )
1309 })?;
1310 let mut df_mut = subset;
1311 ParquetWriter::new(&mut file)
1312 .finish(&mut df_mut)
1313 .map_err(|e| {
1314 PolarsError::ComputeError(
1315 format!("write partitioned parquet: {e}").into(),
1316 )
1317 })?;
1318 }
1319 WriteFormat::Csv => {
1320 let has_header = self
1321 .options
1322 .get("header")
1323 .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
1324 .unwrap_or(true);
1325 let delimiter = self
1326 .options
1327 .get("sep")
1328 .and_then(|s| s.bytes().next())
1329 .unwrap_or(b',');
1330 let mut file = std::fs::File::create(&file_path).map_err(|e| {
1331 PolarsError::ComputeError(
1332 format!("write partitioned csv create: {e}").into(),
1333 )
1334 })?;
1335 CsvWriter::new(&mut file)
1336 .include_header(has_header)
1337 .with_separator(delimiter)
1338 .finish(&mut subset.clone())
1339 .map_err(|e| {
1340 PolarsError::ComputeError(format!("write partitioned csv: {e}").into())
1341 })?;
1342 }
1343 WriteFormat::Json => {
1344 let mut file = std::fs::File::create(&file_path).map_err(|e| {
1345 PolarsError::ComputeError(
1346 format!("write partitioned json create: {e}").into(),
1347 )
1348 })?;
1349 JsonWriter::new(&mut file)
1350 .finish(&mut subset.clone())
1351 .map_err(|e| {
1352 PolarsError::ComputeError(format!("write partitioned json: {e}").into())
1353 })?;
1354 }
1355 }
1356 }
1357 Ok(())
1358 }
1359}
1360
1361impl Clone for DataFrame {
1362 fn clone(&self) -> Self {
1363 DataFrame {
1364 df: self.df.clone(),
1365 case_sensitive: self.case_sensitive,
1366 }
1367 }
1368}
1369
1370fn format_partition_value(av: &AnyValue<'_>) -> String {
1373 let s = match av {
1374 AnyValue::Null => "__HIVE_DEFAULT_PARTITION__".to_string(),
1375 AnyValue::Boolean(b) => b.to_string(),
1376 AnyValue::Int32(i) => i.to_string(),
1377 AnyValue::Int64(i) => i.to_string(),
1378 AnyValue::UInt32(u) => u.to_string(),
1379 AnyValue::UInt64(u) => u.to_string(),
1380 AnyValue::Float32(f) => f.to_string(),
1381 AnyValue::Float64(f) => f.to_string(),
1382 AnyValue::String(s) => s.to_string(),
1383 AnyValue::StringOwned(s) => s.as_str().to_string(),
1384 AnyValue::Date(d) => d.to_string(),
1385 _ => av.to_string(),
1386 };
1387 s.replace([std::path::MAIN_SEPARATOR, '/'], "_")
1389}
1390
1391fn partition_row_to_filter_expr(
1393 col_names: &[String],
1394 row: &[AnyValue<'_>],
1395) -> Result<Expr, PolarsError> {
1396 if col_names.len() != row.len() {
1397 return Err(PolarsError::ComputeError(
1398 format!(
1399 "partition_row_to_filter_expr: {} columns but {} row values",
1400 col_names.len(),
1401 row.len()
1402 )
1403 .into(),
1404 ));
1405 }
1406 let mut pred = None::<Expr>;
1407 for (name, av) in col_names.iter().zip(row.iter()) {
1408 let clause = match av {
1409 AnyValue::Null => col(name.as_str()).is_null(),
1410 AnyValue::Boolean(b) => col(name.as_str()).eq(lit(*b)),
1411 AnyValue::Int32(i) => col(name.as_str()).eq(lit(*i)),
1412 AnyValue::Int64(i) => col(name.as_str()).eq(lit(*i)),
1413 AnyValue::UInt32(u) => col(name.as_str()).eq(lit(*u)),
1414 AnyValue::UInt64(u) => col(name.as_str()).eq(lit(*u)),
1415 AnyValue::Float32(f) => col(name.as_str()).eq(lit(*f)),
1416 AnyValue::Float64(f) => col(name.as_str()).eq(lit(*f)),
1417 AnyValue::String(s) => col(name.as_str()).eq(lit(s.to_string())),
1418 AnyValue::StringOwned(s) => col(name.as_str()).eq(lit(s.clone())),
1419 _ => {
1420 let s = av.to_string();
1422 col(name.as_str()).cast(DataType::String).eq(lit(s))
1423 }
1424 };
1425 pred = Some(match pred {
1426 None => clause,
1427 Some(p) => p.and(clause),
1428 });
1429 }
1430 Ok(pred.unwrap_or_else(|| lit(true)))
1431}
1432
1433fn any_value_to_json(av: AnyValue<'_>) -> JsonValue {
1435 match av {
1436 AnyValue::Null => JsonValue::Null,
1437 AnyValue::Boolean(b) => JsonValue::Bool(b),
1438 AnyValue::Int32(i) => JsonValue::Number(serde_json::Number::from(i)),
1439 AnyValue::Int64(i) => JsonValue::Number(serde_json::Number::from(i)),
1440 AnyValue::UInt32(u) => JsonValue::Number(serde_json::Number::from(u)),
1441 AnyValue::UInt64(u) => JsonValue::Number(serde_json::Number::from(u)),
1442 AnyValue::Float32(f) => serde_json::Number::from_f64(f64::from(f))
1443 .map(JsonValue::Number)
1444 .unwrap_or(JsonValue::Null),
1445 AnyValue::Float64(f) => serde_json::Number::from_f64(f)
1446 .map(JsonValue::Number)
1447 .unwrap_or(JsonValue::Null),
1448 AnyValue::String(s) => JsonValue::String(s.to_string()),
1449 AnyValue::StringOwned(s) => JsonValue::String(s.to_string()),
1450 _ => JsonValue::Null,
1451 }
1452}
1453
1454#[cfg(test)]
1455mod tests {
1456 use super::*;
1457 use polars::prelude::{NamedFrom, Series};
1458
1459 #[test]
1461 fn coerce_string_numeric_root_in_filter() {
1462 let s = Series::new("str_col".into(), &["123", "456"]);
1463 let pl_df = polars::prelude::DataFrame::new(vec![s.into()]).unwrap();
1464 let df = DataFrame::from_polars(pl_df);
1465 let expr = col("str_col").eq(lit(123i64));
1466 let out = df.filter(expr).unwrap();
1467 assert_eq!(out.count().unwrap(), 1);
1468 }
1469}