biscuit_auth/datalog/
expression.rs

1use crate::error;
2
3use super::Term;
4use super::{SymbolTable, TemporarySymbolTable};
5use regex::Regex;
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, PartialEq, Hash, Eq)]
9pub struct Expression {
10    pub ops: Vec<Op>,
11}
12
13#[derive(Debug, Clone, PartialEq, Hash, Eq)]
14pub enum Op {
15    Value(Term),
16    Unary(Unary),
17    Binary(Binary),
18}
19
20/// Unary operation code
21#[derive(Debug, Clone, PartialEq, Hash, Eq)]
22pub enum Unary {
23    Negate,
24    Parens,
25    Length,
26}
27
28impl Unary {
29    fn evaluate(
30        &self,
31        value: Term,
32        symbols: &TemporarySymbolTable,
33    ) -> Result<Term, error::Expression> {
34        match (self, value) {
35            (Unary::Negate, Term::Bool(b)) => Ok(Term::Bool(!b)),
36            (Unary::Parens, i) => Ok(i),
37            (Unary::Length, Term::Str(i)) => symbols
38                .get_symbol(i)
39                .map(|s| Term::Integer(s.len() as i64))
40                .ok_or(error::Expression::UnknownSymbol(i)),
41            (Unary::Length, Term::Bytes(s)) => Ok(Term::Integer(s.len() as i64)),
42            (Unary::Length, Term::Set(s)) => Ok(Term::Integer(s.len() as i64)),
43            _ => {
44                //println!("unexpected value type on the stack");
45                Err(error::Expression::InvalidType)
46            }
47        }
48    }
49
50    pub fn print(&self, value: String, _symbols: &SymbolTable) -> String {
51        match self {
52            Unary::Negate => format!("!{}", value),
53            Unary::Parens => format!("({})", value),
54            Unary::Length => format!("{}.length()", value),
55        }
56    }
57}
58
59/// Binary operation code
60#[derive(Debug, Clone, PartialEq, Hash, Eq)]
61pub enum Binary {
62    LessThan,
63    GreaterThan,
64    LessOrEqual,
65    GreaterOrEqual,
66    Equal,
67    Contains,
68    Prefix,
69    Suffix,
70    Regex,
71    Add,
72    Sub,
73    Mul,
74    Div,
75    And,
76    Or,
77    Intersection,
78    Union,
79    BitwiseAnd,
80    BitwiseOr,
81    BitwiseXor,
82    NotEqual,
83}
84
85impl Binary {
86    fn evaluate(
87        &self,
88        left: Term,
89        right: Term,
90        symbols: &mut TemporarySymbolTable,
91    ) -> Result<Term, error::Expression> {
92        match (self, left, right) {
93            // integer
94            (Binary::LessThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i < j)),
95            (Binary::GreaterThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i > j)),
96            (Binary::LessOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i <= j)),
97            (Binary::GreaterOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i >= j)),
98            (Binary::Equal, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i == j)),
99            (Binary::NotEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i != j)),
100            (Binary::Add, Term::Integer(i), Term::Integer(j)) => i
101                .checked_add(j)
102                .map(Term::Integer)
103                .ok_or(error::Expression::Overflow),
104            (Binary::Sub, Term::Integer(i), Term::Integer(j)) => i
105                .checked_sub(j)
106                .map(Term::Integer)
107                .ok_or(error::Expression::Overflow),
108            (Binary::Mul, Term::Integer(i), Term::Integer(j)) => i
109                .checked_mul(j)
110                .map(Term::Integer)
111                .ok_or(error::Expression::Overflow),
112            (Binary::Div, Term::Integer(i), Term::Integer(j)) => i
113                .checked_div(j)
114                .map(Term::Integer)
115                .ok_or(error::Expression::DivideByZero),
116            (Binary::BitwiseAnd, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i & j)),
117            (Binary::BitwiseOr, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i | j)),
118            (Binary::BitwiseXor, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i ^ j)),
119
120            // string
121            (Binary::Prefix, Term::Str(s), Term::Str(pref)) => {
122                match (symbols.get_symbol(s), symbols.get_symbol(pref)) {
123                    (Some(s), Some(pref)) => Ok(Term::Bool(s.starts_with(pref))),
124                    (Some(_), None) => Err(error::Expression::UnknownSymbol(pref)),
125                    _ => Err(error::Expression::UnknownSymbol(s)),
126                }
127            }
128            (Binary::Suffix, Term::Str(s), Term::Str(suff)) => {
129                match (symbols.get_symbol(s), symbols.get_symbol(suff)) {
130                    (Some(s), Some(suff)) => Ok(Term::Bool(s.ends_with(suff))),
131                    (Some(_), None) => Err(error::Expression::UnknownSymbol(suff)),
132                    _ => Err(error::Expression::UnknownSymbol(s)),
133                }
134            }
135            (Binary::Regex, Term::Str(s), Term::Str(r)) => {
136                match (symbols.get_symbol(s), symbols.get_symbol(r)) {
137                    (Some(s), Some(r)) => Ok(Term::Bool(
138                        Regex::new(r).map(|re| re.is_match(s)).unwrap_or(false),
139                    )),
140                    (Some(_), None) => Err(error::Expression::UnknownSymbol(r)),
141                    _ => Err(error::Expression::UnknownSymbol(s)),
142                }
143            }
144            (Binary::Contains, Term::Str(s), Term::Str(pattern)) => {
145                match (symbols.get_symbol(s), symbols.get_symbol(pattern)) {
146                    (Some(s), Some(pattern)) => Ok(Term::Bool(s.contains(pattern))),
147                    (Some(_), None) => Err(error::Expression::UnknownSymbol(pattern)),
148                    _ => Err(error::Expression::UnknownSymbol(s)),
149                }
150            }
151            (Binary::Add, Term::Str(s1), Term::Str(s2)) => {
152                match (symbols.get_symbol(s1), symbols.get_symbol(s2)) {
153                    (Some(s1), Some(s2)) => {
154                        let s = format!("{}{}", s1, s2);
155                        let sym = symbols.insert(&s);
156                        Ok(Term::Str(sym))
157                    }
158                    (Some(_), None) => Err(error::Expression::UnknownSymbol(s2)),
159                    _ => Err(error::Expression::UnknownSymbol(s1)),
160                }
161            }
162            (Binary::Equal, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i == j)),
163            (Binary::NotEqual, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i != j)),
164
165            // date
166            (Binary::LessThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i < j)),
167            (Binary::GreaterThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i > j)),
168            (Binary::LessOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i <= j)),
169            (Binary::GreaterOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i >= j)),
170            (Binary::Equal, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i == j)),
171            (Binary::NotEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i != j)),
172
173            // symbol
174
175            // byte array
176            (Binary::Equal, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i == j)),
177            (Binary::NotEqual, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i != j)),
178
179            // set
180            (Binary::Equal, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set == s)),
181            (Binary::NotEqual, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set != s)),
182            (Binary::Intersection, Term::Set(set), Term::Set(s)) => {
183                Ok(Term::Set(set.intersection(&s).cloned().collect()))
184            }
185            (Binary::Union, Term::Set(set), Term::Set(s)) => {
186                Ok(Term::Set(set.union(&s).cloned().collect()))
187            }
188            (Binary::Contains, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set.is_superset(&s))),
189            (Binary::Contains, Term::Set(set), Term::Integer(i)) => {
190                Ok(Term::Bool(set.contains(&Term::Integer(i))))
191            }
192            (Binary::Contains, Term::Set(set), Term::Date(i)) => {
193                Ok(Term::Bool(set.contains(&Term::Date(i))))
194            }
195            (Binary::Contains, Term::Set(set), Term::Bool(i)) => {
196                Ok(Term::Bool(set.contains(&Term::Bool(i))))
197            }
198            (Binary::Contains, Term::Set(set), Term::Str(i)) => {
199                Ok(Term::Bool(set.contains(&Term::Str(i))))
200            }
201            (Binary::Contains, Term::Set(set), Term::Bytes(i)) => {
202                Ok(Term::Bool(set.contains(&Term::Bytes(i))))
203            }
204
205            // boolean
206            (Binary::And, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i & j)),
207            (Binary::Or, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i | j)),
208            (Binary::Equal, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i == j)),
209            (Binary::NotEqual, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i != j)),
210
211            _ => {
212                //println!("unexpected value type on the stack");
213                Err(error::Expression::InvalidType)
214            }
215        }
216    }
217
218    pub fn print(&self, left: String, right: String, _symbols: &SymbolTable) -> String {
219        match self {
220            Binary::LessThan => format!("{} < {}", left, right),
221            Binary::GreaterThan => format!("{} > {}", left, right),
222            Binary::LessOrEqual => format!("{} <= {}", left, right),
223            Binary::GreaterOrEqual => format!("{} >= {}", left, right),
224            Binary::Equal => format!("{} == {}", left, right),
225            Binary::NotEqual => format!("{} != {}", left, right),
226            Binary::Contains => format!("{}.contains({})", left, right),
227            Binary::Prefix => format!("{}.starts_with({})", left, right),
228            Binary::Suffix => format!("{}.ends_with({})", left, right),
229            Binary::Regex => format!("{}.matches({})", left, right),
230            Binary::Add => format!("{} + {}", left, right),
231            Binary::Sub => format!("{} - {}", left, right),
232            Binary::Mul => format!("{} * {}", left, right),
233            Binary::Div => format!("{} / {}", left, right),
234            Binary::And => format!("{} && {}", left, right),
235            Binary::Or => format!("{} || {}", left, right),
236            Binary::Intersection => format!("{}.intersection({})", left, right),
237            Binary::Union => format!("{}.union({})", left, right),
238            Binary::BitwiseAnd => format!("{} & {}", left, right),
239            Binary::BitwiseOr => format!("{} | {}", left, right),
240            Binary::BitwiseXor => format!("{} ^ {}", left, right),
241        }
242    }
243}
244
245impl Expression {
246    pub fn evaluate(
247        &self,
248        values: &HashMap<u32, Term>,
249        symbols: &mut TemporarySymbolTable,
250    ) -> Result<Term, error::Expression> {
251        let mut stack: Vec<Term> = Vec::new();
252
253        for op in self.ops.iter() {
254            //println!("op: {:?}\t| stack: {:?}", op, stack);
255            match op {
256                Op::Value(Term::Variable(i)) => match values.get(i) {
257                    Some(term) => stack.push(term.clone()),
258                    None => {
259                        //println!("unknown variable {}", i);
260                        return Err(error::Expression::UnknownVariable(*i));
261                    }
262                },
263                Op::Value(term) => stack.push(term.clone()),
264                Op::Unary(unary) => match stack.pop() {
265                    None => {
266                        //println!("expected a value on the stack");
267                        return Err(error::Expression::InvalidStack);
268                    }
269                    Some(term) => stack.push(unary.evaluate(term, symbols)?),
270                },
271                Op::Binary(binary) => match (stack.pop(), stack.pop()) {
272                    (Some(right_term), Some(left_term)) => {
273                        stack.push(binary.evaluate(left_term, right_term, symbols)?)
274                    }
275
276                    _ => {
277                        //println!("expected two values on the stack");
278                        return Err(error::Expression::InvalidStack);
279                    }
280                },
281            }
282        }
283
284        if stack.len() == 1 {
285            Ok(stack.remove(0))
286        } else {
287            Err(error::Expression::InvalidStack)
288        }
289    }
290
291    pub fn print(&self, symbols: &SymbolTable) -> Option<String> {
292        let mut stack: Vec<String> = Vec::new();
293
294        for op in self.ops.iter() {
295            //println!("op: {:?}\t| stack: {:?}", op, stack);
296            match op {
297                Op::Value(i) => stack.push(symbols.print_term(i)),
298                Op::Unary(unary) => match stack.pop() {
299                    None => return None,
300                    Some(s) => stack.push(unary.print(s, symbols)),
301                },
302                Op::Binary(binary) => match (stack.pop(), stack.pop()) {
303                    (Some(right), Some(left)) => stack.push(binary.print(left, right, symbols)),
304                    _ => return None,
305                },
306            }
307        }
308
309        if stack.len() == 1 {
310            Some(stack.remove(0))
311        } else {
312            None
313        }
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::datalog::{SymbolTable, TemporarySymbolTable};
321
322    #[test]
323    fn negate() {
324        let mut symbols = SymbolTable::new();
325        symbols.insert("test1");
326        symbols.insert("test2");
327        symbols.insert("var1");
328        let mut tmp_symbols = TemporarySymbolTable::new(&symbols);
329
330        let ops = vec![
331            Op::Value(Term::Integer(1)),
332            Op::Value(Term::Variable(2)),
333            Op::Binary(Binary::LessThan),
334            Op::Unary(Unary::Parens),
335            Op::Unary(Unary::Negate),
336        ];
337
338        let values: HashMap<u32, Term> = [(2, Term::Integer(0))].iter().cloned().collect();
339
340        println!("ops: {:?}", ops);
341
342        let e = Expression { ops };
343        println!("print: {}", e.print(&symbols).unwrap());
344
345        let res = e.evaluate(&values, &mut tmp_symbols);
346        assert_eq!(res, Ok(Term::Bool(true)));
347    }
348
349    #[test]
350    fn bitwise() {
351        for (op, v1, v2, expected) in [
352            (Binary::BitwiseAnd, 9, 10, 8),
353            (Binary::BitwiseAnd, 9, 1, 1),
354            (Binary::BitwiseAnd, 9, 0, 0),
355            (Binary::BitwiseOr, 1, 2, 3),
356            (Binary::BitwiseOr, 2, 2, 2),
357            (Binary::BitwiseOr, 2, 0, 2),
358            (Binary::BitwiseXor, 1, 0, 1),
359            (Binary::BitwiseXor, 1, 1, 0),
360        ] {
361            let symbols = SymbolTable::new();
362            let mut tmp_symbols = TemporarySymbolTable::new(&symbols);
363
364            let ops = vec![
365                Op::Value(Term::Integer(v1)),
366                Op::Value(Term::Integer(v2)),
367                Op::Binary(op),
368            ];
369
370            println!("ops: {:?}", ops);
371
372            let e = Expression { ops };
373            println!("print: {}", e.print(&symbols).unwrap());
374
375            let res = e.evaluate(&HashMap::new(), &mut tmp_symbols);
376            assert_eq!(res, Ok(Term::Integer(expected)));
377        }
378    }
379
380    #[test]
381    fn checked() {
382        let symbols = SymbolTable::new();
383        let mut tmp_symbols = TemporarySymbolTable::new(&symbols);
384        let ops = vec![
385            Op::Value(Term::Integer(1)),
386            Op::Value(Term::Integer(0)),
387            Op::Binary(Binary::Div),
388        ];
389
390        let values = HashMap::new();
391        let e = Expression { ops };
392        let res = e.evaluate(&values, &mut tmp_symbols);
393        assert_eq!(res, Err(error::Expression::DivideByZero));
394
395        let ops = vec![
396            Op::Value(Term::Integer(1)),
397            Op::Value(Term::Integer(i64::MAX)),
398            Op::Binary(Binary::Add),
399        ];
400
401        let values = HashMap::new();
402        let e = Expression { ops };
403        let res = e.evaluate(&values, &mut tmp_symbols);
404        assert_eq!(res, Err(error::Expression::Overflow));
405
406        let ops = vec![
407            Op::Value(Term::Integer(-10)),
408            Op::Value(Term::Integer(i64::MAX)),
409            Op::Binary(Binary::Sub),
410        ];
411
412        let values = HashMap::new();
413        let e = Expression { ops };
414        let res = e.evaluate(&values, &mut tmp_symbols);
415        assert_eq!(res, Err(error::Expression::Overflow));
416
417        let ops = vec![
418            Op::Value(Term::Integer(2)),
419            Op::Value(Term::Integer(i64::MAX)),
420            Op::Binary(Binary::Mul),
421        ];
422
423        let values = HashMap::new();
424        let e = Expression { ops };
425        let res = e.evaluate(&values, &mut tmp_symbols);
426        assert_eq!(res, Err(error::Expression::Overflow));
427    }
428
429    #[test]
430    fn printer() {
431        let mut symbols = SymbolTable::new();
432        symbols.insert("test1");
433        symbols.insert("test2");
434        symbols.insert("var1");
435
436        let ops1 = vec![
437            Op::Value(Term::Integer(-1)),
438            Op::Value(Term::Variable(1026)),
439            Op::Binary(Binary::LessThan),
440        ];
441
442        let ops2 = vec![
443            Op::Value(Term::Integer(1)),
444            Op::Value(Term::Integer(2)),
445            Op::Value(Term::Integer(3)),
446            Op::Binary(Binary::Add),
447            Op::Binary(Binary::LessThan),
448        ];
449
450        let ops3 = vec![
451            Op::Value(Term::Integer(1)),
452            Op::Value(Term::Integer(2)),
453            Op::Binary(Binary::Add),
454            Op::Value(Term::Integer(3)),
455            Op::Binary(Binary::LessThan),
456        ];
457
458        println!("ops1: {:?}", ops1);
459        println!("ops2: {:?}", ops2);
460        println!("ops3: {:?}", ops3);
461        let e1 = Expression { ops: ops1 };
462        let e2 = Expression { ops: ops2 };
463        let e3 = Expression { ops: ops3 };
464
465        assert_eq!(e1.print(&symbols).unwrap(), "-1 < $var1");
466
467        assert_eq!(e2.print(&symbols).unwrap(), "1 < 2 + 3");
468
469        assert_eq!(e3.print(&symbols).unwrap(), "1 + 2 < 3");
470        //panic!();
471    }
472}