monster/path_exploration/
shortest_path.rs

1use super::{ControlFlowGraph, ExplorationStrategy, ProcedureCallId};
2use anyhow::{Context as AnyhowContext, Result};
3use itertools::Itertools;
4use log::trace;
5use petgraph::{
6    algo::dijkstra,
7    dot::Dot,
8    graph::NodeIndex,
9    prelude::*,
10    visit::{EdgeRef, Reversed},
11};
12use riscu::{Instruction, Program, Register};
13use std::{collections::HashMap, fmt, fs::File, io::Write, path::Path};
14
15pub struct ShortestPathStrategy {
16    cfg: ControlFlowGraph,
17    distance: HashMap<NodeIndex, u64>,
18    entry_address: u64,
19}
20
21impl ShortestPathStrategy {
22    pub fn compute_for(program: &Program) -> Result<Self> {
23        let cfg = time_info!("generate CFG from binary", {
24            ControlFlowGraph::build_for(program)
25                .context("Could not build control flow graph from program")?
26        });
27
28        let distance = time_info!("computing shortest paths in CFG", {
29            compute_distances(&cfg)
30        });
31
32        Ok(Self {
33            cfg,
34            distance,
35            entry_address: program.code.address,
36        })
37    }
38
39    pub fn write_cfg_with_distances_to_file<P>(&self, path: P) -> Result<()>
40    where
41        P: AsRef<Path>,
42    {
43        File::create(path)
44            .and_then(|mut f| write!(f, "{:?}", self).and_then(|_| f.sync_all()))
45            .context("Failed to write control flow graph to file")?;
46
47        Ok(())
48    }
49
50    pub fn distances(&self) -> &HashMap<NodeIndex, u64> {
51        &self.distance
52    }
53
54    pub fn create_cfg_with_distances(
55        &self,
56    ) -> Graph<(Instruction, Option<u64>), Option<ProcedureCallId>> {
57        self.cfg
58            .graph
59            .map(|i, n| (*n, self.distance.get(&i).copied()), |_, e| *e)
60    }
61
62    fn address_to_cfg_idx(&self, address: u64) -> NodeIndex {
63        NodeIndex::new((address - self.entry_address) as usize / 4)
64    }
65}
66
67impl ExplorationStrategy for ShortestPathStrategy {
68    fn choose_path(&self, branch1: u64, branch2: u64) -> u64 {
69        let distance1 = self.distance.get(&self.address_to_cfg_idx(branch1));
70        let distance2 = self.distance.get(&self.address_to_cfg_idx(branch2));
71
72        trace!(
73            "branch distance: d1={:?}, d2={:?} |- choose smallest",
74            distance1,
75            distance2
76        );
77
78        match (distance1, distance2) {
79            (Some(distance1), Some(distance2)) => {
80                if distance1 > distance2 {
81                    branch2
82                } else {
83                    branch1
84                }
85            }
86            (Some(_), None) => branch1,
87            (None, Some(_)) => branch2,
88            _ => panic!(
89                "both branches {} and {} are not reachable!",
90                branch1, branch2
91            ),
92        }
93    }
94}
95
96pub fn compute_distances(cfg: &ControlFlowGraph) -> HashMap<NodeIndex, u64> {
97    let unrolled = time_info!("unrolling CFG", { compute_unrolled_cfg(cfg) });
98
99    let unrolled_reversed = Reversed(&unrolled);
100
101    let exit_node = unrolled
102        .node_indices()
103        .find(|i| unrolled.edges_directed(*i, Direction::Outgoing).count() == 0)
104        .expect("every valid CFG has to to have on exit node");
105
106    time_info!("computing distances from exit on unrolled CFG", {
107        let distances = dijkstra(unrolled_reversed, exit_node, None, |_| 1_u64);
108
109        let distances_for_idx = unrolled
110            .node_indices()
111            .filter_map(|i| distances.get(&i).map(|d| (unrolled[i], *d)))
112            .into_group_map();
113
114        distances_for_idx
115            .into_iter()
116            .filter_map(|(k, v)| v.into_iter().min().map(|min| (k, min)))
117            .collect::<HashMap<NodeIndex, u64>>()
118    })
119}
120
121impl fmt::Debug for ShortestPathStrategy {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        let cfg_with_distances = self.create_cfg_with_distances();
124
125        let dot_graph = Dot::with_config(&cfg_with_distances, &[]);
126
127        write!(f, "{:?}", dot_graph)
128    }
129}
130
131impl fmt::Display for ShortestPathStrategy {
132    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133        write!(f, "{:?}", self)
134    }
135}
136
137pub type UnrolledCfg = Graph<NodeIndex, ()>;
138
139pub fn compute_unrolled_cfg(cfg: &ControlFlowGraph) -> UnrolledCfg {
140    Context::new(cfg).compute_unrolled()
141}
142
143#[derive(Clone, Copy, Debug)]
144enum ExitReason {
145    AlreadyVisited,
146    ProcedureReturn,
147    ExitSyscall,
148}
149
150#[derive(Clone)]
151pub struct Context<'a> {
152    cfg: &'a ControlFlowGraph,
153    idx: NodeIndex,
154    id: Option<u64>,
155    visited: HashMap<NodeIndex, NodeIndex>,
156    exit_reason: Option<ExitReason>,
157    caller: *const Context<'a>,
158}
159
160impl<'a> Context<'a> {
161    fn new(cfg: &'a ControlFlowGraph) -> Self {
162        Self {
163            cfg,
164            idx: NodeIndex::new(0),
165            id: None,
166            visited: HashMap::new(),
167            exit_reason: None,
168            caller: std::ptr::null(),
169        }
170    }
171
172    fn compute_unrolled(&mut self) -> UnrolledCfg {
173        let mut g = UnrolledCfg::new();
174
175        let n = g.add_node(self.idx);
176
177        self.visited.insert(self.idx, n);
178
179        self.traverse(&mut g);
180
181        g
182    }
183
184    fn next(&self) -> NodeIndex {
185        self.cfg
186            .graph
187            .neighbors_directed(self.idx, Direction::Outgoing)
188            .next()
189            .expect("instruction has a followup instruction")
190    }
191
192    fn visit_unsafe(&mut self, idx: NodeIndex, n: NodeIndex, g: &mut UnrolledCfg) {
193        let runtime_location = *self
194            .visited
195            .get(&self.idx)
196            .expect("current instruction has an runtime location at this point");
197
198        g.update_edge(runtime_location, n, ());
199
200        trace!(
201            "visit: id={:?}, idx={}, instr={:?}",
202            self.id,
203            idx.index(),
204            self.cfg.graph[idx]
205        );
206
207        self.visited.insert(idx, n);
208        self.idx = idx;
209    }
210
211    fn visit(&mut self, idx: NodeIndex, g: &mut UnrolledCfg) {
212        let n = self
213            .visited
214            .get(&idx)
215            .copied()
216            .unwrap_or_else(|| g.add_node(idx));
217
218        self.visit_unsafe(idx, n, g);
219    }
220
221    fn find_call_on_stack(
222        &self,
223        jal_idx: NodeIndex,
224        proc_entry_idx: NodeIndex,
225    ) -> Option<NodeIndex> {
226        unsafe {
227            let mut p: *const Context = self;
228
229            loop {
230                if (*p).caller.is_null() {
231                    break None;
232                } else if (*(*p).caller).idx == jal_idx {
233                    if let Some(proc_entry_node) = (*p).visited.get(&proc_entry_idx) {
234                        break Some(*proc_entry_node);
235                    }
236                } else {
237                    p = (*p).caller;
238                }
239            }
240        }
241    }
242
243    fn traverse(&mut self, g: &mut UnrolledCfg) {
244        let graph = &self.cfg.graph;
245
246        while self.exit_reason.is_none() {
247            match graph[self.idx] {
248                Instruction::Jal(jtype) if jtype.rd() != Register::Zero => {
249                    let jump_idx = self.next();
250
251                    if let Some(ProcedureCallId::Call(id)) = self
252                        .cfg
253                        .graph
254                        .edges_directed(self.idx, Direction::Outgoing)
255                        .next()
256                        .expect("A procedure call (jal) always has to have an outgoing edge")
257                        .weight()
258                    {
259                        if let Some(proc_entry_node) =
260                            self.find_call_on_stack(self.idx, self.next())
261                        {
262                            self.visit_unsafe(jump_idx, proc_entry_node, g);
263                            trace!("jal: (procedure) visited => exiting");
264
265                            self.exit_reason = Some(ExitReason::AlreadyVisited);
266                        } else {
267                            let mut other = self.clone();
268                            trace!("call {:p}: id={}", &other, *id);
269
270                            other.id = Some(*id);
271                            other.caller = self;
272
273                            other.visited = HashMap::new();
274                            other.visited.insert(
275                                other.idx,
276                                *self
277                                    .visited
278                                    .get(&self.idx)
279                                    .expect("has been visited already"),
280                            );
281
282                            other.visit(jump_idx, g);
283                            other.traverse(g);
284
285                            trace!("returned from function");
286
287                            match other.exit_reason {
288                                Some(ExitReason::ProcedureReturn) => {
289                                    self.idx = other.idx;
290                                    self.visited.insert(
291                                        self.idx,
292                                        *other.visited.get(&other.idx).expect(
293                                            "instruction (return) has to have an runtime location",
294                                        ),
295                                    );
296                                }
297                                Some(_) => {
298                                    self.idx = other.idx;
299                                    self.exit_reason = other.exit_reason;
300                                }
301                                _ => unreachable!("reason: {:?}", other.exit_reason),
302                            }
303                        }
304                    } else {
305                        panic!("this has to be a procedure call edge")
306                    }
307                }
308                Instruction::Jal(jtype) if jtype.rd() == Register::Zero && jtype.imm() <= 0 => {
309                    // end of while loop
310                    let jump_idx = self.next();
311
312                    if self.visited.contains_key(&jump_idx) {
313                        self.visit(jump_idx, g);
314                        trace!("jal: (loop) visited => exiting");
315                        self.exit_reason = Some(ExitReason::AlreadyVisited);
316                    } else {
317                        self.visit(jump_idx, g);
318                    }
319                }
320                Instruction::Jalr(_) => {
321                    let mut return_edges = graph.edges_directed(self.idx, Direction::Outgoing);
322
323                    let return_idx = return_edges
324                        .find(
325                            |e| matches!(e.weight(), Some(ProcedureCallId::Return(id)) if self.id == Some(*id)),
326                        )
327                        .expect("no matching jalr for jal of type procedure call found")
328                        .target();
329
330                    self.visit(return_idx, g);
331                    self.exit_reason = Some(ExitReason::ProcedureReturn);
332
333                    trace!("jalr: exiting");
334                }
335                Instruction::Beq(_) => {
336                    let mut neighbors = graph.neighbors_directed(self.idx, Direction::Outgoing);
337
338                    let first = neighbors.next().expect("BEQ creates 2 branches");
339                    let second = neighbors.next().expect("BEQ creates 2 branches");
340
341                    let mut other = self.clone();
342
343                    other.visit(first, g);
344                    other.traverse(g);
345
346                    self.visited = other.visited;
347                    self.visit(second, g);
348                    self.traverse(g);
349
350                    match (other.exit_reason, self.exit_reason) {
351                        (Some(ExitReason::ProcedureReturn), _) => {
352                            self.idx = other.idx;
353                            self.exit_reason = other.exit_reason;
354                        }
355                        (_, Some(ExitReason::ProcedureReturn)) => {}
356                        (Some(ExitReason::ExitSyscall), _) => {
357                            self.idx = other.idx;
358                            self.exit_reason = other.exit_reason;
359                        }
360
361                        (_, Some(ExitReason::ExitSyscall)) => {}
362                        (Some(ExitReason::AlreadyVisited), _) => {
363                            self.idx = other.idx;
364                            self.exit_reason = other.exit_reason;
365                        }
366
367                        (_, Some(ExitReason::AlreadyVisited)) => {}
368                        _ => panic!("can not return address of return"),
369                    };
370                }
371                Instruction::Ecall(_) => {
372                    if graph.edges_directed(self.idx, Direction::Outgoing).count() == 0 {
373                        trace!("ecall: (exit) => exiting");
374                        self.exit_reason = Some(ExitReason::ExitSyscall);
375                    } else {
376                        trace!("ecall: (not exit) => go on");
377                        self.visit(self.next(), g);
378                    }
379                }
380                _ => self.visit(self.next(), g),
381            };
382        }
383    }
384}