1use crate::blocks::annotated::ExitExt;
2
3use etk_dasm::blocks::annotated::{AnnotatedBlock, Exit};
4
5use petgraph::dot::Dot;
6use petgraph::graph::{Graph, NodeIndex};
7
8use std::collections::BTreeMap;
9use std::convert::TryInto;
10use std::fmt;
11
12use z3::ast::{Ast, BV};
13use z3::SatResult;
14
15#[derive(Debug, Clone)]
16enum Node {
17 Terminate,
18 BadJump,
19 Block(Box<AnnotatedBlock>),
20}
21
22impl fmt::Display for Node {
23 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
24 let block = match self {
25 Self::Terminate => return write!(f, "<terminate>"),
26 Self::BadJump => return write!(f, "<bad-jump>"),
27 Self::Block(b) => b,
28 };
29
30 write!(f, "Offset: 0x{:x}", block.offset)
31 }
32}
33
34struct Edge;
35
36impl fmt::Display for Edge {
37 fn fmt(&self, _: &mut fmt::Formatter) -> fmt::Result {
38 Ok(())
39 }
40}
41
42impl Node {
43 fn unwrap_block(&self) -> &AnnotatedBlock {
44 match self {
45 Self::Block(b) => b,
46 _ => panic!("not a block"),
47 }
48 }
49}
50
51pub struct ControlFlowGraph {
52 by_offset: BTreeMap<usize, NodeIndex>,
53 graph: Graph<Node, Edge>,
54}
55
56impl fmt::Debug for ControlFlowGraph {
57 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
58 #[derive(Debug)]
59 struct ControlFlowGraph<'a> {
60 #[allow(dead_code)]
61 by_offset: &'a BTreeMap<usize, NodeIndex>,
62 #[allow(dead_code)]
63 node_count: usize,
64 #[allow(dead_code)]
65 edge_count: usize,
66 }
67
68 let helper = ControlFlowGraph {
69 by_offset: &self.by_offset,
70 node_count: self.graph.node_count(),
71 edge_count: self.graph.edge_count(),
72 };
73
74 helper.fmt(f)
75 }
76}
77
78impl ControlFlowGraph {
79 pub fn new<I>(blocks: I) -> Self
80 where
81 I: Iterator<Item = AnnotatedBlock>,
82 {
83 let mut graph = Graph::<Node, Edge>::new();
84 let mut by_offset = BTreeMap::new();
85 let mut jump_targets = Vec::new();
86
87 let terminate = graph.add_node(Node::Terminate);
88 let bad_jump = graph.add_node(Node::BadJump);
89
90 for block in blocks {
91 let is_jump_target = block.jump_target;
92 let offset = block.offset;
93 let idx = graph.add_node(Node::Block(Box::new(block)));
94 let replaced = by_offset.insert(offset, idx);
95 assert_eq!(replaced, None);
96
97 if is_jump_target {
98 jump_targets.push(idx);
99 }
100 }
101
102 for idx in by_offset.values() {
103 let idx = *idx;
104 let node = &graph[idx];
105
106 let block = match node {
107 Node::Block(b) => b,
108 _ => continue,
109 };
110
111 let exit = block.exit.erase();
112
113 let mut fall_through_idx = None;
114
115 if let Some(fall_through) = block.exit.fall_through() {
117 let next = by_offset.get(&fall_through);
118 if let Some(next_idx) = next {
119 graph.add_edge(idx, *next_idx, Edge);
121 fall_through_idx = Some(next_idx);
122 } else {
123 graph.add_edge(idx, terminate, Edge);
126 }
127 };
128
129 match exit {
130 Exit::Unconditional(_) => (),
131 Exit::Branch { .. } => (),
132 Exit::Terminate => {
133 graph.add_edge(idx, terminate, Edge);
135 continue;
136 }
137 Exit::FallThrough(_) => {
138 continue;
140 }
141 }
142
143 graph.add_edge(idx, bad_jump, Edge);
145
146 for jump_target in jump_targets.iter() {
147 if Some(jump_target) == fall_through_idx {
148 continue;
150 }
151
152 graph.add_edge(idx, *jump_target, Edge);
154 }
155 }
156
157 Self { by_offset, graph }
158 }
159
160 fn shallow_block(&mut self, from: NodeIndex, to: NodeIndex) -> bool {
161 let from = self.graph[from].unwrap_block();
162 let to = self.graph[to].unwrap_block();
163
164 let config = z3::Config::new();
165 let context = z3::Context::new(&config);
166 let target = BV::from_u64(&context, to.offset.try_into().unwrap(), 256);
167
168 let ast = match from.exit.to_z3(&context) {
169 Exit::Terminate => unreachable!(),
170 Exit::FallThrough(f) => {
171 return f == to.offset;
172 }
173 Exit::Unconditional(u) => u,
174 Exit::Branch {
175 when_true,
176 when_false,
177 condition,
178 } => {
179 let when_false: u64 = when_false.try_into().unwrap();
180 let when_false = BV::from_u64(&context, when_false, 256);
181 let zero = BV::from_u64(&context, 0, 256);
182 condition._eq(&zero).ite(&when_false, &when_true)
183 }
184 };
185
186 let solver = z3::Solver::new(&context);
187 solver.assert(&ast._eq(&target));
188 let result = solver.check();
189
190 !matches!(result, SatResult::Unsat)
191 }
192
193 fn shallow_bad_jump(&mut self, from: NodeIndex) -> bool {
194 let from = self.graph[from].unwrap_block();
195
196 let config = z3::Config::new();
197 let context = z3::Context::new(&config);
198 let solver = z3::Solver::new(&context);
199
200 let ast = match from.exit.to_z3(&context) {
201 Exit::FallThrough(_) => return false,
202 Exit::Terminate => unreachable!(),
203 Exit::Unconditional(u) => u,
204 Exit::Branch {
205 when_true,
206 condition,
207 ..
208 } => {
209 let zero = BV::from_u64(&context, 0, 256);
210 solver.assert(&zero._eq(&condition).not());
211 when_true
212 }
213 };
214
215 for (offset, to_idx) in self.by_offset.iter() {
216 if !self.graph[*to_idx].unwrap_block().jump_target {
217 continue;
218 }
219
220 let bv = BV::from_u64(&context, (*offset).try_into().unwrap(), 256);
221 solver.assert(&bv._eq(&ast).not());
222 }
223
224 let result = solver.check();
225
226 !matches!(result, SatResult::Unsat)
227 }
228
229 fn shallow_terminate(&mut self, from: NodeIndex) -> bool {
230 let from = self.graph[from].unwrap_block();
231
232 let config = z3::Config::new();
233 let context = z3::Context::new(&config);
234 let solver = z3::Solver::new(&context);
235
236 let ast = match from.exit.to_z3(&context) {
237 Exit::FallThrough(_) => return true,
238 Exit::Terminate => return true,
239 Exit::Unconditional(_) => unreachable!(),
240 Exit::Branch { condition, .. } => {
241 let zero = BV::from_u64(&context, 0, 256);
242 zero._eq(&condition)
243 }
244 };
245
246 solver.assert(&ast);
247 let result = solver.check();
248
249 !matches!(result, SatResult::Unsat)
250 }
251
252 #[allow(clippy::needless_collect)]
254 pub fn refine_shallow(&mut self) {
255 let indexes: Vec<_> = self
256 .by_offset
257 .values()
258 .filter_map(|idx| {
259 let node = &self.graph[*idx];
260 match node {
261 Node::Block(_) => Some(*idx),
262 _ => None,
263 }
264 })
265 .collect();
266
267 for idx in indexes.into_iter() {
268 let neighbors_indexes: Vec<_> = self.graph.neighbors(idx).collect();
269
270 for neighbor_idx in neighbors_indexes.into_iter() {
271 let neighbor = &self.graph[neighbor_idx];
272
273 let keep = match neighbor {
274 Node::Block(_) => self.shallow_block(idx, neighbor_idx),
275 Node::BadJump => self.shallow_bad_jump(idx),
276 Node::Terminate => self.shallow_terminate(idx),
277 };
278
279 if !keep {
280 let edge = self.graph.find_edge(idx, neighbor_idx).unwrap();
281 self.graph.remove_edge(edge);
282 }
283 }
284 }
285 }
286
287 pub fn render(&self) -> impl '_ + fmt::Display {
288 Dot::new(&self.graph)
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use assert_matches::assert_matches;
295
296 use etk_asm::disasm::Disassembler;
297 use etk_asm::ingest::Ingest;
298
299 use etk_dasm::blocks::basic::Separator;
300
301 use super::*;
302
303 #[derive(Debug, Copy, Clone)]
304 enum N {
305 Offset(usize),
306 BadJump,
307 Terminate,
308 }
309
310 impl From<usize> for N {
311 fn from(offset: usize) -> Self {
312 Self::Offset(offset)
313 }
314 }
315
316 struct CfgTest<C, U> {
317 source: &'static str,
318 connected: C,
319 unconnected: U,
320 }
321
322 impl<C, U, Ci, Ui> CfgTest<C, U>
323 where
324 C: IntoIterator<Item = Ci>,
325 U: IntoIterator<Item = Ui>,
326 Ci: std::borrow::Borrow<(usize, N)>,
327 Ui: std::borrow::Borrow<(usize, N)>,
328 {
329 fn compile(&self) -> Disassembler {
330 let mut output = Disassembler::new();
331 Ingest::new(&mut output)
332 .ingest("./test", self.source)
333 .unwrap();
334 output
335 }
336
337 fn find_nodes(cfg: &ControlFlowGraph, from_off: usize, to_n: N) -> (NodeIndex, NodeIndex) {
338 let terminate_idx: NodeIndex = 0.into();
339 let bad_jump_idx: NodeIndex = 1.into();
340 assert_matches!(cfg.graph[terminate_idx], Node::Terminate);
341 assert_matches!(cfg.graph[bad_jump_idx], Node::BadJump);
342
343 let from_idx = cfg.by_offset[&from_off];
344 let to_idx = match to_n {
345 N::Offset(offset) => cfg.by_offset[&offset],
346 N::BadJump => bad_jump_idx,
347 N::Terminate => terminate_idx,
348 };
349
350 (from_idx, to_idx)
351 }
352
353 fn check(self) {
354 let mut program = self.compile();
355 let mut separator = Separator::new();
356
357 separator.push_all(program.ops());
358
359 let blocks = separator
360 .take()
361 .into_iter()
362 .chain(separator.finish().into_iter())
363 .map(|x| AnnotatedBlock::annotate(&x));
364
365 let mut cfg = ControlFlowGraph::new(blocks);
366 cfg.refine_shallow();
367
368 let connected = self
369 .connected
370 .into_iter()
371 .map(|x| x.borrow().clone())
372 .map(|(f, t)| Self::find_nodes(&cfg, f, t))
373 .map(|(f, t)| (f, t, true));
374
375 let unconnected = self
376 .unconnected
377 .into_iter()
378 .map(|x| x.borrow().clone())
379 .map(|(f, t)| Self::find_nodes(&cfg, f, t))
380 .map(|(f, t)| (f, t, false));
381
382 for (from_idx, to_idx, connected) in connected.chain(unconnected) {
383 let from = &cfg.graph[from_idx];
384 let to = &cfg.graph[to_idx];
385
386 let found = cfg.graph.find_edge(from_idx, to_idx).is_some();
387 if connected && !found {
388 panic!(
389 "edge between {} and {} was expected, but not found",
390 from, to,
391 );
392 } else if !connected && found {
393 panic!(
394 "edge between {} and {} was not expected, but was found",
395 from, to,
396 );
397 }
398 }
399 }
400 }
401
402 #[test]
403 fn empty() {
404 let source = "";
405
406 CfgTest {
407 source,
408 connected: &[],
409 unconnected: &[],
410 }
411 .check();
412 }
413
414 #[test]
415 fn just_stop() {
416 let source = "stop";
417
418 CfgTest {
419 source,
420 connected: &[(0, N::Terminate)],
421 unconnected: &[],
422 }
423 .check();
424 }
425
426 #[test]
427 fn just_pc() {
428 let source = "pc";
429
430 CfgTest {
431 source,
432 connected: &[(0, N::Terminate)],
433 unconnected: &[],
434 }
435 .check();
436 }
437
438 #[test]
439 fn just_bad_jump() {
440 let source = r#"
441 push1 0
442 jump
443 "#;
444
445 CfgTest {
446 source,
447 connected: &[(0, N::BadJump)],
448 unconnected: &[(0, N::Terminate), (0, N::Offset(0))],
449 }
450 .check();
451 }
452
453 #[test]
454 fn infinite_loop() {
455 let source = r#"
456 jumpdest
457 push1 0
458 jump
459 "#;
460
461 CfgTest {
462 source,
463 connected: &[(0, N::Offset(0))],
464 unconnected: &[(0, N::Terminate), (0, N::BadJump)],
465 }
466 .check();
467 }
468
469 #[test]
470 fn infinite_loop_with_branch() {
471 let source = r#"
472 jumpdest
473 push1 1
474 push1 0
475 jumpi
476 "#;
477
478 CfgTest {
479 source,
480 connected: &[(0, N::Offset(0))],
481 unconnected: &[(0, N::Terminate), (0, N::BadJump)],
482 }
483 .check();
484 }
485
486 #[test]
487 fn fallthrough_branch() {
488 let source = r#"
489 jumpdest
490 push1 0
491 push1 100
492 jumpi
493 "#;
494
495 CfgTest {
496 source,
497 connected: &[(0, N::Terminate)],
498 unconnected: &[(0, N::Offset(0)), (0, N::BadJump)],
499 }
500 .check();
501 }
502
503 #[test]
504 fn diamond_branch() {
505 let source = r#"
506 pc
507 calldataload
508 push1 target
509 jumpi
510
511 push1 exit
512 jump
513
514 target:
515 jumpdest
516 push1 exit
517 jump
518
519 exit:
520 jumpdest
521 "#;
522
523 CfgTest {
524 source,
525 connected: &[
526 (0, N::Offset(5)),
527 (0, N::Offset(8)),
528 (5, N::Offset(12)),
529 (8, N::Offset(12)),
530 (12, N::Terminate),
531 ],
532 unconnected: &[
533 (0, N::Offset(0)),
534 (0, N::Offset(12)),
535 (0, N::BadJump),
536 (0, N::Terminate),
537 (5, N::Offset(0)),
538 (5, N::Offset(5)),
539 (5, N::Offset(8)),
540 (5, N::BadJump),
541 (5, N::Terminate),
542 (8, N::Offset(0)),
543 (8, N::Offset(8)),
544 (8, N::Offset(5)),
545 (8, N::BadJump),
546 (8, N::Terminate),
547 (12, N::Offset(0)),
548 (12, N::Offset(8)),
549 (12, N::Offset(5)),
550 (12, N::Offset(12)),
551 (12, N::BadJump),
552 ],
553 }
554 .check();
555 }
556
557 #[test]
558 fn memory_jump() {
559 let source = r#"
560 push1 target
561 push1 0
562 mstore
563 push1 0
564 mload
565 jump
566
567 target:
568 jumpdest
569 "#;
570
571 CfgTest {
572 source,
573 connected: &[(0, N::Offset(9)), (9, N::Terminate)],
574 unconnected: &[
575 (0, N::Terminate),
578 (9, N::BadJump),
579 (9, N::Offset(0)),
580 ],
581 }
582 .check();
583 }
584
585 #[test]
586 fn storage_jump() {
587 let source = r#"
588 push1 target
589 push1 0
590 sstore
591 push1 0
592 sload
593 jump
594
595 target:
596 jumpdest
597 "#;
598
599 CfgTest {
600 source,
601 connected: &[(0, N::Offset(9)), (9, N::Terminate)],
602 unconnected: &[
603 (0, N::Terminate),
606 (9, N::BadJump),
607 (9, N::Offset(0)),
608 ],
609 }
610 .check();
611 }
612
613 #[test]
614 fn shr_branch() {
615 let source = r#"
616 push32 0x23b872dd00000000000000000000000000000000000000000000000000000000
617 push1 224
618 shr
619 push4 0x23b872dd
620 eq
621 push4 transfer_from
622 jumpi
623
624 stop
625
626 transfer_from:
627 jumpdest
628 stop
629 "#;
630
631 CfgTest {
632 source,
633 connected: &[
634 (0, N::Offset(0x31)),
635 (0x30, N::Terminate),
636 (0x31, N::Terminate),
637 ],
638 unconnected: &[
639 (0, N::Offset(0)),
640 (0, N::Offset(0x30)),
641 (0, N::BadJump),
642 (0, N::Terminate),
643 (0x30, N::Offset(0x30)),
644 (0x30, N::Offset(0x31)),
645 (0x30, N::Offset(0)),
646 (0x30, N::BadJump),
647 (0x31, N::Offset(0x31)),
648 (0x31, N::Offset(0x30)),
649 (0x31, N::Offset(0)),
650 (0x31, N::BadJump),
651 ],
652 }
653 .check();
654 }
655}