aip_filtering/
ast.rs

1use std::fmt::{self, Display, Formatter};
2use std::time::Duration;
3
4use chrono::{DateTime, FixedOffset};
5use itertools::Itertools;
6use pest::error::Error;
7use pest::iterators::{Pair, Pairs};
8use pest::Parser;
9use pest_derive::Parser;
10
11#[derive(Parser)]
12#[grammar = "grammar.pest"]
13pub struct FilterParser;
14
15impl FilterParser {
16    pub fn parse_str(input: &str) -> Result<Filter, Error<Rule>> {
17        let mut pairs: Pairs<Rule> = Self::parse(Rule::filter, input)?;
18
19        match pairs.next() {
20            None => Ok(Filter::None),
21
22            Some(pair) => Ok(Filter::Some(Expression::parse(pair).unwrap())),
23        }
24    }
25}
26
27pub type Filter<'a> = Option<Expression<'a>>;
28
29#[derive(Debug)]
30pub enum Value<'a> {
31    Bool(bool),
32    Duration(Duration),
33    Float(f64),
34    Int(i64),
35    String(&'a str),
36    Text(&'a str),
37    Timestamp(DateTime<FixedOffset>),
38}
39
40impl<'a> Value<'a> {
41    pub fn parse(pair: Pair<'a, Rule>) -> Result<Self, ()> {
42        debug_assert_eq!(pair.as_rule(), Rule::value);
43
44        let inner_pair = pair.into_inner().next().unwrap();
45
46        match inner_pair.as_rule() {
47            Rule::string => {
48                let str = inner_pair.into_inner().as_str();
49
50                if let Ok(value) = DateTime::parse_from_rfc3339(str) {
51                    return Ok(Value::Timestamp(value));
52                }
53
54                return Ok(Value::String(str))
55            }
56
57            Rule::text => {
58                let str = inner_pair.as_str();
59
60                if let Ok(value) = str.parse() {
61                    return Ok(Value::Bool(value));
62                }
63
64                if str.ends_with("s") {
65                    if let Ok(secs) = str[..str.len() - 1].parse() {
66                        return Ok(Value::Duration(Duration::from_secs_f64(secs)))
67                    }
68                }
69
70                if let Ok(value) = str.parse() {
71                    return Ok(Value::Int(value));
72                }
73
74                if let Ok(value) = str.parse() {
75                    return Ok(Value::Float(value));
76                }
77
78                if let Ok(value) = DateTime::parse_from_rfc3339(str) {
79                    return Ok(Value::Timestamp(value));
80                }
81
82                Ok(Value::Text(str))
83            }
84
85            _ => unreachable!()
86        }
87    }
88
89    pub fn string_repr(&self) -> String {
90        match self {
91            Self::Bool(value) => value.to_string(),
92            Self::Duration(value) => format!("{}s", value.as_secs_f64()),
93            Self::Float(value) => value.to_string(),
94            Self::Int(value) => value.to_string(),
95            Self::String(value) => value.to_string(),
96            Self::Text(value) => value.to_string(),
97            Self::Timestamp(value) => value.to_rfc3339(),
98        }
99    }
100}
101
102impl Display for Value<'_> {
103    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
104        match self {
105            Self::Bool(value) => write!(f, "{}", value),
106            Self::Duration(value) => write!(f, "{}s", value.as_secs_f64()),
107            Self::Float(value) => write!(f, "{}", value),
108            Self::Int(value) => write!(f, "{}", value),
109            Self::String(value) => write!(f, "\"{}\"", value),
110            Self::Text(value) => write!(f, "{}", value),
111            Self::Timestamp(value) => write!(f, "\"{}\"", value.to_rfc3339()),
112        }
113    }
114}
115
116#[derive(Clone, Copy, Debug, Eq, PartialEq)]
117pub enum Comparator {
118    Eq,
119    Gt,
120    GtEq,
121    Has,
122    Lt,
123    LtEq,
124    Ne,
125}
126
127impl Comparator {
128    pub fn parse(pair: Pair<Rule>) -> Result<Self, ()> {
129        debug_assert_eq!(pair.as_rule(), Rule::comparator);
130
131        let inner_pair = pair.into_inner().next().unwrap();
132
133        Ok(match inner_pair.as_rule() {
134            Rule::eq => Self::Eq,
135            Rule::gt => Self::Gt,
136            Rule::gt_eq => Self::GtEq,
137            Rule::has => Self::Has,
138            Rule::lt => Self::Lt,
139            Rule::lt_eq => Self::LtEq,
140            Rule::ne => Self::Ne,
141            _ => return Err(()),
142        })
143    }
144}
145
146impl Display for Comparator {
147    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
148        write!(
149            f,
150            "{}",
151            match self {
152                Self::Eq => " = ",
153                Self::Gt => " > ",
154                Self::GtEq => " >= ",
155                Self::Has => ":",
156                Self::Lt => " < ",
157                Self::LtEq => " <= ",
158                Self::Ne => " != ",
159            }
160        )
161    }
162}
163
164#[derive(Clone, Copy, Debug, Eq, PartialEq)]
165pub enum BinOp {
166    And,
167    Or,
168}
169
170impl Display for BinOp {
171    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
172        write!(
173            f,
174            "{}",
175            match self {
176                Self::And => " AND ",
177                Self::Or => " OR ",
178            }
179        )
180    }
181}
182
183#[derive(Clone, Copy, Debug, Eq, PartialEq)]
184pub enum UnOp {
185    Neg,
186    Not,
187}
188
189impl Display for UnOp {
190    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
191        write!(
192            f,
193            "{}",
194            match self {
195                Self::Neg => "-",
196                Self::Not => "NOT ",
197            }
198        )
199    }
200}
201
202#[derive(Debug)]
203pub enum Expression<'a> {
204    Binary {
205        lhs: Box<Expression<'a>>,
206        op: BinOp,
207        rhs: Box<Expression<'a>>,
208    },
209
210    FCall {
211        name: &'a str,
212        args: Vec<Expression<'a>>,
213    },
214
215    Member {
216        value: Value<'a>,
217        path: Vec<Value<'a>>,
218    },
219
220    Restriction {
221        lhs: Box<Expression<'a>>,
222        op: Comparator,
223        rhs: Box<Expression<'a>>,
224    },
225
226    Sequence(Vec<Expression<'a>>),
227
228    Unary {
229        op: UnOp,
230        rhs: Box<Expression<'a>>,
231    },
232
233    Value(Value<'a>),
234}
235
236impl<'a> Expression<'a> {
237    pub fn parse(pair: Pair<'a, Rule>) -> Result<Self, ()> {
238        debug_assert_eq!(pair.as_rule(), Rule::expression);
239
240        let mut inner = pair.into_inner();
241
242        let mut expr = Expression::parse_impl(inner.next().unwrap())?;
243
244        while let Some(rhs_pair) = inner.next() {
245            let rhs = Expression::parse_impl(rhs_pair)?;
246
247            expr = Expression::Binary {
248                lhs: Box::new(expr),
249                op: BinOp::And,
250                rhs: Box::new(rhs),
251            }
252        }
253
254        Ok(expr)
255    }
256
257    fn parse_impl(pair: Pair<'a, Rule>) -> Result<Self, ()> {
258        match pair.as_rule() {
259            Rule::expression => Self::parse(pair),
260
261            Rule::factor => {
262                let mut inner = pair.into_inner();
263
264                let mut expr = Expression::parse_impl(inner.next().unwrap())?;
265
266                while let Some(rhs_pair) = inner.next() {
267                    let rhs = Expression::parse_impl(rhs_pair)?;
268
269                    expr = Expression::Binary {
270                        lhs: Box::new(expr),
271                        op: BinOp::Or,
272                        rhs: Box::new(rhs),
273                    };
274                }
275
276                Ok(expr)
277            }
278
279            Rule::sequence => {
280                let mut inner = pair.into_inner();
281
282                match inner.clone().count() {
283                    0 => Err(()),
284
285                    1 => Expression::parse_impl(inner.next().unwrap()),
286
287                    _ => {
288                        let exprs = inner
289                            .map(|pair| Expression::parse_impl(pair))
290                            .try_collect()?;
291
292                        Ok(Expression::Sequence(exprs))
293                    }
294                }
295            }
296
297            Rule::term => {
298                let mut inner = pair.into_inner();
299
300                match inner.next().unwrap() {
301                    inner_pair if inner_pair.as_rule() == Rule::negation => {
302                        let op = match inner_pair.as_str() {
303                            "-" => UnOp::Neg,
304                            "NOT" => UnOp::Not,
305                            _ => unreachable!(),
306                        };
307
308                        let term_pair = inner.next().unwrap();
309
310                        let rhs = Expression::parse_impl(term_pair)?;
311
312                        Ok(Expression::Unary {
313                            op,
314                            rhs: Box::new(rhs),
315                        })
316                    }
317
318                    term_pair => Ok(Expression::parse_impl(term_pair)?),
319                }
320            }
321
322            Rule::restriction => {
323                let mut inner = pair.into_inner();
324
325                let lhs = Expression::parse_impl(inner.next().unwrap())?;
326
327                match inner.next() {
328                    None => Ok(lhs),
329
330                    Some(comparator_pair) => {
331                        let op = Comparator::parse(comparator_pair)?;
332
333                        let rhs = Expression::parse_impl(inner.next().unwrap())?;
334
335                        Ok(Expression::Restriction {
336                            lhs: Box::new(lhs),
337                            op,
338                            rhs: Box::new(rhs),
339                        })
340                    }
341                }
342            }
343
344            Rule::member => {
345                let mut inner = pair.into_inner();
346
347                let value = Value::parse(inner.next().unwrap())?;
348
349                match inner.peek() {
350                    None => Ok(Expression::Value(value)),
351
352                    Some(_) => {
353                        let mut path = Vec::new();
354
355                        for field_pair in inner {
356                            // Should always be equal when coming from Pest
357                            debug_assert_eq!(field_pair.as_rule(), Rule::field);
358
359                            let inner_pair = field_pair.into_inner().next().unwrap();
360
361                            path.push(parse_field(inner_pair)?);
362                        }
363
364                        Ok(Expression::Member { value, path })
365                    }
366                }
367            }
368
369            Rule::function => {
370                let mut inner = pair.into_inner();
371
372                let name = {
373                    let name_pair = inner.next().unwrap();
374
375                    debug_assert_eq!(name_pair.as_rule(), Rule::function_name);
376
377                    name_pair.as_str()
378                };
379
380                let args = inner
381                    .next()
382                    .map(|arg_list_pair| {
383                        debug_assert_eq!(arg_list_pair.as_rule(), Rule::arg_list);
384
385                        arg_list_pair
386                            .into_inner()
387                            .map(|pair| Expression::parse_impl(pair))
388                            .try_collect()
389                    })
390                    .unwrap_or_else(|| Ok(Vec::new()))?;
391
392                Ok(Expression::FCall { name, args })
393            }
394
395            Rule::value => {
396                let inner_pair = pair.into_inner().next().unwrap();
397
398                Ok(Expression::Value(Value::parse(inner_pair)?))
399            }
400
401            _ => unimplemented!("{:?}", pair),
402        }
403    }
404}
405
406impl Display for Expression<'_> {
407    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
408        match self {
409            Self::Binary { lhs, op, rhs } => write!(f, "({}{}{})", lhs, op, rhs),
410
411            Self::FCall { name, args } => {
412                let args_str = args
413                    .iter()
414                    .map(ToString::to_string)
415                    .collect::<Vec<String>>()
416                    .join(", ");
417
418                write!(f, "({}({}))", name, args_str)
419            }
420
421            Self::Member { value, path } => {
422                let path_str = path
423                    .iter()
424                    .map(ToString::to_string)
425                    .collect::<Vec<String>>()
426                    .join(".");
427
428                write!(f, "({}.{})", value, path_str)
429            }
430
431            Self::Restriction { lhs, op, rhs } => write!(f, "({}{}{})", lhs, op, rhs),
432
433            Self::Sequence(expressions) => {
434                let joined_str = expressions
435                    .iter()
436                    .map(|expr| format!("{}", expr))
437                    .collect::<Vec<String>>()
438                    .join(" ");
439
440                write!(f, "({})", joined_str)
441            }
442
443            Self::Unary { op, rhs } => write!(f, "({}{})", op, rhs),
444
445            Self::Value(value) => value.fmt(f),
446        }
447    }
448}
449
450fn parse_field(pair: Pair<Rule>) -> Result<Value, ()> {
451    match pair.as_rule() {
452        Rule::value => Value::parse(pair),
453        Rule::keyword => Ok(Value::Text(pair.as_str())),
454        _ => Err(()),
455    }
456}