atc_router/ffi/
expression.rs

1use crate::ast::{BinaryOperator, Expression, LogicalExpression, Predicate};
2use crate::ffi::ERR_BUF_MAX_LEN;
3use crate::schema::Schema;
4use bitflags::bitflags;
5use std::cmp::min;
6use std::ffi;
7use std::os::raw::c_char;
8use std::slice::from_raw_parts_mut;
9
10use std::iter::Iterator;
11
12struct PredicateIterator<'a> {
13    stack: Vec<&'a Expression>,
14}
15
16impl<'a> PredicateIterator<'a> {
17    fn new(expr: &'a Expression) -> Self {
18        Self { stack: vec![expr] }
19    }
20}
21
22impl<'a> Iterator for PredicateIterator<'a> {
23    type Item = &'a Predicate;
24
25    fn next(&mut self) -> Option<Self::Item> {
26        while let Some(expr) = self.stack.pop() {
27            match expr {
28                Expression::Logical(l) => match l.as_ref() {
29                    LogicalExpression::And(l, r) | LogicalExpression::Or(l, r) => {
30                        self.stack.push(l);
31                        self.stack.push(r);
32                    }
33                    LogicalExpression::Not(r) => {
34                        self.stack.push(r);
35                    }
36                },
37                Expression::Predicate(p) => return Some(p),
38            }
39        }
40        None
41    }
42}
43
44impl Expression {
45    fn iter_predicates(&self) -> PredicateIterator {
46        PredicateIterator::new(self)
47    }
48}
49
50bitflags! {
51    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
52    #[repr(C)]
53    pub struct BinaryOperatorFlags: u64 /* We can only have no more than 64 BinaryOperators */ {
54        const EQUALS = 1 << 0;
55        const NOT_EQUALS = 1 << 1;
56        const REGEX = 1 << 2;
57        const PREFIX = 1 << 3;
58        const POSTFIX = 1 << 4;
59        const GREATER = 1 << 5;
60        const GREATER_OR_EQUAL = 1 << 6;
61        const LESS = 1 << 7;
62        const LESS_OR_EQUAL = 1 << 8;
63        const IN = 1 << 9;
64        const NOT_IN = 1 << 10;
65        const CONTAINS = 1 << 11;
66
67        const UNUSED = !(Self::EQUALS.bits()
68            | Self::NOT_EQUALS.bits()
69            | Self::REGEX.bits()
70            | Self::PREFIX.bits()
71            | Self::POSTFIX.bits()
72            | Self::GREATER.bits()
73            | Self::GREATER_OR_EQUAL.bits()
74            | Self::LESS.bits()
75            | Self::LESS_OR_EQUAL.bits()
76            | Self::IN.bits()
77            | Self::NOT_IN.bits()
78            | Self::CONTAINS.bits());
79    }
80}
81
82impl From<&BinaryOperator> for BinaryOperatorFlags {
83    fn from(op: &BinaryOperator) -> Self {
84        match op {
85            BinaryOperator::Equals => Self::EQUALS,
86            BinaryOperator::NotEquals => Self::NOT_EQUALS,
87            BinaryOperator::Regex => Self::REGEX,
88            BinaryOperator::Prefix => Self::PREFIX,
89            BinaryOperator::Postfix => Self::POSTFIX,
90            BinaryOperator::Greater => Self::GREATER,
91            BinaryOperator::GreaterOrEqual => Self::GREATER_OR_EQUAL,
92            BinaryOperator::Less => Self::LESS,
93            BinaryOperator::LessOrEqual => Self::LESS_OR_EQUAL,
94            BinaryOperator::In => Self::IN,
95            BinaryOperator::NotIn => Self::NOT_IN,
96            BinaryOperator::Contains => Self::CONTAINS,
97        }
98    }
99}
100
101pub const ATC_ROUTER_EXPRESSION_VALIDATE_OK: i64 = 0;
102pub const ATC_ROUTER_EXPRESSION_VALIDATE_FAILED: i64 = 1;
103pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2;
104
105/// Validates an ATC expression against a schema and get its elements.
106///
107/// # Arguments
108///
109/// - `atc`: a C-style string representing the ATC expression.
110/// - `schema`: a valid pointer to a [`Schema`] object, as returned by [`schema_new`].
111/// - `fields_buf`: a buffer for storing the fields used in the expression.
112/// - `fields_buf_len`: a pointer to the length of `fields_buf`.
113/// - `fields_total`: a pointer for storing the total number of unique fields used in the expression.
114/// - `operators`: a pointer for storing the bitflags representing used operators.
115/// - `errbuf`: a buffer to store any error messages.
116/// - `errbuf_len`: a pointer to the length of the error message buffer.
117///
118/// # Returns
119///
120/// An integer indicating the validation result:
121/// - `ATC_ROUTER_EXPRESSION_VALIDATE_OK` (0): Validation succeeded.
122/// - `ATC_ROUTER_EXPRESSION_VALIDATE_FAILED` (1): Validation failed; `errbuf` and `errbuf_len` will be updated with an error message.
123/// - `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL` (2): The provided `fields_buf` is too small.
124///
125/// If `fields_buf_len` indicates that `fields_buf` is sufficient, this function writes the used fields to `fields_buf`, each field terminated by `\0`.
126/// It stores the total number of fields in `fields_total`.
127///
128/// If `fields_buf_len` indicates that `fields_buf` is insufficient, it returns `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL`.
129///
130/// It writes the used operators as bitflags to `operators`.
131/// Bitflags are defined by `BinaryOperatorFlags` and must exclude bits from `BinaryOperatorFlags::UNUSED`.
132///
133///
134/// # Safety
135///
136/// Violating any of the following constraints results in undefined behavior:
137///
138/// - `atc` must be a valid pointer to a C-style string, properly aligned, and must not contain an internal `\0`.
139/// - `schema` must be a valid pointer returned by [`schema_new`].
140/// - `fields_buf`, must be valid for writing `fields_buf_len * size_of::<u8>()` bytes and properly aligned.
141/// - `fields_buf_len` must be a valid pointer to write `size_of::<usize>()` bytes and properly aligned.
142/// - `fields_total` must be a valid pointer to write `size_of::<usize>()` bytes and properly aligned.
143/// - `operators` must be a valid pointer to write `size_of::<u64>()` bytes and properly aligned.
144/// - `errbuf` must be valid for reading and writing `errbuf_len * size_of::<u8>()` bytes and properly aligned.
145/// - `errbuf_len` must be a valid pointer for reading and writing `size_of::<usize>()` bytes and properly aligned.
146
147#[no_mangle]
148pub unsafe extern "C" fn expression_validate(
149    atc: *const u8,
150    schema: &Schema,
151    fields_buf: *mut u8,
152    fields_buf_len: *mut usize,
153    fields_total: *mut usize,
154    operators: *mut u64,
155    errbuf: *mut u8,
156    errbuf_len: *mut usize,
157) -> i64 {
158    use std::collections::HashSet;
159
160    use crate::parser::parse;
161    use crate::semantics::Validate;
162
163    let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap();
164    let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN);
165
166    // Parse the expression
167    let result = parse(atc).map_err(|e| e.to_string());
168    if let Err(e) = result {
169        let errlen = min(e.len(), *errbuf_len);
170        errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]);
171        *errbuf_len = errlen;
172        return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED;
173    }
174    // Unwrap is safe since we've already checked for error
175    let ast = result.unwrap();
176
177    // Validate expression with schema
178    if let Err(e) = ast.validate(schema).map_err(|e| e.to_string()) {
179        let errlen = min(e.len(), *errbuf_len);
180        errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]);
181        *errbuf_len = errlen;
182        return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED;
183    }
184
185    // Iterate over predicates to get fields and operators
186    let mut ops = BinaryOperatorFlags::empty();
187    let mut existed_fields = HashSet::new();
188    let mut total_fields_length = 0;
189    let mut fields_buf_ptr = fields_buf;
190    *fields_total = 0;
191
192    for pred in ast.iter_predicates() {
193        ops |= BinaryOperatorFlags::from(&pred.op);
194
195        let field = pred.lhs.var_name.as_str();
196
197        if existed_fields.insert(field) {
198            // Fields is not existed yet.
199            // Unwrap is safe since `field` cannot contain '\0' as `atc` must not contain any internal `\0`.
200            let field = ffi::CString::new(field).unwrap();
201            let field_slice = field.as_bytes_with_nul();
202            let field_len = field_slice.len();
203
204            *fields_total += 1;
205            total_fields_length += field_len;
206
207            if *fields_buf_len < total_fields_length {
208                return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL;
209            }
210
211            let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len);
212            fields_buf.copy_from_slice(field_slice);
213            fields_buf_ptr = fields_buf_ptr.add(field_len);
214        }
215    }
216
217    *operators = ops.bits();
218
219    ATC_ROUTER_EXPRESSION_VALIDATE_OK
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use crate::ast::Type;
226
227    fn expr_validate_on(
228        schema: &Schema,
229        atc: &str,
230        fields_buf_size: usize,
231    ) -> Result<(Vec<String>, usize, u64), (i64, String)> {
232        let atc = ffi::CString::new(atc).unwrap();
233        let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN];
234        let mut errbuf_len = ERR_BUF_MAX_LEN;
235
236        let mut fields_buf = vec![0u8; fields_buf_size];
237        let mut fields_buf_len = fields_buf.len();
238        let mut fields_total = 0;
239        let mut operators = 0u64;
240
241        let result = unsafe {
242            expression_validate(
243                atc.as_bytes().as_ptr(),
244                &schema,
245                fields_buf.as_mut_ptr(),
246                &mut fields_buf_len,
247                &mut fields_total,
248                &mut operators,
249                errbuf.as_mut_ptr(),
250                &mut errbuf_len,
251            )
252        };
253
254        match result {
255            ATC_ROUTER_EXPRESSION_VALIDATE_OK => {
256                let mut fields = Vec::<String>::with_capacity(fields_total);
257                let mut p = 0;
258                for _ in 0..fields_total {
259                    let field = unsafe { ffi::CStr::from_ptr(fields_buf[p..].as_ptr().cast()) };
260                    let len = field.to_bytes().len() + 1;
261                    fields.push(field.to_string_lossy().to_string());
262                    p += len;
263                }
264                assert_eq!(fields_buf_len, p, "Fields buffer length mismatch");
265                fields.sort();
266                Ok((fields, fields_buf_len, operators))
267            }
268            ATC_ROUTER_EXPRESSION_VALIDATE_FAILED => {
269                let err = String::from_utf8(errbuf[..errbuf_len].to_vec()).unwrap();
270                Err((result, err))
271            }
272            ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL => Err((result, String::new())),
273            _ => panic!("Unknown error code"),
274        }
275    }
276
277    #[test]
278    fn test_expression_validate_success() {
279        let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##;
280
281        let mut schema = Schema::default();
282        schema.add_field("net.protocol", Type::String);
283        schema.add_field("net.dst.port", Type::Int);
284        schema.add_field("net.src.ip", Type::IpAddr);
285        schema.add_field("http.path", Type::String);
286
287        let result = expr_validate_on(&schema, atc, 47);
288
289        assert!(result.is_ok(), "Validation failed");
290        let (fields, fields_buf_len, ops) = result.unwrap(); // Unwrap is safe since we've already asserted it
291        assert_eq!(
292            ops,
293            (BinaryOperatorFlags::EQUALS
294                | BinaryOperatorFlags::REGEX
295                | BinaryOperatorFlags::IN
296                | BinaryOperatorFlags::NOT_IN
297                | BinaryOperatorFlags::CONTAINS)
298                .bits(),
299            "Operators mismatch"
300        );
301        assert_eq!(
302            fields,
303            vec![
304                "http.path".to_string(),
305                "net.dst.port".to_string(),
306                "net.protocol".to_string(),
307                "net.src.ip".to_string()
308            ],
309            "Fields mismatch"
310        );
311        assert_eq!(fields_buf_len, 47, "Fields buffer length mismatch");
312    }
313
314    #[test]
315    fn test_expression_validate_failed_parse() {
316        let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0) && http.path contains "hello""##;
317
318        let mut schema = Schema::default();
319        schema.add_field("net.protocol", Type::String);
320        schema.add_field("net.dst.port", Type::Int);
321        schema.add_field("net.src.ip", Type::IpAddr);
322        schema.add_field("http.path", Type::String);
323
324        let result = expr_validate_on(&schema, atc, 1024);
325
326        assert!(result.is_err(), "Validation unexcepted success");
327        let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it
328        assert_eq!(
329            err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED,
330            "Error code mismatch"
331        );
332        assert_eq!(
333            err_message,
334            "In/NotIn operators only supports IP in CIDR".to_string(),
335            "Error message mismatch"
336        );
337    }
338
339    #[test]
340    fn test_expression_validate_failed_validate() {
341        let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##;
342
343        let mut schema = Schema::default();
344        schema.add_field("net.protocol", Type::String);
345        schema.add_field("net.dst.port", Type::Int);
346        schema.add_field("net.src.ip", Type::IpAddr);
347
348        let result = expr_validate_on(&schema, atc, 1024);
349
350        assert!(result.is_err(), "Validation unexcepted success");
351        let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it
352        assert_eq!(
353            err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED,
354            "Error code mismatch"
355        );
356        assert_eq!(
357            err_message,
358            "Unknown LHS field".to_string(),
359            "Error message mismatch"
360        );
361    }
362
363    #[test]
364    fn test_expression_validate_buf_too_small() {
365        let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##;
366
367        let mut schema = Schema::default();
368        schema.add_field("net.protocol", Type::String);
369        schema.add_field("net.dst.port", Type::Int);
370        schema.add_field("net.src.ip", Type::IpAddr);
371        schema.add_field("http.path", Type::String);
372
373        let result = expr_validate_on(&schema, atc, 46);
374
375        assert!(result.is_err(), "Validation failed");
376        let (err_code, _) = result.unwrap_err(); // Unwrap is safe since we've already asserted it
377        assert_eq!(
378            err_code, ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL,
379            "Error code mismatch"
380        );
381    }
382}