odata_params/filters/
validate.rs

1use super::{Expr, FunctionsTypeMap, IdentifiersTypeMap, Type, ValidationError, Value};
2use std::iter::repeat;
3
4impl Expr {
5    /// Validates if the types within the expression are correct and
6    /// if the expression overall is a boolean type.
7    ///
8    /// A `Result` which is `Ok(true)` if the expression is a valid boolean
9    /// expression, or an `Err` with a `ValidationError` if the types are not valid.
10    ///
11    /// ```
12    /// use std::collections::HashMap;
13    /// use odata_params::filters::{Expr, FunctionsTypeMap, IdentifiersTypeMap, Type};
14    ///
15    /// let mut id_map = HashMap::new();
16    /// id_map.insert("value".to_string(), Type::Boolean);
17    /// let identifiers = IdentifiersTypeMap::from(id_map);
18    ///
19    /// let functions = FunctionsTypeMap::from(HashMap::new());
20    ///
21    /// let expr = Expr::Identifier("value".to_string());
22    ///
23    /// assert_eq!(expr.are_types_valid(&identifiers, &functions), Ok(true));
24    /// ```
25    pub fn are_types_valid(
26        &self,
27        identifiers: &IdentifiersTypeMap,
28        functions: &FunctionsTypeMap,
29    ) -> Result<bool, ValidationError> {
30        let overall_type = self.validate(identifiers, functions)?;
31
32        Ok(overall_type == Type::Boolean)
33    }
34
35    /// Validates the types within the expression.
36    ///
37    /// A `Result` which is `Ok` with the type of the expression if the types
38    /// are valid, or an `Err` with a `ValidationError` if the types are not valid.
39    ///
40    /// ```
41    /// use std::collections::HashMap;
42    /// use odata_params::filters::{Expr, FunctionsTypeMap, IdentifiersTypeMap, Type};
43    ///
44    /// let mut id_map = HashMap::new();
45    /// id_map.insert("value".to_string(), Type::Number);
46    /// let identifiers = IdentifiersTypeMap::from(id_map);
47    ///
48    /// let mut func_map = HashMap::new();
49    /// func_map.insert(
50    ///     "sum".to_string(),
51    ///     (vec![Type::Number], None, Type::Number),
52    /// );
53    /// let functions = FunctionsTypeMap::from(func_map);
54    ///
55    /// let expr = Expr::Function("sum".to_string(), vec![Expr::Identifier("value".to_string())]);
56    ///
57    /// assert_eq!(expr.validate(&identifiers, &functions), Ok(Type::Number));
58    /// ```
59    pub fn validate(
60        &self,
61        identifiers: &IdentifiersTypeMap,
62        functions: &FunctionsTypeMap,
63    ) -> Result<Type, ValidationError> {
64        match self {
65            Expr::Or(lhs, rhs) | Expr::And(lhs, rhs) => {
66                let lhs_type = Self::validate(lhs, identifiers, functions)?;
67                let rhs_type = Self::validate(rhs, identifiers, functions)?;
68
69                if lhs_type == Type::Boolean && rhs_type == Type::Boolean {
70                    Ok(Type::Boolean)
71                } else {
72                    Err(ValidationError::LogicalJoinRequiresBooleans {
73                        lhs: lhs_type,
74                        rhs: rhs_type,
75                    })
76                }
77            }
78
79            Expr::Not(inner) => {
80                let inner_type = Self::validate(inner, identifiers, functions)?;
81
82                if inner_type == Type::Boolean {
83                    Ok(Type::Boolean)
84                } else {
85                    Err(ValidationError::LogicalNotRequiresBoolean { given: inner_type })
86                }
87            }
88
89            Expr::Compare(lhs, _op, rhs) => {
90                let lhs_type = Self::validate(lhs, identifiers, functions)?;
91                let rhs_type = Self::validate(rhs, identifiers, functions)?;
92
93                if lhs_type == rhs_type {
94                    Ok(Type::Boolean)
95                } else {
96                    Err(ValidationError::ComparingIncompatibleTypes {
97                        lhs: lhs_type,
98                        rhs: rhs_type,
99                    })
100                }
101            }
102
103            Expr::In(lhs, values) => {
104                let lhs_type = Self::validate(lhs, identifiers, functions)?;
105
106                for value in values {
107                    let value_type = Self::validate(value, identifiers, functions)?;
108
109                    if lhs_type != value_type {
110                        return Err(ValidationError::ComparingIncompatibleTypes {
111                            lhs: lhs_type,
112                            rhs: value_type,
113                        });
114                    }
115                }
116
117                Ok(Type::Boolean)
118            }
119
120            Expr::Function(function, args) => {
121                let (types, variadic, ret) = functions.0.get(function).ok_or_else(|| {
122                    ValidationError::UndefinedFunction {
123                        name: function.to_owned(),
124                    }
125                })?;
126
127                println!(":: {types:?}, {variadic:?}, {args:?}");
128
129                if (variadic.is_none() && types.len() != args.len())
130                    || (variadic.is_some() && types.len() > args.len())
131                {
132                    return Err(ValidationError::IncorrectFunctionArgumentsCount {
133                        name: function.to_owned(),
134                        is_variadic: variadic.is_some(),
135                        expected: types.len(),
136                        given: args.len(),
137                    });
138                }
139
140                // It should be safe to setup an infinite chain of nulls when
141                // `variadic` is not set since we should have already exited
142                // early when `variadic` is None and `types` have a different
143                // length than the given arguments.
144                //
145                // This is needed to have consistent types without needing to
146                // collect eagerly. The `.zip` is what keeps the infinite
147                // iterator fixed to the length of given arguments.
148                let types = args.iter().zip(
149                    types
150                        .iter()
151                        .copied()
152                        .chain(repeat(variadic.unwrap_or(Type::Null))),
153                );
154
155                for (index, (arg, expected_type)) in types.enumerate() {
156                    let arg_type = Self::validate(arg, identifiers, functions)?;
157
158                    if arg_type != expected_type {
159                        return Err(ValidationError::IncorrectFunctionArgumentType {
160                            name: function.to_owned(),
161                            position: index + 1,
162                            expected: expected_type,
163                            given: arg_type,
164                        });
165                    }
166                }
167
168                Ok(*ret)
169            }
170
171            Expr::Identifier(identifier) => {
172                identifiers.0.get(identifier).copied().ok_or_else(|| {
173                    ValidationError::UndefinedIdentifier {
174                        name: identifier.to_owned(),
175                    }
176                })
177            }
178
179            Expr::Value(value) => Ok(match value {
180                Value::Null => Type::Null,
181                Value::Bool(_) => Type::Boolean,
182                Value::Number(_) => Type::Number,
183                Value::Uuid(_) => Type::Uuid,
184                Value::DateTime(_) => Type::DateTime,
185                Value::Date(_) => Type::Date,
186                Value::Time(_) => Type::Time,
187                Value::String(_) => Type::String,
188            }),
189        }
190    }
191}