1extern crate pest;
2
3use crate::ast::{
4 BinaryOperator, Expression, Lhs, LhsTransformations, LogicalExpression, Predicate, Value,
5};
6use cidr::{IpCidr, Ipv4Cidr, Ipv6Cidr};
7use pest::error::Error as ParseError;
8use pest::error::ErrorVariant;
9use pest::iterators::Pair;
10use pest::pratt_parser::Assoc as AssocNew;
11use pest::pratt_parser::{Op, PrattParser};
12use pest::Parser;
13use regex::Regex;
14use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
15
16type ParseResult<T> = Result<T, ParseError<Rule>>;
17trait IntoParseResult<T> {
21 #[allow(clippy::result_large_err)] fn into_parse_result(self, pair: &Pair<Rule>) -> ParseResult<T>;
23}
24
25impl<T, E> IntoParseResult<T> for Result<T, E>
26where
27 E: ToString,
28{
29 fn into_parse_result(self, pair: &Pair<Rule>) -> ParseResult<T> {
30 self.map_err(|e| {
31 let span = pair.as_span();
32
33 let err_var = ErrorVariant::CustomError {
34 message: e.to_string(),
35 };
36
37 ParseError::new_from_span(err_var, span)
38 })
39 }
40}
41
42#[derive(Parser)]
43#[grammar = "atc_grammar.pest"]
44struct ATCParser {
45 pratt_parser: PrattParser<Rule>,
46}
47
48macro_rules! parse_num {
49 ($node:expr, $ty:ident, $radix:expr) => {
50 $ty::from_str_radix($node.as_str(), $radix).into_parse_result(&$node)
51 };
52}
53
54impl ATCParser {
55 fn new() -> Self {
56 Self {
57 pratt_parser: PrattParser::new()
58 .op(Op::infix(Rule::and_op, AssocNew::Left))
59 .op(Op::infix(Rule::or_op, AssocNew::Left)),
60 }
61 }
62 #[allow(clippy::result_large_err)] fn parse_matcher(&mut self, source: &str) -> ParseResult<Expression> {
65 let pairs = ATCParser::parse(Rule::matcher, source)?;
66 let expr_pair = pairs.peek().unwrap().into_inner().peek().unwrap();
67 let rule = expr_pair.as_rule();
68 match rule {
69 Rule::expression => parse_expression(expr_pair, &self.pratt_parser),
70 _ => unreachable!(),
71 }
72 }
73}
74
75#[allow(clippy::result_large_err)] fn parse_ident(pair: Pair<Rule>) -> ParseResult<String> {
77 Ok(pair.as_str().into())
78}
79
80#[allow(clippy::result_large_err)] fn parse_lhs(pair: Pair<Rule>) -> ParseResult<Lhs> {
82 let pairs = pair.into_inner();
83 let pair = pairs.peek().unwrap();
84 let rule = pair.as_rule();
85 Ok(match rule {
86 Rule::transform_func => parse_transform_func(pair)?,
87 Rule::ident => {
88 let var = parse_ident(pair)?;
89 Lhs {
90 var_name: var,
91 transformations: Vec::new(),
92 }
93 }
94 _ => unreachable!(),
95 })
96}
97
98#[allow(clippy::result_large_err)] fn parse_rhs(pair: Pair<Rule>) -> ParseResult<Value> {
101 let pairs = pair.into_inner();
102 let pair = pairs.peek().unwrap();
103 let rule = pair.as_rule();
104 Ok(match rule {
105 Rule::str_literal => Value::String(parse_str_literal(pair)?),
106 Rule::rawstr_literal => Value::String(parse_rawstr_literal(pair)?),
107 Rule::ipv4_cidr_literal => Value::IpCidr(IpCidr::V4(parse_ipv4_cidr_literal(pair)?)),
108 Rule::ipv6_cidr_literal => Value::IpCidr(IpCidr::V6(parse_ipv6_cidr_literal(pair)?)),
109 Rule::ipv4_literal => Value::IpAddr(IpAddr::V4(parse_ipv4_literal(pair)?)),
110 Rule::ipv6_literal => Value::IpAddr(IpAddr::V6(parse_ipv6_literal(pair)?)),
111 Rule::int_literal => Value::Int(parse_int_literal(pair)?),
112 _ => unreachable!(),
113 })
114}
115
116#[allow(clippy::result_large_err)] fn parse_str_literal(pair: Pair<Rule>) -> ParseResult<String> {
119 let char_pairs = pair.into_inner();
120 let mut s = String::new();
121 for char_pair in char_pairs {
122 let rule = char_pair.as_rule();
123 match rule {
124 Rule::str_esc => s.push(parse_str_esc(char_pair)),
125 Rule::str_char => s.push(parse_str_char(char_pair)),
126 _ => unreachable!(),
127 }
128 }
129 Ok(s)
130}
131
132#[allow(clippy::result_large_err)] fn parse_rawstr_literal(pair: Pair<Rule>) -> ParseResult<String> {
136 let char_pairs = pair.into_inner();
137 let mut s = String::new();
138 for char_pair in char_pairs {
139 let rule = char_pair.as_rule();
140 match rule {
141 Rule::rawstr_char => s.push(parse_str_char(char_pair)),
142 _ => unreachable!(),
143 }
144 }
145 Ok(s)
146}
147
148fn parse_str_esc(pair: Pair<Rule>) -> char {
149 match pair.as_str() {
150 r#"\""# => '"',
151 r#"\\"# => '\\',
152 r#"\n"# => '\n',
153 r#"\r"# => '\r',
154 r#"\t"# => '\t',
155
156 _ => unreachable!(),
157 }
158}
159fn parse_str_char(pair: Pair<Rule>) -> char {
160 return pair.as_str().chars().next().unwrap();
161}
162
163#[allow(clippy::result_large_err)] fn parse_ipv4_cidr_literal(pair: Pair<Rule>) -> ParseResult<Ipv4Cidr> {
165 pair.as_str().parse().into_parse_result(&pair)
166}
167
168#[allow(clippy::result_large_err)] fn parse_ipv6_cidr_literal(pair: Pair<Rule>) -> ParseResult<Ipv6Cidr> {
170 pair.as_str().parse().into_parse_result(&pair)
171}
172
173#[allow(clippy::result_large_err)] fn parse_ipv4_literal(pair: Pair<Rule>) -> ParseResult<Ipv4Addr> {
175 pair.as_str().parse().into_parse_result(&pair)
176}
177
178#[allow(clippy::result_large_err)] fn parse_ipv6_literal(pair: Pair<Rule>) -> ParseResult<Ipv6Addr> {
180 pair.as_str().parse().into_parse_result(&pair)
181}
182
183#[allow(clippy::result_large_err)] fn parse_int_literal(pair: Pair<Rule>) -> ParseResult<i64> {
185 let is_neg = pair.as_str().starts_with('-');
186 let pairs = pair.into_inner();
187 let pair = pairs.peek().unwrap(); let rule = pair.as_rule();
189 let radix = match rule {
190 Rule::hex_digits => 16,
191 Rule::oct_digits => 8,
192 Rule::dec_digits => 10,
193 _ => unreachable!(),
194 };
195
196 let mut num = parse_num!(pair, i64, radix)?;
197
198 if is_neg {
199 num = -num;
200 }
201
202 Ok(num)
203}
204
205#[allow(clippy::result_large_err)] fn parse_predicate(pair: Pair<Rule>) -> ParseResult<Predicate> {
208 let mut pairs = pair.into_inner();
209 let lhs = parse_lhs(pairs.next().unwrap())?;
210 let op = parse_binary_operator(pairs.next().unwrap());
211 let rhs_pair = pairs.next().unwrap();
212 let rhs = parse_rhs(rhs_pair.clone())?;
213 Ok(Predicate {
214 lhs,
215 rhs: if op == BinaryOperator::Regex {
216 if let Value::String(s) = rhs {
217 let r = Regex::new(&s).map_err(|e| {
218 ParseError::new_from_span(
219 ErrorVariant::CustomError {
220 message: e.to_string(),
221 },
222 rhs_pair.as_span(),
223 )
224 })?;
225
226 Value::Regex(r)
227 } else {
228 return Err(ParseError::new_from_span(
229 ErrorVariant::CustomError {
230 message: "regex operator can only be used with String operands".to_string(),
231 },
232 rhs_pair.as_span(),
233 ));
234 }
235 } else {
236 rhs
237 },
238 op,
239 })
240}
241#[allow(clippy::result_large_err)] fn parse_transform_func(pair: Pair<Rule>) -> ParseResult<Lhs> {
244 let span = pair.as_span();
245 let pairs = pair.into_inner();
246 let mut pairs = pairs.peekable();
247 let func_name = pairs.next().unwrap().as_str().to_string();
248 let mut lhs = parse_lhs(pairs.next().unwrap())?;
249 lhs.transformations.push(match func_name.as_str() {
250 "lower" => LhsTransformations::Lower,
251 "any" => LhsTransformations::Any,
252 unknown => {
253 return Err(ParseError::new_from_span(
254 ErrorVariant::CustomError {
255 message: format!("unknown transformation function: {}", unknown),
256 },
257 span,
258 ));
259 }
260 });
261
262 Ok(lhs)
263}
264
265fn parse_binary_operator(pair: Pair<Rule>) -> BinaryOperator {
268 let rule = pair.as_str();
269 use BinaryOperator as BinaryOp;
270 match rule {
271 "==" => BinaryOp::Equals,
272 "!=" => BinaryOp::NotEquals,
273 "~" => BinaryOp::Regex,
274 "^=" => BinaryOp::Prefix,
275 "=^" => BinaryOp::Postfix,
276 ">=" => BinaryOp::GreaterOrEqual,
277 ">" => BinaryOp::Greater,
278 "<=" => BinaryOp::LessOrEqual,
279 "<" => BinaryOp::Less,
280 "in" => BinaryOp::In,
281 "not in" => BinaryOp::NotIn,
282 "contains" => BinaryOp::Contains,
283 _ => unreachable!(),
284 }
285}
286
287#[allow(clippy::result_large_err)] fn parse_parenthesised_expression(
290 pair: Pair<Rule>,
291 pratt: &PrattParser<Rule>,
292) -> ParseResult<Expression> {
293 let mut pairs = pair.into_inner();
294 let pair = pairs.next().unwrap();
295 let rule = pair.as_rule();
296 match rule {
297 Rule::expression => parse_expression(pair, pratt),
298 Rule::not_op => Ok(Expression::Logical(Box::new(LogicalExpression::Not(
299 parse_expression(pairs.next().unwrap(), pratt)?,
300 )))),
301 _ => unreachable!(),
302 }
303}
304
305#[allow(clippy::result_large_err)] fn parse_term(pair: Pair<Rule>, pratt: &PrattParser<Rule>) -> ParseResult<Expression> {
308 let pairs = pair.into_inner();
309 let inner_rule = pairs.peek().unwrap();
310 let rule = inner_rule.as_rule();
311 match rule {
312 Rule::predicate => Ok(Expression::Predicate(parse_predicate(inner_rule)?)),
313 Rule::parenthesised_expression => parse_parenthesised_expression(inner_rule, pratt),
314 _ => unreachable!(),
315 }
316}
317
318#[allow(clippy::result_large_err)] fn parse_expression(pair: Pair<Rule>, pratt: &PrattParser<Rule>) -> ParseResult<Expression> {
321 let pairs = pair.into_inner();
322 pratt
323 .map_primary(|operand| match operand.as_rule() {
324 Rule::term => parse_term(operand, pratt),
325 _ => unreachable!(),
326 })
327 .map_infix(|lhs, op, rhs| {
328 Ok(match op.as_rule() {
329 Rule::and_op => Expression::Logical(Box::new(LogicalExpression::And(lhs?, rhs?))),
330 Rule::or_op => Expression::Logical(Box::new(LogicalExpression::Or(lhs?, rhs?))),
331 _ => unreachable!(),
332 })
333 })
334 .parse(pairs)
335}
336
337#[allow(clippy::result_large_err)] pub fn parse(source: &str) -> ParseResult<Expression> {
339 ATCParser::new().parse_matcher(source)
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_bad_syntax() {
348 assert_eq!(
349 parse("! a == 1").unwrap_err().to_string(),
350 " --> 1:1\n |\n1 | ! a == 1\n | ^---\n |\n = expected term"
351 );
352 assert_eq!(
353 parse("a == 1 || ! b == 2").unwrap_err().to_string(),
354 " --> 1:11\n |\n1 | a == 1 || ! b == 2\n | ^---\n |\n = expected term"
355 );
356 assert_eq!(
357 parse("(a == 1 || b == 2) && ! c == 3")
358 .unwrap_err()
359 .to_string(),
360 " --> 1:23\n |\n1 | (a == 1 || b == 2) && ! c == 3\n | ^---\n |\n = expected term"
361 );
362 }
363}