Skip to main content

robin_sparkless/dataframe/
transformations.rs

1//! DataFrame transformation operations: filter, select, with_column, order_by,
2//! union, distinct, drop, dropna, fillna, limit, with_column_renamed,
3//! replace, cross_join, describe, subtract, intersect,
4//! sample, random_split, first, head, take, tail, is_empty, to_df.
5
6use super::DataFrame;
7use crate::functions::SortOrder;
8use polars::prelude::{
9    col, Expr, IntoLazy, IntoSeries, NamedFrom, PolarsError, Series, UnionArgs, UniqueKeepStrategy,
10};
11use std::collections::HashMap;
12
13/// Select columns (returns a new DataFrame). Preserves case_sensitive on result.
14pub fn select(
15    df: &DataFrame,
16    cols: Vec<&str>,
17    case_sensitive: bool,
18) -> Result<DataFrame, PolarsError> {
19    let selected = df.df.select(cols)?;
20    Ok(super::DataFrame::from_polars_with_options(
21        selected,
22        case_sensitive,
23    ))
24}
25
26/// Select using column expressions (e.g. F.regexp_extract_all(...).alias("m")). Preserves case_sensitive.
27/// Column names in expressions are resolved per df's case sensitivity (PySpark parity).
28/// Duplicate output names are disambiguated with _1, _2, ... so select(col("num").cast("string"), col("num").cast("int")) works (issue #213).
29pub fn select_with_exprs(
30    df: &DataFrame,
31    exprs: Vec<Expr>,
32    case_sensitive: bool,
33) -> Result<DataFrame, PolarsError> {
34    let exprs: Vec<Expr> = exprs
35        .into_iter()
36        .map(|e| df.resolve_expr_column_names(e))
37        .collect::<Result<Vec<_>, _>>()?;
38    let mut name_count: HashMap<String, u32> = HashMap::new();
39    let exprs: Vec<Expr> = exprs
40        .into_iter()
41        .map(|e| {
42            let base_name = polars_plan::utils::expr_output_name(&e)
43                .map(|s| s.to_string())
44                .unwrap_or_else(|_| "_".to_string());
45            let count = name_count.entry(base_name.clone()).or_insert(0);
46            *count += 1;
47            let final_name = if *count == 1 {
48                base_name
49            } else {
50                format!("{}_{}", base_name, *count - 1)
51            };
52            if *count == 1 {
53                e
54            } else {
55                e.alias(final_name.as_str())
56            }
57        })
58        .collect();
59    let lf = df.df.as_ref().clone().lazy();
60    let out_df = lf.select(exprs).collect()?;
61    Ok(super::DataFrame::from_polars_with_options(
62        out_df,
63        case_sensitive,
64    ))
65}
66
67/// Filter rows using a Polars expression. Preserves case_sensitive on result.
68/// Column names in the condition are resolved per df's case sensitivity (PySpark parity).
69pub fn filter(
70    df: &DataFrame,
71    condition: Expr,
72    case_sensitive: bool,
73) -> Result<DataFrame, PolarsError> {
74    let condition = df.resolve_expr_column_names(condition)?;
75    let condition = df.coerce_string_numeric_comparisons(condition)?;
76    let lf = df.df.as_ref().clone().lazy().filter(condition);
77    let out_df = lf.collect()?;
78    Ok(super::DataFrame::from_polars_with_options(
79        out_df,
80        case_sensitive,
81    ))
82}
83
84/// Add or replace a column. Handles deferred rand/randn and Python UDF (UdfCall).
85pub fn with_column(
86    df: &DataFrame,
87    column_name: &str,
88    column: &crate::column::Column,
89    case_sensitive: bool,
90) -> Result<DataFrame, PolarsError> {
91    // Python UDF: eager execution at UDF boundary
92    #[cfg(feature = "pyo3")]
93    if let Some((ref udf_name, ref args)) = column.udf_call {
94        if let Some(session) = crate::session::get_thread_udf_session() {
95            return crate::python::execute_python_udf(
96                df,
97                column_name,
98                udf_name,
99                args,
100                case_sensitive,
101                &session,
102            );
103        }
104    }
105
106    if let Some(deferred) = column.deferred {
107        match deferred {
108            crate::column::DeferredRandom::Rand(seed) => {
109                let mut pl_df = df.df.as_ref().clone();
110                let n = pl_df.height();
111                let series = crate::udfs::series_rand_n(column_name, n, seed);
112                pl_df.with_column(series)?;
113                return Ok(super::DataFrame::from_polars_with_options(
114                    pl_df,
115                    case_sensitive,
116                ));
117            }
118            crate::column::DeferredRandom::Randn(seed) => {
119                let mut pl_df = df.df.as_ref().clone();
120                let n = pl_df.height();
121                let series = crate::udfs::series_randn_n(column_name, n, seed);
122                pl_df.with_column(series)?;
123                return Ok(super::DataFrame::from_polars_with_options(
124                    pl_df,
125                    case_sensitive,
126                ));
127            }
128        }
129    }
130    let expr = df.resolve_expr_column_names(column.expr().clone())?;
131    let lf = df.df.as_ref().clone().lazy();
132    let lf_with_col = lf.with_column(expr.alias(column_name));
133    let pl_df = lf_with_col.collect()?;
134    Ok(super::DataFrame::from_polars_with_options(
135        pl_df,
136        case_sensitive,
137    ))
138}
139
140/// Order by columns (sort). Preserves case_sensitive on result.
141pub fn order_by(
142    df: &DataFrame,
143    column_names: Vec<&str>,
144    ascending: Vec<bool>,
145    case_sensitive: bool,
146) -> Result<DataFrame, PolarsError> {
147    use polars::prelude::*;
148    let mut asc = ascending;
149    while asc.len() < column_names.len() {
150        asc.push(true);
151    }
152    asc.truncate(column_names.len());
153    let lf = df.df.as_ref().clone().lazy();
154    let exprs: Vec<Expr> = column_names.iter().map(|name| col(*name)).collect();
155    let descending: Vec<bool> = asc.iter().map(|&a| !a).collect();
156    let sorted = lf.sort_by_exprs(
157        exprs,
158        SortMultipleOptions::new().with_order_descending_multi(descending),
159    );
160    let pl_df = sorted.collect()?;
161    Ok(super::DataFrame::from_polars_with_options(
162        pl_df,
163        case_sensitive,
164    ))
165}
166
167/// Order by sort expressions (asc/desc with nulls_first/last). Preserves case_sensitive on result.
168/// Column names in sort expressions are resolved per df's case sensitivity (PySpark parity).
169pub fn order_by_exprs(
170    df: &DataFrame,
171    sort_orders: Vec<SortOrder>,
172    case_sensitive: bool,
173) -> Result<DataFrame, PolarsError> {
174    use polars::prelude::*;
175    if sort_orders.is_empty() {
176        return Ok(super::DataFrame::from_polars_with_options(
177            df.df.as_ref().clone(),
178            case_sensitive,
179        ));
180    }
181    let exprs: Vec<Expr> = sort_orders
182        .iter()
183        .map(|s| df.resolve_expr_column_names(s.expr().clone()))
184        .collect::<Result<Vec<_>, _>>()?;
185    let descending: Vec<bool> = sort_orders.iter().map(|s| s.descending).collect();
186    let nulls_last: Vec<bool> = sort_orders.iter().map(|s| s.nulls_last).collect();
187    let opts = SortMultipleOptions::new()
188        .with_order_descending_multi(descending)
189        .with_nulls_last_multi(nulls_last);
190    let lf = df.df.as_ref().clone().lazy();
191    let sorted = lf.sort_by_exprs(exprs, opts);
192    let pl_df = sorted.collect()?;
193    Ok(super::DataFrame::from_polars_with_options(
194        pl_df,
195        case_sensitive,
196    ))
197}
198
199/// Union (unionAll): stack another DataFrame vertically. Schemas must match (same columns, same order).
200pub fn union(
201    left: &DataFrame,
202    right: &DataFrame,
203    case_sensitive: bool,
204) -> Result<DataFrame, PolarsError> {
205    let lf1 = left.df.as_ref().clone().lazy();
206    let lf2 = right.df.as_ref().clone().lazy();
207    let out = polars::prelude::concat([lf1, lf2], UnionArgs::default())?.collect()?;
208    Ok(super::DataFrame::from_polars_with_options(
209        out,
210        case_sensitive,
211    ))
212}
213
214/// Union by name: stack vertically, aligning columns by name. Right columns are reordered to match left; missing columns in right become nulls.
215pub fn union_by_name(
216    left: &DataFrame,
217    right: &DataFrame,
218    case_sensitive: bool,
219) -> Result<DataFrame, PolarsError> {
220    use polars::prelude::*;
221    let left_names = left.df.get_column_names();
222    let right_df = right.df.as_ref();
223    let right_names = right_df.get_column_names();
224    let resolve_right = |name: &str| -> Option<String> {
225        if case_sensitive {
226            right_names
227                .iter()
228                .find(|n| n.as_str() == name)
229                .map(|s| s.as_str().to_string())
230        } else {
231            let name_lower = name.to_lowercase();
232            right_names
233                .iter()
234                .find(|n| n.as_str().to_lowercase() == name_lower)
235                .map(|s| s.as_str().to_string())
236        }
237    };
238    let mut exprs: Vec<Expr> = Vec::with_capacity(left_names.len());
239    for left_col in left_names.iter() {
240        let left_str = left_col.as_str();
241        if let Some(r) = resolve_right(left_str) {
242            exprs.push(col(r.as_str()));
243        } else {
244            exprs.push(Expr::Literal(polars::prelude::LiteralValue::Null).alias(left_str));
245        }
246    }
247    let right_aligned = right_df.clone().lazy().select(exprs).collect()?;
248    let lf1 = left.df.as_ref().clone().lazy();
249    let lf2 = right_aligned.lazy();
250    let out = polars::prelude::concat([lf1, lf2], UnionArgs::default())?.collect()?;
251    Ok(super::DataFrame::from_polars_with_options(
252        out,
253        case_sensitive,
254    ))
255}
256
257/// Distinct: drop duplicate rows (all columns or subset).
258pub fn distinct(
259    df: &DataFrame,
260    subset: Option<Vec<&str>>,
261    case_sensitive: bool,
262) -> Result<DataFrame, PolarsError> {
263    let lf = df.df.as_ref().clone().lazy();
264    let subset_names: Option<Vec<String>> =
265        subset.map(|cols| cols.iter().map(|s| (*s).to_string()).collect());
266    let lf = lf.unique(subset_names, UniqueKeepStrategy::First);
267    let pl_df = lf.collect()?;
268    Ok(super::DataFrame::from_polars_with_options(
269        pl_df,
270        case_sensitive,
271    ))
272}
273
274/// Drop one or more columns.
275pub fn drop(
276    df: &DataFrame,
277    columns: Vec<&str>,
278    case_sensitive: bool,
279) -> Result<DataFrame, PolarsError> {
280    let resolved: Vec<String> = columns
281        .iter()
282        .map(|c| df.resolve_column_name(c))
283        .collect::<Result<Vec<_>, _>>()?;
284    let all_names = df.df.get_column_names();
285    let to_keep: Vec<&str> = all_names
286        .iter()
287        .filter(|n| !resolved.iter().any(|r| r == n.as_str()))
288        .map(|n| n.as_str())
289        .collect();
290    let pl_df = df.df.select(to_keep)?;
291    Ok(super::DataFrame::from_polars_with_options(
292        pl_df,
293        case_sensitive,
294    ))
295}
296
297/// Drop rows with nulls (all columns or subset).
298pub fn dropna(
299    df: &DataFrame,
300    subset: Option<Vec<&str>>,
301    case_sensitive: bool,
302) -> Result<DataFrame, PolarsError> {
303    let lf = df.df.as_ref().clone().lazy();
304    let subset_exprs: Option<Vec<Expr>> = match subset {
305        Some(cols) => Some(cols.iter().map(|c| col(*c)).collect()),
306        None => Some(
307            df.df
308                .get_column_names()
309                .iter()
310                .map(|n| col(n.as_str()))
311                .collect(),
312        ),
313    };
314    let lf = lf.drop_nulls(subset_exprs);
315    let pl_df = lf.collect()?;
316    Ok(super::DataFrame::from_polars_with_options(
317        pl_df,
318        case_sensitive,
319    ))
320}
321
322/// Fill nulls with a literal expression (applied to all columns). For column-specific fill, use a map in a future extension.
323pub fn fillna(
324    df: &DataFrame,
325    value_expr: Expr,
326    case_sensitive: bool,
327) -> Result<DataFrame, PolarsError> {
328    use polars::prelude::*;
329    let exprs: Vec<Expr> = df
330        .df
331        .get_column_names()
332        .iter()
333        .map(|n| col(n.as_str()).fill_null(value_expr.clone()))
334        .collect();
335    let pl_df = df
336        .df
337        .as_ref()
338        .clone()
339        .lazy()
340        .with_columns(exprs)
341        .collect()?;
342    Ok(super::DataFrame::from_polars_with_options(
343        pl_df,
344        case_sensitive,
345    ))
346}
347
348/// Limit: return first n rows.
349pub fn limit(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
350    let pl_df = df.df.as_ref().clone().head(Some(n));
351    Ok(super::DataFrame::from_polars_with_options(
352        pl_df,
353        case_sensitive,
354    ))
355}
356
357/// Rename a column (old_name -> new_name).
358pub fn with_column_renamed(
359    df: &DataFrame,
360    old_name: &str,
361    new_name: &str,
362    case_sensitive: bool,
363) -> Result<DataFrame, PolarsError> {
364    let resolved = df.resolve_column_name(old_name)?;
365    let mut pl_df = df.df.as_ref().clone();
366    pl_df.rename(resolved.as_str(), new_name.into())?;
367    Ok(super::DataFrame::from_polars_with_options(
368        pl_df,
369        case_sensitive,
370    ))
371}
372
373/// Replace values in a column: where column == old_value, use new_value. PySpark replace (single column).
374pub fn replace(
375    df: &DataFrame,
376    column_name: &str,
377    old_value: Expr,
378    new_value: Expr,
379    case_sensitive: bool,
380) -> Result<DataFrame, PolarsError> {
381    use polars::prelude::*;
382    let resolved = df.resolve_column_name(column_name)?;
383    let repl = when(col(resolved.as_str()).eq(old_value))
384        .then(new_value)
385        .otherwise(col(resolved.as_str()));
386    let pl_df = df
387        .df
388        .as_ref()
389        .clone()
390        .lazy()
391        .with_column(repl.alias(resolved.as_str()))
392        .collect()?;
393    Ok(super::DataFrame::from_polars_with_options(
394        pl_df,
395        case_sensitive,
396    ))
397}
398
399/// Cross join: cartesian product of two DataFrames. PySpark crossJoin.
400pub fn cross_join(
401    left: &DataFrame,
402    right: &DataFrame,
403    case_sensitive: bool,
404) -> Result<DataFrame, PolarsError> {
405    let lf_left = left.df.as_ref().clone().lazy();
406    let lf_right = right.df.as_ref().clone().lazy();
407    let out = lf_left.cross_join(lf_right, None);
408    let pl_df = out.collect()?;
409    Ok(super::DataFrame::from_polars_with_options(
410        pl_df,
411        case_sensitive,
412    ))
413}
414
415/// Summary statistics (count, mean, std, min, max). PySpark describe.
416/// Builds a summary DataFrame with a "summary" column (PySpark name) and one column per numeric input column.
417pub fn describe(df: &DataFrame, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
418    use polars::prelude::*;
419    let pl_df = df.df.as_ref().clone();
420    let mut stat_values: Vec<Column> = Vec::new();
421    for col in pl_df.get_columns() {
422        let s = col.as_materialized_series();
423        let dtype = s.dtype();
424        if dtype.is_numeric() {
425            let name = s.name().clone();
426            let count = s.len() as i64 - s.null_count() as i64;
427            let mean_f = s.mean().unwrap_or(f64::NAN);
428            let std_f = s.std(1).unwrap_or(f64::NAN);
429            let s_f64 = s.cast(&DataType::Float64)?;
430            let ca = s_f64
431                .f64()
432                .map_err(|_| PolarsError::ComputeError("cast to f64 failed".into()))?;
433            let min_f = ca.min().unwrap_or(f64::NAN);
434            let max_f = ca.max().unwrap_or(f64::NAN);
435            // PySpark describe/summary returns string type for value columns
436            let is_float = matches!(dtype, DataType::Float64 | DataType::Float32);
437            let count_s = count.to_string();
438            let mean_s = if mean_f.is_nan() {
439                "None".to_string()
440            } else {
441                format!("{:.1}", mean_f)
442            };
443            let std_s = if std_f.is_nan() {
444                "None".to_string()
445            } else {
446                format!("{:.1}", std_f)
447            };
448            let min_s = if min_f.is_nan() {
449                "None".to_string()
450            } else if min_f.fract() == 0.0 && is_float {
451                format!("{:.1}", min_f)
452            } else if min_f.fract() == 0.0 {
453                format!("{:.0}", min_f)
454            } else {
455                format!("{min_f}")
456            };
457            let max_s = if max_f.is_nan() {
458                "None".to_string()
459            } else if max_f.fract() == 0.0 && is_float {
460                format!("{:.1}", max_f)
461            } else if max_f.fract() == 0.0 {
462                format!("{:.0}", max_f)
463            } else {
464                format!("{max_f}")
465            };
466            let series = Series::new(
467                name,
468                [
469                    count_s.as_str(),
470                    mean_s.as_str(),
471                    std_s.as_str(),
472                    min_s.as_str(),
473                    max_s.as_str(),
474                ],
475            );
476            stat_values.push(series.into());
477        }
478    }
479    if stat_values.is_empty() {
480        // No numeric columns: return minimal describe with just summary column (PySpark name)
481        let stat_col = Series::new(
482            "summary".into(),
483            &["count", "mean", "stddev", "min", "max" as &str],
484        )
485        .into();
486        let empty: Vec<f64> = Vec::new();
487        let empty_series = Series::new("placeholder".into(), empty).into();
488        let out_pl = polars::prelude::DataFrame::new(vec![stat_col, empty_series])?;
489        return Ok(super::DataFrame::from_polars_with_options(
490            out_pl,
491            case_sensitive,
492        ));
493    }
494    let summary_col = Series::new(
495        "summary".into(),
496        &["count", "mean", "stddev", "min", "max" as &str],
497    )
498    .into();
499    let mut cols: Vec<Column> = vec![summary_col];
500    cols.extend(stat_values);
501    let out_pl = polars::prelude::DataFrame::new(cols)?;
502    Ok(super::DataFrame::from_polars_with_options(
503        out_pl,
504        case_sensitive,
505    ))
506}
507
508/// Set difference: rows in left that are not in right (by all columns). PySpark subtract / except.
509pub fn subtract(
510    left: &DataFrame,
511    right: &DataFrame,
512    case_sensitive: bool,
513) -> Result<DataFrame, PolarsError> {
514    use polars::prelude::*;
515    let left_names = left.df.get_column_names();
516    let left_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
517    let right_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
518    let right_lf = right.df.as_ref().clone().lazy();
519    let left_lf = left.df.as_ref().clone().lazy();
520    let anti = left_lf.join(right_lf, left_on, right_on, JoinArgs::new(JoinType::Anti));
521    let pl_df = anti.collect()?;
522    Ok(super::DataFrame::from_polars_with_options(
523        pl_df,
524        case_sensitive,
525    ))
526}
527
528/// Set intersection: rows that appear in both DataFrames (by all columns). PySpark intersect.
529pub fn intersect(
530    left: &DataFrame,
531    right: &DataFrame,
532    case_sensitive: bool,
533) -> Result<DataFrame, PolarsError> {
534    use polars::prelude::*;
535    let left_names = left.df.get_column_names();
536    let left_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
537    let right_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
538    let left_lf = left.df.as_ref().clone().lazy();
539    let right_lf = right.df.as_ref().clone().lazy();
540    let semi = left_lf.join(right_lf, left_on, right_on, JoinArgs::new(JoinType::Semi));
541    let pl_df = semi.collect()?;
542    Ok(super::DataFrame::from_polars_with_options(
543        pl_df,
544        case_sensitive,
545    ))
546}
547
548// ---------- Batch A: sample, first/head/take/tail, is_empty, to_df ----------
549
550/// Sample a fraction of rows. PySpark sample(withReplacement, fraction, seed).
551pub fn sample(
552    df: &DataFrame,
553    with_replacement: bool,
554    fraction: f64,
555    seed: Option<u64>,
556    case_sensitive: bool,
557) -> Result<DataFrame, PolarsError> {
558    use polars::prelude::Series;
559    let n = df.df.height();
560    if n == 0 {
561        return Ok(super::DataFrame::from_polars_with_options(
562            df.df.as_ref().clone(),
563            case_sensitive,
564        ));
565    }
566    let take_n = (n as f64 * fraction).round() as usize;
567    let take_n = take_n.min(n).max(0);
568    if take_n == 0 {
569        return Ok(super::DataFrame::from_polars_with_options(
570            df.df.as_ref().clone().head(Some(0)),
571            case_sensitive,
572        ));
573    }
574    let idx_series = Series::new("idx".into(), (0..n).map(|i| i as u32).collect::<Vec<_>>());
575    let sampled_idx = idx_series.sample_n(take_n, with_replacement, true, seed)?;
576    let idx_ca = sampled_idx
577        .u32()
578        .map_err(|_| PolarsError::ComputeError("sample: expected u32 indices".into()))?;
579    let pl_df = df.df.as_ref().take(idx_ca)?;
580    Ok(super::DataFrame::from_polars_with_options(
581        pl_df,
582        case_sensitive,
583    ))
584}
585
586/// Split DataFrame by weights (random split). PySpark randomSplit(weights, seed).
587/// Returns one DataFrame per weight; weights are normalized to fractions.
588/// Each row is assigned to exactly one split (disjoint partitions).
589pub fn random_split(
590    df: &DataFrame,
591    weights: &[f64],
592    seed: Option<u64>,
593    case_sensitive: bool,
594) -> Result<Vec<DataFrame>, PolarsError> {
595    let total: f64 = weights.iter().sum();
596    if total <= 0.0 || weights.is_empty() {
597        return Ok(Vec::new());
598    }
599    let n = df.df.height();
600    if n == 0 {
601        return Ok(weights.iter().map(|_| super::DataFrame::empty()).collect());
602    }
603    // Normalize weights to cumulative fractions: e.g. [0.25, 0.25, 0.5] -> [0.25, 0.5, 1.0]
604    let mut cum = Vec::with_capacity(weights.len());
605    let mut acc = 0.0_f64;
606    for w in weights {
607        acc += w / total;
608        cum.push(acc);
609    }
610    // Assign each row index to one bucket using a single seeded RNG (disjoint split).
611    use polars::prelude::Series;
612    use rand::Rng;
613    use rand::SeedableRng;
614    let mut rng = rand::rngs::StdRng::seed_from_u64(seed.unwrap_or(0));
615    let mut bucket_indices: Vec<Vec<u32>> = (0..weights.len()).map(|_| Vec::new()).collect();
616    for i in 0..n {
617        let r: f64 = rng.gen();
618        let bucket = cum
619            .iter()
620            .position(|&c| r < c)
621            .unwrap_or(weights.len().saturating_sub(1));
622        bucket_indices[bucket].push(i as u32);
623    }
624    let pl = df.df.as_ref();
625    let mut out = Vec::with_capacity(weights.len());
626    for indices in bucket_indices {
627        if indices.is_empty() {
628            out.push(super::DataFrame::from_polars_with_options(
629                pl.clone().head(Some(0)),
630                case_sensitive,
631            ));
632        } else {
633            let idx_series = Series::new("idx".into(), indices);
634            let idx_ca = idx_series.u32().map_err(|_| {
635                PolarsError::ComputeError("random_split: expected u32 indices".into())
636            })?;
637            let taken = pl.take(idx_ca)?;
638            out.push(super::DataFrame::from_polars_with_options(
639                taken,
640                case_sensitive,
641            ));
642        }
643    }
644    Ok(out)
645}
646
647/// Stratified sample by column value. PySpark sampleBy(col, fractions, seed).
648/// fractions: list of (value as Expr literal, fraction to sample for that value).
649pub fn sample_by(
650    df: &DataFrame,
651    col_name: &str,
652    fractions: &[(Expr, f64)],
653    seed: Option<u64>,
654    case_sensitive: bool,
655) -> Result<DataFrame, PolarsError> {
656    use polars::prelude::*;
657    if fractions.is_empty() {
658        return Ok(super::DataFrame::from_polars_with_options(
659            df.df.as_ref().clone().head(Some(0)),
660            case_sensitive,
661        ));
662    }
663    let resolved = df.resolve_column_name(col_name)?;
664    let mut parts = Vec::with_capacity(fractions.len());
665    for (value_expr, frac) in fractions {
666        let cond = col(resolved.as_str()).eq(value_expr.clone());
667        let filtered = df.df.as_ref().clone().lazy().filter(cond).collect()?;
668        if filtered.height() == 0 {
669            parts.push(filtered.head(Some(0)));
670            continue;
671        }
672        let sampled = sample(
673            &super::DataFrame::from_polars_with_options(filtered, case_sensitive),
674            false,
675            *frac,
676            seed,
677            case_sensitive,
678        )?;
679        parts.push(sampled.df.as_ref().clone());
680    }
681    let mut out = parts
682        .first()
683        .ok_or_else(|| PolarsError::ComputeError("sample_by: no parts".into()))?
684        .clone();
685    for p in parts.iter().skip(1) {
686        out.vstack_mut(p)?;
687    }
688    Ok(super::DataFrame::from_polars_with_options(
689        out,
690        case_sensitive,
691    ))
692}
693
694/// First row as a DataFrame (one row). PySpark first().
695pub fn first(df: &DataFrame, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
696    let pl_df = df.df.as_ref().clone().head(Some(1));
697    Ok(super::DataFrame::from_polars_with_options(
698        pl_df,
699        case_sensitive,
700    ))
701}
702
703/// First n rows. PySpark head(n). Same as limit.
704pub fn head(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
705    limit(df, n, case_sensitive)
706}
707
708/// Take first n rows (alias for limit). PySpark take(n).
709pub fn take(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
710    limit(df, n, case_sensitive)
711}
712
713/// Last n rows. PySpark tail(n).
714pub fn tail(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
715    let total = df.df.height();
716    let skip = total.saturating_sub(n);
717    let pl_df = df.df.as_ref().clone().slice(skip as i64, n);
718    Ok(super::DataFrame::from_polars_with_options(
719        pl_df,
720        case_sensitive,
721    ))
722}
723
724/// Whether the DataFrame has zero rows. PySpark isEmpty.
725pub fn is_empty(df: &DataFrame) -> bool {
726    df.df.height() == 0
727}
728
729/// Rename columns. PySpark toDF(*colNames). Names must match length of columns.
730pub fn to_df(
731    df: &DataFrame,
732    names: &[&str],
733    case_sensitive: bool,
734) -> Result<DataFrame, PolarsError> {
735    let cols = df.df.get_column_names();
736    if names.len() != cols.len() {
737        return Err(PolarsError::ComputeError(
738            format!(
739                "toDF: expected {} column names, got {}",
740                cols.len(),
741                names.len()
742            )
743            .into(),
744        ));
745    }
746    let mut pl_df = df.df.as_ref().clone();
747    for (old, new) in cols.iter().zip(names.iter()) {
748        pl_df.rename(old.as_str(), (*new).into())?;
749    }
750    Ok(super::DataFrame::from_polars_with_options(
751        pl_df,
752        case_sensitive,
753    ))
754}
755
756// ---------- Batch B: toJSON, explain, printSchema ----------
757
758fn any_value_to_serde_value(av: &polars::prelude::AnyValue) -> serde_json::Value {
759    use polars::prelude::AnyValue;
760    use serde_json::Number;
761    match av {
762        AnyValue::Null => serde_json::Value::Null,
763        AnyValue::Boolean(v) => serde_json::Value::Bool(*v),
764        AnyValue::Int8(v) => serde_json::Value::Number(Number::from(*v as i64)),
765        AnyValue::Int32(v) => serde_json::Value::Number(Number::from(*v)),
766        AnyValue::Int64(v) => serde_json::Value::Number(Number::from(*v)),
767        AnyValue::UInt32(v) => serde_json::Value::Number(Number::from(*v)),
768        AnyValue::Float64(v) => Number::from_f64(*v)
769            .map(serde_json::Value::Number)
770            .unwrap_or(serde_json::Value::Null),
771        AnyValue::String(v) => serde_json::Value::String(v.to_string()),
772        _ => serde_json::Value::String(format!("{av:?}")),
773    }
774}
775
776/// Collect rows as JSON strings (one JSON object per row). PySpark toJSON.
777pub fn to_json(df: &DataFrame) -> Result<Vec<String>, PolarsError> {
778    use polars::prelude::*;
779    let pl = df.df.as_ref();
780    let names = pl.get_column_names();
781    let mut out = Vec::with_capacity(pl.height());
782    for r in 0..pl.height() {
783        let mut row = serde_json::Map::new();
784        for (i, name) in names.iter().enumerate() {
785            let col = pl
786                .get_columns()
787                .get(i)
788                .ok_or_else(|| PolarsError::ComputeError("to_json: column index".into()))?;
789            let series = col.as_materialized_series();
790            let av = series
791                .get(r)
792                .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
793            row.insert(name.to_string(), any_value_to_serde_value(&av));
794        }
795        out.push(
796            serde_json::to_string(&row)
797                .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?,
798        );
799    }
800    Ok(out)
801}
802
803/// Return a string describing the execution plan. PySpark explain.
804pub fn explain(_df: &DataFrame) -> String {
805    "DataFrame (eager Polars backend)".to_string()
806}
807
808/// Return schema as a tree string. PySpark printSchema (we return string; caller can print).
809pub fn print_schema(df: &DataFrame) -> Result<String, PolarsError> {
810    let schema = df.schema()?;
811    let mut s = "root\n".to_string();
812    for f in schema.fields() {
813        let dt = match &f.data_type {
814            crate::schema::DataType::String => "string",
815            crate::schema::DataType::Integer => "int",
816            crate::schema::DataType::Long => "bigint",
817            crate::schema::DataType::Double => "double",
818            crate::schema::DataType::Boolean => "boolean",
819            crate::schema::DataType::Date => "date",
820            crate::schema::DataType::Timestamp => "timestamp",
821            _ => "string",
822        };
823        s.push_str(&format!(" |-- {}: {}\n", f.name, dt));
824    }
825    Ok(s)
826}
827
828// ---------- Batch D: selectExpr, colRegex, withColumns, withColumnsRenamed, na ----------
829
830/// Select by expression strings. Minimal support: comma-separated column names. PySpark selectExpr.
831pub fn select_expr(
832    df: &DataFrame,
833    exprs: &[String],
834    case_sensitive: bool,
835) -> Result<DataFrame, PolarsError> {
836    let mut cols = Vec::new();
837    for e in exprs {
838        let e = e.trim();
839        if let Some((left, right)) = e.split_once(" as ") {
840            let col_name = left.trim();
841            let _alias = right.trim();
842            cols.push(df.resolve_column_name(col_name)?);
843        } else {
844            cols.push(df.resolve_column_name(e)?);
845        }
846    }
847    let refs: Vec<&str> = cols.iter().map(|s| s.as_str()).collect();
848    select(df, refs, case_sensitive)
849}
850
851/// Select columns whose names match the regex pattern. PySpark colRegex.
852pub fn col_regex(
853    df: &DataFrame,
854    pattern: &str,
855    case_sensitive: bool,
856) -> Result<DataFrame, PolarsError> {
857    let re = regex::Regex::new(pattern).map_err(|e| {
858        PolarsError::ComputeError(format!("colRegex: invalid pattern {pattern:?}: {e}").into())
859    })?;
860    let names = df.df.get_column_names();
861    let matched: Vec<&str> = names
862        .iter()
863        .filter(|n| re.is_match(n.as_str()))
864        .map(|s| s.as_str())
865        .collect();
866    if matched.is_empty() {
867        return Err(PolarsError::ComputeError(
868            format!("colRegex: no columns matched pattern {pattern:?}").into(),
869        ));
870    }
871    select(df, matched, case_sensitive)
872}
873
874/// Add or replace multiple columns. PySpark withColumns. Uses Column so deferred rand/randn get per-row values.
875pub fn with_columns(
876    df: &DataFrame,
877    exprs: &[(String, crate::column::Column)],
878    case_sensitive: bool,
879) -> Result<DataFrame, PolarsError> {
880    let mut current =
881        super::DataFrame::from_polars_with_options(df.df.as_ref().clone(), case_sensitive);
882    for (name, col) in exprs {
883        current = with_column(&current, name, col, case_sensitive)?;
884    }
885    Ok(current)
886}
887
888/// Rename multiple columns. PySpark withColumnsRenamed.
889pub fn with_columns_renamed(
890    df: &DataFrame,
891    renames: &[(String, String)],
892    case_sensitive: bool,
893) -> Result<DataFrame, PolarsError> {
894    let mut out = df.df.as_ref().clone();
895    for (old_name, new_name) in renames {
896        let resolved = df.resolve_column_name(old_name)?;
897        out.rename(resolved.as_str(), new_name.as_str().into())?;
898    }
899    Ok(super::DataFrame::from_polars_with_options(
900        out,
901        case_sensitive,
902    ))
903}
904
905/// NA sub-API builder. PySpark df.na().fill(...) / .drop(...).
906pub struct DataFrameNa<'a> {
907    pub(crate) df: &'a DataFrame,
908}
909
910impl<'a> DataFrameNa<'a> {
911    /// Fill nulls with the given value. PySpark na.fill(value).
912    pub fn fill(&self, value: Expr) -> Result<DataFrame, PolarsError> {
913        fillna(self.df, value, self.df.case_sensitive)
914    }
915
916    /// Drop rows with nulls in the given columns (or all). PySpark na.drop(subset).
917    pub fn drop(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
918        dropna(self.df, subset, self.df.case_sensitive)
919    }
920}
921
922// ---------- Batch E: offset, transform, freqItems, approxQuantile, crosstab, melt, exceptAll, intersectAll ----------
923
924/// Skip first n rows. PySpark offset(n).
925pub fn offset(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
926    let total = df.df.height();
927    let len = total.saturating_sub(n);
928    let pl_df = df.df.as_ref().clone().slice(n as i64, len);
929    Ok(super::DataFrame::from_polars_with_options(
930        pl_df,
931        case_sensitive,
932    ))
933}
934
935/// Transform DataFrame by a function. PySpark transform(func).
936pub fn transform<F>(df: &DataFrame, f: F) -> Result<DataFrame, PolarsError>
937where
938    F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
939{
940    let df_out = f(df.clone())?;
941    Ok(df_out)
942}
943
944/// Frequent items. PySpark freqItems. Returns one row with columns {col}_freqItems (array of values with frequency >= support).
945pub fn freq_items(
946    df: &DataFrame,
947    columns: &[&str],
948    support: f64,
949    case_sensitive: bool,
950) -> Result<DataFrame, PolarsError> {
951    use polars::prelude::SeriesMethods;
952    if columns.is_empty() {
953        return Ok(super::DataFrame::from_polars_with_options(
954            df.df.as_ref().clone().head(Some(0)),
955            case_sensitive,
956        ));
957    }
958    let support = support.clamp(1e-4, 1.0);
959    let pl_df = df.df.as_ref();
960    let n_total = pl_df.height() as f64;
961    if n_total == 0.0 {
962        let mut out = Vec::with_capacity(columns.len());
963        for col_name in columns {
964            let resolved = df.resolve_column_name(col_name)?;
965            let s = pl_df
966                .column(resolved.as_str())?
967                .as_series()
968                .ok_or_else(|| PolarsError::ComputeError("column not a series".into()))?
969                .clone();
970            let empty_sub = s.head(Some(0));
971            let list_chunked = polars::prelude::ListChunked::from_iter([empty_sub].into_iter())
972                .with_name(format!("{resolved}_freqItems").into());
973            out.push(list_chunked.into_series().into());
974        }
975        return Ok(super::DataFrame::from_polars_with_options(
976            polars::prelude::DataFrame::new(out)?,
977            case_sensitive,
978        ));
979    }
980    let mut out_series = Vec::with_capacity(columns.len());
981    for col_name in columns {
982        let resolved = df.resolve_column_name(col_name)?;
983        let s = pl_df
984            .column(resolved.as_str())?
985            .as_series()
986            .ok_or_else(|| PolarsError::ComputeError("column not a series".into()))?
987            .clone();
988        let vc = s.value_counts(false, false, "counts".into(), false)?;
989        let count_col = vc
990            .column("counts")
991            .map_err(|_| PolarsError::ComputeError("value_counts missing counts column".into()))?;
992        let counts = count_col
993            .u32()
994            .map_err(|_| PolarsError::ComputeError("freq_items: counts column not u32".into()))?;
995        let value_col_name = s.name();
996        let values_col = vc
997            .column(value_col_name.as_str())
998            .map_err(|_| PolarsError::ComputeError("value_counts missing value column".into()))?;
999        let threshold = (support * n_total).ceil() as u32;
1000        let indices: Vec<u32> = counts
1001            .into_iter()
1002            .enumerate()
1003            .filter_map(|(i, c)| {
1004                if c? >= threshold {
1005                    Some(i as u32)
1006                } else {
1007                    None
1008                }
1009            })
1010            .collect();
1011        let idx_series = Series::new("idx".into(), indices);
1012        let idx_ca = idx_series
1013            .u32()
1014            .map_err(|_| PolarsError::ComputeError("freq_items: index series not u32".into()))?;
1015        let values_series = values_col
1016            .as_series()
1017            .ok_or_else(|| PolarsError::ComputeError("value column not a series".into()))?;
1018        let filtered = values_series.take(idx_ca)?;
1019        let list_chunked = polars::prelude::ListChunked::from_iter([filtered].into_iter())
1020            .with_name(format!("{resolved}_freqItems").into());
1021        let list_row = list_chunked.into_series();
1022        out_series.push(list_row.into());
1023    }
1024    let out_df = polars::prelude::DataFrame::new(out_series)?;
1025    Ok(super::DataFrame::from_polars_with_options(
1026        out_df,
1027        case_sensitive,
1028    ))
1029}
1030
1031/// Approximate quantiles. PySpark approxQuantile. Returns one column "quantile" with one row per probability.
1032pub fn approx_quantile(
1033    df: &DataFrame,
1034    column: &str,
1035    probabilities: &[f64],
1036    case_sensitive: bool,
1037) -> Result<DataFrame, PolarsError> {
1038    use polars::prelude::{ChunkQuantile, QuantileMethod};
1039    if probabilities.is_empty() {
1040        return Ok(super::DataFrame::from_polars_with_options(
1041            polars::prelude::DataFrame::new(vec![Series::new(
1042                "quantile".into(),
1043                Vec::<f64>::new(),
1044            )
1045            .into()])?,
1046            case_sensitive,
1047        ));
1048    }
1049    let resolved = df.resolve_column_name(column)?;
1050    let s = df
1051        .df
1052        .as_ref()
1053        .column(resolved.as_str())?
1054        .as_series()
1055        .ok_or_else(|| PolarsError::ComputeError("approx_quantile: column not a series".into()))?
1056        .clone();
1057    let s_f64 = s.cast(&polars::prelude::DataType::Float64)?;
1058    let ca = s_f64
1059        .f64()
1060        .map_err(|_| PolarsError::ComputeError("approx_quantile: need numeric column".into()))?;
1061    let mut quantiles = Vec::with_capacity(probabilities.len());
1062    for &p in probabilities {
1063        let q = ca.quantile(p, QuantileMethod::Linear)?;
1064        quantiles.push(q.unwrap_or(f64::NAN));
1065    }
1066    let out_df =
1067        polars::prelude::DataFrame::new(vec![Series::new("quantile".into(), quantiles).into()])?;
1068    Ok(super::DataFrame::from_polars_with_options(
1069        out_df,
1070        case_sensitive,
1071    ))
1072}
1073
1074/// Cross-tabulation. PySpark crosstab. Returns long format (col1, col2, count); for wide format use pivot on the result.
1075pub fn crosstab(
1076    df: &DataFrame,
1077    col1: &str,
1078    col2: &str,
1079    case_sensitive: bool,
1080) -> Result<DataFrame, PolarsError> {
1081    use polars::prelude::*;
1082    let c1 = df.resolve_column_name(col1)?;
1083    let c2 = df.resolve_column_name(col2)?;
1084    let pl_df = df.df.as_ref();
1085    let grouped = pl_df
1086        .clone()
1087        .lazy()
1088        .group_by([col(c1.as_str()), col(c2.as_str())])
1089        .agg([len().alias("count")])
1090        .collect()?;
1091    Ok(super::DataFrame::from_polars_with_options(
1092        grouped,
1093        case_sensitive,
1094    ))
1095}
1096
1097/// Unpivot (melt). PySpark melt. Long format with id_vars kept, plus "variable" and "value" columns.
1098pub fn melt(
1099    df: &DataFrame,
1100    id_vars: &[&str],
1101    value_vars: &[&str],
1102    case_sensitive: bool,
1103) -> Result<DataFrame, PolarsError> {
1104    use polars::prelude::*;
1105    let pl_df = df.df.as_ref();
1106    if value_vars.is_empty() {
1107        return Ok(super::DataFrame::from_polars_with_options(
1108            pl_df.head(Some(0)),
1109            case_sensitive,
1110        ));
1111    }
1112    let id_resolved: Vec<String> = id_vars
1113        .iter()
1114        .map(|s| df.resolve_column_name(s).map(|r| r.to_string()))
1115        .collect::<Result<Vec<_>, _>>()?;
1116    let value_resolved: Vec<String> = value_vars
1117        .iter()
1118        .map(|s| df.resolve_column_name(s).map(|r| r.to_string()))
1119        .collect::<Result<Vec<_>, _>>()?;
1120    let mut parts = Vec::with_capacity(value_vars.len());
1121    for vname in &value_resolved {
1122        let select_cols: Vec<&str> = id_resolved
1123            .iter()
1124            .map(|s| s.as_str())
1125            .chain([vname.as_str()])
1126            .collect();
1127        let mut part = pl_df.select(select_cols)?;
1128        let var_series = Series::new("variable".into(), vec![vname.as_str(); part.height()]);
1129        part.with_column(var_series)?;
1130        part.rename(vname.as_str(), "value".into())?;
1131        parts.push(part);
1132    }
1133    let mut out = parts
1134        .first()
1135        .ok_or_else(|| PolarsError::ComputeError("melt: no value columns".into()))?
1136        .clone();
1137    for p in parts.iter().skip(1) {
1138        out.vstack_mut(p)?;
1139    }
1140    let col_order: Vec<&str> = id_resolved
1141        .iter()
1142        .map(|s| s.as_str())
1143        .chain(["variable", "value"])
1144        .collect();
1145    let out = out.select(col_order)?;
1146    Ok(super::DataFrame::from_polars_with_options(
1147        out,
1148        case_sensitive,
1149    ))
1150}
1151
1152/// Set difference keeping duplicates. PySpark exceptAll. Simple impl: same as subtract.
1153pub fn except_all(
1154    left: &DataFrame,
1155    right: &DataFrame,
1156    case_sensitive: bool,
1157) -> Result<DataFrame, PolarsError> {
1158    subtract(left, right, case_sensitive)
1159}
1160
1161/// Set intersection keeping duplicates. PySpark intersectAll. Simple impl: same as intersect.
1162pub fn intersect_all(
1163    left: &DataFrame,
1164    right: &DataFrame,
1165    case_sensitive: bool,
1166) -> Result<DataFrame, PolarsError> {
1167    intersect(left, right, case_sensitive)
1168}
1169
1170#[cfg(test)]
1171mod tests {
1172    use super::{distinct, drop, dropna, first, head, limit, offset};
1173    use crate::{DataFrame, SparkSession};
1174
1175    fn test_df() -> DataFrame {
1176        let spark = SparkSession::builder()
1177            .app_name("transform_tests")
1178            .get_or_create();
1179        spark
1180            .create_dataframe(
1181                vec![
1182                    (1i64, 10i64, "a".to_string()),
1183                    (2i64, 20i64, "b".to_string()),
1184                    (3i64, 30i64, "c".to_string()),
1185                ],
1186                vec!["id", "v", "label"],
1187            )
1188            .unwrap()
1189    }
1190
1191    #[test]
1192    fn limit_zero() {
1193        let df = test_df();
1194        let out = limit(&df, 0, false).unwrap();
1195        assert_eq!(out.count().unwrap(), 0);
1196    }
1197
1198    #[test]
1199    fn limit_more_than_rows() {
1200        let df = test_df();
1201        let out = limit(&df, 10, false).unwrap();
1202        assert_eq!(out.count().unwrap(), 3);
1203    }
1204
1205    #[test]
1206    fn distinct_on_empty() {
1207        let spark = SparkSession::builder()
1208            .app_name("transform_tests")
1209            .get_or_create();
1210        let df = spark
1211            .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["a", "b", "c"])
1212            .unwrap();
1213        let out = distinct(&df, None, false).unwrap();
1214        assert_eq!(out.count().unwrap(), 0);
1215    }
1216
1217    #[test]
1218    fn first_returns_one_row() {
1219        let df = test_df();
1220        let out = first(&df, false).unwrap();
1221        assert_eq!(out.count().unwrap(), 1);
1222    }
1223
1224    #[test]
1225    fn head_n() {
1226        let df = test_df();
1227        let out = head(&df, 2, false).unwrap();
1228        assert_eq!(out.count().unwrap(), 2);
1229    }
1230
1231    #[test]
1232    fn offset_skip_first() {
1233        let df = test_df();
1234        let out = offset(&df, 1, false).unwrap();
1235        assert_eq!(out.count().unwrap(), 2);
1236    }
1237
1238    #[test]
1239    fn offset_beyond_length_returns_empty() {
1240        let df = test_df();
1241        let out = offset(&df, 10, false).unwrap();
1242        assert_eq!(out.count().unwrap(), 0);
1243    }
1244
1245    #[test]
1246    fn drop_column() {
1247        let df = test_df();
1248        let out = drop(&df, vec!["v"], false).unwrap();
1249        let cols = out.columns().unwrap();
1250        assert!(!cols.contains(&"v".to_string()));
1251        assert_eq!(out.count().unwrap(), 3);
1252    }
1253
1254    #[test]
1255    fn dropna_all_columns() {
1256        let df = test_df();
1257        let out = dropna(&df, None, false).unwrap();
1258        assert_eq!(out.count().unwrap(), 3);
1259    }
1260}