use polars::prelude::*;
pub type CompareOp = polars::prelude::Operator;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[allow(dead_code)] enum TypePrecedence {
Int = 1,
Long = 2,
Decimal = 3,
Float = 4,
Double = 5,
String = 6,
}
fn dtype_to_precedence(dtype: &DataType) -> Option<TypePrecedence> {
match dtype {
DataType::Int32 => Some(TypePrecedence::Int),
DataType::Int64 => Some(TypePrecedence::Long),
DataType::Float32 => Some(TypePrecedence::Float),
DataType::Float64 => Some(TypePrecedence::Double),
DataType::String => Some(TypePrecedence::String),
_ => None,
}
}
pub fn find_common_type(left: &DataType, right: &DataType) -> Result<DataType, PolarsError> {
let left_norm = if matches!(left, DataType::Unknown(_)) {
DataType::String
} else {
left.clone()
};
let right_norm = if matches!(right, DataType::Unknown(_)) {
DataType::String
} else {
right.clone()
};
let left = &left_norm;
let right = &right_norm;
let left_prec = dtype_to_precedence(left);
let right_prec = dtype_to_precedence(right);
match (left_prec, right_prec) {
(Some(l), Some(r)) => {
let target_prec = if l > r { l } else { r };
match target_prec {
TypePrecedence::Int => Ok(DataType::Int32),
TypePrecedence::Long => Ok(DataType::Int64),
TypePrecedence::Float => Ok(DataType::Float32),
TypePrecedence::Double => Ok(DataType::Float64),
TypePrecedence::String => Ok(DataType::String),
_ => Err(PolarsError::ComputeError(
format!(
"Type coercion: unsupported type precedence {target_prec:?}. Supported: Int32, Int64, Float32, Float64, String."
)
.into(),
)),
}
}
_ => {
if is_numeric(left) && is_numeric(right) {
Ok(DataType::Float64)
} else if left == right {
Ok(left.clone())
} else if left == &DataType::String || right == &DataType::String {
Ok(DataType::String)
} else {
Err(PolarsError::ComputeError(
format!(
"Type coercion: cannot find common type for {left:?} and {right:?}. Hint: use cast() to align types, or ensure both are numeric or both are string."
)
.into(),
))
}
}
}
}
pub fn find_common_type_for_join(
left: &DataType,
right: &DataType,
) -> Result<DataType, PolarsError> {
let left_norm = if matches!(left, DataType::Unknown(_)) {
DataType::String
} else {
left.clone()
};
let right_norm = if matches!(right, DataType::Unknown(_)) {
DataType::String
} else {
right.clone()
};
let left = &left_norm;
let right = &right_norm;
if is_numeric(left) && right == &DataType::String {
return Ok(left.clone());
}
if left == &DataType::String && is_numeric(right) {
return Ok(right.clone());
}
find_common_type(left, right)
}
pub fn coerce_expr_pair(
left_name: &str,
right_name: &str,
left_dtype: &DataType,
right_dtype: &DataType,
alias: &str,
) -> Result<(Expr, Expr), PolarsError> {
let common = find_common_type(left_dtype, right_dtype)?;
Ok((
col(left_name).cast(common.clone()).alias(alias),
col(right_name).cast(common).alias(alias),
))
}
pub fn coerce_expr_pair_for_join(
left_name: &str,
right_name: &str,
left_dtype: &DataType,
right_dtype: &DataType,
alias: &str,
) -> Result<(Expr, Expr), PolarsError> {
let common = find_common_type_for_join(left_dtype, right_dtype)?;
Ok((
col(left_name).cast(common.clone()).alias(alias),
col(right_name).cast(common).alias(alias),
))
}
fn is_numeric(dtype: &DataType) -> bool {
matches!(
dtype,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64
)
}
pub fn is_numeric_public(dtype: &DataType) -> bool {
is_numeric(dtype)
}
fn is_date_or_datetime(dtype: &DataType) -> bool {
matches!(dtype, DataType::Date | DataType::Datetime(_, _))
}
pub fn coerce_to_type(expr: Expr, target_type: DataType) -> Expr {
expr.cast(target_type)
}
pub fn coerce_for_comparison(
left: Expr,
right: Expr,
left_type: &DataType,
right_type: &DataType,
) -> Result<(Expr, Expr), PolarsError> {
if left_type == right_type {
return Ok((left, right));
}
let common_type = find_common_type(left_type, right_type)?;
let left_coerced = if left_type == &common_type {
left
} else {
coerce_to_type(left, common_type.clone())
};
let right_coerced = if right_type == &common_type {
right
} else {
coerce_to_type(right, common_type)
};
Ok((left_coerced, right_coerced))
}
pub fn coerce_for_pyspark_comparison(
left: Expr,
right: Expr,
left_type: &DataType,
right_type: &DataType,
_op: &CompareOp,
) -> Result<(Expr, Expr), PolarsError> {
use crate::column::Column;
if is_numeric(left_type) && is_numeric(right_type) {
return coerce_for_comparison(left, right, left_type, right_type);
}
fn wrap_try_to_number(expr: Expr) -> Result<Expr, PolarsError> {
let col = Column::from_expr(expr, None);
let coerced = crate::functions::try_to_number(&col, None)
.map_err(|e| PolarsError::ComputeError(e.into()))?;
Ok(coerced.into_expr())
}
let string_numeric = (left_type == &DataType::String && is_numeric(right_type))
|| (right_type == &DataType::String && is_numeric(left_type));
if string_numeric {
let left_out = if left_type == &DataType::String {
wrap_try_to_number(left)?
} else if is_numeric(left_type) {
coerce_to_type(left, DataType::Float64)
} else {
left
};
let right_out = if right_type == &DataType::String {
wrap_try_to_number(right)?
} else if is_numeric(right_type) {
coerce_to_type(right, DataType::Float64)
} else {
right
};
return Ok((left_out, right_out));
}
fn wrap_try_to_temporal(expr: Expr, target: &DataType) -> Result<Expr, PolarsError> {
if matches!(target, DataType::Date | DataType::Datetime(_, _)) {
if matches!(&expr, Expr::Literal(_)) {
return Ok(expr.cast(target.clone()));
}
}
let col = Column::from_expr(expr, None);
let type_name = match target {
DataType::Date => "date",
DataType::Datetime(..) => "timestamp",
_ => {
return Err(PolarsError::ComputeError(
"date or datetime type required".to_string().into(),
));
}
};
let coerced = crate::functions::try_cast(&col, type_name)
.map_err(|e| PolarsError::ComputeError(e.into()))?;
Ok(coerced.into_expr())
}
let temporal_string = (is_date_or_datetime(left_type) && right_type == &DataType::String)
|| (left_type == &DataType::String && is_date_or_datetime(right_type));
if temporal_string {
let left_out = if left_type == &DataType::String {
wrap_try_to_temporal(left, right_type)?
} else {
left
};
let right_out = if right_type == &DataType::String {
wrap_try_to_temporal(right, left_type)?
} else {
right
};
return Ok((left_out, right_out));
}
let date_vs_datetime = (left_type == &DataType::Date
&& matches!(right_type, DataType::Datetime(_, _)))
|| (matches!(left_type, DataType::Datetime(_, _)) && right_type == &DataType::Date);
if date_vs_datetime {
let target_dt = if matches!(left_type, DataType::Datetime(_, _)) {
left_type.clone()
} else {
right_type.clone()
};
let left_out = if left_type == &DataType::Date {
coerce_to_type(left, target_dt.clone())
} else {
left
};
let right_out = if right_type == &DataType::Date {
coerce_to_type(right, target_dt)
} else {
right
};
return Ok((left_out, right_out));
}
if left_type == right_type && !is_numeric(left_type) {
return Ok((left, right));
}
coerce_for_comparison(left, right, left_type, right_type)
}
pub fn infer_type_from_expr(expr: &Expr) -> Option<DataType> {
match expr {
Expr::Literal(lv) => {
let dt = lv.get_datatype();
Some(if matches!(dt, DataType::Unknown(_)) {
DataType::Float64
} else {
dt
})
}
_ => None,
}
}
fn is_non_numeric_string_literal(expr: &Expr) -> bool {
let _ = expr;
false
}
pub fn coerce_for_pyspark_eq_null_safe(
left: Expr,
right: Expr,
) -> Result<(Expr, Expr), PolarsError> {
if matches!(infer_type_from_expr(&left), Some(DataType::Null))
|| matches!(infer_type_from_expr(&right), Some(DataType::Null))
{
return Ok((left, right));
}
let left_inferred = infer_type_from_expr(&left);
let right_inferred = infer_type_from_expr(&right);
let _left_non_numeric_str_lit = is_non_numeric_string_literal(&left);
let _right_non_numeric_str_lit = is_non_numeric_string_literal(&right);
let (left_ty, right_ty) = match (left_inferred.clone(), right_inferred.clone()) {
(Some(lt), Some(rt)) => (lt, rt),
(Some(DataType::String), None) if matches!(right, Expr::Column(_)) => {
(DataType::String, DataType::String)
}
(None, Some(DataType::String)) if matches!(left, Expr::Column(_)) => {
(DataType::String, DataType::String)
}
(Some(DataType::String), None) => (DataType::String, DataType::Float64),
(None, Some(DataType::String)) => (DataType::Float64, DataType::String),
(Some(lt), None) => (lt.clone(), lt),
(None, Some(rt)) => {
if is_numeric(&rt) && matches!(left, Expr::Column(_)) {
(DataType::String, rt)
} else {
(rt.clone(), rt)
}
}
(None, None) => (DataType::String, DataType::String),
};
if left_inferred.is_none() && right_inferred.is_none() {
let left_str = left.cast(DataType::String);
let right_str = right.cast(DataType::String);
return Ok((left_str, right_str));
}
coerce_for_pyspark_comparison(left, right, &left_ty, &right_ty, &CompareOp::Eq)
}
pub fn coerce_for_pyspark_arithmetic(
left: Expr,
right: Expr,
left_type: &DataType,
right_type: &DataType,
) -> Result<(Expr, Expr), PolarsError> {
use crate::column::Column;
fn wrap_try_to_number(expr: Expr) -> Result<Expr, PolarsError> {
let col = Column::from_expr(expr, None);
let coerced = crate::functions::try_to_number(&col, None)
.map_err(|e| PolarsError::ComputeError(e.into()))?;
Ok(coerced.into_expr())
}
let string_numeric = (left_type == &DataType::String && is_numeric(right_type))
|| (right_type == &DataType::String && is_numeric(left_type));
if !string_numeric {
return Ok((left, right));
}
let left_out = if left_type == &DataType::String {
wrap_try_to_number(left)?
} else if is_numeric(left_type) {
coerce_to_type(left, DataType::Float64)
} else {
left
};
let right_out = if right_type == &DataType::String {
wrap_try_to_number(right)?
} else if is_numeric(right_type) {
coerce_to_type(right, DataType::Float64)
} else {
right
};
Ok((left_out, right_out))
}
#[cfg(test)]
mod tests {
use super::*;
use polars::prelude::{IntoLazy, df};
#[test]
fn numeric_numeric_uses_standard_coercion() -> Result<(), PolarsError> {
let df = df!(
"a" => &[1i32, 2, 3],
"b" => &[1i64, 2, 3]
)?;
let a = col("a");
let b = col("b");
let (ac, bc) = coerce_for_pyspark_comparison(
a.clone(),
b.clone(),
&DataType::Int32,
&DataType::Int64,
&CompareOp::Eq,
)?;
let out = df.lazy().filter(ac.eq(bc)).collect()?;
assert_eq!(out.height(), 3);
Ok(())
}
#[test]
fn eqnullsafe_string_column_vs_int_literal_coerces_like_pyspark() -> Result<(), PolarsError> {
let df = df!(
"str_col" => &["abc", "123"]
)?;
let a = col("str_col");
let lit_123 = lit(123i64);
let (ac, bc) = coerce_for_pyspark_eq_null_safe(a.clone(), lit_123)?;
let lf = df.lazy().select(&[
ac.clone().alias("left"),
bc.clone().alias("right"),
ac.clone().eq(bc.clone()).alias("eq_raw"),
ac.is_null().alias("left_null"),
bc.is_null().alias("right_null"),
]);
let out = lf.collect()?;
let left = out.column("left")?.f64()?;
let right = out.column("right")?.f64()?;
let eq_raw = out.column("eq_raw")?.bool()?;
let left_null = out.column("left_null")?.bool()?;
let right_null = out.column("right_null")?.bool()?;
assert!(left.get(0).is_none());
assert_eq!(right.get(0), Some(123.0));
assert!(eq_raw.get(0).is_none() || eq_raw.get(0) == Some(false));
assert_eq!(left_null.get(0), Some(true));
assert_eq!(right_null.get(0), Some(false));
assert_eq!(left.get(1), Some(123.0));
assert_eq!(right.get(1), Some(123.0));
assert_eq!(eq_raw.get(1), Some(true));
assert_eq!(left_null.get(1), Some(false));
assert_eq!(right_null.get(1), Some(false));
Ok(())
}
#[test]
fn string_numeric_uses_try_to_number() -> Result<(), PolarsError> {
let df = df!(
"s" => &["123", " 45.5 ", "abc"],
"n" => &[123i32, 46, 0]
)?;
let s_expr = col("s");
let n_expr = col("n");
let (s_coerced, n_coerced) = coerce_for_pyspark_comparison(
s_expr.clone(),
n_expr.clone(),
&DataType::String,
&DataType::Int32,
&CompareOp::Eq,
)?;
let out = df.lazy().filter(s_coerced.eq(n_coerced)).collect()?;
assert_eq!(out.height(), 1);
Ok(())
}
#[test]
fn date_datetime_comparison_coerces_date_to_datetime() -> Result<(), PolarsError> {
use chrono::{NaiveDate, NaiveDateTime};
use polars::prelude::*;
let ts = NaiveDateTime::parse_from_str("2024-01-14 23:00:00", "%Y-%m-%d %H:%M:%S").unwrap();
let dt = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
let df = df!(
"ts_col" => [ts],
"date_col" => [dt]
)?;
let df = df
.lazy()
.with_columns([
col("ts_col").cast(DataType::Datetime(TimeUnit::Microseconds, None)),
col("date_col").cast(DataType::Date),
])
.collect()?;
let lf = df.lazy();
let ts_expr = col("ts_col");
let date_expr = col("date_col");
let (ts_c, date_c) = coerce_for_pyspark_comparison(
ts_expr,
date_expr,
&DataType::Datetime(TimeUnit::Microseconds, None),
&DataType::Date,
&CompareOp::Lt,
)?;
let out = lf.filter(ts_c.lt(date_c)).collect()?;
assert_eq!(
out.height(),
1,
"#615: datetime < date should return one row"
);
Ok(())
}
#[test]
fn eq_null_safe_int_column_vs_string_literal_coerces() -> Result<(), PolarsError> {
use crate::column::Column;
use polars::prelude::df;
let df = df!(
"val" => &[Some(123i64), Some(456i64), None],
)?;
let lf = df.lazy();
let col = Column::from_expr(col("val"), None);
let lit_col = crate::functions::lit_i64(123);
let result_col = col.eq_null_safe(&lit_col);
let out = lf
.with_column(result_col.into_expr().alias("match"))
.collect()?;
let matches = out.column("match")?.bool()?;
let got: Vec<Option<bool>> = matches.into_iter().collect();
assert_eq!(got, vec![Some(true), Some(false), Some(false)]);
Ok(())
}
}