monster/solver/
monster.rs

1#![allow(clippy::many_single_char_names)]
2#![allow(clippy::if_same_then_else)]
3#![allow(clippy::neg_cmp_op_on_partial_ord)]
4
5use super::{
6    Assignment, BVOperator, BitVector, Formula, OperandSide, Solver, SolverError, Symbol, SymbolId,
7};
8use divisors::get_divisors;
9use log::{log_enabled, trace, Level};
10use rand::{distributions::Uniform, random, seq::SliceRandom, thread_rng, Rng};
11use std::time::{Duration, Instant};
12
13pub struct MonsterSolver {
14    timeout: Duration,
15}
16
17impl Default for MonsterSolver {
18    fn default() -> Self {
19        Self::new(Duration::new(3, 0))
20    }
21}
22
23impl MonsterSolver {
24    pub fn new(timeout: Duration) -> Self {
25        Self { timeout }
26    }
27}
28
29impl Solver for MonsterSolver {
30    fn name() -> &'static str {
31        "Monster"
32    }
33
34    fn solve_impl<F: Formula>(&self, formula: &F) -> Result<Option<Assignment>, SolverError> {
35        let ab = initialize_ab(formula);
36
37        sat(formula, ab, self.timeout)
38    }
39}
40
41// check if invertibility condition is met
42fn is_invertible(op: BVOperator, s: BitVector, t: BitVector, d: OperandSide) -> bool {
43    match op {
44        BVOperator::Add => true,
45        BVOperator::Sub => true,
46        BVOperator::Mul => (-s | s) & t == t,
47        BVOperator::Divu => match d {
48            OperandSide::Lhs => {
49                if s == BitVector(0) {
50                    t == BitVector::ones()
51                } else {
52                    !t.mulo(s)
53                }
54            }
55            OperandSide::Rhs => {
56                if t == BitVector(0) {
57                    s != BitVector::ones()
58                } else {
59                    t == BitVector::ones() || !(s < t)
60                }
61            }
62        },
63        BVOperator::Sltu => match d {
64            OperandSide::Lhs => {
65                if t != BitVector(0) {
66                    !(s == BitVector(0))
67                } else {
68                    true
69                }
70            }
71            OperandSide::Rhs => {
72                if t != BitVector(0) {
73                    !(s == BitVector::ones())
74                } else {
75                    true
76                }
77            }
78        },
79        BVOperator::Remu => match d {
80            OperandSide::Lhs => !(s <= t),
81            OperandSide::Rhs => {
82                if s == t {
83                    true
84                } else {
85                    !((s < t) || ((t != BitVector(0)) && t == s - BitVector(1)) || (s - t <= t))
86                }
87            }
88        },
89        BVOperator::Not => true,
90        BVOperator::BitwiseAnd => (t & s) == t,
91        BVOperator::Equals => true,
92    }
93}
94
95// initialize bit vectors with a consistent initial assignment to the formula
96// inputs are initialized with random values
97fn initialize_ab<F: Formula>(formula: &F) -> Vec<BitVector> {
98    // Initialize values for all input/const nodes
99    let max_id = formula
100        .symbol_ids()
101        .max()
102        .expect("formula should not be empty");
103
104    let mut ab = Vec::with_capacity(std::mem::size_of::<BitVector>() * (max_id + 1));
105    unsafe {
106        ab.set_len(max_id + 1);
107    }
108
109    formula.symbol_ids().for_each(|i| {
110        ab[i] = match formula[i] {
111            Symbol::Constant(c) => c,
112            _ => BitVector(random::<u64>()),
113        };
114    });
115
116    if log_enabled!(Level::Trace) {
117        formula
118            .symbol_ids()
119            .filter(|i| matches!(formula[*i], Symbol::Input(_)))
120            .for_each(|i| {
121                trace!("initialize: x{} <- {:#x}", i, ab[i].0);
122            });
123    }
124
125    // Propagate all values down when all input/const nodes are initialized
126    formula.symbol_ids().for_each(|i| match formula[i] {
127        Symbol::Input(_) | Symbol::Constant(_) => {
128            formula
129                .dependencies(i)
130                .for_each(|n| propagate_assignment(formula, &mut ab, n));
131        }
132        _ => {}
133    });
134
135    ab
136}
137
138// selects a child node to propagate downwards
139// always selects an "essential" input if there is one
140// otherwise selects an input undeterministically
141fn select<F: Formula>(
142    f: &F,
143    idx: SymbolId,
144    t: BitVector,
145    ab: &[BitVector],
146) -> (SymbolId, SymbolId, OperandSide) {
147    if let (lhs, Some(rhs)) = f.operands(idx) {
148        fn is_constant<F: Formula>(f: &F, n: SymbolId) -> bool {
149            matches!(f[n], Symbol::Constant(_))
150        }
151
152        #[allow(clippy::if_same_then_else)]
153        if is_constant(f, lhs) {
154            (rhs, lhs, OperandSide::Rhs)
155        } else if is_constant(f, rhs) {
156            (lhs, rhs, OperandSide::Lhs)
157        } else if is_essential(f, lhs, OperandSide::Lhs, rhs, t, ab) {
158            (lhs, rhs, OperandSide::Lhs)
159        } else if is_essential(f, rhs, OperandSide::Rhs, lhs, t, ab) {
160            (rhs, lhs, OperandSide::Rhs)
161        } else if random() {
162            (rhs, lhs, OperandSide::Rhs)
163        } else {
164            (lhs, rhs, OperandSide::Lhs)
165        }
166    } else {
167        panic!("can only select path for binary operators")
168    }
169}
170
171fn compute_inverse_value(op: BVOperator, s: BitVector, t: BitVector, d: OperandSide) -> BitVector {
172    match op {
173        BVOperator::Add => t - s,
174        BVOperator::Sub => match d {
175            OperandSide::Lhs => t + s,
176            OperandSide::Rhs => s - t,
177        },
178        BVOperator::Mul => {
179            let y = s >> s.ctz();
180
181            let y_inv = y
182                .modinverse()
183                .expect("a modular inverse has to exist iff operator is invertible");
184
185            let result = (t >> s.ctz()) * y_inv;
186
187            let to_shift = 64 - s.ctz();
188
189            let arbitrary_bit_mask = if to_shift == 64 {
190                BitVector(0)
191            } else {
192                BitVector::ones() << to_shift
193            };
194
195            let arbitrary_bits = BitVector(random::<u64>()) & arbitrary_bit_mask;
196
197            result | arbitrary_bits
198        }
199        BVOperator::Sltu => match d {
200            OperandSide::Lhs => {
201                if t == BitVector(0) {
202                    // x<s == false; therefore we need a random x>=s
203                    BitVector(thread_rng().sample(Uniform::new_inclusive(s.0, BitVector::ones().0)))
204                } else {
205                    // x<s == true; therefore we need a random x<s
206                    BitVector(thread_rng().sample(Uniform::new(0, s.0)))
207                }
208            }
209            OperandSide::Rhs => {
210                if t == BitVector(0) {
211                    // s<x == false; therefore we need a random x<=s
212                    BitVector(thread_rng().sample(Uniform::new_inclusive(0, s.0)))
213                } else {
214                    // s<x == true; therefore we need a random x>s
215                    BitVector(
216                        thread_rng().sample(Uniform::new_inclusive(s.0 + 1, BitVector::ones().0)),
217                    )
218                }
219            }
220        },
221        BVOperator::Divu => match d {
222            OperandSide::Lhs => {
223                if (t == BitVector::ones()) && (s == BitVector(1)) {
224                    BitVector::ones()
225                } else {
226                    let range_start = t * s;
227                    if range_start.0.overflowing_add(s.0 - 1).1 {
228                        BitVector(
229                            thread_rng()
230                                .sample(Uniform::new_inclusive(range_start.0, u64::max_value())),
231                        )
232                    } else {
233                        BitVector(thread_rng().sample(Uniform::new_inclusive(
234                            range_start.0,
235                            range_start.0 + (s.0 - 1),
236                        )))
237                    }
238                }
239            }
240            OperandSide::Rhs => {
241                if (t == s) && t == BitVector::ones() {
242                    BitVector(thread_rng().sample(Uniform::new_inclusive(0, 1)))
243                } else if (t == BitVector::ones()) && (s != BitVector::ones()) {
244                    BitVector(0)
245                } else {
246                    s / t
247                }
248            }
249        },
250        BVOperator::Remu => match d {
251            OperandSide::Lhs => {
252                let y = BitVector(
253                    thread_rng().sample(Uniform::new_inclusive(1, ((BitVector::ones() - t) / s).0)),
254                );
255                // below computation cannot overflow due to how `y` was chosen
256                assert!(
257                    !s.0.overflowing_mul(y.0).1,
258                    "multiplication overflow in REMU inverse"
259                );
260                assert!(
261                    !t.0.overflowing_add(y.0 * s.0).1,
262                    "addition overflow in REMU inverse"
263                );
264                y * s + t
265            }
266            OperandSide::Rhs => {
267                if s == t {
268                    let x = BitVector(
269                        thread_rng().sample(Uniform::new_inclusive(t.0, BitVector::ones().0)),
270                    );
271                    if x == t {
272                        BitVector(0)
273                    } else {
274                        x
275                    }
276                } else {
277                    let mut v = get_divisors(s.0 - t.0);
278                    v.push(1);
279                    v.push(s.0 - t.0);
280                    v = v.into_iter().filter(|x| x > &t.0).collect();
281
282                    BitVector(*v.choose(&mut rand::thread_rng()).unwrap())
283                }
284            }
285        },
286        BVOperator::BitwiseAnd => BitVector(random::<u64>()) | t,
287        BVOperator::Equals => {
288            if t == BitVector(0) {
289                loop {
290                    let r = BitVector(random::<u64>());
291                    if r != s {
292                        break r;
293                    }
294                }
295            } else {
296                s
297            }
298        }
299        _ => unreachable!("unknown operator or unary operator: {:?}", op),
300    }
301}
302
303fn compute_consistent_value(op: BVOperator, t: BitVector, d: OperandSide) -> BitVector {
304    match op {
305        BVOperator::Add | BVOperator::Sub | BVOperator::Equals => BitVector(random::<u64>()),
306        BVOperator::Mul => BitVector({
307            if t == BitVector(0) {
308                0
309            } else {
310                let mut r;
311                loop {
312                    r = random::<u128>();
313                    if r != 0 {
314                        break;
315                    }
316                }
317                if t.ctz() < r.trailing_zeros() {
318                    r >>= r.trailing_zeros() - t.ctz();
319                }
320                assert!(t.ctz() >= r.trailing_zeros());
321                r as u64
322            }
323        }),
324        BVOperator::Divu => match d {
325            OperandSide::Lhs => {
326                if (t == BitVector::ones()) || (t == BitVector(0)) {
327                    BitVector(thread_rng().sample(Uniform::new_inclusive(0, u64::max_value() - 1)))
328                } else {
329                    let mut y = BitVector(0);
330                    while !(y != BitVector(0)) && !(y.mulo(t)) {
331                        y = BitVector(
332                            thread_rng().sample(Uniform::new_inclusive(0, u64::max_value())),
333                        );
334                    }
335
336                    y * t
337                }
338            }
339            OperandSide::Rhs => {
340                if t == BitVector::ones() {
341                    BitVector(thread_rng().sample(Uniform::new_inclusive(0, 1)))
342                } else {
343                    BitVector(
344                        thread_rng().sample(Uniform::new_inclusive(0, u64::max_value() / t.0)),
345                    )
346                }
347            }
348        },
349        BVOperator::Sltu => match d {
350            OperandSide::Lhs => {
351                if t == BitVector(0) {
352                    // x<s == false
353                    BitVector(thread_rng().sample(Uniform::new_inclusive(0, BitVector::ones().0)))
354                } else {
355                    // x<s == true
356                    BitVector(thread_rng().sample(Uniform::new(0, BitVector::ones().0)))
357                }
358            }
359            OperandSide::Rhs => {
360                if t == BitVector(0) {
361                    // s<x == false
362                    BitVector(thread_rng().sample(Uniform::new_inclusive(0, BitVector::ones().0)))
363                } else {
364                    // s<x == true
365                    BitVector(thread_rng().sample(Uniform::new(1, BitVector::ones().0)))
366                }
367            }
368        },
369        BVOperator::Remu => match d {
370            OperandSide::Lhs => {
371                if t == BitVector::ones() {
372                    BitVector::ones()
373                } else {
374                    BitVector(thread_rng().sample(Uniform::new_inclusive(t.0, BitVector::ones().0)))
375                }
376            }
377            OperandSide::Rhs => {
378                if t == BitVector::ones() {
379                    BitVector(0)
380                } else {
381                    BitVector(
382                        thread_rng().sample(Uniform::new_inclusive(t.0 + 1, BitVector::ones().0)),
383                    )
384                }
385            }
386        },
387        BVOperator::BitwiseAnd => BitVector(random::<u64>()) | t,
388        _ => unreachable!("unknown operator for consistent value: {:?}", op),
389    }
390}
391
392fn compute_inverse_value_for_unary_op(op: BVOperator, t: BitVector) -> BitVector {
393    match op {
394        BVOperator::Not => {
395            if t == BitVector(0) {
396                BitVector(1)
397            } else {
398                BitVector(0)
399            }
400        }
401        _ => unreachable!("not unary operator: {:?}", op),
402    }
403}
404
405const CHOOSE_INVERSE: f64 = 0.90;
406
407// computes an inverse/consistent target value
408#[allow(clippy::too_many_arguments)]
409fn value<F: Formula>(
410    f: &F,
411    n: SymbolId,
412    ns: SymbolId,
413    side: OperandSide,
414    t: BitVector,
415    ab: &[BitVector],
416) -> BitVector {
417    let s = ab[ns];
418
419    match &f[n] {
420        Symbol::Operator(op) => {
421            let consistent = compute_consistent_value(*op, t, side);
422
423            if is_invertible(*op, s, t, side) {
424                let inverse = compute_inverse_value(*op, s, t, side);
425                let choose_inverse =
426                    rand::thread_rng().gen_range(0.0_f64..=1.0_f64) < CHOOSE_INVERSE;
427
428                if choose_inverse {
429                    inverse
430                } else {
431                    consistent
432                }
433            } else {
434                consistent
435            }
436        }
437        _ => unimplemented!(),
438    }
439}
440
441fn is_essential<F: Formula>(
442    formula: &F,
443    this: SymbolId,
444    on_side: OperandSide,
445    other: SymbolId,
446    t: BitVector,
447    ab: &[BitVector],
448) -> bool {
449    let ab_nx = ab[this];
450
451    match &formula[other] {
452        Symbol::Operator(op) => !is_invertible(*op, ab_nx, t, on_side.other()),
453        // TODO: not mentioned in paper => improvised. is that really true?
454        Symbol::Constant(_) | Symbol::Input(_) => false,
455    }
456}
457
458fn update_assignment<F: Formula>(f: &F, ab: &mut Vec<BitVector>, n: SymbolId, v: BitVector) {
459    ab[n] = v;
460
461    assert!(
462        matches!(f[n], Symbol::Input(_)),
463        "only inputs can be assigned"
464    );
465
466    trace!("update: x{} <- {:#x}", n, v.0);
467
468    f.dependencies(n)
469        .for_each(|n| propagate_assignment(f, ab, n));
470}
471
472fn propagate_assignment<F: Formula>(f: &F, ab: &mut Vec<BitVector>, n: SymbolId) {
473    fn update_binary<F: Formula, Op>(f: &F, ab: &mut Vec<BitVector>, n: SymbolId, s: &str, op: Op)
474    where
475        Op: FnOnce(BitVector, BitVector) -> BitVector,
476    {
477        if let (lhs, Some(rhs)) = f.operands(n) {
478            let result = op(ab[lhs], ab[rhs]);
479
480            trace!(
481                "propagate: x{} := x{}({:#x}) {} x{}({:#x}) |- x{} <- {:#x}",
482                n,
483                lhs,
484                ab[lhs].0,
485                s,
486                rhs,
487                ab[rhs].0,
488                n,
489                result.0
490            );
491
492            ab[n] = result;
493        } else {
494            panic!("can not update binary operator with 1 operand")
495        }
496    }
497
498    fn update_unary<F: Formula, Op>(f: &F, ab: &mut Vec<BitVector>, n: SymbolId, s: &str, op: Op)
499    where
500        Op: FnOnce(BitVector) -> BitVector,
501    {
502        if let (p, None) = f.operands(n) {
503            let result = op(ab[p]);
504
505            trace!(
506                "propagate: x{} := {}x{}({:#x}) |- x{} <- {:#x}",
507                n,
508                s,
509                p,
510                ab[p].0,
511                n,
512                result.0
513            );
514
515            ab[n] = result;
516        } else {
517            panic!("can not update unary operator with more than one operand")
518        }
519    }
520
521    match &f[n] {
522        Symbol::Operator(op) => {
523            match op {
524                BVOperator::Add => update_binary(f, ab, n, "+", |l, r| l + r),
525                BVOperator::Sub => update_binary(f, ab, n, "-", |l, r| l - r),
526                BVOperator::Mul => update_binary(f, ab, n, "*", |l, r| l * r),
527                BVOperator::Divu => update_binary(f, ab, n, "/", |l, r| l / r),
528                BVOperator::BitwiseAnd => update_binary(f, ab, n, "&", |l, r| l & r),
529                BVOperator::Sltu => update_binary(f, ab, n, "<", |l, r| {
530                    if l < r {
531                        BitVector(1)
532                    } else {
533                        BitVector(0)
534                    }
535                }),
536                BVOperator::Remu => update_binary(f, ab, n, "%", |l, r| l % r),
537                BVOperator::Equals => update_binary(f, ab, n, "=", |l, r| {
538                    if l == r {
539                        BitVector(1)
540                    } else {
541                        BitVector(0)
542                    }
543                }),
544                BVOperator::Not => update_unary(f, ab, n, "!", |x| {
545                    if x == BitVector(0) {
546                        BitVector(1)
547                    } else {
548                        BitVector(0)
549                    }
550                }),
551            }
552            f.dependencies(n)
553                //f.neighbors_directed(n, Direction::Outgoing)
554                .for_each(|n| propagate_assignment(f, ab, n));
555        }
556        _ => unreachable!(),
557    }
558}
559
560// can only handle one Equals constraint with constant
561fn sat<F: Formula>(
562    formula: &F,
563    mut ab: Vec<BitVector>,
564    timeout_time: Duration,
565) -> Result<Option<Assignment>, SolverError> {
566    let mut iterations = 0;
567
568    let start_time = Instant::now();
569
570    let root = formula.root();
571
572    while ab[root] != BitVector(1) {
573        let mut n = root;
574        let mut t = BitVector(1);
575
576        iterations += 1;
577        trace!("search {}: x{} <- 0x1", iterations, root);
578
579        while !formula.is_operand(n) {
580            if start_time.elapsed() > timeout_time {
581                return Err(SolverError::Timeout);
582            }
583            let (v, nx) = match formula[n] {
584                Symbol::Operator(op) => {
585                    if op.is_unary() {
586                        let nx = formula.operand(n);
587
588                        let v = compute_inverse_value_for_unary_op(op, t);
589
590                        trace!(
591                            "search {}: x{}({:#x}) = {}x{}({:#x}) |- x{} <- {:#x}",
592                            iterations,
593                            n,
594                            t.0,
595                            op,
596                            nx,
597                            ab[nx].0,
598                            nx,
599                            v.0
600                        );
601
602                        (v, nx)
603                    } else {
604                        let (nx, ns, side) = select(formula, n, t, &ab);
605
606                        let v = value(formula, n, ns, side, t, &ab);
607
608                        if log_enabled!(Level::Trace) {
609                            let (lhs, rhs) = if side == OperandSide::Lhs {
610                                (nx, ns)
611                            } else {
612                                (ns, nx)
613                            };
614
615                            trace!(
616                                "search {}: x{}({:#x}) := x{}({:#x}) {} x{}({:#x}) |- x{} <- {:#x}",
617                                iterations,
618                                n,
619                                t.0,
620                                lhs,
621                                ab[lhs].0,
622                                op,
623                                rhs,
624                                ab[rhs].0,
625                                nx,
626                                v.0
627                            );
628                        }
629
630                        (v, nx)
631                    }
632                }
633                _ => panic!("non instruction node found"),
634            };
635
636            t = v;
637            n = nx;
638        }
639
640        update_assignment(formula, &mut ab, n, t);
641    }
642
643    let assignment: Assignment = formula.symbol_ids().map(|i| (i, ab[i])).collect();
644
645    Ok(Some(assignment))
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651    use crate::engine::symbolic_state::{DataFlowGraph, FormulaView, SymbolicValue};
652    use crate::solver::*;
653
654    fn create_data_flow_with_input() -> (DataFlowGraph, SymbolicValue) {
655        let mut formula = DataFlowGraph::new();
656
657        let input = Symbol::Input(String::from("x0"));
658        let input_idx = formula.add_node(input);
659
660        (formula, input_idx)
661    }
662
663    fn add_equals_constraint(
664        data_flow: &mut DataFlowGraph,
665        to: SymbolicValue,
666        on: OperandSide,
667        constant: u64,
668    ) -> SymbolicValue {
669        let constrain = Symbol::Operator(BVOperator::Equals);
670        let constrain_idx = data_flow.add_node(constrain);
671
672        let constrain_c = Symbol::Constant(BitVector(constant));
673        let constrain_c_idx = data_flow.add_node(constrain_c);
674
675        data_flow.add_edge(to, constrain_idx, on);
676        data_flow.add_edge(constrain_c_idx, constrain_idx, on.other());
677
678        constrain_idx
679    }
680
681    #[test]
682    fn solve_trivial_equals_constraint() {
683        let (mut data_flow, input_idx) = create_data_flow_with_input();
684
685        let root = add_equals_constraint(&mut data_flow, input_idx, OperandSide::Lhs, 10);
686
687        let solver = MonsterSolver::default();
688        let formula = FormulaView::new(&data_flow, root);
689        let result = solver.solve(&formula);
690
691        assert!(result.is_ok(), "solver did not time out");
692        let unwrapped_result = result.unwrap();
693
694        assert!(
695            unwrapped_result.is_some(),
696            "has result for trivial equals constraint"
697        );
698        assert_eq!(
699            *unwrapped_result.unwrap().get(&input_idx.index()).unwrap(),
700            BitVector(10),
701            "solver result of trivial equal constrain has right value"
702        );
703    }
704
705    #[test]
706    fn solve_bvadd() {
707        let (mut data_flow, input_idx) = create_data_flow_with_input();
708
709        let constant = Symbol::Constant(BitVector(3));
710        let constant_idx = data_flow.add_node(constant);
711
712        let instr = Symbol::Operator(BVOperator::Add);
713        let instr_idx = data_flow.add_node(instr);
714
715        data_flow.add_edge(input_idx, instr_idx, OperandSide::Lhs);
716        data_flow.add_edge(constant_idx, instr_idx, OperandSide::Rhs);
717
718        let root = add_equals_constraint(&mut data_flow, instr_idx, OperandSide::Lhs, 10);
719
720        let solver = MonsterSolver::default();
721        let formula = FormulaView::new(&data_flow, root);
722        let result = solver.solve(&formula);
723
724        assert!(result.is_ok(), "solver did not time out");
725        let unwrapped_result = result.unwrap();
726
727        assert!(unwrapped_result.is_some(), "has result for trivial add op");
728        assert_eq!(
729            *unwrapped_result.unwrap().get(&input_idx.index()).unwrap(),
730            BitVector(7),
731            "solver result of trivial add op has right value"
732        );
733    }
734
735    fn test_invertibility(
736        op: BVOperator,
737        s: u64,
738        t: u64,
739        d: OperandSide,
740        result: bool,
741        msg: &'static str,
742    ) {
743        let s = BitVector(s);
744        let t = BitVector(t);
745
746        match d {
747            OperandSide::Lhs => {
748                assert_eq!(
749                    is_invertible(op, s, t, d),
750                    result,
751                    "x {:?} {:?} == {:?}   {}",
752                    op,
753                    s,
754                    t,
755                    msg
756                );
757            }
758            OperandSide::Rhs => {
759                assert_eq!(
760                    is_invertible(op, s, t, d),
761                    result,
762                    "{:?} {:?} x == {:?}   {}",
763                    s,
764                    op,
765                    t,
766                    msg
767                );
768            }
769        }
770    }
771
772    fn test_inverse_value_computation<F>(op: BVOperator, s: u64, t: u64, d: OperandSide, f: F)
773    where
774        F: FnOnce(BitVector, BitVector) -> BitVector,
775    {
776        let s = BitVector(s);
777        let t = BitVector(t);
778
779        let computed = compute_inverse_value(op, s, t, d);
780
781        // prove: computed <> s == t        where <> is the binary operator
782
783        match d {
784            OperandSide::Lhs => {
785                assert_eq!(
786                    f(computed, s),
787                    t,
788                    "{:?} {:?} {:?} == {:?}",
789                    computed,
790                    op,
791                    s,
792                    t
793                );
794            }
795            OperandSide::Rhs => {
796                assert_eq!(
797                    f(s, computed),
798                    t,
799                    "{:?} {:?} {:?} == {:?}",
800                    s,
801                    op,
802                    computed,
803                    t
804                );
805            }
806        }
807    }
808
809    fn test_consistent_value_computation<F>(op: BVOperator, t: u64, d: OperandSide, f: F)
810    where
811        F: FnOnce(BitVector, BitVector) -> BitVector,
812    {
813        let t = BitVector(t);
814
815        let computed = compute_consistent_value(op, t, d);
816
817        // TODO: How to test consistent values?
818        // To proof that there exists a y, we would have to compute and inverse value, which is not
819        // always possible.
820        // I think, Alastairs
821        // prove: Ey.(computed <> y == t)        where <> is the binary bit vector operator
822        //
823
824        let inverse = match op {
825            BVOperator::Add => t - computed,
826            BVOperator::Mul => {
827                assert!(
828                    is_invertible(op, computed, t, d),
829                    "choose values which are invertible..."
830                );
831
832                compute_inverse_value(op, computed, t, d)
833            }
834            BVOperator::Sltu => compute_inverse_value(op, computed, t, d),
835            BVOperator::Divu => {
836                assert!(is_invertible(op, computed, t, d));
837                compute_inverse_value(op, computed, t, d)
838            }
839            _ => unimplemented!(),
840        };
841
842        if d == OperandSide::Lhs {
843            assert_eq!(
844                f(inverse, computed),
845                t,
846                "{:?} {:?} {:?} == {:?}",
847                inverse,
848                op,
849                computed,
850                t
851            );
852        } else {
853            assert_eq!(
854                f(computed, inverse),
855                t,
856                "{:?} {:?} {:?} == {:?}",
857                computed,
858                op,
859                inverse,
860                t
861            );
862        }
863    }
864
865    // TODO: add tests for ADD
866    // TODO: add tests for SUB
867
868    const MUL: BVOperator = BVOperator::Mul;
869    const SLTU: BVOperator = BVOperator::Sltu;
870    const DIVU: BVOperator = BVOperator::Divu;
871    const REMU: BVOperator = BVOperator::Remu;
872
873    #[test]
874    fn check_invertibility_condition_for_divu() {
875        test_invertibility(DIVU, 0b1, 0b1, OperandSide::Lhs, true, "trivial divu");
876        test_invertibility(DIVU, 0b1, 0b1, OperandSide::Rhs, true, "trivial divu");
877
878        test_invertibility(DIVU, 3, 2, OperandSide::Lhs, true, "x / 3 = 2");
879        test_invertibility(DIVU, 6, 2, OperandSide::Rhs, true, "6 / x = 2");
880
881        test_invertibility(DIVU, 0, 2, OperandSide::Lhs, false, "x / 0 = 2");
882        test_invertibility(DIVU, 0, 2, OperandSide::Rhs, false, "0 / x = 2");
883
884        test_invertibility(DIVU, 5, 6, OperandSide::Rhs, false, "5 / x = 6");
885    }
886
887    #[test]
888    fn check_invertibility_condition_for_mul() {
889        let side = OperandSide::Lhs;
890
891        test_invertibility(MUL, 0b1, 0b1, side, true, "trivial multiplication");
892        test_invertibility(MUL, 0b10, 0b1, side, false, "operand bigger than result");
893        test_invertibility(
894            MUL,
895            0b10,
896            0b10,
897            side,
898            true,
899            "operand with undetermined bits and possible invsere",
900        );
901        test_invertibility(
902            MUL,
903            0b10,
904            0b10,
905            side,
906            true,
907            "operand with undetermined bits and no inverse value",
908        );
909        test_invertibility(
910            MUL,
911            0b100,
912            0b100,
913            side,
914            true,
915            "operand with undetermined bits and no inverse value",
916        );
917        test_invertibility(
918            MUL,
919            0b10,
920            0b1100,
921            side,
922            true,
923            "operand with undetermined bits and no inverse value",
924        );
925    }
926
927    #[test]
928    fn check_invertibility_condition_for_sltu() {
929        let mut side = OperandSide::Lhs;
930
931        test_invertibility(SLTU, 0, 1, side, false, "x < 0 == 1 FALSE");
932        test_invertibility(SLTU, 1, 1, side, true, "x < 1 == 1 TRUE");
933        test_invertibility(
934            SLTU,
935            u64::max_value(),
936            0,
937            side,
938            true,
939            "x < max_value == 0 TRUE",
940        );
941
942        side = OperandSide::Rhs;
943
944        test_invertibility(SLTU, 0, 1, side, true, "0 < x == 1 TRUE");
945        test_invertibility(SLTU, 0, 0, side, true, "0 < x == 0 TRUE");
946        test_invertibility(
947            SLTU,
948            u64::max_value(),
949            1,
950            side,
951            false,
952            "max_value < x == 1 FALSE",
953        );
954        test_invertibility(
955            SLTU,
956            u64::max_value(),
957            0,
958            side,
959            true,
960            "max_value < x == 0 TRUE",
961        );
962    }
963
964    #[test]
965    fn check_invertibility_condition_for_remu() {
966        let mut side = OperandSide::Lhs;
967
968        test_invertibility(REMU, 3, 2, side, true, "x mod 3 = 2 TRUE");
969        test_invertibility(REMU, 3, 3, side, false, "x mod 3 = 3 FALSE");
970
971        side = OperandSide::Rhs;
972
973        test_invertibility(REMU, 3, 3, side, true, "3 mod x = 3 TRUE");
974        test_invertibility(REMU, 3, 2, side, false, "3 mod x = 2 FALSE");
975        test_invertibility(REMU, 5, 3, side, false, "5 mod x = 3 FALSE");
976    }
977
978    #[test]
979    fn compute_inverse_values_for_mul() {
980        let side = OperandSide::Lhs;
981
982        fn f(l: BitVector, r: BitVector) -> BitVector {
983            l * r
984        }
985
986        // test only for values which are actually invertible
987        test_inverse_value_computation(MUL, 0b1, 0b1, side, f);
988        test_inverse_value_computation(MUL, 0b10, 0b10, side, f);
989        test_inverse_value_computation(MUL, 0b100, 0b100, side, f);
990        test_inverse_value_computation(MUL, 0b10, 0b1100, side, f);
991    }
992
993    #[test]
994    fn compute_inverse_values_for_sltu() {
995        let mut side = OperandSide::Lhs;
996
997        fn f(l: BitVector, r: BitVector) -> BitVector {
998            if l < r {
999                BitVector(1)
1000            } else {
1001                BitVector(0)
1002            }
1003        }
1004
1005        // test only for values which are actually invertible
1006        test_inverse_value_computation(SLTU, u64::max_value(), 0, side, f);
1007        test_inverse_value_computation(SLTU, 0, 0, side, f);
1008        test_inverse_value_computation(SLTU, 1, 1, side, f);
1009
1010        side = OperandSide::Rhs;
1011
1012        test_inverse_value_computation(SLTU, 0, 0, side, f);
1013        test_inverse_value_computation(SLTU, u64::max_value() - 1, 1, side, f);
1014        test_inverse_value_computation(SLTU, 1, 1, side, f);
1015    }
1016
1017    #[test]
1018    fn compute_inverse_values_for_divu() {
1019        fn f(l: BitVector, r: BitVector) -> BitVector {
1020            l / r
1021        }
1022
1023        // test only for values which are actually invertible
1024        test_inverse_value_computation(DIVU, 0b1, 0b1, OperandSide::Lhs, f);
1025        test_inverse_value_computation(DIVU, 0b1, 0b1, OperandSide::Rhs, f);
1026
1027        test_inverse_value_computation(DIVU, 2, 3, OperandSide::Lhs, f);
1028        test_inverse_value_computation(DIVU, 6, 2, OperandSide::Rhs, f);
1029    }
1030
1031    #[test]
1032    fn compute_inverse_values_for_remu() {
1033        fn f(l: BitVector, r: BitVector) -> BitVector {
1034            l % r
1035        }
1036
1037        // test only for values which are actually invertible
1038        test_inverse_value_computation(REMU, u64::max_value(), 0, OperandSide::Lhs, f);
1039        test_inverse_value_computation(
1040            REMU,
1041            u64::max_value() - 1,
1042            u64::max_value() - 1,
1043            OperandSide::Rhs,
1044            f,
1045        );
1046        test_inverse_value_computation(REMU, 3, 2, OperandSide::Lhs, f);
1047        test_inverse_value_computation(REMU, 5, 2, OperandSide::Rhs, f);
1048        test_inverse_value_computation(REMU, 3, 3, OperandSide::Rhs, f);
1049    }
1050
1051    #[test]
1052    fn compute_consistent_values_for_mul() {
1053        let side = OperandSide::Lhs;
1054
1055        fn f(l: BitVector, r: BitVector) -> BitVector {
1056            l * r
1057        }
1058
1059        // test only for values which actually have a consistent value
1060        test_consistent_value_computation(MUL, 0b110, side, f);
1061        test_consistent_value_computation(MUL, 0b101, side, f);
1062        test_consistent_value_computation(MUL, 0b11, side, f);
1063        test_consistent_value_computation(MUL, 0b100, side, f);
1064    }
1065
1066    #[test]
1067    fn compute_consistent_values_for_sltu() {
1068        let mut side = OperandSide::Lhs;
1069
1070        fn f(l: BitVector, r: BitVector) -> BitVector {
1071            if l < r {
1072                BitVector(1)
1073            } else {
1074                BitVector(0)
1075            }
1076        }
1077
1078        // test only for values which actually have a consistent value
1079        test_consistent_value_computation(SLTU, 0, side, f);
1080        test_consistent_value_computation(SLTU, 1, side, f);
1081
1082        side = OperandSide::Rhs;
1083
1084        // test only for values which actually have a consistent value
1085        test_consistent_value_computation(SLTU, 0, side, f);
1086        test_consistent_value_computation(SLTU, 1, side, f);
1087    }
1088}