Skip to main content

robin_sparkless/dataframe/
mod.rs

1//! DataFrame module: main tabular type and submodules for transformations, aggregations, joins, stats.
2
3mod aggregations;
4mod joins;
5mod stats;
6mod transformations;
7
8pub use aggregations::{CubeRollupData, GroupedData};
9pub use joins::{join, JoinType};
10pub use stats::DataFrameStat;
11pub use transformations::{
12    filter, order_by, order_by_exprs, select, select_with_exprs, with_column, DataFrameNa,
13};
14
15use crate::column::Column;
16use crate::functions::SortOrder;
17use crate::schema::StructType;
18use polars::prelude::{
19    col, lit, AnyValue, DataFrame as PlDataFrame, DataType, Expr, PolarsError, SchemaNamesAndDtypes,
20};
21use serde_json::Value as JsonValue;
22use std::collections::HashMap;
23use std::path::Path;
24use std::sync::Arc;
25
26/// Default for `spark.sql.caseSensitive` (PySpark default is false = case-insensitive).
27const DEFAULT_CASE_SENSITIVE: bool = false;
28
29/// DataFrame - main tabular data structure.
30/// Thin wrapper around an eager Polars `DataFrame`.
31pub struct DataFrame {
32    pub(crate) df: Arc<PlDataFrame>,
33    /// When false (default), column names are matched case-insensitively (PySpark behavior).
34    pub(crate) case_sensitive: bool,
35}
36
37impl DataFrame {
38    /// Create a new DataFrame from a Polars DataFrame (case-insensitive column matching by default).
39    pub fn from_polars(df: PlDataFrame) -> Self {
40        DataFrame {
41            df: Arc::new(df),
42            case_sensitive: DEFAULT_CASE_SENSITIVE,
43        }
44    }
45
46    /// Create a new DataFrame from a Polars DataFrame with explicit case sensitivity.
47    /// When `case_sensitive` is false, column resolution is case-insensitive (PySpark default).
48    pub fn from_polars_with_options(df: PlDataFrame, case_sensitive: bool) -> Self {
49        DataFrame {
50            df: Arc::new(df),
51            case_sensitive,
52        }
53    }
54
55    /// Create an empty DataFrame
56    pub fn empty() -> Self {
57        DataFrame {
58            df: Arc::new(PlDataFrame::empty()),
59            case_sensitive: DEFAULT_CASE_SENSITIVE,
60        }
61    }
62
63    /// Resolve a logical column name to the actual column name in the schema.
64    /// When case_sensitive is false, matches case-insensitively.
65    pub fn resolve_column_name(&self, name: &str) -> Result<String, PolarsError> {
66        let names = self.df.get_column_names();
67        if self.case_sensitive {
68            if names.iter().any(|n| *n == name) {
69                return Ok(name.to_string());
70            }
71        } else {
72            let name_lower = name.to_lowercase();
73            for n in names {
74                if n.to_lowercase() == name_lower {
75                    return Ok(n.to_string());
76                }
77            }
78        }
79        let available: Vec<String> = self
80            .df
81            .get_column_names()
82            .iter()
83            .map(|s| s.to_string())
84            .collect();
85        Err(PolarsError::ColumnNotFound(
86            format!(
87                "Column '{}' not found. Available columns: [{}]. Check spelling and case sensitivity (spark.sql.caseSensitive).",
88                name,
89                available.join(", ")
90            )
91            .into(),
92        ))
93    }
94
95    /// Get the schema of the DataFrame
96    pub fn schema(&self) -> Result<StructType, PolarsError> {
97        Ok(StructType::from_polars_schema(&self.df.schema()))
98    }
99
100    /// Get column names
101    pub fn columns(&self) -> Result<Vec<String>, PolarsError> {
102        Ok(self
103            .df
104            .get_column_names()
105            .iter()
106            .map(|s| s.to_string())
107            .collect())
108    }
109
110    /// Count the number of rows (action - triggers execution)
111    pub fn count(&self) -> Result<usize, PolarsError> {
112        Ok(self.df.height())
113    }
114
115    /// Show the first n rows
116    pub fn show(&self, n: Option<usize>) -> Result<(), PolarsError> {
117        let n = n.unwrap_or(20);
118        println!("{}", self.df.head(Some(n)));
119        Ok(())
120    }
121
122    /// Collect the DataFrame (action - triggers execution)
123    pub fn collect(&self) -> Result<Arc<PlDataFrame>, PolarsError> {
124        Ok(self.df.clone())
125    }
126
127    /// Collect as rows of column-name -> JSON value. For use by language bindings (Node, etc.).
128    pub fn collect_as_json_rows(&self) -> Result<Vec<HashMap<String, JsonValue>>, PolarsError> {
129        let df = self.df.as_ref();
130        let names = df.get_column_names();
131        let nrows = df.height();
132        let mut rows = Vec::with_capacity(nrows);
133        for i in 0..nrows {
134            let mut row = HashMap::with_capacity(names.len());
135            for (col_idx, name) in names.iter().enumerate() {
136                let s = df
137                    .get_columns()
138                    .get(col_idx)
139                    .ok_or_else(|| PolarsError::ComputeError("column index out of range".into()))?;
140                let av = s.get(i)?;
141                let jv = any_value_to_json(av);
142                row.insert(name.to_string(), jv);
143            }
144            rows.push(row);
145        }
146        Ok(rows)
147    }
148
149    /// Select columns (returns a new DataFrame).
150    /// Accepts either column names (strings) or Column expressions (e.g. from regexp_extract_all(...).alias("m")).
151    /// Column names are resolved according to case sensitivity.
152    pub fn select_exprs(&self, exprs: Vec<Expr>) -> Result<DataFrame, PolarsError> {
153        transformations::select_with_exprs(self, exprs, self.case_sensitive)
154    }
155
156    /// Select columns by name (returns a new DataFrame).
157    /// Column names are resolved according to case sensitivity.
158    pub fn select(&self, cols: Vec<&str>) -> Result<DataFrame, PolarsError> {
159        let resolved: Vec<String> = cols
160            .iter()
161            .map(|c| self.resolve_column_name(c))
162            .collect::<Result<Vec<_>, _>>()?;
163        let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
164        let mut result = transformations::select(self, refs, self.case_sensitive)?;
165        // When case-insensitive, PySpark returns column names in requested (e.g. lowercase) form.
166        if !self.case_sensitive {
167            for (requested, res) in cols.iter().zip(resolved.iter()) {
168                if *requested != res.as_str() {
169                    result = result.with_column_renamed(res, requested)?;
170                }
171            }
172        }
173        Ok(result)
174    }
175
176    /// Filter rows using a Polars expression.
177    pub fn filter(&self, condition: Expr) -> Result<DataFrame, PolarsError> {
178        transformations::filter(self, condition, self.case_sensitive)
179    }
180
181    /// Get a column reference by name (for building expressions).
182    /// Respects case sensitivity: when false, "Age" resolves to column "age" if present.
183    pub fn column(&self, name: &str) -> Result<Column, PolarsError> {
184        let resolved = self.resolve_column_name(name)?;
185        Ok(Column::new(resolved))
186    }
187
188    /// Add or replace a column. Use a [`Column`] (e.g. from `col("x")`, `rand(42)`, `randn(42)`).
189    /// For `rand`/`randn`, generates one distinct value per row (PySpark-like).
190    pub fn with_column(&self, column_name: &str, col: &Column) -> Result<DataFrame, PolarsError> {
191        transformations::with_column(self, column_name, col, self.case_sensitive)
192    }
193
194    /// Add or replace a column using an expression. Prefer [`with_column`](Self::with_column) with a `Column` for rand/randn (per-row values).
195    pub fn with_column_expr(
196        &self,
197        column_name: &str,
198        expr: Expr,
199    ) -> Result<DataFrame, PolarsError> {
200        let col = Column::from_expr(expr, None);
201        self.with_column(column_name, &col)
202    }
203
204    /// Group by columns (returns GroupedData for aggregation).
205    /// Column names are resolved according to case sensitivity.
206    pub fn group_by(&self, column_names: Vec<&str>) -> Result<GroupedData, PolarsError> {
207        use polars::prelude::*;
208        let resolved: Vec<String> = column_names
209            .iter()
210            .map(|c| self.resolve_column_name(c))
211            .collect::<Result<Vec<_>, _>>()?;
212        let exprs: Vec<Expr> = resolved.iter().map(|name| col(name.as_str())).collect();
213        let lazy_grouped = self.df.as_ref().clone().lazy().group_by(exprs);
214        Ok(GroupedData {
215            lazy_grouped,
216            grouping_cols: resolved,
217            case_sensitive: self.case_sensitive,
218        })
219    }
220
221    /// Cube: multiple grouping sets (all subsets of columns), then union (PySpark cube).
222    pub fn cube(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
223        let resolved: Vec<String> = column_names
224            .iter()
225            .map(|c| self.resolve_column_name(c))
226            .collect::<Result<Vec<_>, _>>()?;
227        Ok(CubeRollupData {
228            df: self.df.as_ref().clone(),
229            grouping_cols: resolved,
230            case_sensitive: self.case_sensitive,
231            is_cube: true,
232        })
233    }
234
235    /// Rollup: grouping sets (prefixes of columns), then union (PySpark rollup).
236    pub fn rollup(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
237        let resolved: Vec<String> = column_names
238            .iter()
239            .map(|c| self.resolve_column_name(c))
240            .collect::<Result<Vec<_>, _>>()?;
241        Ok(CubeRollupData {
242            df: self.df.as_ref().clone(),
243            grouping_cols: resolved,
244            case_sensitive: self.case_sensitive,
245            is_cube: false,
246        })
247    }
248
249    /// Join with another DataFrame on the given columns.
250    /// Join column names are resolved on the left (and right must have matching names).
251    pub fn join(
252        &self,
253        other: &DataFrame,
254        on: Vec<&str>,
255        how: JoinType,
256    ) -> Result<DataFrame, PolarsError> {
257        let resolved: Vec<String> = on
258            .iter()
259            .map(|c| self.resolve_column_name(c))
260            .collect::<Result<Vec<_>, _>>()?;
261        let on_refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
262        join(self, other, on_refs, how, self.case_sensitive)
263    }
264
265    /// Order by columns (sort).
266    /// Column names are resolved according to case sensitivity.
267    pub fn order_by(
268        &self,
269        column_names: Vec<&str>,
270        ascending: Vec<bool>,
271    ) -> Result<DataFrame, PolarsError> {
272        let resolved: Vec<String> = column_names
273            .iter()
274            .map(|c| self.resolve_column_name(c))
275            .collect::<Result<Vec<_>, _>>()?;
276        let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
277        transformations::order_by(self, refs, ascending, self.case_sensitive)
278    }
279
280    /// Order by sort expressions (asc/desc with nulls_first/last).
281    pub fn order_by_exprs(&self, sort_orders: Vec<SortOrder>) -> Result<DataFrame, PolarsError> {
282        transformations::order_by_exprs(self, sort_orders, self.case_sensitive)
283    }
284
285    /// Union (unionAll): stack another DataFrame vertically. Schemas must match (same columns, same order).
286    pub fn union(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
287        transformations::union(self, other, self.case_sensitive)
288    }
289
290    /// Union by name: stack vertically, aligning columns by name.
291    pub fn union_by_name(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
292        transformations::union_by_name(self, other, self.case_sensitive)
293    }
294
295    /// Distinct: drop duplicate rows (all columns or optional subset).
296    pub fn distinct(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
297        transformations::distinct(self, subset, self.case_sensitive)
298    }
299
300    /// Drop one or more columns.
301    pub fn drop(&self, columns: Vec<&str>) -> Result<DataFrame, PolarsError> {
302        transformations::drop(self, columns, self.case_sensitive)
303    }
304
305    /// Drop rows with nulls (all columns or optional subset).
306    pub fn dropna(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
307        transformations::dropna(self, subset, self.case_sensitive)
308    }
309
310    /// Fill nulls with a literal expression (applied to all columns).
311    pub fn fillna(&self, value: Expr) -> Result<DataFrame, PolarsError> {
312        transformations::fillna(self, value, self.case_sensitive)
313    }
314
315    /// Limit: return first n rows.
316    pub fn limit(&self, n: usize) -> Result<DataFrame, PolarsError> {
317        transformations::limit(self, n, self.case_sensitive)
318    }
319
320    /// Rename a column (old_name -> new_name).
321    pub fn with_column_renamed(
322        &self,
323        old_name: &str,
324        new_name: &str,
325    ) -> Result<DataFrame, PolarsError> {
326        transformations::with_column_renamed(self, old_name, new_name, self.case_sensitive)
327    }
328
329    /// Replace values in a column (old_value -> new_value). PySpark replace.
330    pub fn replace(
331        &self,
332        column_name: &str,
333        old_value: Expr,
334        new_value: Expr,
335    ) -> Result<DataFrame, PolarsError> {
336        transformations::replace(self, column_name, old_value, new_value, self.case_sensitive)
337    }
338
339    /// Cross join with another DataFrame (cartesian product). PySpark crossJoin.
340    pub fn cross_join(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
341        transformations::cross_join(self, other, self.case_sensitive)
342    }
343
344    /// Summary statistics. PySpark describe.
345    pub fn describe(&self) -> Result<DataFrame, PolarsError> {
346        transformations::describe(self, self.case_sensitive)
347    }
348
349    /// No-op: execution is eager by default. PySpark cache.
350    pub fn cache(&self) -> Result<DataFrame, PolarsError> {
351        Ok(self.clone())
352    }
353
354    /// No-op: execution is eager by default. PySpark persist.
355    pub fn persist(&self) -> Result<DataFrame, PolarsError> {
356        Ok(self.clone())
357    }
358
359    /// No-op. PySpark unpersist.
360    pub fn unpersist(&self) -> Result<DataFrame, PolarsError> {
361        Ok(self.clone())
362    }
363
364    /// Set difference: rows in self not in other. PySpark subtract / except.
365    pub fn subtract(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
366        transformations::subtract(self, other, self.case_sensitive)
367    }
368
369    /// Set intersection: rows in both self and other. PySpark intersect.
370    pub fn intersect(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
371        transformations::intersect(self, other, self.case_sensitive)
372    }
373
374    /// Sample a fraction of rows. PySpark sample(withReplacement, fraction, seed).
375    pub fn sample(
376        &self,
377        with_replacement: bool,
378        fraction: f64,
379        seed: Option<u64>,
380    ) -> Result<DataFrame, PolarsError> {
381        transformations::sample(self, with_replacement, fraction, seed, self.case_sensitive)
382    }
383
384    /// Split into multiple DataFrames by weights. PySpark randomSplit(weights, seed).
385    pub fn random_split(
386        &self,
387        weights: &[f64],
388        seed: Option<u64>,
389    ) -> Result<Vec<DataFrame>, PolarsError> {
390        transformations::random_split(self, weights, seed, self.case_sensitive)
391    }
392
393    /// Stratified sample by column value. PySpark sampleBy(col, fractions, seed).
394    /// fractions: list of (value as Expr, fraction) for that stratum.
395    pub fn sample_by(
396        &self,
397        col_name: &str,
398        fractions: &[(Expr, f64)],
399        seed: Option<u64>,
400    ) -> Result<DataFrame, PolarsError> {
401        transformations::sample_by(self, col_name, fractions, seed, self.case_sensitive)
402    }
403
404    /// First row as a one-row DataFrame. PySpark first().
405    pub fn first(&self) -> Result<DataFrame, PolarsError> {
406        transformations::first(self, self.case_sensitive)
407    }
408
409    /// First n rows. PySpark head(n).
410    pub fn head(&self, n: usize) -> Result<DataFrame, PolarsError> {
411        transformations::head(self, n, self.case_sensitive)
412    }
413
414    /// Take first n rows. PySpark take(n).
415    pub fn take(&self, n: usize) -> Result<DataFrame, PolarsError> {
416        transformations::take(self, n, self.case_sensitive)
417    }
418
419    /// Last n rows. PySpark tail(n).
420    pub fn tail(&self, n: usize) -> Result<DataFrame, PolarsError> {
421        transformations::tail(self, n, self.case_sensitive)
422    }
423
424    /// True if the DataFrame has zero rows. PySpark isEmpty.
425    pub fn is_empty(&self) -> bool {
426        transformations::is_empty(self)
427    }
428
429    /// Rename columns. PySpark toDF(*colNames).
430    pub fn to_df(&self, names: Vec<&str>) -> Result<DataFrame, PolarsError> {
431        transformations::to_df(self, &names, self.case_sensitive)
432    }
433
434    /// Statistical helper. PySpark df.stat().cov / .corr.
435    pub fn stat(&self) -> DataFrameStat<'_> {
436        DataFrameStat { df: self }
437    }
438
439    /// Correlation matrix of all numeric columns. PySpark df.corr() returns a DataFrame of pairwise correlations.
440    pub fn corr(&self) -> Result<DataFrame, PolarsError> {
441        self.stat().corr_matrix()
442    }
443
444    /// Pearson correlation between two columns (scalar). PySpark df.corr(col1, col2).
445    pub fn corr_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
446        self.stat().corr(col1, col2)
447    }
448
449    /// Sample covariance between two columns (scalar). PySpark df.cov(col1, col2).
450    pub fn cov_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
451        self.stat().cov(col1, col2)
452    }
453
454    /// Summary statistics (alias for describe). PySpark summary.
455    pub fn summary(&self) -> Result<DataFrame, PolarsError> {
456        self.describe()
457    }
458
459    /// Collect rows as JSON strings (one per row). PySpark toJSON.
460    pub fn to_json(&self) -> Result<Vec<String>, PolarsError> {
461        transformations::to_json(self)
462    }
463
464    /// Return execution plan description. PySpark explain.
465    pub fn explain(&self) -> String {
466        transformations::explain(self)
467    }
468
469    /// Return schema as tree string. PySpark printSchema (returns string; print to stdout if needed).
470    pub fn print_schema(&self) -> Result<String, PolarsError> {
471        transformations::print_schema(self)
472    }
473
474    /// No-op: Polars backend is eager. PySpark checkpoint.
475    pub fn checkpoint(&self) -> Result<DataFrame, PolarsError> {
476        Ok(self.clone())
477    }
478
479    /// No-op: Polars backend is eager. PySpark localCheckpoint.
480    pub fn local_checkpoint(&self) -> Result<DataFrame, PolarsError> {
481        Ok(self.clone())
482    }
483
484    /// No-op: single partition in Polars. PySpark repartition(n).
485    pub fn repartition(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
486        Ok(self.clone())
487    }
488
489    /// No-op: Polars has no range partitioning. PySpark repartitionByRange(n, cols).
490    pub fn repartition_by_range(
491        &self,
492        _num_partitions: usize,
493        _cols: Vec<&str>,
494    ) -> Result<DataFrame, PolarsError> {
495        Ok(self.clone())
496    }
497
498    /// Column names and dtype strings. PySpark dtypes. Returns (name, dtype_string) per column.
499    pub fn dtypes(&self) -> Result<Vec<(String, String)>, PolarsError> {
500        let schema = self.df.schema();
501        Ok(schema
502            .iter_names_and_dtypes()
503            .map(|(name, dtype)| (name.to_string(), format!("{dtype:?}")))
504            .collect())
505    }
506
507    /// No-op: we don't model partitions. PySpark sortWithinPartitions. Same as orderBy for compatibility.
508    pub fn sort_within_partitions(
509        &self,
510        _cols: &[crate::functions::SortOrder],
511    ) -> Result<DataFrame, PolarsError> {
512        Ok(self.clone())
513    }
514
515    /// No-op: single partition in Polars. PySpark coalesce(n).
516    pub fn coalesce(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
517        Ok(self.clone())
518    }
519
520    /// No-op. PySpark hint (query planner hint).
521    pub fn hint(&self, _name: &str, _params: &[i32]) -> Result<DataFrame, PolarsError> {
522        Ok(self.clone())
523    }
524
525    /// Returns true (eager single-node). PySpark isLocal.
526    pub fn is_local(&self) -> bool {
527        true
528    }
529
530    /// Returns empty vec (no file sources). PySpark inputFiles.
531    pub fn input_files(&self) -> Vec<String> {
532        Vec::new()
533    }
534
535    /// No-op; returns false. PySpark sameSemantics.
536    pub fn same_semantics(&self, _other: &DataFrame) -> bool {
537        false
538    }
539
540    /// No-op; returns 0. PySpark semanticHash.
541    pub fn semantic_hash(&self) -> u64 {
542        0
543    }
544
545    /// No-op. PySpark observe (metrics).
546    pub fn observe(&self, _name: &str, _expr: Expr) -> Result<DataFrame, PolarsError> {
547        Ok(self.clone())
548    }
549
550    /// No-op. PySpark withWatermark (streaming).
551    pub fn with_watermark(
552        &self,
553        _event_time: &str,
554        _delay: &str,
555    ) -> Result<DataFrame, PolarsError> {
556        Ok(self.clone())
557    }
558
559    /// Select by expression strings (minimal: column names, optionally "col as alias"). PySpark selectExpr.
560    pub fn select_expr(&self, exprs: &[String]) -> Result<DataFrame, PolarsError> {
561        transformations::select_expr(self, exprs, self.case_sensitive)
562    }
563
564    /// Select columns whose names match the regex. PySpark colRegex.
565    pub fn col_regex(&self, pattern: &str) -> Result<DataFrame, PolarsError> {
566        transformations::col_regex(self, pattern, self.case_sensitive)
567    }
568
569    /// Add or replace multiple columns. PySpark withColumns. Accepts `Column` so rand/randn get per-row values.
570    pub fn with_columns(&self, exprs: &[(String, Column)]) -> Result<DataFrame, PolarsError> {
571        transformations::with_columns(self, exprs, self.case_sensitive)
572    }
573
574    /// Rename multiple columns. PySpark withColumnsRenamed.
575    pub fn with_columns_renamed(
576        &self,
577        renames: &[(String, String)],
578    ) -> Result<DataFrame, PolarsError> {
579        transformations::with_columns_renamed(self, renames, self.case_sensitive)
580    }
581
582    /// NA sub-API. PySpark df.na().
583    pub fn na(&self) -> DataFrameNa<'_> {
584        DataFrameNa { df: self }
585    }
586
587    /// Skip first n rows. PySpark offset(n).
588    pub fn offset(&self, n: usize) -> Result<DataFrame, PolarsError> {
589        transformations::offset(self, n, self.case_sensitive)
590    }
591
592    /// Transform by a function. PySpark transform(func).
593    pub fn transform<F>(&self, f: F) -> Result<DataFrame, PolarsError>
594    where
595        F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
596    {
597        transformations::transform(self, f)
598    }
599
600    /// Frequent items. PySpark freqItems (stub).
601    pub fn freq_items(&self, columns: &[&str], support: f64) -> Result<DataFrame, PolarsError> {
602        transformations::freq_items(self, columns, support, self.case_sensitive)
603    }
604
605    /// Approximate quantiles. PySpark approxQuantile (stub).
606    pub fn approx_quantile(
607        &self,
608        column: &str,
609        probabilities: &[f64],
610    ) -> Result<DataFrame, PolarsError> {
611        transformations::approx_quantile(self, column, probabilities, self.case_sensitive)
612    }
613
614    /// Cross-tabulation. PySpark crosstab (stub).
615    pub fn crosstab(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
616        transformations::crosstab(self, col1, col2, self.case_sensitive)
617    }
618
619    /// Unpivot (melt). PySpark melt (stub).
620    pub fn melt(&self, id_vars: &[&str], value_vars: &[&str]) -> Result<DataFrame, PolarsError> {
621        transformations::melt(self, id_vars, value_vars, self.case_sensitive)
622    }
623
624    /// Pivot (wide format). PySpark pivot. Stub: not yet implemented; use crosstab for two-column count.
625    pub fn pivot(
626        &self,
627        _pivot_col: &str,
628        _values: Option<Vec<&str>>,
629    ) -> Result<DataFrame, PolarsError> {
630        Err(PolarsError::InvalidOperation(
631            "pivot is not yet implemented; use crosstab(col1, col2) for two-column cross-tabulation."
632                .into(),
633        ))
634    }
635
636    /// Set difference keeping duplicates. PySpark exceptAll.
637    pub fn except_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
638        transformations::except_all(self, other, self.case_sensitive)
639    }
640
641    /// Set intersection keeping duplicates. PySpark intersectAll.
642    pub fn intersect_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
643        transformations::intersect_all(self, other, self.case_sensitive)
644    }
645
646    /// Write this DataFrame to a Delta table at the given path.
647    /// Requires the `delta` feature. If `overwrite` is true, replaces the table; otherwise appends.
648    #[cfg(feature = "delta")]
649    pub fn write_delta(
650        &self,
651        path: impl AsRef<std::path::Path>,
652        overwrite: bool,
653    ) -> Result<(), PolarsError> {
654        crate::delta::write_delta(self.df.as_ref(), path, overwrite)
655    }
656
657    /// Stub when `delta` feature is disabled.
658    #[cfg(not(feature = "delta"))]
659    pub fn write_delta(
660        &self,
661        _path: impl AsRef<std::path::Path>,
662        _overwrite: bool,
663    ) -> Result<(), PolarsError> {
664        Err(PolarsError::InvalidOperation(
665            "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
666        ))
667    }
668
669    /// Return a writer for generic format (parquet, csv, json). PySpark-style write API.
670    pub fn write(&self) -> DataFrameWriter<'_> {
671        DataFrameWriter {
672            df: self,
673            mode: WriteMode::Overwrite,
674            format: WriteFormat::Parquet,
675            options: HashMap::new(),
676            partition_by: Vec::new(),
677        }
678    }
679}
680
681/// Write mode: overwrite or append (PySpark DataFrameWriter.mode).
682#[derive(Clone, Copy, PartialEq, Eq)]
683pub enum WriteMode {
684    Overwrite,
685    Append,
686}
687
688/// Output format for generic write (PySpark DataFrameWriter.format).
689#[derive(Clone, Copy)]
690pub enum WriteFormat {
691    Parquet,
692    Csv,
693    Json,
694}
695
696/// Builder for writing DataFrame to path (PySpark DataFrameWriter).
697pub struct DataFrameWriter<'a> {
698    df: &'a DataFrame,
699    mode: WriteMode,
700    format: WriteFormat,
701    options: HashMap<String, String>,
702    partition_by: Vec<String>,
703}
704
705impl<'a> DataFrameWriter<'a> {
706    pub fn mode(mut self, mode: WriteMode) -> Self {
707        self.mode = mode;
708        self
709    }
710
711    pub fn format(mut self, format: WriteFormat) -> Self {
712        self.format = format;
713        self
714    }
715
716    /// Add a single option (PySpark: option(key, value)). Returns self for chaining.
717    pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
718        self.options.insert(key.into(), value.into());
719        self
720    }
721
722    /// Add multiple options (PySpark: options(**kwargs)). Returns self for chaining.
723    pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
724        for (k, v) in opts {
725            self.options.insert(k, v);
726        }
727        self
728    }
729
730    /// Partition output by the given columns (PySpark: partitionBy(cols)).
731    pub fn partition_by(mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> Self {
732        self.partition_by = cols.into_iter().map(|s| s.into()).collect();
733        self
734    }
735
736    /// Write as Parquet (PySpark: parquet(path)). Equivalent to format(Parquet).save(path).
737    pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
738        DataFrameWriter {
739            df: self.df,
740            mode: self.mode,
741            format: WriteFormat::Parquet,
742            options: self.options.clone(),
743            partition_by: self.partition_by.clone(),
744        }
745        .save(path)
746    }
747
748    /// Write as CSV (PySpark: csv(path)). Equivalent to format(Csv).save(path).
749    pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
750        DataFrameWriter {
751            df: self.df,
752            mode: self.mode,
753            format: WriteFormat::Csv,
754            options: self.options.clone(),
755            partition_by: self.partition_by.clone(),
756        }
757        .save(path)
758    }
759
760    /// Write as JSON lines (PySpark: json(path)). Equivalent to format(Json).save(path).
761    pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
762        DataFrameWriter {
763            df: self.df,
764            mode: self.mode,
765            format: WriteFormat::Json,
766            options: self.options.clone(),
767            partition_by: self.partition_by.clone(),
768        }
769        .save(path)
770    }
771
772    /// Write to path. Overwrite replaces; append reads existing (if any) and concatenates then writes.
773    /// With partition_by, path is a directory; each partition is written as path/col1=val1/col2=val2/... with partition columns omitted from the file (Spark/Hive style).
774    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
775        use polars::prelude::*;
776        let path = path.as_ref();
777        let to_write: PlDataFrame = match self.mode {
778            WriteMode::Overwrite => self.df.df.as_ref().clone(),
779            WriteMode::Append => {
780                if self.partition_by.is_empty() {
781                    let existing: Option<PlDataFrame> = if path.exists() && path.is_file() {
782                        match self.format {
783                            WriteFormat::Parquet => {
784                                LazyFrame::scan_parquet(path, ScanArgsParquet::default())
785                                    .and_then(|lf| lf.collect())
786                                    .ok()
787                            }
788                            WriteFormat::Csv => LazyCsvReader::new(path)
789                                .with_has_header(true)
790                                .finish()
791                                .and_then(|lf| lf.collect())
792                                .ok(),
793                            WriteFormat::Json => LazyJsonLineReader::new(path)
794                                .finish()
795                                .and_then(|lf| lf.collect())
796                                .ok(),
797                        }
798                    } else {
799                        None
800                    };
801                    match existing {
802                        Some(existing) => {
803                            let lfs: [LazyFrame; 2] =
804                                [existing.lazy(), self.df.df.as_ref().clone().lazy()];
805                            concat(lfs, UnionArgs::default())?.collect()?
806                        }
807                        None => self.df.df.as_ref().clone(),
808                    }
809                } else {
810                    self.df.df.as_ref().clone()
811                }
812            }
813        };
814
815        if !self.partition_by.is_empty() {
816            return self.save_partitioned(path, &to_write);
817        }
818
819        match self.format {
820            WriteFormat::Parquet => {
821                let mut file = std::fs::File::create(path).map_err(|e| {
822                    PolarsError::ComputeError(format!("write parquet create: {e}").into())
823                })?;
824                let mut df_mut = to_write;
825                ParquetWriter::new(&mut file)
826                    .finish(&mut df_mut)
827                    .map_err(|e| PolarsError::ComputeError(format!("write parquet: {e}").into()))?;
828            }
829            WriteFormat::Csv => {
830                let has_header = self
831                    .options
832                    .get("header")
833                    .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
834                    .unwrap_or(true);
835                let delimiter = self
836                    .options
837                    .get("sep")
838                    .and_then(|s| s.bytes().next())
839                    .unwrap_or(b',');
840                let mut file = std::fs::File::create(path).map_err(|e| {
841                    PolarsError::ComputeError(format!("write csv create: {e}").into())
842                })?;
843                CsvWriter::new(&mut file)
844                    .include_header(has_header)
845                    .with_separator(delimiter)
846                    .finish(&mut to_write.clone())
847                    .map_err(|e| PolarsError::ComputeError(format!("write csv: {e}").into()))?;
848            }
849            WriteFormat::Json => {
850                let mut file = std::fs::File::create(path).map_err(|e| {
851                    PolarsError::ComputeError(format!("write json create: {e}").into())
852                })?;
853                JsonWriter::new(&mut file)
854                    .finish(&mut to_write.clone())
855                    .map_err(|e| PolarsError::ComputeError(format!("write json: {e}").into()))?;
856            }
857        }
858        Ok(())
859    }
860
861    /// Write partitioned by columns: path/col1=val1/col2=val2/part-00000.{ext}. Partition columns are not written into the file (Spark/Hive style).
862    fn save_partitioned(&self, path: &Path, to_write: &PlDataFrame) -> Result<(), PolarsError> {
863        use polars::prelude::*;
864        let resolved: Vec<String> = self
865            .partition_by
866            .iter()
867            .map(|c| self.df.resolve_column_name(c))
868            .collect::<Result<Vec<_>, _>>()?;
869        let all_names = to_write.get_column_names();
870        let data_cols: Vec<&str> = all_names
871            .iter()
872            .filter(|n| !resolved.iter().any(|r| r == n.as_str()))
873            .map(|n| n.as_str())
874            .collect();
875
876        let unique_keys = to_write
877            .select(resolved.iter().map(|s| s.as_str()).collect::<Vec<_>>())?
878            .unique::<Option<&[String]>, String>(
879                None,
880                polars::prelude::UniqueKeepStrategy::First,
881                None,
882            )?;
883
884        if self.mode == WriteMode::Overwrite && path.exists() {
885            if path.is_dir() {
886                std::fs::remove_dir_all(path).map_err(|e| {
887                    PolarsError::ComputeError(
888                        format!("write partitioned: remove_dir_all: {e}").into(),
889                    )
890                })?;
891            } else {
892                std::fs::remove_file(path).map_err(|e| {
893                    PolarsError::ComputeError(format!("write partitioned: remove_file: {e}").into())
894                })?;
895            }
896        }
897        std::fs::create_dir_all(path).map_err(|e| {
898            PolarsError::ComputeError(format!("write partitioned: create_dir_all: {e}").into())
899        })?;
900
901        let ext = match self.format {
902            WriteFormat::Parquet => "parquet",
903            WriteFormat::Csv => "csv",
904            WriteFormat::Json => "json",
905        };
906
907        for row_idx in 0..unique_keys.height() {
908            let row = unique_keys
909                .get(row_idx)
910                .ok_or_else(|| PolarsError::ComputeError("partition_row: get row".into()))?;
911            let filter_expr = partition_row_to_filter_expr(&resolved, &row)?;
912            let subset = to_write.clone().lazy().filter(filter_expr).collect()?;
913            let subset = subset.select(data_cols.iter().copied())?;
914            if subset.height() == 0 {
915                continue;
916            }
917
918            let part_path: std::path::PathBuf = resolved
919                .iter()
920                .zip(row.iter())
921                .map(|(name, av)| format!("{}={}", name, format_partition_value(av)))
922                .fold(path.to_path_buf(), |p, seg| p.join(seg));
923            std::fs::create_dir_all(&part_path).map_err(|e| {
924                PolarsError::ComputeError(
925                    format!("write partitioned: create_dir_all partition: {e}").into(),
926                )
927            })?;
928
929            let file_idx = if self.mode == WriteMode::Append {
930                let suffix = format!(".{ext}");
931                let max_n = std::fs::read_dir(&part_path)
932                    .map(|rd| {
933                        rd.filter_map(Result::ok)
934                            .filter_map(|e| {
935                                e.file_name().to_str().and_then(|s| {
936                                    s.strip_prefix("part-")
937                                        .and_then(|t| t.strip_suffix(&suffix))
938                                        .and_then(|t| t.parse::<u32>().ok())
939                                })
940                            })
941                            .max()
942                            .unwrap_or(0)
943                    })
944                    .unwrap_or(0);
945                max_n + 1
946            } else {
947                0
948            };
949            let filename = format!("part-{file_idx:05}.{ext}");
950            let file_path = part_path.join(&filename);
951
952            match self.format {
953                WriteFormat::Parquet => {
954                    let mut file = std::fs::File::create(&file_path).map_err(|e| {
955                        PolarsError::ComputeError(
956                            format!("write partitioned parquet create: {e}").into(),
957                        )
958                    })?;
959                    let mut df_mut = subset;
960                    ParquetWriter::new(&mut file)
961                        .finish(&mut df_mut)
962                        .map_err(|e| {
963                            PolarsError::ComputeError(
964                                format!("write partitioned parquet: {e}").into(),
965                            )
966                        })?;
967                }
968                WriteFormat::Csv => {
969                    let has_header = self
970                        .options
971                        .get("header")
972                        .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
973                        .unwrap_or(true);
974                    let delimiter = self
975                        .options
976                        .get("sep")
977                        .and_then(|s| s.bytes().next())
978                        .unwrap_or(b',');
979                    let mut file = std::fs::File::create(&file_path).map_err(|e| {
980                        PolarsError::ComputeError(
981                            format!("write partitioned csv create: {e}").into(),
982                        )
983                    })?;
984                    CsvWriter::new(&mut file)
985                        .include_header(has_header)
986                        .with_separator(delimiter)
987                        .finish(&mut subset.clone())
988                        .map_err(|e| {
989                            PolarsError::ComputeError(format!("write partitioned csv: {e}").into())
990                        })?;
991                }
992                WriteFormat::Json => {
993                    let mut file = std::fs::File::create(&file_path).map_err(|e| {
994                        PolarsError::ComputeError(
995                            format!("write partitioned json create: {e}").into(),
996                        )
997                    })?;
998                    JsonWriter::new(&mut file)
999                        .finish(&mut subset.clone())
1000                        .map_err(|e| {
1001                            PolarsError::ComputeError(format!("write partitioned json: {e}").into())
1002                        })?;
1003                }
1004            }
1005        }
1006        Ok(())
1007    }
1008}
1009
1010impl Clone for DataFrame {
1011    fn clone(&self) -> Self {
1012        DataFrame {
1013            df: self.df.clone(),
1014            case_sensitive: self.case_sensitive,
1015        }
1016    }
1017}
1018
1019/// Format a partition column value for use in a directory name (Spark/Hive style).
1020/// Null becomes "__HIVE_DEFAULT_PARTITION__"; other values use string representation with path-unsafe chars replaced.
1021fn format_partition_value(av: &AnyValue<'_>) -> String {
1022    let s = match av {
1023        AnyValue::Null => "__HIVE_DEFAULT_PARTITION__".to_string(),
1024        AnyValue::Boolean(b) => b.to_string(),
1025        AnyValue::Int32(i) => i.to_string(),
1026        AnyValue::Int64(i) => i.to_string(),
1027        AnyValue::UInt32(u) => u.to_string(),
1028        AnyValue::UInt64(u) => u.to_string(),
1029        AnyValue::Float32(f) => f.to_string(),
1030        AnyValue::Float64(f) => f.to_string(),
1031        AnyValue::String(s) => s.to_string(),
1032        AnyValue::StringOwned(s) => s.as_str().to_string(),
1033        AnyValue::Date(d) => d.to_string(),
1034        _ => av.to_string(),
1035    };
1036    // Replace path separators and other unsafe chars so the value is a valid path segment
1037    s.replace([std::path::MAIN_SEPARATOR, '/'], "_")
1038}
1039
1040/// Build a filter expression that matches rows where partition columns equal the given row values.
1041fn partition_row_to_filter_expr(
1042    col_names: &[String],
1043    row: &[AnyValue<'_>],
1044) -> Result<Expr, PolarsError> {
1045    if col_names.len() != row.len() {
1046        return Err(PolarsError::ComputeError(
1047            format!(
1048                "partition_row_to_filter_expr: {} columns but {} row values",
1049                col_names.len(),
1050                row.len()
1051            )
1052            .into(),
1053        ));
1054    }
1055    let mut pred = None::<Expr>;
1056    for (name, av) in col_names.iter().zip(row.iter()) {
1057        let clause = match av {
1058            AnyValue::Null => col(name.as_str()).is_null(),
1059            AnyValue::Boolean(b) => col(name.as_str()).eq(lit(*b)),
1060            AnyValue::Int32(i) => col(name.as_str()).eq(lit(*i)),
1061            AnyValue::Int64(i) => col(name.as_str()).eq(lit(*i)),
1062            AnyValue::UInt32(u) => col(name.as_str()).eq(lit(*u)),
1063            AnyValue::UInt64(u) => col(name.as_str()).eq(lit(*u)),
1064            AnyValue::Float32(f) => col(name.as_str()).eq(lit(*f)),
1065            AnyValue::Float64(f) => col(name.as_str()).eq(lit(*f)),
1066            AnyValue::String(s) => col(name.as_str()).eq(lit(s.to_string())),
1067            AnyValue::StringOwned(s) => col(name.as_str()).eq(lit(s.clone())),
1068            _ => {
1069                // Fallback: compare as string
1070                let s = av.to_string();
1071                col(name.as_str()).cast(DataType::String).eq(lit(s))
1072            }
1073        };
1074        pred = Some(match pred {
1075            None => clause,
1076            Some(p) => p.and(clause),
1077        });
1078    }
1079    Ok(pred.unwrap_or_else(|| lit(true)))
1080}
1081
1082/// Convert Polars AnyValue to serde_json::Value for language bindings (Node, etc.).
1083fn any_value_to_json(av: AnyValue<'_>) -> JsonValue {
1084    match av {
1085        AnyValue::Null => JsonValue::Null,
1086        AnyValue::Boolean(b) => JsonValue::Bool(b),
1087        AnyValue::Int32(i) => JsonValue::Number(serde_json::Number::from(i)),
1088        AnyValue::Int64(i) => JsonValue::Number(serde_json::Number::from(i)),
1089        AnyValue::UInt32(u) => JsonValue::Number(serde_json::Number::from(u)),
1090        AnyValue::UInt64(u) => JsonValue::Number(serde_json::Number::from(u)),
1091        AnyValue::Float32(f) => serde_json::Number::from_f64(f64::from(f))
1092            .map(JsonValue::Number)
1093            .unwrap_or(JsonValue::Null),
1094        AnyValue::Float64(f) => serde_json::Number::from_f64(f)
1095            .map(JsonValue::Number)
1096            .unwrap_or(JsonValue::Null),
1097        AnyValue::String(s) => JsonValue::String(s.to_string()),
1098        AnyValue::StringOwned(s) => JsonValue::String(s.to_string()),
1099        _ => JsonValue::Null,
1100    }
1101}