Skip to main content

robin_sparkless/dataframe/
mod.rs

1//! DataFrame module: main tabular type and submodules for transformations, aggregations, joins, stats.
2
3mod aggregations;
4mod joins;
5mod stats;
6mod transformations;
7
8pub use aggregations::{CubeRollupData, GroupedData};
9pub use joins::{join, JoinType};
10pub use stats::DataFrameStat;
11pub use transformations::{
12    filter, order_by, order_by_exprs, select, select_with_exprs, with_column, DataFrameNa,
13};
14
15use crate::column::Column;
16use crate::functions::SortOrder;
17use crate::schema::StructType;
18use crate::session::SparkSession;
19use crate::type_coercion::coerce_for_pyspark_comparison;
20use polars::prelude::{
21    col, lit, AnyValue, DataFrame as PlDataFrame, DataType, Expr, PlSmallStr, PolarsError,
22    SchemaNamesAndDtypes,
23};
24use serde_json::Value as JsonValue;
25use std::collections::{HashMap, HashSet};
26use std::path::Path;
27use std::sync::Arc;
28
29/// Default for `spark.sql.caseSensitive` (PySpark default is false = case-insensitive).
30const DEFAULT_CASE_SENSITIVE: bool = false;
31
32/// DataFrame - main tabular data structure.
33/// Thin wrapper around an eager Polars `DataFrame`.
34pub struct DataFrame {
35    pub(crate) df: Arc<PlDataFrame>,
36    /// When false (default), column names are matched case-insensitively (PySpark behavior).
37    pub(crate) case_sensitive: bool,
38}
39
40impl DataFrame {
41    /// Create a new DataFrame from a Polars DataFrame (case-insensitive column matching by default).
42    pub fn from_polars(df: PlDataFrame) -> Self {
43        DataFrame {
44            df: Arc::new(df),
45            case_sensitive: DEFAULT_CASE_SENSITIVE,
46        }
47    }
48
49    /// Create a new DataFrame from a Polars DataFrame with explicit case sensitivity.
50    /// When `case_sensitive` is false, column resolution is case-insensitive (PySpark default).
51    pub fn from_polars_with_options(df: PlDataFrame, case_sensitive: bool) -> Self {
52        DataFrame {
53            df: Arc::new(df),
54            case_sensitive,
55        }
56    }
57
58    /// Create an empty DataFrame
59    pub fn empty() -> Self {
60        DataFrame {
61            df: Arc::new(PlDataFrame::empty()),
62            case_sensitive: DEFAULT_CASE_SENSITIVE,
63        }
64    }
65
66    /// Resolve column names in a Polars expression against this DataFrame's schema.
67    /// When case_sensitive is false, column references (e.g. col("name")) are resolved
68    /// case-insensitively (PySpark default). Use before filter/select_with_exprs/order_by_exprs.
69    /// Names that appear as alias outputs (e.g. in expr.alias("partial")) are not resolved
70    /// as input columns, so select(col("x").substr(1, 3).alias("partial")),
71    /// when().then().otherwise().alias("result"), and col("x").rank().over([]).alias("rank") work (issues #200, #212).
72    pub fn resolve_expr_column_names(&self, expr: Expr) -> Result<Expr, PolarsError> {
73        let df = self;
74        let mut alias_output_names: HashSet<String> = HashSet::new();
75        let _ = expr.clone().try_map_expr(|e| {
76            if let Expr::Alias(_, name) = &e {
77                alias_output_names.insert(name.as_str().to_string());
78            }
79            Ok(e)
80        })?;
81        expr.try_map_expr(move |e| {
82            if let Expr::Column(name) = &e {
83                let name_str = name.as_str();
84                if alias_output_names.contains(name_str) {
85                    return Ok(e);
86                }
87                let resolved = df.resolve_column_name(name_str)?;
88                return Ok(Expr::Column(PlSmallStr::from(resolved.as_str())));
89            }
90            Ok(e)
91        })
92    }
93
94    /// Rewrite comparison expressions to apply PySpark-style type coercion.
95    ///
96    /// This walks the expression tree and, for comparison operators where one side is
97    /// a column and the other is a numeric literal, delegates to
98    /// `coerce_for_pyspark_comparison` so that string–numeric comparisons behave like
99    /// PySpark (string values parsed to numbers where possible, invalid strings treated
100    /// as null/non-matching).
101    pub fn coerce_string_numeric_comparisons(&self, expr: Expr) -> Result<Expr, PolarsError> {
102        use polars::prelude::{DataType, LiteralValue, Operator};
103        use std::sync::Arc;
104
105        fn is_numeric_literal(expr: &Expr) -> bool {
106            matches!(
107                expr,
108                Expr::Literal(
109                    LiteralValue::Int32(_)
110                        | LiteralValue::Int64(_)
111                        | LiteralValue::UInt32(_)
112                        | LiteralValue::UInt64(_)
113                        | LiteralValue::Float32(_)
114                        | LiteralValue::Float64(_)
115                        | LiteralValue::Int(_)   // dynamic int (e.g. lit(123) from some code paths)
116                        | LiteralValue::Float(_) // dynamic float
117                )
118            )
119        }
120
121        fn literal_dtype(lv: &LiteralValue) -> DataType {
122            match lv {
123                LiteralValue::Int32(_) => DataType::Int32,
124                LiteralValue::Int64(_) => DataType::Int64,
125                LiteralValue::UInt32(_) => DataType::UInt32,
126                LiteralValue::UInt64(_) => DataType::UInt64,
127                LiteralValue::Float32(_) => DataType::Float32,
128                LiteralValue::Float64(_) => DataType::Float64,
129                LiteralValue::Int(_) | LiteralValue::Float(_) => DataType::Float64,
130                _ => DataType::Float64,
131            }
132        }
133
134        // Apply root-level coercion first so the top-level filter condition (e.g. col("str_col") == lit(123))
135        // is always rewritten even if try_map_expr traversal does not hit the root in the expected order.
136        let expr = {
137            if let Expr::BinaryExpr { left, op, right } = &expr {
138                let is_comparison_op = matches!(
139                    op,
140                    Operator::Eq
141                        | Operator::NotEq
142                        | Operator::Lt
143                        | Operator::LtEq
144                        | Operator::Gt
145                        | Operator::GtEq
146                );
147                let left_is_col = matches!(&**left, Expr::Column(_));
148                let right_is_col = matches!(&**right, Expr::Column(_));
149                let left_is_numeric_lit =
150                    matches!(&**left, Expr::Literal(_)) && is_numeric_literal(left.as_ref());
151                let right_is_numeric_lit =
152                    matches!(&**right, Expr::Literal(_)) && is_numeric_literal(right.as_ref());
153                let root_is_col_vs_numeric = is_comparison_op
154                    && ((left_is_col && right_is_numeric_lit)
155                        || (right_is_col && left_is_numeric_lit));
156                if root_is_col_vs_numeric {
157                    let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
158                        let lit_ty = match &**right {
159                            Expr::Literal(lv) => literal_dtype(lv),
160                            _ => DataType::Float64,
161                        };
162                        coerce_for_pyspark_comparison(
163                            (*left).as_ref().clone(),
164                            (*right).as_ref().clone(),
165                            &DataType::String,
166                            &lit_ty,
167                            op,
168                        )
169                        .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
170                    } else {
171                        let lit_ty = match &**left {
172                            Expr::Literal(lv) => literal_dtype(lv),
173                            _ => DataType::Float64,
174                        };
175                        coerce_for_pyspark_comparison(
176                            (*left).as_ref().clone(),
177                            (*right).as_ref().clone(),
178                            &lit_ty,
179                            &DataType::String,
180                            op,
181                        )
182                        .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
183                    };
184                    Expr::BinaryExpr {
185                        left: Arc::new(new_left),
186                        op: *op,
187                        right: Arc::new(new_right),
188                    }
189                } else {
190                    expr
191                }
192            } else {
193                expr
194            }
195        };
196
197        // Then walk the tree for nested comparisons (e.g. (col("a")==1) & (col("b")==2)).
198        expr.try_map_expr(move |e| {
199            if let Expr::BinaryExpr { left, op, right } = e {
200                let is_comparison_op = matches!(
201                    op,
202                    Operator::Eq
203                        | Operator::NotEq
204                        | Operator::Lt
205                        | Operator::LtEq
206                        | Operator::Gt
207                        | Operator::GtEq
208                );
209                if !is_comparison_op {
210                    return Ok(Expr::BinaryExpr { left, op, right });
211                }
212
213                let left_is_col = matches!(&*left, Expr::Column(_));
214                let right_is_col = matches!(&*right, Expr::Column(_));
215                let left_is_lit = matches!(&*left, Expr::Literal(_));
216                let right_is_lit = matches!(&*right, Expr::Literal(_));
217
218                let left_is_numeric_lit = left_is_lit && is_numeric_literal(left.as_ref());
219                let right_is_numeric_lit = right_is_lit && is_numeric_literal(right.as_ref());
220
221                // Heuristic: for column-vs-numeric-literal, treat the column as "string-like"
222                // and the literal as numeric, so coerce_for_pyspark_comparison will route
223                // the column through try_to_number and compare as doubles.
224                let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
225                    let lit_ty = match &*right {
226                        Expr::Literal(lv) => literal_dtype(lv),
227                        _ => DataType::Float64,
228                    };
229                    coerce_for_pyspark_comparison(
230                        (*left).clone(),
231                        (*right).clone(),
232                        &DataType::String,
233                        &lit_ty,
234                        &op,
235                    )
236                    .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
237                } else if right_is_col && left_is_numeric_lit {
238                    let lit_ty = match &*left {
239                        Expr::Literal(lv) => literal_dtype(lv),
240                        _ => DataType::Float64,
241                    };
242                    coerce_for_pyspark_comparison(
243                        (*left).clone(),
244                        (*right).clone(),
245                        &lit_ty,
246                        &DataType::String,
247                        &op,
248                    )
249                    .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
250                } else {
251                    // Leave other comparison forms (col-col, lit-lit, non-numeric) unchanged.
252                    return Ok(Expr::BinaryExpr { left, op, right });
253                };
254
255                Ok(Expr::BinaryExpr {
256                    left: Arc::new(new_left),
257                    op,
258                    right: Arc::new(new_right),
259                })
260            } else {
261                Ok(e)
262            }
263        })
264    }
265
266    /// Resolve a logical column name to the actual column name in the schema.
267    /// When case_sensitive is false, matches case-insensitively.
268    pub fn resolve_column_name(&self, name: &str) -> Result<String, PolarsError> {
269        let names = self.df.get_column_names();
270        if self.case_sensitive {
271            if names.iter().any(|n| *n == name) {
272                return Ok(name.to_string());
273            }
274        } else {
275            let name_lower = name.to_lowercase();
276            for n in names {
277                if n.to_lowercase() == name_lower {
278                    return Ok(n.to_string());
279                }
280            }
281        }
282        let available: Vec<String> = self
283            .df
284            .get_column_names()
285            .iter()
286            .map(|s| s.to_string())
287            .collect();
288        Err(PolarsError::ColumnNotFound(
289            format!(
290                "Column '{}' not found. Available columns: [{}]. Check spelling and case sensitivity (spark.sql.caseSensitive).",
291                name,
292                available.join(", ")
293            )
294            .into(),
295        ))
296    }
297
298    /// Get the schema of the DataFrame
299    pub fn schema(&self) -> Result<StructType, PolarsError> {
300        Ok(StructType::from_polars_schema(&self.df.schema()))
301    }
302
303    /// Get column names
304    pub fn columns(&self) -> Result<Vec<String>, PolarsError> {
305        Ok(self
306            .df
307            .get_column_names()
308            .iter()
309            .map(|s| s.to_string())
310            .collect())
311    }
312
313    /// Count the number of rows (action - triggers execution)
314    pub fn count(&self) -> Result<usize, PolarsError> {
315        Ok(self.df.height())
316    }
317
318    /// Show the first n rows
319    pub fn show(&self, n: Option<usize>) -> Result<(), PolarsError> {
320        let n = n.unwrap_or(20);
321        println!("{}", self.df.head(Some(n)));
322        Ok(())
323    }
324
325    /// Collect the DataFrame (action - triggers execution)
326    pub fn collect(&self) -> Result<Arc<PlDataFrame>, PolarsError> {
327        Ok(self.df.clone())
328    }
329
330    /// Collect as rows of column-name -> JSON value. For use by language bindings (Node, etc.).
331    pub fn collect_as_json_rows(&self) -> Result<Vec<HashMap<String, JsonValue>>, PolarsError> {
332        let df = self.df.as_ref();
333        let names = df.get_column_names();
334        let nrows = df.height();
335        let mut rows = Vec::with_capacity(nrows);
336        for i in 0..nrows {
337            let mut row = HashMap::with_capacity(names.len());
338            for (col_idx, name) in names.iter().enumerate() {
339                let s = df
340                    .get_columns()
341                    .get(col_idx)
342                    .ok_or_else(|| PolarsError::ComputeError("column index out of range".into()))?;
343                let av = s.get(i)?;
344                let jv = any_value_to_json(av);
345                row.insert(name.to_string(), jv);
346            }
347            rows.push(row);
348        }
349        Ok(rows)
350    }
351
352    /// Select columns (returns a new DataFrame).
353    /// Accepts either column names (strings) or Column expressions (e.g. from regexp_extract_all(...).alias("m")).
354    /// Column names are resolved according to case sensitivity.
355    pub fn select_exprs(&self, exprs: Vec<Expr>) -> Result<DataFrame, PolarsError> {
356        transformations::select_with_exprs(self, exprs, self.case_sensitive)
357    }
358
359    /// Select columns by name (returns a new DataFrame).
360    /// Column names are resolved according to case sensitivity.
361    pub fn select(&self, cols: Vec<&str>) -> Result<DataFrame, PolarsError> {
362        let resolved: Vec<String> = cols
363            .iter()
364            .map(|c| self.resolve_column_name(c))
365            .collect::<Result<Vec<_>, _>>()?;
366        let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
367        let mut result = transformations::select(self, refs, self.case_sensitive)?;
368        // When case-insensitive, PySpark returns column names in requested (e.g. lowercase) form.
369        if !self.case_sensitive {
370            for (requested, res) in cols.iter().zip(resolved.iter()) {
371                if *requested != res.as_str() {
372                    result = result.with_column_renamed(res, requested)?;
373                }
374            }
375        }
376        Ok(result)
377    }
378
379    /// Filter rows using a Polars expression.
380    pub fn filter(&self, condition: Expr) -> Result<DataFrame, PolarsError> {
381        transformations::filter(self, condition, self.case_sensitive)
382    }
383
384    /// Get a column reference by name (for building expressions).
385    /// Respects case sensitivity: when false, "Age" resolves to column "age" if present.
386    pub fn column(&self, name: &str) -> Result<Column, PolarsError> {
387        let resolved = self.resolve_column_name(name)?;
388        Ok(Column::new(resolved))
389    }
390
391    /// Add or replace a column. Use a [`Column`] (e.g. from `col("x")`, `rand(42)`, `randn(42)`).
392    /// For `rand`/`randn`, generates one distinct value per row (PySpark-like).
393    pub fn with_column(&self, column_name: &str, col: &Column) -> Result<DataFrame, PolarsError> {
394        transformations::with_column(self, column_name, col, self.case_sensitive)
395    }
396
397    /// Add or replace a column using an expression. Prefer [`with_column`](Self::with_column) with a `Column` for rand/randn (per-row values).
398    pub fn with_column_expr(
399        &self,
400        column_name: &str,
401        expr: Expr,
402    ) -> Result<DataFrame, PolarsError> {
403        let col = Column::from_expr(expr, None);
404        self.with_column(column_name, &col)
405    }
406
407    /// Group by columns (returns GroupedData for aggregation).
408    /// Column names are resolved according to case sensitivity.
409    pub fn group_by(&self, column_names: Vec<&str>) -> Result<GroupedData, PolarsError> {
410        use polars::prelude::*;
411        let resolved: Vec<String> = column_names
412            .iter()
413            .map(|c| self.resolve_column_name(c))
414            .collect::<Result<Vec<_>, _>>()?;
415        let exprs: Vec<Expr> = resolved.iter().map(|name| col(name.as_str())).collect();
416        let pl_df = self.df.as_ref().clone();
417        let lazy_grouped = pl_df.clone().lazy().group_by(exprs);
418        Ok(GroupedData {
419            df: pl_df,
420            lazy_grouped,
421            grouping_cols: resolved,
422            case_sensitive: self.case_sensitive,
423        })
424    }
425
426    /// Cube: multiple grouping sets (all subsets of columns), then union (PySpark cube).
427    pub fn cube(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
428        let resolved: Vec<String> = column_names
429            .iter()
430            .map(|c| self.resolve_column_name(c))
431            .collect::<Result<Vec<_>, _>>()?;
432        Ok(CubeRollupData {
433            df: self.df.as_ref().clone(),
434            grouping_cols: resolved,
435            case_sensitive: self.case_sensitive,
436            is_cube: true,
437        })
438    }
439
440    /// Rollup: grouping sets (prefixes of columns), then union (PySpark rollup).
441    pub fn rollup(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
442        let resolved: Vec<String> = column_names
443            .iter()
444            .map(|c| self.resolve_column_name(c))
445            .collect::<Result<Vec<_>, _>>()?;
446        Ok(CubeRollupData {
447            df: self.df.as_ref().clone(),
448            grouping_cols: resolved,
449            case_sensitive: self.case_sensitive,
450            is_cube: false,
451        })
452    }
453
454    /// Join with another DataFrame on the given columns.
455    /// Join column names are resolved on the left (and right must have matching names).
456    pub fn join(
457        &self,
458        other: &DataFrame,
459        on: Vec<&str>,
460        how: JoinType,
461    ) -> Result<DataFrame, PolarsError> {
462        let resolved: Vec<String> = on
463            .iter()
464            .map(|c| self.resolve_column_name(c))
465            .collect::<Result<Vec<_>, _>>()?;
466        let on_refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
467        join(self, other, on_refs, how, self.case_sensitive)
468    }
469
470    /// Order by columns (sort).
471    /// Column names are resolved according to case sensitivity.
472    pub fn order_by(
473        &self,
474        column_names: Vec<&str>,
475        ascending: Vec<bool>,
476    ) -> Result<DataFrame, PolarsError> {
477        let resolved: Vec<String> = column_names
478            .iter()
479            .map(|c| self.resolve_column_name(c))
480            .collect::<Result<Vec<_>, _>>()?;
481        let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
482        transformations::order_by(self, refs, ascending, self.case_sensitive)
483    }
484
485    /// Order by sort expressions (asc/desc with nulls_first/last).
486    pub fn order_by_exprs(&self, sort_orders: Vec<SortOrder>) -> Result<DataFrame, PolarsError> {
487        transformations::order_by_exprs(self, sort_orders, self.case_sensitive)
488    }
489
490    /// Union (unionAll): stack another DataFrame vertically. Schemas must match (same columns, same order).
491    pub fn union(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
492        transformations::union(self, other, self.case_sensitive)
493    }
494
495    /// Union by name: stack vertically, aligning columns by name.
496    pub fn union_by_name(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
497        transformations::union_by_name(self, other, self.case_sensitive)
498    }
499
500    /// Distinct: drop duplicate rows (all columns or optional subset).
501    pub fn distinct(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
502        transformations::distinct(self, subset, self.case_sensitive)
503    }
504
505    /// Drop one or more columns.
506    pub fn drop(&self, columns: Vec<&str>) -> Result<DataFrame, PolarsError> {
507        transformations::drop(self, columns, self.case_sensitive)
508    }
509
510    /// Drop rows with nulls (all columns or optional subset).
511    pub fn dropna(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
512        transformations::dropna(self, subset, self.case_sensitive)
513    }
514
515    /// Fill nulls with a literal expression (applied to all columns).
516    pub fn fillna(&self, value: Expr) -> Result<DataFrame, PolarsError> {
517        transformations::fillna(self, value, self.case_sensitive)
518    }
519
520    /// Limit: return first n rows.
521    pub fn limit(&self, n: usize) -> Result<DataFrame, PolarsError> {
522        transformations::limit(self, n, self.case_sensitive)
523    }
524
525    /// Rename a column (old_name -> new_name).
526    pub fn with_column_renamed(
527        &self,
528        old_name: &str,
529        new_name: &str,
530    ) -> Result<DataFrame, PolarsError> {
531        transformations::with_column_renamed(self, old_name, new_name, self.case_sensitive)
532    }
533
534    /// Replace values in a column (old_value -> new_value). PySpark replace.
535    pub fn replace(
536        &self,
537        column_name: &str,
538        old_value: Expr,
539        new_value: Expr,
540    ) -> Result<DataFrame, PolarsError> {
541        transformations::replace(self, column_name, old_value, new_value, self.case_sensitive)
542    }
543
544    /// Cross join with another DataFrame (cartesian product). PySpark crossJoin.
545    pub fn cross_join(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
546        transformations::cross_join(self, other, self.case_sensitive)
547    }
548
549    /// Summary statistics. PySpark describe.
550    pub fn describe(&self) -> Result<DataFrame, PolarsError> {
551        transformations::describe(self, self.case_sensitive)
552    }
553
554    /// No-op: execution is eager by default. PySpark cache.
555    pub fn cache(&self) -> Result<DataFrame, PolarsError> {
556        Ok(self.clone())
557    }
558
559    /// No-op: execution is eager by default. PySpark persist.
560    pub fn persist(&self) -> Result<DataFrame, PolarsError> {
561        Ok(self.clone())
562    }
563
564    /// No-op. PySpark unpersist.
565    pub fn unpersist(&self) -> Result<DataFrame, PolarsError> {
566        Ok(self.clone())
567    }
568
569    /// Set difference: rows in self not in other. PySpark subtract / except.
570    pub fn subtract(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
571        transformations::subtract(self, other, self.case_sensitive)
572    }
573
574    /// Set intersection: rows in both self and other. PySpark intersect.
575    pub fn intersect(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
576        transformations::intersect(self, other, self.case_sensitive)
577    }
578
579    /// Sample a fraction of rows. PySpark sample(withReplacement, fraction, seed).
580    pub fn sample(
581        &self,
582        with_replacement: bool,
583        fraction: f64,
584        seed: Option<u64>,
585    ) -> Result<DataFrame, PolarsError> {
586        transformations::sample(self, with_replacement, fraction, seed, self.case_sensitive)
587    }
588
589    /// Split into multiple DataFrames by weights. PySpark randomSplit(weights, seed).
590    pub fn random_split(
591        &self,
592        weights: &[f64],
593        seed: Option<u64>,
594    ) -> Result<Vec<DataFrame>, PolarsError> {
595        transformations::random_split(self, weights, seed, self.case_sensitive)
596    }
597
598    /// Stratified sample by column value. PySpark sampleBy(col, fractions, seed).
599    /// fractions: list of (value as Expr, fraction) for that stratum.
600    pub fn sample_by(
601        &self,
602        col_name: &str,
603        fractions: &[(Expr, f64)],
604        seed: Option<u64>,
605    ) -> Result<DataFrame, PolarsError> {
606        transformations::sample_by(self, col_name, fractions, seed, self.case_sensitive)
607    }
608
609    /// First row as a one-row DataFrame. PySpark first().
610    pub fn first(&self) -> Result<DataFrame, PolarsError> {
611        transformations::first(self, self.case_sensitive)
612    }
613
614    /// First n rows. PySpark head(n).
615    pub fn head(&self, n: usize) -> Result<DataFrame, PolarsError> {
616        transformations::head(self, n, self.case_sensitive)
617    }
618
619    /// Take first n rows. PySpark take(n).
620    pub fn take(&self, n: usize) -> Result<DataFrame, PolarsError> {
621        transformations::take(self, n, self.case_sensitive)
622    }
623
624    /// Last n rows. PySpark tail(n).
625    pub fn tail(&self, n: usize) -> Result<DataFrame, PolarsError> {
626        transformations::tail(self, n, self.case_sensitive)
627    }
628
629    /// True if the DataFrame has zero rows. PySpark isEmpty.
630    pub fn is_empty(&self) -> bool {
631        transformations::is_empty(self)
632    }
633
634    /// Rename columns. PySpark toDF(*colNames).
635    pub fn to_df(&self, names: Vec<&str>) -> Result<DataFrame, PolarsError> {
636        transformations::to_df(self, &names, self.case_sensitive)
637    }
638
639    /// Statistical helper. PySpark df.stat().cov / .corr.
640    pub fn stat(&self) -> DataFrameStat<'_> {
641        DataFrameStat { df: self }
642    }
643
644    /// Correlation matrix of all numeric columns. PySpark df.corr() returns a DataFrame of pairwise correlations.
645    pub fn corr(&self) -> Result<DataFrame, PolarsError> {
646        self.stat().corr_matrix()
647    }
648
649    /// Pearson correlation between two columns (scalar). PySpark df.corr(col1, col2).
650    pub fn corr_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
651        self.stat().corr(col1, col2)
652    }
653
654    /// Sample covariance between two columns (scalar). PySpark df.cov(col1, col2).
655    pub fn cov_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
656        self.stat().cov(col1, col2)
657    }
658
659    /// Summary statistics (alias for describe). PySpark summary.
660    pub fn summary(&self) -> Result<DataFrame, PolarsError> {
661        self.describe()
662    }
663
664    /// Collect rows as JSON strings (one per row). PySpark toJSON.
665    pub fn to_json(&self) -> Result<Vec<String>, PolarsError> {
666        transformations::to_json(self)
667    }
668
669    /// Return execution plan description. PySpark explain.
670    pub fn explain(&self) -> String {
671        transformations::explain(self)
672    }
673
674    /// Return schema as tree string. PySpark printSchema (returns string; print to stdout if needed).
675    pub fn print_schema(&self) -> Result<String, PolarsError> {
676        transformations::print_schema(self)
677    }
678
679    /// No-op: Polars backend is eager. PySpark checkpoint.
680    pub fn checkpoint(&self) -> Result<DataFrame, PolarsError> {
681        Ok(self.clone())
682    }
683
684    /// No-op: Polars backend is eager. PySpark localCheckpoint.
685    pub fn local_checkpoint(&self) -> Result<DataFrame, PolarsError> {
686        Ok(self.clone())
687    }
688
689    /// No-op: single partition in Polars. PySpark repartition(n).
690    pub fn repartition(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
691        Ok(self.clone())
692    }
693
694    /// No-op: Polars has no range partitioning. PySpark repartitionByRange(n, cols).
695    pub fn repartition_by_range(
696        &self,
697        _num_partitions: usize,
698        _cols: Vec<&str>,
699    ) -> Result<DataFrame, PolarsError> {
700        Ok(self.clone())
701    }
702
703    /// Column names and dtype strings. PySpark dtypes. Returns (name, dtype_string) per column.
704    pub fn dtypes(&self) -> Result<Vec<(String, String)>, PolarsError> {
705        let schema = self.df.schema();
706        Ok(schema
707            .iter_names_and_dtypes()
708            .map(|(name, dtype)| (name.to_string(), format!("{dtype:?}")))
709            .collect())
710    }
711
712    /// No-op: we don't model partitions. PySpark sortWithinPartitions. Same as orderBy for compatibility.
713    pub fn sort_within_partitions(
714        &self,
715        _cols: &[crate::functions::SortOrder],
716    ) -> Result<DataFrame, PolarsError> {
717        Ok(self.clone())
718    }
719
720    /// No-op: single partition in Polars. PySpark coalesce(n).
721    pub fn coalesce(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
722        Ok(self.clone())
723    }
724
725    /// No-op. PySpark hint (query planner hint).
726    pub fn hint(&self, _name: &str, _params: &[i32]) -> Result<DataFrame, PolarsError> {
727        Ok(self.clone())
728    }
729
730    /// Returns true (eager single-node). PySpark isLocal.
731    pub fn is_local(&self) -> bool {
732        true
733    }
734
735    /// Returns empty vec (no file sources). PySpark inputFiles.
736    pub fn input_files(&self) -> Vec<String> {
737        Vec::new()
738    }
739
740    /// No-op; returns false. PySpark sameSemantics.
741    pub fn same_semantics(&self, _other: &DataFrame) -> bool {
742        false
743    }
744
745    /// No-op; returns 0. PySpark semanticHash.
746    pub fn semantic_hash(&self) -> u64 {
747        0
748    }
749
750    /// No-op. PySpark observe (metrics).
751    pub fn observe(&self, _name: &str, _expr: Expr) -> Result<DataFrame, PolarsError> {
752        Ok(self.clone())
753    }
754
755    /// No-op. PySpark withWatermark (streaming).
756    pub fn with_watermark(
757        &self,
758        _event_time: &str,
759        _delay: &str,
760    ) -> Result<DataFrame, PolarsError> {
761        Ok(self.clone())
762    }
763
764    /// Select by expression strings (minimal: column names, optionally "col as alias"). PySpark selectExpr.
765    pub fn select_expr(&self, exprs: &[String]) -> Result<DataFrame, PolarsError> {
766        transformations::select_expr(self, exprs, self.case_sensitive)
767    }
768
769    /// Select columns whose names match the regex. PySpark colRegex.
770    pub fn col_regex(&self, pattern: &str) -> Result<DataFrame, PolarsError> {
771        transformations::col_regex(self, pattern, self.case_sensitive)
772    }
773
774    /// Add or replace multiple columns. PySpark withColumns. Accepts `Column` so rand/randn get per-row values.
775    pub fn with_columns(&self, exprs: &[(String, Column)]) -> Result<DataFrame, PolarsError> {
776        transformations::with_columns(self, exprs, self.case_sensitive)
777    }
778
779    /// Rename multiple columns. PySpark withColumnsRenamed.
780    pub fn with_columns_renamed(
781        &self,
782        renames: &[(String, String)],
783    ) -> Result<DataFrame, PolarsError> {
784        transformations::with_columns_renamed(self, renames, self.case_sensitive)
785    }
786
787    /// NA sub-API. PySpark df.na().
788    pub fn na(&self) -> DataFrameNa<'_> {
789        DataFrameNa { df: self }
790    }
791
792    /// Skip first n rows. PySpark offset(n).
793    pub fn offset(&self, n: usize) -> Result<DataFrame, PolarsError> {
794        transformations::offset(self, n, self.case_sensitive)
795    }
796
797    /// Transform by a function. PySpark transform(func).
798    pub fn transform<F>(&self, f: F) -> Result<DataFrame, PolarsError>
799    where
800        F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
801    {
802        transformations::transform(self, f)
803    }
804
805    /// Frequent items. PySpark freqItems (stub).
806    pub fn freq_items(&self, columns: &[&str], support: f64) -> Result<DataFrame, PolarsError> {
807        transformations::freq_items(self, columns, support, self.case_sensitive)
808    }
809
810    /// Approximate quantiles. PySpark approxQuantile (stub).
811    pub fn approx_quantile(
812        &self,
813        column: &str,
814        probabilities: &[f64],
815    ) -> Result<DataFrame, PolarsError> {
816        transformations::approx_quantile(self, column, probabilities, self.case_sensitive)
817    }
818
819    /// Cross-tabulation. PySpark crosstab (stub).
820    pub fn crosstab(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
821        transformations::crosstab(self, col1, col2, self.case_sensitive)
822    }
823
824    /// Unpivot (melt). PySpark melt (stub).
825    pub fn melt(&self, id_vars: &[&str], value_vars: &[&str]) -> Result<DataFrame, PolarsError> {
826        transformations::melt(self, id_vars, value_vars, self.case_sensitive)
827    }
828
829    /// Pivot (wide format). PySpark pivot. Stub: not yet implemented; use crosstab for two-column count.
830    pub fn pivot(
831        &self,
832        _pivot_col: &str,
833        _values: Option<Vec<&str>>,
834    ) -> Result<DataFrame, PolarsError> {
835        Err(PolarsError::InvalidOperation(
836            "pivot is not yet implemented; use crosstab(col1, col2) for two-column cross-tabulation."
837                .into(),
838        ))
839    }
840
841    /// Set difference keeping duplicates. PySpark exceptAll.
842    pub fn except_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
843        transformations::except_all(self, other, self.case_sensitive)
844    }
845
846    /// Set intersection keeping duplicates. PySpark intersectAll.
847    pub fn intersect_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
848        transformations::intersect_all(self, other, self.case_sensitive)
849    }
850
851    /// Write this DataFrame to a Delta table at the given path.
852    /// Requires the `delta` feature. If `overwrite` is true, replaces the table; otherwise appends.
853    #[cfg(feature = "delta")]
854    pub fn write_delta(
855        &self,
856        path: impl AsRef<std::path::Path>,
857        overwrite: bool,
858    ) -> Result<(), PolarsError> {
859        crate::delta::write_delta(self.df.as_ref(), path, overwrite)
860    }
861
862    /// Stub when `delta` feature is disabled.
863    #[cfg(not(feature = "delta"))]
864    pub fn write_delta(
865        &self,
866        _path: impl AsRef<std::path::Path>,
867        _overwrite: bool,
868    ) -> Result<(), PolarsError> {
869        Err(PolarsError::InvalidOperation(
870            "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
871        ))
872    }
873
874    /// Register this DataFrame as an in-memory "delta table" by name (same namespace as saveAsTable). Readable via `read_delta(name)` or `table(name)`.
875    pub fn save_as_delta_table(&self, session: &crate::session::SparkSession, name: &str) {
876        session.register_table(name, self.clone());
877    }
878
879    /// Return a writer for generic format (parquet, csv, json). PySpark-style write API.
880    pub fn write(&self) -> DataFrameWriter<'_> {
881        DataFrameWriter {
882            df: self,
883            mode: WriteMode::Overwrite,
884            format: WriteFormat::Parquet,
885            options: HashMap::new(),
886            partition_by: Vec::new(),
887        }
888    }
889}
890
891/// Write mode: overwrite or append (PySpark DataFrameWriter.mode for path-based save).
892#[derive(Clone, Copy, PartialEq, Eq)]
893pub enum WriteMode {
894    Overwrite,
895    Append,
896}
897
898/// Save mode for saveAsTable (PySpark default is ErrorIfExists).
899#[derive(Clone, Copy, PartialEq, Eq)]
900pub enum SaveMode {
901    /// Throw if table already exists (PySpark default).
902    ErrorIfExists,
903    /// Replace existing table.
904    Overwrite,
905    /// Append to existing table; create if not exists. Column names align.
906    Append,
907    /// No-op if table exists; create if not.
908    Ignore,
909}
910
911/// Output format for generic write (PySpark DataFrameWriter.format).
912#[derive(Clone, Copy)]
913pub enum WriteFormat {
914    Parquet,
915    Csv,
916    Json,
917}
918
919/// Builder for writing DataFrame to path (PySpark DataFrameWriter).
920pub struct DataFrameWriter<'a> {
921    df: &'a DataFrame,
922    mode: WriteMode,
923    format: WriteFormat,
924    options: HashMap<String, String>,
925    partition_by: Vec<String>,
926}
927
928impl<'a> DataFrameWriter<'a> {
929    pub fn mode(mut self, mode: WriteMode) -> Self {
930        self.mode = mode;
931        self
932    }
933
934    pub fn format(mut self, format: WriteFormat) -> Self {
935        self.format = format;
936        self
937    }
938
939    /// Add a single option (PySpark: option(key, value)). Returns self for chaining.
940    pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
941        self.options.insert(key.into(), value.into());
942        self
943    }
944
945    /// Add multiple options (PySpark: options(**kwargs)). Returns self for chaining.
946    pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
947        for (k, v) in opts {
948            self.options.insert(k, v);
949        }
950        self
951    }
952
953    /// Partition output by the given columns (PySpark: partitionBy(cols)).
954    pub fn partition_by(mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> Self {
955        self.partition_by = cols.into_iter().map(|s| s.into()).collect();
956        self
957    }
958
959    /// Save the DataFrame as a table (PySpark: saveAsTable). In-memory by default; when spark.sql.warehouse.dir is set, persists to disk for cross-session access.
960    pub fn save_as_table(
961        &self,
962        session: &SparkSession,
963        name: &str,
964        mode: SaveMode,
965    ) -> Result<(), PolarsError> {
966        use polars::prelude::*;
967        use std::fs;
968        use std::path::Path;
969
970        let warehouse_path = session.warehouse_dir().map(|w| Path::new(w).join(name));
971        let warehouse_exists = warehouse_path.as_ref().is_some_and(|p| p.is_dir());
972
973        fn persist_to_warehouse(
974            df: &crate::dataframe::DataFrame,
975            dir: &Path,
976        ) -> Result<(), PolarsError> {
977            use std::fs;
978            fs::create_dir_all(dir).map_err(|e| {
979                PolarsError::ComputeError(format!("saveAsTable: create dir: {e}").into())
980            })?;
981            let file_path = dir.join("data.parquet");
982            df.write()
983                .mode(crate::dataframe::WriteMode::Overwrite)
984                .format(crate::dataframe::WriteFormat::Parquet)
985                .save(&file_path)
986        }
987
988        let final_df = match mode {
989            SaveMode::ErrorIfExists => {
990                if session.saved_table_exists(name) || warehouse_exists {
991                    return Err(PolarsError::InvalidOperation(
992                        format!(
993                            "Table or view '{name}' already exists. SaveMode is ErrorIfExists."
994                        )
995                        .into(),
996                    ));
997                }
998                if let Some(ref p) = warehouse_path {
999                    persist_to_warehouse(self.df, p)?;
1000                }
1001                self.df.clone()
1002            }
1003            SaveMode::Overwrite => {
1004                if let Some(ref p) = warehouse_path {
1005                    let _ = fs::remove_dir_all(p);
1006                    persist_to_warehouse(self.df, p)?;
1007                }
1008                self.df.clone()
1009            }
1010            SaveMode::Append => {
1011                let existing_pl = if let Some(existing) = session.get_saved_table(name) {
1012                    existing.df.as_ref().clone()
1013                } else if let (Some(ref p), true) = (warehouse_path.as_ref(), warehouse_exists) {
1014                    // Read from warehouse (data.parquet convention)
1015                    let data_file = p.join("data.parquet");
1016                    let read_path = if data_file.is_file() {
1017                        data_file.as_path()
1018                    } else {
1019                        p.as_ref()
1020                    };
1021                    let lf = LazyFrame::scan_parquet(read_path, ScanArgsParquet::default())
1022                        .map_err(|e| {
1023                            PolarsError::ComputeError(
1024                                format!("saveAsTable append: read warehouse: {e}").into(),
1025                            )
1026                        })?;
1027                    lf.collect().map_err(|e| {
1028                        PolarsError::ComputeError(
1029                            format!("saveAsTable append: collect: {e}").into(),
1030                        )
1031                    })?
1032                } else {
1033                    // New table
1034                    session.register_table(name, self.df.clone());
1035                    if let Some(ref p) = warehouse_path {
1036                        persist_to_warehouse(self.df, p)?;
1037                    }
1038                    return Ok(());
1039                };
1040                let new_pl = self.df.df.as_ref().clone();
1041                let existing_cols: Vec<&str> = existing_pl
1042                    .get_column_names()
1043                    .iter()
1044                    .map(|s| s.as_str())
1045                    .collect();
1046                let new_cols = new_pl.get_column_names();
1047                let missing: Vec<_> = existing_cols
1048                    .iter()
1049                    .filter(|c| !new_cols.iter().any(|n| n.as_str() == **c))
1050                    .collect();
1051                if !missing.is_empty() {
1052                    return Err(PolarsError::InvalidOperation(
1053                        format!(
1054                            "saveAsTable append: new DataFrame missing columns: {:?}",
1055                            missing
1056                        )
1057                        .into(),
1058                    ));
1059                }
1060                let new_ordered = new_pl.select(existing_cols.iter().copied())?;
1061                let mut combined = existing_pl;
1062                combined.vstack_mut(&new_ordered)?;
1063                let merged = crate::dataframe::DataFrame::from_polars_with_options(
1064                    combined,
1065                    self.df.case_sensitive,
1066                );
1067                if let Some(ref p) = warehouse_path {
1068                    let _ = fs::remove_dir_all(p);
1069                    persist_to_warehouse(&merged, p)?;
1070                }
1071                merged
1072            }
1073            SaveMode::Ignore => {
1074                if session.saved_table_exists(name) || warehouse_exists {
1075                    return Ok(());
1076                }
1077                if let Some(ref p) = warehouse_path {
1078                    persist_to_warehouse(self.df, p)?;
1079                }
1080                self.df.clone()
1081            }
1082        };
1083        session.register_table(name, final_df);
1084        Ok(())
1085    }
1086
1087    /// Write as Parquet (PySpark: parquet(path)). Equivalent to format(Parquet).save(path).
1088    pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1089        DataFrameWriter {
1090            df: self.df,
1091            mode: self.mode,
1092            format: WriteFormat::Parquet,
1093            options: self.options.clone(),
1094            partition_by: self.partition_by.clone(),
1095        }
1096        .save(path)
1097    }
1098
1099    /// Write as CSV (PySpark: csv(path)). Equivalent to format(Csv).save(path).
1100    pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1101        DataFrameWriter {
1102            df: self.df,
1103            mode: self.mode,
1104            format: WriteFormat::Csv,
1105            options: self.options.clone(),
1106            partition_by: self.partition_by.clone(),
1107        }
1108        .save(path)
1109    }
1110
1111    /// Write as JSON lines (PySpark: json(path)). Equivalent to format(Json).save(path).
1112    pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1113        DataFrameWriter {
1114            df: self.df,
1115            mode: self.mode,
1116            format: WriteFormat::Json,
1117            options: self.options.clone(),
1118            partition_by: self.partition_by.clone(),
1119        }
1120        .save(path)
1121    }
1122
1123    /// Write to path. Overwrite replaces; append reads existing (if any) and concatenates then writes.
1124    /// With partition_by, path is a directory; each partition is written as path/col1=val1/col2=val2/... with partition columns omitted from the file (Spark/Hive style).
1125    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
1126        use polars::prelude::*;
1127        let path = path.as_ref();
1128        let to_write: PlDataFrame = match self.mode {
1129            WriteMode::Overwrite => self.df.df.as_ref().clone(),
1130            WriteMode::Append => {
1131                if self.partition_by.is_empty() {
1132                    let existing: Option<PlDataFrame> = if path.exists() && path.is_file() {
1133                        match self.format {
1134                            WriteFormat::Parquet => {
1135                                LazyFrame::scan_parquet(path, ScanArgsParquet::default())
1136                                    .and_then(|lf| lf.collect())
1137                                    .ok()
1138                            }
1139                            WriteFormat::Csv => LazyCsvReader::new(path)
1140                                .with_has_header(true)
1141                                .finish()
1142                                .and_then(|lf| lf.collect())
1143                                .ok(),
1144                            WriteFormat::Json => LazyJsonLineReader::new(path)
1145                                .finish()
1146                                .and_then(|lf| lf.collect())
1147                                .ok(),
1148                        }
1149                    } else {
1150                        None
1151                    };
1152                    match existing {
1153                        Some(existing) => {
1154                            let lfs: [LazyFrame; 2] =
1155                                [existing.lazy(), self.df.df.as_ref().clone().lazy()];
1156                            concat(lfs, UnionArgs::default())?.collect()?
1157                        }
1158                        None => self.df.df.as_ref().clone(),
1159                    }
1160                } else {
1161                    self.df.df.as_ref().clone()
1162                }
1163            }
1164        };
1165
1166        if !self.partition_by.is_empty() {
1167            return self.save_partitioned(path, &to_write);
1168        }
1169
1170        match self.format {
1171            WriteFormat::Parquet => {
1172                let mut file = std::fs::File::create(path).map_err(|e| {
1173                    PolarsError::ComputeError(format!("write parquet create: {e}").into())
1174                })?;
1175                let mut df_mut = to_write;
1176                ParquetWriter::new(&mut file)
1177                    .finish(&mut df_mut)
1178                    .map_err(|e| PolarsError::ComputeError(format!("write parquet: {e}").into()))?;
1179            }
1180            WriteFormat::Csv => {
1181                let has_header = self
1182                    .options
1183                    .get("header")
1184                    .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
1185                    .unwrap_or(true);
1186                let delimiter = self
1187                    .options
1188                    .get("sep")
1189                    .and_then(|s| s.bytes().next())
1190                    .unwrap_or(b',');
1191                let mut file = std::fs::File::create(path).map_err(|e| {
1192                    PolarsError::ComputeError(format!("write csv create: {e}").into())
1193                })?;
1194                CsvWriter::new(&mut file)
1195                    .include_header(has_header)
1196                    .with_separator(delimiter)
1197                    .finish(&mut to_write.clone())
1198                    .map_err(|e| PolarsError::ComputeError(format!("write csv: {e}").into()))?;
1199            }
1200            WriteFormat::Json => {
1201                let mut file = std::fs::File::create(path).map_err(|e| {
1202                    PolarsError::ComputeError(format!("write json create: {e}").into())
1203                })?;
1204                JsonWriter::new(&mut file)
1205                    .finish(&mut to_write.clone())
1206                    .map_err(|e| PolarsError::ComputeError(format!("write json: {e}").into()))?;
1207            }
1208        }
1209        Ok(())
1210    }
1211
1212    /// Write partitioned by columns: path/col1=val1/col2=val2/part-00000.{ext}. Partition columns are not written into the file (Spark/Hive style).
1213    fn save_partitioned(&self, path: &Path, to_write: &PlDataFrame) -> Result<(), PolarsError> {
1214        use polars::prelude::*;
1215        let resolved: Vec<String> = self
1216            .partition_by
1217            .iter()
1218            .map(|c| self.df.resolve_column_name(c))
1219            .collect::<Result<Vec<_>, _>>()?;
1220        let all_names = to_write.get_column_names();
1221        let data_cols: Vec<&str> = all_names
1222            .iter()
1223            .filter(|n| !resolved.iter().any(|r| r == n.as_str()))
1224            .map(|n| n.as_str())
1225            .collect();
1226
1227        let unique_keys = to_write
1228            .select(resolved.iter().map(|s| s.as_str()).collect::<Vec<_>>())?
1229            .unique::<Option<&[String]>, String>(
1230                None,
1231                polars::prelude::UniqueKeepStrategy::First,
1232                None,
1233            )?;
1234
1235        if self.mode == WriteMode::Overwrite && path.exists() {
1236            if path.is_dir() {
1237                std::fs::remove_dir_all(path).map_err(|e| {
1238                    PolarsError::ComputeError(
1239                        format!("write partitioned: remove_dir_all: {e}").into(),
1240                    )
1241                })?;
1242            } else {
1243                std::fs::remove_file(path).map_err(|e| {
1244                    PolarsError::ComputeError(format!("write partitioned: remove_file: {e}").into())
1245                })?;
1246            }
1247        }
1248        std::fs::create_dir_all(path).map_err(|e| {
1249            PolarsError::ComputeError(format!("write partitioned: create_dir_all: {e}").into())
1250        })?;
1251
1252        let ext = match self.format {
1253            WriteFormat::Parquet => "parquet",
1254            WriteFormat::Csv => "csv",
1255            WriteFormat::Json => "json",
1256        };
1257
1258        for row_idx in 0..unique_keys.height() {
1259            let row = unique_keys
1260                .get(row_idx)
1261                .ok_or_else(|| PolarsError::ComputeError("partition_row: get row".into()))?;
1262            let filter_expr = partition_row_to_filter_expr(&resolved, &row)?;
1263            let subset = to_write.clone().lazy().filter(filter_expr).collect()?;
1264            let subset = subset.select(data_cols.iter().copied())?;
1265            if subset.height() == 0 {
1266                continue;
1267            }
1268
1269            let part_path: std::path::PathBuf = resolved
1270                .iter()
1271                .zip(row.iter())
1272                .map(|(name, av)| format!("{}={}", name, format_partition_value(av)))
1273                .fold(path.to_path_buf(), |p, seg| p.join(seg));
1274            std::fs::create_dir_all(&part_path).map_err(|e| {
1275                PolarsError::ComputeError(
1276                    format!("write partitioned: create_dir_all partition: {e}").into(),
1277                )
1278            })?;
1279
1280            let file_idx = if self.mode == WriteMode::Append {
1281                let suffix = format!(".{ext}");
1282                let max_n = std::fs::read_dir(&part_path)
1283                    .map(|rd| {
1284                        rd.filter_map(Result::ok)
1285                            .filter_map(|e| {
1286                                e.file_name().to_str().and_then(|s| {
1287                                    s.strip_prefix("part-")
1288                                        .and_then(|t| t.strip_suffix(&suffix))
1289                                        .and_then(|t| t.parse::<u32>().ok())
1290                                })
1291                            })
1292                            .max()
1293                            .unwrap_or(0)
1294                    })
1295                    .unwrap_or(0);
1296                max_n + 1
1297            } else {
1298                0
1299            };
1300            let filename = format!("part-{file_idx:05}.{ext}");
1301            let file_path = part_path.join(&filename);
1302
1303            match self.format {
1304                WriteFormat::Parquet => {
1305                    let mut file = std::fs::File::create(&file_path).map_err(|e| {
1306                        PolarsError::ComputeError(
1307                            format!("write partitioned parquet create: {e}").into(),
1308                        )
1309                    })?;
1310                    let mut df_mut = subset;
1311                    ParquetWriter::new(&mut file)
1312                        .finish(&mut df_mut)
1313                        .map_err(|e| {
1314                            PolarsError::ComputeError(
1315                                format!("write partitioned parquet: {e}").into(),
1316                            )
1317                        })?;
1318                }
1319                WriteFormat::Csv => {
1320                    let has_header = self
1321                        .options
1322                        .get("header")
1323                        .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
1324                        .unwrap_or(true);
1325                    let delimiter = self
1326                        .options
1327                        .get("sep")
1328                        .and_then(|s| s.bytes().next())
1329                        .unwrap_or(b',');
1330                    let mut file = std::fs::File::create(&file_path).map_err(|e| {
1331                        PolarsError::ComputeError(
1332                            format!("write partitioned csv create: {e}").into(),
1333                        )
1334                    })?;
1335                    CsvWriter::new(&mut file)
1336                        .include_header(has_header)
1337                        .with_separator(delimiter)
1338                        .finish(&mut subset.clone())
1339                        .map_err(|e| {
1340                            PolarsError::ComputeError(format!("write partitioned csv: {e}").into())
1341                        })?;
1342                }
1343                WriteFormat::Json => {
1344                    let mut file = std::fs::File::create(&file_path).map_err(|e| {
1345                        PolarsError::ComputeError(
1346                            format!("write partitioned json create: {e}").into(),
1347                        )
1348                    })?;
1349                    JsonWriter::new(&mut file)
1350                        .finish(&mut subset.clone())
1351                        .map_err(|e| {
1352                            PolarsError::ComputeError(format!("write partitioned json: {e}").into())
1353                        })?;
1354                }
1355            }
1356        }
1357        Ok(())
1358    }
1359}
1360
1361impl Clone for DataFrame {
1362    fn clone(&self) -> Self {
1363        DataFrame {
1364            df: self.df.clone(),
1365            case_sensitive: self.case_sensitive,
1366        }
1367    }
1368}
1369
1370/// Format a partition column value for use in a directory name (Spark/Hive style).
1371/// Null becomes "__HIVE_DEFAULT_PARTITION__"; other values use string representation with path-unsafe chars replaced.
1372fn format_partition_value(av: &AnyValue<'_>) -> String {
1373    let s = match av {
1374        AnyValue::Null => "__HIVE_DEFAULT_PARTITION__".to_string(),
1375        AnyValue::Boolean(b) => b.to_string(),
1376        AnyValue::Int32(i) => i.to_string(),
1377        AnyValue::Int64(i) => i.to_string(),
1378        AnyValue::UInt32(u) => u.to_string(),
1379        AnyValue::UInt64(u) => u.to_string(),
1380        AnyValue::Float32(f) => f.to_string(),
1381        AnyValue::Float64(f) => f.to_string(),
1382        AnyValue::String(s) => s.to_string(),
1383        AnyValue::StringOwned(s) => s.as_str().to_string(),
1384        AnyValue::Date(d) => d.to_string(),
1385        _ => av.to_string(),
1386    };
1387    // Replace path separators and other unsafe chars so the value is a valid path segment
1388    s.replace([std::path::MAIN_SEPARATOR, '/'], "_")
1389}
1390
1391/// Build a filter expression that matches rows where partition columns equal the given row values.
1392fn partition_row_to_filter_expr(
1393    col_names: &[String],
1394    row: &[AnyValue<'_>],
1395) -> Result<Expr, PolarsError> {
1396    if col_names.len() != row.len() {
1397        return Err(PolarsError::ComputeError(
1398            format!(
1399                "partition_row_to_filter_expr: {} columns but {} row values",
1400                col_names.len(),
1401                row.len()
1402            )
1403            .into(),
1404        ));
1405    }
1406    let mut pred = None::<Expr>;
1407    for (name, av) in col_names.iter().zip(row.iter()) {
1408        let clause = match av {
1409            AnyValue::Null => col(name.as_str()).is_null(),
1410            AnyValue::Boolean(b) => col(name.as_str()).eq(lit(*b)),
1411            AnyValue::Int32(i) => col(name.as_str()).eq(lit(*i)),
1412            AnyValue::Int64(i) => col(name.as_str()).eq(lit(*i)),
1413            AnyValue::UInt32(u) => col(name.as_str()).eq(lit(*u)),
1414            AnyValue::UInt64(u) => col(name.as_str()).eq(lit(*u)),
1415            AnyValue::Float32(f) => col(name.as_str()).eq(lit(*f)),
1416            AnyValue::Float64(f) => col(name.as_str()).eq(lit(*f)),
1417            AnyValue::String(s) => col(name.as_str()).eq(lit(s.to_string())),
1418            AnyValue::StringOwned(s) => col(name.as_str()).eq(lit(s.clone())),
1419            _ => {
1420                // Fallback: compare as string
1421                let s = av.to_string();
1422                col(name.as_str()).cast(DataType::String).eq(lit(s))
1423            }
1424        };
1425        pred = Some(match pred {
1426            None => clause,
1427            Some(p) => p.and(clause),
1428        });
1429    }
1430    Ok(pred.unwrap_or_else(|| lit(true)))
1431}
1432
1433/// Convert Polars AnyValue to serde_json::Value for language bindings (Node, etc.).
1434fn any_value_to_json(av: AnyValue<'_>) -> JsonValue {
1435    match av {
1436        AnyValue::Null => JsonValue::Null,
1437        AnyValue::Boolean(b) => JsonValue::Bool(b),
1438        AnyValue::Int32(i) => JsonValue::Number(serde_json::Number::from(i)),
1439        AnyValue::Int64(i) => JsonValue::Number(serde_json::Number::from(i)),
1440        AnyValue::UInt32(u) => JsonValue::Number(serde_json::Number::from(u)),
1441        AnyValue::UInt64(u) => JsonValue::Number(serde_json::Number::from(u)),
1442        AnyValue::Float32(f) => serde_json::Number::from_f64(f64::from(f))
1443            .map(JsonValue::Number)
1444            .unwrap_or(JsonValue::Null),
1445        AnyValue::Float64(f) => serde_json::Number::from_f64(f)
1446            .map(JsonValue::Number)
1447            .unwrap_or(JsonValue::Null),
1448        AnyValue::String(s) => JsonValue::String(s.to_string()),
1449        AnyValue::StringOwned(s) => JsonValue::String(s.to_string()),
1450        _ => JsonValue::Null,
1451    }
1452}
1453
1454#[cfg(test)]
1455mod tests {
1456    use super::*;
1457    use polars::prelude::{NamedFrom, Series};
1458
1459    /// Issue #235: root-level string–numeric comparison coercion in filter.
1460    #[test]
1461    fn coerce_string_numeric_root_in_filter() {
1462        let s = Series::new("str_col".into(), &["123", "456"]);
1463        let pl_df = polars::prelude::DataFrame::new(vec![s.into()]).unwrap();
1464        let df = DataFrame::from_polars(pl_df);
1465        let expr = col("str_col").eq(lit(123i64));
1466        let out = df.filter(expr).unwrap();
1467        assert_eq!(out.count().unwrap(), 1);
1468    }
1469}