1use super::core::{BinaryOperator, Expr, Literal, UnaryOperator};
7use super::projection::ColumnProjection;
8use super::schema::{ColumnMeta, ExprSchema};
9use super::ExprDataType;
10use crate::error::{Error, Result};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct InferredType {
16 pub data_type: ExprDataType,
18 pub nullable: bool,
20}
21
22impl InferredType {
23 pub fn new(data_type: ExprDataType, nullable: bool) -> Self {
25 Self {
26 data_type,
27 nullable,
28 }
29 }
30
31 pub fn boolean(nullable: bool) -> Self {
33 Self::new(ExprDataType::Boolean, nullable)
34 }
35
36 pub fn integer(nullable: bool) -> Self {
38 Self::new(ExprDataType::Integer, nullable)
39 }
40
41 pub fn float(nullable: bool) -> Self {
43 Self::new(ExprDataType::Float, nullable)
44 }
45
46 pub fn string(nullable: bool) -> Self {
48 Self::new(ExprDataType::String, nullable)
49 }
50
51 pub fn date(nullable: bool) -> Self {
53 Self::new(ExprDataType::Date, nullable)
54 }
55
56 pub fn timestamp(nullable: bool) -> Self {
58 Self::new(ExprDataType::Timestamp, nullable)
59 }
60}
61
62pub struct ExprValidator<'a> {
64 schema: &'a ExprSchema,
66 function_types: HashMap<String, (ExprDataType, Vec<ExprDataType>)>,
68}
69
70impl<'a> ExprValidator<'a> {
71 pub fn new(schema: &'a ExprSchema) -> Self {
73 let mut function_types = HashMap::new();
75
76 function_types.insert(
78 "lower".to_string(),
79 (ExprDataType::String, vec![ExprDataType::String]),
80 );
81 function_types.insert(
82 "upper".to_string(),
83 (ExprDataType::String, vec![ExprDataType::String]),
84 );
85 function_types.insert(
86 "concat".to_string(),
87 (
88 ExprDataType::String,
89 vec![ExprDataType::String, ExprDataType::String],
90 ),
91 );
92 function_types.insert(
93 "trim".to_string(),
94 (ExprDataType::String, vec![ExprDataType::String]),
95 );
96
97 function_types.insert(
99 "abs".to_string(),
100 (ExprDataType::Float, vec![ExprDataType::Float]),
101 );
102 function_types.insert(
103 "round".to_string(),
104 (
105 ExprDataType::Float,
106 vec![ExprDataType::Float, ExprDataType::Integer],
107 ),
108 );
109 function_types.insert(
110 "floor".to_string(),
111 (ExprDataType::Float, vec![ExprDataType::Float]),
112 );
113 function_types.insert(
114 "ceiling".to_string(),
115 (ExprDataType::Float, vec![ExprDataType::Float]),
116 );
117
118 function_types.insert(
120 "date_trunc".to_string(),
121 (
122 ExprDataType::Timestamp,
123 vec![ExprDataType::String, ExprDataType::Timestamp],
124 ),
125 );
126 function_types.insert(
127 "extract".to_string(),
128 (
129 ExprDataType::Integer,
130 vec![ExprDataType::String, ExprDataType::Timestamp],
131 ),
132 );
133
134 function_types.insert(
136 "min".to_string(),
137 (ExprDataType::Float, vec![ExprDataType::Float]),
138 );
139 function_types.insert(
140 "max".to_string(),
141 (ExprDataType::Float, vec![ExprDataType::Float]),
142 );
143 function_types.insert(
144 "sum".to_string(),
145 (ExprDataType::Float, vec![ExprDataType::Float]),
146 );
147 function_types.insert(
148 "avg".to_string(),
149 (ExprDataType::Float, vec![ExprDataType::Float]),
150 );
151 function_types.insert(
152 "count".to_string(),
153 (ExprDataType::Integer, vec![ExprDataType::String]),
154 );
155
156 Self {
157 schema,
158 function_types,
159 }
160 }
161
162 pub fn add_udf(
164 &mut self,
165 name: impl Into<String>,
166 return_type: ExprDataType,
167 parameter_types: Vec<ExprDataType>,
168 ) -> &mut Self {
169 self.function_types
170 .insert(name.into(), (return_type, parameter_types));
171 self
172 }
173
174 pub fn validate_expr(&self, expr: &Expr) -> Result<InferredType> {
176 match expr {
177 Expr::Column(name) => {
178 if let Some(col_meta) = self.schema.column(name) {
180 Ok(InferredType::new(
181 col_meta.data_type.clone(),
182 col_meta.nullable,
183 ))
184 } else {
185 Err(Error::InvalidOperation(format!(
186 "Column '{}' not found in schema",
187 name
188 )))
189 }
190 }
191 Expr::Literal(lit) => {
192 match lit {
194 Literal::Null => Err(Error::InvalidOperation(
195 "Cannot infer type for NULL literal without context".to_string(),
196 )),
197 Literal::Boolean(_) => Ok(InferredType::boolean(false)),
198 Literal::Integer(_) => Ok(InferredType::integer(false)),
199 Literal::Float(_) => Ok(InferredType::float(false)),
200 Literal::String(_) => Ok(InferredType::string(false)),
201 }
202 }
203 Expr::BinaryOp { left, op, right } => {
204 let left_type = self.validate_expr(left)?;
206 let right_type = self.validate_expr(right)?;
207
208 let nullable = left_type.nullable || right_type.nullable;
210
211 match op {
213 BinaryOperator::Add
215 | BinaryOperator::Subtract
216 | BinaryOperator::Multiply
217 | BinaryOperator::Divide
218 | BinaryOperator::Modulo => {
219 match (left_type.data_type.clone(), right_type.data_type.clone()) {
221 (ExprDataType::Integer, ExprDataType::Integer) => {
223 if *op == BinaryOperator::Divide {
224 Ok(InferredType::float(nullable))
225 } else {
226 Ok(InferredType::integer(nullable))
227 }
228 }
229 (ExprDataType::Float, _) | (_, ExprDataType::Float) => {
231 Ok(InferredType::float(nullable))
232 }
233 _ => Err(Error::InvalidOperation(format!(
235 "Invalid operand types for arithmetic operation: {:?} {} {:?}",
236 left_type.data_type, op, right_type.data_type
237 ))),
238 }
239 }
240 BinaryOperator::Equal
242 | BinaryOperator::NotEqual
243 | BinaryOperator::LessThan
244 | BinaryOperator::LessThanOrEqual
245 | BinaryOperator::GreaterThan
246 | BinaryOperator::GreaterThanOrEqual => {
247 match (left_type.data_type.clone(), right_type.data_type.clone()) {
250 (a, b) if a == b => Ok(InferredType::boolean(nullable)),
252 (ExprDataType::Integer, ExprDataType::Float)
254 | (ExprDataType::Float, ExprDataType::Integer) => {
255 Ok(InferredType::boolean(nullable))
256 }
257 (ExprDataType::Date, ExprDataType::Timestamp)
259 | (ExprDataType::Timestamp, ExprDataType::Date) => {
260 Ok(InferredType::boolean(nullable))
261 }
262 _ => Err(Error::InvalidOperation(format!(
264 "Invalid operand types for comparison: {:?} {} {:?}",
265 left_type.data_type.clone(),
266 op,
267 right_type.data_type.clone()
268 ))),
269 }
270 }
271 BinaryOperator::And | BinaryOperator::Or => {
273 if left_type.data_type == ExprDataType::Boolean
275 && right_type.data_type == ExprDataType::Boolean
276 {
277 Ok(InferredType::boolean(nullable))
278 } else {
279 Err(Error::InvalidOperation(format!(
280 "Logical operations require boolean operands, got {:?} and {:?}",
281 left_type.data_type, right_type.data_type
282 )))
283 }
284 }
285 BinaryOperator::BitwiseAnd
287 | BinaryOperator::BitwiseOr
288 | BinaryOperator::BitwiseXor => {
289 if left_type.data_type == ExprDataType::Integer
291 && right_type.data_type == ExprDataType::Integer
292 {
293 Ok(InferredType::integer(nullable))
294 } else {
295 Err(Error::InvalidOperation(format!(
296 "Bitwise operations require integer operands, got {:?} and {:?}",
297 left_type.data_type, right_type.data_type
298 )))
299 }
300 }
301 BinaryOperator::Like => {
303 if left_type.data_type == ExprDataType::String
305 && right_type.data_type == ExprDataType::String
306 {
307 Ok(InferredType::boolean(nullable))
308 } else {
309 Err(Error::InvalidOperation(format!(
310 "LIKE operation requires string operands, got {:?} and {:?}",
311 left_type.data_type, right_type.data_type
312 )))
313 }
314 }
315 BinaryOperator::Concat => {
317 if left_type.data_type == ExprDataType::String
319 && right_type.data_type == ExprDataType::String
320 {
321 Ok(InferredType::string(nullable))
322 } else {
323 Err(Error::InvalidOperation(format!(
324 "String concatenation requires string operands, got {:?} and {:?}",
325 left_type.data_type, right_type.data_type
326 )))
327 }
328 }
329 }
330 }
331 Expr::UnaryOp { op, expr } => {
332 let expr_type = self.validate_expr(expr)?;
334
335 match op {
336 UnaryOperator::Negate => {
338 match expr_type.data_type {
340 ExprDataType::Integer => Ok(InferredType::integer(expr_type.nullable)),
341 ExprDataType::Float => Ok(InferredType::float(expr_type.nullable)),
342 _ => Err(Error::InvalidOperation(format!(
343 "Negation requires numeric operand, got {:?}",
344 expr_type.data_type
345 ))),
346 }
347 }
348 UnaryOperator::Not => {
350 if expr_type.data_type == ExprDataType::Boolean {
352 Ok(InferredType::boolean(expr_type.nullable))
353 } else {
354 Err(Error::InvalidOperation(format!(
355 "Logical NOT requires boolean operand, got {:?}",
356 expr_type.data_type
357 )))
358 }
359 }
360 UnaryOperator::IsNull | UnaryOperator::IsNotNull => {
362 Ok(InferredType::boolean(false))
364 }
365 }
366 }
367 Expr::Function { name, args } => {
368 if let Some((return_type, param_types)) = self.function_types.get(name) {
370 if args.len() != param_types.len() {
372 return Err(Error::InvalidOperation(format!(
373 "Function '{}' expects {} arguments, got {}",
374 name,
375 param_types.len(),
376 args.len()
377 )));
378 }
379
380 for (i, (arg, expected_type)) in args.iter().zip(param_types.iter()).enumerate()
382 {
383 let arg_type = self.validate_expr(arg)?;
384
385 if arg_type.data_type != *expected_type {
387 return Err(Error::InvalidOperation(
388 format!("Function '{}' argument {} has invalid type: expected {:?}, got {:?}",
389 name, i + 1, expected_type, arg_type.data_type)
390 ));
391 }
392 }
393
394 Ok(InferredType::new(return_type.clone(), true))
396 } else {
397 Err(Error::InvalidOperation(format!(
398 "Function '{}' not found",
399 name
400 )))
401 }
402 }
403 Expr::Case {
404 when_then,
405 else_expr,
406 } => {
407 if when_then.is_empty() {
408 return Err(Error::InvalidOperation(
409 "CASE expression must have at least one WHEN clause".to_string(),
410 ));
411 }
412
413 for (when, _) in when_then.iter() {
415 let when_type = self.validate_expr(when)?;
416 if when_type.data_type != ExprDataType::Boolean {
417 return Err(Error::InvalidOperation(format!(
418 "CASE WHEN condition must be boolean, got {:?}",
419 when_type.data_type
420 )));
421 }
422 }
423
424 let first_then_type = self.validate_expr(&when_then[0].1)?;
426 let mut nullable = first_then_type.nullable;
427
428 for (_, then) in when_then.iter().skip(1) {
430 let then_type = self.validate_expr(then)?;
431 if then_type.data_type != first_then_type.data_type {
432 return Err(Error::InvalidOperation(format!(
433 "CASE THEN expressions must have the same type: {:?} vs {:?}",
434 first_then_type.data_type, then_type.data_type
435 )));
436 }
437 nullable = nullable || then_type.nullable;
438 }
439
440 if let Some(else_expr) = else_expr {
442 let else_type = self.validate_expr(else_expr)?;
443 if else_type.data_type != first_then_type.data_type {
444 return Err(Error::InvalidOperation(
445 format!("CASE ELSE expression must have the same type as THEN expressions: {:?} vs {:?}",
446 first_then_type.data_type, else_type.data_type)
447 ));
448 }
449 nullable = nullable || else_type.nullable;
450 } else {
451 nullable = true;
453 }
454
455 Ok(InferredType::new(first_then_type.data_type, nullable))
456 }
457 Expr::Cast {
458 expr: inner,
459 data_type,
460 } => {
461 let inner_type = self.validate_expr(inner)?;
463
464 let valid_cast = match (inner_type.data_type.clone(), data_type) {
466 (ExprDataType::Integer, ExprDataType::Float)
468 | (ExprDataType::Float, ExprDataType::Integer) => true,
469
470 (ExprDataType::String, ExprDataType::Integer)
472 | (ExprDataType::String, ExprDataType::Float)
473 | (ExprDataType::Integer, ExprDataType::String)
474 | (ExprDataType::Float, ExprDataType::String) => true,
475
476 (ExprDataType::Date, ExprDataType::Timestamp)
478 | (ExprDataType::Timestamp, ExprDataType::Date) => true,
479
480 (ExprDataType::String, ExprDataType::Date)
482 | (ExprDataType::String, ExprDataType::Timestamp)
483 | (ExprDataType::Date, ExprDataType::String)
484 | (ExprDataType::Timestamp, ExprDataType::String) => true,
485
486 (ExprDataType::Boolean, ExprDataType::Integer)
488 | (ExprDataType::Boolean, ExprDataType::String)
489 | (ExprDataType::Integer, ExprDataType::Boolean)
490 | (ExprDataType::String, ExprDataType::Boolean) => true,
491
492 (a, b) if a == *b => true,
494
495 _ => false,
497 };
498
499 if valid_cast {
500 Ok(InferredType::new(data_type.clone(), inner_type.nullable))
501 } else {
502 Err(Error::InvalidOperation(format!(
503 "Invalid cast from {:?} to {:?}",
504 inner_type.data_type.clone(),
505 data_type
506 )))
507 }
508 }
509 Expr::Coalesce { exprs } => {
510 if exprs.is_empty() {
511 return Err(Error::InvalidOperation(
512 "COALESCE expression must have at least one argument".to_string(),
513 ));
514 }
515
516 let first_type = self.validate_expr(&exprs[0])?;
518 let mut nullable = first_type.nullable;
519
520 for expr in exprs.iter().skip(1) {
522 let expr_type = self.validate_expr(expr)?;
523 if expr_type.data_type != first_type.data_type {
524 return Err(Error::InvalidOperation(format!(
525 "COALESCE expressions must have the same type: {:?} vs {:?}",
526 first_type.data_type, expr_type.data_type
527 )));
528 }
529 nullable = nullable && expr_type.nullable;
530 }
531
532 Ok(InferredType::new(first_type.data_type, nullable))
534 }
535 }
536 }
537
538 pub fn validate_projections(
540 &self,
541 projections: &[ColumnProjection],
542 ) -> Result<HashMap<String, InferredType>> {
543 let mut result = HashMap::new();
544
545 for projection in projections {
546 let expr_type = self.validate_expr(&projection.expr)?;
548
549 let output_name = projection.output_name();
551
552 result.insert(output_name, expr_type);
554 }
555
556 Ok(result)
557 }
558}