monster/path_exploration/
shortest_path.rs1use super::{ControlFlowGraph, ExplorationStrategy, ProcedureCallId};
2use anyhow::{Context as AnyhowContext, Result};
3use itertools::Itertools;
4use log::trace;
5use petgraph::{
6 algo::dijkstra,
7 dot::Dot,
8 graph::NodeIndex,
9 prelude::*,
10 visit::{EdgeRef, Reversed},
11};
12use riscu::{Instruction, Program, Register};
13use std::{collections::HashMap, fmt, fs::File, io::Write, path::Path};
14
15pub struct ShortestPathStrategy {
16 cfg: ControlFlowGraph,
17 distance: HashMap<NodeIndex, u64>,
18 entry_address: u64,
19}
20
21impl ShortestPathStrategy {
22 pub fn compute_for(program: &Program) -> Result<Self> {
23 let cfg = time_info!("generate CFG from binary", {
24 ControlFlowGraph::build_for(program)
25 .context("Could not build control flow graph from program")?
26 });
27
28 let distance = time_info!("computing shortest paths in CFG", {
29 compute_distances(&cfg)
30 });
31
32 Ok(Self {
33 cfg,
34 distance,
35 entry_address: program.code.address,
36 })
37 }
38
39 pub fn write_cfg_with_distances_to_file<P>(&self, path: P) -> Result<()>
40 where
41 P: AsRef<Path>,
42 {
43 File::create(path)
44 .and_then(|mut f| write!(f, "{:?}", self).and_then(|_| f.sync_all()))
45 .context("Failed to write control flow graph to file")?;
46
47 Ok(())
48 }
49
50 pub fn distances(&self) -> &HashMap<NodeIndex, u64> {
51 &self.distance
52 }
53
54 pub fn create_cfg_with_distances(
55 &self,
56 ) -> Graph<(Instruction, Option<u64>), Option<ProcedureCallId>> {
57 self.cfg
58 .graph
59 .map(|i, n| (*n, self.distance.get(&i).copied()), |_, e| *e)
60 }
61
62 fn address_to_cfg_idx(&self, address: u64) -> NodeIndex {
63 NodeIndex::new((address - self.entry_address) as usize / 4)
64 }
65}
66
67impl ExplorationStrategy for ShortestPathStrategy {
68 fn choose_path(&self, branch1: u64, branch2: u64) -> u64 {
69 let distance1 = self.distance.get(&self.address_to_cfg_idx(branch1));
70 let distance2 = self.distance.get(&self.address_to_cfg_idx(branch2));
71
72 trace!(
73 "branch distance: d1={:?}, d2={:?} |- choose smallest",
74 distance1,
75 distance2
76 );
77
78 match (distance1, distance2) {
79 (Some(distance1), Some(distance2)) => {
80 if distance1 > distance2 {
81 branch2
82 } else {
83 branch1
84 }
85 }
86 (Some(_), None) => branch1,
87 (None, Some(_)) => branch2,
88 _ => panic!(
89 "both branches {} and {} are not reachable!",
90 branch1, branch2
91 ),
92 }
93 }
94}
95
96pub fn compute_distances(cfg: &ControlFlowGraph) -> HashMap<NodeIndex, u64> {
97 let unrolled = time_info!("unrolling CFG", { compute_unrolled_cfg(cfg) });
98
99 let unrolled_reversed = Reversed(&unrolled);
100
101 let exit_node = unrolled
102 .node_indices()
103 .find(|i| unrolled.edges_directed(*i, Direction::Outgoing).count() == 0)
104 .expect("every valid CFG has to to have on exit node");
105
106 time_info!("computing distances from exit on unrolled CFG", {
107 let distances = dijkstra(unrolled_reversed, exit_node, None, |_| 1_u64);
108
109 let distances_for_idx = unrolled
110 .node_indices()
111 .filter_map(|i| distances.get(&i).map(|d| (unrolled[i], *d)))
112 .into_group_map();
113
114 distances_for_idx
115 .into_iter()
116 .filter_map(|(k, v)| v.into_iter().min().map(|min| (k, min)))
117 .collect::<HashMap<NodeIndex, u64>>()
118 })
119}
120
121impl fmt::Debug for ShortestPathStrategy {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 let cfg_with_distances = self.create_cfg_with_distances();
124
125 let dot_graph = Dot::with_config(&cfg_with_distances, &[]);
126
127 write!(f, "{:?}", dot_graph)
128 }
129}
130
131impl fmt::Display for ShortestPathStrategy {
132 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133 write!(f, "{:?}", self)
134 }
135}
136
137pub type UnrolledCfg = Graph<NodeIndex, ()>;
138
139pub fn compute_unrolled_cfg(cfg: &ControlFlowGraph) -> UnrolledCfg {
140 Context::new(cfg).compute_unrolled()
141}
142
143#[derive(Clone, Copy, Debug)]
144enum ExitReason {
145 AlreadyVisited,
146 ProcedureReturn,
147 ExitSyscall,
148}
149
150#[derive(Clone)]
151pub struct Context<'a> {
152 cfg: &'a ControlFlowGraph,
153 idx: NodeIndex,
154 id: Option<u64>,
155 visited: HashMap<NodeIndex, NodeIndex>,
156 exit_reason: Option<ExitReason>,
157 caller: *const Context<'a>,
158}
159
160impl<'a> Context<'a> {
161 fn new(cfg: &'a ControlFlowGraph) -> Self {
162 Self {
163 cfg,
164 idx: NodeIndex::new(0),
165 id: None,
166 visited: HashMap::new(),
167 exit_reason: None,
168 caller: std::ptr::null(),
169 }
170 }
171
172 fn compute_unrolled(&mut self) -> UnrolledCfg {
173 let mut g = UnrolledCfg::new();
174
175 let n = g.add_node(self.idx);
176
177 self.visited.insert(self.idx, n);
178
179 self.traverse(&mut g);
180
181 g
182 }
183
184 fn next(&self) -> NodeIndex {
185 self.cfg
186 .graph
187 .neighbors_directed(self.idx, Direction::Outgoing)
188 .next()
189 .expect("instruction has a followup instruction")
190 }
191
192 fn visit_unsafe(&mut self, idx: NodeIndex, n: NodeIndex, g: &mut UnrolledCfg) {
193 let runtime_location = *self
194 .visited
195 .get(&self.idx)
196 .expect("current instruction has an runtime location at this point");
197
198 g.update_edge(runtime_location, n, ());
199
200 trace!(
201 "visit: id={:?}, idx={}, instr={:?}",
202 self.id,
203 idx.index(),
204 self.cfg.graph[idx]
205 );
206
207 self.visited.insert(idx, n);
208 self.idx = idx;
209 }
210
211 fn visit(&mut self, idx: NodeIndex, g: &mut UnrolledCfg) {
212 let n = self
213 .visited
214 .get(&idx)
215 .copied()
216 .unwrap_or_else(|| g.add_node(idx));
217
218 self.visit_unsafe(idx, n, g);
219 }
220
221 fn find_call_on_stack(
222 &self,
223 jal_idx: NodeIndex,
224 proc_entry_idx: NodeIndex,
225 ) -> Option<NodeIndex> {
226 unsafe {
227 let mut p: *const Context = self;
228
229 loop {
230 if (*p).caller.is_null() {
231 break None;
232 } else if (*(*p).caller).idx == jal_idx {
233 if let Some(proc_entry_node) = (*p).visited.get(&proc_entry_idx) {
234 break Some(*proc_entry_node);
235 }
236 } else {
237 p = (*p).caller;
238 }
239 }
240 }
241 }
242
243 fn traverse(&mut self, g: &mut UnrolledCfg) {
244 let graph = &self.cfg.graph;
245
246 while self.exit_reason.is_none() {
247 match graph[self.idx] {
248 Instruction::Jal(jtype) if jtype.rd() != Register::Zero => {
249 let jump_idx = self.next();
250
251 if let Some(ProcedureCallId::Call(id)) = self
252 .cfg
253 .graph
254 .edges_directed(self.idx, Direction::Outgoing)
255 .next()
256 .expect("A procedure call (jal) always has to have an outgoing edge")
257 .weight()
258 {
259 if let Some(proc_entry_node) =
260 self.find_call_on_stack(self.idx, self.next())
261 {
262 self.visit_unsafe(jump_idx, proc_entry_node, g);
263 trace!("jal: (procedure) visited => exiting");
264
265 self.exit_reason = Some(ExitReason::AlreadyVisited);
266 } else {
267 let mut other = self.clone();
268 trace!("call {:p}: id={}", &other, *id);
269
270 other.id = Some(*id);
271 other.caller = self;
272
273 other.visited = HashMap::new();
274 other.visited.insert(
275 other.idx,
276 *self
277 .visited
278 .get(&self.idx)
279 .expect("has been visited already"),
280 );
281
282 other.visit(jump_idx, g);
283 other.traverse(g);
284
285 trace!("returned from function");
286
287 match other.exit_reason {
288 Some(ExitReason::ProcedureReturn) => {
289 self.idx = other.idx;
290 self.visited.insert(
291 self.idx,
292 *other.visited.get(&other.idx).expect(
293 "instruction (return) has to have an runtime location",
294 ),
295 );
296 }
297 Some(_) => {
298 self.idx = other.idx;
299 self.exit_reason = other.exit_reason;
300 }
301 _ => unreachable!("reason: {:?}", other.exit_reason),
302 }
303 }
304 } else {
305 panic!("this has to be a procedure call edge")
306 }
307 }
308 Instruction::Jal(jtype) if jtype.rd() == Register::Zero && jtype.imm() <= 0 => {
309 let jump_idx = self.next();
311
312 if self.visited.contains_key(&jump_idx) {
313 self.visit(jump_idx, g);
314 trace!("jal: (loop) visited => exiting");
315 self.exit_reason = Some(ExitReason::AlreadyVisited);
316 } else {
317 self.visit(jump_idx, g);
318 }
319 }
320 Instruction::Jalr(_) => {
321 let mut return_edges = graph.edges_directed(self.idx, Direction::Outgoing);
322
323 let return_idx = return_edges
324 .find(
325 |e| matches!(e.weight(), Some(ProcedureCallId::Return(id)) if self.id == Some(*id)),
326 )
327 .expect("no matching jalr for jal of type procedure call found")
328 .target();
329
330 self.visit(return_idx, g);
331 self.exit_reason = Some(ExitReason::ProcedureReturn);
332
333 trace!("jalr: exiting");
334 }
335 Instruction::Beq(_) => {
336 let mut neighbors = graph.neighbors_directed(self.idx, Direction::Outgoing);
337
338 let first = neighbors.next().expect("BEQ creates 2 branches");
339 let second = neighbors.next().expect("BEQ creates 2 branches");
340
341 let mut other = self.clone();
342
343 other.visit(first, g);
344 other.traverse(g);
345
346 self.visited = other.visited;
347 self.visit(second, g);
348 self.traverse(g);
349
350 match (other.exit_reason, self.exit_reason) {
351 (Some(ExitReason::ProcedureReturn), _) => {
352 self.idx = other.idx;
353 self.exit_reason = other.exit_reason;
354 }
355 (_, Some(ExitReason::ProcedureReturn)) => {}
356 (Some(ExitReason::ExitSyscall), _) => {
357 self.idx = other.idx;
358 self.exit_reason = other.exit_reason;
359 }
360
361 (_, Some(ExitReason::ExitSyscall)) => {}
362 (Some(ExitReason::AlreadyVisited), _) => {
363 self.idx = other.idx;
364 self.exit_reason = other.exit_reason;
365 }
366
367 (_, Some(ExitReason::AlreadyVisited)) => {}
368 _ => panic!("can not return address of return"),
369 };
370 }
371 Instruction::Ecall(_) => {
372 if graph.edges_directed(self.idx, Direction::Outgoing).count() == 0 {
373 trace!("ecall: (exit) => exiting");
374 self.exit_reason = Some(ExitReason::ExitSyscall);
375 } else {
376 trace!("ecall: (not exit) => go on");
377 self.visit(self.next(), g);
378 }
379 }
380 _ => self.visit(self.next(), g),
381 };
382 }
383 }
384}