1use super::DataFrame;
7use crate::functions::SortOrder;
8use crate::type_coercion::{coerce_expr_pair, find_common_type};
9use polars::prelude::{
10 DataType, Expr, Float64Chunked, IntoLazy, IntoSeries, NamedFrom, PlSmallStr, PolarsError,
11 Selector, Series, UnionArgs, UniqueKeepStrategy, col,
12};
13use std::collections::HashMap;
14
15fn series_as_f64_ca(s: &Series, context: &str) -> Result<Float64Chunked, PolarsError> {
16 let s_f64 = s.cast(&DataType::Float64)?;
17 let ca = s_f64.f64().map_err(|_| {
18 PolarsError::ComputeError(format!("{}: need numeric/f64 column", context).into())
19 })?;
20 Ok(ca.clone())
21}
22use std::sync::Arc;
23
24pub fn select(
26 df: &DataFrame,
27 cols: Vec<&str>,
28 case_sensitive: bool,
29) -> Result<DataFrame, PolarsError> {
30 let resolved: Vec<String> = cols
31 .iter()
32 .map(|c| df.resolve_column_name(c))
33 .collect::<Result<Vec<_>, _>>()?;
34 let exprs: Vec<Expr> = resolved.iter().map(|s| col(s.as_str())).collect();
35 let lf = df.lazy_frame().select(&exprs);
36 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
37}
38
39pub fn select_with_exprs(
43 df: &DataFrame,
44 exprs: Vec<Expr>,
45 case_sensitive: bool,
46) -> Result<DataFrame, PolarsError> {
47 let exprs: Vec<Expr> = exprs
48 .into_iter()
49 .map(|e| df.resolve_expr_column_names(e))
50 .collect::<Result<Vec<_>, _>>()?;
51 let mut name_count: HashMap<String, u32> = HashMap::new();
52 let exprs: Vec<Expr> = exprs
53 .into_iter()
54 .map(|e| {
55 let base_name = polars_plan::utils::expr_output_name(&e)
56 .map(|s| s.to_string())
57 .unwrap_or_else(|_| "_".to_string());
58 let count = name_count.entry(base_name.clone()).or_insert(0);
59 *count += 1;
60 let final_name = if *count == 1 {
61 base_name
62 } else {
63 format!("{}_{}", base_name, *count - 1)
64 };
65 if *count == 1 {
66 e
67 } else {
68 e.alias(final_name.as_str())
69 }
70 })
71 .collect();
72 let lf = df.lazy_frame().select(&exprs);
73 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
74}
75
76#[derive(Clone)]
79pub enum SelectItem<'a> {
80 ColumnName(&'a str),
82 Expr(Expr),
84}
85
86pub fn select_items(
88 df: &DataFrame,
89 items: Vec<SelectItem<'_>>,
90 case_sensitive: bool,
91) -> Result<DataFrame, PolarsError> {
92 let mut exprs = Vec::with_capacity(items.len());
93 for item in items {
94 match item {
95 SelectItem::ColumnName(name) => {
96 let resolved = df.resolve_column_name(name)?;
97 exprs.push(col(resolved));
98 }
99 SelectItem::Expr(e) => {
100 let resolved = df.resolve_expr_column_names(e)?;
101 exprs.push(resolved);
102 }
103 }
104 }
105 select_with_exprs(df, exprs, case_sensitive)
106}
107
108pub fn filter(
111 df: &DataFrame,
112 condition: Expr,
113 case_sensitive: bool,
114) -> Result<DataFrame, PolarsError> {
115 let condition = df.resolve_expr_column_names(condition)?;
116 let condition = df.coerce_string_numeric_comparisons(condition)?;
117 let lf = df.lazy_frame().filter(condition);
118 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
119}
120
121pub fn with_column(
123 df: &DataFrame,
124 column_name: &str,
125 column: &crate::column::Column,
126 case_sensitive: bool,
127) -> Result<DataFrame, PolarsError> {
128 if let Some(deferred) = column.deferred {
130 match deferred {
131 crate::column::DeferredRandom::Rand(seed) => {
132 let pl_df = df.collect_inner()?;
133 let mut pl_df = pl_df.as_ref().clone();
134 let n = pl_df.height();
135 let series = crate::udfs::series_rand_n(column_name, n, seed);
136 pl_df.with_column(series.into())?;
137 return Ok(super::DataFrame::from_polars_with_options(
138 pl_df,
139 case_sensitive,
140 ));
141 }
142 crate::column::DeferredRandom::Randn(seed) => {
143 let pl_df = df.collect_inner()?;
144 let mut pl_df = pl_df.as_ref().clone();
145 let n = pl_df.height();
146 let series = crate::udfs::series_randn_n(column_name, n, seed);
147 pl_df.with_column(series.into())?;
148 return Ok(super::DataFrame::from_polars_with_options(
149 pl_df,
150 case_sensitive,
151 ));
152 }
153 }
154 }
155 let expr = df.resolve_expr_column_names(column.expr().clone())?;
156 let expr = df.coerce_string_numeric_comparisons(expr)?;
157 let lf = df.lazy_frame().with_column(expr.alias(column_name));
158 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
159}
160
161pub fn order_by(
163 df: &DataFrame,
164 column_names: Vec<&str>,
165 ascending: Vec<bool>,
166 case_sensitive: bool,
167) -> Result<DataFrame, PolarsError> {
168 use polars::prelude::*;
169 let mut asc = ascending;
170 while asc.len() < column_names.len() {
171 asc.push(true);
172 }
173 asc.truncate(column_names.len());
174 let resolved: Vec<String> = column_names
175 .iter()
176 .map(|c| df.resolve_column_name(c))
177 .collect::<Result<Vec<_>, _>>()?;
178 let exprs: Vec<Expr> = resolved.iter().map(|s| col(s.as_str())).collect();
179 let descending: Vec<bool> = asc.iter().map(|&a| !a).collect();
180 let nulls_last: Vec<bool> = descending.clone();
182 let lf = df.lazy_frame().sort_by_exprs(
183 exprs,
184 SortMultipleOptions::new()
185 .with_order_descending_multi(descending)
186 .with_nulls_last_multi(nulls_last),
187 );
188 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
189}
190
191pub fn order_by_exprs(
194 df: &DataFrame,
195 sort_orders: Vec<SortOrder>,
196 case_sensitive: bool,
197) -> Result<DataFrame, PolarsError> {
198 use polars::prelude::*;
199 if sort_orders.is_empty() {
200 return Ok(super::DataFrame::from_lazy_with_options(
201 df.lazy_frame(),
202 case_sensitive,
203 ));
204 }
205 let exprs: Vec<Expr> = sort_orders
206 .iter()
207 .map(|s| df.resolve_expr_column_names(s.expr().clone()))
208 .collect::<Result<Vec<_>, _>>()?;
209 let descending: Vec<bool> = sort_orders.iter().map(|s| s.descending).collect();
210 let nulls_last: Vec<bool> = sort_orders.iter().map(|s| s.nulls_last).collect();
211 let opts = SortMultipleOptions::new()
212 .with_order_descending_multi(descending)
213 .with_nulls_last_multi(nulls_last);
214 let lf = df.lazy_frame().sort_by_exprs(exprs, opts);
215 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
216}
217
218pub fn union(
221 left: &DataFrame,
222 right: &DataFrame,
223 case_sensitive: bool,
224) -> Result<DataFrame, PolarsError> {
225 let left_names = left.columns()?;
226 let right_names = right.columns()?;
227 if left_names != right_names {
228 return Err(PolarsError::InvalidOperation(
229 format!(
230 "union: column order/names must match. Left: {:?}, Right: {:?}",
231 left_names, right_names
232 )
233 .into(),
234 ));
235 }
236 let mut left_exprs: Vec<Expr> = Vec::with_capacity(left_names.len());
237 let mut right_exprs: Vec<Expr> = Vec::with_capacity(right_names.len());
238 for name in &left_names {
239 let resolved_left = left.resolve_column_name(name)?;
240 let resolved_right = right.resolve_column_name(name)?;
241 let left_dtype = left.get_column_dtype(name).unwrap_or(DataType::Null);
242 let right_dtype = right.get_column_dtype(name).unwrap_or(DataType::Null);
243 let target = if left_dtype == DataType::Null {
244 right_dtype.clone()
245 } else if right_dtype == DataType::Null || left_dtype == right_dtype {
246 left_dtype.clone()
247 } else {
248 find_common_type(&left_dtype, &right_dtype)?
249 };
250 let left_expr = if left_dtype == target {
251 col(resolved_left.as_str())
252 } else {
253 col(resolved_left.as_str()).cast(target.clone())
254 };
255 let right_expr = if right_dtype == target {
256 col(resolved_right.as_str())
257 } else {
258 col(resolved_right.as_str()).cast(target)
259 };
260 left_exprs.push(left_expr.alias(name.as_str()));
261 right_exprs.push(right_expr.alias(name.as_str()));
262 }
263 let lf1 = left.lazy_frame().select(&left_exprs);
264 let lf2 = right.lazy_frame().select(&right_exprs);
265 let out = polars::prelude::concat([lf1, lf2], UnionArgs::default())?;
266 Ok(super::DataFrame::from_lazy_with_options(
267 out,
268 case_sensitive,
269 ))
270}
271
272pub fn union_by_name(
277 left: &DataFrame,
278 right: &DataFrame,
279 allow_missing_columns: bool,
280 case_sensitive: bool,
281) -> Result<DataFrame, PolarsError> {
282 use crate::type_coercion::find_common_type;
283 use polars::prelude::*;
284
285 let left_names = left.columns()?;
286 let right_names = right.columns()?;
287 let contains = |names: &[String], name: &str| -> bool {
288 if case_sensitive {
289 names.iter().any(|n| n.as_str() == name)
290 } else {
291 let name_lower = name.to_lowercase();
292 names
293 .iter()
294 .any(|n| n.as_str().to_lowercase() == name_lower)
295 }
296 };
297 let resolve = |names: &[String], name: &str| -> Option<String> {
298 if case_sensitive {
299 names.iter().find(|n| n.as_str() == name).cloned()
300 } else {
301 let name_lower = name.to_lowercase();
302 names
303 .iter()
304 .find(|n| n.as_str().to_lowercase() == name_lower)
305 .cloned()
306 }
307 };
308 let all_columns: Vec<String> = if allow_missing_columns {
309 let mut out = left_names.clone();
310 for r in &right_names {
311 if !contains(&out, r.as_str()) {
312 out.push(r.clone());
313 }
314 }
315 out
316 } else {
317 left_names.clone()
318 };
319 let mut left_exprs: Vec<Expr> = Vec::with_capacity(all_columns.len());
321 let mut right_exprs: Vec<Expr> = Vec::with_capacity(all_columns.len());
322 for c in &all_columns {
323 let left_has = resolve(&left_names, c.as_str());
324 let right_has = resolve(&right_names, c.as_str());
325 let left_dtype = left_has.as_ref().and_then(|r| left.get_column_dtype(r));
326 let right_dtype = right_has.as_ref().and_then(|r| right.get_column_dtype(r));
327 if let (Some(l), Some(r)) = (&left_has, &right_has) {
329 if let (Some(lt), Some(rt)) = (&left_dtype, &right_dtype) {
330 if lt != rt {
331 let (le, re) = coerce_expr_pair(l, r, lt, rt, c).map_err(|e| {
332 PolarsError::ComputeError(
333 format!("union_by_name: column '{}' type coercion: {}", c, e).into(),
334 )
335 })?;
336 left_exprs.push(le);
337 right_exprs.push(re);
338 continue;
339 }
340 }
341 }
342 let common_dtype = match (&left_dtype, &right_dtype) {
345 (Some(lt), Some(rt)) if lt != rt => find_common_type(lt, rt).map_err(|e| {
346 PolarsError::ComputeError(
347 format!("union_by_name: column '{}' type coercion: {}", c, e).into(),
348 )
349 })?,
350 (Some(lt), Some(_)) => lt.clone(),
351 (Some(lt), None) | (None, Some(lt)) => {
352 if lt == &polars::prelude::DataType::String {
354 lt.clone()
355 } else {
356 polars::prelude::DataType::String
357 }
358 }
359 (None, None) => polars::prelude::DataType::Null,
360 };
361 let left_expr = match &left_has {
362 Some(r) => col(r.as_str()).cast(common_dtype.clone()).alias(c.as_str()),
363 None => polars::prelude::lit(polars::prelude::NULL)
364 .cast(common_dtype.clone())
365 .alias(c.as_str()),
366 };
367 left_exprs.push(left_expr);
368 let right_expr = match &right_has {
369 Some(r) => col(r.as_str()).cast(common_dtype.clone()).alias(c.as_str()),
370 None if allow_missing_columns => polars::prelude::lit(polars::prelude::NULL)
371 .cast(common_dtype)
372 .alias(c.as_str()),
373 None => {
374 return Err(PolarsError::InvalidOperation(
375 format!(
376 "union_by_name: column '{}' missing in right DataFrame (allow_missing_columns=False)",
377 c
378 )
379 .into(),
380 ));
381 }
382 };
383 right_exprs.push(right_expr);
384 }
385 let lf1 = left.lazy_frame().select(&left_exprs);
386 let lf2 = right.lazy_frame().select(&right_exprs);
387 let out = polars::prelude::concat([lf1, lf2], UnionArgs::default())?;
388 Ok(super::DataFrame::from_lazy_with_options(
389 out,
390 case_sensitive,
391 ))
392}
393
394pub fn distinct(
396 df: &DataFrame,
397 subset: Option<Vec<&str>>,
398 case_sensitive: bool,
399) -> Result<DataFrame, PolarsError> {
400 let subset_names: Option<Vec<String>> = subset
401 .map(|cols| {
402 cols.iter()
403 .map(|s| df.resolve_column_name(s))
404 .collect::<Result<Vec<_>, _>>()
405 })
406 .transpose()?;
407 let subset_selector: Option<Selector> = subset_names.map(|names| Selector::ByName {
408 names: Arc::from(names.into_iter().map(PlSmallStr::from).collect::<Vec<_>>()),
409 strict: false,
410 });
411 let lf = df
412 .lazy_frame()
413 .unique(subset_selector, UniqueKeepStrategy::First);
414 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
415}
416
417pub fn drop(
419 df: &DataFrame,
420 columns: Vec<&str>,
421 case_sensitive: bool,
422) -> Result<DataFrame, PolarsError> {
423 let resolved: Vec<String> = columns
424 .iter()
425 .map(|c| df.resolve_column_name(c))
426 .collect::<Result<Vec<_>, _>>()?;
427 let all_names = df.columns()?;
428 let to_keep: Vec<Expr> = all_names
429 .iter()
430 .filter(|n| !resolved.iter().any(|r| r == n.as_str()))
431 .map(|n| col(n.as_str()))
432 .collect();
433 let lf = df.lazy_frame().select(&to_keep);
434 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
435}
436
437pub fn dropna(
441 df: &DataFrame,
442 subset: Option<Vec<&str>>,
443 how: &str,
444 thresh: Option<usize>,
445 case_sensitive: bool,
446) -> Result<DataFrame, PolarsError> {
447 use polars::prelude::*;
448 let cols: Vec<String> = match &subset {
449 Some(c) => c
450 .iter()
451 .map(|n| df.resolve_column_name(n))
452 .collect::<Result<Vec<_>, _>>()?,
453 None => df.columns()?,
454 };
455 let col_exprs: Vec<Expr> = cols.iter().map(|c| col(c.as_str())).collect();
456 let base_lf = df.lazy_frame();
457 let lf = if let Some(n) = thresh {
458 let count_expr: Expr = col_exprs
460 .iter()
461 .map(|e| e.clone().is_not_null().cast(DataType::Int32))
462 .fold(lit(0i32), |a, b| a + b);
463 base_lf.filter(count_expr.gt_eq(lit(n as i32)))
464 } else if how.eq_ignore_ascii_case("all") {
465 let any_not_null: Expr = col_exprs
467 .into_iter()
468 .map(|e| e.is_not_null())
469 .fold(lit(false), |a, b| a.or(b));
470 base_lf.filter(any_not_null)
471 } else {
472 let subset_selector = Selector::ByName {
474 names: Arc::from(
475 cols.iter()
476 .map(|s| PlSmallStr::from(s.as_str()))
477 .collect::<Vec<_>>(),
478 ),
479 strict: false,
480 };
481 base_lf.drop_nulls(Some(subset_selector))
482 };
483 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
484}
485
486pub fn fillna(
489 df: &DataFrame,
490 value_expr: Expr,
491 subset: Option<Vec<&str>>,
492 case_sensitive: bool,
493) -> Result<DataFrame, PolarsError> {
494 use polars::prelude::*;
495 let exprs: Vec<Expr> = match subset {
496 Some(cols) => cols
497 .iter()
498 .map(|n| {
499 let resolved = df.resolve_column_name(n)?;
500 Ok(col(resolved.as_str()).fill_null(value_expr.clone()))
501 })
502 .collect::<Result<Vec<_>, PolarsError>>()?,
503 None => df
504 .columns()?
505 .iter()
506 .map(|n| col(n.as_str()).fill_null(value_expr.clone()))
507 .collect(),
508 };
509 let lf = df.lazy_frame().with_columns(exprs);
510 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
511}
512
513pub fn limit(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
515 let lf = df.lazy_frame().slice(0, n as u32);
517 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
518}
519
520pub fn with_column_renamed(
522 df: &DataFrame,
523 old_name: &str,
524 new_name: &str,
525 case_sensitive: bool,
526) -> Result<DataFrame, PolarsError> {
527 let resolved = df.resolve_column_name(old_name)?;
528 let lf = df
529 .lazy_frame()
530 .rename([resolved.as_str()], [new_name], true);
531 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
532}
533
534pub fn replace(
536 df: &DataFrame,
537 column_name: &str,
538 old_value: Expr,
539 new_value: Expr,
540 case_sensitive: bool,
541) -> Result<DataFrame, PolarsError> {
542 use polars::prelude::*;
543 let resolved = df.resolve_column_name(column_name)?;
544 let repl = when(col(resolved.as_str()).eq(old_value))
545 .then(new_value)
546 .otherwise(col(resolved.as_str()));
547 let lf = df.lazy_frame().with_column(repl.alias(resolved.as_str()));
548 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
549}
550
551pub fn cross_join(
553 left: &DataFrame,
554 right: &DataFrame,
555 case_sensitive: bool,
556) -> Result<DataFrame, PolarsError> {
557 let lf_left = left.lazy_frame();
558 let lf_right = right.lazy_frame();
559 let out = lf_left.cross_join(lf_right, None);
560 Ok(super::DataFrame::from_lazy_with_options(
561 out,
562 case_sensitive,
563 ))
564}
565
566pub fn describe(df: &DataFrame, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
569 use polars::prelude::*;
570 let pl_df = df.collect_inner()?.as_ref().clone();
571 let mut stat_values: Vec<Column> = Vec::new();
572 for col in pl_df.columns() {
573 let s = col.as_materialized_series();
574 let dtype = s.dtype();
575 if dtype.is_numeric() {
576 let name = s.name().clone();
577 let count = s.len() as i64 - s.null_count() as i64;
578 let mean_f = s.mean().unwrap_or(f64::NAN);
579 let std_f = s.std(1).unwrap_or(f64::NAN);
580 let ca = series_as_f64_ca(s, "describe")?;
581 let min_f = ca.min().unwrap_or(f64::NAN);
582 let max_f = ca.max().unwrap_or(f64::NAN);
583 let is_float = matches!(dtype, DataType::Float64 | DataType::Float32);
585 let count_s = count.to_string();
586 let mean_s = if mean_f.is_nan() {
587 "None".to_string()
588 } else {
589 format!("{:.1}", mean_f)
590 };
591 let std_s = if std_f.is_nan() {
592 "None".to_string()
593 } else {
594 format!("{:.1}", std_f)
595 };
596 let min_s = if min_f.is_nan() {
597 "None".to_string()
598 } else if min_f.fract() == 0.0 && is_float {
599 format!("{:.1}", min_f)
600 } else if min_f.fract() == 0.0 {
601 format!("{:.0}", min_f)
602 } else {
603 format!("{min_f}")
604 };
605 let max_s = if max_f.is_nan() {
606 "None".to_string()
607 } else if max_f.fract() == 0.0 && is_float {
608 format!("{:.1}", max_f)
609 } else if max_f.fract() == 0.0 {
610 format!("{:.0}", max_f)
611 } else {
612 format!("{max_f}")
613 };
614 let series = Series::new(
615 name,
616 [
617 count_s.as_str(),
618 mean_s.as_str(),
619 std_s.as_str(),
620 min_s.as_str(),
621 max_s.as_str(),
622 ],
623 );
624 stat_values.push(series.into());
625 }
626 }
627 if stat_values.is_empty() {
628 let stat_col = Series::new(
630 "summary".into(),
631 &["count", "mean", "stddev", "min", "max" as &str],
632 )
633 .into();
634 let empty: Vec<f64> = Vec::new();
635 let empty_series = Series::new("placeholder".into(), empty).into();
636 let out_pl = polars::prelude::DataFrame::new_infer_height(vec![stat_col, empty_series])?;
637 return Ok(super::DataFrame::from_polars_with_options(
638 out_pl,
639 case_sensitive,
640 ));
641 }
642 let summary_col = Series::new(
643 "summary".into(),
644 &["count", "mean", "stddev", "min", "max" as &str],
645 )
646 .into();
647 let mut cols: Vec<Column> = vec![summary_col];
648 cols.extend(stat_values);
649 let out_pl = polars::prelude::DataFrame::new_infer_height(cols)?;
650 Ok(super::DataFrame::from_polars_with_options(
651 out_pl,
652 case_sensitive,
653 ))
654}
655
656pub fn subtract(
659 left: &DataFrame,
660 right: &DataFrame,
661 case_sensitive: bool,
662) -> Result<DataFrame, PolarsError> {
663 use polars::prelude::*;
664 let left_names = left.columns()?;
665 let right_names = right.columns()?;
666 let right_on: Vec<Expr> = left_names
667 .iter()
668 .map(|ln| {
669 let resolved = if case_sensitive {
670 right_names
671 .iter()
672 .find(|rn| rn.as_str() == ln.as_str())
673 .cloned()
674 .ok_or_else(|| {
675 PolarsError::ColumnNotFound(
676 format!("subtract: column '{}' not found on right", ln).into(),
677 )
678 })?
679 } else {
680 let ln_lower = ln.to_lowercase();
681 right_names
682 .iter()
683 .find(|rn| rn.to_lowercase() == ln_lower)
684 .cloned()
685 .ok_or_else(|| {
686 PolarsError::ColumnNotFound(
687 format!("subtract: column '{}' not found on right", ln).into(),
688 )
689 })?
690 };
691 Ok(col(resolved.as_str()))
692 })
693 .collect::<Result<Vec<_>, PolarsError>>()?;
694 let left_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
695 let right_lf = right.lazy_frame();
696 let left_lf = left.lazy_frame();
697 let anti = left_lf.join(right_lf, left_on, right_on, JoinArgs::new(JoinType::Anti));
698 Ok(super::DataFrame::from_lazy_with_options(
699 anti,
700 case_sensitive,
701 ))
702}
703
704pub fn intersect(
707 left: &DataFrame,
708 right: &DataFrame,
709 case_sensitive: bool,
710) -> Result<DataFrame, PolarsError> {
711 use polars::prelude::*;
712 let left_names = left.columns()?;
713 let right_names = right.columns()?;
714 let right_on: Vec<Expr> = left_names
715 .iter()
716 .map(|ln| {
717 let resolved = if case_sensitive {
718 right_names
719 .iter()
720 .find(|rn| rn.as_str() == ln.as_str())
721 .cloned()
722 .ok_or_else(|| {
723 PolarsError::ColumnNotFound(
724 format!("intersect: column '{}' not found on right", ln).into(),
725 )
726 })?
727 } else {
728 let ln_lower = ln.to_lowercase();
729 right_names
730 .iter()
731 .find(|rn| rn.to_lowercase() == ln_lower)
732 .cloned()
733 .ok_or_else(|| {
734 PolarsError::ColumnNotFound(
735 format!("intersect: column '{}' not found on right", ln).into(),
736 )
737 })?
738 };
739 Ok(col(resolved.as_str()))
740 })
741 .collect::<Result<Vec<_>, PolarsError>>()?;
742 let left_on: Vec<Expr> = left_names.iter().map(|n| col(n.as_str())).collect();
743 let left_lf = left.lazy_frame();
744 let right_lf = right.lazy_frame();
745 let semi = left_lf
746 .join(right_lf, left_on, right_on, JoinArgs::new(JoinType::Semi))
747 .unique(None, UniqueKeepStrategy::First);
748 Ok(super::DataFrame::from_lazy_with_options(
749 semi,
750 case_sensitive,
751 ))
752}
753
754pub fn sample(
758 df: &DataFrame,
759 with_replacement: bool,
760 fraction: f64,
761 seed: Option<u64>,
762 case_sensitive: bool,
763) -> Result<DataFrame, PolarsError> {
764 use polars::prelude::Series;
765 let pl = df.collect_inner()?;
766 let n = pl.height();
767 if n == 0 {
768 return Ok(super::DataFrame::from_lazy_with_options(
769 polars::prelude::DataFrame::empty().lazy(),
770 case_sensitive,
771 ));
772 }
773 let take_n = (n as f64 * fraction).round() as usize;
774 let take_n = take_n.min(n).max(0);
775 if take_n == 0 {
776 return Ok(super::DataFrame::from_lazy_with_options(
777 pl.as_ref().head(Some(0)).lazy(),
778 case_sensitive,
779 ));
780 }
781 let idx_series = Series::new("idx".into(), (0..n).map(|i| i as u32).collect::<Vec<_>>());
782 let sampled_idx = idx_series.sample_n(take_n, with_replacement, true, seed)?;
783 let idx_ca = sampled_idx
784 .u32()
785 .map_err(|_| PolarsError::ComputeError("sample: expected u32 indices".into()))?;
786 let pl_df = pl.as_ref().take(idx_ca)?;
787 Ok(super::DataFrame::from_polars_with_options(
788 pl_df,
789 case_sensitive,
790 ))
791}
792
793pub fn random_split(
797 df: &DataFrame,
798 weights: &[f64],
799 seed: Option<u64>,
800 case_sensitive: bool,
801) -> Result<Vec<DataFrame>, PolarsError> {
802 let total: f64 = weights.iter().sum();
803 if total <= 0.0 || weights.is_empty() {
804 return Ok(Vec::new());
805 }
806 let pl = df.collect_inner()?;
807 let n = pl.height();
808 if n == 0 {
809 return Ok(weights.iter().map(|_| super::DataFrame::empty()).collect());
810 }
811 let mut cum = Vec::with_capacity(weights.len());
813 let mut acc = 0.0_f64;
814 for w in weights {
815 acc += w / total;
816 cum.push(acc);
817 }
818 use polars::prelude::Series;
820 use rand::Rng;
821 use rand::SeedableRng;
822 let mut rng = rand::rngs::StdRng::seed_from_u64(seed.unwrap_or(0));
823 let mut bucket_indices: Vec<Vec<u32>> = (0..weights.len()).map(|_| Vec::new()).collect();
824 for i in 0..n {
825 let r: f64 = rng.r#gen();
826 let bucket = cum
827 .iter()
828 .position(|&c| r < c)
829 .unwrap_or(weights.len().saturating_sub(1));
830 bucket_indices[bucket].push(i as u32);
831 }
832 let pl = pl.as_ref();
833 let mut out = Vec::with_capacity(weights.len());
834 for indices in bucket_indices {
835 if indices.is_empty() {
836 out.push(super::DataFrame::from_polars_with_options(
837 pl.clone().head(Some(0)),
838 case_sensitive,
839 ));
840 } else {
841 let idx_series = Series::new("idx".into(), indices);
842 let idx_ca = idx_series.u32().map_err(|_| {
843 PolarsError::ComputeError("random_split: expected u32 indices".into())
844 })?;
845 let taken = pl.take(idx_ca)?;
846 out.push(super::DataFrame::from_polars_with_options(
847 taken,
848 case_sensitive,
849 ));
850 }
851 }
852 Ok(out)
853}
854
855pub fn sample_by(
858 df: &DataFrame,
859 col_name: &str,
860 fractions: &[(Expr, f64)],
861 seed: Option<u64>,
862 case_sensitive: bool,
863) -> Result<DataFrame, PolarsError> {
864 use polars::prelude::*;
865 if fractions.is_empty() {
866 return Ok(super::DataFrame::from_lazy_with_options(
867 df.lazy_frame().slice(0, 0),
868 case_sensitive,
869 ));
870 }
871 let resolved = df.resolve_column_name(col_name)?;
872 let mut parts = Vec::with_capacity(fractions.len());
873 for (value_expr, frac) in fractions {
874 let cond = col(resolved.as_str()).eq(value_expr.clone());
875 let filtered = df.lazy_frame().filter(cond).collect()?;
876 if filtered.height() == 0 {
877 parts.push(filtered.head(Some(0)));
878 continue;
879 }
880 let sampled = sample(
881 &super::DataFrame::from_polars_with_options(filtered, case_sensitive),
882 false,
883 *frac,
884 seed,
885 case_sensitive,
886 )?;
887 parts.push(sampled.collect_inner()?.as_ref().clone());
888 }
889 let mut out = parts
890 .first()
891 .ok_or_else(|| PolarsError::ComputeError("sample_by: no parts".into()))?
892 .clone();
893 for p in parts.iter().skip(1) {
894 out.vstack_mut(p)?;
895 }
896 Ok(super::DataFrame::from_polars_with_options(
897 out,
898 case_sensitive,
899 ))
900}
901
902pub fn first(df: &DataFrame, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
906 let limited = limit(df, 1, case_sensitive)?;
907 let pl_df = limited.collect_inner()?.as_ref().clone();
908 Ok(super::DataFrame::from_polars_with_options(
909 pl_df,
910 case_sensitive,
911 ))
912}
913
914pub fn head(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
916 limit(df, n, case_sensitive)
917}
918
919pub fn take(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
921 limit(df, n, case_sensitive)
922}
923
924pub fn tail(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
926 let pl = df.collect_inner()?;
927 let total = pl.height();
928 let skip = total.saturating_sub(n);
929 let pl_df = pl.as_ref().clone().slice(skip as i64, n);
930 Ok(super::DataFrame::from_polars_with_options(
931 pl_df,
932 case_sensitive,
933 ))
934}
935
936pub fn is_empty(df: &DataFrame) -> bool {
938 df.count().map(|n| n == 0).unwrap_or(true)
939}
940
941pub fn to_df(
943 df: &DataFrame,
944 names: &[&str],
945 case_sensitive: bool,
946) -> Result<DataFrame, PolarsError> {
947 let cols = df.columns()?;
948 if names.len() != cols.len() {
949 return Err(PolarsError::ComputeError(
950 format!(
951 "toDF: expected {} column names, got {}",
952 cols.len(),
953 names.len()
954 )
955 .into(),
956 ));
957 }
958 let pl_df = df.collect_inner()?;
959 let mut pl_df = pl_df.as_ref().clone();
960 for (old, new) in cols.iter().zip(names.iter()) {
961 pl_df.rename(old.as_str(), (*new).into())?;
962 }
963 Ok(super::DataFrame::from_polars_with_options(
964 pl_df,
965 case_sensitive,
966 ))
967}
968
969fn any_value_to_serde_value(av: &polars::prelude::AnyValue) -> serde_json::Value {
972 use polars::prelude::AnyValue;
973 use serde_json::Number;
974 match av {
975 AnyValue::Null => serde_json::Value::Null,
976 AnyValue::Boolean(v) => serde_json::Value::Bool(*v),
977 AnyValue::Int8(v) => serde_json::Value::Number(Number::from(*v as i64)),
978 AnyValue::Int32(v) => serde_json::Value::Number(Number::from(*v)),
979 AnyValue::Int64(v) => serde_json::Value::Number(Number::from(*v)),
980 AnyValue::UInt32(v) => serde_json::Value::Number(Number::from(*v)),
981 AnyValue::Float64(v) => Number::from_f64(*v)
982 .map(serde_json::Value::Number)
983 .unwrap_or(serde_json::Value::Null),
984 AnyValue::String(v) => serde_json::Value::String(v.to_string()),
985 _ => serde_json::Value::String(format!("{av:?}")),
986 }
987}
988
989pub fn to_json(df: &DataFrame) -> Result<Vec<String>, PolarsError> {
991 use polars::prelude::*;
992 let collected = df.collect_inner()?;
993 let pl = collected.as_ref();
994 let names = pl.get_column_names();
995 let mut out = Vec::with_capacity(pl.height());
996 for r in 0..pl.height() {
997 let mut row = serde_json::Map::new();
998 for (i, name) in names.iter().enumerate() {
999 let col = pl
1000 .columns()
1001 .get(i)
1002 .ok_or_else(|| PolarsError::ComputeError("to_json: column index".into()))?;
1003 let series = col.as_materialized_series();
1004 let av = series
1005 .get(r)
1006 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
1007 row.insert(name.to_string(), any_value_to_serde_value(&av));
1008 }
1009 out.push(
1010 serde_json::to_string(&row)
1011 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?,
1012 );
1013 }
1014 Ok(out)
1015}
1016
1017pub fn explain(_df: &DataFrame) -> String {
1019 "DataFrame (eager Polars backend)".to_string()
1020}
1021
1022pub fn print_schema(df: &DataFrame) -> Result<String, PolarsError> {
1024 let schema = df.schema()?;
1025 let mut s = "root\n".to_string();
1026 for f in schema.fields() {
1027 let dt = match &f.data_type {
1028 crate::schema::DataType::String => "string",
1029 crate::schema::DataType::Integer => "int",
1030 crate::schema::DataType::Long => "bigint",
1031 crate::schema::DataType::Double => "double",
1032 crate::schema::DataType::Boolean => "boolean",
1033 crate::schema::DataType::Date => "date",
1034 crate::schema::DataType::Timestamp => "timestamp",
1035 _ => "string",
1036 };
1037 s.push_str(&format!(" |-- {}: {}\n", f.name, dt));
1038 }
1039 Ok(s)
1040}
1041
1042pub fn select_expr(
1046 df: &DataFrame,
1047 exprs: &[String],
1048 case_sensitive: bool,
1049) -> Result<DataFrame, PolarsError> {
1050 let mut cols = Vec::new();
1051 for e in exprs {
1052 let e = e.trim();
1053 if let Some((left, right)) = e.split_once(" as ") {
1054 let col_name = left.trim();
1055 let _alias = right.trim();
1056 cols.push(df.resolve_column_name(col_name)?);
1057 } else {
1058 cols.push(df.resolve_column_name(e)?);
1059 }
1060 }
1061 let refs: Vec<&str> = cols.iter().map(|s| s.as_str()).collect();
1062 select(df, refs, case_sensitive)
1063}
1064
1065pub fn col_regex(
1067 df: &DataFrame,
1068 pattern: &str,
1069 case_sensitive: bool,
1070) -> Result<DataFrame, PolarsError> {
1071 let re = regex::Regex::new(pattern).map_err(|e| {
1072 PolarsError::ComputeError(format!("colRegex: invalid pattern {pattern:?}: {e}").into())
1073 })?;
1074 let names = df.columns()?;
1075 let matched: Vec<&str> = names
1076 .iter()
1077 .filter(|n| re.is_match(n))
1078 .map(|s| s.as_str())
1079 .collect();
1080 if matched.is_empty() {
1081 return Err(PolarsError::ComputeError(
1082 format!("colRegex: no columns matched pattern {pattern:?}").into(),
1083 ));
1084 }
1085 select(df, matched, case_sensitive)
1086}
1087
1088pub fn with_columns(
1090 df: &DataFrame,
1091 exprs: &[(String, crate::column::Column)],
1092 case_sensitive: bool,
1093) -> Result<DataFrame, PolarsError> {
1094 let pl = df.collect_inner()?.as_ref().clone();
1095 let mut current = super::DataFrame::from_polars_with_options(pl, case_sensitive);
1096 for (name, col) in exprs {
1097 current = with_column(¤t, name, col, case_sensitive)?;
1098 }
1099 Ok(current)
1100}
1101
1102pub fn with_columns_renamed(
1104 df: &DataFrame,
1105 renames: &[(String, String)],
1106 case_sensitive: bool,
1107) -> Result<DataFrame, PolarsError> {
1108 let mut mapping = Vec::new();
1109 for (old_name, new_name) in renames {
1110 let resolved = df.resolve_column_name(old_name)?;
1111 mapping.push((resolved, new_name.clone()));
1112 }
1113 let mut lf = df.lazy_frame();
1114 for (old, new) in mapping {
1115 lf = lf.rename([old.as_str()], [new.as_str()], true);
1116 }
1117 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
1118}
1119
1120pub struct DataFrameNa<'a> {
1122 pub(crate) df: &'a DataFrame,
1123}
1124
1125impl<'a> DataFrameNa<'a> {
1126 pub fn new(df: &'a DataFrame) -> Self {
1128 DataFrameNa { df }
1129 }
1130
1131 pub fn fill(&self, value: Expr, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
1133 fillna(self.df, value, subset, self.df.case_sensitive)
1134 }
1135
1136 pub fn replace(
1138 &self,
1139 old_value: Expr,
1140 new_value: Expr,
1141 subset: Option<Vec<&str>>,
1142 ) -> Result<DataFrame, PolarsError> {
1143 let cols: Vec<String> = match &subset {
1144 Some(s) => s.iter().map(|x| (*x).to_string()).collect(),
1145 None => self.df.columns()?,
1146 };
1147 let mut result = self.df.clone();
1148 for col_name in &cols {
1149 result = replace(
1150 &result,
1151 col_name.as_str(),
1152 old_value.clone(),
1153 new_value.clone(),
1154 self.df.case_sensitive,
1155 )?;
1156 }
1157 Ok(result)
1158 }
1159
1160 pub fn drop(
1162 &self,
1163 subset: Option<Vec<&str>>,
1164 how: &str,
1165 thresh: Option<usize>,
1166 ) -> Result<DataFrame, PolarsError> {
1167 dropna(self.df, subset, how, thresh, self.df.case_sensitive)
1168 }
1169}
1170
1171pub fn offset(df: &DataFrame, n: usize, case_sensitive: bool) -> Result<DataFrame, PolarsError> {
1175 let lf = df.lazy_frame().slice(n as i64, u32::MAX);
1176 Ok(super::DataFrame::from_lazy_with_options(lf, case_sensitive))
1177}
1178
1179pub fn transform<F>(df: &DataFrame, f: F) -> Result<DataFrame, PolarsError>
1181where
1182 F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
1183{
1184 let df_out = f(df.clone())?;
1185 Ok(df_out)
1186}
1187
1188pub fn freq_items(
1190 df: &DataFrame,
1191 columns: &[&str],
1192 support: f64,
1193 case_sensitive: bool,
1194) -> Result<DataFrame, PolarsError> {
1195 use polars::prelude::SeriesMethods;
1196 if columns.is_empty() {
1197 return Ok(super::DataFrame::from_lazy_with_options(
1198 df.lazy_frame().slice(0, 0),
1199 case_sensitive,
1200 ));
1201 }
1202 let support = support.clamp(1e-4, 1.0);
1203 let collected = df.collect_inner()?;
1204 let pl_df = collected.as_ref();
1205 let n_total = pl_df.height() as f64;
1206 if n_total == 0.0 {
1207 let mut out = Vec::with_capacity(columns.len());
1208 for col_name in columns {
1209 let resolved = df.resolve_column_name(col_name)?;
1210 let s = pl_df
1211 .column(resolved.as_str())?
1212 .as_series()
1213 .ok_or_else(|| PolarsError::ComputeError("column not a series".into()))?
1214 .clone();
1215 let empty_sub = s.head(Some(0));
1216 let list_chunked = polars::prelude::ListChunked::from_iter([empty_sub].into_iter())
1217 .with_name(format!("{resolved}_freqItems").into());
1218 out.push(list_chunked.into_series().into());
1219 }
1220 return Ok(super::DataFrame::from_polars_with_options(
1221 polars::prelude::DataFrame::new_infer_height(out)?,
1222 case_sensitive,
1223 ));
1224 }
1225 let mut out_series = Vec::with_capacity(columns.len());
1226 for col_name in columns {
1227 let resolved = df.resolve_column_name(col_name)?;
1228 let s = pl_df
1229 .column(resolved.as_str())?
1230 .as_series()
1231 .ok_or_else(|| PolarsError::ComputeError("column not a series".into()))?
1232 .clone();
1233 let vc = s.value_counts(false, false, "counts".into(), false)?;
1234 let count_col = vc
1235 .column("counts")
1236 .map_err(|_| PolarsError::ComputeError("value_counts missing counts column".into()))?;
1237 let counts = count_col
1238 .u32()
1239 .map_err(|_| PolarsError::ComputeError("freq_items: counts column not u32".into()))?;
1240 let value_col_name = s.name();
1241 let values_col = vc
1242 .column(value_col_name.as_str())
1243 .map_err(|_| PolarsError::ComputeError("value_counts missing value column".into()))?;
1244 let threshold = (support * n_total).ceil() as u32;
1245 let indices: Vec<u32> = counts
1246 .into_iter()
1247 .enumerate()
1248 .filter_map(|(i, c)| {
1249 if c? >= threshold {
1250 Some(i as u32)
1251 } else {
1252 None
1253 }
1254 })
1255 .collect();
1256 let idx_series = Series::new("idx".into(), indices);
1257 let idx_ca = idx_series
1258 .u32()
1259 .map_err(|_| PolarsError::ComputeError("freq_items: index series not u32".into()))?;
1260 let values_series = values_col
1261 .as_series()
1262 .ok_or_else(|| PolarsError::ComputeError("value column not a series".into()))?;
1263 let filtered = values_series.take(idx_ca)?;
1264 let list_chunked = polars::prelude::ListChunked::from_iter([filtered].into_iter())
1265 .with_name(format!("{resolved}_freqItems").into());
1266 let list_row = list_chunked.into_series();
1267 out_series.push(list_row.into());
1268 }
1269 let out_df = polars::prelude::DataFrame::new_infer_height(out_series)?;
1270 Ok(super::DataFrame::from_polars_with_options(
1271 out_df,
1272 case_sensitive,
1273 ))
1274}
1275
1276pub fn approx_quantile(
1278 df: &DataFrame,
1279 column: &str,
1280 probabilities: &[f64],
1281 case_sensitive: bool,
1282) -> Result<DataFrame, PolarsError> {
1283 use polars::prelude::{ChunkQuantile, QuantileMethod};
1284 if probabilities.is_empty() {
1285 return Ok(super::DataFrame::from_polars_with_options(
1286 polars::prelude::DataFrame::new_infer_height(vec![
1287 Series::new("quantile".into(), Vec::<f64>::new()).into(),
1288 ])?,
1289 case_sensitive,
1290 ));
1291 }
1292 let resolved = df.resolve_column_name(column)?;
1293 let collected = df.collect_inner()?;
1294 let s = collected
1295 .column(resolved.as_str())?
1296 .as_series()
1297 .ok_or_else(|| PolarsError::ComputeError("approx_quantile: column not a series".into()))?
1298 .clone();
1299 let ca = series_as_f64_ca(&s, "approx_quantile")?;
1300 let mut quantiles = Vec::with_capacity(probabilities.len());
1301 for &p in probabilities {
1302 let q = ca.quantile(p, QuantileMethod::Linear)?;
1303 quantiles.push(q.unwrap_or(f64::NAN));
1304 }
1305 let out_df = polars::prelude::DataFrame::new_infer_height(vec![
1306 Series::new("quantile".into(), quantiles).into(),
1307 ])?;
1308 Ok(super::DataFrame::from_polars_with_options(
1309 out_df,
1310 case_sensitive,
1311 ))
1312}
1313
1314pub fn crosstab(
1316 df: &DataFrame,
1317 col1: &str,
1318 col2: &str,
1319 case_sensitive: bool,
1320) -> Result<DataFrame, PolarsError> {
1321 use polars::prelude::*;
1322 let c1 = df.resolve_column_name(col1)?;
1323 let c2 = df.resolve_column_name(col2)?;
1324 let collected = df.collect_inner()?;
1325 let pl_df = collected.as_ref();
1326 let grouped = pl_df
1327 .clone()
1328 .lazy()
1329 .group_by([col(c1.as_str()), col(c2.as_str())])
1330 .agg([len().alias("count")])
1331 .collect()?;
1332 Ok(super::DataFrame::from_polars_with_options(
1333 grouped,
1334 case_sensitive,
1335 ))
1336}
1337
1338pub fn melt(
1340 df: &DataFrame,
1341 id_vars: &[&str],
1342 value_vars: &[&str],
1343 case_sensitive: bool,
1344) -> Result<DataFrame, PolarsError> {
1345 use polars::prelude::*;
1346 let collected = df.collect_inner()?;
1347 let pl_df = collected.as_ref();
1348 if value_vars.is_empty() {
1349 return Ok(super::DataFrame::from_polars_with_options(
1350 pl_df.head(Some(0)),
1351 case_sensitive,
1352 ));
1353 }
1354 let id_resolved: Vec<String> = id_vars
1355 .iter()
1356 .map(|s| df.resolve_column_name(s).map(|r| r.to_string()))
1357 .collect::<Result<Vec<_>, _>>()?;
1358 let value_resolved: Vec<String> = value_vars
1359 .iter()
1360 .map(|s| df.resolve_column_name(s).map(|r| r.to_string()))
1361 .collect::<Result<Vec<_>, _>>()?;
1362 let mut parts = Vec::with_capacity(value_vars.len());
1363 for vname in &value_resolved {
1364 let select_cols: Vec<&str> = id_resolved
1365 .iter()
1366 .map(|s| s.as_str())
1367 .chain([vname.as_str()])
1368 .collect();
1369 let mut part = pl_df.select(select_cols)?;
1370 let var_series = Series::new("variable".into(), vec![vname.as_str(); part.height()]);
1371 part.with_column(var_series.into())?;
1372 part.rename(vname.as_str(), "value".into())?;
1373 parts.push(part);
1374 }
1375 let mut out = parts
1376 .first()
1377 .ok_or_else(|| PolarsError::ComputeError("melt: no value columns".into()))?
1378 .clone();
1379 for p in parts.iter().skip(1) {
1380 out.vstack_mut(p)?;
1381 }
1382 let col_order: Vec<&str> = id_resolved
1383 .iter()
1384 .map(|s| s.as_str())
1385 .chain(["variable", "value"])
1386 .collect();
1387 let out = out.select(col_order)?;
1388 Ok(super::DataFrame::from_polars_with_options(
1389 out,
1390 case_sensitive,
1391 ))
1392}
1393
1394pub fn except_all(
1396 left: &DataFrame,
1397 right: &DataFrame,
1398 case_sensitive: bool,
1399) -> Result<DataFrame, PolarsError> {
1400 subtract(left, right, case_sensitive)
1401}
1402
1403pub fn intersect_all(
1405 left: &DataFrame,
1406 right: &DataFrame,
1407 case_sensitive: bool,
1408) -> Result<DataFrame, PolarsError> {
1409 intersect(left, right, case_sensitive)
1410}
1411
1412#[cfg(test)]
1413mod tests {
1414 use super::{distinct, drop, dropna, first, head, limit, offset, order_by, union_by_name};
1415 use crate::{DataFrame, SparkSession};
1416 use serde_json::json;
1417
1418 fn test_df() -> DataFrame {
1419 let spark = SparkSession::builder()
1420 .app_name("transform_tests")
1421 .get_or_create();
1422 spark
1423 .create_dataframe(
1424 vec![
1425 (1i64, 10i64, "a".to_string()),
1426 (2i64, 20i64, "b".to_string()),
1427 (3i64, 30i64, "c".to_string()),
1428 ],
1429 vec!["id", "v", "label"],
1430 )
1431 .unwrap()
1432 }
1433
1434 #[test]
1435 fn limit_zero() {
1436 let df = test_df();
1437 let out = limit(&df, 0, false).unwrap();
1438 assert_eq!(out.count().unwrap(), 0);
1439 }
1440
1441 #[test]
1442 fn limit_more_than_rows() {
1443 let df = test_df();
1444 let out = limit(&df, 10, false).unwrap();
1445 assert_eq!(out.count().unwrap(), 3);
1446 }
1447
1448 #[test]
1449 fn distinct_on_empty() {
1450 let spark = SparkSession::builder()
1451 .app_name("transform_tests")
1452 .get_or_create();
1453 let df = spark
1454 .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["a", "b", "c"])
1455 .unwrap();
1456 let out = distinct(&df, None, false).unwrap();
1457 assert_eq!(out.count().unwrap(), 0);
1458 }
1459
1460 #[test]
1461 fn first_returns_one_row() {
1462 let df = test_df();
1463 let out = first(&df, false).unwrap();
1464 assert_eq!(out.count().unwrap(), 1);
1465 }
1466
1467 #[test]
1469 fn first_after_order_by_returns_first_in_sort_order() {
1470 use polars::prelude::df;
1471 let spark = SparkSession::builder()
1472 .app_name("transform_tests")
1473 .get_or_create();
1474 let pl = df![
1475 "name" => ["Charlie", "Alice", "Bob"],
1476 "value" => [3i64, 1i64, 2i64],
1477 ]
1478 .unwrap();
1479 let df = spark.create_dataframe_from_polars(pl);
1480 let ordered = order_by(&df, vec!["value"], vec![true], false).unwrap();
1481 let one = first(&ordered, false).unwrap();
1482 let collected = one.collect_inner().unwrap();
1483 let name_series = collected.column("name").unwrap();
1484 let first_name = name_series.str().unwrap().get(0).unwrap();
1485 assert_eq!(
1486 first_name, "Alice",
1487 "first() after orderBy(value) must return row with min value (Alice=1), not first in storage (Charlie)"
1488 );
1489 }
1490
1491 #[test]
1492 fn head_n() {
1493 let df = test_df();
1494 let out = head(&df, 2, false).unwrap();
1495 assert_eq!(out.count().unwrap(), 2);
1496 }
1497
1498 #[test]
1499 fn offset_skip_first() {
1500 let df = test_df();
1501 let out = offset(&df, 1, false).unwrap();
1502 assert_eq!(out.count().unwrap(), 2);
1503 }
1504
1505 #[test]
1506 fn offset_beyond_length_returns_empty() {
1507 let df = test_df();
1508 let out = offset(&df, 10, false).unwrap();
1509 assert_eq!(out.count().unwrap(), 0);
1510 }
1511
1512 #[test]
1513 fn drop_column() {
1514 let df = test_df();
1515 let out = drop(&df, vec!["v"], false).unwrap();
1516 let cols = out.columns().unwrap();
1517 assert!(!cols.contains(&"v".to_string()));
1518 assert_eq!(out.count().unwrap(), 3);
1519 }
1520
1521 #[test]
1523 fn union_by_name_coerces_different_column_types() {
1524 use polars::prelude::df;
1525
1526 let spark = SparkSession::builder()
1527 .app_name("transform_tests")
1528 .get_or_create();
1529 let left_pl = df!("id" => &[1i64], "name" => &["a"]).unwrap();
1530 let left = spark.create_dataframe_from_polars(left_pl);
1531 let schema = vec![
1532 ("id".to_string(), "string".to_string()),
1533 ("name".to_string(), "string".to_string()),
1534 ];
1535 let right = spark
1536 .create_dataframe_from_rows(vec![vec![json!("2"), json!("b")]], schema)
1537 .unwrap();
1538 let out = union_by_name(&left, &right, true, false)
1539 .expect("issue #603: union_by_name must coerce id Int64 vs String");
1540 assert_eq!(out.count().unwrap(), 2);
1541 }
1542
1543 #[test]
1544 fn dropna_all_columns() {
1545 let df = test_df();
1546 let out = dropna(&df, None, "any", None, false).unwrap();
1547 assert_eq!(out.count().unwrap(), 3);
1548 }
1549}