Skip to main content

pandrs/distributed/expr/
validator.rs

1//! # Expression Validation
2//!
3//! This module provides validation capabilities for expressions, ensuring type safety
4//! before execution.
5
6use super::core::{BinaryOperator, Expr, Literal, UnaryOperator};
7use super::projection::ColumnProjection;
8use super::schema::{ColumnMeta, ExprSchema};
9use super::ExprDataType;
10use crate::error::{Error, Result};
11use std::collections::HashMap;
12
13/// Inferred type for an expression
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct InferredType {
16    /// Data type
17    pub data_type: ExprDataType,
18    /// Whether the expression can be null
19    pub nullable: bool,
20}
21
22impl InferredType {
23    /// Creates a new inferred type
24    pub fn new(data_type: ExprDataType, nullable: bool) -> Self {
25        Self {
26            data_type,
27            nullable,
28        }
29    }
30
31    /// Gets a boolean type
32    pub fn boolean(nullable: bool) -> Self {
33        Self::new(ExprDataType::Boolean, nullable)
34    }
35
36    /// Gets an integer type
37    pub fn integer(nullable: bool) -> Self {
38        Self::new(ExprDataType::Integer, nullable)
39    }
40
41    /// Gets a float type
42    pub fn float(nullable: bool) -> Self {
43        Self::new(ExprDataType::Float, nullable)
44    }
45
46    /// Gets a string type
47    pub fn string(nullable: bool) -> Self {
48        Self::new(ExprDataType::String, nullable)
49    }
50
51    /// Gets a date type
52    pub fn date(nullable: bool) -> Self {
53        Self::new(ExprDataType::Date, nullable)
54    }
55
56    /// Gets a timestamp type
57    pub fn timestamp(nullable: bool) -> Self {
58        Self::new(ExprDataType::Timestamp, nullable)
59    }
60}
61
62/// Validates expressions using schema information
63pub struct ExprValidator<'a> {
64    /// Schema to validate against
65    schema: &'a ExprSchema,
66    /// Function return types
67    function_types: HashMap<String, (ExprDataType, Vec<ExprDataType>)>,
68}
69
70impl<'a> ExprValidator<'a> {
71    /// Creates a new validator
72    pub fn new(schema: &'a ExprSchema) -> Self {
73        // Initialize with standard functions and their return types
74        let mut function_types = HashMap::new();
75
76        // String functions
77        function_types.insert(
78            "lower".to_string(),
79            (ExprDataType::String, vec![ExprDataType::String]),
80        );
81        function_types.insert(
82            "upper".to_string(),
83            (ExprDataType::String, vec![ExprDataType::String]),
84        );
85        function_types.insert(
86            "concat".to_string(),
87            (
88                ExprDataType::String,
89                vec![ExprDataType::String, ExprDataType::String],
90            ),
91        );
92        function_types.insert(
93            "trim".to_string(),
94            (ExprDataType::String, vec![ExprDataType::String]),
95        );
96
97        // Numeric functions
98        function_types.insert(
99            "abs".to_string(),
100            (ExprDataType::Float, vec![ExprDataType::Float]),
101        );
102        function_types.insert(
103            "round".to_string(),
104            (
105                ExprDataType::Float,
106                vec![ExprDataType::Float, ExprDataType::Integer],
107            ),
108        );
109        function_types.insert(
110            "floor".to_string(),
111            (ExprDataType::Float, vec![ExprDataType::Float]),
112        );
113        function_types.insert(
114            "ceiling".to_string(),
115            (ExprDataType::Float, vec![ExprDataType::Float]),
116        );
117
118        // Date/Time functions
119        function_types.insert(
120            "date_trunc".to_string(),
121            (
122                ExprDataType::Timestamp,
123                vec![ExprDataType::String, ExprDataType::Timestamp],
124            ),
125        );
126        function_types.insert(
127            "extract".to_string(),
128            (
129                ExprDataType::Integer,
130                vec![ExprDataType::String, ExprDataType::Timestamp],
131            ),
132        );
133
134        // Aggregate functions
135        function_types.insert(
136            "min".to_string(),
137            (ExprDataType::Float, vec![ExprDataType::Float]),
138        );
139        function_types.insert(
140            "max".to_string(),
141            (ExprDataType::Float, vec![ExprDataType::Float]),
142        );
143        function_types.insert(
144            "sum".to_string(),
145            (ExprDataType::Float, vec![ExprDataType::Float]),
146        );
147        function_types.insert(
148            "avg".to_string(),
149            (ExprDataType::Float, vec![ExprDataType::Float]),
150        );
151        function_types.insert(
152            "count".to_string(),
153            (ExprDataType::Integer, vec![ExprDataType::String]),
154        );
155
156        Self {
157            schema,
158            function_types,
159        }
160    }
161
162    /// Adds a user-defined function
163    pub fn add_udf(
164        &mut self,
165        name: impl Into<String>,
166        return_type: ExprDataType,
167        parameter_types: Vec<ExprDataType>,
168    ) -> &mut Self {
169        self.function_types
170            .insert(name.into(), (return_type, parameter_types));
171        self
172    }
173
174    /// Validates an expression and infers its type
175    pub fn validate_expr(&self, expr: &Expr) -> Result<InferredType> {
176        match expr {
177            Expr::Column(name) => {
178                // Check if column exists in schema
179                if let Some(col_meta) = self.schema.column(name) {
180                    Ok(InferredType::new(
181                        col_meta.data_type.clone(),
182                        col_meta.nullable,
183                    ))
184                } else {
185                    Err(Error::InvalidOperation(format!(
186                        "Column '{}' not found in schema",
187                        name
188                    )))
189                }
190            }
191            Expr::Literal(lit) => {
192                // Infer type from literal
193                match lit {
194                    Literal::Null => Err(Error::InvalidOperation(
195                        "Cannot infer type for NULL literal without context".to_string(),
196                    )),
197                    Literal::Boolean(_) => Ok(InferredType::boolean(false)),
198                    Literal::Integer(_) => Ok(InferredType::integer(false)),
199                    Literal::Float(_) => Ok(InferredType::float(false)),
200                    Literal::String(_) => Ok(InferredType::string(false)),
201                }
202            }
203            Expr::BinaryOp { left, op, right } => {
204                // Validate operands
205                let left_type = self.validate_expr(left)?;
206                let right_type = self.validate_expr(right)?;
207
208                // Result is nullable if either operand is nullable
209                let nullable = left_type.nullable || right_type.nullable;
210
211                // Determine result type based on operator and operand types
212                match op {
213                    // Arithmetic operations
214                    BinaryOperator::Add
215                    | BinaryOperator::Subtract
216                    | BinaryOperator::Multiply
217                    | BinaryOperator::Divide
218                    | BinaryOperator::Modulo => {
219                        // For arithmetic operations, result is numeric
220                        match (left_type.data_type.clone(), right_type.data_type.clone()) {
221                            // If both are integers, result is integer (except division)
222                            (ExprDataType::Integer, ExprDataType::Integer) => {
223                                if *op == BinaryOperator::Divide {
224                                    Ok(InferredType::float(nullable))
225                                } else {
226                                    Ok(InferredType::integer(nullable))
227                                }
228                            }
229                            // If either is float, result is float
230                            (ExprDataType::Float, _) | (_, ExprDataType::Float) => {
231                                Ok(InferredType::float(nullable))
232                            }
233                            // Other combinations are invalid
234                            _ => Err(Error::InvalidOperation(format!(
235                                "Invalid operand types for arithmetic operation: {:?} {} {:?}",
236                                left_type.data_type, op, right_type.data_type
237                            ))),
238                        }
239                    }
240                    // Comparison operations
241                    BinaryOperator::Equal
242                    | BinaryOperator::NotEqual
243                    | BinaryOperator::LessThan
244                    | BinaryOperator::LessThanOrEqual
245                    | BinaryOperator::GreaterThan
246                    | BinaryOperator::GreaterThanOrEqual => {
247                        // For comparison operations, result is boolean
248                        // Check if operand types are comparable
249                        match (left_type.data_type.clone(), right_type.data_type.clone()) {
250                            // Same types are comparable
251                            (a, b) if a == b => Ok(InferredType::boolean(nullable)),
252                            // Integer and float are comparable
253                            (ExprDataType::Integer, ExprDataType::Float)
254                            | (ExprDataType::Float, ExprDataType::Integer) => {
255                                Ok(InferredType::boolean(nullable))
256                            }
257                            // Date and timestamp are comparable
258                            (ExprDataType::Date, ExprDataType::Timestamp)
259                            | (ExprDataType::Timestamp, ExprDataType::Date) => {
260                                Ok(InferredType::boolean(nullable))
261                            }
262                            // Other combinations are invalid
263                            _ => Err(Error::InvalidOperation(format!(
264                                "Invalid operand types for comparison: {:?} {} {:?}",
265                                left_type.data_type.clone(),
266                                op,
267                                right_type.data_type.clone()
268                            ))),
269                        }
270                    }
271                    // Logical operations
272                    BinaryOperator::And | BinaryOperator::Or => {
273                        // Both operands must be boolean
274                        if left_type.data_type == ExprDataType::Boolean
275                            && right_type.data_type == ExprDataType::Boolean
276                        {
277                            Ok(InferredType::boolean(nullable))
278                        } else {
279                            Err(Error::InvalidOperation(format!(
280                                "Logical operations require boolean operands, got {:?} and {:?}",
281                                left_type.data_type, right_type.data_type
282                            )))
283                        }
284                    }
285                    // Bitwise operations
286                    BinaryOperator::BitwiseAnd
287                    | BinaryOperator::BitwiseOr
288                    | BinaryOperator::BitwiseXor => {
289                        // Both operands must be integers
290                        if left_type.data_type == ExprDataType::Integer
291                            && right_type.data_type == ExprDataType::Integer
292                        {
293                            Ok(InferredType::integer(nullable))
294                        } else {
295                            Err(Error::InvalidOperation(format!(
296                                "Bitwise operations require integer operands, got {:?} and {:?}",
297                                left_type.data_type, right_type.data_type
298                            )))
299                        }
300                    }
301                    // Like pattern matching
302                    BinaryOperator::Like => {
303                        // Left operand must be string, right operand must be string
304                        if left_type.data_type == ExprDataType::String
305                            && right_type.data_type == ExprDataType::String
306                        {
307                            Ok(InferredType::boolean(nullable))
308                        } else {
309                            Err(Error::InvalidOperation(format!(
310                                "LIKE operation requires string operands, got {:?} and {:?}",
311                                left_type.data_type, right_type.data_type
312                            )))
313                        }
314                    }
315                    // String concatenation
316                    BinaryOperator::Concat => {
317                        // Both operands must be strings
318                        if left_type.data_type == ExprDataType::String
319                            && right_type.data_type == ExprDataType::String
320                        {
321                            Ok(InferredType::string(nullable))
322                        } else {
323                            Err(Error::InvalidOperation(format!(
324                                "String concatenation requires string operands, got {:?} and {:?}",
325                                left_type.data_type, right_type.data_type
326                            )))
327                        }
328                    }
329                }
330            }
331            Expr::UnaryOp { op, expr } => {
332                // Validate operand
333                let expr_type = self.validate_expr(expr)?;
334
335                match op {
336                    // Negation
337                    UnaryOperator::Negate => {
338                        // Operand must be numeric
339                        match expr_type.data_type {
340                            ExprDataType::Integer => Ok(InferredType::integer(expr_type.nullable)),
341                            ExprDataType::Float => Ok(InferredType::float(expr_type.nullable)),
342                            _ => Err(Error::InvalidOperation(format!(
343                                "Negation requires numeric operand, got {:?}",
344                                expr_type.data_type
345                            ))),
346                        }
347                    }
348                    // Logical NOT
349                    UnaryOperator::Not => {
350                        // Operand must be boolean
351                        if expr_type.data_type == ExprDataType::Boolean {
352                            Ok(InferredType::boolean(expr_type.nullable))
353                        } else {
354                            Err(Error::InvalidOperation(format!(
355                                "Logical NOT requires boolean operand, got {:?}",
356                                expr_type.data_type
357                            )))
358                        }
359                    }
360                    // IS NULL and IS NOT NULL
361                    UnaryOperator::IsNull | UnaryOperator::IsNotNull => {
362                        // Can be applied to any type, result is boolean
363                        Ok(InferredType::boolean(false))
364                    }
365                }
366            }
367            Expr::Function { name, args } => {
368                // Check if function exists
369                if let Some((return_type, param_types)) = self.function_types.get(name) {
370                    // Validate number of arguments
371                    if args.len() != param_types.len() {
372                        return Err(Error::InvalidOperation(format!(
373                            "Function '{}' expects {} arguments, got {}",
374                            name,
375                            param_types.len(),
376                            args.len()
377                        )));
378                    }
379
380                    // Validate each argument
381                    for (i, (arg, expected_type)) in args.iter().zip(param_types.iter()).enumerate()
382                    {
383                        let arg_type = self.validate_expr(arg)?;
384
385                        // Check if argument type matches expected type
386                        if arg_type.data_type != *expected_type {
387                            return Err(Error::InvalidOperation(
388                                format!("Function '{}' argument {} has invalid type: expected {:?}, got {:?}",
389                                    name, i + 1, expected_type, arg_type.data_type)
390                            ));
391                        }
392                    }
393
394                    // Return function's return type
395                    Ok(InferredType::new(return_type.clone(), true))
396                } else {
397                    Err(Error::InvalidOperation(format!(
398                        "Function '{}' not found",
399                        name
400                    )))
401                }
402            }
403            Expr::Case {
404                when_then,
405                else_expr,
406            } => {
407                if when_then.is_empty() {
408                    return Err(Error::InvalidOperation(
409                        "CASE expression must have at least one WHEN clause".to_string(),
410                    ));
411                }
412
413                // Validate WHEN conditions (must be boolean)
414                for (when, _) in when_then.iter() {
415                    let when_type = self.validate_expr(when)?;
416                    if when_type.data_type != ExprDataType::Boolean {
417                        return Err(Error::InvalidOperation(format!(
418                            "CASE WHEN condition must be boolean, got {:?}",
419                            when_type.data_type
420                        )));
421                    }
422                }
423
424                // Get type of first THEN expression
425                let first_then_type = self.validate_expr(&when_then[0].1)?;
426                let mut nullable = first_then_type.nullable;
427
428                // Validate that all THEN expressions have the same type
429                for (_, then) in when_then.iter().skip(1) {
430                    let then_type = self.validate_expr(then)?;
431                    if then_type.data_type != first_then_type.data_type {
432                        return Err(Error::InvalidOperation(format!(
433                            "CASE THEN expressions must have the same type: {:?} vs {:?}",
434                            first_then_type.data_type, then_type.data_type
435                        )));
436                    }
437                    nullable = nullable || then_type.nullable;
438                }
439
440                // If ELSE expression exists, validate it has the same type as THEN expressions
441                if let Some(else_expr) = else_expr {
442                    let else_type = self.validate_expr(else_expr)?;
443                    if else_type.data_type != first_then_type.data_type {
444                        return Err(Error::InvalidOperation(
445                            format!("CASE ELSE expression must have the same type as THEN expressions: {:?} vs {:?}",
446                                first_then_type.data_type, else_type.data_type)
447                        ));
448                    }
449                    nullable = nullable || else_type.nullable;
450                } else {
451                    // If no ELSE expression, result is nullable (missing case)
452                    nullable = true;
453                }
454
455                Ok(InferredType::new(first_then_type.data_type, nullable))
456            }
457            Expr::Cast {
458                expr: inner,
459                data_type,
460            } => {
461                // Validate inner expression
462                let inner_type = self.validate_expr(inner)?;
463
464                // Check if cast is valid
465                let valid_cast = match (inner_type.data_type.clone(), data_type) {
466                    // Numeric conversions are valid
467                    (ExprDataType::Integer, ExprDataType::Float)
468                    | (ExprDataType::Float, ExprDataType::Integer) => true,
469
470                    // String to/from numeric conversions are valid
471                    (ExprDataType::String, ExprDataType::Integer)
472                    | (ExprDataType::String, ExprDataType::Float)
473                    | (ExprDataType::Integer, ExprDataType::String)
474                    | (ExprDataType::Float, ExprDataType::String) => true,
475
476                    // Date/Timestamp conversions are valid
477                    (ExprDataType::Date, ExprDataType::Timestamp)
478                    | (ExprDataType::Timestamp, ExprDataType::Date) => true,
479
480                    // String to/from date/timestamp conversions are valid
481                    (ExprDataType::String, ExprDataType::Date)
482                    | (ExprDataType::String, ExprDataType::Timestamp)
483                    | (ExprDataType::Date, ExprDataType::String)
484                    | (ExprDataType::Timestamp, ExprDataType::String) => true,
485
486                    // Boolean conversions
487                    (ExprDataType::Boolean, ExprDataType::Integer)
488                    | (ExprDataType::Boolean, ExprDataType::String)
489                    | (ExprDataType::Integer, ExprDataType::Boolean)
490                    | (ExprDataType::String, ExprDataType::Boolean) => true,
491
492                    // Same type is always valid
493                    (a, b) if a == *b => true,
494
495                    // Other conversions are invalid
496                    _ => false,
497                };
498
499                if valid_cast {
500                    Ok(InferredType::new(data_type.clone(), inner_type.nullable))
501                } else {
502                    Err(Error::InvalidOperation(format!(
503                        "Invalid cast from {:?} to {:?}",
504                        inner_type.data_type.clone(),
505                        data_type
506                    )))
507                }
508            }
509            Expr::Coalesce { exprs } => {
510                if exprs.is_empty() {
511                    return Err(Error::InvalidOperation(
512                        "COALESCE expression must have at least one argument".to_string(),
513                    ));
514                }
515
516                // Get type of first expression
517                let first_type = self.validate_expr(&exprs[0])?;
518                let mut nullable = first_type.nullable;
519
520                // Validate that all expressions have the same type
521                for expr in exprs.iter().skip(1) {
522                    let expr_type = self.validate_expr(expr)?;
523                    if expr_type.data_type != first_type.data_type {
524                        return Err(Error::InvalidOperation(format!(
525                            "COALESCE expressions must have the same type: {:?} vs {:?}",
526                            first_type.data_type, expr_type.data_type
527                        )));
528                    }
529                    nullable = nullable && expr_type.nullable;
530                }
531
532                // Result is nullable only if all expressions are nullable
533                Ok(InferredType::new(first_type.data_type, nullable))
534            }
535        }
536    }
537
538    /// Validates a list of projections
539    pub fn validate_projections(
540        &self,
541        projections: &[ColumnProjection],
542    ) -> Result<HashMap<String, InferredType>> {
543        let mut result = HashMap::new();
544
545        for projection in projections {
546            // Validate the expression
547            let expr_type = self.validate_expr(&projection.expr)?;
548
549            // Get the output name
550            let output_name = projection.output_name();
551
552            // Add to result
553            result.insert(output_name, expr_type);
554        }
555
556        Ok(result)
557    }
558}