Skip to main content

robin_sparkless/dataframe/
transformations.rs

1//! DataFrame transformation operations: filter, select, with_column, order_by,
2//! union, distinct, drop, dropna, fillna, limit, with_column_renamed,
3//! replace, cross_join, describe, subtract, intersect,
4//! sample, random_split, first, head, take, tail, is_empty, to_df.
5
6use super::DataFrame;
7use crate::functions::SortOrder;
8use crate::type_coercion::find_common_type;
9use polars::prelude::{
10    DataType, Expr, IntoLazy, IntoSeries, NamedFrom, PlSmallStr, PolarsError, Selector, Series,
11    UnionArgs, UniqueKeepStrategy, col,
12};
13use std::collections::HashMap;
14use std::sync::Arc;
15
16/// Select columns (returns a new DataFrame). Preserves case_sensitive on result.
17pub fn select(
18    df: &DataFrame,
19    cols: Vec<&str>,
20    case_sensitive: bool,
21) -> Result<DataFrame, PolarsError> {
22    let resolved: Vec<String> = cols
23        .iter()
24        .map(|c| df.resolve_column_name(c))
25        .collect::<Result<Vec<_>, _>>()?;
26    let exprs: Vec<Expr> = resolved.iter().map(|s| col(s.as_str())).collect();
27    let lf = df.lazy_frame().select(&exprs);
28    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
29}
30
31/// Select using column expressions (e.g. F.regexp_extract_all(...).alias("m")). Preserves case_sensitive.
32/// Column names in expressions are resolved per df's case sensitivity (PySpark parity).
33/// Duplicate output names are disambiguated with _1, _2, ... so select(col("num").cast("string"), col("num").cast("int")) works (issue #213).
34pub fn select_with_exprs(
35    df: &DataFrame,
36    exprs: Vec<Expr>,
37    case_sensitive: bool,
38) -> Result<DataFrame, PolarsError> {
39    let exprs: Vec<Expr> = exprs
40        .into_iter()
41        .map(|e| df.resolve_expr_column_names(e))
42        .collect::<Result<Vec<_>, _>>()?;
43    let mut name_count: HashMap<String, u32> = HashMap::new();
44    let exprs: Vec<Expr> = exprs
45        .into_iter()
46        .map(|e| {
47            let base_name = polars_plan::utils::expr_output_name(&e)
48                .map(|s| s.to_string())
49                .unwrap_or_else(|_| "_".to_string());
50            let count = name_count.entry(base_name.clone()).or_insert(0);
51            *count += 1;
52            let final_name = if *count == 1 {
53                base_name
54            } else {
55                format!("{}_{}", base_name, *count - 1)
56            };
57            if *count == 1 {
58                e
59            } else {
60                e.alias(final_name.as_str())
61            }
62        })
63        .collect();
64    let lf = df.lazy_frame().select(&exprs);
65    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
66}
67
68/// Filter rows using a Polars expression. Preserves case_sensitive on result.
69/// Column names in the condition are resolved per df's case sensitivity (PySpark parity).
70pub fn filter(
71    df: &DataFrame,
72    condition: Expr,
73    case_sensitive: bool,
74) -> Result<DataFrame, PolarsError> {
75    let condition = df.resolve_expr_column_names(condition)?;
76    let condition = df.coerce_string_numeric_comparisons(condition)?;
77    let lf = df.lazy_frame().filter(condition);
78    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
79}
80
81/// Add or replace a column. Handles deferred rand/randn and Python UDF (UdfCall).
82pub fn with_column(
83    df: &DataFrame,
84    column_name: &str,
85    column: &crate::column::Column,
86    case_sensitive: bool,
87) -> Result<DataFrame, PolarsError> {
88    // Python UDF: eager execution at UDF boundary
89    if let Some(deferred) = column.deferred {
90        match deferred {
91            crate::column::DeferredRandom::Rand(seed) => {
92                let pl_df = df.collect_inner()?;
93                let mut pl_df = pl_df.as_ref().clone();
94                let n = pl_df.height();
95                let series = crate::udfs::series_rand_n(column_name, n, seed);
96                pl_df.with_column(series.into())?;
97                return Ok(super::DataFrame::from_polars_with_options(
98                    pl_df,
99                    case_sensitive,
100                ));
101            }
102            crate::column::DeferredRandom::Randn(seed) => {
103                let pl_df = df.collect_inner()?;
104                let mut pl_df = pl_df.as_ref().clone();
105                let n = pl_df.height();
106                let series = crate::udfs::series_randn_n(column_name, n, seed);
107                pl_df.with_column(series.into())?;
108                return Ok(super::DataFrame::from_polars_with_options(
109                    pl_df,
110                    case_sensitive,
111                ));
112            }
113        }
114    }
115    let expr = df.resolve_expr_column_names(column.expr().clone())?;
116    let expr = df.coerce_string_numeric_comparisons(expr)?;
117    let lf = df.lazy_frame().with_column(expr.alias(column_name));
118    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
119}
120
121/// Order by columns (sort). Preserves case_sensitive on result.
122pub fn order_by(
123    df: &DataFrame,
124    column_names: Vec<&str>,
125    ascending: Vec<bool>,
126    case_sensitive: bool,
127) -> Result<DataFrame, PolarsError> {
128    use polars::prelude::*;
129    let mut asc = ascending;
130    while asc.len() < column_names.len() {
131        asc.push(true);
132    }
133    asc.truncate(column_names.len());
134    let resolved: Vec<String> = column_names
135        .iter()
136        .map(|c| df.resolve_column_name(c))
137        .collect::<Result<Vec<_>, _>>()?;
138    let exprs: Vec<Expr> = resolved.iter().map(|s| col(s.as_str())).collect();
139    let descending: Vec<bool> = asc.iter().map(|&a| !a).collect();
140    // PySpark default: ASC nulls first (nulls_last=false), DESC nulls last (nulls_last=true).
141    let nulls_last: Vec<bool> = descending.clone();
142    let lf = df.lazy_frame().sort_by_exprs(
143        exprs,
144        SortMultipleOptions::new()
145            .with_order_descending_multi(descending)
146            .with_nulls_last_multi(nulls_last),
147    );
148    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
149}
150
151/// Order by sort expressions (asc/desc with nulls_first/last). Preserves case_sensitive on result.
152/// Column names in sort expressions are resolved per df's case sensitivity (PySpark parity).
153pub fn order_by_exprs(
154    df: &DataFrame,
155    sort_orders: Vec<SortOrder>,
156    case_sensitive: bool,
157) -> Result<DataFrame, PolarsError> {
158    use polars::prelude::*;
159    if sort_orders.is_empty() {
160        return Ok(super::DataFrame::from_lazy_with_options(
161            df.lazy_frame(),
162            case_sensitive,
163        ));
164    }
165    let exprs: Vec<Expr> = sort_orders
166        .iter()
167        .map(|s| df.resolve_expr_column_names(s.expr().clone()))
168        .collect::<Result<Vec<_>, _>>()?;
169    let descending: Vec<bool> = sort_orders.iter().map(|s| s.descending).collect();
170    let nulls_last: Vec<bool> = sort_orders.iter().map(|s| s.nulls_last).collect();
171    let opts = SortMultipleOptions::new()
172        .with_order_descending_multi(descending)
173        .with_nulls_last_multi(nulls_last);
174    let lf = df.lazy_frame().sort_by_exprs(exprs, opts);
175    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
176}
177
178/// Union (unionAll): stack another DataFrame vertically. Schemas must match (same columns, same order).
179/// When column types differ (e.g. String vs Int64), both sides are coerced to a common type (PySpark parity #551).
180pub fn union(
181    left: &DataFrame,
182    right: &DataFrame,
183    case_sensitive: bool,
184) -> Result<DataFrame, PolarsError> {
185    let left_names = left.columns()?;
186    let right_names = right.columns()?;
187    if left_names != right_names {
188        return Err(PolarsError::InvalidOperation(
189            format!(
190                "union: column order/names must match. Left: {:?}, Right: {:?}",
191                left_names, right_names
192            )
193            .into(),
194        ));
195    }
196    let mut left_exprs: Vec<Expr> = Vec::with_capacity(left_names.len());
197    let mut right_exprs: Vec<Expr> = Vec::with_capacity(right_names.len());
198    for name in &left_names {
199        let resolved_left = left.resolve_column_name(name)?;
200        let resolved_right = right.resolve_column_name(name)?;
201        let left_dtype = left.get_column_dtype(name).unwrap_or(DataType::Null);
202        let right_dtype = right.get_column_dtype(name).unwrap_or(DataType::Null);
203        let target = if left_dtype == DataType::Null {
204            right_dtype.clone()
205        } else if right_dtype == DataType::Null || left_dtype == right_dtype {
206            left_dtype.clone()
207        } else {
208            find_common_type(&left_dtype, &right_dtype)?
209        };
210        let left_expr = if left_dtype == target {
211            col(resolved_left.as_str())
212        } else {
213            col(resolved_left.as_str()).cast(target.clone())
214        };
215        let right_expr = if right_dtype == target {
216            col(resolved_right.as_str())
217        } else {
218            col(resolved_right.as_str()).cast(target)
219        };
220        left_exprs.push(left_expr.alias(name.as_str()));
221        right_exprs.push(right_expr.alias(name.as_str()));
222    }
223    let lf1 = left.lazy_frame().select(&left_exprs);
224    let lf2 = right.lazy_frame().select(&right_exprs);
225    let out = polars::prelude::concat([lf1, lf2], UnionArgs::default())?;
226    Ok(super::DataFrame::from_lazy_with_options(
227        out,
228        case_sensitive,
229    ))
230}
231
232/// Union by name: stack vertically, aligning columns by name.
233/// When allow_missing_columns is true: result has all columns from both sides (missing filled with null).
234/// When false: result has only left columns; right must have all left columns.
235/// When same-named columns have different types (e.g. String vs Int64), coerces to a common type (PySpark parity #603).
236pub fn union_by_name(
237    left: &DataFrame,
238    right: &DataFrame,
239    allow_missing_columns: bool,
240    case_sensitive: bool,
241) -> Result<DataFrame, PolarsError> {
242    use crate::type_coercion::find_common_type;
243    use polars::prelude::*;
244
245    let left_names = left.columns()?;
246    let right_names = right.columns()?;
247    let contains = |names: &[String], name: &str| -> bool {
248        if case_sensitive {
249            names.iter().any(|n| n.as_str() == name)
250        } else {
251            let name_lower = name.to_lowercase();
252            names
253                .iter()
254                .any(|n| n.as_str().to_lowercase() == name_lower)
255        }
256    };
257    let resolve = |names: &[String], name: &str| -> Option<String> {
258        if case_sensitive {
259            names.iter().find(|n| n.as_str() == name).cloned()
260        } else {
261            let name_lower = name.to_lowercase();
262            names
263                .iter()
264                .find(|n| n.as_str().to_lowercase() == name_lower)
265                .cloned()
266        }
267    };
268    let all_columns: Vec<String> = if allow_missing_columns {
269        let mut out = left_names.clone();
270        for r in &right_names {
271            if !contains(&out, r.as_str()) {
272                out.push(r.clone());
273            }
274        }
275        out
276    } else {
277        left_names.clone()
278    };
279    // Per-column common type for coercion when left/right types differ (#603).
280    let mut left_exprs: Vec<Expr> = Vec::with_capacity(all_columns.len());
281    let mut right_exprs: Vec<Expr> = Vec::with_capacity(all_columns.len());
282    for c in &all_columns {
283        let left_has = resolve(&left_names, c.as_str());
284        let right_has = resolve(&right_names, c.as_str());
285        let left_dtype = left_has.as_ref().and_then(|r| left.get_column_dtype(r));
286        let right_dtype = right_has.as_ref().and_then(|r| right.get_column_dtype(r));
287        // #613: When one side's dtype is unknown (None), use String as common type so we never
288        // cast string to int (which would fail); both columns can safely cast to String.
289        let common_dtype = match (&left_dtype, &right_dtype) {
290            (Some(lt), Some(rt)) if lt != rt => find_common_type(lt, rt).map_err(|e| {
291                PolarsError::ComputeError(
292                    format!("union_by_name: column '{}' type coercion: {}", c, e).into(),
293                )
294            })?,
295            (Some(lt), Some(_)) => lt.clone(),
296            (Some(lt), None) | (None, Some(lt)) => {
297                // One side unknown: coerce to String so we never cast string→int (PySpark union promotes to string).
298                if lt == &polars::prelude::DataType::String {
299                    lt.clone()
300                } else {
301                    polars::prelude::DataType::String
302                }
303            }
304            (None, None) => polars::prelude::DataType::Null,
305        };
306        let left_expr = match &left_has {
307            Some(r) => col(r.as_str()).cast(common_dtype.clone()).alias(c.as_str()),
308            None => polars::prelude::lit(polars::prelude::NULL)
309                .cast(common_dtype.clone())
310                .alias(c.as_str()),
311        };
312        left_exprs.push(left_expr);
313        let right_expr = match &right_has {
314            Some(r) => col(r.as_str()).cast(common_dtype.clone()).alias(c.as_str()),
315            None if allow_missing_columns => polars::prelude::lit(polars::prelude::NULL)
316                .cast(common_dtype)
317                .alias(c.as_str()),
318            None => {
319                return Err(PolarsError::InvalidOperation(
320                    format!(
321                        "union_by_name: column '{}' missing in right DataFrame (allow_missing_columns=False)",
322                        c
323                    )
324                    .into(),
325                ));
326            }
327        };
328        right_exprs.push(right_expr);
329    }
330    let lf1 = left.lazy_frame().select(&left_exprs);
331    let lf2 = right.lazy_frame().select(&right_exprs);
332    let out = polars::prelude::concat([lf1, lf2], UnionArgs::default())?;
333    Ok(super::DataFrame::from_lazy_with_options(
334        out,
335        case_sensitive,
336    ))
337}
338
339/// Distinct: drop duplicate rows (all columns or subset).
340pub fn distinct(
341    df: &DataFrame,
342    subset: Option<Vec<&str>>,
343    case_sensitive: bool,
344) -> Result<DataFrame, PolarsError> {
345    let subset_names: Option<Vec<String>> = subset
346        .map(|cols| {
347            cols.iter()
348                .map(|s| df.resolve_column_name(s))
349                .collect::<Result<Vec<_>, _>>()
350        })
351        .transpose()?;
352    let subset_selector: Option<Selector> = subset_names.map(|names| Selector::ByName {
353        names: Arc::from(names.into_iter().map(PlSmallStr::from).collect::<Vec<_>>()),
354        strict: false,
355    });
356    let lf = df
357        .lazy_frame()
358        .unique(subset_selector, UniqueKeepStrategy::First);
359    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
360}
361
362/// Drop one or more columns.
363pub fn drop(
364    df: &DataFrame,
365    columns: Vec<&str>,
366    case_sensitive: bool,
367) -> Result<DataFrame, PolarsError> {
368    let resolved: Vec<String> = columns
369        .iter()
370        .map(|c| df.resolve_column_name(c))
371        .collect::<Result<Vec<_>, _>>()?;
372    let all_names = df.columns()?;
373    let to_keep: Vec<Expr> = all_names
374        .iter()
375        .filter(|n| !resolved.iter().any(|r| r == n.as_str()))
376        .map(|n| col(n.as_str()))
377        .collect();
378    let lf = df.lazy_frame().select(&to_keep);
379    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
380}
381
382/// Drop rows with nulls (all columns or subset). PySpark na.drop(subset, how, thresh).
383/// - how: "any" (default) = drop if any null in subset; "all" = drop only if all null in subset.
384/// - thresh: if set, keep row if it has at least this many non-null values in subset (overrides how).
385pub fn dropna(
386    df: &DataFrame,
387    subset: Option<Vec<&str>>,
388    how: &str,
389    thresh: Option<usize>,
390    case_sensitive: bool,
391) -> Result<DataFrame, PolarsError> {
392    use polars::prelude::*;
393    let cols: Vec<String> = match &subset {
394        Some(c) => c
395            .iter()
396            .map(|n| df.resolve_column_name(n))
397            .collect::<Result<Vec<_>, _>>()?,
398        None => df.columns()?,
399    };
400    let col_exprs: Vec<Expr> = cols.iter().map(|c| col(c.as_str())).collect();
401    let base_lf = df.lazy_frame();
402    let lf = if let Some(n) = thresh {
403        // Keep row if number of non-null in subset >= n
404        let count_expr: Expr = col_exprs
405            .iter()
406            .map(|e| e.clone().is_not_null().cast(DataType::Int32))
407            .fold(lit(0i32), |a, b| a + b);
408        base_lf.filter(count_expr.gt_eq(lit(n as i32)))
409    } else if how.eq_ignore_ascii_case("all") {
410        // Drop only when all subset columns are null → keep when any is not null
411        let any_not_null: Expr = col_exprs
412            .into_iter()
413            .map(|e| e.is_not_null())
414            .fold(lit(false), |a, b| a.or(b));
415        base_lf.filter(any_not_null)
416    } else {
417        // how == "any" (default): drop if any null in subset
418        let subset_selector = Selector::ByName {
419            names: Arc::from(
420                cols.iter()
421                    .map(|s| PlSmallStr::from(s.as_str()))
422                    .collect::<Vec<_>>(),
423            ),
424            strict: false,
425        };
426        base_lf.drop_nulls(Some(subset_selector))
427    };
428    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
429}
430
431/// Fill nulls with a literal expression. If subset is Some, only those columns are filled; else all.
432/// PySpark na.fill(value, subset=...).
433pub fn fillna(
434    df: &DataFrame,
435    value_expr: Expr,
436    subset: Option<Vec<&str>>,
437    case_sensitive: bool,
438) -> Result<DataFrame, PolarsError> {
439    use polars::prelude::*;
440    let exprs: Vec<Expr> = match subset {
441        Some(cols) => cols
442            .iter()
443            .map(|n| {
444                let resolved = df.resolve_column_name(n)?;
445                Ok(col(resolved.as_str()).fill_null(value_expr.clone()))
446            })
447            .collect::<Result<Vec<_>, PolarsError>>()?,
448        None => df
449            .columns()?
450            .iter()
451            .map(|n| col(n.as_str()).fill_null(value_expr.clone()))
452            .collect(),
453    };
454    let lf = df.lazy_frame().with_columns(exprs);
455    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
456}
457
458/// Limit: return first n rows.
459pub fn limit(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
460    // limit is a transformation: slice(0, n) on lazy
461    let lf = df.lazy_frame().slice(0, n as u32);
462    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
463}
464
465/// Rename a column (old_name -> new_name).
466pub fn with_column_renamed(
467    df: &DataFrame,
468    old_name: &str,
469    new_name: &str,
470    case_sensitive: bool,
471) -> Result<DataFrame, PolarsError> {
472    let resolved = df.resolve_column_name(old_name)?;
473    let lf = df
474        .lazy_frame()
475        .rename([resolved.as_str()], [new_name], true);
476    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
477}
478
479/// Replace values in a column: where column == old_value, use new_value. PySpark replace (single column).
480pub fn replace(
481    df: &DataFrame,
482    column_name: &str,
483    old_value: Expr,
484    new_value: Expr,
485    case_sensitive: bool,
486) -> Result<DataFrame, PolarsError> {
487    use polars::prelude::*;
488    let resolved = df.resolve_column_name(column_name)?;
489    let repl = when(col(resolved.as_str()).eq(old_value))
490        .then(new_value)
491        .otherwise(col(resolved.as_str()));
492    let lf = df.lazy_frame().with_column(repl.alias(resolved.as_str()));
493    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
494}
495
496/// Cross join: cartesian product of two DataFrames. PySpark crossJoin.
497pub fn cross_join(
498    left: &DataFrame,
499    right: &DataFrame,
500    case_sensitive: bool,
501) -> Result<DataFrame, PolarsError> {
502    let lf_left = left.lazy_frame();
503    let lf_right = right.lazy_frame();
504    let out = lf_left.cross_join(lf_right, None);
505    Ok(super::DataFrame::from_lazy_with_options(
506        out,
507        case_sensitive,
508    ))
509}
510
511/// Summary statistics (count, mean, std, min, max). PySpark describe.
512/// Builds a summary DataFrame with a "summary" column (PySpark name) and one column per numeric input column.
513pub fn describe(df: &DataFrame, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
514    use polars::prelude::*;
515    let pl_df = df.collect_inner()?.as_ref().clone();
516    let mut stat_values: Vec<Column> = Vec::new();
517    for col in pl_df.columns() {
518        let s = col.as_materialized_series();
519        let dtype = s.dtype();
520        if dtype.is_numeric() {
521            let name = s.name().clone();
522            let count = s.len() as i64 - s.null_count() as i64;
523            let mean_f = s.mean().unwrap_or(f64::NAN);
524            let std_f = s.std(1).unwrap_or(f64::NAN);
525            let s_f64 = s.cast(&DataType::Float64)?;
526            let ca = s_f64
527                .f64()
528                .map_err(|_| PolarsError::ComputeError("cast to f64 failed".into()))?;
529            let min_f = ca.min().unwrap_or(f64::NAN);
530            let max_f = ca.max().unwrap_or(f64::NAN);
531            // PySpark describe/summary returns string type for value columns
532            let is_float = matches!(dtype, DataType::Float64 | DataType::Float32);
533            let count_s = count.to_string();
534            let mean_s = if mean_f.is_nan() {
535                "None".to_string()
536            } else {
537                format!("{:.1}", mean_f)
538            };
539            let std_s = if std_f.is_nan() {
540                "None".to_string()
541            } else {
542                format!("{:.1}", std_f)
543            };
544            let min_s = if min_f.is_nan() {
545                "None".to_string()
546            } else if min_f.fract() == 0.0 && is_float {
547                format!("{:.1}", min_f)
548            } else if min_f.fract() == 0.0 {
549                format!("{:.0}", min_f)
550            } else {
551                format!("{min_f}")
552            };
553            let max_s = if max_f.is_nan() {
554                "None".to_string()
555            } else if max_f.fract() == 0.0 && is_float {
556                format!("{:.1}", max_f)
557            } else if max_f.fract() == 0.0 {
558                format!("{:.0}", max_f)
559            } else {
560                format!("{max_f}")
561            };
562            let series = Series::new(
563                name,
564                [
565                    count_s.as_str(),
566                    mean_s.as_str(),
567                    std_s.as_str(),
568                    min_s.as_str(),
569                    max_s.as_str(),
570                ],
571            );
572            stat_values.push(series.into());
573        }
574    }
575    if stat_values.is_empty() {
576        // No numeric columns: return minimal describe with just summary column (PySpark name)
577        let stat_col = Series::new(
578            "summary".into(),
579            &["count", "mean", "stddev", "min", "max" as &str],
580        )
581        .into();
582        let empty: Vec<f64> = Vec::new();
583        let empty_series = Series::new("placeholder".into(), empty).into();
584        let out_pl = polars::prelude::DataFrame::new_infer_height(vec![stat_col, empty_series])?;
585        return Ok(super::DataFrame::from_polars_with_options(
586            out_pl,
587            case_sensitive,
588        ));
589    }
590    let summary_col = Series::new(
591        "summary".into(),
592        &["count", "mean", "stddev", "min", "max" as &str],
593    )
594    .into();
595    let mut cols: Vec<Column> = vec![summary_col];
596    cols.extend(stat_values);
597    let out_pl = polars::prelude::DataFrame::new_infer_height(cols)?;
598    Ok(super::DataFrame::from_polars_with_options(
599        out_pl,
600        case_sensitive,
601    ))
602}
603
604/// Set difference: rows in left that are not in right (by all columns). PySpark subtract / except.
605/// Aligns right column names to left (case-insensitive) so subtract works when casing differs.
606pub fn subtract(
607    left: &DataFrame,
608    right: &DataFrame,
609    case_sensitive: bool,
610) -> Result<DataFrame, PolarsError> {
611    use polars::prelude::*;
612    let left_names = left.columns()?;
613    let right_names = right.columns()?;
614    let right_on: Vec<Expr> = left_names
615        .iter()
616        .map(|ln| {
617            let resolved = if case_sensitive {
618                right_names
619                    .iter()
620                    .find(|rn| rn.as_str() == ln.as_str())
621                    .cloned()
622                    .ok_or_else(|| {
623                        PolarsError::ColumnNotFound(
624                            format!("subtract: column '{}' not found on right", ln).into(),
625                        )
626                    })?
627            } else {
628                let ln_lower = ln.to_lowercase();
629                right_names
630                    .iter()
631                    .find(|rn| rn.to_lowercase() == ln_lower)
632                    .cloned()
633                    .ok_or_else(|| {
634                        PolarsError::ColumnNotFound(
635                            format!("subtract: column '{}' not found on right", ln).into(),
636                        )
637                    })?
638            };
639            Ok(col(resolved.as_str()))
640        })
641        .collect::<Result<Vec<_>, PolarsError>>()?;
642    let left_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
643    let right_lf = right.lazy_frame();
644    let left_lf = left.lazy_frame();
645    let anti = left_lf.join(right_lf, left_on, right_on, JoinArgs::new(JoinType::Anti));
646    Ok(super::DataFrame::from_lazy_with_options(
647        anti,
648        case_sensitive,
649    ))
650}
651
652/// Set intersection: rows that appear in both DataFrames (by all columns). PySpark intersect.
653/// Aligns right column names to left (case-insensitive) so intersect works when casing differs.
654pub fn intersect(
655    left: &DataFrame,
656    right: &DataFrame,
657    case_sensitive: bool,
658) -> Result<DataFrame, PolarsError> {
659    use polars::prelude::*;
660    let left_names = left.columns()?;
661    let right_names = right.columns()?;
662    let right_on: Vec<Expr> = left_names
663        .iter()
664        .map(|ln| {
665            let resolved = if case_sensitive {
666                right_names
667                    .iter()
668                    .find(|rn| rn.as_str() == ln.as_str())
669                    .cloned()
670                    .ok_or_else(|| {
671                        PolarsError::ColumnNotFound(
672                            format!("intersect: column '{}' not found on right", ln).into(),
673                        )
674                    })?
675            } else {
676                let ln_lower = ln.to_lowercase();
677                right_names
678                    .iter()
679                    .find(|rn| rn.to_lowercase() == ln_lower)
680                    .cloned()
681                    .ok_or_else(|| {
682                        PolarsError::ColumnNotFound(
683                            format!("intersect: column '{}' not found on right", ln).into(),
684                        )
685                    })?
686            };
687            Ok(col(resolved.as_str()))
688        })
689        .collect::<Result<Vec<_>, PolarsError>>()?;
690    let left_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
691    let left_lf = left.lazy_frame();
692    let right_lf = right.lazy_frame();
693    let semi = left_lf
694        .join(right_lf, left_on, right_on, JoinArgs::new(JoinType::Semi))
695        .unique(None, UniqueKeepStrategy::First);
696    Ok(super::DataFrame::from_lazy_with_options(
697        semi,
698        case_sensitive,
699    ))
700}
701
702// ---------- Batch A: sample, first/head/take/tail, is_empty, to_df ----------
703
704/// Sample a fraction of rows. PySpark sample(withReplacement, fraction, seed).
705pub fn sample(
706    df: &DataFrame,
707    with_replacement: bool,
708    fraction: f64,
709    seed: Option<u64>,
710    case_sensitive: bool,
711) -> Result<DataFrame, PolarsError> {
712    use polars::prelude::Series;
713    let pl = df.collect_inner()?;
714    let n = pl.height();
715    if n == 0 {
716        return Ok(super::DataFrame::from_lazy_with_options(
717            polars::prelude::DataFrame::empty().lazy(),
718            case_sensitive,
719        ));
720    }
721    let take_n = (n as f64 * fraction).round() as usize;
722    let take_n = take_n.min(n).max(0);
723    if take_n == 0 {
724        return Ok(super::DataFrame::from_lazy_with_options(
725            pl.as_ref().head(Some(0)).lazy(),
726            case_sensitive,
727        ));
728    }
729    let idx_series = Series::new("idx".into(), (0..n).map(|i| i as u32).collect::<Vec<_>>());
730    let sampled_idx = idx_series.sample_n(take_n, with_replacement, true, seed)?;
731    let idx_ca = sampled_idx
732        .u32()
733        .map_err(|_| PolarsError::ComputeError("sample: expected u32 indices".into()))?;
734    let pl_df = pl.as_ref().take(idx_ca)?;
735    Ok(super::DataFrame::from_polars_with_options(
736        pl_df,
737        case_sensitive,
738    ))
739}
740
741/// Split DataFrame by weights (random split). PySpark randomSplit(weights, seed).
742/// Returns one DataFrame per weight; weights are normalized to fractions.
743/// Each row is assigned to exactly one split (disjoint partitions).
744pub fn random_split(
745    df: &DataFrame,
746    weights: &[f64],
747    seed: Option<u64>,
748    case_sensitive: bool,
749) -> Result<Vec<DataFrame>, PolarsError> {
750    let total: f64 = weights.iter().sum();
751    if total <= 0.0 || weights.is_empty() {
752        return Ok(Vec::new());
753    }
754    let pl = df.collect_inner()?;
755    let n = pl.height();
756    if n == 0 {
757        return Ok(weights.iter().map(|_| super::DataFrame::empty()).collect());
758    }
759    // Normalize weights to cumulative fractions: e.g. [0.25, 0.25, 0.5] -> [0.25, 0.5, 1.0]
760    let mut cum = Vec::with_capacity(weights.len());
761    let mut acc = 0.0_f64;
762    for w in weights {
763        acc += w / total;
764        cum.push(acc);
765    }
766    // Assign each row index to one bucket using a single seeded RNG (disjoint split).
767    use polars::prelude::Series;
768    use rand::Rng;
769    use rand::SeedableRng;
770    let mut rng = rand::rngs::StdRng::seed_from_u64(seed.unwrap_or(0));
771    let mut bucket_indices: Vec<Vec<u32>> = (0..weights.len()).map(|_| Vec::new()).collect();
772    for i in 0..n {
773        let r: f64 = rng.r#gen();
774        let bucket = cum
775            .iter()
776            .position(|&c| r < c)
777            .unwrap_or(weights.len().saturating_sub(1));
778        bucket_indices[bucket].push(i as u32);
779    }
780    let pl = pl.as_ref();
781    let mut out = Vec::with_capacity(weights.len());
782    for indices in bucket_indices {
783        if indices.is_empty() {
784            out.push(super::DataFrame::from_polars_with_options(
785                pl.clone().head(Some(0)),
786                case_sensitive,
787            ));
788        } else {
789            let idx_series = Series::new("idx".into(), indices);
790            let idx_ca = idx_series.u32().map_err(|_| {
791                PolarsError::ComputeError("random_split: expected u32 indices".into())
792            })?;
793            let taken = pl.take(idx_ca)?;
794            out.push(super::DataFrame::from_polars_with_options(
795                taken,
796                case_sensitive,
797            ));
798        }
799    }
800    Ok(out)
801}
802
803/// Stratified sample by column value. PySpark sampleBy(col, fractions, seed).
804/// fractions: list of (value as Expr literal, fraction to sample for that value).
805pub fn sample_by(
806    df: &DataFrame,
807    col_name: &str,
808    fractions: &[(Expr, f64)],
809    seed: Option<u64>,
810    case_sensitive: bool,
811) -> Result<DataFrame, PolarsError> {
812    use polars::prelude::*;
813    if fractions.is_empty() {
814        return Ok(super::DataFrame::from_lazy_with_options(
815            df.lazy_frame().slice(0, 0),
816            case_sensitive,
817        ));
818    }
819    let resolved = df.resolve_column_name(col_name)?;
820    let mut parts = Vec::with_capacity(fractions.len());
821    for (value_expr, frac) in fractions {
822        let cond = col(resolved.as_str()).eq(value_expr.clone());
823        let filtered = df.lazy_frame().filter(cond).collect()?;
824        if filtered.height() == 0 {
825            parts.push(filtered.head(Some(0)));
826            continue;
827        }
828        let sampled = sample(
829            &super::DataFrame::from_polars_with_options(filtered, case_sensitive),
830            false,
831            *frac,
832            seed,
833            case_sensitive,
834        )?;
835        parts.push(sampled.collect_inner()?.as_ref().clone());
836    }
837    let mut out = parts
838        .first()
839        .ok_or_else(|| PolarsError::ComputeError("sample_by: no parts".into()))?
840        .clone();
841    for p in parts.iter().skip(1) {
842        out.vstack_mut(p)?;
843    }
844    Ok(super::DataFrame::from_polars_with_options(
845        out,
846        case_sensitive,
847    ))
848}
849
850/// First row as a DataFrame (one row). PySpark first().
851/// Uses limit(1) then collect so that orderBy (and other plan steps) are applied before taking
852/// the first row (issue #579: first() after orderBy must return first in sort order, not storage order).
853pub fn first(df: &DataFrame, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
854    let limited = limit(df, 1, case_sensitive)?;
855    let pl_df = limited.collect_inner()?.as_ref().clone();
856    Ok(super::DataFrame::from_polars_with_options(
857        pl_df,
858        case_sensitive,
859    ))
860}
861
862/// First n rows. PySpark head(n). Same as limit.
863pub fn head(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
864    limit(df, n, case_sensitive)
865}
866
867/// Take first n rows (alias for limit). PySpark take(n).
868pub fn take(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
869    limit(df, n, case_sensitive)
870}
871
872/// Last n rows. PySpark tail(n).
873pub fn tail(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
874    let pl = df.collect_inner()?;
875    let total = pl.height();
876    let skip = total.saturating_sub(n);
877    let pl_df = pl.as_ref().clone().slice(skip as i64, n);
878    Ok(super::DataFrame::from_polars_with_options(
879        pl_df,
880        case_sensitive,
881    ))
882}
883
884/// Whether the DataFrame has zero rows. PySpark isEmpty.
885pub fn is_empty(df: &DataFrame) -> bool {
886    df.count().map(|n| n == 0).unwrap_or(true)
887}
888
889/// Rename columns. PySpark toDF(*colNames). Names must match length of columns.
890pub fn to_df(
891    df: &DataFrame,
892    names: &[&str],
893    case_sensitive: bool,
894) -> Result<DataFrame, PolarsError> {
895    let cols = df.columns()?;
896    if names.len() != cols.len() {
897        return Err(PolarsError::ComputeError(
898            format!(
899                "toDF: expected {} column names, got {}",
900                cols.len(),
901                names.len()
902            )
903            .into(),
904        ));
905    }
906    let pl_df = df.collect_inner()?;
907    let mut pl_df = pl_df.as_ref().clone();
908    for (old, new) in cols.iter().zip(names.iter()) {
909        pl_df.rename(old.as_str(), (*new).into())?;
910    }
911    Ok(super::DataFrame::from_polars_with_options(
912        pl_df,
913        case_sensitive,
914    ))
915}
916
917// ---------- Batch B: toJSON, explain, printSchema ----------
918
919fn any_value_to_serde_value(av: &polars::prelude::AnyValue) -> serde_json::Value {
920    use polars::prelude::AnyValue;
921    use serde_json::Number;
922    match av {
923        AnyValue::Null => serde_json::Value::Null,
924        AnyValue::Boolean(v) => serde_json::Value::Bool(*v),
925        AnyValue::Int8(v) => serde_json::Value::Number(Number::from(*v as i64)),
926        AnyValue::Int32(v) => serde_json::Value::Number(Number::from(*v)),
927        AnyValue::Int64(v) => serde_json::Value::Number(Number::from(*v)),
928        AnyValue::UInt32(v) => serde_json::Value::Number(Number::from(*v)),
929        AnyValue::Float64(v) => Number::from_f64(*v)
930            .map(serde_json::Value::Number)
931            .unwrap_or(serde_json::Value::Null),
932        AnyValue::String(v) => serde_json::Value::String(v.to_string()),
933        _ => serde_json::Value::String(format!("{av:?}")),
934    }
935}
936
937/// Collect rows as JSON strings (one JSON object per row). PySpark toJSON.
938pub fn to_json(df: &DataFrame) -> Result<Vec<String>, PolarsError> {
939    use polars::prelude::*;
940    let collected = df.collect_inner()?;
941    let pl = collected.as_ref();
942    let names = pl.get_column_names();
943    let mut out = Vec::with_capacity(pl.height());
944    for r in 0..pl.height() {
945        let mut row = serde_json::Map::new();
946        for (i, name) in names.iter().enumerate() {
947            let col = pl
948                .columns()
949                .get(i)
950                .ok_or_else(|| PolarsError::ComputeError("to_json: column index".into()))?;
951            let series = col.as_materialized_series();
952            let av = series
953                .get(r)
954                .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
955            row.insert(name.to_string(), any_value_to_serde_value(&av));
956        }
957        out.push(
958            serde_json::to_string(&row)
959                .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?,
960        );
961    }
962    Ok(out)
963}
964
965/// Return a string describing the execution plan. PySpark explain.
966pub fn explain(_df: &DataFrame) -> String {
967    "DataFrame (eager Polars backend)".to_string()
968}
969
970/// Return schema as a tree string. PySpark printSchema (we return string; caller can print).
971pub fn print_schema(df: &DataFrame) -> Result<String, PolarsError> {
972    let schema = df.schema()?;
973    let mut s = "root\n".to_string();
974    for f in schema.fields() {
975        let dt = match &f.data_type {
976            crate::schema::DataType::String => "string",
977            crate::schema::DataType::Integer => "int",
978            crate::schema::DataType::Long => "bigint",
979            crate::schema::DataType::Double => "double",
980            crate::schema::DataType::Boolean => "boolean",
981            crate::schema::DataType::Date => "date",
982            crate::schema::DataType::Timestamp => "timestamp",
983            _ => "string",
984        };
985        s.push_str(&format!(" |-- {}: {}\n", f.name, dt));
986    }
987    Ok(s)
988}
989
990// ---------- Batch D: selectExpr, colRegex, withColumns, withColumnsRenamed, na ----------
991
992/// Select by expression strings. Minimal support: comma-separated column names. PySpark selectExpr.
993pub fn select_expr(
994    df: &DataFrame,
995    exprs: &[String],
996    case_sensitive: bool,
997) -> Result<DataFrame, PolarsError> {
998    let mut cols = Vec::new();
999    for e in exprs {
1000        let e = e.trim();
1001        if let Some((left, right)) = e.split_once(" as ") {
1002            let col_name = left.trim();
1003            let _alias = right.trim();
1004            cols.push(df.resolve_column_name(col_name)?);
1005        } else {
1006            cols.push(df.resolve_column_name(e)?);
1007        }
1008    }
1009    let refs: Vec<&str> = cols.iter().map(|s| s.as_str()).collect();
1010    select(df, refs, case_sensitive)
1011}
1012
1013/// Select columns whose names match the regex pattern. PySpark colRegex.
1014pub fn col_regex(
1015    df: &DataFrame,
1016    pattern: &str,
1017    case_sensitive: bool,
1018) -> Result<DataFrame, PolarsError> {
1019    let re = regex::Regex::new(pattern).map_err(|e| {
1020        PolarsError::ComputeError(format!("colRegex: invalid pattern {pattern:?}: {e}").into())
1021    })?;
1022    let names = df.columns()?;
1023    let matched: Vec<&str> = names
1024        .iter()
1025        .filter(|n| re.is_match(n))
1026        .map(|s| s.as_str())
1027        .collect();
1028    if matched.is_empty() {
1029        return Err(PolarsError::ComputeError(
1030            format!("colRegex: no columns matched pattern {pattern:?}").into(),
1031        ));
1032    }
1033    select(df, matched, case_sensitive)
1034}
1035
1036/// Add or replace multiple columns. PySpark withColumns. Uses Column so deferred rand/randn get per-row values.
1037pub fn with_columns(
1038    df: &DataFrame,
1039    exprs: &[(String, crate::column::Column)],
1040    case_sensitive: bool,
1041) -> Result<DataFrame, PolarsError> {
1042    let pl = df.collect_inner()?.as_ref().clone();
1043    let mut current = super::DataFrame::from_polars_with_options(pl, case_sensitive);
1044    for (name, col) in exprs {
1045        current = with_column(&current, name, col, case_sensitive)?;
1046    }
1047    Ok(current)
1048}
1049
1050/// Rename multiple columns. PySpark withColumnsRenamed.
1051pub fn with_columns_renamed(
1052    df: &DataFrame,
1053    renames: &[(String, String)],
1054    case_sensitive: bool,
1055) -> Result<DataFrame, PolarsError> {
1056    let mut mapping = Vec::new();
1057    for (old_name, new_name) in renames {
1058        let resolved = df.resolve_column_name(old_name)?;
1059        mapping.push((resolved, new_name.clone()));
1060    }
1061    let mut lf = df.lazy_frame();
1062    for (old, new) in mapping {
1063        lf = lf.rename([old.as_str()], [new.as_str()], true);
1064    }
1065    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
1066}
1067
1068/// NA sub-API builder. PySpark df.na().fill(...) / .drop(...).
1069pub struct DataFrameNa<'a> {
1070    pub(crate) df: &'a DataFrame,
1071}
1072
1073impl<'a> DataFrameNa<'a> {
1074    /// Fill nulls with the given value. PySpark na.fill(value, subset=...).
1075    pub fn fill(&self, value: Expr, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
1076        fillna(self.df, value, subset, self.df.case_sensitive)
1077    }
1078
1079    /// Replace values in columns. PySpark na.replace(to_replace, value, subset=None).
1080    pub fn replace(
1081        &self,
1082        old_value: Expr,
1083        new_value: Expr,
1084        subset: Option<Vec<&str>>,
1085    ) -> Result<DataFrame, PolarsError> {
1086        let cols: Vec<String> = match &subset {
1087            Some(s) => s.iter().map(|x| (*x).to_string()).collect(),
1088            None => self.df.columns()?,
1089        };
1090        let mut result = self.df.clone();
1091        for col_name in &cols {
1092            result = replace(
1093                &result,
1094                col_name.as_str(),
1095                old_value.clone(),
1096                new_value.clone(),
1097                self.df.case_sensitive,
1098            )?;
1099        }
1100        Ok(result)
1101    }
1102
1103    /// Drop rows with nulls. PySpark na.drop(subset=..., how=..., thresh=...).
1104    pub fn drop(
1105        &self,
1106        subset: Option<Vec<&str>>,
1107        how: &str,
1108        thresh: Option<usize>,
1109    ) -> Result<DataFrame, PolarsError> {
1110        dropna(self.df, subset, how, thresh, self.df.case_sensitive)
1111    }
1112}
1113
1114// ---------- Batch E: offset, transform, freqItems, approxQuantile, crosstab, melt, exceptAll, intersectAll ----------
1115
1116/// Skip first n rows. PySpark offset(n).
1117pub fn offset(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
1118    let lf = df.lazy_frame().slice(n as i64, u32::MAX);
1119    Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
1120}
1121
1122/// Transform DataFrame by a function. PySpark transform(func).
1123pub fn transform<F>(df: &DataFrame, f: F) -> Result<DataFrame, PolarsError>
1124where
1125    F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
1126{
1127    let df_out = f(df.clone())?;
1128    Ok(df_out)
1129}
1130
1131/// Frequent items. PySpark freqItems. Returns one row with columns {col}_freqItems (array of values with frequency >= support).
1132pub fn freq_items(
1133    df: &DataFrame,
1134    columns: &[&str],
1135    support: f64,
1136    case_sensitive: bool,
1137) -> Result<DataFrame, PolarsError> {
1138    use polars::prelude::SeriesMethods;
1139    if columns.is_empty() {
1140        return Ok(super::DataFrame::from_lazy_with_options(
1141            df.lazy_frame().slice(0, 0),
1142            case_sensitive,
1143        ));
1144    }
1145    let support = support.clamp(1e-4, 1.0);
1146    let collected = df.collect_inner()?;
1147    let pl_df = collected.as_ref();
1148    let n_total = pl_df.height() as f64;
1149    if n_total == 0.0 {
1150        let mut out = Vec::with_capacity(columns.len());
1151        for col_name in columns {
1152            let resolved = df.resolve_column_name(col_name)?;
1153            let s = pl_df
1154                .column(resolved.as_str())?
1155                .as_series()
1156                .ok_or_else(|| PolarsError::ComputeError("column not a series".into()))?
1157                .clone();
1158            let empty_sub = s.head(Some(0));
1159            let list_chunked = polars::prelude::ListChunked::from_iter([empty_sub].into_iter())
1160                .with_name(format!("{resolved}_freqItems").into());
1161            out.push(list_chunked.into_series().into());
1162        }
1163        return Ok(super::DataFrame::from_polars_with_options(
1164            polars::prelude::DataFrame::new_infer_height(out)?,
1165            case_sensitive,
1166        ));
1167    }
1168    let mut out_series = Vec::with_capacity(columns.len());
1169    for col_name in columns {
1170        let resolved = df.resolve_column_name(col_name)?;
1171        let s = pl_df
1172            .column(resolved.as_str())?
1173            .as_series()
1174            .ok_or_else(|| PolarsError::ComputeError("column not a series".into()))?
1175            .clone();
1176        let vc = s.value_counts(false, false, "counts".into(), false)?;
1177        let count_col = vc
1178            .column("counts")
1179            .map_err(|_| PolarsError::ComputeError("value_counts missing counts column".into()))?;
1180        let counts = count_col
1181            .u32()
1182            .map_err(|_| PolarsError::ComputeError("freq_items: counts column not u32".into()))?;
1183        let value_col_name = s.name();
1184        let values_col = vc
1185            .column(value_col_name.as_str())
1186            .map_err(|_| PolarsError::ComputeError("value_counts missing value column".into()))?;
1187        let threshold = (support * n_total).ceil() as u32;
1188        let indices: Vec<u32> = counts
1189            .into_iter()
1190            .enumerate()
1191            .filter_map(|(i, c)| {
1192                if c? >= threshold {
1193                    Some(i as u32)
1194                } else {
1195                    None
1196                }
1197            })
1198            .collect();
1199        let idx_series = Series::new("idx".into(), indices);
1200        let idx_ca = idx_series
1201            .u32()
1202            .map_err(|_| PolarsError::ComputeError("freq_items: index series not u32".into()))?;
1203        let values_series = values_col
1204            .as_series()
1205            .ok_or_else(|| PolarsError::ComputeError("value column not a series".into()))?;
1206        let filtered = values_series.take(idx_ca)?;
1207        let list_chunked = polars::prelude::ListChunked::from_iter([filtered].into_iter())
1208            .with_name(format!("{resolved}_freqItems").into());
1209        let list_row = list_chunked.into_series();
1210        out_series.push(list_row.into());
1211    }
1212    let out_df = polars::prelude::DataFrame::new_infer_height(out_series)?;
1213    Ok(super::DataFrame::from_polars_with_options(
1214        out_df,
1215        case_sensitive,
1216    ))
1217}
1218
1219/// Approximate quantiles. PySpark approxQuantile. Returns one column "quantile" with one row per probability.
1220pub fn approx_quantile(
1221    df: &DataFrame,
1222    column: &str,
1223    probabilities: &[f64],
1224    case_sensitive: bool,
1225) -> Result<DataFrame, PolarsError> {
1226    use polars::prelude::{ChunkQuantile, QuantileMethod};
1227    if probabilities.is_empty() {
1228        return Ok(super::DataFrame::from_polars_with_options(
1229            polars::prelude::DataFrame::new_infer_height(vec![
1230                Series::new("quantile".into(), Vec::<f64>::new()).into(),
1231            ])?,
1232            case_sensitive,
1233        ));
1234    }
1235    let resolved = df.resolve_column_name(column)?;
1236    let collected = df.collect_inner()?;
1237    let s = collected
1238        .column(resolved.as_str())?
1239        .as_series()
1240        .ok_or_else(|| PolarsError::ComputeError("approx_quantile: column not a series".into()))?
1241        .clone();
1242    let s_f64 = s.cast(&polars::prelude::DataType::Float64)?;
1243    let ca = s_f64
1244        .f64()
1245        .map_err(|_| PolarsError::ComputeError("approx_quantile: need numeric column".into()))?;
1246    let mut quantiles = Vec::with_capacity(probabilities.len());
1247    for &p in probabilities {
1248        let q = ca.quantile(p, QuantileMethod::Linear)?;
1249        quantiles.push(q.unwrap_or(f64::NAN));
1250    }
1251    let out_df = polars::prelude::DataFrame::new_infer_height(vec![
1252        Series::new("quantile".into(), quantiles).into(),
1253    ])?;
1254    Ok(super::DataFrame::from_polars_with_options(
1255        out_df,
1256        case_sensitive,
1257    ))
1258}
1259
1260/// Cross-tabulation. PySpark crosstab. Returns long format (col1, col2, count); for wide format use pivot on the result.
1261pub fn crosstab(
1262    df: &DataFrame,
1263    col1: &str,
1264    col2: &str,
1265    case_sensitive: bool,
1266) -> Result<DataFrame, PolarsError> {
1267    use polars::prelude::*;
1268    let c1 = df.resolve_column_name(col1)?;
1269    let c2 = df.resolve_column_name(col2)?;
1270    let collected = df.collect_inner()?;
1271    let pl_df = collected.as_ref();
1272    let grouped = pl_df
1273        .clone()
1274        .lazy()
1275        .group_by([col(c1.as_str()), col(c2.as_str())])
1276        .agg([len().alias("count")])
1277        .collect()?;
1278    Ok(super::DataFrame::from_polars_with_options(
1279        grouped,
1280        case_sensitive,
1281    ))
1282}
1283
1284/// Unpivot (melt). PySpark melt. Long format with id_vars kept, plus "variable" and "value" columns.
1285pub fn melt(
1286    df: &DataFrame,
1287    id_vars: &[&str],
1288    value_vars: &[&str],
1289    case_sensitive: bool,
1290) -> Result<DataFrame, PolarsError> {
1291    use polars::prelude::*;
1292    let collected = df.collect_inner()?;
1293    let pl_df = collected.as_ref();
1294    if value_vars.is_empty() {
1295        return Ok(super::DataFrame::from_polars_with_options(
1296            pl_df.head(Some(0)),
1297            case_sensitive,
1298        ));
1299    }
1300    let id_resolved: Vec<String> = id_vars
1301        .iter()
1302        .map(|s| df.resolve_column_name(s).map(|r| r.to_string()))
1303        .collect::<Result<Vec<_>, _>>()?;
1304    let value_resolved: Vec<String> = value_vars
1305        .iter()
1306        .map(|s| df.resolve_column_name(s).map(|r| r.to_string()))
1307        .collect::<Result<Vec<_>, _>>()?;
1308    let mut parts = Vec::with_capacity(value_vars.len());
1309    for vname in &value_resolved {
1310        let select_cols: Vec<&str> = id_resolved
1311            .iter()
1312            .map(|s| s.as_str())
1313            .chain([vname.as_str()])
1314            .collect();
1315        let mut part = pl_df.select(select_cols)?;
1316        let var_series = Series::new("variable".into(), vec![vname.as_str(); part.height()]);
1317        part.with_column(var_series.into())?;
1318        part.rename(vname.as_str(), "value".into())?;
1319        parts.push(part);
1320    }
1321    let mut out = parts
1322        .first()
1323        .ok_or_else(|| PolarsError::ComputeError("melt: no value columns".into()))?
1324        .clone();
1325    for p in parts.iter().skip(1) {
1326        out.vstack_mut(p)?;
1327    }
1328    let col_order: Vec<&str> = id_resolved
1329        .iter()
1330        .map(|s| s.as_str())
1331        .chain(["variable", "value"])
1332        .collect();
1333    let out = out.select(col_order)?;
1334    Ok(super::DataFrame::from_polars_with_options(
1335        out,
1336        case_sensitive,
1337    ))
1338}
1339
1340/// Set difference keeping duplicates. PySpark exceptAll. Simple impl: same as subtract.
1341pub fn except_all(
1342    left: &DataFrame,
1343    right: &DataFrame,
1344    case_sensitive: bool,
1345) -> Result<DataFrame, PolarsError> {
1346    subtract(left, right, case_sensitive)
1347}
1348
1349/// Set intersection keeping duplicates. PySpark intersectAll. Simple impl: same as intersect.
1350pub fn intersect_all(
1351    left: &DataFrame,
1352    right: &DataFrame,
1353    case_sensitive: bool,
1354) -> Result<DataFrame, PolarsError> {
1355    intersect(left, right, case_sensitive)
1356}
1357
1358#[cfg(test)]
1359mod tests {
1360    use super::{distinct, drop, dropna, first, head, limit, offset, order_by, union_by_name};
1361    use crate::{DataFrame, SparkSession};
1362    use serde_json::json;
1363
1364    fn test_df() -> DataFrame {
1365        let spark = SparkSession::builder()
1366            .app_name("transform_tests")
1367            .get_or_create();
1368        spark
1369            .create_dataframe(
1370                vec![
1371                    (1i64, 10i64, "a".to_string()),
1372                    (2i64, 20i64, "b".to_string()),
1373                    (3i64, 30i64, "c".to_string()),
1374                ],
1375                vec!["id", "v", "label"],
1376            )
1377            .unwrap()
1378    }
1379
1380    #[test]
1381    fn limit_zero() {
1382        let df = test_df();
1383        let out = limit(&df, 0, false).unwrap();
1384        assert_eq!(out.count().unwrap(), 0);
1385    }
1386
1387    #[test]
1388    fn limit_more_than_rows() {
1389        let df = test_df();
1390        let out = limit(&df, 10, false).unwrap();
1391        assert_eq!(out.count().unwrap(), 3);
1392    }
1393
1394    #[test]
1395    fn distinct_on_empty() {
1396        let spark = SparkSession::builder()
1397            .app_name("transform_tests")
1398            .get_or_create();
1399        let df = spark
1400            .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["a", "b", "c"])
1401            .unwrap();
1402        let out = distinct(&df, None, false).unwrap();
1403        assert_eq!(out.count().unwrap(), 0);
1404    }
1405
1406    #[test]
1407    fn first_returns_one_row() {
1408        let df = test_df();
1409        let out = first(&df, false).unwrap();
1410        assert_eq!(out.count().unwrap(), 1);
1411    }
1412
1413    /// Issue #579: first() after orderBy must return first row in sort order, not storage order.
1414    #[test]
1415    fn first_after_order_by_returns_first_in_sort_order() {
1416        use polars::prelude::df;
1417        let spark = SparkSession::builder()
1418            .app_name("transform_tests")
1419            .get_or_create();
1420        let pl = df![
1421            "name" => ["Charlie", "Alice", "Bob"],
1422            "value" => [3i64, 1i64, 2i64],
1423        ]
1424        .unwrap();
1425        let df = spark.create_dataframe_from_polars(pl);
1426        let ordered = order_by(&df, vec!["value"], vec![true], false).unwrap();
1427        let one = first(&ordered, false).unwrap();
1428        let collected = one.collect_inner().unwrap();
1429        let name_series = collected.column("name").unwrap();
1430        let first_name = name_series.str().unwrap().get(0).unwrap();
1431        assert_eq!(
1432            first_name, "Alice",
1433            "first() after orderBy(value) must return row with min value (Alice=1), not first in storage (Charlie)"
1434        );
1435    }
1436
1437    #[test]
1438    fn head_n() {
1439        let df = test_df();
1440        let out = head(&df, 2, false).unwrap();
1441        assert_eq!(out.count().unwrap(), 2);
1442    }
1443
1444    #[test]
1445    fn offset_skip_first() {
1446        let df = test_df();
1447        let out = offset(&df, 1, false).unwrap();
1448        assert_eq!(out.count().unwrap(), 2);
1449    }
1450
1451    #[test]
1452    fn offset_beyond_length_returns_empty() {
1453        let df = test_df();
1454        let out = offset(&df, 10, false).unwrap();
1455        assert_eq!(out.count().unwrap(), 0);
1456    }
1457
1458    #[test]
1459    fn drop_column() {
1460        let df = test_df();
1461        let out = drop(&df, vec!["v"], false).unwrap();
1462        let cols = out.columns().unwrap();
1463        assert!(!cols.contains(&"v".to_string()));
1464        assert_eq!(out.count().unwrap(), 3);
1465    }
1466
1467    /// Issue #603: unionByName with same-named columns of different types (e.g. id Int vs id String) must coerce and succeed.
1468    #[test]
1469    fn union_by_name_coerces_different_column_types() {
1470        use polars::prelude::df;
1471
1472        let spark = SparkSession::builder()
1473            .app_name("transform_tests")
1474            .get_or_create();
1475        let left_pl = df!("id" => &[1i64], "name" => &["a"]).unwrap();
1476        let left = spark.create_dataframe_from_polars(left_pl);
1477        let schema = vec![
1478            ("id".to_string(), "string".to_string()),
1479            ("name".to_string(), "string".to_string()),
1480        ];
1481        let right = spark
1482            .create_dataframe_from_rows(vec![vec![json!("2"), json!("b")]], schema)
1483            .unwrap();
1484        let out = union_by_name(&left, &right, true, false)
1485            .expect("issue #603: union_by_name must coerce id Int64 vs String");
1486        assert_eq!(out.count().unwrap(), 2);
1487    }
1488
1489    #[test]
1490    fn dropna_all_columns() {
1491        let df = test_df();
1492        let out = dropna(&df, None, "any", None, false).unwrap();
1493        assert_eq!(out.count().unwrap(), 3);
1494    }
1495}