Skip to main content

modkit_odata/
filter.rs

1use std::fmt;
2
3use thiserror::Error;
4
5use crate::ast as odata_ast;
6
7pub use crate::ast::Value as ODataValue;
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10pub enum FieldKind {
11    String,
12    I64,
13    F64,
14    Bool,
15    Uuid,
16    DateTimeUtc,
17    Date,
18    Time,
19    Decimal,
20}
21
22impl fmt::Display for FieldKind {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        match self {
25            FieldKind::String => write!(f, "String"),
26            FieldKind::I64 => write!(f, "I64"),
27            FieldKind::F64 => write!(f, "F64"),
28            FieldKind::Bool => write!(f, "Bool"),
29            FieldKind::Uuid => write!(f, "Uuid"),
30            FieldKind::DateTimeUtc => write!(f, "DateTimeUtc"),
31            FieldKind::Date => write!(f, "Date"),
32            FieldKind::Time => write!(f, "Time"),
33            FieldKind::Decimal => write!(f, "Decimal"),
34        }
35    }
36}
37
38pub trait FilterField: Copy + Eq + std::hash::Hash + fmt::Debug + 'static {
39    const FIELDS: &'static [Self];
40
41    fn name(&self) -> &'static str;
42
43    fn kind(&self) -> FieldKind;
44
45    fn from_name(name: &str) -> Option<Self> {
46        Self::FIELDS
47            .iter()
48            .copied()
49            .find(|f| f.name().eq_ignore_ascii_case(name))
50    }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum FilterOp {
55    Eq,
56    Ne,
57    Gt,
58    Ge,
59    Lt,
60    Le,
61    Contains,
62    StartsWith,
63    EndsWith,
64    And,
65    Or,
66}
67
68impl fmt::Display for FilterOp {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        match self {
71            FilterOp::Eq => write!(f, "eq"),
72            FilterOp::Ne => write!(f, "ne"),
73            FilterOp::Gt => write!(f, "gt"),
74            FilterOp::Ge => write!(f, "ge"),
75            FilterOp::Lt => write!(f, "lt"),
76            FilterOp::Le => write!(f, "le"),
77            FilterOp::Contains => write!(f, "contains"),
78            FilterOp::StartsWith => write!(f, "startswith"),
79            FilterOp::EndsWith => write!(f, "endswith"),
80            FilterOp::And => write!(f, "and"),
81            FilterOp::Or => write!(f, "or"),
82        }
83    }
84}
85
86#[derive(Debug, Clone)]
87pub enum FilterNode<F: FilterField> {
88    Binary {
89        field: F,
90        op: FilterOp,
91        value: ODataValue,
92    },
93    Composite {
94        op: FilterOp,
95        children: Vec<FilterNode<F>>,
96    },
97    Not(Box<FilterNode<F>>),
98}
99
100impl<F: FilterField> FilterNode<F> {
101    pub fn binary(field: F, op: FilterOp, value: ODataValue) -> Self {
102        FilterNode::Binary { field, op, value }
103    }
104
105    #[must_use]
106    pub fn and(children: Vec<FilterNode<F>>) -> Self {
107        FilterNode::Composite {
108            op: FilterOp::And,
109            children,
110        }
111    }
112
113    #[must_use]
114    pub fn or(children: Vec<FilterNode<F>>) -> Self {
115        FilterNode::Composite {
116            op: FilterOp::Or,
117            children,
118        }
119    }
120
121    #[allow(clippy::should_implement_trait)]
122    pub fn not(inner: FilterNode<F>) -> Self {
123        FilterNode::Not(Box::new(inner))
124    }
125}
126
127#[derive(Debug, Error, Clone)]
128pub enum FilterError {
129    #[error("Unknown field: {0}")]
130    UnknownField(String),
131
132    #[error("Type mismatch for field {field}: expected {expected}, got {got}")]
133    TypeMismatch {
134        field: String,
135        expected: FieldKind,
136        got: String,
137    },
138
139    #[error("Unsupported operation: {0}")]
140    UnsupportedOperation(String),
141
142    #[error("Invalid filter expression: {0}")]
143    InvalidExpression(String),
144
145    #[error("Field-to-field comparisons are not supported")]
146    FieldToFieldComparison,
147
148    #[error("Bare identifier in filter: {0}")]
149    BareIdentifier(String),
150
151    #[error("Bare literal in filter")]
152    BareLiteral,
153}
154
155pub type FilterResult<T> = Result<T, FilterError>;
156
157#[allow(unexpected_cfgs)]
158/// Parse an `OData` filter string into a typed `FilterNode`.
159///
160/// # Errors
161///
162/// Returns `FilterError::InvalidExpression` if parsing fails, the required feature is disabled,
163/// or the expression cannot be converted into a typed filter node.
164pub fn parse_odata_filter<F: FilterField>(raw: &str) -> FilterResult<FilterNode<F>> {
165    #[cfg(feature = "with-odata-params")]
166    {
167        use odata_params::filters::parse_str;
168
169        let ast = parse_str(raw).map_err(|e| FilterError::InvalidExpression(format!("{e:?}")))?;
170        let ast: odata_ast::Expr = ast.into();
171        convert_expr_to_filter_node::<F>(&ast)
172    }
173
174    #[cfg(not(feature = "with-odata-params"))]
175    {
176        let _ = raw;
177        Err(FilterError::InvalidExpression(
178            "OData filter parsing requires 'with-odata-params' feature".to_owned(),
179        ))
180    }
181}
182
183/// Convert a parsed `OData` AST expression into a typed `FilterNode`.
184///
185/// # Errors
186///
187/// Returns `FilterError` if the expression is invalid, references unknown fields, uses unsupported
188/// operations, or contains type mismatches.
189pub fn convert_expr_to_filter_node<F: FilterField>(
190    expr: &odata_ast::Expr,
191) -> FilterResult<FilterNode<F>> {
192    use odata_ast::Expr as E;
193
194    match expr {
195        E::And(left, right) => {
196            let left_node = convert_expr_to_filter_node::<F>(left)?;
197            let right_node = convert_expr_to_filter_node::<F>(right)?;
198            Ok(FilterNode::and(vec![left_node, right_node]))
199        }
200        E::Or(left, right) => {
201            let left_node = convert_expr_to_filter_node::<F>(left)?;
202            let right_node = convert_expr_to_filter_node::<F>(right)?;
203            Ok(FilterNode::or(vec![left_node, right_node]))
204        }
205        E::Not(inner) => {
206            let inner_node = convert_expr_to_filter_node::<F>(inner)?;
207            Ok(FilterNode::not(inner_node))
208        }
209
210        E::Compare(left, op, right) => {
211            let (field_name, value) = match (&**left, &**right) {
212                (E::Identifier(name), E::Value(val)) => (name.as_str(), val.clone()),
213                (E::Identifier(_), E::Identifier(_)) => {
214                    return Err(FilterError::FieldToFieldComparison);
215                }
216                _ => {
217                    return Err(FilterError::InvalidExpression(
218                        "Comparison must be between field and value".to_owned(),
219                    ));
220                }
221            };
222
223            let field = F::from_name(field_name)
224                .ok_or_else(|| FilterError::UnknownField(field_name.to_owned()))?;
225
226            validate_value_type(field, &value)?;
227
228            let filter_op = match op {
229                odata_ast::CompareOperator::Eq => FilterOp::Eq,
230                odata_ast::CompareOperator::Ne => FilterOp::Ne,
231                odata_ast::CompareOperator::Gt => FilterOp::Gt,
232                odata_ast::CompareOperator::Ge => FilterOp::Ge,
233                odata_ast::CompareOperator::Lt => FilterOp::Lt,
234                odata_ast::CompareOperator::Le => FilterOp::Le,
235            };
236
237            Ok(FilterNode::binary(field, filter_op, value))
238        }
239
240        E::Function(func_name, args) => {
241            let name_lower = func_name.to_ascii_lowercase();
242            match (name_lower.as_str(), args.as_slice()) {
243                (
244                    "contains",
245                    [
246                        E::Identifier(field_name),
247                        E::Value(odata_ast::Value::String(s)),
248                    ],
249                ) => {
250                    let field = F::from_name(field_name)
251                        .ok_or_else(|| FilterError::UnknownField(field_name.clone()))?;
252
253                    if field.kind() != FieldKind::String {
254                        return Err(FilterError::TypeMismatch {
255                            field: field_name.clone(),
256                            expected: FieldKind::String,
257                            got: "non-string".to_owned(),
258                        });
259                    }
260
261                    Ok(FilterNode::binary(
262                        field,
263                        FilterOp::Contains,
264                        odata_ast::Value::String(s.clone()),
265                    ))
266                }
267                (
268                    "startswith",
269                    [
270                        E::Identifier(field_name),
271                        E::Value(odata_ast::Value::String(s)),
272                    ],
273                ) => {
274                    let field = F::from_name(field_name)
275                        .ok_or_else(|| FilterError::UnknownField(field_name.clone()))?;
276
277                    if field.kind() != FieldKind::String {
278                        return Err(FilterError::TypeMismatch {
279                            field: field_name.clone(),
280                            expected: FieldKind::String,
281                            got: "non-string".to_owned(),
282                        });
283                    }
284
285                    Ok(FilterNode::binary(
286                        field,
287                        FilterOp::StartsWith,
288                        odata_ast::Value::String(s.clone()),
289                    ))
290                }
291                (
292                    "endswith",
293                    [
294                        E::Identifier(field_name),
295                        E::Value(odata_ast::Value::String(s)),
296                    ],
297                ) => {
298                    let field = F::from_name(field_name)
299                        .ok_or_else(|| FilterError::UnknownField(field_name.clone()))?;
300
301                    if field.kind() != FieldKind::String {
302                        return Err(FilterError::TypeMismatch {
303                            field: field_name.clone(),
304                            expected: FieldKind::String,
305                            got: "non-string".to_owned(),
306                        });
307                    }
308
309                    Ok(FilterNode::binary(
310                        field,
311                        FilterOp::EndsWith,
312                        odata_ast::Value::String(s.clone()),
313                    ))
314                }
315                _ => Err(FilterError::UnsupportedOperation(format!(
316                    "Function '{func_name}'"
317                ))),
318            }
319        }
320
321        E::In(_left, _list) => Err(FilterError::UnsupportedOperation(
322            "IN operator not yet supported in typed filters".to_owned(),
323        )),
324
325        E::Identifier(name) => Err(FilterError::BareIdentifier(name.clone())),
326        E::Value(_) => Err(FilterError::BareLiteral),
327    }
328}
329
330fn validate_value_type<F: FilterField>(field: F, value: &odata_ast::Value) -> FilterResult<()> {
331    use odata_ast::Value as V;
332
333    let kind = field.kind();
334    let matches = matches!(
335        (kind, value),
336        (FieldKind::String, V::String(_))
337            | (
338                FieldKind::I64 | FieldKind::F64 | FieldKind::Decimal,
339                V::Number(_)
340            )
341            | (FieldKind::Bool, V::Bool(_))
342            | (FieldKind::Uuid, V::Uuid(_))
343            | (FieldKind::DateTimeUtc, V::DateTime(_))
344            | (FieldKind::Date, V::Date(_))
345            | (FieldKind::Time, V::Time(_))
346    );
347
348    if matches {
349        Ok(())
350    } else {
351        Err(FilterError::TypeMismatch {
352            field: field.name().to_owned(),
353            expected: kind,
354            got: value.to_string(),
355        })
356    }
357}