1use crate::ast::{BinaryOperator, Expression, LogicalExpression, Predicate};
2use crate::ffi::ERR_BUF_MAX_LEN;
3use crate::schema::Schema;
4use bitflags::bitflags;
5use std::cmp::min;
6use std::ffi;
7use std::os::raw::c_char;
8use std::slice::from_raw_parts_mut;
9
10use std::iter::Iterator;
11
12struct PredicateIterator<'a> {
13 stack: Vec<&'a Expression>,
14}
15
16impl<'a> PredicateIterator<'a> {
17 fn new(expr: &'a Expression) -> Self {
18 Self { stack: vec![expr] }
19 }
20}
21
22impl<'a> Iterator for PredicateIterator<'a> {
23 type Item = &'a Predicate;
24
25 fn next(&mut self) -> Option<Self::Item> {
26 while let Some(expr) = self.stack.pop() {
27 match expr {
28 Expression::Logical(l) => match l.as_ref() {
29 LogicalExpression::And(l, r) | LogicalExpression::Or(l, r) => {
30 self.stack.push(l);
31 self.stack.push(r);
32 }
33 LogicalExpression::Not(r) => {
34 self.stack.push(r);
35 }
36 },
37 Expression::Predicate(p) => return Some(p),
38 }
39 }
40 None
41 }
42}
43
44impl Expression {
45 fn iter_predicates(&self) -> PredicateIterator {
46 PredicateIterator::new(self)
47 }
48}
49
50bitflags! {
51 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
52 #[repr(C)]
53 pub struct BinaryOperatorFlags: u64 {
54 const EQUALS = 1 << 0;
55 const NOT_EQUALS = 1 << 1;
56 const REGEX = 1 << 2;
57 const PREFIX = 1 << 3;
58 const POSTFIX = 1 << 4;
59 const GREATER = 1 << 5;
60 const GREATER_OR_EQUAL = 1 << 6;
61 const LESS = 1 << 7;
62 const LESS_OR_EQUAL = 1 << 8;
63 const IN = 1 << 9;
64 const NOT_IN = 1 << 10;
65 const CONTAINS = 1 << 11;
66
67 const UNUSED = !(Self::EQUALS.bits()
68 | Self::NOT_EQUALS.bits()
69 | Self::REGEX.bits()
70 | Self::PREFIX.bits()
71 | Self::POSTFIX.bits()
72 | Self::GREATER.bits()
73 | Self::GREATER_OR_EQUAL.bits()
74 | Self::LESS.bits()
75 | Self::LESS_OR_EQUAL.bits()
76 | Self::IN.bits()
77 | Self::NOT_IN.bits()
78 | Self::CONTAINS.bits());
79 }
80}
81
82impl From<&BinaryOperator> for BinaryOperatorFlags {
83 fn from(op: &BinaryOperator) -> Self {
84 match op {
85 BinaryOperator::Equals => Self::EQUALS,
86 BinaryOperator::NotEquals => Self::NOT_EQUALS,
87 BinaryOperator::Regex => Self::REGEX,
88 BinaryOperator::Prefix => Self::PREFIX,
89 BinaryOperator::Postfix => Self::POSTFIX,
90 BinaryOperator::Greater => Self::GREATER,
91 BinaryOperator::GreaterOrEqual => Self::GREATER_OR_EQUAL,
92 BinaryOperator::Less => Self::LESS,
93 BinaryOperator::LessOrEqual => Self::LESS_OR_EQUAL,
94 BinaryOperator::In => Self::IN,
95 BinaryOperator::NotIn => Self::NOT_IN,
96 BinaryOperator::Contains => Self::CONTAINS,
97 }
98 }
99}
100
101pub const ATC_ROUTER_EXPRESSION_VALIDATE_OK: i64 = 0;
102pub const ATC_ROUTER_EXPRESSION_VALIDATE_FAILED: i64 = 1;
103pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2;
104
105#[no_mangle]
148pub unsafe extern "C" fn expression_validate(
149 atc: *const u8,
150 schema: &Schema,
151 fields_buf: *mut u8,
152 fields_buf_len: *mut usize,
153 fields_total: *mut usize,
154 operators: *mut u64,
155 errbuf: *mut u8,
156 errbuf_len: *mut usize,
157) -> i64 {
158 use std::collections::HashSet;
159
160 use crate::parser::parse;
161 use crate::semantics::Validate;
162
163 let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap();
164 let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN);
165
166 let result = parse(atc).map_err(|e| e.to_string());
168 if let Err(e) = result {
169 let errlen = min(e.len(), *errbuf_len);
170 errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]);
171 *errbuf_len = errlen;
172 return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED;
173 }
174 let ast = result.unwrap();
176
177 if let Err(e) = ast.validate(schema).map_err(|e| e.to_string()) {
179 let errlen = min(e.len(), *errbuf_len);
180 errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]);
181 *errbuf_len = errlen;
182 return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED;
183 }
184
185 let mut ops = BinaryOperatorFlags::empty();
187 let mut existed_fields = HashSet::new();
188 let mut total_fields_length = 0;
189 let mut fields_buf_ptr = fields_buf;
190 *fields_total = 0;
191
192 for pred in ast.iter_predicates() {
193 ops |= BinaryOperatorFlags::from(&pred.op);
194
195 let field = pred.lhs.var_name.as_str();
196
197 if existed_fields.insert(field) {
198 let field = ffi::CString::new(field).unwrap();
201 let field_slice = field.as_bytes_with_nul();
202 let field_len = field_slice.len();
203
204 *fields_total += 1;
205 total_fields_length += field_len;
206
207 if *fields_buf_len < total_fields_length {
208 return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL;
209 }
210
211 let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len);
212 fields_buf.copy_from_slice(field_slice);
213 fields_buf_ptr = fields_buf_ptr.add(field_len);
214 }
215 }
216
217 *operators = ops.bits();
218
219 ATC_ROUTER_EXPRESSION_VALIDATE_OK
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use crate::ast::Type;
226
227 fn expr_validate_on(
228 schema: &Schema,
229 atc: &str,
230 fields_buf_size: usize,
231 ) -> Result<(Vec<String>, usize, u64), (i64, String)> {
232 let atc = ffi::CString::new(atc).unwrap();
233 let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN];
234 let mut errbuf_len = ERR_BUF_MAX_LEN;
235
236 let mut fields_buf = vec![0u8; fields_buf_size];
237 let mut fields_buf_len = fields_buf.len();
238 let mut fields_total = 0;
239 let mut operators = 0u64;
240
241 let result = unsafe {
242 expression_validate(
243 atc.as_bytes().as_ptr(),
244 &schema,
245 fields_buf.as_mut_ptr(),
246 &mut fields_buf_len,
247 &mut fields_total,
248 &mut operators,
249 errbuf.as_mut_ptr(),
250 &mut errbuf_len,
251 )
252 };
253
254 match result {
255 ATC_ROUTER_EXPRESSION_VALIDATE_OK => {
256 let mut fields = Vec::<String>::with_capacity(fields_total);
257 let mut p = 0;
258 for _ in 0..fields_total {
259 let field = unsafe { ffi::CStr::from_ptr(fields_buf[p..].as_ptr().cast()) };
260 let len = field.to_bytes().len() + 1;
261 fields.push(field.to_string_lossy().to_string());
262 p += len;
263 }
264 assert_eq!(fields_buf_len, p, "Fields buffer length mismatch");
265 fields.sort();
266 Ok((fields, fields_buf_len, operators))
267 }
268 ATC_ROUTER_EXPRESSION_VALIDATE_FAILED => {
269 let err = String::from_utf8(errbuf[..errbuf_len].to_vec()).unwrap();
270 Err((result, err))
271 }
272 ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL => Err((result, String::new())),
273 _ => panic!("Unknown error code"),
274 }
275 }
276
277 #[test]
278 fn test_expression_validate_success() {
279 let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##;
280
281 let mut schema = Schema::default();
282 schema.add_field("net.protocol", Type::String);
283 schema.add_field("net.dst.port", Type::Int);
284 schema.add_field("net.src.ip", Type::IpAddr);
285 schema.add_field("http.path", Type::String);
286
287 let result = expr_validate_on(&schema, atc, 47);
288
289 assert!(result.is_ok(), "Validation failed");
290 let (fields, fields_buf_len, ops) = result.unwrap(); assert_eq!(
292 ops,
293 (BinaryOperatorFlags::EQUALS
294 | BinaryOperatorFlags::REGEX
295 | BinaryOperatorFlags::IN
296 | BinaryOperatorFlags::NOT_IN
297 | BinaryOperatorFlags::CONTAINS)
298 .bits(),
299 "Operators mismatch"
300 );
301 assert_eq!(
302 fields,
303 vec![
304 "http.path".to_string(),
305 "net.dst.port".to_string(),
306 "net.protocol".to_string(),
307 "net.src.ip".to_string()
308 ],
309 "Fields mismatch"
310 );
311 assert_eq!(fields_buf_len, 47, "Fields buffer length mismatch");
312 }
313
314 #[test]
315 fn test_expression_validate_failed_parse() {
316 let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0) && http.path contains "hello""##;
317
318 let mut schema = Schema::default();
319 schema.add_field("net.protocol", Type::String);
320 schema.add_field("net.dst.port", Type::Int);
321 schema.add_field("net.src.ip", Type::IpAddr);
322 schema.add_field("http.path", Type::String);
323
324 let result = expr_validate_on(&schema, atc, 1024);
325
326 assert!(result.is_err(), "Validation unexcepted success");
327 let (err_code, err_message) = result.unwrap_err(); assert_eq!(
329 err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED,
330 "Error code mismatch"
331 );
332 assert_eq!(
333 err_message,
334 "In/NotIn operators only supports IP in CIDR".to_string(),
335 "Error message mismatch"
336 );
337 }
338
339 #[test]
340 fn test_expression_validate_failed_validate() {
341 let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##;
342
343 let mut schema = Schema::default();
344 schema.add_field("net.protocol", Type::String);
345 schema.add_field("net.dst.port", Type::Int);
346 schema.add_field("net.src.ip", Type::IpAddr);
347
348 let result = expr_validate_on(&schema, atc, 1024);
349
350 assert!(result.is_err(), "Validation unexcepted success");
351 let (err_code, err_message) = result.unwrap_err(); assert_eq!(
353 err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED,
354 "Error code mismatch"
355 );
356 assert_eq!(
357 err_message,
358 "Unknown LHS field".to_string(),
359 "Error message mismatch"
360 );
361 }
362
363 #[test]
364 fn test_expression_validate_buf_too_small() {
365 let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##;
366
367 let mut schema = Schema::default();
368 schema.add_field("net.protocol", Type::String);
369 schema.add_field("net.dst.port", Type::Int);
370 schema.add_field("net.src.ip", Type::IpAddr);
371 schema.add_field("http.path", Type::String);
372
373 let result = expr_validate_on(&schema, atc, 46);
374
375 assert!(result.is_err(), "Validation failed");
376 let (err_code, _) = result.unwrap_err(); assert_eq!(
378 err_code, ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL,
379 "Error code mismatch"
380 );
381 }
382}