Skip to main content

robin_sparkless/dataframe/
transformations.rs

1//! DataFrame transformation operations: filter, select, with_column, order_by,
2//! union, distinct, drop, dropna, fillna, limit, with_column_renamed,
3//! replace, cross_join, describe, subtract, intersect,
4//! sample, random_split, first, head, take, tail, is_empty, to_df.
5
6use super::DataFrame;
7use crate::functions::SortOrder;
8use polars::prelude::{
9    col, Expr, IntoLazy, IntoSeries, NamedFrom, PolarsError, Series, UnionArgs, UniqueKeepStrategy,
10};
11
12/// Select columns (returns a new DataFrame). Preserves case_sensitive on result.
13pub fn select(
14    df: &DataFrame,
15    cols: Vec<&str>,
16    case_sensitive: bool,
17) -> Result<DataFrame, PolarsError> {
18    let selected = df.df.select(cols)?;
19    Ok(super::DataFrame::from_polars_with_options(
20        selected,
21        case_sensitive,
22    ))
23}
24
25/// Select using column expressions (e.g. F.regexp_extract_all(...).alias("m")). Preserves case_sensitive.
26pub fn select_with_exprs(
27    df: &DataFrame,
28    exprs: Vec<Expr>,
29    case_sensitive: bool,
30) -> Result<DataFrame, PolarsError> {
31    let lf = df.df.as_ref().clone().lazy();
32    let out_df = lf.select(exprs).collect()?;
33    Ok(super::DataFrame::from_polars_with_options(
34        out_df,
35        case_sensitive,
36    ))
37}
38
39/// Filter rows using a Polars expression. Preserves case_sensitive on result.
40pub fn filter(
41    df: &DataFrame,
42    condition: Expr,
43    case_sensitive: bool,
44) -> Result<DataFrame, PolarsError> {
45    let lf = df.df.as_ref().clone().lazy().filter(condition);
46    let out_df = lf.collect()?;
47    Ok(super::DataFrame::from_polars_with_options(
48        out_df,
49        case_sensitive,
50    ))
51}
52
53/// Add or replace a column. Handles deferred rand/randn by generating a full-length series (PySpark-like).
54pub fn with_column(
55    df: &DataFrame,
56    column_name: &str,
57    column: &crate::column::Column,
58    case_sensitive: bool,
59) -> Result<DataFrame, PolarsError> {
60    if let Some(deferred) = column.deferred {
61        match deferred {
62            crate::column::DeferredRandom::Rand(seed) => {
63                let mut pl_df = df.df.as_ref().clone();
64                let n = pl_df.height();
65                let series = crate::udfs::series_rand_n(column_name, n, seed);
66                pl_df.with_column(series)?;
67                return Ok(super::DataFrame::from_polars_with_options(
68                    pl_df,
69                    case_sensitive,
70                ));
71            }
72            crate::column::DeferredRandom::Randn(seed) => {
73                let mut pl_df = df.df.as_ref().clone();
74                let n = pl_df.height();
75                let series = crate::udfs::series_randn_n(column_name, n, seed);
76                pl_df.with_column(series)?;
77                return Ok(super::DataFrame::from_polars_with_options(
78                    pl_df,
79                    case_sensitive,
80                ));
81            }
82        }
83    }
84    let lf = df.df.as_ref().clone().lazy();
85    let lf_with_col = lf.with_column(column.expr().clone().alias(column_name));
86    let pl_df = lf_with_col.collect()?;
87    Ok(super::DataFrame::from_polars_with_options(
88        pl_df,
89        case_sensitive,
90    ))
91}
92
93/// Order by columns (sort). Preserves case_sensitive on result.
94pub fn order_by(
95    df: &DataFrame,
96    column_names: Vec<&str>,
97    ascending: Vec<bool>,
98    case_sensitive: bool,
99) -> Result<DataFrame, PolarsError> {
100    use polars::prelude::*;
101    let mut asc = ascending;
102    while asc.len() < column_names.len() {
103        asc.push(true);
104    }
105    asc.truncate(column_names.len());
106    let lf = df.df.as_ref().clone().lazy();
107    let exprs: Vec<Expr> = column_names.iter().map(|name| col(*name)).collect();
108    let descending: Vec<bool> = asc.iter().map(|&a| !a).collect();
109    let sorted = lf.sort_by_exprs(
110        exprs,
111        SortMultipleOptions::new().with_order_descending_multi(descending),
112    );
113    let pl_df = sorted.collect()?;
114    Ok(super::DataFrame::from_polars_with_options(
115        pl_df,
116        case_sensitive,
117    ))
118}
119
120/// Order by sort expressions (asc/desc with nulls_first/last). Preserves case_sensitive on result.
121pub fn order_by_exprs(
122    df: &DataFrame,
123    sort_orders: Vec<SortOrder>,
124    case_sensitive: bool,
125) -> Result<DataFrame, PolarsError> {
126    use polars::prelude::*;
127    if sort_orders.is_empty() {
128        return Ok(super::DataFrame::from_polars_with_options(
129            df.df.as_ref().clone(),
130            case_sensitive,
131        ));
132    }
133    let exprs: Vec<Expr> = sort_orders.iter().map(|s| s.expr().clone()).collect();
134    let descending: Vec<bool> = sort_orders.iter().map(|s| s.descending).collect();
135    let nulls_last: Vec<bool> = sort_orders.iter().map(|s| s.nulls_last).collect();
136    let opts = SortMultipleOptions::new()
137        .with_order_descending_multi(descending)
138        .with_nulls_last_multi(nulls_last);
139    let lf = df.df.as_ref().clone().lazy();
140    let sorted = lf.sort_by_exprs(exprs, opts);
141    let pl_df = sorted.collect()?;
142    Ok(super::DataFrame::from_polars_with_options(
143        pl_df,
144        case_sensitive,
145    ))
146}
147
148/// Union (unionAll): stack another DataFrame vertically. Schemas must match (same columns, same order).
149pub fn union(
150    left: &DataFrame,
151    right: &DataFrame,
152    case_sensitive: bool,
153) -> Result<DataFrame, PolarsError> {
154    let lf1 = left.df.as_ref().clone().lazy();
155    let lf2 = right.df.as_ref().clone().lazy();
156    let out = polars::prelude::concat([lf1, lf2], UnionArgs::default())?.collect()?;
157    Ok(super::DataFrame::from_polars_with_options(
158        out,
159        case_sensitive,
160    ))
161}
162
163/// Union by name: stack vertically, aligning columns by name. Right columns are reordered to match left; missing columns in right become nulls.
164pub fn union_by_name(
165    left: &DataFrame,
166    right: &DataFrame,
167    case_sensitive: bool,
168) -> Result<DataFrame, PolarsError> {
169    use polars::prelude::*;
170    let left_names = left.df.get_column_names();
171    let right_df = right.df.as_ref();
172    let right_names = right_df.get_column_names();
173    let resolve_right = |name: &str| -> Option<String> {
174        if case_sensitive {
175            right_names
176                .iter()
177                .find(|n| n.as_str() == name)
178                .map(|s| s.as_str().to_string())
179        } else {
180            let name_lower = name.to_lowercase();
181            right_names
182                .iter()
183                .find(|n| n.as_str().to_lowercase() == name_lower)
184                .map(|s| s.as_str().to_string())
185        }
186    };
187    let mut exprs: Vec<Expr> = Vec::with_capacity(left_names.len());
188    for left_col in left_names.iter() {
189        let left_str = left_col.as_str();
190        if let Some(r) = resolve_right(left_str) {
191            exprs.push(col(r.as_str()));
192        } else {
193            exprs.push(Expr::Literal(polars::prelude::LiteralValue::Null).alias(left_str));
194        }
195    }
196    let right_aligned = right_df.clone().lazy().select(exprs).collect()?;
197    let lf1 = left.df.as_ref().clone().lazy();
198    let lf2 = right_aligned.lazy();
199    let out = polars::prelude::concat([lf1, lf2], UnionArgs::default())?.collect()?;
200    Ok(super::DataFrame::from_polars_with_options(
201        out,
202        case_sensitive,
203    ))
204}
205
206/// Distinct: drop duplicate rows (all columns or subset).
207pub fn distinct(
208    df: &DataFrame,
209    subset: Option<Vec<&str>>,
210    case_sensitive: bool,
211) -> Result<DataFrame, PolarsError> {
212    let lf = df.df.as_ref().clone().lazy();
213    let subset_names: Option<Vec<String>> =
214        subset.map(|cols| cols.iter().map(|s| (*s).to_string()).collect());
215    let lf = lf.unique(subset_names, UniqueKeepStrategy::First);
216    let pl_df = lf.collect()?;
217    Ok(super::DataFrame::from_polars_with_options(
218        pl_df,
219        case_sensitive,
220    ))
221}
222
223/// Drop one or more columns.
224pub fn drop(
225    df: &DataFrame,
226    columns: Vec<&str>,
227    case_sensitive: bool,
228) -> Result<DataFrame, PolarsError> {
229    let resolved: Vec<String> = columns
230        .iter()
231        .map(|c| df.resolve_column_name(c))
232        .collect::<Result<Vec<_>, _>>()?;
233    let all_names = df.df.get_column_names();
234    let to_keep: Vec<&str> = all_names
235        .iter()
236        .filter(|n| !resolved.iter().any(|r| r == n.as_str()))
237        .map(|n| n.as_str())
238        .collect();
239    let pl_df = df.df.select(to_keep)?;
240    Ok(super::DataFrame::from_polars_with_options(
241        pl_df,
242        case_sensitive,
243    ))
244}
245
246/// Drop rows with nulls (all columns or subset).
247pub fn dropna(
248    df: &DataFrame,
249    subset: Option<Vec<&str>>,
250    case_sensitive: bool,
251) -> Result<DataFrame, PolarsError> {
252    let lf = df.df.as_ref().clone().lazy();
253    let subset_exprs: Option<Vec<Expr>> = match subset {
254        Some(cols) => Some(cols.iter().map(|c| col(*c)).collect()),
255        None => Some(
256            df.df
257                .get_column_names()
258                .iter()
259                .map(|n| col(n.as_str()))
260                .collect(),
261        ),
262    };
263    let lf = lf.drop_nulls(subset_exprs);
264    let pl_df = lf.collect()?;
265    Ok(super::DataFrame::from_polars_with_options(
266        pl_df,
267        case_sensitive,
268    ))
269}
270
271/// Fill nulls with a literal expression (applied to all columns). For column-specific fill, use a map in a future extension.
272pub fn fillna(
273    df: &DataFrame,
274    value_expr: Expr,
275    case_sensitive: bool,
276) -> Result<DataFrame, PolarsError> {
277    use polars::prelude::*;
278    let exprs: Vec<Expr> = df
279        .df
280        .get_column_names()
281        .iter()
282        .map(|n| col(n.as_str()).fill_null(value_expr.clone()))
283        .collect();
284    let pl_df = df
285        .df
286        .as_ref()
287        .clone()
288        .lazy()
289        .with_columns(exprs)
290        .collect()?;
291    Ok(super::DataFrame::from_polars_with_options(
292        pl_df,
293        case_sensitive,
294    ))
295}
296
297/// Limit: return first n rows.
298pub fn limit(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
299    let pl_df = df.df.as_ref().clone().head(Some(n));
300    Ok(super::DataFrame::from_polars_with_options(
301        pl_df,
302        case_sensitive,
303    ))
304}
305
306/// Rename a column (old_name -> new_name).
307pub fn with_column_renamed(
308    df: &DataFrame,
309    old_name: &str,
310    new_name: &str,
311    case_sensitive: bool,
312) -> Result<DataFrame, PolarsError> {
313    let resolved = df.resolve_column_name(old_name)?;
314    let mut pl_df = df.df.as_ref().clone();
315    pl_df.rename(resolved.as_str(), new_name.into())?;
316    Ok(super::DataFrame::from_polars_with_options(
317        pl_df,
318        case_sensitive,
319    ))
320}
321
322/// Replace values in a column: where column == old_value, use new_value. PySpark replace (single column).
323pub fn replace(
324    df: &DataFrame,
325    column_name: &str,
326    old_value: Expr,
327    new_value: Expr,
328    case_sensitive: bool,
329) -> Result<DataFrame, PolarsError> {
330    use polars::prelude::*;
331    let resolved = df.resolve_column_name(column_name)?;
332    let repl = when(col(resolved.as_str()).eq(old_value))
333        .then(new_value)
334        .otherwise(col(resolved.as_str()));
335    let pl_df = df
336        .df
337        .as_ref()
338        .clone()
339        .lazy()
340        .with_column(repl.alias(resolved.as_str()))
341        .collect()?;
342    Ok(super::DataFrame::from_polars_with_options(
343        pl_df,
344        case_sensitive,
345    ))
346}
347
348/// Cross join: cartesian product of two DataFrames. PySpark crossJoin.
349pub fn cross_join(
350    left: &DataFrame,
351    right: &DataFrame,
352    case_sensitive: bool,
353) -> Result<DataFrame, PolarsError> {
354    let lf_left = left.df.as_ref().clone().lazy();
355    let lf_right = right.df.as_ref().clone().lazy();
356    let out = lf_left.cross_join(lf_right, None);
357    let pl_df = out.collect()?;
358    Ok(super::DataFrame::from_polars_with_options(
359        pl_df,
360        case_sensitive,
361    ))
362}
363
364/// Summary statistics (count, mean, std, min, max). PySpark describe.
365/// Builds a summary DataFrame with a "summary" column (PySpark name) and one column per numeric input column.
366pub fn describe(df: &DataFrame, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
367    use polars::prelude::*;
368    let pl_df = df.df.as_ref().clone();
369    let mut stat_values: Vec<Column> = Vec::new();
370    for col in pl_df.get_columns() {
371        let s = col.as_materialized_series();
372        let dtype = s.dtype();
373        if dtype.is_numeric() {
374            let name = s.name().clone();
375            let count = s.len() as i64 - s.null_count() as i64;
376            let mean_f = s.mean().unwrap_or(f64::NAN);
377            let std_f = s.std(1).unwrap_or(f64::NAN);
378            let s_f64 = s.cast(&DataType::Float64)?;
379            let ca = s_f64
380                .f64()
381                .map_err(|_| PolarsError::ComputeError("cast to f64 failed".into()))?;
382            let min_f = ca.min().unwrap_or(f64::NAN);
383            let max_f = ca.max().unwrap_or(f64::NAN);
384            // PySpark describe/summary returns string type for value columns
385            let is_float = matches!(dtype, DataType::Float64 | DataType::Float32);
386            let count_s = count.to_string();
387            let mean_s = if mean_f.is_nan() {
388                "None".to_string()
389            } else {
390                format!("{:.1}", mean_f)
391            };
392            let std_s = if std_f.is_nan() {
393                "None".to_string()
394            } else {
395                format!("{:.1}", std_f)
396            };
397            let min_s = if min_f.is_nan() {
398                "None".to_string()
399            } else if min_f.fract() == 0.0 && is_float {
400                format!("{:.1}", min_f)
401            } else if min_f.fract() == 0.0 {
402                format!("{:.0}", min_f)
403            } else {
404                format!("{min_f}")
405            };
406            let max_s = if max_f.is_nan() {
407                "None".to_string()
408            } else if max_f.fract() == 0.0 && is_float {
409                format!("{:.1}", max_f)
410            } else if max_f.fract() == 0.0 {
411                format!("{:.0}", max_f)
412            } else {
413                format!("{max_f}")
414            };
415            let series = Series::new(
416                name,
417                [
418                    count_s.as_str(),
419                    mean_s.as_str(),
420                    std_s.as_str(),
421                    min_s.as_str(),
422                    max_s.as_str(),
423                ],
424            );
425            stat_values.push(series.into());
426        }
427    }
428    if stat_values.is_empty() {
429        // No numeric columns: return minimal describe with just summary column (PySpark name)
430        let stat_col = Series::new(
431            "summary".into(),
432            &["count", "mean", "stddev", "min", "max" as &str],
433        )
434        .into();
435        let empty: Vec<f64> = Vec::new();
436        let empty_series = Series::new("placeholder".into(), empty).into();
437        let out_pl = polars::prelude::DataFrame::new(vec![stat_col, empty_series])?;
438        return Ok(super::DataFrame::from_polars_with_options(
439            out_pl,
440            case_sensitive,
441        ));
442    }
443    let summary_col = Series::new(
444        "summary".into(),
445        &["count", "mean", "stddev", "min", "max" as &str],
446    )
447    .into();
448    let mut cols: Vec<Column> = vec![summary_col];
449    cols.extend(stat_values);
450    let out_pl = polars::prelude::DataFrame::new(cols)?;
451    Ok(super::DataFrame::from_polars_with_options(
452        out_pl,
453        case_sensitive,
454    ))
455}
456
457/// Set difference: rows in left that are not in right (by all columns). PySpark subtract / except.
458pub fn subtract(
459    left: &DataFrame,
460    right: &DataFrame,
461    case_sensitive: bool,
462) -> Result<DataFrame, PolarsError> {
463    use polars::prelude::*;
464    let left_names = left.df.get_column_names();
465    let left_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
466    let right_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
467    let right_lf = right.df.as_ref().clone().lazy();
468    let left_lf = left.df.as_ref().clone().lazy();
469    let anti = left_lf.join(right_lf, left_on, right_on, JoinArgs::new(JoinType::Anti));
470    let pl_df = anti.collect()?;
471    Ok(super::DataFrame::from_polars_with_options(
472        pl_df,
473        case_sensitive,
474    ))
475}
476
477/// Set intersection: rows that appear in both DataFrames (by all columns). PySpark intersect.
478pub fn intersect(
479    left: &DataFrame,
480    right: &DataFrame,
481    case_sensitive: bool,
482) -> Result<DataFrame, PolarsError> {
483    use polars::prelude::*;
484    let left_names = left.df.get_column_names();
485    let left_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
486    let right_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
487    let left_lf = left.df.as_ref().clone().lazy();
488    let right_lf = right.df.as_ref().clone().lazy();
489    let semi = left_lf.join(right_lf, left_on, right_on, JoinArgs::new(JoinType::Semi));
490    let pl_df = semi.collect()?;
491    Ok(super::DataFrame::from_polars_with_options(
492        pl_df,
493        case_sensitive,
494    ))
495}
496
497// ---------- Batch A: sample, first/head/take/tail, is_empty, to_df ----------
498
499/// Sample a fraction of rows. PySpark sample(withReplacement, fraction, seed).
500pub fn sample(
501    df: &DataFrame,
502    with_replacement: bool,
503    fraction: f64,
504    seed: Option<u64>,
505    case_sensitive: bool,
506) -> Result<DataFrame, PolarsError> {
507    use polars::prelude::Series;
508    let n = df.df.height();
509    if n == 0 {
510        return Ok(super::DataFrame::from_polars_with_options(
511            df.df.as_ref().clone(),
512            case_sensitive,
513        ));
514    }
515    let take_n = (n as f64 * fraction).round() as usize;
516    let take_n = take_n.min(n).max(0);
517    if take_n == 0 {
518        return Ok(super::DataFrame::from_polars_with_options(
519            df.df.as_ref().clone().head(Some(0)),
520            case_sensitive,
521        ));
522    }
523    let idx_series = Series::new("idx".into(), (0..n).map(|i| i as u32).collect::<Vec<_>>());
524    let sampled_idx = idx_series.sample_n(take_n, with_replacement, true, seed)?;
525    let idx_ca = sampled_idx
526        .u32()
527        .map_err(|_| PolarsError::ComputeError("sample: expected u32 indices".into()))?;
528    let pl_df = df.df.as_ref().take(idx_ca)?;
529    Ok(super::DataFrame::from_polars_with_options(
530        pl_df,
531        case_sensitive,
532    ))
533}
534
535/// Split DataFrame by weights (random split). PySpark randomSplit(weights, seed).
536/// Returns one DataFrame per weight; weights are normalized to fractions.
537/// Each row is assigned to exactly one split (disjoint partitions).
538pub fn random_split(
539    df: &DataFrame,
540    weights: &[f64],
541    seed: Option<u64>,
542    case_sensitive: bool,
543) -> Result<Vec<DataFrame>, PolarsError> {
544    let total: f64 = weights.iter().sum();
545    if total <= 0.0 || weights.is_empty() {
546        return Ok(Vec::new());
547    }
548    let n = df.df.height();
549    if n == 0 {
550        return Ok(weights.iter().map(|_| super::DataFrame::empty()).collect());
551    }
552    // Normalize weights to cumulative fractions: e.g. [0.25, 0.25, 0.5] -> [0.25, 0.5, 1.0]
553    let mut cum = Vec::with_capacity(weights.len());
554    let mut acc = 0.0_f64;
555    for w in weights {
556        acc += w / total;
557        cum.push(acc);
558    }
559    // Assign each row index to one bucket using a single seeded RNG (disjoint split).
560    use polars::prelude::Series;
561    use rand::Rng;
562    use rand::SeedableRng;
563    let mut rng = rand::rngs::StdRng::seed_from_u64(seed.unwrap_or(0));
564    let mut bucket_indices: Vec<Vec<u32>> = (0..weights.len()).map(|_| Vec::new()).collect();
565    for i in 0..n {
566        let r: f64 = rng.gen();
567        let bucket = cum
568            .iter()
569            .position(|&c| r < c)
570            .unwrap_or(weights.len().saturating_sub(1));
571        bucket_indices[bucket].push(i as u32);
572    }
573    let pl = df.df.as_ref();
574    let mut out = Vec::with_capacity(weights.len());
575    for indices in bucket_indices {
576        if indices.is_empty() {
577            out.push(super::DataFrame::from_polars_with_options(
578                pl.clone().head(Some(0)),
579                case_sensitive,
580            ));
581        } else {
582            let idx_series = Series::new("idx".into(), indices);
583            let idx_ca = idx_series.u32().map_err(|_| {
584                PolarsError::ComputeError("random_split: expected u32 indices".into())
585            })?;
586            let taken = pl.take(idx_ca)?;
587            out.push(super::DataFrame::from_polars_with_options(
588                taken,
589                case_sensitive,
590            ));
591        }
592    }
593    Ok(out)
594}
595
596/// Stratified sample by column value. PySpark sampleBy(col, fractions, seed).
597/// fractions: list of (value as Expr literal, fraction to sample for that value).
598pub fn sample_by(
599    df: &DataFrame,
600    col_name: &str,
601    fractions: &[(Expr, f64)],
602    seed: Option<u64>,
603    case_sensitive: bool,
604) -> Result<DataFrame, PolarsError> {
605    use polars::prelude::*;
606    if fractions.is_empty() {
607        return Ok(super::DataFrame::from_polars_with_options(
608            df.df.as_ref().clone().head(Some(0)),
609            case_sensitive,
610        ));
611    }
612    let resolved = df.resolve_column_name(col_name)?;
613    let mut parts = Vec::with_capacity(fractions.len());
614    for (value_expr, frac) in fractions {
615        let cond = col(resolved.as_str()).eq(value_expr.clone());
616        let filtered = df.df.as_ref().clone().lazy().filter(cond).collect()?;
617        if filtered.height() == 0 {
618            parts.push(filtered.head(Some(0)));
619            continue;
620        }
621        let sampled = sample(
622            &super::DataFrame::from_polars_with_options(filtered, case_sensitive),
623            false,
624            *frac,
625            seed,
626            case_sensitive,
627        )?;
628        parts.push(sampled.df.as_ref().clone());
629    }
630    let mut out = parts
631        .first()
632        .ok_or_else(|| PolarsError::ComputeError("sample_by: no parts".into()))?
633        .clone();
634    for p in parts.iter().skip(1) {
635        out.vstack_mut(p)?;
636    }
637    Ok(super::DataFrame::from_polars_with_options(
638        out,
639        case_sensitive,
640    ))
641}
642
643/// First row as a DataFrame (one row). PySpark first().
644pub fn first(df: &DataFrame, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
645    let pl_df = df.df.as_ref().clone().head(Some(1));
646    Ok(super::DataFrame::from_polars_with_options(
647        pl_df,
648        case_sensitive,
649    ))
650}
651
652/// First n rows. PySpark head(n). Same as limit.
653pub fn head(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
654    limit(df, n, case_sensitive)
655}
656
657/// Take first n rows (alias for limit). PySpark take(n).
658pub fn take(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
659    limit(df, n, case_sensitive)
660}
661
662/// Last n rows. PySpark tail(n).
663pub fn tail(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
664    let total = df.df.height();
665    let skip = total.saturating_sub(n);
666    let pl_df = df.df.as_ref().clone().slice(skip as i64, n);
667    Ok(super::DataFrame::from_polars_with_options(
668        pl_df,
669        case_sensitive,
670    ))
671}
672
673/// Whether the DataFrame has zero rows. PySpark isEmpty.
674pub fn is_empty(df: &DataFrame) -> bool {
675    df.df.height() == 0
676}
677
678/// Rename columns. PySpark toDF(*colNames). Names must match length of columns.
679pub fn to_df(
680    df: &DataFrame,
681    names: &[&str],
682    case_sensitive: bool,
683) -> Result<DataFrame, PolarsError> {
684    let cols = df.df.get_column_names();
685    if names.len() != cols.len() {
686        return Err(PolarsError::ComputeError(
687            format!(
688                "toDF: expected {} column names, got {}",
689                cols.len(),
690                names.len()
691            )
692            .into(),
693        ));
694    }
695    let mut pl_df = df.df.as_ref().clone();
696    for (old, new) in cols.iter().zip(names.iter()) {
697        pl_df.rename(old.as_str(), (*new).into())?;
698    }
699    Ok(super::DataFrame::from_polars_with_options(
700        pl_df,
701        case_sensitive,
702    ))
703}
704
705// ---------- Batch B: toJSON, explain, printSchema ----------
706
707fn any_value_to_serde_value(av: &polars::prelude::AnyValue) -> serde_json::Value {
708    use polars::prelude::AnyValue;
709    use serde_json::Number;
710    match av {
711        AnyValue::Null => serde_json::Value::Null,
712        AnyValue::Boolean(v) => serde_json::Value::Bool(*v),
713        AnyValue::Int8(v) => serde_json::Value::Number(Number::from(*v as i64)),
714        AnyValue::Int32(v) => serde_json::Value::Number(Number::from(*v)),
715        AnyValue::Int64(v) => serde_json::Value::Number(Number::from(*v)),
716        AnyValue::UInt32(v) => serde_json::Value::Number(Number::from(*v)),
717        AnyValue::Float64(v) => Number::from_f64(*v)
718            .map(serde_json::Value::Number)
719            .unwrap_or(serde_json::Value::Null),
720        AnyValue::String(v) => serde_json::Value::String(v.to_string()),
721        _ => serde_json::Value::String(format!("{av:?}")),
722    }
723}
724
725/// Collect rows as JSON strings (one JSON object per row). PySpark toJSON.
726pub fn to_json(df: &DataFrame) -> Result<Vec<String>, PolarsError> {
727    use polars::prelude::*;
728    let pl = df.df.as_ref();
729    let names = pl.get_column_names();
730    let mut out = Vec::with_capacity(pl.height());
731    for r in 0..pl.height() {
732        let mut row = serde_json::Map::new();
733        for (i, name) in names.iter().enumerate() {
734            let col = pl
735                .get_columns()
736                .get(i)
737                .ok_or_else(|| PolarsError::ComputeError("to_json: column index".into()))?;
738            let series = col.as_materialized_series();
739            let av = series
740                .get(r)
741                .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
742            row.insert(name.to_string(), any_value_to_serde_value(&av));
743        }
744        out.push(
745            serde_json::to_string(&row)
746                .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?,
747        );
748    }
749    Ok(out)
750}
751
752/// Return a string describing the execution plan. PySpark explain.
753pub fn explain(_df: &DataFrame) -> String {
754    "DataFrame (eager Polars backend)".to_string()
755}
756
757/// Return schema as a tree string. PySpark printSchema (we return string; caller can print).
758pub fn print_schema(df: &DataFrame) -> Result<String, PolarsError> {
759    let schema = df.schema()?;
760    let mut s = "root\n".to_string();
761    for f in schema.fields() {
762        let dt = match &f.data_type {
763            crate::schema::DataType::String => "string",
764            crate::schema::DataType::Integer => "int",
765            crate::schema::DataType::Long => "bigint",
766            crate::schema::DataType::Double => "double",
767            crate::schema::DataType::Boolean => "boolean",
768            crate::schema::DataType::Date => "date",
769            crate::schema::DataType::Timestamp => "timestamp",
770            _ => "string",
771        };
772        s.push_str(&format!(" |-- {}: {}\n", f.name, dt));
773    }
774    Ok(s)
775}
776
777// ---------- Batch D: selectExpr, colRegex, withColumns, withColumnsRenamed, na ----------
778
779/// Select by expression strings. Minimal support: comma-separated column names. PySpark selectExpr.
780pub fn select_expr(
781    df: &DataFrame,
782    exprs: &[String],
783    case_sensitive: bool,
784) -> Result<DataFrame, PolarsError> {
785    let mut cols = Vec::new();
786    for e in exprs {
787        let e = e.trim();
788        if let Some((left, right)) = e.split_once(" as ") {
789            let col_name = left.trim();
790            let _alias = right.trim();
791            cols.push(df.resolve_column_name(col_name)?);
792        } else {
793            cols.push(df.resolve_column_name(e)?);
794        }
795    }
796    let refs: Vec<&str> = cols.iter().map(|s| s.as_str()).collect();
797    select(df, refs, case_sensitive)
798}
799
800/// Select columns whose names match the regex pattern. PySpark colRegex.
801pub fn col_regex(
802    df: &DataFrame,
803    pattern: &str,
804    case_sensitive: bool,
805) -> Result<DataFrame, PolarsError> {
806    let re = regex::Regex::new(pattern).map_err(|e| {
807        PolarsError::ComputeError(format!("colRegex: invalid pattern {pattern:?}: {e}").into())
808    })?;
809    let names = df.df.get_column_names();
810    let matched: Vec<&str> = names
811        .iter()
812        .filter(|n| re.is_match(n.as_str()))
813        .map(|s| s.as_str())
814        .collect();
815    if matched.is_empty() {
816        return Err(PolarsError::ComputeError(
817            format!("colRegex: no columns matched pattern {pattern:?}").into(),
818        ));
819    }
820    select(df, matched, case_sensitive)
821}
822
823/// Add or replace multiple columns. PySpark withColumns. Uses Column so deferred rand/randn get per-row values.
824pub fn with_columns(
825    df: &DataFrame,
826    exprs: &[(String, crate::column::Column)],
827    case_sensitive: bool,
828) -> Result<DataFrame, PolarsError> {
829    let mut current =
830        super::DataFrame::from_polars_with_options(df.df.as_ref().clone(), case_sensitive);
831    for (name, col) in exprs {
832        current = with_column(&current, name, col, case_sensitive)?;
833    }
834    Ok(current)
835}
836
837/// Rename multiple columns. PySpark withColumnsRenamed.
838pub fn with_columns_renamed(
839    df: &DataFrame,
840    renames: &[(String, String)],
841    case_sensitive: bool,
842) -> Result<DataFrame, PolarsError> {
843    let mut out = df.df.as_ref().clone();
844    for (old_name, new_name) in renames {
845        let resolved = df.resolve_column_name(old_name)?;
846        out.rename(resolved.as_str(), new_name.as_str().into())?;
847    }
848    Ok(super::DataFrame::from_polars_with_options(
849        out,
850        case_sensitive,
851    ))
852}
853
854/// NA sub-API builder. PySpark df.na().fill(...) / .drop(...).
855pub struct DataFrameNa<'a> {
856    pub(crate) df: &'a DataFrame,
857}
858
859impl<'a> DataFrameNa<'a> {
860    /// Fill nulls with the given value. PySpark na.fill(value).
861    pub fn fill(&self, value: Expr) -> Result<DataFrame, PolarsError> {
862        fillna(self.df, value, self.df.case_sensitive)
863    }
864
865    /// Drop rows with nulls in the given columns (or all). PySpark na.drop(subset).
866    pub fn drop(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
867        dropna(self.df, subset, self.df.case_sensitive)
868    }
869}
870
871// ---------- Batch E: offset, transform, freqItems, approxQuantile, crosstab, melt, exceptAll, intersectAll ----------
872
873/// Skip first n rows. PySpark offset(n).
874pub fn offset(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
875    let total = df.df.height();
876    let len = total.saturating_sub(n);
877    let pl_df = df.df.as_ref().clone().slice(n as i64, len);
878    Ok(super::DataFrame::from_polars_with_options(
879        pl_df,
880        case_sensitive,
881    ))
882}
883
884/// Transform DataFrame by a function. PySpark transform(func).
885pub fn transform<F>(df: &DataFrame, f: F) -> Result<DataFrame, PolarsError>
886where
887    F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
888{
889    let df_out = f(df.clone())?;
890    Ok(df_out)
891}
892
893/// Frequent items. PySpark freqItems. Returns one row with columns {col}_freqItems (array of values with frequency >= support).
894pub fn freq_items(
895    df: &DataFrame,
896    columns: &[&str],
897    support: f64,
898    case_sensitive: bool,
899) -> Result<DataFrame, PolarsError> {
900    use polars::prelude::SeriesMethods;
901    if columns.is_empty() {
902        return Ok(super::DataFrame::from_polars_with_options(
903            df.df.as_ref().clone().head(Some(0)),
904            case_sensitive,
905        ));
906    }
907    let support = support.clamp(1e-4, 1.0);
908    let pl_df = df.df.as_ref();
909    let n_total = pl_df.height() as f64;
910    if n_total == 0.0 {
911        let mut out = Vec::with_capacity(columns.len());
912        for col_name in columns {
913            let resolved = df.resolve_column_name(col_name)?;
914            let s = pl_df
915                .column(resolved.as_str())?
916                .as_series()
917                .ok_or_else(|| PolarsError::ComputeError("column not a series".into()))?
918                .clone();
919            let empty_sub = s.head(Some(0));
920            let list_chunked = polars::prelude::ListChunked::from_iter([empty_sub].into_iter())
921                .with_name(format!("{resolved}_freqItems").into());
922            out.push(list_chunked.into_series().into());
923        }
924        return Ok(super::DataFrame::from_polars_with_options(
925            polars::prelude::DataFrame::new(out)?,
926            case_sensitive,
927        ));
928    }
929    let mut out_series = Vec::with_capacity(columns.len());
930    for col_name in columns {
931        let resolved = df.resolve_column_name(col_name)?;
932        let s = pl_df
933            .column(resolved.as_str())?
934            .as_series()
935            .ok_or_else(|| PolarsError::ComputeError("column not a series".into()))?
936            .clone();
937        let vc = s.value_counts(false, false, "counts".into(), false)?;
938        let count_col = vc
939            .column("counts")
940            .map_err(|_| PolarsError::ComputeError("value_counts missing counts column".into()))?;
941        let counts = count_col
942            .u32()
943            .map_err(|_| PolarsError::ComputeError("freq_items: counts column not u32".into()))?;
944        let value_col_name = s.name();
945        let values_col = vc
946            .column(value_col_name.as_str())
947            .map_err(|_| PolarsError::ComputeError("value_counts missing value column".into()))?;
948        let threshold = (support * n_total).ceil() as u32;
949        let indices: Vec<u32> = counts
950            .into_iter()
951            .enumerate()
952            .filter_map(|(i, c)| {
953                if c? >= threshold {
954                    Some(i as u32)
955                } else {
956                    None
957                }
958            })
959            .collect();
960        let idx_series = Series::new("idx".into(), indices);
961        let idx_ca = idx_series
962            .u32()
963            .map_err(|_| PolarsError::ComputeError("freq_items: index series not u32".into()))?;
964        let values_series = values_col
965            .as_series()
966            .ok_or_else(|| PolarsError::ComputeError("value column not a series".into()))?;
967        let filtered = values_series.take(idx_ca)?;
968        let list_chunked = polars::prelude::ListChunked::from_iter([filtered].into_iter())
969            .with_name(format!("{resolved}_freqItems").into());
970        let list_row = list_chunked.into_series();
971        out_series.push(list_row.into());
972    }
973    let out_df = polars::prelude::DataFrame::new(out_series)?;
974    Ok(super::DataFrame::from_polars_with_options(
975        out_df,
976        case_sensitive,
977    ))
978}
979
980/// Approximate quantiles. PySpark approxQuantile. Returns one column "quantile" with one row per probability.
981pub fn approx_quantile(
982    df: &DataFrame,
983    column: &str,
984    probabilities: &[f64],
985    case_sensitive: bool,
986) -> Result<DataFrame, PolarsError> {
987    use polars::prelude::{ChunkQuantile, QuantileMethod};
988    if probabilities.is_empty() {
989        return Ok(super::DataFrame::from_polars_with_options(
990            polars::prelude::DataFrame::new(vec![Series::new(
991                "quantile".into(),
992                Vec::<f64>::new(),
993            )
994            .into()])?,
995            case_sensitive,
996        ));
997    }
998    let resolved = df.resolve_column_name(column)?;
999    let s = df
1000        .df
1001        .as_ref()
1002        .column(resolved.as_str())?
1003        .as_series()
1004        .ok_or_else(|| PolarsError::ComputeError("approx_quantile: column not a series".into()))?
1005        .clone();
1006    let s_f64 = s.cast(&polars::prelude::DataType::Float64)?;
1007    let ca = s_f64
1008        .f64()
1009        .map_err(|_| PolarsError::ComputeError("approx_quantile: need numeric column".into()))?;
1010    let mut quantiles = Vec::with_capacity(probabilities.len());
1011    for &p in probabilities {
1012        let q = ca.quantile(p, QuantileMethod::Linear)?;
1013        quantiles.push(q.unwrap_or(f64::NAN));
1014    }
1015    let out_df =
1016        polars::prelude::DataFrame::new(vec![Series::new("quantile".into(), quantiles).into()])?;
1017    Ok(super::DataFrame::from_polars_with_options(
1018        out_df,
1019        case_sensitive,
1020    ))
1021}
1022
1023/// Cross-tabulation. PySpark crosstab. Returns long format (col1, col2, count); for wide format use pivot on the result.
1024pub fn crosstab(
1025    df: &DataFrame,
1026    col1: &str,
1027    col2: &str,
1028    case_sensitive: bool,
1029) -> Result<DataFrame, PolarsError> {
1030    use polars::prelude::*;
1031    let c1 = df.resolve_column_name(col1)?;
1032    let c2 = df.resolve_column_name(col2)?;
1033    let pl_df = df.df.as_ref();
1034    let grouped = pl_df
1035        .clone()
1036        .lazy()
1037        .group_by([col(c1.as_str()), col(c2.as_str())])
1038        .agg([len().alias("count")])
1039        .collect()?;
1040    Ok(super::DataFrame::from_polars_with_options(
1041        grouped,
1042        case_sensitive,
1043    ))
1044}
1045
1046/// Unpivot (melt). PySpark melt. Long format with id_vars kept, plus "variable" and "value" columns.
1047pub fn melt(
1048    df: &DataFrame,
1049    id_vars: &[&str],
1050    value_vars: &[&str],
1051    case_sensitive: bool,
1052) -> Result<DataFrame, PolarsError> {
1053    use polars::prelude::*;
1054    let pl_df = df.df.as_ref();
1055    if value_vars.is_empty() {
1056        return Ok(super::DataFrame::from_polars_with_options(
1057            pl_df.head(Some(0)),
1058            case_sensitive,
1059        ));
1060    }
1061    let id_resolved: Vec<String> = id_vars
1062        .iter()
1063        .map(|s| df.resolve_column_name(s).map(|r| r.to_string()))
1064        .collect::<Result<Vec<_>, _>>()?;
1065    let value_resolved: Vec<String> = value_vars
1066        .iter()
1067        .map(|s| df.resolve_column_name(s).map(|r| r.to_string()))
1068        .collect::<Result<Vec<_>, _>>()?;
1069    let mut parts = Vec::with_capacity(value_vars.len());
1070    for vname in &value_resolved {
1071        let select_cols: Vec<&str> = id_resolved
1072            .iter()
1073            .map(|s| s.as_str())
1074            .chain([vname.as_str()])
1075            .collect();
1076        let mut part = pl_df.select(select_cols)?;
1077        let var_series = Series::new("variable".into(), vec![vname.as_str(); part.height()]);
1078        part.with_column(var_series)?;
1079        part.rename(vname.as_str(), "value".into())?;
1080        parts.push(part);
1081    }
1082    let mut out = parts
1083        .first()
1084        .ok_or_else(|| PolarsError::ComputeError("melt: no value columns".into()))?
1085        .clone();
1086    for p in parts.iter().skip(1) {
1087        out.vstack_mut(p)?;
1088    }
1089    let col_order: Vec<&str> = id_resolved
1090        .iter()
1091        .map(|s| s.as_str())
1092        .chain(["variable", "value"])
1093        .collect();
1094    let out = out.select(col_order)?;
1095    Ok(super::DataFrame::from_polars_with_options(
1096        out,
1097        case_sensitive,
1098    ))
1099}
1100
1101/// Set difference keeping duplicates. PySpark exceptAll. Simple impl: same as subtract.
1102pub fn except_all(
1103    left: &DataFrame,
1104    right: &DataFrame,
1105    case_sensitive: bool,
1106) -> Result<DataFrame, PolarsError> {
1107    subtract(left, right, case_sensitive)
1108}
1109
1110/// Set intersection keeping duplicates. PySpark intersectAll. Simple impl: same as intersect.
1111pub fn intersect_all(
1112    left: &DataFrame,
1113    right: &DataFrame,
1114    case_sensitive: bool,
1115) -> Result<DataFrame, PolarsError> {
1116    intersect(left, right, case_sensitive)
1117}
1118
1119#[cfg(test)]
1120mod tests {
1121    use super::{distinct, drop, dropna, first, head, limit, offset};
1122    use crate::{DataFrame, SparkSession};
1123
1124    fn test_df() -> DataFrame {
1125        let spark = SparkSession::builder()
1126            .app_name("transform_tests")
1127            .get_or_create();
1128        spark
1129            .create_dataframe(
1130                vec![
1131                    (1i64, 10i64, "a".to_string()),
1132                    (2i64, 20i64, "b".to_string()),
1133                    (3i64, 30i64, "c".to_string()),
1134                ],
1135                vec!["id", "v", "label"],
1136            )
1137            .unwrap()
1138    }
1139
1140    #[test]
1141    fn limit_zero() {
1142        let df = test_df();
1143        let out = limit(&df, 0, false).unwrap();
1144        assert_eq!(out.count().unwrap(), 0);
1145    }
1146
1147    #[test]
1148    fn limit_more_than_rows() {
1149        let df = test_df();
1150        let out = limit(&df, 10, false).unwrap();
1151        assert_eq!(out.count().unwrap(), 3);
1152    }
1153
1154    #[test]
1155    fn distinct_on_empty() {
1156        let spark = SparkSession::builder()
1157            .app_name("transform_tests")
1158            .get_or_create();
1159        let df = spark
1160            .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["a", "b", "c"])
1161            .unwrap();
1162        let out = distinct(&df, None, false).unwrap();
1163        assert_eq!(out.count().unwrap(), 0);
1164    }
1165
1166    #[test]
1167    fn first_returns_one_row() {
1168        let df = test_df();
1169        let out = first(&df, false).unwrap();
1170        assert_eq!(out.count().unwrap(), 1);
1171    }
1172
1173    #[test]
1174    fn head_n() {
1175        let df = test_df();
1176        let out = head(&df, 2, false).unwrap();
1177        assert_eq!(out.count().unwrap(), 2);
1178    }
1179
1180    #[test]
1181    fn offset_skip_first() {
1182        let df = test_df();
1183        let out = offset(&df, 1, false).unwrap();
1184        assert_eq!(out.count().unwrap(), 2);
1185    }
1186
1187    #[test]
1188    fn offset_beyond_length_returns_empty() {
1189        let df = test_df();
1190        let out = offset(&df, 10, false).unwrap();
1191        assert_eq!(out.count().unwrap(), 0);
1192    }
1193
1194    #[test]
1195    fn drop_column() {
1196        let df = test_df();
1197        let out = drop(&df, vec!["v"], false).unwrap();
1198        let cols = out.columns().unwrap();
1199        assert!(!cols.contains(&"v".to_string()));
1200        assert_eq!(out.count().unwrap(), 3);
1201    }
1202
1203    #[test]
1204    fn dropna_all_columns() {
1205        let df = test_df();
1206        let out = dropna(&df, None, false).unwrap();
1207        assert_eq!(out.count().unwrap(), 3);
1208    }
1209}