Skip to main content

robin_sparkless_expr/
type_coercion.rs

1use polars::prelude::*;
2
3/// Comparison operators of interest for PySpark-style coercion.
4///
5/// We keep a local alias to avoid leaking polars::prelude in public signatures unnecessarily.
6pub type CompareOp = polars::prelude::Operator;
7
8/// Type precedence for ANSI SQL type coercion
9/// Higher precedence types can be coerced to, lower precedence types are coerced from
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
11#[allow(dead_code)] // Decimal reserved for future use
12enum TypePrecedence {
13    Int = 1,
14    Long = 2,
15    Decimal = 3,
16    Float = 4,
17    Double = 5,
18    String = 6,
19}
20
21/// Convert Polars DataType to TypePrecedence
22fn dtype_to_precedence(dtype: &DataType) -> Option<TypePrecedence> {
23    match dtype {
24        DataType::Int32 => Some(TypePrecedence::Int),
25        DataType::Int64 => Some(TypePrecedence::Long),
26        DataType::Float32 => Some(TypePrecedence::Float),
27        DataType::Float64 => Some(TypePrecedence::Double),
28        DataType::String => Some(TypePrecedence::String),
29        // Decimal: add when Polars exposes Decimal in public API / dtype set we use
30        _ => None,
31    }
32}
33
34/// Determine the common type for two columns based on PySpark's type precedence rules
35/// Returns the tightest (highest precedence) common type that both can be coerced to
36pub fn find_common_type(left: &DataType, right: &DataType) -> Result<DataType, PolarsError> {
37    let left_prec = dtype_to_precedence(left);
38    let right_prec = dtype_to_precedence(right);
39
40    match (left_prec, right_prec) {
41        (Some(l), Some(r)) => {
42            // Return the type with higher precedence
43            let target_prec = if l > r { l } else { r };
44            match target_prec {
45                TypePrecedence::Int => Ok(DataType::Int32),
46                TypePrecedence::Long => Ok(DataType::Int64),
47                TypePrecedence::Float => Ok(DataType::Float32),
48                TypePrecedence::Double => Ok(DataType::Float64),
49                TypePrecedence::String => Ok(DataType::String),
50                _ => Err(PolarsError::ComputeError(
51                    format!(
52                        "Type coercion: unsupported type precedence {target_prec:?}. Supported: Int32, Int64, Float32, Float64, String."
53                    )
54                    .into(),
55                )),
56            }
57        }
58        _ => {
59            // If types don't match known precedence, try to find a common type
60            if is_numeric(left) && is_numeric(right) {
61                Ok(DataType::Float64)
62            } else if left == right {
63                Ok(left.clone())
64            } else if left == &DataType::String || right == &DataType::String {
65                // #613: unionByName string vs numeric -> coerce to String (PySpark parity)
66                Ok(DataType::String)
67            } else {
68                Err(PolarsError::ComputeError(
69                    format!(
70                        "Type coercion: cannot find common type for {left:?} and {right:?}. Hint: use cast() to align types, or ensure both are numeric or both are string."
71                    )
72                    .into(),
73                ))
74            }
75        }
76    }
77}
78
79/// Check if a DataType is numeric
80fn is_numeric(dtype: &DataType) -> bool {
81    matches!(
82        dtype,
83        DataType::Int8
84            | DataType::Int16
85            | DataType::Int32
86            | DataType::Int64
87            | DataType::UInt8
88            | DataType::UInt16
89            | DataType::UInt32
90            | DataType::UInt64
91            | DataType::Float32
92            | DataType::Float64
93    )
94}
95
96/// Check if a DataType is date/datetime (temporal types that we can cast from string via try_cast).
97fn is_date_or_datetime(dtype: &DataType) -> bool {
98    matches!(dtype, DataType::Date | DataType::Datetime(_, _))
99}
100
101/// Coerce a column expression to a target type
102pub fn coerce_to_type(expr: Expr, target_type: DataType) -> Expr {
103    expr.cast(target_type)
104}
105
106/// Coerce two expressions to their common type for comparison
107pub fn coerce_for_comparison(
108    left: Expr,
109    right: Expr,
110    left_type: &DataType,
111    right_type: &DataType,
112) -> Result<(Expr, Expr), PolarsError> {
113    if left_type == right_type {
114        // Same type, no coercion needed
115        return Ok((left, right));
116    }
117
118    let common_type = find_common_type(left_type, right_type)?;
119
120    let left_coerced = if left_type == &common_type {
121        left
122    } else {
123        coerce_to_type(left, common_type.clone())
124    };
125
126    let right_coerced = if right_type == &common_type {
127        right
128    } else {
129        coerce_to_type(right, common_type)
130    };
131
132    Ok((left_coerced, right_coerced))
133}
134
135/// Coerce two expressions for PySpark-style comparison semantics.
136///
137/// This extends [`coerce_for_comparison`] to handle string–numeric combinations by
138/// parsing string values to numbers (double) instead of erroring, mirroring PySpark:
139///
140/// - String values that parse as numbers (e.g. "123", " 45.6 ") are compared numerically.
141/// - Non‑numeric strings behave as null under numeric comparison (non-matching in filters).
142///
143/// For plain numeric–numeric inputs, it delegates to [`coerce_for_comparison`].
144pub fn coerce_for_pyspark_comparison(
145    left: Expr,
146    right: Expr,
147    left_type: &DataType,
148    right_type: &DataType,
149    _op: &CompareOp,
150) -> Result<(Expr, Expr), PolarsError> {
151    use crate::column::Column;
152
153    // Fast-path: both numeric -> existing numeric coercion.
154    if is_numeric(left_type) && is_numeric(right_type) {
155        return coerce_for_comparison(left, right, left_type, right_type);
156    }
157
158    // Helper to wrap an Expr in try_to_number (double) semantics when it represents a value
159    // that should be interpreted as numeric if possible.
160    fn wrap_try_to_number(expr: Expr) -> Result<Expr, PolarsError> {
161        let col = Column::from_expr(expr, None);
162        let coerced = crate::functions::try_to_number(&col, None)
163            .map_err(|e| PolarsError::ComputeError(e.into()))?;
164        Ok(coerced.into_expr())
165    }
166
167    // String–numeric (or numeric–string): route string side through try_to_number and
168    // cast numeric side to Float64 so both sides line up.
169    let string_numeric = (left_type == &DataType::String && is_numeric(right_type))
170        || (right_type == &DataType::String && is_numeric(left_type));
171
172    if string_numeric {
173        let left_out = if left_type == &DataType::String {
174            wrap_try_to_number(left)?
175        } else if is_numeric(left_type) {
176            coerce_to_type(left, DataType::Float64)
177        } else {
178            left
179        };
180
181        let right_out = if right_type == &DataType::String {
182            wrap_try_to_number(right)?
183        } else if is_numeric(right_type) {
184            coerce_to_type(right, DataType::Float64)
185        } else {
186            right
187        };
188
189        return Ok((left_out, right_out));
190    }
191
192    // Date/datetime vs string: cast string side to the temporal type (PySpark implicit cast).
193    fn wrap_try_to_temporal(expr: Expr, target: &DataType) -> Result<Expr, PolarsError> {
194        let col = Column::from_expr(expr, None);
195        let type_name = match target {
196            DataType::Date => "date",
197            DataType::Datetime(..) => "timestamp",
198            _ => {
199                return Err(PolarsError::ComputeError(
200                    "date or datetime type required".to_string().into(),
201                ));
202            }
203        };
204        let coerced = crate::functions::try_cast(&col, type_name)
205            .map_err(|e| PolarsError::ComputeError(e.into()))?;
206        Ok(coerced.into_expr())
207    }
208
209    let temporal_string = (is_date_or_datetime(left_type) && right_type == &DataType::String)
210        || (left_type == &DataType::String && is_date_or_datetime(right_type));
211
212    if temporal_string {
213        let left_out = if left_type == &DataType::String {
214            wrap_try_to_temporal(left, right_type)?
215        } else {
216            left
217        };
218        let right_out = if right_type == &DataType::String {
219            wrap_try_to_temporal(right, left_type)?
220        } else {
221            right
222        };
223        return Ok((left_out, right_out));
224    }
225
226    // #615: Date vs datetime comparison (e.g. datetime_col < date_col): cast Date to Datetime
227    // so both sides are comparable (PySpark treats date as start-of-day timestamp).
228    let date_vs_datetime = (left_type == &DataType::Date
229        && matches!(right_type, DataType::Datetime(_, _)))
230        || (matches!(left_type, DataType::Datetime(_, _)) && right_type == &DataType::Date);
231    if date_vs_datetime {
232        let target_dt = if matches!(left_type, DataType::Datetime(_, _)) {
233            left_type.clone()
234        } else {
235            right_type.clone()
236        };
237        let left_out = if left_type == &DataType::Date {
238            coerce_to_type(left, target_dt.clone())
239        } else {
240            left
241        };
242        let right_out = if right_type == &DataType::Date {
243            coerce_to_type(right, target_dt)
244        } else {
245            right
246        };
247        return Ok((left_out, right_out));
248    }
249
250    // Equal non-numeric types: leave as-is for now.
251    if left_type == right_type && !is_numeric(left_type) {
252        return Ok((left, right));
253    }
254
255    // Fallback to generic comparison coercion (may error with a clear message).
256    coerce_for_comparison(left, right, left_type, right_type)
257}
258
259/// Infer DataType from an expression when it is a literal (for coercion heuristics).
260pub fn infer_type_from_expr(expr: &Expr) -> Option<DataType> {
261    match expr {
262        Expr::Literal(lv) => {
263            let dt = lv.get_datatype();
264            Some(if matches!(dt, DataType::Unknown(_)) {
265                DataType::Float64
266            } else {
267                dt
268            })
269        }
270        _ => None,
271    }
272}
273
274/// Coerce left/right for eq_null_safe so string–numeric compares like PySpark (try_to_number on string side).
275/// Infers types from literals; assumes String for column (so string–numeric gets coerced).
276pub fn coerce_for_pyspark_eq_null_safe(
277    left: Expr,
278    right: Expr,
279) -> Result<(Expr, Expr), PolarsError> {
280    let left_ty = infer_type_from_expr(&left).unwrap_or(DataType::String);
281    let right_ty = infer_type_from_expr(&right).unwrap_or(DataType::String);
282    coerce_for_pyspark_comparison(left, right, &left_ty, &right_ty, &CompareOp::Eq)
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use polars::prelude::{IntoLazy, df};
289
290    #[test]
291    fn numeric_numeric_uses_standard_coercion() -> Result<(), PolarsError> {
292        let df = df!(
293            "a" => &[1i32, 2, 3],
294            "b" => &[1i64, 2, 3]
295        )?;
296
297        let a = col("a");
298        let b = col("b");
299        let (ac, bc) = coerce_for_pyspark_comparison(
300            a.clone(),
301            b.clone(),
302            &DataType::Int32,
303            &DataType::Int64,
304            &CompareOp::Eq,
305        )?;
306
307        // After coercion both sides should be comparable without error and all rows match.
308        let out = df.lazy().filter(ac.eq(bc)).collect()?;
309        assert_eq!(out.height(), 3);
310        Ok(())
311    }
312
313    #[test]
314    fn string_numeric_uses_try_to_number() -> Result<(), PolarsError> {
315        let df = df!(
316            "s" => &["123", " 45.5 ", "abc"],
317            "n" => &[123i32, 46, 0]
318        )?;
319
320        let s_expr = col("s");
321        let n_expr = col("n");
322
323        let (s_coerced, n_coerced) = coerce_for_pyspark_comparison(
324            s_expr.clone(),
325            n_expr.clone(),
326            &DataType::String,
327            &DataType::Int32,
328            &CompareOp::Eq,
329        )?;
330
331        let out = df.lazy().filter(s_coerced.eq(n_coerced)).collect()?;
332
333        // Only the first row matches ("123" == 123); " 45.5 " != 46, "abc" -> null (non-match).
334        assert_eq!(out.height(), 1);
335        Ok(())
336    }
337
338    /// #615: datetime_col < date_col must return rows (PySpark: date as start-of-day for comparison).
339    #[test]
340    fn date_datetime_comparison_coerces_date_to_datetime() -> Result<(), PolarsError> {
341        use chrono::{NaiveDate, NaiveDateTime};
342        use polars::prelude::*;
343
344        let ts = NaiveDateTime::parse_from_str("2024-01-14 23:00:00", "%Y-%m-%d %H:%M:%S").unwrap();
345        let dt = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
346        let df = df!(
347            "ts_col" => [ts],
348            "date_col" => [dt]
349        )?;
350        let df = df
351            .lazy()
352            .with_columns([
353                col("ts_col").cast(DataType::Datetime(TimeUnit::Microseconds, None)),
354                col("date_col").cast(DataType::Date),
355            ])
356            .collect()?;
357        let lf = df.lazy();
358
359        let ts_expr = col("ts_col");
360        let date_expr = col("date_col");
361        let (ts_c, date_c) = coerce_for_pyspark_comparison(
362            ts_expr,
363            date_expr,
364            &DataType::Datetime(TimeUnit::Microseconds, None),
365            &DataType::Date,
366            &CompareOp::Lt,
367        )?;
368
369        let out = lf.filter(ts_c.lt(date_c)).collect()?;
370        assert_eq!(
371            out.height(),
372            1,
373            "#615: datetime < date should return one row"
374        );
375        Ok(())
376    }
377}