use polars::prelude::*;
pub type CompareOp = polars::prelude::Operator;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[allow(dead_code)] enum TypePrecedence {
Int = 1,
Long = 2,
Decimal = 3,
Float = 4,
Double = 5,
String = 6,
}
fn dtype_to_precedence(dtype: &DataType) -> Option<TypePrecedence> {
match dtype {
DataType::Int32 => Some(TypePrecedence::Int),
DataType::Int64 => Some(TypePrecedence::Long),
DataType::Float32 => Some(TypePrecedence::Float),
DataType::Float64 => Some(TypePrecedence::Double),
DataType::String => Some(TypePrecedence::String),
_ => None,
}
}
pub fn find_common_type(left: &DataType, right: &DataType) -> Result<DataType, PolarsError> {
let left_prec = dtype_to_precedence(left);
let right_prec = dtype_to_precedence(right);
match (left_prec, right_prec) {
(Some(l), Some(r)) => {
let target_prec = if l > r { l } else { r };
match target_prec {
TypePrecedence::Int => Ok(DataType::Int32),
TypePrecedence::Long => Ok(DataType::Int64),
TypePrecedence::Float => Ok(DataType::Float32),
TypePrecedence::Double => Ok(DataType::Float64),
TypePrecedence::String => Ok(DataType::String),
_ => Err(PolarsError::ComputeError(
format!(
"Type coercion: unsupported type precedence {target_prec:?}. Supported: Int32, Int64, Float32, Float64, String."
)
.into(),
)),
}
}
_ => {
if is_numeric(left) && is_numeric(right) {
Ok(DataType::Float64)
} else if left == right {
Ok(left.clone())
} else {
Err(PolarsError::ComputeError(
format!(
"Type coercion: cannot find common type for {left:?} and {right:?}. Hint: use cast() to align types, or ensure both are numeric or both are string."
)
.into(),
))
}
}
}
}
fn is_numeric(dtype: &DataType) -> bool {
matches!(
dtype,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64
)
}
pub fn coerce_to_type(expr: Expr, target_type: DataType) -> Expr {
expr.cast(target_type)
}
pub fn coerce_for_comparison(
left: Expr,
right: Expr,
left_type: &DataType,
right_type: &DataType,
) -> Result<(Expr, Expr), PolarsError> {
if left_type == right_type {
return Ok((left, right));
}
let common_type = find_common_type(left_type, right_type)?;
let left_coerced = if left_type == &common_type {
left
} else {
coerce_to_type(left, common_type.clone())
};
let right_coerced = if right_type == &common_type {
right
} else {
coerce_to_type(right, common_type)
};
Ok((left_coerced, right_coerced))
}
pub fn coerce_for_pyspark_comparison(
left: Expr,
right: Expr,
left_type: &DataType,
right_type: &DataType,
_op: &CompareOp,
) -> Result<(Expr, Expr), PolarsError> {
use crate::column::Column;
if is_numeric(left_type) && is_numeric(right_type) {
return coerce_for_comparison(left, right, left_type, right_type);
}
fn wrap_try_to_number(expr: Expr) -> Result<Expr, PolarsError> {
let col = Column::from_expr(expr, None);
let coerced = crate::functions::try_to_number(&col, None)
.map_err(|e| PolarsError::ComputeError(e.into()))?;
Ok(coerced.into_expr())
}
let string_numeric = (left_type == &DataType::String && is_numeric(right_type))
|| (right_type == &DataType::String && is_numeric(left_type));
if string_numeric {
let left_out = if left_type == &DataType::String {
wrap_try_to_number(left)?
} else if is_numeric(left_type) {
coerce_to_type(left, DataType::Float64)
} else {
left
};
let right_out = if right_type == &DataType::String {
wrap_try_to_number(right)?
} else if is_numeric(right_type) {
coerce_to_type(right, DataType::Float64)
} else {
right
};
return Ok((left_out, right_out));
}
if left_type == right_type && !is_numeric(left_type) {
return Ok((left, right));
}
coerce_for_comparison(left, right, left_type, right_type)
}
#[cfg(test)]
mod tests {
use super::*;
use polars::prelude::{df, IntoLazy};
#[test]
fn numeric_numeric_uses_standard_coercion() -> Result<(), PolarsError> {
let df = df!(
"a" => &[1i32, 2, 3],
"b" => &[1i64, 2, 3]
)?;
let a = col("a");
let b = col("b");
let (ac, bc) = coerce_for_pyspark_comparison(
a.clone(),
b.clone(),
&DataType::Int32,
&DataType::Int64,
&CompareOp::Eq,
)?;
let out = df.lazy().filter(ac.eq(bc)).collect()?;
assert_eq!(out.height(), 3);
Ok(())
}
#[test]
fn string_numeric_uses_try_to_number() -> Result<(), PolarsError> {
let df = df!(
"s" => &["123", " 45.5 ", "abc"],
"n" => &[123i32, 46, 0]
)?;
let s_expr = col("s");
let n_expr = col("n");
let (s_coerced, n_coerced) = coerce_for_pyspark_comparison(
s_expr.clone(),
n_expr.clone(),
&DataType::String,
&DataType::Int32,
&CompareOp::Eq,
)?;
let out = df.lazy().filter(s_coerced.eq(n_coerced)).collect()?;
assert_eq!(out.height(), 1);
Ok(())
}
}