1use std::fmt;
2
3use thiserror::Error;
4
5use crate::ast as odata_ast;
6
7pub use crate::ast::Value as ODataValue;
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10pub enum FieldKind {
11 String,
12 I64,
13 F64,
14 Bool,
15 Uuid,
16 DateTimeUtc,
17 Date,
18 Time,
19 Decimal,
20}
21
22impl fmt::Display for FieldKind {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 match self {
25 FieldKind::String => write!(f, "String"),
26 FieldKind::I64 => write!(f, "I64"),
27 FieldKind::F64 => write!(f, "F64"),
28 FieldKind::Bool => write!(f, "Bool"),
29 FieldKind::Uuid => write!(f, "Uuid"),
30 FieldKind::DateTimeUtc => write!(f, "DateTimeUtc"),
31 FieldKind::Date => write!(f, "Date"),
32 FieldKind::Time => write!(f, "Time"),
33 FieldKind::Decimal => write!(f, "Decimal"),
34 }
35 }
36}
37
38pub trait FilterField: Copy + Eq + std::hash::Hash + fmt::Debug + 'static {
39 const FIELDS: &'static [Self];
40
41 fn name(&self) -> &'static str;
42
43 fn kind(&self) -> FieldKind;
44
45 fn from_name(name: &str) -> Option<Self> {
46 Self::FIELDS
47 .iter()
48 .copied()
49 .find(|f| f.name().eq_ignore_ascii_case(name))
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum FilterOp {
55 Eq,
56 Ne,
57 Gt,
58 Ge,
59 Lt,
60 Le,
61 Contains,
62 StartsWith,
63 EndsWith,
64 And,
65 Or,
66}
67
68impl fmt::Display for FilterOp {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 match self {
71 FilterOp::Eq => write!(f, "eq"),
72 FilterOp::Ne => write!(f, "ne"),
73 FilterOp::Gt => write!(f, "gt"),
74 FilterOp::Ge => write!(f, "ge"),
75 FilterOp::Lt => write!(f, "lt"),
76 FilterOp::Le => write!(f, "le"),
77 FilterOp::Contains => write!(f, "contains"),
78 FilterOp::StartsWith => write!(f, "startswith"),
79 FilterOp::EndsWith => write!(f, "endswith"),
80 FilterOp::And => write!(f, "and"),
81 FilterOp::Or => write!(f, "or"),
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
87pub enum FilterNode<F: FilterField> {
88 Binary {
89 field: F,
90 op: FilterOp,
91 value: ODataValue,
92 },
93 Composite {
94 op: FilterOp,
95 children: Vec<FilterNode<F>>,
96 },
97 Not(Box<FilterNode<F>>),
98}
99
100impl<F: FilterField> FilterNode<F> {
101 pub fn binary(field: F, op: FilterOp, value: ODataValue) -> Self {
102 FilterNode::Binary { field, op, value }
103 }
104
105 #[must_use]
106 pub fn and(children: Vec<FilterNode<F>>) -> Self {
107 FilterNode::Composite {
108 op: FilterOp::And,
109 children,
110 }
111 }
112
113 #[must_use]
114 pub fn or(children: Vec<FilterNode<F>>) -> Self {
115 FilterNode::Composite {
116 op: FilterOp::Or,
117 children,
118 }
119 }
120
121 #[allow(clippy::should_implement_trait)]
122 pub fn not(inner: FilterNode<F>) -> Self {
123 FilterNode::Not(Box::new(inner))
124 }
125}
126
127#[derive(Debug, Error, Clone)]
128pub enum FilterError {
129 #[error("Unknown field: {0}")]
130 UnknownField(String),
131
132 #[error("Type mismatch for field {field}: expected {expected}, got {got}")]
133 TypeMismatch {
134 field: String,
135 expected: FieldKind,
136 got: String,
137 },
138
139 #[error("Unsupported operation: {0}")]
140 UnsupportedOperation(String),
141
142 #[error("Invalid filter expression: {0}")]
143 InvalidExpression(String),
144
145 #[error("Field-to-field comparisons are not supported")]
146 FieldToFieldComparison,
147
148 #[error("Bare identifier in filter: {0}")]
149 BareIdentifier(String),
150
151 #[error("Bare literal in filter")]
152 BareLiteral,
153}
154
155pub type FilterResult<T> = Result<T, FilterError>;
156
157#[allow(unexpected_cfgs)]
158pub fn parse_odata_filter<F: FilterField>(raw: &str) -> FilterResult<FilterNode<F>> {
165 #[cfg(feature = "with-odata-params")]
166 {
167 use odata_params::filters::parse_str;
168
169 let ast = parse_str(raw).map_err(|e| FilterError::InvalidExpression(format!("{e:?}")))?;
170 let ast: odata_ast::Expr = ast.into();
171 convert_expr_to_filter_node::<F>(&ast)
172 }
173
174 #[cfg(not(feature = "with-odata-params"))]
175 {
176 let _ = raw;
177 Err(FilterError::InvalidExpression(
178 "OData filter parsing requires 'with-odata-params' feature".to_owned(),
179 ))
180 }
181}
182
183pub fn convert_expr_to_filter_node<F: FilterField>(
190 expr: &odata_ast::Expr,
191) -> FilterResult<FilterNode<F>> {
192 use odata_ast::Expr as E;
193
194 match expr {
195 E::And(left, right) => {
196 let left_node = convert_expr_to_filter_node::<F>(left)?;
197 let right_node = convert_expr_to_filter_node::<F>(right)?;
198 Ok(FilterNode::and(vec![left_node, right_node]))
199 }
200 E::Or(left, right) => {
201 let left_node = convert_expr_to_filter_node::<F>(left)?;
202 let right_node = convert_expr_to_filter_node::<F>(right)?;
203 Ok(FilterNode::or(vec![left_node, right_node]))
204 }
205 E::Not(inner) => {
206 let inner_node = convert_expr_to_filter_node::<F>(inner)?;
207 Ok(FilterNode::not(inner_node))
208 }
209
210 E::Compare(left, op, right) => {
211 let (field_name, value) = match (&**left, &**right) {
212 (E::Identifier(name), E::Value(val)) => (name.as_str(), val.clone()),
213 (E::Identifier(_), E::Identifier(_)) => {
214 return Err(FilterError::FieldToFieldComparison);
215 }
216 _ => {
217 return Err(FilterError::InvalidExpression(
218 "Comparison must be between field and value".to_owned(),
219 ));
220 }
221 };
222
223 let field = F::from_name(field_name)
224 .ok_or_else(|| FilterError::UnknownField(field_name.to_owned()))?;
225
226 validate_value_type(field, &value)?;
227
228 let filter_op = match op {
229 odata_ast::CompareOperator::Eq => FilterOp::Eq,
230 odata_ast::CompareOperator::Ne => FilterOp::Ne,
231 odata_ast::CompareOperator::Gt => FilterOp::Gt,
232 odata_ast::CompareOperator::Ge => FilterOp::Ge,
233 odata_ast::CompareOperator::Lt => FilterOp::Lt,
234 odata_ast::CompareOperator::Le => FilterOp::Le,
235 };
236
237 Ok(FilterNode::binary(field, filter_op, value))
238 }
239
240 E::Function(func_name, args) => {
241 let name_lower = func_name.to_ascii_lowercase();
242 match (name_lower.as_str(), args.as_slice()) {
243 (
244 "contains",
245 [
246 E::Identifier(field_name),
247 E::Value(odata_ast::Value::String(s)),
248 ],
249 ) => {
250 let field = F::from_name(field_name)
251 .ok_or_else(|| FilterError::UnknownField(field_name.clone()))?;
252
253 if field.kind() != FieldKind::String {
254 return Err(FilterError::TypeMismatch {
255 field: field_name.clone(),
256 expected: FieldKind::String,
257 got: "non-string".to_owned(),
258 });
259 }
260
261 Ok(FilterNode::binary(
262 field,
263 FilterOp::Contains,
264 odata_ast::Value::String(s.clone()),
265 ))
266 }
267 (
268 "startswith",
269 [
270 E::Identifier(field_name),
271 E::Value(odata_ast::Value::String(s)),
272 ],
273 ) => {
274 let field = F::from_name(field_name)
275 .ok_or_else(|| FilterError::UnknownField(field_name.clone()))?;
276
277 if field.kind() != FieldKind::String {
278 return Err(FilterError::TypeMismatch {
279 field: field_name.clone(),
280 expected: FieldKind::String,
281 got: "non-string".to_owned(),
282 });
283 }
284
285 Ok(FilterNode::binary(
286 field,
287 FilterOp::StartsWith,
288 odata_ast::Value::String(s.clone()),
289 ))
290 }
291 (
292 "endswith",
293 [
294 E::Identifier(field_name),
295 E::Value(odata_ast::Value::String(s)),
296 ],
297 ) => {
298 let field = F::from_name(field_name)
299 .ok_or_else(|| FilterError::UnknownField(field_name.clone()))?;
300
301 if field.kind() != FieldKind::String {
302 return Err(FilterError::TypeMismatch {
303 field: field_name.clone(),
304 expected: FieldKind::String,
305 got: "non-string".to_owned(),
306 });
307 }
308
309 Ok(FilterNode::binary(
310 field,
311 FilterOp::EndsWith,
312 odata_ast::Value::String(s.clone()),
313 ))
314 }
315 _ => Err(FilterError::UnsupportedOperation(format!(
316 "Function '{func_name}'"
317 ))),
318 }
319 }
320
321 E::In(_left, _list) => Err(FilterError::UnsupportedOperation(
322 "IN operator not yet supported in typed filters".to_owned(),
323 )),
324
325 E::Identifier(name) => Err(FilterError::BareIdentifier(name.clone())),
326 E::Value(_) => Err(FilterError::BareLiteral),
327 }
328}
329
330fn validate_value_type<F: FilterField>(field: F, value: &odata_ast::Value) -> FilterResult<()> {
331 use odata_ast::Value as V;
332
333 let kind = field.kind();
334 let matches = matches!(
335 (kind, value),
336 (FieldKind::String, V::String(_))
337 | (
338 FieldKind::I64 | FieldKind::F64 | FieldKind::Decimal,
339 V::Number(_)
340 )
341 | (FieldKind::Bool, V::Bool(_))
342 | (FieldKind::Uuid, V::Uuid(_))
343 | (FieldKind::DateTimeUtc, V::DateTime(_))
344 | (FieldKind::Date, V::Date(_))
345 | (FieldKind::Time, V::Time(_))
346 );
347
348 if matches {
349 Ok(())
350 } else {
351 Err(FilterError::TypeMismatch {
352 field: field.name().to_owned(),
353 expected: kind,
354 got: value.to_string(),
355 })
356 }
357}