Skip to main content

etk_analyze/
cfg.rs

1use crate::blocks::annotated::ExitExt;
2
3use etk_dasm::blocks::annotated::{AnnotatedBlock, Exit};
4
5use petgraph::dot::Dot;
6use petgraph::graph::{Graph, NodeIndex};
7
8use std::collections::BTreeMap;
9use std::convert::TryInto;
10use std::fmt;
11
12use z3::ast::{Ast, BV};
13use z3::SatResult;
14
15#[derive(Debug, Clone)]
16enum Node {
17    Terminate,
18    BadJump,
19    Block(Box<AnnotatedBlock>),
20}
21
22impl fmt::Display for Node {
23    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
24        let block = match self {
25            Self::Terminate => return write!(f, "<terminate>"),
26            Self::BadJump => return write!(f, "<bad-jump>"),
27            Self::Block(b) => b,
28        };
29
30        write!(f, "Offset: 0x{:x}", block.offset)
31    }
32}
33
34struct Edge;
35
36impl fmt::Display for Edge {
37    fn fmt(&self, _: &mut fmt::Formatter) -> fmt::Result {
38        Ok(())
39    }
40}
41
42impl Node {
43    fn unwrap_block(&self) -> &AnnotatedBlock {
44        match self {
45            Self::Block(b) => b,
46            _ => panic!("not a block"),
47        }
48    }
49}
50
51pub struct ControlFlowGraph {
52    by_offset: BTreeMap<usize, NodeIndex>,
53    graph: Graph<Node, Edge>,
54}
55
56impl fmt::Debug for ControlFlowGraph {
57    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
58        #[derive(Debug)]
59        struct ControlFlowGraph<'a> {
60            #[allow(dead_code)]
61            by_offset: &'a BTreeMap<usize, NodeIndex>,
62            #[allow(dead_code)]
63            node_count: usize,
64            #[allow(dead_code)]
65            edge_count: usize,
66        }
67
68        let helper = ControlFlowGraph {
69            by_offset: &self.by_offset,
70            node_count: self.graph.node_count(),
71            edge_count: self.graph.edge_count(),
72        };
73
74        helper.fmt(f)
75    }
76}
77
78impl ControlFlowGraph {
79    pub fn new<I>(blocks: I) -> Self
80    where
81        I: Iterator<Item = AnnotatedBlock>,
82    {
83        let mut graph = Graph::<Node, Edge>::new();
84        let mut by_offset = BTreeMap::new();
85        let mut jump_targets = Vec::new();
86
87        let terminate = graph.add_node(Node::Terminate);
88        let bad_jump = graph.add_node(Node::BadJump);
89
90        for block in blocks {
91            let is_jump_target = block.jump_target;
92            let offset = block.offset;
93            let idx = graph.add_node(Node::Block(Box::new(block)));
94            let replaced = by_offset.insert(offset, idx);
95            assert_eq!(replaced, None);
96
97            if is_jump_target {
98                jump_targets.push(idx);
99            }
100        }
101
102        for idx in by_offset.values() {
103            let idx = *idx;
104            let node = &graph[idx];
105
106            let block = match node {
107                Node::Block(b) => b,
108                _ => continue,
109            };
110
111            let exit = block.exit.erase();
112
113            let mut fall_through_idx = None;
114
115            // Add edge to fall-through (aka the next instruction after this block.)
116            if let Some(fall_through) = block.exit.fall_through() {
117                let next = by_offset.get(&fall_through);
118                if let Some(next_idx) = next {
119                    // If the fallthrough matches a block, add an edge to it.
120                    graph.add_edge(idx, *next_idx, Edge);
121                    fall_through_idx = Some(next_idx);
122                } else {
123                    // If the fallthough doesn't match a block, add an edge to
124                    // <terminate>.
125                    graph.add_edge(idx, terminate, Edge);
126                }
127            };
128
129            match exit {
130                Exit::Unconditional(_) => (),
131                Exit::Branch { .. } => (),
132                Exit::Terminate => {
133                    // Terminate isn't a jump, so it can never be a bad one.
134                    graph.add_edge(idx, terminate, Edge);
135                    continue;
136                }
137                Exit::FallThrough(_) => {
138                    // Fallthrough is never a jump, so it can never be a bad one.
139                    continue;
140                }
141            }
142
143            // Assume all jumps can be bad.
144            graph.add_edge(idx, bad_jump, Edge);
145
146            for jump_target in jump_targets.iter() {
147                if Some(jump_target) == fall_through_idx {
148                    // Edge was added earlier.
149                    continue;
150                }
151
152                // Assume all jumps can go to any jump target.
153                graph.add_edge(idx, *jump_target, Edge);
154            }
155        }
156
157        Self { by_offset, graph }
158    }
159
160    fn shallow_block(&mut self, from: NodeIndex, to: NodeIndex) -> bool {
161        let from = self.graph[from].unwrap_block();
162        let to = self.graph[to].unwrap_block();
163
164        let config = z3::Config::new();
165        let context = z3::Context::new(&config);
166        let target = BV::from_u64(&context, to.offset.try_into().unwrap(), 256);
167
168        let ast = match from.exit.to_z3(&context) {
169            Exit::Terminate => unreachable!(),
170            Exit::FallThrough(f) => {
171                return f == to.offset;
172            }
173            Exit::Unconditional(u) => u,
174            Exit::Branch {
175                when_true,
176                when_false,
177                condition,
178            } => {
179                let when_false: u64 = when_false.try_into().unwrap();
180                let when_false = BV::from_u64(&context, when_false, 256);
181                let zero = BV::from_u64(&context, 0, 256);
182                condition._eq(&zero).ite(&when_false, &when_true)
183            }
184        };
185
186        let solver = z3::Solver::new(&context);
187        solver.assert(&ast._eq(&target));
188        let result = solver.check();
189
190        !matches!(result, SatResult::Unsat)
191    }
192
193    fn shallow_bad_jump(&mut self, from: NodeIndex) -> bool {
194        let from = self.graph[from].unwrap_block();
195
196        let config = z3::Config::new();
197        let context = z3::Context::new(&config);
198        let solver = z3::Solver::new(&context);
199
200        let ast = match from.exit.to_z3(&context) {
201            Exit::FallThrough(_) => return false,
202            Exit::Terminate => unreachable!(),
203            Exit::Unconditional(u) => u,
204            Exit::Branch {
205                when_true,
206                condition,
207                ..
208            } => {
209                let zero = BV::from_u64(&context, 0, 256);
210                solver.assert(&zero._eq(&condition).not());
211                when_true
212            }
213        };
214
215        for (offset, to_idx) in self.by_offset.iter() {
216            if !self.graph[*to_idx].unwrap_block().jump_target {
217                continue;
218            }
219
220            let bv = BV::from_u64(&context, (*offset).try_into().unwrap(), 256);
221            solver.assert(&bv._eq(&ast).not());
222        }
223
224        let result = solver.check();
225
226        !matches!(result, SatResult::Unsat)
227    }
228
229    fn shallow_terminate(&mut self, from: NodeIndex) -> bool {
230        let from = self.graph[from].unwrap_block();
231
232        let config = z3::Config::new();
233        let context = z3::Context::new(&config);
234        let solver = z3::Solver::new(&context);
235
236        let ast = match from.exit.to_z3(&context) {
237            Exit::FallThrough(_) => return true,
238            Exit::Terminate => return true,
239            Exit::Unconditional(_) => unreachable!(),
240            Exit::Branch { condition, .. } => {
241                let zero = BV::from_u64(&context, 0, 256);
242                zero._eq(&condition)
243            }
244        };
245
246        solver.assert(&ast);
247        let result = solver.check();
248
249        !matches!(result, SatResult::Unsat)
250    }
251
252    // https://github.com/rust-lang/rust-clippy/issues/6420
253    #[allow(clippy::needless_collect)]
254    pub fn refine_shallow(&mut self) {
255        let indexes: Vec<_> = self
256            .by_offset
257            .values()
258            .filter_map(|idx| {
259                let node = &self.graph[*idx];
260                match node {
261                    Node::Block(_) => Some(*idx),
262                    _ => None,
263                }
264            })
265            .collect();
266
267        for idx in indexes.into_iter() {
268            let neighbors_indexes: Vec<_> = self.graph.neighbors(idx).collect();
269
270            for neighbor_idx in neighbors_indexes.into_iter() {
271                let neighbor = &self.graph[neighbor_idx];
272
273                let keep = match neighbor {
274                    Node::Block(_) => self.shallow_block(idx, neighbor_idx),
275                    Node::BadJump => self.shallow_bad_jump(idx),
276                    Node::Terminate => self.shallow_terminate(idx),
277                };
278
279                if !keep {
280                    let edge = self.graph.find_edge(idx, neighbor_idx).unwrap();
281                    self.graph.remove_edge(edge);
282                }
283            }
284        }
285    }
286
287    pub fn render(&self) -> impl '_ + fmt::Display {
288        Dot::new(&self.graph)
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use assert_matches::assert_matches;
295
296    use etk_asm::disasm::Disassembler;
297    use etk_asm::ingest::Ingest;
298
299    use etk_dasm::blocks::basic::Separator;
300
301    use super::*;
302
303    #[derive(Debug, Copy, Clone)]
304    enum N {
305        Offset(usize),
306        BadJump,
307        Terminate,
308    }
309
310    impl From<usize> for N {
311        fn from(offset: usize) -> Self {
312            Self::Offset(offset)
313        }
314    }
315
316    struct CfgTest<C, U> {
317        source: &'static str,
318        connected: C,
319        unconnected: U,
320    }
321
322    impl<C, U, Ci, Ui> CfgTest<C, U>
323    where
324        C: IntoIterator<Item = Ci>,
325        U: IntoIterator<Item = Ui>,
326        Ci: std::borrow::Borrow<(usize, N)>,
327        Ui: std::borrow::Borrow<(usize, N)>,
328    {
329        fn compile(&self) -> Disassembler {
330            let mut output = Disassembler::new();
331            Ingest::new(&mut output)
332                .ingest("./test", self.source)
333                .unwrap();
334            output
335        }
336
337        fn find_nodes(cfg: &ControlFlowGraph, from_off: usize, to_n: N) -> (NodeIndex, NodeIndex) {
338            let terminate_idx: NodeIndex = 0.into();
339            let bad_jump_idx: NodeIndex = 1.into();
340            assert_matches!(cfg.graph[terminate_idx], Node::Terminate);
341            assert_matches!(cfg.graph[bad_jump_idx], Node::BadJump);
342
343            let from_idx = cfg.by_offset[&from_off];
344            let to_idx = match to_n {
345                N::Offset(offset) => cfg.by_offset[&offset],
346                N::BadJump => bad_jump_idx,
347                N::Terminate => terminate_idx,
348            };
349
350            (from_idx, to_idx)
351        }
352
353        fn check(self) {
354            let mut program = self.compile();
355            let mut separator = Separator::new();
356
357            separator.push_all(program.ops());
358
359            let blocks = separator
360                .take()
361                .into_iter()
362                .chain(separator.finish().into_iter())
363                .map(|x| AnnotatedBlock::annotate(&x));
364
365            let mut cfg = ControlFlowGraph::new(blocks);
366            cfg.refine_shallow();
367
368            let connected = self
369                .connected
370                .into_iter()
371                .map(|x| x.borrow().clone())
372                .map(|(f, t)| Self::find_nodes(&cfg, f, t))
373                .map(|(f, t)| (f, t, true));
374
375            let unconnected = self
376                .unconnected
377                .into_iter()
378                .map(|x| x.borrow().clone())
379                .map(|(f, t)| Self::find_nodes(&cfg, f, t))
380                .map(|(f, t)| (f, t, false));
381
382            for (from_idx, to_idx, connected) in connected.chain(unconnected) {
383                let from = &cfg.graph[from_idx];
384                let to = &cfg.graph[to_idx];
385
386                let found = cfg.graph.find_edge(from_idx, to_idx).is_some();
387                if connected && !found {
388                    panic!(
389                        "edge between {} and {} was expected, but not found",
390                        from, to,
391                    );
392                } else if !connected && found {
393                    panic!(
394                        "edge between {} and {} was not expected, but was found",
395                        from, to,
396                    );
397                }
398            }
399        }
400    }
401
402    #[test]
403    fn empty() {
404        let source = "";
405
406        CfgTest {
407            source,
408            connected: &[],
409            unconnected: &[],
410        }
411        .check();
412    }
413
414    #[test]
415    fn just_stop() {
416        let source = "stop";
417
418        CfgTest {
419            source,
420            connected: &[(0, N::Terminate)],
421            unconnected: &[],
422        }
423        .check();
424    }
425
426    #[test]
427    fn just_pc() {
428        let source = "pc";
429
430        CfgTest {
431            source,
432            connected: &[(0, N::Terminate)],
433            unconnected: &[],
434        }
435        .check();
436    }
437
438    #[test]
439    fn just_bad_jump() {
440        let source = r#"
441            push1 0
442            jump
443        "#;
444
445        CfgTest {
446            source,
447            connected: &[(0, N::BadJump)],
448            unconnected: &[(0, N::Terminate), (0, N::Offset(0))],
449        }
450        .check();
451    }
452
453    #[test]
454    fn infinite_loop() {
455        let source = r#"
456            jumpdest
457            push1 0
458            jump
459        "#;
460
461        CfgTest {
462            source,
463            connected: &[(0, N::Offset(0))],
464            unconnected: &[(0, N::Terminate), (0, N::BadJump)],
465        }
466        .check();
467    }
468
469    #[test]
470    fn infinite_loop_with_branch() {
471        let source = r#"
472            jumpdest
473            push1 1
474            push1 0
475            jumpi
476        "#;
477
478        CfgTest {
479            source,
480            connected: &[(0, N::Offset(0))],
481            unconnected: &[(0, N::Terminate), (0, N::BadJump)],
482        }
483        .check();
484    }
485
486    #[test]
487    fn fallthrough_branch() {
488        let source = r#"
489            jumpdest
490            push1 0
491            push1 100
492            jumpi
493        "#;
494
495        CfgTest {
496            source,
497            connected: &[(0, N::Terminate)],
498            unconnected: &[(0, N::Offset(0)), (0, N::BadJump)],
499        }
500        .check();
501    }
502
503    #[test]
504    fn diamond_branch() {
505        let source = r#"
506            pc
507            calldataload
508            push1 target
509            jumpi
510
511            push1 exit
512            jump
513
514            target:
515                jumpdest
516                push1 exit
517                jump
518
519            exit:
520                jumpdest
521        "#;
522
523        CfgTest {
524            source,
525            connected: &[
526                (0, N::Offset(5)),
527                (0, N::Offset(8)),
528                (5, N::Offset(12)),
529                (8, N::Offset(12)),
530                (12, N::Terminate),
531            ],
532            unconnected: &[
533                (0, N::Offset(0)),
534                (0, N::Offset(12)),
535                (0, N::BadJump),
536                (0, N::Terminate),
537                (5, N::Offset(0)),
538                (5, N::Offset(5)),
539                (5, N::Offset(8)),
540                (5, N::BadJump),
541                (5, N::Terminate),
542                (8, N::Offset(0)),
543                (8, N::Offset(8)),
544                (8, N::Offset(5)),
545                (8, N::BadJump),
546                (8, N::Terminate),
547                (12, N::Offset(0)),
548                (12, N::Offset(8)),
549                (12, N::Offset(5)),
550                (12, N::Offset(12)),
551                (12, N::BadJump),
552            ],
553        }
554        .check();
555    }
556
557    #[test]
558    fn memory_jump() {
559        let source = r#"
560            push1 target
561            push1 0
562            mstore
563            push1 0
564            mload
565            jump
566
567            target:
568                jumpdest
569        "#;
570
571        CfgTest {
572            source,
573            connected: &[(0, N::Offset(9)), (9, N::Terminate)],
574            unconnected: &[
575                // TODO: Until the memory stuff is better, can't prove:
576                // (0, N::BadJump),
577                (0, N::Terminate),
578                (9, N::BadJump),
579                (9, N::Offset(0)),
580            ],
581        }
582        .check();
583    }
584
585    #[test]
586    fn storage_jump() {
587        let source = r#"
588            push1 target
589            push1 0
590            sstore
591            push1 0
592            sload
593            jump
594
595            target:
596                jumpdest
597        "#;
598
599        CfgTest {
600            source,
601            connected: &[(0, N::Offset(9)), (9, N::Terminate)],
602            unconnected: &[
603                // TODO: Until the storage stuff is better, can't prove:
604                // (0, N::BadJump),
605                (0, N::Terminate),
606                (9, N::BadJump),
607                (9, N::Offset(0)),
608            ],
609        }
610        .check();
611    }
612
613    #[test]
614    fn shr_branch() {
615        let source = r#"
616            push32 0x23b872dd00000000000000000000000000000000000000000000000000000000
617            push1 224
618            shr
619            push4 0x23b872dd
620            eq
621            push4 transfer_from
622            jumpi
623
624            stop
625
626            transfer_from:
627            jumpdest
628            stop
629        "#;
630
631        CfgTest {
632            source,
633            connected: &[
634                (0, N::Offset(0x31)),
635                (0x30, N::Terminate),
636                (0x31, N::Terminate),
637            ],
638            unconnected: &[
639                (0, N::Offset(0)),
640                (0, N::Offset(0x30)),
641                (0, N::BadJump),
642                (0, N::Terminate),
643                (0x30, N::Offset(0x30)),
644                (0x30, N::Offset(0x31)),
645                (0x30, N::Offset(0)),
646                (0x30, N::BadJump),
647                (0x31, N::Offset(0x31)),
648                (0x31, N::Offset(0x30)),
649                (0x31, N::Offset(0)),
650                (0x31, N::BadJump),
651            ],
652        }
653        .check();
654    }
655}