1#![allow(clippy::many_single_char_names)]
2#![allow(clippy::if_same_then_else)]
3#![allow(clippy::neg_cmp_op_on_partial_ord)]
4
5use super::{
6 Assignment, BVOperator, BitVector, Formula, OperandSide, Solver, SolverError, Symbol, SymbolId,
7};
8use divisors::get_divisors;
9use log::{log_enabled, trace, Level};
10use rand::{distributions::Uniform, random, seq::SliceRandom, thread_rng, Rng};
11use std::time::{Duration, Instant};
12
13pub struct MonsterSolver {
14 timeout: Duration,
15}
16
17impl Default for MonsterSolver {
18 fn default() -> Self {
19 Self::new(Duration::new(3, 0))
20 }
21}
22
23impl MonsterSolver {
24 pub fn new(timeout: Duration) -> Self {
25 Self { timeout }
26 }
27}
28
29impl Solver for MonsterSolver {
30 fn name() -> &'static str {
31 "Monster"
32 }
33
34 fn solve_impl<F: Formula>(&self, formula: &F) -> Result<Option<Assignment>, SolverError> {
35 let ab = initialize_ab(formula);
36
37 sat(formula, ab, self.timeout)
38 }
39}
40
41fn is_invertible(op: BVOperator, s: BitVector, t: BitVector, d: OperandSide) -> bool {
43 match op {
44 BVOperator::Add => true,
45 BVOperator::Sub => true,
46 BVOperator::Mul => (-s | s) & t == t,
47 BVOperator::Divu => match d {
48 OperandSide::Lhs => {
49 if s == BitVector(0) {
50 t == BitVector::ones()
51 } else {
52 !t.mulo(s)
53 }
54 }
55 OperandSide::Rhs => {
56 if t == BitVector(0) {
57 s != BitVector::ones()
58 } else {
59 t == BitVector::ones() || !(s < t)
60 }
61 }
62 },
63 BVOperator::Sltu => match d {
64 OperandSide::Lhs => {
65 if t != BitVector(0) {
66 !(s == BitVector(0))
67 } else {
68 true
69 }
70 }
71 OperandSide::Rhs => {
72 if t != BitVector(0) {
73 !(s == BitVector::ones())
74 } else {
75 true
76 }
77 }
78 },
79 BVOperator::Remu => match d {
80 OperandSide::Lhs => !(s <= t),
81 OperandSide::Rhs => {
82 if s == t {
83 true
84 } else {
85 !((s < t) || ((t != BitVector(0)) && t == s - BitVector(1)) || (s - t <= t))
86 }
87 }
88 },
89 BVOperator::Not => true,
90 BVOperator::BitwiseAnd => (t & s) == t,
91 BVOperator::Equals => true,
92 }
93}
94
95fn initialize_ab<F: Formula>(formula: &F) -> Vec<BitVector> {
98 let max_id = formula
100 .symbol_ids()
101 .max()
102 .expect("formula should not be empty");
103
104 let mut ab = Vec::with_capacity(std::mem::size_of::<BitVector>() * (max_id + 1));
105 unsafe {
106 ab.set_len(max_id + 1);
107 }
108
109 formula.symbol_ids().for_each(|i| {
110 ab[i] = match formula[i] {
111 Symbol::Constant(c) => c,
112 _ => BitVector(random::<u64>()),
113 };
114 });
115
116 if log_enabled!(Level::Trace) {
117 formula
118 .symbol_ids()
119 .filter(|i| matches!(formula[*i], Symbol::Input(_)))
120 .for_each(|i| {
121 trace!("initialize: x{} <- {:#x}", i, ab[i].0);
122 });
123 }
124
125 formula.symbol_ids().for_each(|i| match formula[i] {
127 Symbol::Input(_) | Symbol::Constant(_) => {
128 formula
129 .dependencies(i)
130 .for_each(|n| propagate_assignment(formula, &mut ab, n));
131 }
132 _ => {}
133 });
134
135 ab
136}
137
138fn select<F: Formula>(
142 f: &F,
143 idx: SymbolId,
144 t: BitVector,
145 ab: &[BitVector],
146) -> (SymbolId, SymbolId, OperandSide) {
147 if let (lhs, Some(rhs)) = f.operands(idx) {
148 fn is_constant<F: Formula>(f: &F, n: SymbolId) -> bool {
149 matches!(f[n], Symbol::Constant(_))
150 }
151
152 #[allow(clippy::if_same_then_else)]
153 if is_constant(f, lhs) {
154 (rhs, lhs, OperandSide::Rhs)
155 } else if is_constant(f, rhs) {
156 (lhs, rhs, OperandSide::Lhs)
157 } else if is_essential(f, lhs, OperandSide::Lhs, rhs, t, ab) {
158 (lhs, rhs, OperandSide::Lhs)
159 } else if is_essential(f, rhs, OperandSide::Rhs, lhs, t, ab) {
160 (rhs, lhs, OperandSide::Rhs)
161 } else if random() {
162 (rhs, lhs, OperandSide::Rhs)
163 } else {
164 (lhs, rhs, OperandSide::Lhs)
165 }
166 } else {
167 panic!("can only select path for binary operators")
168 }
169}
170
171fn compute_inverse_value(op: BVOperator, s: BitVector, t: BitVector, d: OperandSide) -> BitVector {
172 match op {
173 BVOperator::Add => t - s,
174 BVOperator::Sub => match d {
175 OperandSide::Lhs => t + s,
176 OperandSide::Rhs => s - t,
177 },
178 BVOperator::Mul => {
179 let y = s >> s.ctz();
180
181 let y_inv = y
182 .modinverse()
183 .expect("a modular inverse has to exist iff operator is invertible");
184
185 let result = (t >> s.ctz()) * y_inv;
186
187 let to_shift = 64 - s.ctz();
188
189 let arbitrary_bit_mask = if to_shift == 64 {
190 BitVector(0)
191 } else {
192 BitVector::ones() << to_shift
193 };
194
195 let arbitrary_bits = BitVector(random::<u64>()) & arbitrary_bit_mask;
196
197 result | arbitrary_bits
198 }
199 BVOperator::Sltu => match d {
200 OperandSide::Lhs => {
201 if t == BitVector(0) {
202 BitVector(thread_rng().sample(Uniform::new_inclusive(s.0, BitVector::ones().0)))
204 } else {
205 BitVector(thread_rng().sample(Uniform::new(0, s.0)))
207 }
208 }
209 OperandSide::Rhs => {
210 if t == BitVector(0) {
211 BitVector(thread_rng().sample(Uniform::new_inclusive(0, s.0)))
213 } else {
214 BitVector(
216 thread_rng().sample(Uniform::new_inclusive(s.0 + 1, BitVector::ones().0)),
217 )
218 }
219 }
220 },
221 BVOperator::Divu => match d {
222 OperandSide::Lhs => {
223 if (t == BitVector::ones()) && (s == BitVector(1)) {
224 BitVector::ones()
225 } else {
226 let range_start = t * s;
227 if range_start.0.overflowing_add(s.0 - 1).1 {
228 BitVector(
229 thread_rng()
230 .sample(Uniform::new_inclusive(range_start.0, u64::max_value())),
231 )
232 } else {
233 BitVector(thread_rng().sample(Uniform::new_inclusive(
234 range_start.0,
235 range_start.0 + (s.0 - 1),
236 )))
237 }
238 }
239 }
240 OperandSide::Rhs => {
241 if (t == s) && t == BitVector::ones() {
242 BitVector(thread_rng().sample(Uniform::new_inclusive(0, 1)))
243 } else if (t == BitVector::ones()) && (s != BitVector::ones()) {
244 BitVector(0)
245 } else {
246 s / t
247 }
248 }
249 },
250 BVOperator::Remu => match d {
251 OperandSide::Lhs => {
252 let y = BitVector(
253 thread_rng().sample(Uniform::new_inclusive(1, ((BitVector::ones() - t) / s).0)),
254 );
255 assert!(
257 !s.0.overflowing_mul(y.0).1,
258 "multiplication overflow in REMU inverse"
259 );
260 assert!(
261 !t.0.overflowing_add(y.0 * s.0).1,
262 "addition overflow in REMU inverse"
263 );
264 y * s + t
265 }
266 OperandSide::Rhs => {
267 if s == t {
268 let x = BitVector(
269 thread_rng().sample(Uniform::new_inclusive(t.0, BitVector::ones().0)),
270 );
271 if x == t {
272 BitVector(0)
273 } else {
274 x
275 }
276 } else {
277 let mut v = get_divisors(s.0 - t.0);
278 v.push(1);
279 v.push(s.0 - t.0);
280 v = v.into_iter().filter(|x| x > &t.0).collect();
281
282 BitVector(*v.choose(&mut rand::thread_rng()).unwrap())
283 }
284 }
285 },
286 BVOperator::BitwiseAnd => BitVector(random::<u64>()) | t,
287 BVOperator::Equals => {
288 if t == BitVector(0) {
289 loop {
290 let r = BitVector(random::<u64>());
291 if r != s {
292 break r;
293 }
294 }
295 } else {
296 s
297 }
298 }
299 _ => unreachable!("unknown operator or unary operator: {:?}", op),
300 }
301}
302
303fn compute_consistent_value(op: BVOperator, t: BitVector, d: OperandSide) -> BitVector {
304 match op {
305 BVOperator::Add | BVOperator::Sub | BVOperator::Equals => BitVector(random::<u64>()),
306 BVOperator::Mul => BitVector({
307 if t == BitVector(0) {
308 0
309 } else {
310 let mut r;
311 loop {
312 r = random::<u128>();
313 if r != 0 {
314 break;
315 }
316 }
317 if t.ctz() < r.trailing_zeros() {
318 r >>= r.trailing_zeros() - t.ctz();
319 }
320 assert!(t.ctz() >= r.trailing_zeros());
321 r as u64
322 }
323 }),
324 BVOperator::Divu => match d {
325 OperandSide::Lhs => {
326 if (t == BitVector::ones()) || (t == BitVector(0)) {
327 BitVector(thread_rng().sample(Uniform::new_inclusive(0, u64::max_value() - 1)))
328 } else {
329 let mut y = BitVector(0);
330 while !(y != BitVector(0)) && !(y.mulo(t)) {
331 y = BitVector(
332 thread_rng().sample(Uniform::new_inclusive(0, u64::max_value())),
333 );
334 }
335
336 y * t
337 }
338 }
339 OperandSide::Rhs => {
340 if t == BitVector::ones() {
341 BitVector(thread_rng().sample(Uniform::new_inclusive(0, 1)))
342 } else {
343 BitVector(
344 thread_rng().sample(Uniform::new_inclusive(0, u64::max_value() / t.0)),
345 )
346 }
347 }
348 },
349 BVOperator::Sltu => match d {
350 OperandSide::Lhs => {
351 if t == BitVector(0) {
352 BitVector(thread_rng().sample(Uniform::new_inclusive(0, BitVector::ones().0)))
354 } else {
355 BitVector(thread_rng().sample(Uniform::new(0, BitVector::ones().0)))
357 }
358 }
359 OperandSide::Rhs => {
360 if t == BitVector(0) {
361 BitVector(thread_rng().sample(Uniform::new_inclusive(0, BitVector::ones().0)))
363 } else {
364 BitVector(thread_rng().sample(Uniform::new(1, BitVector::ones().0)))
366 }
367 }
368 },
369 BVOperator::Remu => match d {
370 OperandSide::Lhs => {
371 if t == BitVector::ones() {
372 BitVector::ones()
373 } else {
374 BitVector(thread_rng().sample(Uniform::new_inclusive(t.0, BitVector::ones().0)))
375 }
376 }
377 OperandSide::Rhs => {
378 if t == BitVector::ones() {
379 BitVector(0)
380 } else {
381 BitVector(
382 thread_rng().sample(Uniform::new_inclusive(t.0 + 1, BitVector::ones().0)),
383 )
384 }
385 }
386 },
387 BVOperator::BitwiseAnd => BitVector(random::<u64>()) | t,
388 _ => unreachable!("unknown operator for consistent value: {:?}", op),
389 }
390}
391
392fn compute_inverse_value_for_unary_op(op: BVOperator, t: BitVector) -> BitVector {
393 match op {
394 BVOperator::Not => {
395 if t == BitVector(0) {
396 BitVector(1)
397 } else {
398 BitVector(0)
399 }
400 }
401 _ => unreachable!("not unary operator: {:?}", op),
402 }
403}
404
405const CHOOSE_INVERSE: f64 = 0.90;
406
407#[allow(clippy::too_many_arguments)]
409fn value<F: Formula>(
410 f: &F,
411 n: SymbolId,
412 ns: SymbolId,
413 side: OperandSide,
414 t: BitVector,
415 ab: &[BitVector],
416) -> BitVector {
417 let s = ab[ns];
418
419 match &f[n] {
420 Symbol::Operator(op) => {
421 let consistent = compute_consistent_value(*op, t, side);
422
423 if is_invertible(*op, s, t, side) {
424 let inverse = compute_inverse_value(*op, s, t, side);
425 let choose_inverse =
426 rand::thread_rng().gen_range(0.0_f64..=1.0_f64) < CHOOSE_INVERSE;
427
428 if choose_inverse {
429 inverse
430 } else {
431 consistent
432 }
433 } else {
434 consistent
435 }
436 }
437 _ => unimplemented!(),
438 }
439}
440
441fn is_essential<F: Formula>(
442 formula: &F,
443 this: SymbolId,
444 on_side: OperandSide,
445 other: SymbolId,
446 t: BitVector,
447 ab: &[BitVector],
448) -> bool {
449 let ab_nx = ab[this];
450
451 match &formula[other] {
452 Symbol::Operator(op) => !is_invertible(*op, ab_nx, t, on_side.other()),
453 Symbol::Constant(_) | Symbol::Input(_) => false,
455 }
456}
457
458fn update_assignment<F: Formula>(f: &F, ab: &mut Vec<BitVector>, n: SymbolId, v: BitVector) {
459 ab[n] = v;
460
461 assert!(
462 matches!(f[n], Symbol::Input(_)),
463 "only inputs can be assigned"
464 );
465
466 trace!("update: x{} <- {:#x}", n, v.0);
467
468 f.dependencies(n)
469 .for_each(|n| propagate_assignment(f, ab, n));
470}
471
472fn propagate_assignment<F: Formula>(f: &F, ab: &mut Vec<BitVector>, n: SymbolId) {
473 fn update_binary<F: Formula, Op>(f: &F, ab: &mut Vec<BitVector>, n: SymbolId, s: &str, op: Op)
474 where
475 Op: FnOnce(BitVector, BitVector) -> BitVector,
476 {
477 if let (lhs, Some(rhs)) = f.operands(n) {
478 let result = op(ab[lhs], ab[rhs]);
479
480 trace!(
481 "propagate: x{} := x{}({:#x}) {} x{}({:#x}) |- x{} <- {:#x}",
482 n,
483 lhs,
484 ab[lhs].0,
485 s,
486 rhs,
487 ab[rhs].0,
488 n,
489 result.0
490 );
491
492 ab[n] = result;
493 } else {
494 panic!("can not update binary operator with 1 operand")
495 }
496 }
497
498 fn update_unary<F: Formula, Op>(f: &F, ab: &mut Vec<BitVector>, n: SymbolId, s: &str, op: Op)
499 where
500 Op: FnOnce(BitVector) -> BitVector,
501 {
502 if let (p, None) = f.operands(n) {
503 let result = op(ab[p]);
504
505 trace!(
506 "propagate: x{} := {}x{}({:#x}) |- x{} <- {:#x}",
507 n,
508 s,
509 p,
510 ab[p].0,
511 n,
512 result.0
513 );
514
515 ab[n] = result;
516 } else {
517 panic!("can not update unary operator with more than one operand")
518 }
519 }
520
521 match &f[n] {
522 Symbol::Operator(op) => {
523 match op {
524 BVOperator::Add => update_binary(f, ab, n, "+", |l, r| l + r),
525 BVOperator::Sub => update_binary(f, ab, n, "-", |l, r| l - r),
526 BVOperator::Mul => update_binary(f, ab, n, "*", |l, r| l * r),
527 BVOperator::Divu => update_binary(f, ab, n, "/", |l, r| l / r),
528 BVOperator::BitwiseAnd => update_binary(f, ab, n, "&", |l, r| l & r),
529 BVOperator::Sltu => update_binary(f, ab, n, "<", |l, r| {
530 if l < r {
531 BitVector(1)
532 } else {
533 BitVector(0)
534 }
535 }),
536 BVOperator::Remu => update_binary(f, ab, n, "%", |l, r| l % r),
537 BVOperator::Equals => update_binary(f, ab, n, "=", |l, r| {
538 if l == r {
539 BitVector(1)
540 } else {
541 BitVector(0)
542 }
543 }),
544 BVOperator::Not => update_unary(f, ab, n, "!", |x| {
545 if x == BitVector(0) {
546 BitVector(1)
547 } else {
548 BitVector(0)
549 }
550 }),
551 }
552 f.dependencies(n)
553 .for_each(|n| propagate_assignment(f, ab, n));
555 }
556 _ => unreachable!(),
557 }
558}
559
560fn sat<F: Formula>(
562 formula: &F,
563 mut ab: Vec<BitVector>,
564 timeout_time: Duration,
565) -> Result<Option<Assignment>, SolverError> {
566 let mut iterations = 0;
567
568 let start_time = Instant::now();
569
570 let root = formula.root();
571
572 while ab[root] != BitVector(1) {
573 let mut n = root;
574 let mut t = BitVector(1);
575
576 iterations += 1;
577 trace!("search {}: x{} <- 0x1", iterations, root);
578
579 while !formula.is_operand(n) {
580 if start_time.elapsed() > timeout_time {
581 return Err(SolverError::Timeout);
582 }
583 let (v, nx) = match formula[n] {
584 Symbol::Operator(op) => {
585 if op.is_unary() {
586 let nx = formula.operand(n);
587
588 let v = compute_inverse_value_for_unary_op(op, t);
589
590 trace!(
591 "search {}: x{}({:#x}) = {}x{}({:#x}) |- x{} <- {:#x}",
592 iterations,
593 n,
594 t.0,
595 op,
596 nx,
597 ab[nx].0,
598 nx,
599 v.0
600 );
601
602 (v, nx)
603 } else {
604 let (nx, ns, side) = select(formula, n, t, &ab);
605
606 let v = value(formula, n, ns, side, t, &ab);
607
608 if log_enabled!(Level::Trace) {
609 let (lhs, rhs) = if side == OperandSide::Lhs {
610 (nx, ns)
611 } else {
612 (ns, nx)
613 };
614
615 trace!(
616 "search {}: x{}({:#x}) := x{}({:#x}) {} x{}({:#x}) |- x{} <- {:#x}",
617 iterations,
618 n,
619 t.0,
620 lhs,
621 ab[lhs].0,
622 op,
623 rhs,
624 ab[rhs].0,
625 nx,
626 v.0
627 );
628 }
629
630 (v, nx)
631 }
632 }
633 _ => panic!("non instruction node found"),
634 };
635
636 t = v;
637 n = nx;
638 }
639
640 update_assignment(formula, &mut ab, n, t);
641 }
642
643 let assignment: Assignment = formula.symbol_ids().map(|i| (i, ab[i])).collect();
644
645 Ok(Some(assignment))
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651 use crate::engine::symbolic_state::{DataFlowGraph, FormulaView, SymbolicValue};
652 use crate::solver::*;
653
654 fn create_data_flow_with_input() -> (DataFlowGraph, SymbolicValue) {
655 let mut formula = DataFlowGraph::new();
656
657 let input = Symbol::Input(String::from("x0"));
658 let input_idx = formula.add_node(input);
659
660 (formula, input_idx)
661 }
662
663 fn add_equals_constraint(
664 data_flow: &mut DataFlowGraph,
665 to: SymbolicValue,
666 on: OperandSide,
667 constant: u64,
668 ) -> SymbolicValue {
669 let constrain = Symbol::Operator(BVOperator::Equals);
670 let constrain_idx = data_flow.add_node(constrain);
671
672 let constrain_c = Symbol::Constant(BitVector(constant));
673 let constrain_c_idx = data_flow.add_node(constrain_c);
674
675 data_flow.add_edge(to, constrain_idx, on);
676 data_flow.add_edge(constrain_c_idx, constrain_idx, on.other());
677
678 constrain_idx
679 }
680
681 #[test]
682 fn solve_trivial_equals_constraint() {
683 let (mut data_flow, input_idx) = create_data_flow_with_input();
684
685 let root = add_equals_constraint(&mut data_flow, input_idx, OperandSide::Lhs, 10);
686
687 let solver = MonsterSolver::default();
688 let formula = FormulaView::new(&data_flow, root);
689 let result = solver.solve(&formula);
690
691 assert!(result.is_ok(), "solver did not time out");
692 let unwrapped_result = result.unwrap();
693
694 assert!(
695 unwrapped_result.is_some(),
696 "has result for trivial equals constraint"
697 );
698 assert_eq!(
699 *unwrapped_result.unwrap().get(&input_idx.index()).unwrap(),
700 BitVector(10),
701 "solver result of trivial equal constrain has right value"
702 );
703 }
704
705 #[test]
706 fn solve_bvadd() {
707 let (mut data_flow, input_idx) = create_data_flow_with_input();
708
709 let constant = Symbol::Constant(BitVector(3));
710 let constant_idx = data_flow.add_node(constant);
711
712 let instr = Symbol::Operator(BVOperator::Add);
713 let instr_idx = data_flow.add_node(instr);
714
715 data_flow.add_edge(input_idx, instr_idx, OperandSide::Lhs);
716 data_flow.add_edge(constant_idx, instr_idx, OperandSide::Rhs);
717
718 let root = add_equals_constraint(&mut data_flow, instr_idx, OperandSide::Lhs, 10);
719
720 let solver = MonsterSolver::default();
721 let formula = FormulaView::new(&data_flow, root);
722 let result = solver.solve(&formula);
723
724 assert!(result.is_ok(), "solver did not time out");
725 let unwrapped_result = result.unwrap();
726
727 assert!(unwrapped_result.is_some(), "has result for trivial add op");
728 assert_eq!(
729 *unwrapped_result.unwrap().get(&input_idx.index()).unwrap(),
730 BitVector(7),
731 "solver result of trivial add op has right value"
732 );
733 }
734
735 fn test_invertibility(
736 op: BVOperator,
737 s: u64,
738 t: u64,
739 d: OperandSide,
740 result: bool,
741 msg: &'static str,
742 ) {
743 let s = BitVector(s);
744 let t = BitVector(t);
745
746 match d {
747 OperandSide::Lhs => {
748 assert_eq!(
749 is_invertible(op, s, t, d),
750 result,
751 "x {:?} {:?} == {:?} {}",
752 op,
753 s,
754 t,
755 msg
756 );
757 }
758 OperandSide::Rhs => {
759 assert_eq!(
760 is_invertible(op, s, t, d),
761 result,
762 "{:?} {:?} x == {:?} {}",
763 s,
764 op,
765 t,
766 msg
767 );
768 }
769 }
770 }
771
772 fn test_inverse_value_computation<F>(op: BVOperator, s: u64, t: u64, d: OperandSide, f: F)
773 where
774 F: FnOnce(BitVector, BitVector) -> BitVector,
775 {
776 let s = BitVector(s);
777 let t = BitVector(t);
778
779 let computed = compute_inverse_value(op, s, t, d);
780
781 match d {
784 OperandSide::Lhs => {
785 assert_eq!(
786 f(computed, s),
787 t,
788 "{:?} {:?} {:?} == {:?}",
789 computed,
790 op,
791 s,
792 t
793 );
794 }
795 OperandSide::Rhs => {
796 assert_eq!(
797 f(s, computed),
798 t,
799 "{:?} {:?} {:?} == {:?}",
800 s,
801 op,
802 computed,
803 t
804 );
805 }
806 }
807 }
808
809 fn test_consistent_value_computation<F>(op: BVOperator, t: u64, d: OperandSide, f: F)
810 where
811 F: FnOnce(BitVector, BitVector) -> BitVector,
812 {
813 let t = BitVector(t);
814
815 let computed = compute_consistent_value(op, t, d);
816
817 let inverse = match op {
825 BVOperator::Add => t - computed,
826 BVOperator::Mul => {
827 assert!(
828 is_invertible(op, computed, t, d),
829 "choose values which are invertible..."
830 );
831
832 compute_inverse_value(op, computed, t, d)
833 }
834 BVOperator::Sltu => compute_inverse_value(op, computed, t, d),
835 BVOperator::Divu => {
836 assert!(is_invertible(op, computed, t, d));
837 compute_inverse_value(op, computed, t, d)
838 }
839 _ => unimplemented!(),
840 };
841
842 if d == OperandSide::Lhs {
843 assert_eq!(
844 f(inverse, computed),
845 t,
846 "{:?} {:?} {:?} == {:?}",
847 inverse,
848 op,
849 computed,
850 t
851 );
852 } else {
853 assert_eq!(
854 f(computed, inverse),
855 t,
856 "{:?} {:?} {:?} == {:?}",
857 computed,
858 op,
859 inverse,
860 t
861 );
862 }
863 }
864
865 const MUL: BVOperator = BVOperator::Mul;
869 const SLTU: BVOperator = BVOperator::Sltu;
870 const DIVU: BVOperator = BVOperator::Divu;
871 const REMU: BVOperator = BVOperator::Remu;
872
873 #[test]
874 fn check_invertibility_condition_for_divu() {
875 test_invertibility(DIVU, 0b1, 0b1, OperandSide::Lhs, true, "trivial divu");
876 test_invertibility(DIVU, 0b1, 0b1, OperandSide::Rhs, true, "trivial divu");
877
878 test_invertibility(DIVU, 3, 2, OperandSide::Lhs, true, "x / 3 = 2");
879 test_invertibility(DIVU, 6, 2, OperandSide::Rhs, true, "6 / x = 2");
880
881 test_invertibility(DIVU, 0, 2, OperandSide::Lhs, false, "x / 0 = 2");
882 test_invertibility(DIVU, 0, 2, OperandSide::Rhs, false, "0 / x = 2");
883
884 test_invertibility(DIVU, 5, 6, OperandSide::Rhs, false, "5 / x = 6");
885 }
886
887 #[test]
888 fn check_invertibility_condition_for_mul() {
889 let side = OperandSide::Lhs;
890
891 test_invertibility(MUL, 0b1, 0b1, side, true, "trivial multiplication");
892 test_invertibility(MUL, 0b10, 0b1, side, false, "operand bigger than result");
893 test_invertibility(
894 MUL,
895 0b10,
896 0b10,
897 side,
898 true,
899 "operand with undetermined bits and possible invsere",
900 );
901 test_invertibility(
902 MUL,
903 0b10,
904 0b10,
905 side,
906 true,
907 "operand with undetermined bits and no inverse value",
908 );
909 test_invertibility(
910 MUL,
911 0b100,
912 0b100,
913 side,
914 true,
915 "operand with undetermined bits and no inverse value",
916 );
917 test_invertibility(
918 MUL,
919 0b10,
920 0b1100,
921 side,
922 true,
923 "operand with undetermined bits and no inverse value",
924 );
925 }
926
927 #[test]
928 fn check_invertibility_condition_for_sltu() {
929 let mut side = OperandSide::Lhs;
930
931 test_invertibility(SLTU, 0, 1, side, false, "x < 0 == 1 FALSE");
932 test_invertibility(SLTU, 1, 1, side, true, "x < 1 == 1 TRUE");
933 test_invertibility(
934 SLTU,
935 u64::max_value(),
936 0,
937 side,
938 true,
939 "x < max_value == 0 TRUE",
940 );
941
942 side = OperandSide::Rhs;
943
944 test_invertibility(SLTU, 0, 1, side, true, "0 < x == 1 TRUE");
945 test_invertibility(SLTU, 0, 0, side, true, "0 < x == 0 TRUE");
946 test_invertibility(
947 SLTU,
948 u64::max_value(),
949 1,
950 side,
951 false,
952 "max_value < x == 1 FALSE",
953 );
954 test_invertibility(
955 SLTU,
956 u64::max_value(),
957 0,
958 side,
959 true,
960 "max_value < x == 0 TRUE",
961 );
962 }
963
964 #[test]
965 fn check_invertibility_condition_for_remu() {
966 let mut side = OperandSide::Lhs;
967
968 test_invertibility(REMU, 3, 2, side, true, "x mod 3 = 2 TRUE");
969 test_invertibility(REMU, 3, 3, side, false, "x mod 3 = 3 FALSE");
970
971 side = OperandSide::Rhs;
972
973 test_invertibility(REMU, 3, 3, side, true, "3 mod x = 3 TRUE");
974 test_invertibility(REMU, 3, 2, side, false, "3 mod x = 2 FALSE");
975 test_invertibility(REMU, 5, 3, side, false, "5 mod x = 3 FALSE");
976 }
977
978 #[test]
979 fn compute_inverse_values_for_mul() {
980 let side = OperandSide::Lhs;
981
982 fn f(l: BitVector, r: BitVector) -> BitVector {
983 l * r
984 }
985
986 test_inverse_value_computation(MUL, 0b1, 0b1, side, f);
988 test_inverse_value_computation(MUL, 0b10, 0b10, side, f);
989 test_inverse_value_computation(MUL, 0b100, 0b100, side, f);
990 test_inverse_value_computation(MUL, 0b10, 0b1100, side, f);
991 }
992
993 #[test]
994 fn compute_inverse_values_for_sltu() {
995 let mut side = OperandSide::Lhs;
996
997 fn f(l: BitVector, r: BitVector) -> BitVector {
998 if l < r {
999 BitVector(1)
1000 } else {
1001 BitVector(0)
1002 }
1003 }
1004
1005 test_inverse_value_computation(SLTU, u64::max_value(), 0, side, f);
1007 test_inverse_value_computation(SLTU, 0, 0, side, f);
1008 test_inverse_value_computation(SLTU, 1, 1, side, f);
1009
1010 side = OperandSide::Rhs;
1011
1012 test_inverse_value_computation(SLTU, 0, 0, side, f);
1013 test_inverse_value_computation(SLTU, u64::max_value() - 1, 1, side, f);
1014 test_inverse_value_computation(SLTU, 1, 1, side, f);
1015 }
1016
1017 #[test]
1018 fn compute_inverse_values_for_divu() {
1019 fn f(l: BitVector, r: BitVector) -> BitVector {
1020 l / r
1021 }
1022
1023 test_inverse_value_computation(DIVU, 0b1, 0b1, OperandSide::Lhs, f);
1025 test_inverse_value_computation(DIVU, 0b1, 0b1, OperandSide::Rhs, f);
1026
1027 test_inverse_value_computation(DIVU, 2, 3, OperandSide::Lhs, f);
1028 test_inverse_value_computation(DIVU, 6, 2, OperandSide::Rhs, f);
1029 }
1030
1031 #[test]
1032 fn compute_inverse_values_for_remu() {
1033 fn f(l: BitVector, r: BitVector) -> BitVector {
1034 l % r
1035 }
1036
1037 test_inverse_value_computation(REMU, u64::max_value(), 0, OperandSide::Lhs, f);
1039 test_inverse_value_computation(
1040 REMU,
1041 u64::max_value() - 1,
1042 u64::max_value() - 1,
1043 OperandSide::Rhs,
1044 f,
1045 );
1046 test_inverse_value_computation(REMU, 3, 2, OperandSide::Lhs, f);
1047 test_inverse_value_computation(REMU, 5, 2, OperandSide::Rhs, f);
1048 test_inverse_value_computation(REMU, 3, 3, OperandSide::Rhs, f);
1049 }
1050
1051 #[test]
1052 fn compute_consistent_values_for_mul() {
1053 let side = OperandSide::Lhs;
1054
1055 fn f(l: BitVector, r: BitVector) -> BitVector {
1056 l * r
1057 }
1058
1059 test_consistent_value_computation(MUL, 0b110, side, f);
1061 test_consistent_value_computation(MUL, 0b101, side, f);
1062 test_consistent_value_computation(MUL, 0b11, side, f);
1063 test_consistent_value_computation(MUL, 0b100, side, f);
1064 }
1065
1066 #[test]
1067 fn compute_consistent_values_for_sltu() {
1068 let mut side = OperandSide::Lhs;
1069
1070 fn f(l: BitVector, r: BitVector) -> BitVector {
1071 if l < r {
1072 BitVector(1)
1073 } else {
1074 BitVector(0)
1075 }
1076 }
1077
1078 test_consistent_value_computation(SLTU, 0, side, f);
1080 test_consistent_value_computation(SLTU, 1, side, f);
1081
1082 side = OperandSide::Rhs;
1083
1084 test_consistent_value_computation(SLTU, 0, side, f);
1086 test_consistent_value_computation(SLTU, 1, side, f);
1087 }
1088}