Skip to main content

robin_sparkless/dataframe/
transformations.rs

1//! DataFrame transformation operations: filter, select, with_column, order_by,
2//! union, distinct, drop, dropna, fillna, limit, with_column_renamed,
3//! replace, cross_join, describe, subtract, intersect,
4//! sample, random_split, first, head, take, tail, is_empty, to_df.
5
6use super::DataFrame;
7use crate::functions::SortOrder;
8use polars::prelude::{
9    col, Expr, IntoLazy, IntoSeries, NamedFrom, PolarsError, Series, UnionArgs, UniqueKeepStrategy,
10};
11
12/// Select columns (returns a new DataFrame). Preserves case_sensitive on result.
13pub fn select(
14    df: &DataFrame,
15    cols: Vec<&str>,
16    case_sensitive: bool,
17) -> Result<DataFrame, PolarsError> {
18    let selected = df.df.select(cols)?;
19    Ok(super::DataFrame::from_polars_with_options(
20        selected,
21        case_sensitive,
22    ))
23}
24
25/// Filter rows using a Polars expression. Preserves case_sensitive on result.
26pub fn filter(
27    df: &DataFrame,
28    condition: Expr,
29    case_sensitive: bool,
30) -> Result<DataFrame, PolarsError> {
31    let lf = df.df.as_ref().clone().lazy().filter(condition);
32    let out_df = lf.collect()?;
33    Ok(super::DataFrame::from_polars_with_options(
34        out_df,
35        case_sensitive,
36    ))
37}
38
39/// Add or replace a column. Handles deferred rand/randn by generating a full-length series (PySpark-like).
40pub fn with_column(
41    df: &DataFrame,
42    column_name: &str,
43    column: &crate::column::Column,
44    case_sensitive: bool,
45) -> Result<DataFrame, PolarsError> {
46    if let Some(deferred) = column.deferred {
47        match deferred {
48            crate::column::DeferredRandom::Rand(seed) => {
49                let mut pl_df = df.df.as_ref().clone();
50                let n = pl_df.height();
51                let series = crate::udfs::series_rand_n(column_name, n, seed);
52                pl_df.with_column(series)?;
53                return Ok(super::DataFrame::from_polars_with_options(
54                    pl_df,
55                    case_sensitive,
56                ));
57            }
58            crate::column::DeferredRandom::Randn(seed) => {
59                let mut pl_df = df.df.as_ref().clone();
60                let n = pl_df.height();
61                let series = crate::udfs::series_randn_n(column_name, n, seed);
62                pl_df.with_column(series)?;
63                return Ok(super::DataFrame::from_polars_with_options(
64                    pl_df,
65                    case_sensitive,
66                ));
67            }
68        }
69    }
70    let lf = df.df.as_ref().clone().lazy();
71    let lf_with_col = lf.with_column(column.expr().clone().alias(column_name));
72    let pl_df = lf_with_col.collect()?;
73    Ok(super::DataFrame::from_polars_with_options(
74        pl_df,
75        case_sensitive,
76    ))
77}
78
79/// Order by columns (sort). Preserves case_sensitive on result.
80pub fn order_by(
81    df: &DataFrame,
82    column_names: Vec<&str>,
83    ascending: Vec<bool>,
84    case_sensitive: bool,
85) -> Result<DataFrame, PolarsError> {
86    use polars::prelude::*;
87    let mut asc = ascending;
88    while asc.len() < column_names.len() {
89        asc.push(true);
90    }
91    asc.truncate(column_names.len());
92    let lf = df.df.as_ref().clone().lazy();
93    let exprs: Vec<Expr> = column_names.iter().map(|name| col(*name)).collect();
94    let descending: Vec<bool> = asc.iter().map(|&a| !a).collect();
95    let sorted = lf.sort_by_exprs(
96        exprs,
97        SortMultipleOptions::new().with_order_descending_multi(descending),
98    );
99    let pl_df = sorted.collect()?;
100    Ok(super::DataFrame::from_polars_with_options(
101        pl_df,
102        case_sensitive,
103    ))
104}
105
106/// Order by sort expressions (asc/desc with nulls_first/last). Preserves case_sensitive on result.
107pub fn order_by_exprs(
108    df: &DataFrame,
109    sort_orders: Vec<SortOrder>,
110    case_sensitive: bool,
111) -> Result<DataFrame, PolarsError> {
112    use polars::prelude::*;
113    if sort_orders.is_empty() {
114        return Ok(super::DataFrame::from_polars_with_options(
115            df.df.as_ref().clone(),
116            case_sensitive,
117        ));
118    }
119    let exprs: Vec<Expr> = sort_orders.iter().map(|s| s.expr().clone()).collect();
120    let descending: Vec<bool> = sort_orders.iter().map(|s| s.descending).collect();
121    let nulls_last: Vec<bool> = sort_orders.iter().map(|s| s.nulls_last).collect();
122    let opts = SortMultipleOptions::new()
123        .with_order_descending_multi(descending)
124        .with_nulls_last_multi(nulls_last);
125    let lf = df.df.as_ref().clone().lazy();
126    let sorted = lf.sort_by_exprs(exprs, opts);
127    let pl_df = sorted.collect()?;
128    Ok(super::DataFrame::from_polars_with_options(
129        pl_df,
130        case_sensitive,
131    ))
132}
133
134/// Union (unionAll): stack another DataFrame vertically. Schemas must match (same columns, same order).
135pub fn union(
136    left: &DataFrame,
137    right: &DataFrame,
138    case_sensitive: bool,
139) -> Result<DataFrame, PolarsError> {
140    let lf1 = left.df.as_ref().clone().lazy();
141    let lf2 = right.df.as_ref().clone().lazy();
142    let out = polars::prelude::concat([lf1, lf2], UnionArgs::default())?.collect()?;
143    Ok(super::DataFrame::from_polars_with_options(
144        out,
145        case_sensitive,
146    ))
147}
148
149/// Union by name: stack vertically, aligning columns by name. Right columns are reordered to match left; missing columns in right become nulls.
150pub fn union_by_name(
151    left: &DataFrame,
152    right: &DataFrame,
153    case_sensitive: bool,
154) -> Result<DataFrame, PolarsError> {
155    use polars::prelude::*;
156    let left_names = left.df.get_column_names();
157    let right_df = right.df.as_ref();
158    let right_names = right_df.get_column_names();
159    let resolve_right = |name: &str| -> Option<String> {
160        if case_sensitive {
161            right_names
162                .iter()
163                .find(|n| n.as_str() == name)
164                .map(|s| s.as_str().to_string())
165        } else {
166            let name_lower = name.to_lowercase();
167            right_names
168                .iter()
169                .find(|n| n.as_str().to_lowercase() == name_lower)
170                .map(|s| s.as_str().to_string())
171        }
172    };
173    let mut exprs: Vec<Expr> = Vec::with_capacity(left_names.len());
174    for left_col in left_names.iter() {
175        let left_str = left_col.as_str();
176        if let Some(r) = resolve_right(left_str) {
177            exprs.push(col(r.as_str()));
178        } else {
179            exprs.push(Expr::Literal(polars::prelude::LiteralValue::Null).alias(left_str));
180        }
181    }
182    let right_aligned = right_df.clone().lazy().select(exprs).collect()?;
183    let lf1 = left.df.as_ref().clone().lazy();
184    let lf2 = right_aligned.lazy();
185    let out = polars::prelude::concat([lf1, lf2], UnionArgs::default())?.collect()?;
186    Ok(super::DataFrame::from_polars_with_options(
187        out,
188        case_sensitive,
189    ))
190}
191
192/// Distinct: drop duplicate rows (all columns or subset).
193pub fn distinct(
194    df: &DataFrame,
195    subset: Option<Vec<&str>>,
196    case_sensitive: bool,
197) -> Result<DataFrame, PolarsError> {
198    let lf = df.df.as_ref().clone().lazy();
199    let subset_names: Option<Vec<String>> =
200        subset.map(|cols| cols.iter().map(|s| (*s).to_string()).collect());
201    let lf = lf.unique(subset_names, UniqueKeepStrategy::First);
202    let pl_df = lf.collect()?;
203    Ok(super::DataFrame::from_polars_with_options(
204        pl_df,
205        case_sensitive,
206    ))
207}
208
209/// Drop one or more columns.
210pub fn drop(
211    df: &DataFrame,
212    columns: Vec<&str>,
213    case_sensitive: bool,
214) -> Result<DataFrame, PolarsError> {
215    let resolved: Vec<String> = columns
216        .iter()
217        .map(|c| df.resolve_column_name(c))
218        .collect::<Result<Vec<_>, _>>()?;
219    let all_names = df.df.get_column_names();
220    let to_keep: Vec<&str> = all_names
221        .iter()
222        .filter(|n| !resolved.iter().any(|r| r == n.as_str()))
223        .map(|n| n.as_str())
224        .collect();
225    let pl_df = df.df.select(to_keep)?;
226    Ok(super::DataFrame::from_polars_with_options(
227        pl_df,
228        case_sensitive,
229    ))
230}
231
232/// Drop rows with nulls (all columns or subset).
233pub fn dropna(
234    df: &DataFrame,
235    subset: Option<Vec<&str>>,
236    case_sensitive: bool,
237) -> Result<DataFrame, PolarsError> {
238    let lf = df.df.as_ref().clone().lazy();
239    let subset_exprs: Option<Vec<Expr>> = match subset {
240        Some(cols) => Some(cols.iter().map(|c| col(*c)).collect()),
241        None => Some(
242            df.df
243                .get_column_names()
244                .iter()
245                .map(|n| col(n.as_str()))
246                .collect(),
247        ),
248    };
249    let lf = lf.drop_nulls(subset_exprs);
250    let pl_df = lf.collect()?;
251    Ok(super::DataFrame::from_polars_with_options(
252        pl_df,
253        case_sensitive,
254    ))
255}
256
257/// Fill nulls with a literal expression (applied to all columns). For column-specific fill, use a map in a future extension.
258pub fn fillna(
259    df: &DataFrame,
260    value_expr: Expr,
261    case_sensitive: bool,
262) -> Result<DataFrame, PolarsError> {
263    use polars::prelude::*;
264    let exprs: Vec<Expr> = df
265        .df
266        .get_column_names()
267        .iter()
268        .map(|n| col(n.as_str()).fill_null(value_expr.clone()))
269        .collect();
270    let pl_df = df
271        .df
272        .as_ref()
273        .clone()
274        .lazy()
275        .with_columns(exprs)
276        .collect()?;
277    Ok(super::DataFrame::from_polars_with_options(
278        pl_df,
279        case_sensitive,
280    ))
281}
282
283/// Limit: return first n rows.
284pub fn limit(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
285    let pl_df = df.df.as_ref().clone().head(Some(n));
286    Ok(super::DataFrame::from_polars_with_options(
287        pl_df,
288        case_sensitive,
289    ))
290}
291
292/// Rename a column (old_name -> new_name).
293pub fn with_column_renamed(
294    df: &DataFrame,
295    old_name: &str,
296    new_name: &str,
297    case_sensitive: bool,
298) -> Result<DataFrame, PolarsError> {
299    let resolved = df.resolve_column_name(old_name)?;
300    let mut pl_df = df.df.as_ref().clone();
301    pl_df.rename(resolved.as_str(), new_name.into())?;
302    Ok(super::DataFrame::from_polars_with_options(
303        pl_df,
304        case_sensitive,
305    ))
306}
307
308/// Replace values in a column: where column == old_value, use new_value. PySpark replace (single column).
309pub fn replace(
310    df: &DataFrame,
311    column_name: &str,
312    old_value: Expr,
313    new_value: Expr,
314    case_sensitive: bool,
315) -> Result<DataFrame, PolarsError> {
316    use polars::prelude::*;
317    let resolved = df.resolve_column_name(column_name)?;
318    let repl = when(col(resolved.as_str()).eq(old_value))
319        .then(new_value)
320        .otherwise(col(resolved.as_str()));
321    let pl_df = df
322        .df
323        .as_ref()
324        .clone()
325        .lazy()
326        .with_column(repl.alias(resolved.as_str()))
327        .collect()?;
328    Ok(super::DataFrame::from_polars_with_options(
329        pl_df,
330        case_sensitive,
331    ))
332}
333
334/// Cross join: cartesian product of two DataFrames. PySpark crossJoin.
335pub fn cross_join(
336    left: &DataFrame,
337    right: &DataFrame,
338    case_sensitive: bool,
339) -> Result<DataFrame, PolarsError> {
340    let lf_left = left.df.as_ref().clone().lazy();
341    let lf_right = right.df.as_ref().clone().lazy();
342    let out = lf_left.cross_join(lf_right, None);
343    let pl_df = out.collect()?;
344    Ok(super::DataFrame::from_polars_with_options(
345        pl_df,
346        case_sensitive,
347    ))
348}
349
350/// Summary statistics (count, mean, std, min, max). PySpark describe.
351/// Builds a summary DataFrame with a "summary" column (PySpark name) and one column per numeric input column.
352pub fn describe(df: &DataFrame, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
353    use polars::prelude::*;
354    let pl_df = df.df.as_ref().clone();
355    let mut stat_values: Vec<Column> = Vec::new();
356    for col in pl_df.get_columns() {
357        let s = col.as_materialized_series();
358        let dtype = s.dtype();
359        if dtype.is_numeric() {
360            let name = s.name().clone();
361            let count = s.len() as i64 - s.null_count() as i64;
362            let mean_f = s.mean().unwrap_or(f64::NAN);
363            let std_f = s.std(1).unwrap_or(f64::NAN);
364            let s_f64 = s.cast(&DataType::Float64)?;
365            let ca = s_f64
366                .f64()
367                .map_err(|_| PolarsError::ComputeError("cast to f64 failed".into()))?;
368            let min_f = ca.min().unwrap_or(f64::NAN);
369            let max_f = ca.max().unwrap_or(f64::NAN);
370            // PySpark describe/summary returns string type for value columns
371            let is_float = matches!(dtype, DataType::Float64 | DataType::Float32);
372            let count_s = count.to_string();
373            let mean_s = if mean_f.is_nan() {
374                "None".to_string()
375            } else {
376                format!("{:.1}", mean_f)
377            };
378            let std_s = if std_f.is_nan() {
379                "None".to_string()
380            } else {
381                format!("{:.1}", std_f)
382            };
383            let min_s = if min_f.is_nan() {
384                "None".to_string()
385            } else if min_f.fract() == 0.0 && is_float {
386                format!("{:.1}", min_f)
387            } else if min_f.fract() == 0.0 {
388                format!("{:.0}", min_f)
389            } else {
390                format!("{min_f}")
391            };
392            let max_s = if max_f.is_nan() {
393                "None".to_string()
394            } else if max_f.fract() == 0.0 && is_float {
395                format!("{:.1}", max_f)
396            } else if max_f.fract() == 0.0 {
397                format!("{:.0}", max_f)
398            } else {
399                format!("{max_f}")
400            };
401            let series = Series::new(
402                name,
403                [
404                    count_s.as_str(),
405                    mean_s.as_str(),
406                    std_s.as_str(),
407                    min_s.as_str(),
408                    max_s.as_str(),
409                ],
410            );
411            stat_values.push(series.into());
412        }
413    }
414    if stat_values.is_empty() {
415        // No numeric columns: return minimal describe with just summary column (PySpark name)
416        let stat_col = Series::new(
417            "summary".into(),
418            &["count", "mean", "stddev", "min", "max" as &str],
419        )
420        .into();
421        let empty: Vec<f64> = Vec::new();
422        let empty_series = Series::new("placeholder".into(), empty).into();
423        let out_pl = polars::prelude::DataFrame::new(vec![stat_col, empty_series])?;
424        return Ok(super::DataFrame::from_polars_with_options(
425            out_pl,
426            case_sensitive,
427        ));
428    }
429    let summary_col = Series::new(
430        "summary".into(),
431        &["count", "mean", "stddev", "min", "max" as &str],
432    )
433    .into();
434    let mut cols: Vec<Column> = vec![summary_col];
435    cols.extend(stat_values);
436    let out_pl = polars::prelude::DataFrame::new(cols)?;
437    Ok(super::DataFrame::from_polars_with_options(
438        out_pl,
439        case_sensitive,
440    ))
441}
442
443/// Set difference: rows in left that are not in right (by all columns). PySpark subtract / except.
444pub fn subtract(
445    left: &DataFrame,
446    right: &DataFrame,
447    case_sensitive: bool,
448) -> Result<DataFrame, PolarsError> {
449    use polars::prelude::*;
450    let left_names = left.df.get_column_names();
451    let left_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
452    let right_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
453    let right_lf = right.df.as_ref().clone().lazy();
454    let left_lf = left.df.as_ref().clone().lazy();
455    let anti = left_lf.join(right_lf, left_on, right_on, JoinArgs::new(JoinType::Anti));
456    let pl_df = anti.collect()?;
457    Ok(super::DataFrame::from_polars_with_options(
458        pl_df,
459        case_sensitive,
460    ))
461}
462
463/// Set intersection: rows that appear in both DataFrames (by all columns). PySpark intersect.
464pub fn intersect(
465    left: &DataFrame,
466    right: &DataFrame,
467    case_sensitive: bool,
468) -> Result<DataFrame, PolarsError> {
469    use polars::prelude::*;
470    let left_names = left.df.get_column_names();
471    let left_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
472    let right_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
473    let left_lf = left.df.as_ref().clone().lazy();
474    let right_lf = right.df.as_ref().clone().lazy();
475    let semi = left_lf.join(right_lf, left_on, right_on, JoinArgs::new(JoinType::Semi));
476    let pl_df = semi.collect()?;
477    Ok(super::DataFrame::from_polars_with_options(
478        pl_df,
479        case_sensitive,
480    ))
481}
482
483// ---------- Batch A: sample, first/head/take/tail, is_empty, to_df ----------
484
485/// Sample a fraction of rows. PySpark sample(withReplacement, fraction, seed).
486pub fn sample(
487    df: &DataFrame,
488    with_replacement: bool,
489    fraction: f64,
490    seed: Option<u64>,
491    case_sensitive: bool,
492) -> Result<DataFrame, PolarsError> {
493    use polars::prelude::Series;
494    let n = df.df.height();
495    if n == 0 {
496        return Ok(super::DataFrame::from_polars_with_options(
497            df.df.as_ref().clone(),
498            case_sensitive,
499        ));
500    }
501    let take_n = (n as f64 * fraction).round() as usize;
502    let take_n = take_n.min(n).max(0);
503    if take_n == 0 {
504        return Ok(super::DataFrame::from_polars_with_options(
505            df.df.as_ref().clone().head(Some(0)),
506            case_sensitive,
507        ));
508    }
509    let idx_series = Series::new("idx".into(), (0..n).map(|i| i as u32).collect::<Vec<_>>());
510    let sampled_idx = idx_series.sample_n(take_n, with_replacement, true, seed)?;
511    let idx_ca = sampled_idx
512        .u32()
513        .map_err(|_| PolarsError::ComputeError("sample: expected u32 indices".into()))?;
514    let pl_df = df.df.as_ref().take(idx_ca)?;
515    Ok(super::DataFrame::from_polars_with_options(
516        pl_df,
517        case_sensitive,
518    ))
519}
520
521/// Split DataFrame by weights (random split). PySpark randomSplit(weights, seed).
522/// Returns one DataFrame per weight; weights are normalized to fractions.
523pub fn random_split(
524    df: &DataFrame,
525    weights: &[f64],
526    seed: Option<u64>,
527    case_sensitive: bool,
528) -> Result<Vec<DataFrame>, PolarsError> {
529    let total: f64 = weights.iter().sum();
530    if total <= 0.0 || weights.is_empty() {
531        return Ok(Vec::new());
532    }
533    let n = df.df.height();
534    if n == 0 {
535        return Ok(weights.iter().map(|_| super::DataFrame::empty()).collect());
536    }
537    let mut out = Vec::with_capacity(weights.len());
538    let mut start = 0_usize;
539    for w in weights {
540        let frac = w / total;
541        let take_n = (n as f64 * frac).round() as usize;
542        let take_n = take_n.min(n - start).max(0);
543        if take_n == 0 {
544            out.push(super::DataFrame::from_polars_with_options(
545                df.df.as_ref().clone().head(Some(0)),
546                case_sensitive,
547            ));
548            continue;
549        }
550        use polars::prelude::Series;
551        let idx_series = Series::new("idx".into(), (0..n).map(|i| i as u32).collect::<Vec<_>>());
552        let sampled_idx = idx_series.sample_n(take_n, false, true, seed)?;
553        let idx_ca = sampled_idx
554            .u32()
555            .map_err(|_| PolarsError::ComputeError("random_split: expected u32 indices".into()))?;
556        let sampled = df.df.as_ref().take(idx_ca)?;
557        start += take_n;
558        out.push(super::DataFrame::from_polars_with_options(
559            sampled,
560            case_sensitive,
561        ));
562    }
563    Ok(out)
564}
565
566/// Stratified sample by column value. PySpark sampleBy(col, fractions, seed).
567/// fractions: list of (value as Expr literal, fraction to sample for that value).
568pub fn sample_by(
569    df: &DataFrame,
570    col_name: &str,
571    fractions: &[(Expr, f64)],
572    seed: Option<u64>,
573    case_sensitive: bool,
574) -> Result<DataFrame, PolarsError> {
575    use polars::prelude::*;
576    if fractions.is_empty() {
577        return Ok(super::DataFrame::from_polars_with_options(
578            df.df.as_ref().clone().head(Some(0)),
579            case_sensitive,
580        ));
581    }
582    let resolved = df.resolve_column_name(col_name)?;
583    let mut parts = Vec::with_capacity(fractions.len());
584    for (value_expr, frac) in fractions {
585        let cond = col(resolved.as_str()).eq(value_expr.clone());
586        let filtered = df.df.as_ref().clone().lazy().filter(cond).collect()?;
587        if filtered.height() == 0 {
588            parts.push(filtered.head(Some(0)));
589            continue;
590        }
591        let sampled = sample(
592            &super::DataFrame::from_polars_with_options(filtered, case_sensitive),
593            false,
594            *frac,
595            seed,
596            case_sensitive,
597        )?;
598        parts.push(sampled.df.as_ref().clone());
599    }
600    let mut out = parts
601        .first()
602        .ok_or_else(|| PolarsError::ComputeError("sample_by: no parts".into()))?
603        .clone();
604    for p in parts.iter().skip(1) {
605        out.vstack_mut(p)?;
606    }
607    Ok(super::DataFrame::from_polars_with_options(
608        out,
609        case_sensitive,
610    ))
611}
612
613/// First row as a DataFrame (one row). PySpark first().
614pub fn first(df: &DataFrame, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
615    let pl_df = df.df.as_ref().clone().head(Some(1));
616    Ok(super::DataFrame::from_polars_with_options(
617        pl_df,
618        case_sensitive,
619    ))
620}
621
622/// First n rows. PySpark head(n). Same as limit.
623pub fn head(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
624    limit(df, n, case_sensitive)
625}
626
627/// Take first n rows (alias for limit). PySpark take(n).
628pub fn take(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
629    limit(df, n, case_sensitive)
630}
631
632/// Last n rows. PySpark tail(n).
633pub fn tail(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
634    let total = df.df.height();
635    let skip = total.saturating_sub(n);
636    let pl_df = df.df.as_ref().clone().slice(skip as i64, n);
637    Ok(super::DataFrame::from_polars_with_options(
638        pl_df,
639        case_sensitive,
640    ))
641}
642
643/// Whether the DataFrame has zero rows. PySpark isEmpty.
644pub fn is_empty(df: &DataFrame) -> bool {
645    df.df.height() == 0
646}
647
648/// Rename columns. PySpark toDF(*colNames). Names must match length of columns.
649pub fn to_df(
650    df: &DataFrame,
651    names: &[&str],
652    case_sensitive: bool,
653) -> Result<DataFrame, PolarsError> {
654    let cols = df.df.get_column_names();
655    if names.len() != cols.len() {
656        return Err(PolarsError::ComputeError(
657            format!(
658                "toDF: expected {} column names, got {}",
659                cols.len(),
660                names.len()
661            )
662            .into(),
663        ));
664    }
665    let mut pl_df = df.df.as_ref().clone();
666    for (old, new) in cols.iter().zip(names.iter()) {
667        pl_df.rename(old.as_str(), (*new).into())?;
668    }
669    Ok(super::DataFrame::from_polars_with_options(
670        pl_df,
671        case_sensitive,
672    ))
673}
674
675// ---------- Batch B: toJSON, explain, printSchema ----------
676
677fn any_value_to_serde_value(av: &polars::prelude::AnyValue) -> serde_json::Value {
678    use polars::prelude::AnyValue;
679    use serde_json::Number;
680    match av {
681        AnyValue::Null => serde_json::Value::Null,
682        AnyValue::Boolean(v) => serde_json::Value::Bool(*v),
683        AnyValue::Int8(v) => serde_json::Value::Number(Number::from(*v as i64)),
684        AnyValue::Int32(v) => serde_json::Value::Number(Number::from(*v)),
685        AnyValue::Int64(v) => serde_json::Value::Number(Number::from(*v)),
686        AnyValue::UInt32(v) => serde_json::Value::Number(Number::from(*v)),
687        AnyValue::Float64(v) => Number::from_f64(*v)
688            .map(serde_json::Value::Number)
689            .unwrap_or(serde_json::Value::Null),
690        AnyValue::String(v) => serde_json::Value::String(v.to_string()),
691        _ => serde_json::Value::String(format!("{av:?}")),
692    }
693}
694
695/// Collect rows as JSON strings (one JSON object per row). PySpark toJSON.
696pub fn to_json(df: &DataFrame) -> Result<Vec<String>, PolarsError> {
697    use polars::prelude::*;
698    let pl = df.df.as_ref();
699    let names = pl.get_column_names();
700    let mut out = Vec::with_capacity(pl.height());
701    for r in 0..pl.height() {
702        let mut row = serde_json::Map::new();
703        for (i, name) in names.iter().enumerate() {
704            let col = pl
705                .get_columns()
706                .get(i)
707                .ok_or_else(|| PolarsError::ComputeError("to_json: column index".into()))?;
708            let series = col.as_materialized_series();
709            let av = series
710                .get(r)
711                .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
712            row.insert(name.to_string(), any_value_to_serde_value(&av));
713        }
714        out.push(
715            serde_json::to_string(&row)
716                .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?,
717        );
718    }
719    Ok(out)
720}
721
722/// Return a string describing the execution plan. PySpark explain.
723pub fn explain(_df: &DataFrame) -> String {
724    "DataFrame (eager Polars backend)".to_string()
725}
726
727/// Return schema as a tree string. PySpark printSchema (we return string; caller can print).
728pub fn print_schema(df: &DataFrame) -> Result<String, PolarsError> {
729    let schema = df.schema()?;
730    let mut s = "root\n".to_string();
731    for f in schema.fields() {
732        let dt = match &f.data_type {
733            crate::schema::DataType::String => "string",
734            crate::schema::DataType::Integer => "int",
735            crate::schema::DataType::Long => "bigint",
736            crate::schema::DataType::Double => "double",
737            crate::schema::DataType::Boolean => "boolean",
738            crate::schema::DataType::Date => "date",
739            crate::schema::DataType::Timestamp => "timestamp",
740            _ => "string",
741        };
742        s.push_str(&format!(" |-- {}: {}\n", f.name, dt));
743    }
744    Ok(s)
745}
746
747// ---------- Batch D: selectExpr, colRegex, withColumns, withColumnsRenamed, na ----------
748
749/// Select by expression strings. Minimal support: comma-separated column names. PySpark selectExpr.
750pub fn select_expr(
751    df: &DataFrame,
752    exprs: &[String],
753    case_sensitive: bool,
754) -> Result<DataFrame, PolarsError> {
755    let mut cols = Vec::new();
756    for e in exprs {
757        let e = e.trim();
758        if let Some((left, right)) = e.split_once(" as ") {
759            let col_name = left.trim();
760            let _alias = right.trim();
761            cols.push(df.resolve_column_name(col_name)?);
762        } else {
763            cols.push(df.resolve_column_name(e)?);
764        }
765    }
766    let refs: Vec<&str> = cols.iter().map(|s| s.as_str()).collect();
767    select(df, refs, case_sensitive)
768}
769
770/// Select columns whose names match the regex pattern. PySpark colRegex.
771pub fn col_regex(
772    df: &DataFrame,
773    pattern: &str,
774    case_sensitive: bool,
775) -> Result<DataFrame, PolarsError> {
776    let re = regex::Regex::new(pattern).map_err(|e| {
777        PolarsError::ComputeError(format!("colRegex: invalid pattern {pattern:?}: {e}").into())
778    })?;
779    let names = df.df.get_column_names();
780    let matched: Vec<&str> = names
781        .iter()
782        .filter(|n| re.is_match(n.as_str()))
783        .map(|s| s.as_str())
784        .collect();
785    if matched.is_empty() {
786        return Err(PolarsError::ComputeError(
787            format!("colRegex: no columns matched pattern {pattern:?}").into(),
788        ));
789    }
790    select(df, matched, case_sensitive)
791}
792
793/// Add or replace multiple columns. PySpark withColumns. Uses Column so deferred rand/randn get per-row values.
794pub fn with_columns(
795    df: &DataFrame,
796    exprs: &[(String, crate::column::Column)],
797    case_sensitive: bool,
798) -> Result<DataFrame, PolarsError> {
799    let mut current =
800        super::DataFrame::from_polars_with_options(df.df.as_ref().clone(), case_sensitive);
801    for (name, col) in exprs {
802        current = with_column(&current, name, col, case_sensitive)?;
803    }
804    Ok(current)
805}
806
807/// Rename multiple columns. PySpark withColumnsRenamed.
808pub fn with_columns_renamed(
809    df: &DataFrame,
810    renames: &[(String, String)],
811    case_sensitive: bool,
812) -> Result<DataFrame, PolarsError> {
813    let mut out = df.df.as_ref().clone();
814    for (old_name, new_name) in renames {
815        let resolved = df.resolve_column_name(old_name)?;
816        out.rename(resolved.as_str(), new_name.as_str().into())?;
817    }
818    Ok(super::DataFrame::from_polars_with_options(
819        out,
820        case_sensitive,
821    ))
822}
823
824/// NA sub-API builder. PySpark df.na().fill(...) / .drop(...).
825pub struct DataFrameNa<'a> {
826    pub(crate) df: &'a DataFrame,
827}
828
829impl<'a> DataFrameNa<'a> {
830    /// Fill nulls with the given value. PySpark na.fill(value).
831    pub fn fill(&self, value: Expr) -> Result<DataFrame, PolarsError> {
832        fillna(self.df, value, self.df.case_sensitive)
833    }
834
835    /// Drop rows with nulls in the given columns (or all). PySpark na.drop(subset).
836    pub fn drop(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
837        dropna(self.df, subset, self.df.case_sensitive)
838    }
839}
840
841// ---------- Batch E: offset, transform, freqItems, approxQuantile, crosstab, melt, exceptAll, intersectAll ----------
842
843/// Skip first n rows. PySpark offset(n).
844pub fn offset(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
845    let total = df.df.height();
846    let len = total.saturating_sub(n);
847    let pl_df = df.df.as_ref().clone().slice(n as i64, len);
848    Ok(super::DataFrame::from_polars_with_options(
849        pl_df,
850        case_sensitive,
851    ))
852}
853
854/// Transform DataFrame by a function. PySpark transform(func).
855pub fn transform<F>(df: &DataFrame, f: F) -> Result<DataFrame, PolarsError>
856where
857    F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
858{
859    let df_out = f(df.clone())?;
860    Ok(df_out)
861}
862
863/// Frequent items. PySpark freqItems. Returns one row with columns {col}_freqItems (array of values with frequency >= support).
864pub fn freq_items(
865    df: &DataFrame,
866    columns: &[&str],
867    support: f64,
868    case_sensitive: bool,
869) -> Result<DataFrame, PolarsError> {
870    use polars::prelude::SeriesMethods;
871    if columns.is_empty() {
872        return Ok(super::DataFrame::from_polars_with_options(
873            df.df.as_ref().clone().head(Some(0)),
874            case_sensitive,
875        ));
876    }
877    let support = support.clamp(1e-4, 1.0);
878    let pl_df = df.df.as_ref();
879    let n_total = pl_df.height() as f64;
880    if n_total == 0.0 {
881        let mut out = Vec::with_capacity(columns.len());
882        for col_name in columns {
883            let resolved = df.resolve_column_name(col_name)?;
884            let s = pl_df
885                .column(resolved.as_str())?
886                .as_series()
887                .ok_or_else(|| PolarsError::ComputeError("column not a series".into()))?
888                .clone();
889            let empty_sub = s.head(Some(0));
890            let list_chunked = polars::prelude::ListChunked::from_iter([empty_sub].into_iter())
891                .with_name(format!("{resolved}_freqItems").into());
892            out.push(list_chunked.into_series().into());
893        }
894        return Ok(super::DataFrame::from_polars_with_options(
895            polars::prelude::DataFrame::new(out)?,
896            case_sensitive,
897        ));
898    }
899    let mut out_series = Vec::with_capacity(columns.len());
900    for col_name in columns {
901        let resolved = df.resolve_column_name(col_name)?;
902        let s = pl_df
903            .column(resolved.as_str())?
904            .as_series()
905            .ok_or_else(|| PolarsError::ComputeError("column not a series".into()))?
906            .clone();
907        let vc = s.value_counts(false, false, "counts".into(), false)?;
908        let count_col = vc
909            .column("counts")
910            .map_err(|_| PolarsError::ComputeError("value_counts missing counts column".into()))?;
911        let counts = count_col
912            .u32()
913            .map_err(|_| PolarsError::ComputeError("freq_items: counts column not u32".into()))?;
914        let value_col_name = s.name();
915        let values_col = vc
916            .column(value_col_name.as_str())
917            .map_err(|_| PolarsError::ComputeError("value_counts missing value column".into()))?;
918        let threshold = (support * n_total).ceil() as u32;
919        let indices: Vec<u32> = counts
920            .into_iter()
921            .enumerate()
922            .filter_map(|(i, c)| {
923                if c? >= threshold {
924                    Some(i as u32)
925                } else {
926                    None
927                }
928            })
929            .collect();
930        let idx_series = Series::new("idx".into(), indices);
931        let idx_ca = idx_series
932            .u32()
933            .map_err(|_| PolarsError::ComputeError("freq_items: index series not u32".into()))?;
934        let values_series = values_col
935            .as_series()
936            .ok_or_else(|| PolarsError::ComputeError("value column not a series".into()))?;
937        let filtered = values_series.take(idx_ca)?;
938        let list_chunked = polars::prelude::ListChunked::from_iter([filtered].into_iter())
939            .with_name(format!("{resolved}_freqItems").into());
940        let list_row = list_chunked.into_series();
941        out_series.push(list_row.into());
942    }
943    let out_df = polars::prelude::DataFrame::new(out_series)?;
944    Ok(super::DataFrame::from_polars_with_options(
945        out_df,
946        case_sensitive,
947    ))
948}
949
950/// Approximate quantiles. PySpark approxQuantile. Returns one column "quantile" with one row per probability.
951pub fn approx_quantile(
952    df: &DataFrame,
953    column: &str,
954    probabilities: &[f64],
955    case_sensitive: bool,
956) -> Result<DataFrame, PolarsError> {
957    use polars::prelude::{ChunkQuantile, QuantileMethod};
958    if probabilities.is_empty() {
959        return Ok(super::DataFrame::from_polars_with_options(
960            polars::prelude::DataFrame::new(vec![Series::new(
961                "quantile".into(),
962                Vec::<f64>::new(),
963            )
964            .into()])?,
965            case_sensitive,
966        ));
967    }
968    let resolved = df.resolve_column_name(column)?;
969    let s = df
970        .df
971        .as_ref()
972        .column(resolved.as_str())?
973        .as_series()
974        .ok_or_else(|| PolarsError::ComputeError("approx_quantile: column not a series".into()))?
975        .clone();
976    let s_f64 = s.cast(&polars::prelude::DataType::Float64)?;
977    let ca = s_f64
978        .f64()
979        .map_err(|_| PolarsError::ComputeError("approx_quantile: need numeric column".into()))?;
980    let mut quantiles = Vec::with_capacity(probabilities.len());
981    for &p in probabilities {
982        let q = ca.quantile(p, QuantileMethod::Linear)?;
983        quantiles.push(q.unwrap_or(f64::NAN));
984    }
985    let out_df =
986        polars::prelude::DataFrame::new(vec![Series::new("quantile".into(), quantiles).into()])?;
987    Ok(super::DataFrame::from_polars_with_options(
988        out_df,
989        case_sensitive,
990    ))
991}
992
993/// Cross-tabulation. PySpark crosstab. Returns long format (col1, col2, count); for wide format use pivot on the result.
994pub fn crosstab(
995    df: &DataFrame,
996    col1: &str,
997    col2: &str,
998    case_sensitive: bool,
999) -> Result<DataFrame, PolarsError> {
1000    use polars::prelude::*;
1001    let c1 = df.resolve_column_name(col1)?;
1002    let c2 = df.resolve_column_name(col2)?;
1003    let pl_df = df.df.as_ref();
1004    let grouped = pl_df
1005        .clone()
1006        .lazy()
1007        .group_by([col(c1.as_str()), col(c2.as_str())])
1008        .agg([len().alias("count")])
1009        .collect()?;
1010    Ok(super::DataFrame::from_polars_with_options(
1011        grouped,
1012        case_sensitive,
1013    ))
1014}
1015
1016/// Unpivot (melt). PySpark melt. Long format with id_vars kept, plus "variable" and "value" columns.
1017pub fn melt(
1018    df: &DataFrame,
1019    id_vars: &[&str],
1020    value_vars: &[&str],
1021    case_sensitive: bool,
1022) -> Result<DataFrame, PolarsError> {
1023    use polars::prelude::*;
1024    let pl_df = df.df.as_ref();
1025    if value_vars.is_empty() {
1026        return Ok(super::DataFrame::from_polars_with_options(
1027            pl_df.head(Some(0)),
1028            case_sensitive,
1029        ));
1030    }
1031    let id_resolved: Vec<String> = id_vars
1032        .iter()
1033        .map(|s| df.resolve_column_name(s).map(|r| r.to_string()))
1034        .collect::<Result<Vec<_>, _>>()?;
1035    let value_resolved: Vec<String> = value_vars
1036        .iter()
1037        .map(|s| df.resolve_column_name(s).map(|r| r.to_string()))
1038        .collect::<Result<Vec<_>, _>>()?;
1039    let mut parts = Vec::with_capacity(value_vars.len());
1040    for vname in &value_resolved {
1041        let select_cols: Vec<&str> = id_resolved
1042            .iter()
1043            .map(|s| s.as_str())
1044            .chain([vname.as_str()])
1045            .collect();
1046        let mut part = pl_df.select(select_cols)?;
1047        let var_series = Series::new("variable".into(), vec![vname.as_str(); part.height()]);
1048        part.with_column(var_series)?;
1049        part.rename(vname.as_str(), "value".into())?;
1050        parts.push(part);
1051    }
1052    let mut out = parts
1053        .first()
1054        .ok_or_else(|| PolarsError::ComputeError("melt: no value columns".into()))?
1055        .clone();
1056    for p in parts.iter().skip(1) {
1057        out.vstack_mut(p)?;
1058    }
1059    let col_order: Vec<&str> = id_resolved
1060        .iter()
1061        .map(|s| s.as_str())
1062        .chain(["variable", "value"])
1063        .collect();
1064    let out = out.select(col_order)?;
1065    Ok(super::DataFrame::from_polars_with_options(
1066        out,
1067        case_sensitive,
1068    ))
1069}
1070
1071/// Set difference keeping duplicates. PySpark exceptAll. Simple impl: same as subtract.
1072pub fn except_all(
1073    left: &DataFrame,
1074    right: &DataFrame,
1075    case_sensitive: bool,
1076) -> Result<DataFrame, PolarsError> {
1077    subtract(left, right, case_sensitive)
1078}
1079
1080/// Set intersection keeping duplicates. PySpark intersectAll. Simple impl: same as intersect.
1081pub fn intersect_all(
1082    left: &DataFrame,
1083    right: &DataFrame,
1084    case_sensitive: bool,
1085) -> Result<DataFrame, PolarsError> {
1086    intersect(left, right, case_sensitive)
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091    use super::{distinct, drop, dropna, first, head, limit, offset};
1092    use crate::{DataFrame, SparkSession};
1093
1094    fn test_df() -> DataFrame {
1095        let spark = SparkSession::builder()
1096            .app_name("transform_tests")
1097            .get_or_create();
1098        spark
1099            .create_dataframe(
1100                vec![
1101                    (1i64, 10i64, "a".to_string()),
1102                    (2i64, 20i64, "b".to_string()),
1103                    (3i64, 30i64, "c".to_string()),
1104                ],
1105                vec!["id", "v", "label"],
1106            )
1107            .unwrap()
1108    }
1109
1110    #[test]
1111    fn limit_zero() {
1112        let df = test_df();
1113        let out = limit(&df, 0, false).unwrap();
1114        assert_eq!(out.count().unwrap(), 0);
1115    }
1116
1117    #[test]
1118    fn limit_more_than_rows() {
1119        let df = test_df();
1120        let out = limit(&df, 10, false).unwrap();
1121        assert_eq!(out.count().unwrap(), 3);
1122    }
1123
1124    #[test]
1125    fn distinct_on_empty() {
1126        let spark = SparkSession::builder()
1127            .app_name("transform_tests")
1128            .get_or_create();
1129        let df = spark
1130            .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["a", "b", "c"])
1131            .unwrap();
1132        let out = distinct(&df, None, false).unwrap();
1133        assert_eq!(out.count().unwrap(), 0);
1134    }
1135
1136    #[test]
1137    fn first_returns_one_row() {
1138        let df = test_df();
1139        let out = first(&df, false).unwrap();
1140        assert_eq!(out.count().unwrap(), 1);
1141    }
1142
1143    #[test]
1144    fn head_n() {
1145        let df = test_df();
1146        let out = head(&df, 2, false).unwrap();
1147        assert_eq!(out.count().unwrap(), 2);
1148    }
1149
1150    #[test]
1151    fn offset_skip_first() {
1152        let df = test_df();
1153        let out = offset(&df, 1, false).unwrap();
1154        assert_eq!(out.count().unwrap(), 2);
1155    }
1156
1157    #[test]
1158    fn offset_beyond_length_returns_empty() {
1159        let df = test_df();
1160        let out = offset(&df, 10, false).unwrap();
1161        assert_eq!(out.count().unwrap(), 0);
1162    }
1163
1164    #[test]
1165    fn drop_column() {
1166        let df = test_df();
1167        let out = drop(&df, vec!["v"], false).unwrap();
1168        let cols = out.columns().unwrap();
1169        assert!(!cols.contains(&"v".to_string()));
1170        assert_eq!(out.count().unwrap(), 3);
1171    }
1172
1173    #[test]
1174    fn dropna_all_columns() {
1175        let df = test_df();
1176        let out = dropna(&df, None, false).unwrap();
1177        assert_eq!(out.count().unwrap(), 3);
1178    }
1179}