Skip to main content

kyu_expression/
coercion.rs

1//! Type coercion — implicit cast insertion for expression binding.
2//!
3//! Wraps `kyu_types::type_utils` functions and inserts `BoundExpression::Cast`
4//! nodes when implicit casts are needed.
5
6use kyu_common::{KyuError, KyuResult};
7use kyu_types::LogicalType;
8use kyu_types::type_utils::{are_comparable, arithmetic_result_type, implicit_cast_cost};
9
10use crate::bound_expr::BoundExpression;
11
12/// Try to coerce an expression to the target type.
13///
14/// Returns the original expression if types already match, wraps in a `Cast`
15/// node if an implicit cast is possible, or returns an error.
16pub fn try_coerce(expr: BoundExpression, target: &LogicalType) -> KyuResult<BoundExpression> {
17    let from = expr.result_type();
18    if from == target {
19        return Ok(expr);
20    }
21    // Any can be cast to anything (e.g., NULL literal).
22    if implicit_cast_cost(from, target).is_some() {
23        Ok(BoundExpression::Cast {
24            expr: Box::new(expr),
25            target_type: target.clone(),
26        })
27    } else {
28        Err(KyuError::Binder(format!(
29            "cannot implicitly cast {} to {}",
30            from.type_name(),
31            target.type_name(),
32        )))
33    }
34}
35
36/// Find the common type for a list of types.
37///
38/// All types must be implicitly castable to the result. Used for CASE
39/// branches, UNION columns, and IN list elements.
40pub fn common_type(types: &[LogicalType]) -> KyuResult<LogicalType> {
41    if types.is_empty() {
42        return Ok(LogicalType::Any);
43    }
44
45    let mut result = types[0].clone();
46    for ty in &types[1..] {
47        if *ty == result {
48            continue;
49        }
50        // Try casting left to right, or right to left, pick the cheaper direction.
51        let cost_lr = implicit_cast_cost(&result, ty);
52        let cost_rl = implicit_cast_cost(ty, &result);
53
54        match (cost_lr, cost_rl) {
55            (Some(_), Some(0)) => {
56                // result can be cast to ty, but ty is exactly result — keep result.
57            }
58            (Some(cl), Some(cr)) => {
59                if cl <= cr {
60                    result = ty.clone();
61                }
62                // else keep result
63            }
64            (Some(_), None) => {
65                result = ty.clone();
66            }
67            (None, Some(_)) => {
68                // ty can be cast to result — keep result.
69            }
70            (None, None) => {
71                return Err(KyuError::Binder(format!(
72                    "incompatible types: {} and {}",
73                    result.type_name(),
74                    ty.type_name(),
75                )));
76            }
77        }
78    }
79
80    Ok(result)
81}
82
83/// Determine the result type for a binary arithmetic operation and coerce
84/// both operands if needed. Returns (coerced_left, coerced_right, result_type).
85pub fn coerce_binary_arithmetic(
86    left: BoundExpression,
87    right: BoundExpression,
88) -> KyuResult<(BoundExpression, BoundExpression, LogicalType)> {
89    let lt = left.result_type().clone();
90    let rt = right.result_type().clone();
91
92    let result = arithmetic_result_type(&lt, &rt).ok_or_else(|| {
93        KyuError::Binder(format!(
94            "arithmetic not defined for {} and {}",
95            lt.type_name(),
96            rt.type_name(),
97        ))
98    })?;
99
100    let left = try_coerce(left, &result)?;
101    let right = try_coerce(right, &result)?;
102
103    Ok((left, right, result))
104}
105
106/// Coerce both operands of a comparison to a common comparable type.
107/// Returns (coerced_left, coerced_right).
108pub fn coerce_comparison(
109    left: BoundExpression,
110    right: BoundExpression,
111) -> KyuResult<(BoundExpression, BoundExpression)> {
112    let lt = left.result_type().clone();
113    let rt = right.result_type().clone();
114
115    if !are_comparable(&lt, &rt) {
116        return Err(KyuError::Binder(format!(
117            "cannot compare {} and {}",
118            lt.type_name(),
119            rt.type_name(),
120        )));
121    }
122
123    // Find common type and coerce both sides.
124    let target = common_type(&[lt, rt])?;
125    let left = try_coerce(left, &target)?;
126    let right = try_coerce(right, &target)?;
127
128    Ok((left, right))
129}
130
131/// Coerce both sides of a string concatenation to String.
132pub fn coerce_concat(
133    left: BoundExpression,
134    right: BoundExpression,
135) -> KyuResult<(BoundExpression, BoundExpression)> {
136    let left = try_coerce(left, &LogicalType::String)?;
137    let right = try_coerce(right, &LogicalType::String)?;
138    Ok((left, right))
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use kyu_types::TypedValue;
145    use smol_str::SmolStr;
146
147    fn lit(value: TypedValue, result_type: LogicalType) -> BoundExpression {
148        BoundExpression::Literal { value, result_type }
149    }
150
151    fn lit_int32(v: i32) -> BoundExpression {
152        lit(TypedValue::Int32(v), LogicalType::Int32)
153    }
154
155    fn lit_int64(v: i64) -> BoundExpression {
156        lit(TypedValue::Int64(v), LogicalType::Int64)
157    }
158
159    fn lit_double(v: f64) -> BoundExpression {
160        lit(TypedValue::Double(v), LogicalType::Double)
161    }
162
163    fn lit_str(s: &str) -> BoundExpression {
164        lit(TypedValue::String(SmolStr::new(s)), LogicalType::String)
165    }
166
167    fn lit_bool(v: bool) -> BoundExpression {
168        lit(TypedValue::Bool(v), LogicalType::Bool)
169    }
170
171    fn lit_null() -> BoundExpression {
172        lit(TypedValue::Null, LogicalType::Any)
173    }
174
175    #[test]
176    fn coerce_same_type_noop() {
177        let expr = lit_int64(42);
178        let result = try_coerce(expr, &LogicalType::Int64).unwrap();
179        assert!(matches!(result, BoundExpression::Literal { .. }));
180    }
181
182    #[test]
183    fn coerce_int32_to_int64() {
184        let expr = lit_int32(42);
185        let result = try_coerce(expr, &LogicalType::Int64).unwrap();
186        assert!(matches!(
187            result,
188            BoundExpression::Cast {
189                target_type: LogicalType::Int64,
190                ..
191            }
192        ));
193    }
194
195    #[test]
196    fn coerce_int_to_double() {
197        let expr = lit_int64(42);
198        let result = try_coerce(expr, &LogicalType::Double).unwrap();
199        assert!(matches!(
200            result,
201            BoundExpression::Cast {
202                target_type: LogicalType::Double,
203                ..
204            }
205        ));
206    }
207
208    #[test]
209    fn coerce_incompatible_error() {
210        let expr = lit_bool(true);
211        let result = try_coerce(expr, &LogicalType::Int64);
212        assert!(result.is_err());
213    }
214
215    #[test]
216    fn coerce_null_to_any_type() {
217        let expr = lit_null();
218        let result = try_coerce(expr, &LogicalType::Int64).unwrap();
219        assert!(matches!(
220            result,
221            BoundExpression::Cast {
222                target_type: LogicalType::Int64,
223                ..
224            }
225        ));
226    }
227
228    #[test]
229    fn common_type_same() {
230        let result = common_type(&[LogicalType::Int64, LogicalType::Int64]).unwrap();
231        assert_eq!(result, LogicalType::Int64);
232    }
233
234    #[test]
235    fn common_type_widening() {
236        let result = common_type(&[LogicalType::Int32, LogicalType::Int64]).unwrap();
237        assert_eq!(result, LogicalType::Int64);
238    }
239
240    #[test]
241    fn common_type_int_float() {
242        let result =
243            common_type(&[LogicalType::Int32, LogicalType::Int64, LogicalType::Double]).unwrap();
244        assert_eq!(result, LogicalType::Double);
245    }
246
247    #[test]
248    fn common_type_incompatible() {
249        let result = common_type(&[LogicalType::Bool, LogicalType::Int64]);
250        assert!(result.is_err());
251    }
252
253    #[test]
254    fn binary_arithmetic_coercion() {
255        let (l, r, rt) = coerce_binary_arithmetic(lit_int32(1), lit_int64(2)).unwrap();
256        assert_eq!(rt, LogicalType::Int64);
257        assert!(matches!(l, BoundExpression::Cast { .. }));
258        assert!(matches!(r, BoundExpression::Literal { .. }));
259    }
260
261    #[test]
262    fn comparison_coercion() {
263        let (l, r) = coerce_comparison(lit_int32(1), lit_double(2.0)).unwrap();
264        assert_eq!(l.result_type(), &LogicalType::Double);
265        assert_eq!(r.result_type(), &LogicalType::Double);
266    }
267
268    #[test]
269    fn concat_coercion() {
270        let (l, r) = coerce_concat(lit_str("a"), lit_int64(1)).unwrap();
271        assert_eq!(l.result_type(), &LogicalType::String);
272        assert_eq!(r.result_type(), &LogicalType::String);
273    }
274}