isla_lib/ir/
linearize.rs

1// BSD 2-Clause License
2//
3// Copyright (c) 2020 Alasdair Armstrong
4//
5// All rights reserved.
6//
7// Redistribution and use in source and binary forms, with or without
8// modification, are permitted provided that the following conditions are
9// met:
10//
11// 1. Redistributions of source code must retain the above copyright
12// notice, this list of conditions and the following disclaimer.
13//
14// 2. Redistributions in binary form must reproduce the above copyright
15// notice, this list of conditions and the following disclaimer in the
16// documentation and/or other materials provided with the distribution.
17//
18// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
30//! This module provides a function [linearize()] that converts IR
31//! from function bodies containing loops and other IR, into a linear
32//! sequence of instructions without any control flow.
33//!
34//! The way this works is as follows:
35//!
36//! ```text
37//!     A    A: declare x; if b ...
38//!    / \   B: then { x = f(x) }
39//!   B   C  C: else { x = g(x) }
40//!    \ /   D: return x
41//!     D
42//! ```
43//!
44//! This is then converted into SSA form, like:
45//!
46//! ```text
47//!     A    A: declare x/1; if b
48//!    / \   B: then { x/2 = f(x/1) }
49//!   B   C  C: else { x/3 = g(x/1) }
50//!    \ /   D: x/4 = φ(x/2, x/3); return x/4
51//!     D
52//! ```
53//!
54//! Finally, we come out of SSA form by placing the control flow graph
55//! into topological order, and replacing the phi functions with `ite`
56//! functions that map directly to the `ite` construct in the SMT
57//! solver.
58//!
59//! ```text
60//!    A     A: declare x/1;
61//!    |     B: declare x/2;
62//!    B        x/2 = f(x/1);
63//!    |     C: declare x/3;
64//!    C        x/3 = g(x/1);
65//!    |     D: declare x/4;
66//!    D        x/4 = ite(b, x/2, x/3);
67//!             return x/4
68//! ```
69//!
70//! The obvious limitations of this are that the function in question
71//! needs to be pure (it can only read architectural state), and its
72//! control flow graph must be acyclic so it can be placed into a
73//! topological order.
74
75use petgraph::algo;
76use petgraph::graph::{EdgeIndex, NodeIndex};
77use petgraph::Direction;
78use std::cmp;
79use std::ops::{BitAnd, BitOr};
80
81use super::ssa::{unssa_ty, BlockInstr, BlockLoc, Edge, SSAName, Terminator, CFG};
82use super::*;
83use crate::config::ISAConfig;
84use crate::primop::{binary_primops, variadic_primops};
85
86/// The reachability of a node in an SSA graph is determined by a
87/// boolean formula over edges which can be taken to reach that node.
88#[derive(Clone)]
89enum Reachability {
90    True,
91    False,
92    Edge(EdgeIndex),
93    And(Box<Reachability>, Box<Reachability>),
94    Or(Box<Reachability>, Box<Reachability>),
95}
96
97fn terminator_reachability_exp(terminator: &Terminator, edge: &Edge) -> Exp<SSAName> {
98    match (terminator, edge) {
99        (Terminator::Continue, Edge::Continue) => Exp::Bool(true),
100        (Terminator::Goto(_), Edge::Goto) => Exp::Bool(true),
101        (Terminator::Jump(exp, _, _), Edge::Jump(true)) => exp.clone(),
102        (Terminator::Jump(exp, _, _), Edge::Jump(false)) => Exp::Call(Op::Not, vec![exp.clone()]),
103        (_, _) => panic!("Bad terminator/edge pair in SSA"),
104    }
105}
106
107impl Reachability {
108    fn exp<B: BV>(&self, cfg: &CFG<B>) -> Exp<SSAName> {
109        use Reachability::*;
110        match self {
111            True => Exp::Bool(true),
112            False => Exp::Bool(false),
113            Edge(edge) => {
114                if let Some((pred, _)) = cfg.graph.edge_endpoints(*edge) {
115                    terminator_reachability_exp(&cfg.graph[pred].terminator, &cfg.graph[*edge])
116                } else {
117                    panic!("Edge in reachability condition does not exist!")
118                }
119            }
120            And(lhs, rhs) => Exp::Call(Op::And, vec![lhs.exp(cfg), rhs.exp(cfg)]),
121            Or(lhs, rhs) => Exp::Call(Op::Or, vec![lhs.exp(cfg), rhs.exp(cfg)]),
122        }
123    }
124}
125
126impl BitOr for Reachability {
127    type Output = Self;
128
129    fn bitor(self, rhs: Self) -> Self::Output {
130        use Reachability::*;
131        match (self, rhs) {
132            (True, _) => True,
133            (_, True) => True,
134            (False, rhs) => rhs,
135            (lhs, False) => lhs,
136            (lhs, rhs) => Or(Box::new(lhs), Box::new(rhs)),
137        }
138    }
139}
140
141impl BitAnd for Reachability {
142    type Output = Self;
143
144    fn bitand(self, rhs: Self) -> Self::Output {
145        use Reachability::*;
146        match (self, rhs) {
147            (True, rhs) => rhs,
148            (lhs, True) => lhs,
149            (False, _) => False,
150            (_, False) => False,
151            (lhs, rhs) => And(Box::new(lhs), Box::new(rhs)),
152        }
153    }
154}
155
156/// Computes the reachability condition for each node in an acyclic graph.
157fn compute_reachability<B: BV>(cfg: &CFG<B>, topo_order: &[NodeIndex]) -> HashMap<NodeIndex, Reachability> {
158    let mut reachability: HashMap<NodeIndex, Reachability> = HashMap::new();
159
160    for ix in topo_order {
161        let mut r = if *ix == cfg.root { Reachability::True } else { Reachability::False };
162
163        for pred in cfg.graph.neighbors_directed(*ix, Direction::Incoming) {
164            let edge = cfg.graph.find_edge(pred, *ix).unwrap();
165            let (pred, _) = cfg.graph.edge_endpoints(edge).unwrap();
166            let pred_r = reachability.get(&pred).unwrap().clone();
167            r = r | (pred_r & Reachability::Edge(edge))
168        }
169
170        reachability.insert(*ix, r);
171    }
172
173    reachability
174}
175
176fn unssa_loc(loc: &BlockLoc, symtab: &mut Symtab, names: &mut HashMap<SSAName, Name>) -> Loc<Name> {
177    use Loc::*;
178    match loc {
179        BlockLoc::Id(id) => Id(id.unssa(symtab, names)),
180        BlockLoc::Field(loc, _, field) => Field(Box::new(unssa_loc(loc, symtab, names)), field.unssa(symtab, names)),
181        BlockLoc::Addr(loc) => Addr(Box::new(unssa_loc(loc, symtab, names))),
182    }
183}
184
185fn unssa_exp(exp: &Exp<SSAName>, symtab: &mut Symtab, names: &mut HashMap<SSAName, Name>) -> Exp<Name> {
186    use Exp::*;
187    match exp {
188        Id(id) => Id(id.unssa(symtab, names)),
189        Ref(r) => Ref(r.unssa(symtab, names)),
190        Bool(b) => Bool(*b),
191        Bits(bv) => Bits(*bv),
192        String(s) => String(s.clone()),
193        Unit => Unit,
194        I64(n) => I64(*n),
195        I128(n) => I128(*n),
196        Undefined(ty) => Undefined(unssa_ty(ty)),
197        Struct(s, fields) => Struct(
198            s.unssa(symtab, names),
199            fields.iter().map(|(field, exp)| (field.unssa(symtab, names), unssa_exp(exp, symtab, names))).collect(),
200        ),
201        Kind(ctor, exp) => Kind(ctor.unssa(symtab, names), Box::new(unssa_exp(exp, symtab, names))),
202        Unwrap(ctor, exp) => Unwrap(ctor.unssa(symtab, names), Box::new(unssa_exp(exp, symtab, names))),
203        Field(exp, field) => Field(Box::new(unssa_exp(exp, symtab, names)), field.unssa(symtab, names)),
204        Call(op, args) => Call(*op, args.iter().map(|arg| unssa_exp(arg, symtab, names)).collect()),
205    }
206}
207
208fn unssa_block_instr<B: BV>(
209    instr: &BlockInstr<B>,
210    symtab: &mut Symtab,
211    names: &mut HashMap<SSAName, Name>,
212) -> Instr<Name, B> {
213    use BlockInstr::*;
214    match instr {
215        Decl(v, ty) => Instr::Decl(v.unssa(symtab, names), unssa_ty(ty)),
216        Init(v, ty, exp) => Instr::Init(v.unssa(symtab, names), unssa_ty(ty), unssa_exp(exp, symtab, names)),
217        Copy(loc, exp) => Instr::Copy(unssa_loc(loc, symtab, names), unssa_exp(exp, symtab, names)),
218        Monomorphize(v) => Instr::Monomorphize(v.unssa(symtab, names)),
219        Call(loc, ext, f, args) => Instr::Call(
220            unssa_loc(loc, symtab, names),
221            *ext,
222            *f,
223            args.iter().map(|arg| unssa_exp(arg, symtab, names)).collect(),
224        ),
225        PrimopUnary(loc, fptr, exp) => {
226            Instr::PrimopUnary(unssa_loc(loc, symtab, names), *fptr, unssa_exp(exp, symtab, names))
227        }
228        PrimopBinary(loc, fptr, exp1, exp2) => Instr::PrimopBinary(
229            unssa_loc(loc, symtab, names),
230            *fptr,
231            unssa_exp(exp1, symtab, names),
232            unssa_exp(exp2, symtab, names),
233        ),
234        PrimopVariadic(loc, fptr, args) => Instr::PrimopVariadic(
235            unssa_loc(loc, symtab, names),
236            *fptr,
237            args.iter().map(|arg| unssa_exp(arg, symtab, names)).collect(),
238        ),
239    }
240}
241
242fn apply_label<B: BV>(label: &mut Option<usize>, instr: Instr<Name, B>) -> LabeledInstr<B> {
243    if let Some(label) = label.take() {
244        LabeledInstr::Labeled(label, instr)
245    } else {
246        LabeledInstr::Unlabeled(instr)
247    }
248}
249
250#[allow(clippy::too_many_arguments)]
251fn ite_chain<B: BV>(
252    label: &mut Option<usize>,
253    i: usize,
254    path_conds: &[Exp<SSAName>],
255    id: Name,
256    first: SSAName,
257    rest: &[SSAName],
258    ty: &Ty<Name>,
259    names: &mut HashMap<SSAName, Name>,
260    symtab: &mut Symtab,
261    linearized: &mut Vec<LabeledInstr<B>>,
262) {
263    let ite = *variadic_primops::<B>().get("ite").unwrap();
264
265    if let Some((second, rest)) = rest.split_first() {
266        let gs = symtab.gensym();
267        linearized.push(apply_label(label, Instr::Decl(gs, ty.clone())));
268        ite_chain(label, i + 1, path_conds, gs, *second, rest, ty, names, symtab, linearized);
269        linearized.push(apply_label(
270            label,
271            Instr::PrimopVariadic(
272                Loc::Id(id),
273                ite,
274                vec![unssa_exp(&path_conds[i], symtab, names), Exp::Id(first.unssa(symtab, names)), Exp::Id(gs)],
275            ),
276        ))
277    } else {
278        linearized.push(apply_label(label, Instr::Copy(Loc::Id(id), Exp::Id(first.unssa(symtab, names)))))
279    }
280}
281
282#[allow(clippy::too_many_arguments)]
283fn linearize_phi<B: BV>(
284    label: &mut Option<usize>,
285    id: SSAName,
286    args: &[SSAName],
287    n: NodeIndex,
288    cfg: &CFG<B>,
289    reachability: &HashMap<NodeIndex, Reachability>,
290    names: &mut HashMap<SSAName, Name>,
291    types: &HashMap<Name, Ty<Name>>,
292    symtab: &mut Symtab,
293    linearized: &mut Vec<LabeledInstr<B>>,
294) {
295    let mut path_conds = Vec::new();
296
297    for pred in cfg.graph.neighbors_directed(n, Direction::Incoming) {
298        let edge = cfg.graph.find_edge(pred, n).unwrap();
299        let cond = reachability[&pred].clone() & Reachability::Edge(edge);
300        path_conds.push(cond.exp(cfg))
301    }
302
303    // A phi function with no arguments has been explicitly pruned, so
304    // we do nothing in that case.
305    if let Some((first, rest)) = args.split_first() {
306        let ty = &types[&id.base_name()];
307        ite_chain(label, 0, &path_conds, id.unssa(symtab, names), *first, rest, ty, names, symtab, linearized)
308    }
309}
310
311fn linearize_block<B: BV>(
312    n: NodeIndex,
313    cfg: &CFG<B>,
314    reachability: &HashMap<NodeIndex, Reachability>,
315    names: &mut HashMap<SSAName, Name>,
316    types: &HashMap<Name, Ty<Name>>,
317    symtab: &mut Symtab,
318    linearized: &mut Vec<LabeledInstr<B>>,
319) {
320    let block = cfg.graph.node_weight(n).unwrap();
321    let mut label = block.label;
322
323    for (id, args) in &block.phis {
324        let ty = &types[&id.base_name()];
325
326        linearized.push(apply_label(&mut label, Instr::Decl(id.unssa(symtab, names), ty.clone())));
327
328        // We never have to insert ites for phi functions with unit
329        // types, and in fact cannot because unit is always concrete.
330        match ty {
331            Ty::Unit => (),
332            _ => linearize_phi(&mut label, *id, args, n, cfg, reachability, names, types, symtab, linearized),
333        }
334    }
335
336    for instr in &block.instrs {
337        if let Some((id, prev_id)) = instr.write_ssa() {
338            if instr.declares().is_none() {
339                let ty = types[&id.base_name()].clone();
340                let instr = match prev_id {
341                    Some(prev_id) => Instr::Init(id.unssa(symtab, names), ty, Exp::Id(prev_id.unssa(symtab, names))),
342                    None => Instr::Decl(id.unssa(symtab, names), ty),
343                };
344                linearized.push(apply_label(&mut label, instr))
345            }
346        }
347        linearized.push(apply_label(&mut label, unssa_block_instr(instr, symtab, names)))
348    }
349}
350
351// Linearized functions must be pure - so any assertions ought to be provable
352// and we can remove the assertion.  (Actually we change it to true to avoid
353// renumbering/relabelling.)  To be sure of correctness the self test should
354// be used.
355fn drop_assertions<B: BV>(instrs: &[Instr<Name, B>]) -> Vec<Instr<Name, B>> {
356    instrs
357        .iter()
358        .map(|instr| match instr {
359            Instr::Call(l, ext, op, args) if *op == SAIL_ASSERT => {
360                Instr::Call(l.clone(), *ext, *op, vec![Exp::Bool(true), args[1].clone()])
361            }
362            _ => instr.clone(),
363        })
364        .collect()
365}
366
367pub fn linearize<B: BV>(instrs: Vec<Instr<Name, B>>, ret_ty: &Ty<Name>, symtab: &mut Symtab) -> Vec<Instr<Name, B>> {
368    use LabeledInstr::*;
369
370    let instrs = drop_assertions(&instrs);
371    let labeled = prune_labels(label_instrs(instrs));
372    let mut cfg = CFG::new(&labeled);
373    cfg.ssa();
374
375    if let Ok(topo_order) = algo::toposort(&cfg.graph, None) {
376        let reachability = compute_reachability(&cfg, &topo_order);
377        let types = cfg.all_vars_typed(ret_ty);
378        let mut linearized = Vec::new();
379        let mut names = HashMap::new();
380        let mut last_return = -1;
381
382        for ix in cfg.graph.node_indices() {
383            let node = &cfg.graph[ix];
384            for instr in &node.instrs {
385                if let Some((id, _)) = instr.write_ssa() {
386                    if id.base_name() == RETURN {
387                        last_return = cmp::max(id.ssa_number(), last_return)
388                    }
389                }
390            }
391            for (id, _) in &node.phis {
392                if id.base_name() == RETURN {
393                    last_return = cmp::max(id.ssa_number(), last_return)
394                }
395            }
396        }
397
398        for ix in &topo_order {
399            linearize_block(*ix, &cfg, &reachability, &mut names, &types, symtab, &mut linearized)
400        }
401
402        if last_return >= 0 {
403            linearized.push(Unlabeled(Instr::Copy(
404                Loc::Id(RETURN),
405                Exp::Id(SSAName::new_ssa(RETURN, last_return).unssa(symtab, &mut names)),
406            )))
407        }
408        linearized.push(Unlabeled(Instr::End));
409
410        unlabel_instrs(linearized)
411    } else {
412        unlabel_instrs(labeled)
413    }
414}
415
416/// Test that a rewritten function body is equivalent to the original
417/// body by constructing a symbolic execution problem that proves
418/// this. Note that this function should called with an uninitialized
419/// architecture.
420#[allow(clippy::too_many_arguments)]
421pub fn self_test<'ir, B: BV>(
422    num_threads: usize,
423    mut arch: Vec<Def<Name, B>>,
424    mut symtab: Symtab<'ir>,
425    isa_config: &ISAConfig<B>,
426    args: &[Name],
427    arg_tys: &[Ty<Name>],
428    ret_ty: &Ty<Name>,
429    instrs1: Vec<Instr<Name, B>>,
430    instrs2: Vec<Instr<Name, B>>,
431) -> bool {
432    use crate::executor;
433    use crate::init::{initialize_architecture, Initialized};
434    use std::sync::atomic::{AtomicBool, Ordering};
435
436    let fn1 = symtab.intern("self_test_fn1#");
437    let fn2 = symtab.intern("self_test_fn2#");
438    let comparison = symtab.intern("self_test_compare#");
439
440    arch.push(Def::Val(fn1, arg_tys.to_vec(), ret_ty.clone()));
441    arch.push(Def::Fn(fn1, args.to_vec(), instrs1));
442
443    arch.push(Def::Val(fn2, arg_tys.to_vec(), ret_ty.clone()));
444    arch.push(Def::Fn(fn2, args.to_vec(), instrs2));
445
446    arch.push(Def::Val(comparison, arg_tys.to_vec(), Ty::Bool));
447    arch.push(Def::Fn(comparison, args.to_vec(), {
448        use super::Instr::*;
449        let x = symtab.gensym();
450        let y = symtab.gensym();
451        let eq_anything = *binary_primops::<B>().get("eq_anything").unwrap();
452        vec![
453            Decl(x, ret_ty.clone()),
454            Call(Loc::Id(x), false, fn1, args.iter().map(|id| Exp::Id(*id)).collect()),
455            Decl(y, ret_ty.clone()),
456            Call(Loc::Id(y), false, fn2, args.iter().map(|id| Exp::Id(*id)).collect()),
457            PrimopBinary(Loc::Id(RETURN), eq_anything, Exp::Id(x), Exp::Id(y)),
458            End,
459        ]
460    }));
461
462    let Initialized { regs, lets, shared_state } =
463        initialize_architecture(&mut arch, symtab, isa_config, AssertionMode::Optimistic);
464
465    let (args, _, instrs) = shared_state.functions.get(&comparison).unwrap();
466    let task_state = executor::TaskState::new();
467    let task =
468        executor::LocalFrame::new(comparison, args, None, instrs).add_lets(&lets).add_regs(&regs).task(0, &task_state);
469    let result = Arc::new(AtomicBool::new(true));
470
471    executor::start_multi(num_threads, None, vec![task], &shared_state, result.clone(), &executor::all_unsat_collector);
472
473    result.load(Ordering::Acquire)
474}