1use polars::prelude::*;
2
3pub type CompareOp = polars::prelude::Operator;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
11#[allow(dead_code)] enum TypePrecedence {
13 Int = 1,
14 Long = 2,
15 Decimal = 3,
16 Float = 4,
17 Double = 5,
18 String = 6,
19}
20
21fn dtype_to_precedence(dtype: &DataType) -> Option<TypePrecedence> {
23 match dtype {
24 DataType::Int32 => Some(TypePrecedence::Int),
25 DataType::Int64 => Some(TypePrecedence::Long),
26 DataType::Float32 => Some(TypePrecedence::Float),
27 DataType::Float64 => Some(TypePrecedence::Double),
28 DataType::String => Some(TypePrecedence::String),
29 _ => None,
31 }
32}
33
34pub fn find_common_type(left: &DataType, right: &DataType) -> Result<DataType, PolarsError> {
37 let left_prec = dtype_to_precedence(left);
38 let right_prec = dtype_to_precedence(right);
39
40 match (left_prec, right_prec) {
41 (Some(l), Some(r)) => {
42 let target_prec = if l > r { l } else { r };
44 match target_prec {
45 TypePrecedence::Int => Ok(DataType::Int32),
46 TypePrecedence::Long => Ok(DataType::Int64),
47 TypePrecedence::Float => Ok(DataType::Float32),
48 TypePrecedence::Double => Ok(DataType::Float64),
49 TypePrecedence::String => Ok(DataType::String),
50 _ => Err(PolarsError::ComputeError(
51 format!(
52 "Type coercion: unsupported type precedence {target_prec:?}. Supported: Int32, Int64, Float32, Float64, String."
53 )
54 .into(),
55 )),
56 }
57 }
58 _ => {
59 if is_numeric(left) && is_numeric(right) {
61 Ok(DataType::Float64)
62 } else if left == right {
63 Ok(left.clone())
64 } else if left == &DataType::String || right == &DataType::String {
65 Ok(DataType::String)
67 } else {
68 Err(PolarsError::ComputeError(
69 format!(
70 "Type coercion: cannot find common type for {left:?} and {right:?}. Hint: use cast() to align types, or ensure both are numeric or both are string."
71 )
72 .into(),
73 ))
74 }
75 }
76 }
77}
78
79fn is_numeric(dtype: &DataType) -> bool {
81 matches!(
82 dtype,
83 DataType::Int8
84 | DataType::Int16
85 | DataType::Int32
86 | DataType::Int64
87 | DataType::UInt8
88 | DataType::UInt16
89 | DataType::UInt32
90 | DataType::UInt64
91 | DataType::Float32
92 | DataType::Float64
93 )
94}
95
96fn is_date_or_datetime(dtype: &DataType) -> bool {
98 matches!(dtype, DataType::Date | DataType::Datetime(_, _))
99}
100
101pub fn coerce_to_type(expr: Expr, target_type: DataType) -> Expr {
103 expr.cast(target_type)
104}
105
106pub fn coerce_for_comparison(
108 left: Expr,
109 right: Expr,
110 left_type: &DataType,
111 right_type: &DataType,
112) -> Result<(Expr, Expr), PolarsError> {
113 if left_type == right_type {
114 return Ok((left, right));
116 }
117
118 let common_type = find_common_type(left_type, right_type)?;
119
120 let left_coerced = if left_type == &common_type {
121 left
122 } else {
123 coerce_to_type(left, common_type.clone())
124 };
125
126 let right_coerced = if right_type == &common_type {
127 right
128 } else {
129 coerce_to_type(right, common_type)
130 };
131
132 Ok((left_coerced, right_coerced))
133}
134
135pub fn coerce_for_pyspark_comparison(
145 left: Expr,
146 right: Expr,
147 left_type: &DataType,
148 right_type: &DataType,
149 _op: &CompareOp,
150) -> Result<(Expr, Expr), PolarsError> {
151 use crate::column::Column;
152
153 if is_numeric(left_type) && is_numeric(right_type) {
155 return coerce_for_comparison(left, right, left_type, right_type);
156 }
157
158 fn wrap_try_to_number(expr: Expr) -> Result<Expr, PolarsError> {
161 let col = Column::from_expr(expr, None);
162 let coerced = crate::functions::try_to_number(&col, None)
163 .map_err(|e| PolarsError::ComputeError(e.into()))?;
164 Ok(coerced.into_expr())
165 }
166
167 let string_numeric = (left_type == &DataType::String && is_numeric(right_type))
170 || (right_type == &DataType::String && is_numeric(left_type));
171
172 if string_numeric {
173 let left_out = if left_type == &DataType::String {
174 wrap_try_to_number(left)?
175 } else if is_numeric(left_type) {
176 coerce_to_type(left, DataType::Float64)
177 } else {
178 left
179 };
180
181 let right_out = if right_type == &DataType::String {
182 wrap_try_to_number(right)?
183 } else if is_numeric(right_type) {
184 coerce_to_type(right, DataType::Float64)
185 } else {
186 right
187 };
188
189 return Ok((left_out, right_out));
190 }
191
192 fn wrap_try_to_temporal(expr: Expr, target: &DataType) -> Result<Expr, PolarsError> {
194 let col = Column::from_expr(expr, None);
195 let type_name = match target {
196 DataType::Date => "date",
197 DataType::Datetime(..) => "timestamp",
198 _ => {
199 return Err(PolarsError::ComputeError(
200 "date or datetime type required".to_string().into(),
201 ));
202 }
203 };
204 let coerced = crate::functions::try_cast(&col, type_name)
205 .map_err(|e| PolarsError::ComputeError(e.into()))?;
206 Ok(coerced.into_expr())
207 }
208
209 let temporal_string = (is_date_or_datetime(left_type) && right_type == &DataType::String)
210 || (left_type == &DataType::String && is_date_or_datetime(right_type));
211
212 if temporal_string {
213 let left_out = if left_type == &DataType::String {
214 wrap_try_to_temporal(left, right_type)?
215 } else {
216 left
217 };
218 let right_out = if right_type == &DataType::String {
219 wrap_try_to_temporal(right, left_type)?
220 } else {
221 right
222 };
223 return Ok((left_out, right_out));
224 }
225
226 let date_vs_datetime = (left_type == &DataType::Date
229 && matches!(right_type, DataType::Datetime(_, _)))
230 || (matches!(left_type, DataType::Datetime(_, _)) && right_type == &DataType::Date);
231 if date_vs_datetime {
232 let target_dt = if matches!(left_type, DataType::Datetime(_, _)) {
233 left_type.clone()
234 } else {
235 right_type.clone()
236 };
237 let left_out = if left_type == &DataType::Date {
238 coerce_to_type(left, target_dt.clone())
239 } else {
240 left
241 };
242 let right_out = if right_type == &DataType::Date {
243 coerce_to_type(right, target_dt)
244 } else {
245 right
246 };
247 return Ok((left_out, right_out));
248 }
249
250 if left_type == right_type && !is_numeric(left_type) {
252 return Ok((left, right));
253 }
254
255 coerce_for_comparison(left, right, left_type, right_type)
257}
258
259pub fn infer_type_from_expr(expr: &Expr) -> Option<DataType> {
261 match expr {
262 Expr::Literal(lv) => {
263 let dt = lv.get_datatype();
264 Some(if matches!(dt, DataType::Unknown(_)) {
265 DataType::Float64
266 } else {
267 dt
268 })
269 }
270 _ => None,
271 }
272}
273
274pub fn coerce_for_pyspark_eq_null_safe(
277 left: Expr,
278 right: Expr,
279) -> Result<(Expr, Expr), PolarsError> {
280 let left_ty = infer_type_from_expr(&left).unwrap_or(DataType::String);
281 let right_ty = infer_type_from_expr(&right).unwrap_or(DataType::String);
282 coerce_for_pyspark_comparison(left, right, &left_ty, &right_ty, &CompareOp::Eq)
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use polars::prelude::{IntoLazy, df};
289
290 #[test]
291 fn numeric_numeric_uses_standard_coercion() -> Result<(), PolarsError> {
292 let df = df!(
293 "a" => &[1i32, 2, 3],
294 "b" => &[1i64, 2, 3]
295 )?;
296
297 let a = col("a");
298 let b = col("b");
299 let (ac, bc) = coerce_for_pyspark_comparison(
300 a.clone(),
301 b.clone(),
302 &DataType::Int32,
303 &DataType::Int64,
304 &CompareOp::Eq,
305 )?;
306
307 let out = df.lazy().filter(ac.eq(bc)).collect()?;
309 assert_eq!(out.height(), 3);
310 Ok(())
311 }
312
313 #[test]
314 fn string_numeric_uses_try_to_number() -> Result<(), PolarsError> {
315 let df = df!(
316 "s" => &["123", " 45.5 ", "abc"],
317 "n" => &[123i32, 46, 0]
318 )?;
319
320 let s_expr = col("s");
321 let n_expr = col("n");
322
323 let (s_coerced, n_coerced) = coerce_for_pyspark_comparison(
324 s_expr.clone(),
325 n_expr.clone(),
326 &DataType::String,
327 &DataType::Int32,
328 &CompareOp::Eq,
329 )?;
330
331 let out = df.lazy().filter(s_coerced.eq(n_coerced)).collect()?;
332
333 assert_eq!(out.height(), 1);
335 Ok(())
336 }
337
338 #[test]
340 fn date_datetime_comparison_coerces_date_to_datetime() -> Result<(), PolarsError> {
341 use chrono::{NaiveDate, NaiveDateTime};
342 use polars::prelude::*;
343
344 let ts = NaiveDateTime::parse_from_str("2024-01-14 23:00:00", "%Y-%m-%d %H:%M:%S").unwrap();
345 let dt = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap();
346 let df = df!(
347 "ts_col" => [ts],
348 "date_col" => [dt]
349 )?;
350 let df = df
351 .lazy()
352 .with_columns([
353 col("ts_col").cast(DataType::Datetime(TimeUnit::Microseconds, None)),
354 col("date_col").cast(DataType::Date),
355 ])
356 .collect()?;
357 let lf = df.lazy();
358
359 let ts_expr = col("ts_col");
360 let date_expr = col("date_col");
361 let (ts_c, date_c) = coerce_for_pyspark_comparison(
362 ts_expr,
363 date_expr,
364 &DataType::Datetime(TimeUnit::Microseconds, None),
365 &DataType::Date,
366 &CompareOp::Lt,
367 )?;
368
369 let out = lf.filter(ts_c.lt(date_c)).collect()?;
370 assert_eq!(
371 out.height(),
372 1,
373 "#615: datetime < date should return one row"
374 );
375 Ok(())
376 }
377}