1use crate::error;
2
3use super::Term;
4use super::{SymbolTable, TemporarySymbolTable};
5use regex::Regex;
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, PartialEq, Hash, Eq)]
9pub struct Expression {
10 pub ops: Vec<Op>,
11}
12
13#[derive(Debug, Clone, PartialEq, Hash, Eq)]
14pub enum Op {
15 Value(Term),
16 Unary(Unary),
17 Binary(Binary),
18}
19
20#[derive(Debug, Clone, PartialEq, Hash, Eq)]
22pub enum Unary {
23 Negate,
24 Parens,
25 Length,
26}
27
28impl Unary {
29 fn evaluate(
30 &self,
31 value: Term,
32 symbols: &TemporarySymbolTable,
33 ) -> Result<Term, error::Expression> {
34 match (self, value) {
35 (Unary::Negate, Term::Bool(b)) => Ok(Term::Bool(!b)),
36 (Unary::Parens, i) => Ok(i),
37 (Unary::Length, Term::Str(i)) => symbols
38 .get_symbol(i)
39 .map(|s| Term::Integer(s.len() as i64))
40 .ok_or(error::Expression::UnknownSymbol(i)),
41 (Unary::Length, Term::Bytes(s)) => Ok(Term::Integer(s.len() as i64)),
42 (Unary::Length, Term::Set(s)) => Ok(Term::Integer(s.len() as i64)),
43 _ => {
44 Err(error::Expression::InvalidType)
46 }
47 }
48 }
49
50 pub fn print(&self, value: String, _symbols: &SymbolTable) -> String {
51 match self {
52 Unary::Negate => format!("!{}", value),
53 Unary::Parens => format!("({})", value),
54 Unary::Length => format!("{}.length()", value),
55 }
56 }
57}
58
59#[derive(Debug, Clone, PartialEq, Hash, Eq)]
61pub enum Binary {
62 LessThan,
63 GreaterThan,
64 LessOrEqual,
65 GreaterOrEqual,
66 Equal,
67 Contains,
68 Prefix,
69 Suffix,
70 Regex,
71 Add,
72 Sub,
73 Mul,
74 Div,
75 And,
76 Or,
77 Intersection,
78 Union,
79 BitwiseAnd,
80 BitwiseOr,
81 BitwiseXor,
82 NotEqual,
83}
84
85impl Binary {
86 fn evaluate(
87 &self,
88 left: Term,
89 right: Term,
90 symbols: &mut TemporarySymbolTable,
91 ) -> Result<Term, error::Expression> {
92 match (self, left, right) {
93 (Binary::LessThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i < j)),
95 (Binary::GreaterThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i > j)),
96 (Binary::LessOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i <= j)),
97 (Binary::GreaterOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i >= j)),
98 (Binary::Equal, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i == j)),
99 (Binary::NotEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i != j)),
100 (Binary::Add, Term::Integer(i), Term::Integer(j)) => i
101 .checked_add(j)
102 .map(Term::Integer)
103 .ok_or(error::Expression::Overflow),
104 (Binary::Sub, Term::Integer(i), Term::Integer(j)) => i
105 .checked_sub(j)
106 .map(Term::Integer)
107 .ok_or(error::Expression::Overflow),
108 (Binary::Mul, Term::Integer(i), Term::Integer(j)) => i
109 .checked_mul(j)
110 .map(Term::Integer)
111 .ok_or(error::Expression::Overflow),
112 (Binary::Div, Term::Integer(i), Term::Integer(j)) => i
113 .checked_div(j)
114 .map(Term::Integer)
115 .ok_or(error::Expression::DivideByZero),
116 (Binary::BitwiseAnd, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i & j)),
117 (Binary::BitwiseOr, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i | j)),
118 (Binary::BitwiseXor, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i ^ j)),
119
120 (Binary::Prefix, Term::Str(s), Term::Str(pref)) => {
122 match (symbols.get_symbol(s), symbols.get_symbol(pref)) {
123 (Some(s), Some(pref)) => Ok(Term::Bool(s.starts_with(pref))),
124 (Some(_), None) => Err(error::Expression::UnknownSymbol(pref)),
125 _ => Err(error::Expression::UnknownSymbol(s)),
126 }
127 }
128 (Binary::Suffix, Term::Str(s), Term::Str(suff)) => {
129 match (symbols.get_symbol(s), symbols.get_symbol(suff)) {
130 (Some(s), Some(suff)) => Ok(Term::Bool(s.ends_with(suff))),
131 (Some(_), None) => Err(error::Expression::UnknownSymbol(suff)),
132 _ => Err(error::Expression::UnknownSymbol(s)),
133 }
134 }
135 (Binary::Regex, Term::Str(s), Term::Str(r)) => {
136 match (symbols.get_symbol(s), symbols.get_symbol(r)) {
137 (Some(s), Some(r)) => Ok(Term::Bool(
138 Regex::new(r).map(|re| re.is_match(s)).unwrap_or(false),
139 )),
140 (Some(_), None) => Err(error::Expression::UnknownSymbol(r)),
141 _ => Err(error::Expression::UnknownSymbol(s)),
142 }
143 }
144 (Binary::Contains, Term::Str(s), Term::Str(pattern)) => {
145 match (symbols.get_symbol(s), symbols.get_symbol(pattern)) {
146 (Some(s), Some(pattern)) => Ok(Term::Bool(s.contains(pattern))),
147 (Some(_), None) => Err(error::Expression::UnknownSymbol(pattern)),
148 _ => Err(error::Expression::UnknownSymbol(s)),
149 }
150 }
151 (Binary::Add, Term::Str(s1), Term::Str(s2)) => {
152 match (symbols.get_symbol(s1), symbols.get_symbol(s2)) {
153 (Some(s1), Some(s2)) => {
154 let s = format!("{}{}", s1, s2);
155 let sym = symbols.insert(&s);
156 Ok(Term::Str(sym))
157 }
158 (Some(_), None) => Err(error::Expression::UnknownSymbol(s2)),
159 _ => Err(error::Expression::UnknownSymbol(s1)),
160 }
161 }
162 (Binary::Equal, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i == j)),
163 (Binary::NotEqual, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i != j)),
164
165 (Binary::LessThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i < j)),
167 (Binary::GreaterThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i > j)),
168 (Binary::LessOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i <= j)),
169 (Binary::GreaterOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i >= j)),
170 (Binary::Equal, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i == j)),
171 (Binary::NotEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i != j)),
172
173 (Binary::Equal, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i == j)),
177 (Binary::NotEqual, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i != j)),
178
179 (Binary::Equal, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set == s)),
181 (Binary::NotEqual, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set != s)),
182 (Binary::Intersection, Term::Set(set), Term::Set(s)) => {
183 Ok(Term::Set(set.intersection(&s).cloned().collect()))
184 }
185 (Binary::Union, Term::Set(set), Term::Set(s)) => {
186 Ok(Term::Set(set.union(&s).cloned().collect()))
187 }
188 (Binary::Contains, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set.is_superset(&s))),
189 (Binary::Contains, Term::Set(set), Term::Integer(i)) => {
190 Ok(Term::Bool(set.contains(&Term::Integer(i))))
191 }
192 (Binary::Contains, Term::Set(set), Term::Date(i)) => {
193 Ok(Term::Bool(set.contains(&Term::Date(i))))
194 }
195 (Binary::Contains, Term::Set(set), Term::Bool(i)) => {
196 Ok(Term::Bool(set.contains(&Term::Bool(i))))
197 }
198 (Binary::Contains, Term::Set(set), Term::Str(i)) => {
199 Ok(Term::Bool(set.contains(&Term::Str(i))))
200 }
201 (Binary::Contains, Term::Set(set), Term::Bytes(i)) => {
202 Ok(Term::Bool(set.contains(&Term::Bytes(i))))
203 }
204
205 (Binary::And, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i & j)),
207 (Binary::Or, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i | j)),
208 (Binary::Equal, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i == j)),
209 (Binary::NotEqual, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i != j)),
210
211 _ => {
212 Err(error::Expression::InvalidType)
214 }
215 }
216 }
217
218 pub fn print(&self, left: String, right: String, _symbols: &SymbolTable) -> String {
219 match self {
220 Binary::LessThan => format!("{} < {}", left, right),
221 Binary::GreaterThan => format!("{} > {}", left, right),
222 Binary::LessOrEqual => format!("{} <= {}", left, right),
223 Binary::GreaterOrEqual => format!("{} >= {}", left, right),
224 Binary::Equal => format!("{} == {}", left, right),
225 Binary::NotEqual => format!("{} != {}", left, right),
226 Binary::Contains => format!("{}.contains({})", left, right),
227 Binary::Prefix => format!("{}.starts_with({})", left, right),
228 Binary::Suffix => format!("{}.ends_with({})", left, right),
229 Binary::Regex => format!("{}.matches({})", left, right),
230 Binary::Add => format!("{} + {}", left, right),
231 Binary::Sub => format!("{} - {}", left, right),
232 Binary::Mul => format!("{} * {}", left, right),
233 Binary::Div => format!("{} / {}", left, right),
234 Binary::And => format!("{} && {}", left, right),
235 Binary::Or => format!("{} || {}", left, right),
236 Binary::Intersection => format!("{}.intersection({})", left, right),
237 Binary::Union => format!("{}.union({})", left, right),
238 Binary::BitwiseAnd => format!("{} & {}", left, right),
239 Binary::BitwiseOr => format!("{} | {}", left, right),
240 Binary::BitwiseXor => format!("{} ^ {}", left, right),
241 }
242 }
243}
244
245impl Expression {
246 pub fn evaluate(
247 &self,
248 values: &HashMap<u32, Term>,
249 symbols: &mut TemporarySymbolTable,
250 ) -> Result<Term, error::Expression> {
251 let mut stack: Vec<Term> = Vec::new();
252
253 for op in self.ops.iter() {
254 match op {
256 Op::Value(Term::Variable(i)) => match values.get(i) {
257 Some(term) => stack.push(term.clone()),
258 None => {
259 return Err(error::Expression::UnknownVariable(*i));
261 }
262 },
263 Op::Value(term) => stack.push(term.clone()),
264 Op::Unary(unary) => match stack.pop() {
265 None => {
266 return Err(error::Expression::InvalidStack);
268 }
269 Some(term) => stack.push(unary.evaluate(term, symbols)?),
270 },
271 Op::Binary(binary) => match (stack.pop(), stack.pop()) {
272 (Some(right_term), Some(left_term)) => {
273 stack.push(binary.evaluate(left_term, right_term, symbols)?)
274 }
275
276 _ => {
277 return Err(error::Expression::InvalidStack);
279 }
280 },
281 }
282 }
283
284 if stack.len() == 1 {
285 Ok(stack.remove(0))
286 } else {
287 Err(error::Expression::InvalidStack)
288 }
289 }
290
291 pub fn print(&self, symbols: &SymbolTable) -> Option<String> {
292 let mut stack: Vec<String> = Vec::new();
293
294 for op in self.ops.iter() {
295 match op {
297 Op::Value(i) => stack.push(symbols.print_term(i)),
298 Op::Unary(unary) => match stack.pop() {
299 None => return None,
300 Some(s) => stack.push(unary.print(s, symbols)),
301 },
302 Op::Binary(binary) => match (stack.pop(), stack.pop()) {
303 (Some(right), Some(left)) => stack.push(binary.print(left, right, symbols)),
304 _ => return None,
305 },
306 }
307 }
308
309 if stack.len() == 1 {
310 Some(stack.remove(0))
311 } else {
312 None
313 }
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::datalog::{SymbolTable, TemporarySymbolTable};
321
322 #[test]
323 fn negate() {
324 let mut symbols = SymbolTable::new();
325 symbols.insert("test1");
326 symbols.insert("test2");
327 symbols.insert("var1");
328 let mut tmp_symbols = TemporarySymbolTable::new(&symbols);
329
330 let ops = vec![
331 Op::Value(Term::Integer(1)),
332 Op::Value(Term::Variable(2)),
333 Op::Binary(Binary::LessThan),
334 Op::Unary(Unary::Parens),
335 Op::Unary(Unary::Negate),
336 ];
337
338 let values: HashMap<u32, Term> = [(2, Term::Integer(0))].iter().cloned().collect();
339
340 println!("ops: {:?}", ops);
341
342 let e = Expression { ops };
343 println!("print: {}", e.print(&symbols).unwrap());
344
345 let res = e.evaluate(&values, &mut tmp_symbols);
346 assert_eq!(res, Ok(Term::Bool(true)));
347 }
348
349 #[test]
350 fn bitwise() {
351 for (op, v1, v2, expected) in [
352 (Binary::BitwiseAnd, 9, 10, 8),
353 (Binary::BitwiseAnd, 9, 1, 1),
354 (Binary::BitwiseAnd, 9, 0, 0),
355 (Binary::BitwiseOr, 1, 2, 3),
356 (Binary::BitwiseOr, 2, 2, 2),
357 (Binary::BitwiseOr, 2, 0, 2),
358 (Binary::BitwiseXor, 1, 0, 1),
359 (Binary::BitwiseXor, 1, 1, 0),
360 ] {
361 let symbols = SymbolTable::new();
362 let mut tmp_symbols = TemporarySymbolTable::new(&symbols);
363
364 let ops = vec![
365 Op::Value(Term::Integer(v1)),
366 Op::Value(Term::Integer(v2)),
367 Op::Binary(op),
368 ];
369
370 println!("ops: {:?}", ops);
371
372 let e = Expression { ops };
373 println!("print: {}", e.print(&symbols).unwrap());
374
375 let res = e.evaluate(&HashMap::new(), &mut tmp_symbols);
376 assert_eq!(res, Ok(Term::Integer(expected)));
377 }
378 }
379
380 #[test]
381 fn checked() {
382 let symbols = SymbolTable::new();
383 let mut tmp_symbols = TemporarySymbolTable::new(&symbols);
384 let ops = vec![
385 Op::Value(Term::Integer(1)),
386 Op::Value(Term::Integer(0)),
387 Op::Binary(Binary::Div),
388 ];
389
390 let values = HashMap::new();
391 let e = Expression { ops };
392 let res = e.evaluate(&values, &mut tmp_symbols);
393 assert_eq!(res, Err(error::Expression::DivideByZero));
394
395 let ops = vec![
396 Op::Value(Term::Integer(1)),
397 Op::Value(Term::Integer(i64::MAX)),
398 Op::Binary(Binary::Add),
399 ];
400
401 let values = HashMap::new();
402 let e = Expression { ops };
403 let res = e.evaluate(&values, &mut tmp_symbols);
404 assert_eq!(res, Err(error::Expression::Overflow));
405
406 let ops = vec![
407 Op::Value(Term::Integer(-10)),
408 Op::Value(Term::Integer(i64::MAX)),
409 Op::Binary(Binary::Sub),
410 ];
411
412 let values = HashMap::new();
413 let e = Expression { ops };
414 let res = e.evaluate(&values, &mut tmp_symbols);
415 assert_eq!(res, Err(error::Expression::Overflow));
416
417 let ops = vec![
418 Op::Value(Term::Integer(2)),
419 Op::Value(Term::Integer(i64::MAX)),
420 Op::Binary(Binary::Mul),
421 ];
422
423 let values = HashMap::new();
424 let e = Expression { ops };
425 let res = e.evaluate(&values, &mut tmp_symbols);
426 assert_eq!(res, Err(error::Expression::Overflow));
427 }
428
429 #[test]
430 fn printer() {
431 let mut symbols = SymbolTable::new();
432 symbols.insert("test1");
433 symbols.insert("test2");
434 symbols.insert("var1");
435
436 let ops1 = vec![
437 Op::Value(Term::Integer(-1)),
438 Op::Value(Term::Variable(1026)),
439 Op::Binary(Binary::LessThan),
440 ];
441
442 let ops2 = vec![
443 Op::Value(Term::Integer(1)),
444 Op::Value(Term::Integer(2)),
445 Op::Value(Term::Integer(3)),
446 Op::Binary(Binary::Add),
447 Op::Binary(Binary::LessThan),
448 ];
449
450 let ops3 = vec![
451 Op::Value(Term::Integer(1)),
452 Op::Value(Term::Integer(2)),
453 Op::Binary(Binary::Add),
454 Op::Value(Term::Integer(3)),
455 Op::Binary(Binary::LessThan),
456 ];
457
458 println!("ops1: {:?}", ops1);
459 println!("ops2: {:?}", ops2);
460 println!("ops3: {:?}", ops3);
461 let e1 = Expression { ops: ops1 };
462 let e2 = Expression { ops: ops2 };
463 let e3 = Expression { ops: ops3 };
464
465 assert_eq!(e1.print(&symbols).unwrap(), "-1 < $var1");
466
467 assert_eq!(e2.print(&symbols).unwrap(), "1 < 2 + 3");
468
469 assert_eq!(e3.print(&symbols).unwrap(), "1 + 2 < 3");
470 }
472}