mod aggregations;
mod joins;
mod stats;
mod transformations;
pub use aggregations::{CubeRollupData, GroupedData, PivotedGroupedData};
pub use joins::{join, JoinType};
pub use stats::DataFrameStat;
pub use transformations::{
filter, order_by, order_by_exprs, select, select_with_exprs, with_column, DataFrameNa,
};
use crate::column::Column;
use crate::functions::SortOrder;
use crate::schema::StructType;
use crate::session::SparkSession;
use crate::type_coercion::coerce_for_pyspark_comparison;
use polars::prelude::{
col, lit, AnyValue, DataFrame as PlDataFrame, DataType, Expr, IntoLazy, LazyFrame, PlSmallStr,
PolarsError, Schema, SchemaNamesAndDtypes,
};
use serde_json::Value as JsonValue;
use std::collections::{HashMap, HashSet};
use std::path::Path;
use std::sync::Arc;
const DEFAULT_CASE_SENSITIVE: bool = false;
#[allow(clippy::large_enum_variant)]
pub(crate) enum DataFrameInner {
#[allow(dead_code)]
Eager(Arc<PlDataFrame>),
Lazy(LazyFrame),
}
pub struct DataFrame {
pub(crate) inner: DataFrameInner,
pub(crate) case_sensitive: bool,
pub(crate) alias: Option<String>,
}
impl DataFrame {
pub fn from_polars(df: PlDataFrame) -> Self {
let lf = df.lazy();
DataFrame {
inner: DataFrameInner::Lazy(lf),
case_sensitive: DEFAULT_CASE_SENSITIVE,
alias: None,
}
}
pub fn from_polars_with_options(df: PlDataFrame, case_sensitive: bool) -> Self {
let lf = df.lazy();
DataFrame {
inner: DataFrameInner::Lazy(lf),
case_sensitive,
alias: None,
}
}
pub fn from_lazy(lf: LazyFrame) -> Self {
DataFrame {
inner: DataFrameInner::Lazy(lf),
case_sensitive: DEFAULT_CASE_SENSITIVE,
alias: None,
}
}
pub fn from_lazy_with_options(lf: LazyFrame, case_sensitive: bool) -> Self {
DataFrame {
inner: DataFrameInner::Lazy(lf),
case_sensitive,
alias: None,
}
}
pub fn empty() -> Self {
DataFrame {
inner: DataFrameInner::Lazy(PlDataFrame::empty().lazy()),
case_sensitive: DEFAULT_CASE_SENSITIVE,
alias: None,
}
}
pub(crate) fn lazy_frame(&self) -> LazyFrame {
match &self.inner {
DataFrameInner::Eager(df) => df.as_ref().clone().lazy(),
DataFrameInner::Lazy(lf) => lf.clone(),
}
}
pub(crate) fn collect_inner(&self) -> Result<Arc<PlDataFrame>, PolarsError> {
match &self.inner {
DataFrameInner::Eager(df) => Ok(df.clone()),
DataFrameInner::Lazy(lf) => Ok(Arc::new(lf.clone().collect()?)),
}
}
pub fn alias(&self, name: &str) -> Self {
let lf = self.lazy_frame();
DataFrame {
inner: DataFrameInner::Lazy(lf),
case_sensitive: self.case_sensitive,
alias: Some(name.to_string()),
}
}
pub fn resolve_expr_column_names(&self, expr: Expr) -> Result<Expr, PolarsError> {
let df = self;
let mut alias_output_names: HashSet<String> = HashSet::new();
let _ = expr.clone().try_map_expr(|e| {
if let Expr::Alias(_, name) = &e {
alias_output_names.insert(name.as_str().to_string());
}
Ok(e)
})?;
expr.try_map_expr(move |e| {
if let Expr::Column(name) = &e {
let name_str = name.as_str();
if alias_output_names.contains(name_str) {
return Ok(e);
}
if name_str.contains('.') {
let parts: Vec<&str> = name_str.split('.').collect();
let first = parts[0];
let rest = &parts[1..];
if rest.is_empty() {
return Err(PolarsError::ColumnNotFound(
format!("Column '{}': trailing dot not allowed", name_str).into(),
));
}
let resolved = df.resolve_column_name(first)?;
let mut expr = col(PlSmallStr::from(resolved.as_str()));
for field in rest {
expr = expr.struct_().field_by_name(field);
}
return Ok(expr);
}
let resolved = df.resolve_column_name(name_str)?;
return Ok(Expr::Column(PlSmallStr::from(resolved.as_str())));
}
Ok(e)
})
}
pub fn coerce_string_numeric_comparisons(&self, expr: Expr) -> Result<Expr, PolarsError> {
use polars::prelude::{DataType, LiteralValue, Operator};
use std::sync::Arc;
fn is_numeric_literal(expr: &Expr) -> bool {
matches!(
expr,
Expr::Literal(
LiteralValue::Int32(_)
| LiteralValue::Int64(_)
| LiteralValue::UInt32(_)
| LiteralValue::UInt64(_)
| LiteralValue::Float32(_)
| LiteralValue::Float64(_)
| LiteralValue::Int(_) | LiteralValue::Float(_) )
)
}
fn literal_dtype(lv: &LiteralValue) -> DataType {
match lv {
LiteralValue::Int32(_) => DataType::Int32,
LiteralValue::Int64(_) => DataType::Int64,
LiteralValue::UInt32(_) => DataType::UInt32,
LiteralValue::UInt64(_) => DataType::UInt64,
LiteralValue::Float32(_) => DataType::Float32,
LiteralValue::Float64(_) => DataType::Float64,
LiteralValue::Int(_) | LiteralValue::Float(_) => DataType::Float64,
_ => DataType::Float64,
}
}
let expr = {
if let Expr::BinaryExpr { left, op, right } = &expr {
let is_comparison_op = matches!(
op,
Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::LtEq
| Operator::Gt
| Operator::GtEq
);
let left_is_col = matches!(&**left, Expr::Column(_));
let right_is_col = matches!(&**right, Expr::Column(_));
let left_is_numeric_lit =
matches!(&**left, Expr::Literal(_)) && is_numeric_literal(left.as_ref());
let right_is_numeric_lit =
matches!(&**right, Expr::Literal(_)) && is_numeric_literal(right.as_ref());
let left_is_string_lit = matches!(&**left, Expr::Literal(LiteralValue::String(_)));
let right_is_string_lit =
matches!(&**right, Expr::Literal(LiteralValue::String(_)));
let root_is_col_vs_numeric = is_comparison_op
&& ((left_is_col && right_is_numeric_lit)
|| (right_is_col && left_is_numeric_lit));
let root_is_col_vs_string = is_comparison_op
&& ((left_is_col && right_is_string_lit)
|| (right_is_col && left_is_string_lit));
if root_is_col_vs_numeric {
let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
let lit_ty = match &**right {
Expr::Literal(lv) => literal_dtype(lv),
_ => DataType::Float64,
};
coerce_for_pyspark_comparison(
(*left).as_ref().clone(),
(*right).as_ref().clone(),
&DataType::String,
&lit_ty,
op,
)
.map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
} else {
let lit_ty = match &**left {
Expr::Literal(lv) => literal_dtype(lv),
_ => DataType::Float64,
};
coerce_for_pyspark_comparison(
(*left).as_ref().clone(),
(*right).as_ref().clone(),
&lit_ty,
&DataType::String,
op,
)
.map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
};
Expr::BinaryExpr {
left: Arc::new(new_left),
op: *op,
right: Arc::new(new_right),
}
} else if root_is_col_vs_string {
let col_name = if left_is_col {
if let Expr::Column(n) = &**left {
n.as_str()
} else {
unreachable!()
}
} else if let Expr::Column(n) = &**right {
n.as_str()
} else {
unreachable!()
};
if let Some(col_dtype) = self.get_column_dtype(col_name) {
if matches!(col_dtype, DataType::Date | DataType::Datetime(_, _)) {
let (left_ty, right_ty) = if left_is_col {
(col_dtype.clone(), DataType::String)
} else {
(DataType::String, col_dtype.clone())
};
let (new_left, new_right) = coerce_for_pyspark_comparison(
(*left).as_ref().clone(),
(*right).as_ref().clone(),
&left_ty,
&right_ty,
op,
)
.map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
return Ok(Expr::BinaryExpr {
left: Arc::new(new_left),
op: *op,
right: Arc::new(new_right),
});
}
}
expr
} else if is_comparison_op && left_is_col && right_is_col {
let left_name = if let Expr::Column(n) = &**left {
n.as_str()
} else {
unreachable!()
};
let right_name = if let Expr::Column(n) = &**right {
n.as_str()
} else {
unreachable!()
};
if let (Some(left_ty), Some(right_ty)) = (
self.get_column_dtype(left_name),
self.get_column_dtype(right_name),
) {
if left_ty != right_ty {
if let Ok((new_left, new_right)) = coerce_for_pyspark_comparison(
(*left).as_ref().clone(),
(*right).as_ref().clone(),
&left_ty,
&right_ty,
op,
) {
return Ok(Expr::BinaryExpr {
left: Arc::new(new_left),
op: *op,
right: Arc::new(new_right),
});
}
}
}
expr
} else {
expr
}
} else {
expr
}
};
expr.try_map_expr(move |e| {
if let Expr::BinaryExpr { left, op, right } = e {
let is_comparison_op = matches!(
op,
Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::LtEq
| Operator::Gt
| Operator::GtEq
);
if !is_comparison_op {
return Ok(Expr::BinaryExpr { left, op, right });
}
let left_is_col = matches!(&*left, Expr::Column(_));
let right_is_col = matches!(&*right, Expr::Column(_));
let left_is_lit = matches!(&*left, Expr::Literal(_));
let right_is_lit = matches!(&*right, Expr::Literal(_));
let left_is_string_lit = matches!(&*left, Expr::Literal(LiteralValue::String(_)));
let right_is_string_lit = matches!(&*right, Expr::Literal(LiteralValue::String(_)));
let left_is_numeric_lit = left_is_lit && is_numeric_literal(left.as_ref());
let right_is_numeric_lit = right_is_lit && is_numeric_literal(right.as_ref());
let (new_left, new_right) = if left_is_col && right_is_numeric_lit {
let lit_ty = match &*right {
Expr::Literal(lv) => literal_dtype(lv),
_ => DataType::Float64,
};
coerce_for_pyspark_comparison(
(*left).clone(),
(*right).clone(),
&DataType::String,
&lit_ty,
&op,
)
.map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
} else if right_is_col && left_is_numeric_lit {
let lit_ty = match &*left {
Expr::Literal(lv) => literal_dtype(lv),
_ => DataType::Float64,
};
coerce_for_pyspark_comparison(
(*left).clone(),
(*right).clone(),
&lit_ty,
&DataType::String,
&op,
)
.map_err(|e| PolarsError::ComputeError(e.to_string().into()))?
} else if (left_is_col && right_is_string_lit)
|| (right_is_col && left_is_string_lit)
{
let col_name = if left_is_col {
if let Expr::Column(n) = &*left {
n.as_str()
} else {
unreachable!()
}
} else if let Expr::Column(n) = &*right {
n.as_str()
} else {
unreachable!()
};
if let Some(col_dtype) = self.get_column_dtype(col_name) {
if matches!(col_dtype, DataType::Date | DataType::Datetime(_, _)) {
let (left_ty, right_ty) = if left_is_col {
(col_dtype.clone(), DataType::String)
} else {
(DataType::String, col_dtype.clone())
};
let (new_l, new_r) = coerce_for_pyspark_comparison(
(*left).clone(),
(*right).clone(),
&left_ty,
&right_ty,
&op,
)
.map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
return Ok(Expr::BinaryExpr {
left: Arc::new(new_l),
op,
right: Arc::new(new_r),
});
}
}
return Ok(Expr::BinaryExpr { left, op, right });
} else {
return Ok(Expr::BinaryExpr { left, op, right });
};
Ok(Expr::BinaryExpr {
left: Arc::new(new_left),
op,
right: Arc::new(new_right),
})
} else {
Ok(e)
}
})
}
fn schema_or_collect(&self) -> Result<Arc<Schema>, PolarsError> {
match &self.inner {
DataFrameInner::Eager(df) => Ok(Arc::new(df.schema())),
DataFrameInner::Lazy(lf) => Ok(lf.clone().collect_schema()?),
}
}
pub fn resolve_column_name(&self, name: &str) -> Result<String, PolarsError> {
let schema = self.schema_or_collect()?;
let names: Vec<String> = schema
.iter_names_and_dtypes()
.map(|(n, _)| n.to_string())
.collect();
if self.case_sensitive {
if names.iter().any(|n| n == name) {
return Ok(name.to_string());
}
} else {
let name_lower = name.to_lowercase();
for n in &names {
if n.to_lowercase() == name_lower {
return Ok(n.clone());
}
}
}
let available = names.join(", ");
Err(PolarsError::ColumnNotFound(
format!(
"Column '{}' not found. Available columns: [{}]. Check spelling and case sensitivity (spark.sql.caseSensitive).",
name,
available
)
.into(),
))
}
pub fn schema(&self) -> Result<StructType, PolarsError> {
let s = self.schema_or_collect()?;
Ok(StructType::from_polars_schema(&s))
}
pub fn get_column_dtype(&self, name: &str) -> Option<DataType> {
let resolved = self.resolve_column_name(name).ok()?;
self.schema_or_collect()
.ok()?
.iter_names_and_dtypes()
.find(|(n, _)| n.to_string() == resolved)
.map(|(_, dt)| dt.clone())
}
pub fn columns(&self) -> Result<Vec<String>, PolarsError> {
let schema = self.schema_or_collect()?;
Ok(schema
.iter_names_and_dtypes()
.map(|(n, _)| n.to_string())
.collect())
}
pub fn count(&self) -> Result<usize, PolarsError> {
Ok(self.collect_inner()?.height())
}
pub fn show(&self, n: Option<usize>) -> Result<(), PolarsError> {
let n = n.unwrap_or(20);
let df = self.collect_inner()?;
println!("{}", df.head(Some(n)));
Ok(())
}
pub fn collect(&self) -> Result<Arc<PlDataFrame>, PolarsError> {
self.collect_inner()
}
pub fn collect_as_json_rows(&self) -> Result<Vec<HashMap<String, JsonValue>>, PolarsError> {
let collected = self.collect_inner()?;
let names = collected.get_column_names();
let nrows = collected.height();
let mut rows = Vec::with_capacity(nrows);
for i in 0..nrows {
let mut row = HashMap::with_capacity(names.len());
for (col_idx, name) in names.iter().enumerate() {
let s = collected
.get_columns()
.get(col_idx)
.ok_or_else(|| PolarsError::ComputeError("column index out of range".into()))?;
let av = s.get(i)?;
let jv = any_value_to_json(av);
row.insert(name.to_string(), jv);
}
rows.push(row);
}
Ok(rows)
}
pub fn select_exprs(&self, exprs: Vec<Expr>) -> Result<DataFrame, PolarsError> {
transformations::select_with_exprs(self, exprs, self.case_sensitive)
}
pub fn select(&self, cols: Vec<&str>) -> Result<DataFrame, PolarsError> {
let resolved: Vec<String> = cols
.iter()
.map(|c| self.resolve_column_name(c))
.collect::<Result<Vec<_>, _>>()?;
let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
let mut result = transformations::select(self, refs, self.case_sensitive)?;
if !self.case_sensitive {
for (requested, res) in cols.iter().zip(resolved.iter()) {
if *requested != res.as_str() {
result = result.with_column_renamed(res, requested)?;
}
}
}
Ok(result)
}
pub fn filter(&self, condition: Expr) -> Result<DataFrame, PolarsError> {
transformations::filter(self, condition, self.case_sensitive)
}
pub fn column(&self, name: &str) -> Result<Column, PolarsError> {
let resolved = self.resolve_column_name(name)?;
Ok(Column::new(resolved))
}
pub fn with_column(&self, column_name: &str, col: &Column) -> Result<DataFrame, PolarsError> {
transformations::with_column(self, column_name, col, self.case_sensitive)
}
pub fn with_column_expr(
&self,
column_name: &str,
expr: Expr,
) -> Result<DataFrame, PolarsError> {
let col = Column::from_expr(expr, None);
self.with_column(column_name, &col)
}
pub fn group_by(&self, column_names: Vec<&str>) -> Result<GroupedData, PolarsError> {
use polars::prelude::*;
let resolved: Vec<String> = column_names
.iter()
.map(|c| self.resolve_column_name(c))
.collect::<Result<Vec<_>, _>>()?;
let exprs: Vec<Expr> = resolved.iter().map(|name| col(name.as_str())).collect();
let lf = self.lazy_frame();
let lazy_grouped = lf.clone().group_by(exprs);
Ok(GroupedData {
lf,
lazy_grouped,
grouping_cols: resolved,
case_sensitive: self.case_sensitive,
})
}
pub fn group_by_exprs(
&self,
exprs: Vec<Expr>,
grouping_col_names: Vec<String>,
) -> Result<GroupedData, PolarsError> {
use polars::prelude::*;
if exprs.len() != grouping_col_names.len() {
return Err(PolarsError::ComputeError(
format!(
"group_by_exprs: {} exprs but {} names",
exprs.len(),
grouping_col_names.len()
)
.into(),
));
}
let resolved: Vec<Expr> = exprs
.into_iter()
.map(|e| self.resolve_expr_column_names(e))
.collect::<Result<Vec<_>, _>>()?;
let lf = self.lazy_frame();
let lazy_grouped = lf.clone().group_by(resolved);
Ok(GroupedData {
lf,
lazy_grouped,
grouping_cols: grouping_col_names,
case_sensitive: self.case_sensitive,
})
}
pub fn cube(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
let resolved: Vec<String> = column_names
.iter()
.map(|c| self.resolve_column_name(c))
.collect::<Result<Vec<_>, _>>()?;
Ok(CubeRollupData {
lf: self.lazy_frame(),
grouping_cols: resolved,
case_sensitive: self.case_sensitive,
is_cube: true,
})
}
pub fn rollup(&self, column_names: Vec<&str>) -> Result<CubeRollupData, PolarsError> {
let resolved: Vec<String> = column_names
.iter()
.map(|c| self.resolve_column_name(c))
.collect::<Result<Vec<_>, _>>()?;
Ok(CubeRollupData {
lf: self.lazy_frame(),
grouping_cols: resolved,
case_sensitive: self.case_sensitive,
is_cube: false,
})
}
pub fn agg(&self, aggregations: Vec<Expr>) -> Result<DataFrame, PolarsError> {
let resolved: Vec<Expr> = aggregations
.into_iter()
.map(|e| self.resolve_expr_column_names(e))
.collect::<Result<Vec<_>, _>>()?;
let disambiguated = aggregations::disambiguate_agg_output_names(resolved);
let pl_df = self.lazy_frame().select(disambiguated).collect()?;
Ok(Self::from_polars_with_options(pl_df, self.case_sensitive))
}
pub fn join(
&self,
other: &DataFrame,
on: Vec<&str>,
how: JoinType,
) -> Result<DataFrame, PolarsError> {
let resolved: Vec<String> = on
.iter()
.map(|c| self.resolve_column_name(c))
.collect::<Result<Vec<_>, _>>()?;
let on_refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
join(self, other, on_refs, how, self.case_sensitive)
}
pub fn order_by(
&self,
column_names: Vec<&str>,
ascending: Vec<bool>,
) -> Result<DataFrame, PolarsError> {
let resolved: Vec<String> = column_names
.iter()
.map(|c| self.resolve_column_name(c))
.collect::<Result<Vec<_>, _>>()?;
let refs: Vec<&str> = resolved.iter().map(|s| s.as_str()).collect();
transformations::order_by(self, refs, ascending, self.case_sensitive)
}
pub fn order_by_exprs(&self, sort_orders: Vec<SortOrder>) -> Result<DataFrame, PolarsError> {
transformations::order_by_exprs(self, sort_orders, self.case_sensitive)
}
pub fn union(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
transformations::union(self, other, self.case_sensitive)
}
pub fn union_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
self.union(other)
}
pub fn union_by_name(
&self,
other: &DataFrame,
allow_missing_columns: bool,
) -> Result<DataFrame, PolarsError> {
transformations::union_by_name(self, other, allow_missing_columns, self.case_sensitive)
}
pub fn distinct(&self, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
transformations::distinct(self, subset, self.case_sensitive)
}
pub fn drop(&self, columns: Vec<&str>) -> Result<DataFrame, PolarsError> {
transformations::drop(self, columns, self.case_sensitive)
}
pub fn dropna(
&self,
subset: Option<Vec<&str>>,
how: &str,
thresh: Option<usize>,
) -> Result<DataFrame, PolarsError> {
transformations::dropna(self, subset, how, thresh, self.case_sensitive)
}
pub fn fillna(&self, value: Expr, subset: Option<Vec<&str>>) -> Result<DataFrame, PolarsError> {
transformations::fillna(self, value, subset, self.case_sensitive)
}
pub fn limit(&self, n: usize) -> Result<DataFrame, PolarsError> {
transformations::limit(self, n, self.case_sensitive)
}
pub fn with_column_renamed(
&self,
old_name: &str,
new_name: &str,
) -> Result<DataFrame, PolarsError> {
transformations::with_column_renamed(self, old_name, new_name, self.case_sensitive)
}
pub fn replace(
&self,
column_name: &str,
old_value: Expr,
new_value: Expr,
) -> Result<DataFrame, PolarsError> {
transformations::replace(self, column_name, old_value, new_value, self.case_sensitive)
}
pub fn cross_join(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
transformations::cross_join(self, other, self.case_sensitive)
}
pub fn describe(&self) -> Result<DataFrame, PolarsError> {
transformations::describe(self, self.case_sensitive)
}
pub fn cache(&self) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn persist(&self) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn unpersist(&self) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn subtract(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
transformations::subtract(self, other, self.case_sensitive)
}
pub fn intersect(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
transformations::intersect(self, other, self.case_sensitive)
}
pub fn sample(
&self,
with_replacement: bool,
fraction: f64,
seed: Option<u64>,
) -> Result<DataFrame, PolarsError> {
transformations::sample(self, with_replacement, fraction, seed, self.case_sensitive)
}
pub fn random_split(
&self,
weights: &[f64],
seed: Option<u64>,
) -> Result<Vec<DataFrame>, PolarsError> {
transformations::random_split(self, weights, seed, self.case_sensitive)
}
pub fn sample_by(
&self,
col_name: &str,
fractions: &[(Expr, f64)],
seed: Option<u64>,
) -> Result<DataFrame, PolarsError> {
transformations::sample_by(self, col_name, fractions, seed, self.case_sensitive)
}
pub fn first(&self) -> Result<DataFrame, PolarsError> {
transformations::first(self, self.case_sensitive)
}
pub fn head(&self, n: usize) -> Result<DataFrame, PolarsError> {
transformations::head(self, n, self.case_sensitive)
}
pub fn take(&self, n: usize) -> Result<DataFrame, PolarsError> {
transformations::take(self, n, self.case_sensitive)
}
pub fn tail(&self, n: usize) -> Result<DataFrame, PolarsError> {
transformations::tail(self, n, self.case_sensitive)
}
pub fn is_empty(&self) -> bool {
transformations::is_empty(self)
}
pub fn to_df(&self, names: Vec<&str>) -> Result<DataFrame, PolarsError> {
transformations::to_df(self, &names, self.case_sensitive)
}
pub fn stat(&self) -> DataFrameStat<'_> {
DataFrameStat { df: self }
}
pub fn corr(&self) -> Result<DataFrame, PolarsError> {
self.stat().corr_matrix()
}
pub fn corr_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
self.stat().corr(col1, col2)
}
pub fn cov_cols(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
self.stat().cov(col1, col2)
}
pub fn summary(&self) -> Result<DataFrame, PolarsError> {
self.describe()
}
pub fn to_json(&self) -> Result<Vec<String>, PolarsError> {
transformations::to_json(self)
}
pub fn explain(&self) -> String {
transformations::explain(self)
}
pub fn print_schema(&self) -> Result<String, PolarsError> {
transformations::print_schema(self)
}
pub fn checkpoint(&self) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn local_checkpoint(&self) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn repartition(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn repartition_by_range(
&self,
_num_partitions: usize,
_cols: Vec<&str>,
) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn dtypes(&self) -> Result<Vec<(String, String)>, PolarsError> {
let schema = self.schema_or_collect()?;
Ok(schema
.iter_names_and_dtypes()
.map(|(name, dtype)| (name.to_string(), format!("{dtype:?}")))
.collect())
}
pub fn sort_within_partitions(
&self,
_cols: &[crate::functions::SortOrder],
) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn coalesce(&self, _num_partitions: usize) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn hint(&self, _name: &str, _params: &[i32]) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn is_local(&self) -> bool {
true
}
pub fn input_files(&self) -> Vec<String> {
Vec::new()
}
pub fn same_semantics(&self, _other: &DataFrame) -> bool {
false
}
pub fn semantic_hash(&self) -> u64 {
0
}
pub fn observe(&self, _name: &str, _expr: Expr) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn with_watermark(
&self,
_event_time: &str,
_delay: &str,
) -> Result<DataFrame, PolarsError> {
Ok(self.clone())
}
pub fn select_expr(&self, exprs: &[String]) -> Result<DataFrame, PolarsError> {
transformations::select_expr(self, exprs, self.case_sensitive)
}
pub fn col_regex(&self, pattern: &str) -> Result<DataFrame, PolarsError> {
transformations::col_regex(self, pattern, self.case_sensitive)
}
pub fn with_columns(&self, exprs: &[(String, Column)]) -> Result<DataFrame, PolarsError> {
transformations::with_columns(self, exprs, self.case_sensitive)
}
pub fn with_columns_renamed(
&self,
renames: &[(String, String)],
) -> Result<DataFrame, PolarsError> {
transformations::with_columns_renamed(self, renames, self.case_sensitive)
}
pub fn na(&self) -> DataFrameNa<'_> {
DataFrameNa { df: self }
}
pub fn offset(&self, n: usize) -> Result<DataFrame, PolarsError> {
transformations::offset(self, n, self.case_sensitive)
}
pub fn transform<F>(&self, f: F) -> Result<DataFrame, PolarsError>
where
F: FnOnce(DataFrame) -> Result<DataFrame, PolarsError>,
{
transformations::transform(self, f)
}
pub fn freq_items(&self, columns: &[&str], support: f64) -> Result<DataFrame, PolarsError> {
transformations::freq_items(self, columns, support, self.case_sensitive)
}
pub fn approx_quantile(
&self,
column: &str,
probabilities: &[f64],
) -> Result<DataFrame, PolarsError> {
transformations::approx_quantile(self, column, probabilities, self.case_sensitive)
}
pub fn crosstab(&self, col1: &str, col2: &str) -> Result<DataFrame, PolarsError> {
transformations::crosstab(self, col1, col2, self.case_sensitive)
}
pub fn melt(&self, id_vars: &[&str], value_vars: &[&str]) -> Result<DataFrame, PolarsError> {
transformations::melt(self, id_vars, value_vars, self.case_sensitive)
}
pub fn unpivot(&self, ids: &[&str], values: &[&str]) -> Result<DataFrame, PolarsError> {
transformations::melt(self, ids, values, self.case_sensitive)
}
pub fn pivot(
&self,
_pivot_col: &str,
_values: Option<Vec<&str>>,
) -> Result<DataFrame, PolarsError> {
Err(PolarsError::InvalidOperation(
"pivot is not yet implemented; use crosstab(col1, col2) for two-column cross-tabulation."
.into(),
))
}
pub fn except_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
transformations::except_all(self, other, self.case_sensitive)
}
pub fn intersect_all(&self, other: &DataFrame) -> Result<DataFrame, PolarsError> {
transformations::intersect_all(self, other, self.case_sensitive)
}
#[cfg(feature = "delta")]
pub fn write_delta(
&self,
path: impl AsRef<std::path::Path>,
overwrite: bool,
) -> Result<(), PolarsError> {
crate::delta::write_delta(self.collect_inner()?.as_ref(), path, overwrite)
}
#[cfg(not(feature = "delta"))]
pub fn write_delta(
&self,
_path: impl AsRef<std::path::Path>,
_overwrite: bool,
) -> Result<(), PolarsError> {
Err(PolarsError::InvalidOperation(
"Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
))
}
pub fn save_as_delta_table(&self, session: &crate::session::SparkSession, name: &str) {
session.register_table(name, self.clone());
}
pub fn write(&self) -> DataFrameWriter<'_> {
DataFrameWriter {
df: self,
mode: WriteMode::Overwrite,
format: WriteFormat::Parquet,
options: HashMap::new(),
partition_by: Vec::new(),
}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum WriteMode {
Overwrite,
Append,
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum SaveMode {
ErrorIfExists,
Overwrite,
Append,
Ignore,
}
#[derive(Clone, Copy)]
pub enum WriteFormat {
Parquet,
Csv,
Json,
}
pub struct DataFrameWriter<'a> {
df: &'a DataFrame,
mode: WriteMode,
format: WriteFormat,
options: HashMap<String, String>,
partition_by: Vec<String>,
}
impl<'a> DataFrameWriter<'a> {
pub fn mode(mut self, mode: WriteMode) -> Self {
self.mode = mode;
self
}
pub fn format(mut self, format: WriteFormat) -> Self {
self.format = format;
self
}
pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.options.insert(key.into(), value.into());
self
}
pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
for (k, v) in opts {
self.options.insert(k, v);
}
self
}
pub fn partition_by(mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.partition_by = cols.into_iter().map(|s| s.into()).collect();
self
}
pub fn save_as_table(
&self,
session: &SparkSession,
name: &str,
mode: SaveMode,
) -> Result<(), PolarsError> {
use polars::prelude::*;
use std::fs;
use std::path::Path;
let warehouse_path = session.warehouse_dir().map(|w| Path::new(w).join(name));
let warehouse_exists = warehouse_path.as_ref().is_some_and(|p| p.is_dir());
fn persist_to_warehouse(
df: &crate::dataframe::DataFrame,
dir: &Path,
) -> Result<(), PolarsError> {
use std::fs;
fs::create_dir_all(dir).map_err(|e| {
PolarsError::ComputeError(format!("saveAsTable: create dir: {e}").into())
})?;
let file_path = dir.join("data.parquet");
df.write()
.mode(crate::dataframe::WriteMode::Overwrite)
.format(crate::dataframe::WriteFormat::Parquet)
.save(&file_path)
}
let final_df = match mode {
SaveMode::ErrorIfExists => {
if session.saved_table_exists(name) || warehouse_exists {
return Err(PolarsError::InvalidOperation(
format!(
"Table or view '{name}' already exists. SaveMode is ErrorIfExists."
)
.into(),
));
}
if let Some(ref p) = warehouse_path {
persist_to_warehouse(self.df, p)?;
}
self.df.clone()
}
SaveMode::Overwrite => {
if let Some(ref p) = warehouse_path {
let _ = fs::remove_dir_all(p);
persist_to_warehouse(self.df, p)?;
}
self.df.clone()
}
SaveMode::Append => {
let existing_pl = if let Some(existing) = session.get_saved_table(name) {
existing.collect_inner()?.as_ref().clone()
} else if let (Some(ref p), true) = (warehouse_path.as_ref(), warehouse_exists) {
let data_file = p.join("data.parquet");
let read_path = if data_file.is_file() {
data_file.as_path()
} else {
p.as_ref()
};
let lf = LazyFrame::scan_parquet(read_path, ScanArgsParquet::default())
.map_err(|e| {
PolarsError::ComputeError(
format!("saveAsTable append: read warehouse: {e}").into(),
)
})?;
lf.collect().map_err(|e| {
PolarsError::ComputeError(
format!("saveAsTable append: collect: {e}").into(),
)
})?
} else {
session.register_table(name, self.df.clone());
if let Some(ref p) = warehouse_path {
persist_to_warehouse(self.df, p)?;
}
return Ok(());
};
let new_pl = self.df.collect_inner()?.as_ref().clone();
let existing_cols: Vec<&str> = existing_pl
.get_column_names()
.iter()
.map(|s| s.as_str())
.collect();
let new_cols = new_pl.get_column_names();
let missing: Vec<_> = existing_cols
.iter()
.filter(|c| !new_cols.iter().any(|n| n.as_str() == **c))
.collect();
if !missing.is_empty() {
return Err(PolarsError::InvalidOperation(
format!(
"saveAsTable append: new DataFrame missing columns: {:?}",
missing
)
.into(),
));
}
let new_ordered = new_pl.select(existing_cols.iter().copied())?;
let mut combined = existing_pl;
combined.vstack_mut(&new_ordered)?;
let merged = crate::dataframe::DataFrame::from_polars_with_options(
combined,
self.df.case_sensitive,
);
if let Some(ref p) = warehouse_path {
let _ = fs::remove_dir_all(p);
persist_to_warehouse(&merged, p)?;
}
merged
}
SaveMode::Ignore => {
if session.saved_table_exists(name) || warehouse_exists {
return Ok(());
}
if let Some(ref p) = warehouse_path {
persist_to_warehouse(self.df, p)?;
}
self.df.clone()
}
};
session.register_table(name, final_df);
Ok(())
}
pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
DataFrameWriter {
df: self.df,
mode: self.mode,
format: WriteFormat::Parquet,
options: self.options.clone(),
partition_by: self.partition_by.clone(),
}
.save(path)
}
pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
DataFrameWriter {
df: self.df,
mode: self.mode,
format: WriteFormat::Csv,
options: self.options.clone(),
partition_by: self.partition_by.clone(),
}
.save(path)
}
pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
DataFrameWriter {
df: self.df,
mode: self.mode,
format: WriteFormat::Json,
options: self.options.clone(),
partition_by: self.partition_by.clone(),
}
.save(path)
}
pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), PolarsError> {
use polars::prelude::*;
let path = path.as_ref();
let to_write: PlDataFrame = match self.mode {
WriteMode::Overwrite => self.df.collect_inner()?.as_ref().clone(),
WriteMode::Append => {
if self.partition_by.is_empty() {
let existing: Option<PlDataFrame> = if path.exists() && path.is_file() {
match self.format {
WriteFormat::Parquet => {
LazyFrame::scan_parquet(path, ScanArgsParquet::default())
.and_then(|lf| lf.collect())
.ok()
}
WriteFormat::Csv => LazyCsvReader::new(path)
.with_has_header(true)
.finish()
.and_then(|lf| lf.collect())
.ok(),
WriteFormat::Json => LazyJsonLineReader::new(path)
.finish()
.and_then(|lf| lf.collect())
.ok(),
}
} else {
None
};
match existing {
Some(existing) => {
let lfs: [LazyFrame; 2] = [
existing.clone().lazy(),
self.df.collect_inner()?.as_ref().clone().lazy(),
];
concat(lfs, UnionArgs::default())?.collect()?
}
None => self.df.collect_inner()?.as_ref().clone(),
}
} else {
self.df.collect_inner()?.as_ref().clone()
}
}
};
if !self.partition_by.is_empty() {
return self.save_partitioned(path, &to_write);
}
match self.format {
WriteFormat::Parquet => {
let mut file = std::fs::File::create(path).map_err(|e| {
PolarsError::ComputeError(format!("write parquet create: {e}").into())
})?;
let mut df_mut = to_write;
ParquetWriter::new(&mut file)
.finish(&mut df_mut)
.map_err(|e| PolarsError::ComputeError(format!("write parquet: {e}").into()))?;
}
WriteFormat::Csv => {
let has_header = self
.options
.get("header")
.map(|v| v.eq_ignore_ascii_case("true") || v == "1")
.unwrap_or(true);
let delimiter = self
.options
.get("sep")
.and_then(|s| s.bytes().next())
.unwrap_or(b',');
let mut file = std::fs::File::create(path).map_err(|e| {
PolarsError::ComputeError(format!("write csv create: {e}").into())
})?;
CsvWriter::new(&mut file)
.include_header(has_header)
.with_separator(delimiter)
.finish(&mut to_write.clone())
.map_err(|e| PolarsError::ComputeError(format!("write csv: {e}").into()))?;
}
WriteFormat::Json => {
let mut file = std::fs::File::create(path).map_err(|e| {
PolarsError::ComputeError(format!("write json create: {e}").into())
})?;
JsonWriter::new(&mut file)
.finish(&mut to_write.clone())
.map_err(|e| PolarsError::ComputeError(format!("write json: {e}").into()))?;
}
}
Ok(())
}
fn save_partitioned(&self, path: &Path, to_write: &PlDataFrame) -> Result<(), PolarsError> {
use polars::prelude::*;
let resolved: Vec<String> = self
.partition_by
.iter()
.map(|c| self.df.resolve_column_name(c))
.collect::<Result<Vec<_>, _>>()?;
let all_names = to_write.get_column_names();
let data_cols: Vec<&str> = all_names
.iter()
.filter(|n| !resolved.iter().any(|r| r == n.as_str()))
.map(|n| n.as_str())
.collect();
let unique_keys = to_write
.select(resolved.iter().map(|s| s.as_str()).collect::<Vec<_>>())?
.unique::<Option<&[String]>, String>(
None,
polars::prelude::UniqueKeepStrategy::First,
None,
)?;
if self.mode == WriteMode::Overwrite && path.exists() {
if path.is_dir() {
std::fs::remove_dir_all(path).map_err(|e| {
PolarsError::ComputeError(
format!("write partitioned: remove_dir_all: {e}").into(),
)
})?;
} else {
std::fs::remove_file(path).map_err(|e| {
PolarsError::ComputeError(format!("write partitioned: remove_file: {e}").into())
})?;
}
}
std::fs::create_dir_all(path).map_err(|e| {
PolarsError::ComputeError(format!("write partitioned: create_dir_all: {e}").into())
})?;
let ext = match self.format {
WriteFormat::Parquet => "parquet",
WriteFormat::Csv => "csv",
WriteFormat::Json => "json",
};
for row_idx in 0..unique_keys.height() {
let row = unique_keys
.get(row_idx)
.ok_or_else(|| PolarsError::ComputeError("partition_row: get row".into()))?;
let filter_expr = partition_row_to_filter_expr(&resolved, &row)?;
let subset = to_write.clone().lazy().filter(filter_expr).collect()?;
let subset = subset.select(data_cols.iter().copied())?;
if subset.height() == 0 {
continue;
}
let part_path: std::path::PathBuf = resolved
.iter()
.zip(row.iter())
.map(|(name, av)| format!("{}={}", name, format_partition_value(av)))
.fold(path.to_path_buf(), |p, seg| p.join(seg));
std::fs::create_dir_all(&part_path).map_err(|e| {
PolarsError::ComputeError(
format!("write partitioned: create_dir_all partition: {e}").into(),
)
})?;
let file_idx = if self.mode == WriteMode::Append {
let suffix = format!(".{ext}");
let max_n = std::fs::read_dir(&part_path)
.map(|rd| {
rd.filter_map(Result::ok)
.filter_map(|e| {
e.file_name().to_str().and_then(|s| {
s.strip_prefix("part-")
.and_then(|t| t.strip_suffix(&suffix))
.and_then(|t| t.parse::<u32>().ok())
})
})
.max()
.unwrap_or(0)
})
.unwrap_or(0);
max_n + 1
} else {
0
};
let filename = format!("part-{file_idx:05}.{ext}");
let file_path = part_path.join(&filename);
match self.format {
WriteFormat::Parquet => {
let mut file = std::fs::File::create(&file_path).map_err(|e| {
PolarsError::ComputeError(
format!("write partitioned parquet create: {e}").into(),
)
})?;
let mut df_mut = subset;
ParquetWriter::new(&mut file)
.finish(&mut df_mut)
.map_err(|e| {
PolarsError::ComputeError(
format!("write partitioned parquet: {e}").into(),
)
})?;
}
WriteFormat::Csv => {
let has_header = self
.options
.get("header")
.map(|v| v.eq_ignore_ascii_case("true") || v == "1")
.unwrap_or(true);
let delimiter = self
.options
.get("sep")
.and_then(|s| s.bytes().next())
.unwrap_or(b',');
let mut file = std::fs::File::create(&file_path).map_err(|e| {
PolarsError::ComputeError(
format!("write partitioned csv create: {e}").into(),
)
})?;
CsvWriter::new(&mut file)
.include_header(has_header)
.with_separator(delimiter)
.finish(&mut subset.clone())
.map_err(|e| {
PolarsError::ComputeError(format!("write partitioned csv: {e}").into())
})?;
}
WriteFormat::Json => {
let mut file = std::fs::File::create(&file_path).map_err(|e| {
PolarsError::ComputeError(
format!("write partitioned json create: {e}").into(),
)
})?;
JsonWriter::new(&mut file)
.finish(&mut subset.clone())
.map_err(|e| {
PolarsError::ComputeError(format!("write partitioned json: {e}").into())
})?;
}
}
}
Ok(())
}
}
impl Clone for DataFrame {
fn clone(&self) -> Self {
DataFrame {
inner: match &self.inner {
DataFrameInner::Eager(df) => DataFrameInner::Eager(df.clone()),
DataFrameInner::Lazy(lf) => DataFrameInner::Lazy(lf.clone()),
},
case_sensitive: self.case_sensitive,
alias: self.alias.clone(),
}
}
}
fn format_partition_value(av: &AnyValue<'_>) -> String {
let s = match av {
AnyValue::Null => "__HIVE_DEFAULT_PARTITION__".to_string(),
AnyValue::Boolean(b) => b.to_string(),
AnyValue::Int32(i) => i.to_string(),
AnyValue::Int64(i) => i.to_string(),
AnyValue::UInt32(u) => u.to_string(),
AnyValue::UInt64(u) => u.to_string(),
AnyValue::Float32(f) => f.to_string(),
AnyValue::Float64(f) => f.to_string(),
AnyValue::String(s) => s.to_string(),
AnyValue::StringOwned(s) => s.as_str().to_string(),
AnyValue::Date(d) => d.to_string(),
_ => av.to_string(),
};
s.replace([std::path::MAIN_SEPARATOR, '/'], "_")
}
fn partition_row_to_filter_expr(
col_names: &[String],
row: &[AnyValue<'_>],
) -> Result<Expr, PolarsError> {
if col_names.len() != row.len() {
return Err(PolarsError::ComputeError(
format!(
"partition_row_to_filter_expr: {} columns but {} row values",
col_names.len(),
row.len()
)
.into(),
));
}
let mut pred = None::<Expr>;
for (name, av) in col_names.iter().zip(row.iter()) {
let clause = match av {
AnyValue::Null => col(name.as_str()).is_null(),
AnyValue::Boolean(b) => col(name.as_str()).eq(lit(*b)),
AnyValue::Int32(i) => col(name.as_str()).eq(lit(*i)),
AnyValue::Int64(i) => col(name.as_str()).eq(lit(*i)),
AnyValue::UInt32(u) => col(name.as_str()).eq(lit(*u)),
AnyValue::UInt64(u) => col(name.as_str()).eq(lit(*u)),
AnyValue::Float32(f) => col(name.as_str()).eq(lit(*f)),
AnyValue::Float64(f) => col(name.as_str()).eq(lit(*f)),
AnyValue::String(s) => col(name.as_str()).eq(lit(s.to_string())),
AnyValue::StringOwned(s) => col(name.as_str()).eq(lit(s.clone())),
_ => {
let s = av.to_string();
col(name.as_str()).cast(DataType::String).eq(lit(s))
}
};
pred = Some(match pred {
None => clause,
Some(p) => p.and(clause),
});
}
Ok(pred.unwrap_or_else(|| lit(true)))
}
fn any_value_to_json(av: AnyValue<'_>) -> JsonValue {
match av {
AnyValue::Null => JsonValue::Null,
AnyValue::Boolean(b) => JsonValue::Bool(b),
AnyValue::Int32(i) => JsonValue::Number(serde_json::Number::from(i)),
AnyValue::Int64(i) => JsonValue::Number(serde_json::Number::from(i)),
AnyValue::UInt32(u) => JsonValue::Number(serde_json::Number::from(u)),
AnyValue::UInt64(u) => JsonValue::Number(serde_json::Number::from(u)),
AnyValue::Float32(f) => serde_json::Number::from_f64(f64::from(f))
.map(JsonValue::Number)
.unwrap_or(JsonValue::Null),
AnyValue::Float64(f) => serde_json::Number::from_f64(f)
.map(JsonValue::Number)
.unwrap_or(JsonValue::Null),
AnyValue::String(s) => JsonValue::String(s.to_string()),
AnyValue::StringOwned(s) => JsonValue::String(s.to_string()),
_ => JsonValue::Null,
}
}
#[cfg(test)]
mod tests {
use super::*;
use polars::prelude::{NamedFrom, Series};
#[test]
fn coerce_string_numeric_root_in_filter() {
let s = Series::new("str_col".into(), &["123", "456"]);
let pl_df = polars::prelude::DataFrame::new(vec![s.into()]).unwrap();
let df = DataFrame::from_polars(pl_df);
let expr = col("str_col").eq(lit(123i64));
let out = df.filter(expr).unwrap();
assert_eq!(out.count().unwrap(), 1);
}
#[test]
fn lazy_schema_columns_resolve_before_collect() {
let spark = SparkSession::builder()
.app_name("lazy_mod_tests")
.get_or_create();
let df = spark
.create_dataframe(
vec![
(1i64, 25i64, "a".to_string()),
(2i64, 30i64, "b".to_string()),
],
vec!["id", "age", "name"],
)
.unwrap();
assert_eq!(df.columns().unwrap(), vec!["id", "age", "name"]);
assert_eq!(df.resolve_column_name("AGE").unwrap(), "age");
assert!(df.get_column_dtype("id").unwrap().is_integer());
}
#[test]
fn lazy_from_lazy_produces_valid_df() {
let _spark = SparkSession::builder()
.app_name("lazy_mod_tests")
.get_or_create();
let pl_df = polars::prelude::df!("x" => &[1i64, 2, 3]).unwrap();
let df = DataFrame::from_lazy_with_options(pl_df.lazy(), false);
assert_eq!(df.columns().unwrap(), vec!["x"]);
assert_eq!(df.count().unwrap(), 3);
}
}