1use kyu_common::{KyuError, KyuResult};
7use kyu_types::LogicalType;
8use kyu_types::type_utils::{are_comparable, arithmetic_result_type, implicit_cast_cost};
9
10use crate::bound_expr::BoundExpression;
11
12pub fn try_coerce(expr: BoundExpression, target: &LogicalType) -> KyuResult<BoundExpression> {
17 let from = expr.result_type();
18 if from == target {
19 return Ok(expr);
20 }
21 if implicit_cast_cost(from, target).is_some() {
23 Ok(BoundExpression::Cast {
24 expr: Box::new(expr),
25 target_type: target.clone(),
26 })
27 } else {
28 Err(KyuError::Binder(format!(
29 "cannot implicitly cast {} to {}",
30 from.type_name(),
31 target.type_name(),
32 )))
33 }
34}
35
36pub fn common_type(types: &[LogicalType]) -> KyuResult<LogicalType> {
41 if types.is_empty() {
42 return Ok(LogicalType::Any);
43 }
44
45 let mut result = types[0].clone();
46 for ty in &types[1..] {
47 if *ty == result {
48 continue;
49 }
50 let cost_lr = implicit_cast_cost(&result, ty);
52 let cost_rl = implicit_cast_cost(ty, &result);
53
54 match (cost_lr, cost_rl) {
55 (Some(_), Some(0)) => {
56 }
58 (Some(cl), Some(cr)) => {
59 if cl <= cr {
60 result = ty.clone();
61 }
62 }
64 (Some(_), None) => {
65 result = ty.clone();
66 }
67 (None, Some(_)) => {
68 }
70 (None, None) => {
71 return Err(KyuError::Binder(format!(
72 "incompatible types: {} and {}",
73 result.type_name(),
74 ty.type_name(),
75 )));
76 }
77 }
78 }
79
80 Ok(result)
81}
82
83pub fn coerce_binary_arithmetic(
86 left: BoundExpression,
87 right: BoundExpression,
88) -> KyuResult<(BoundExpression, BoundExpression, LogicalType)> {
89 let lt = left.result_type().clone();
90 let rt = right.result_type().clone();
91
92 let result = arithmetic_result_type(<, &rt).ok_or_else(|| {
93 KyuError::Binder(format!(
94 "arithmetic not defined for {} and {}",
95 lt.type_name(),
96 rt.type_name(),
97 ))
98 })?;
99
100 let left = try_coerce(left, &result)?;
101 let right = try_coerce(right, &result)?;
102
103 Ok((left, right, result))
104}
105
106pub fn coerce_comparison(
109 left: BoundExpression,
110 right: BoundExpression,
111) -> KyuResult<(BoundExpression, BoundExpression)> {
112 let lt = left.result_type().clone();
113 let rt = right.result_type().clone();
114
115 if !are_comparable(<, &rt) {
116 return Err(KyuError::Binder(format!(
117 "cannot compare {} and {}",
118 lt.type_name(),
119 rt.type_name(),
120 )));
121 }
122
123 let target = common_type(&[lt, rt])?;
125 let left = try_coerce(left, &target)?;
126 let right = try_coerce(right, &target)?;
127
128 Ok((left, right))
129}
130
131pub fn coerce_concat(
133 left: BoundExpression,
134 right: BoundExpression,
135) -> KyuResult<(BoundExpression, BoundExpression)> {
136 let left = try_coerce(left, &LogicalType::String)?;
137 let right = try_coerce(right, &LogicalType::String)?;
138 Ok((left, right))
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use kyu_types::TypedValue;
145 use smol_str::SmolStr;
146
147 fn lit(value: TypedValue, result_type: LogicalType) -> BoundExpression {
148 BoundExpression::Literal { value, result_type }
149 }
150
151 fn lit_int32(v: i32) -> BoundExpression {
152 lit(TypedValue::Int32(v), LogicalType::Int32)
153 }
154
155 fn lit_int64(v: i64) -> BoundExpression {
156 lit(TypedValue::Int64(v), LogicalType::Int64)
157 }
158
159 fn lit_double(v: f64) -> BoundExpression {
160 lit(TypedValue::Double(v), LogicalType::Double)
161 }
162
163 fn lit_str(s: &str) -> BoundExpression {
164 lit(TypedValue::String(SmolStr::new(s)), LogicalType::String)
165 }
166
167 fn lit_bool(v: bool) -> BoundExpression {
168 lit(TypedValue::Bool(v), LogicalType::Bool)
169 }
170
171 fn lit_null() -> BoundExpression {
172 lit(TypedValue::Null, LogicalType::Any)
173 }
174
175 #[test]
176 fn coerce_same_type_noop() {
177 let expr = lit_int64(42);
178 let result = try_coerce(expr, &LogicalType::Int64).unwrap();
179 assert!(matches!(result, BoundExpression::Literal { .. }));
180 }
181
182 #[test]
183 fn coerce_int32_to_int64() {
184 let expr = lit_int32(42);
185 let result = try_coerce(expr, &LogicalType::Int64).unwrap();
186 assert!(matches!(
187 result,
188 BoundExpression::Cast {
189 target_type: LogicalType::Int64,
190 ..
191 }
192 ));
193 }
194
195 #[test]
196 fn coerce_int_to_double() {
197 let expr = lit_int64(42);
198 let result = try_coerce(expr, &LogicalType::Double).unwrap();
199 assert!(matches!(
200 result,
201 BoundExpression::Cast {
202 target_type: LogicalType::Double,
203 ..
204 }
205 ));
206 }
207
208 #[test]
209 fn coerce_incompatible_error() {
210 let expr = lit_bool(true);
211 let result = try_coerce(expr, &LogicalType::Int64);
212 assert!(result.is_err());
213 }
214
215 #[test]
216 fn coerce_null_to_any_type() {
217 let expr = lit_null();
218 let result = try_coerce(expr, &LogicalType::Int64).unwrap();
219 assert!(matches!(
220 result,
221 BoundExpression::Cast {
222 target_type: LogicalType::Int64,
223 ..
224 }
225 ));
226 }
227
228 #[test]
229 fn common_type_same() {
230 let result = common_type(&[LogicalType::Int64, LogicalType::Int64]).unwrap();
231 assert_eq!(result, LogicalType::Int64);
232 }
233
234 #[test]
235 fn common_type_widening() {
236 let result = common_type(&[LogicalType::Int32, LogicalType::Int64]).unwrap();
237 assert_eq!(result, LogicalType::Int64);
238 }
239
240 #[test]
241 fn common_type_int_float() {
242 let result =
243 common_type(&[LogicalType::Int32, LogicalType::Int64, LogicalType::Double]).unwrap();
244 assert_eq!(result, LogicalType::Double);
245 }
246
247 #[test]
248 fn common_type_incompatible() {
249 let result = common_type(&[LogicalType::Bool, LogicalType::Int64]);
250 assert!(result.is_err());
251 }
252
253 #[test]
254 fn binary_arithmetic_coercion() {
255 let (l, r, rt) = coerce_binary_arithmetic(lit_int32(1), lit_int64(2)).unwrap();
256 assert_eq!(rt, LogicalType::Int64);
257 assert!(matches!(l, BoundExpression::Cast { .. }));
258 assert!(matches!(r, BoundExpression::Literal { .. }));
259 }
260
261 #[test]
262 fn comparison_coercion() {
263 let (l, r) = coerce_comparison(lit_int32(1), lit_double(2.0)).unwrap();
264 assert_eq!(l.result_type(), &LogicalType::Double);
265 assert_eq!(r.result_type(), &LogicalType::Double);
266 }
267
268 #[test]
269 fn concat_coercion() {
270 let (l, r) = coerce_concat(lit_str("a"), lit_int64(1)).unwrap();
271 assert_eq!(l.result_type(), &LogicalType::String);
272 assert_eq!(r.result_type(), &LogicalType::String);
273 }
274}