sway_ir/optimize/
cse.rs

1//! Value numbering based common subexpression elimination.
2//! Reference: Value Driven Redundancy Elimination - Loren Taylor Simpson.
3
4use core::panic;
5use itertools::Itertools;
6use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
7use slotmap::Key;
8use std::{
9    collections::hash_map,
10    fmt::Debug,
11    hash::{Hash, Hasher},
12};
13
14use crate::{
15    AnalysisResults, BinaryOpKind, Context, DebugWithContext, DomTree, Function, InstOp, IrError,
16    Pass, PassMutability, PostOrder, Predicate, ScopedPass, Type, UnaryOpKind, Value,
17    DOMINATORS_NAME, POSTORDER_NAME,
18};
19
20pub const CSE_NAME: &str = "cse";
21
22pub fn create_cse_pass() -> Pass {
23    Pass {
24        name: CSE_NAME,
25        descr: "Common subexpression elimination",
26        runner: ScopedPass::FunctionPass(PassMutability::Transform(cse)),
27        deps: vec![POSTORDER_NAME, DOMINATORS_NAME],
28    }
29}
30
31#[derive(Clone, Copy, Eq, PartialEq, Hash, DebugWithContext)]
32enum ValueNumber {
33    // Top of the lattice = Don't know = uninitialized
34    Top,
35    // Belongs to a congruence class represented by the inner value.
36    Number(Value),
37}
38
39impl Debug for ValueNumber {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            Self::Top => write!(f, "Top"),
43            Self::Number(arg0) => write!(f, "v{:?}", arg0.0.data()),
44        }
45    }
46}
47
48#[derive(Clone, Debug, Eq, PartialEq, Hash, DebugWithContext)]
49enum Expr {
50    Phi(Vec<ValueNumber>),
51    UnaryOp {
52        op: UnaryOpKind,
53        arg: ValueNumber,
54    },
55    BinaryOp {
56        op: BinaryOpKind,
57        arg1: ValueNumber,
58        arg2: ValueNumber,
59    },
60    BitCast(ValueNumber, Type),
61    CastPtr(ValueNumber, Type),
62    Cmp(Predicate, ValueNumber, ValueNumber),
63    GetElemPtr {
64        base: ValueNumber,
65        elem_ptr_ty: Type,
66        indices: Vec<ValueNumber>,
67    },
68    IntToPtr(ValueNumber, Type),
69    PtrToInt(ValueNumber, Type),
70}
71
72/// Convert an instruction to an expression for hashing
73/// Instructions that we don't handle will have their value numbers be equal to themselves.
74fn instr_to_expr(context: &Context, vntable: &VNTable, instr: Value) -> Option<Expr> {
75    match &instr.get_instruction(context).unwrap().op {
76        InstOp::AsmBlock(_, _) => None,
77        InstOp::UnaryOp { op, arg } => Some(Expr::UnaryOp {
78            op: *op,
79            arg: vntable.value_map.get(arg).cloned().unwrap(),
80        }),
81        InstOp::BinaryOp { op, arg1, arg2 } => Some(Expr::BinaryOp {
82            op: *op,
83            arg1: vntable.value_map.get(arg1).cloned().unwrap(),
84            arg2: vntable.value_map.get(arg2).cloned().unwrap(),
85        }),
86        InstOp::BitCast(val, ty) => Some(Expr::BitCast(
87            vntable.value_map.get(val).cloned().unwrap(),
88            *ty,
89        )),
90        InstOp::Branch(_) => None,
91        InstOp::Call(_, _) => None,
92        InstOp::CastPtr(val, ty) => Some(Expr::CastPtr(
93            vntable.value_map.get(val).cloned().unwrap(),
94            *ty,
95        )),
96        InstOp::Cmp(pred, val1, val2) => Some(Expr::Cmp(
97            *pred,
98            vntable.value_map.get(val1).cloned().unwrap(),
99            vntable.value_map.get(val2).cloned().unwrap(),
100        )),
101        InstOp::ConditionalBranch { .. } => None,
102        InstOp::ContractCall { .. } => None,
103        InstOp::FuelVm(_) => None,
104        InstOp::GetLocal(_) => None,
105        InstOp::GetGlobal(_) => None,
106        InstOp::GetConfig(_, _) => None,
107        InstOp::GetStorageKey(_) => None,
108        InstOp::GetElemPtr {
109            base,
110            elem_ptr_ty,
111            indices,
112        } => Some(Expr::GetElemPtr {
113            base: vntable.value_map.get(base).cloned().unwrap(),
114            elem_ptr_ty: *elem_ptr_ty,
115            indices: indices
116                .iter()
117                .map(|idx| vntable.value_map.get(idx).cloned().unwrap())
118                .collect(),
119        }),
120        InstOp::IntToPtr(val, ty) => Some(Expr::IntToPtr(
121            vntable.value_map.get(val).cloned().unwrap(),
122            *ty,
123        )),
124        InstOp::Load(_) => None,
125        InstOp::Alloc { .. } => None,
126        InstOp::MemCopyBytes { .. } => None,
127        InstOp::MemCopyVal { .. } => None,
128        InstOp::MemClearVal { .. } => None,
129        InstOp::Nop => None,
130        InstOp::PtrToInt(val, ty) => Some(Expr::PtrToInt(
131            vntable.value_map.get(val).cloned().unwrap(),
132            *ty,
133        )),
134        InstOp::Ret(_, _) => None,
135        InstOp::Store { .. } => None,
136    }
137}
138
139/// Convert a PHI argument to Expr
140fn phi_to_expr(context: &Context, vntable: &VNTable, phi_arg: Value) -> Expr {
141    let phi_arg = phi_arg.get_argument(context).unwrap();
142    let phi_args = phi_arg
143        .block
144        .pred_iter(context)
145        .map(|pred| {
146            let incoming_val = phi_arg
147                .get_val_coming_from(context, pred)
148                .expect("No parameter from predecessor");
149            vntable.value_map.get(&incoming_val).cloned().unwrap()
150        })
151        .collect();
152    Expr::Phi(phi_args)
153}
154
155#[derive(Default)]
156struct VNTable {
157    value_map: FxHashMap<Value, ValueNumber>,
158    expr_map: FxHashMap<Expr, ValueNumber>,
159}
160
161impl Debug for VNTable {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        writeln!(f, "value_map:")?;
164        self.value_map.iter().for_each(|(key, value)| {
165            if format!("v{:?}", key.0.data()) == "v620v3" {
166                writeln!(f, "\tv{:?} -> {:?}", key.0.data(), value).expect("writeln! failed");
167            }
168        });
169        Ok(())
170    }
171}
172
173/// Wrapper around [DomTree::dominates] to work at instruction level.
174/// Does `inst1` dominate `inst2` ?
175fn dominates(context: &Context, dom_tree: &DomTree, inst1: Value, inst2: Value) -> bool {
176    let block1 = match &context.values[inst1.0].value {
177        crate::ValueDatum::Argument(arg) => arg.block,
178        crate::ValueDatum::Constant(_) => {
179            panic!("Shouldn't be querying dominance info for constants")
180        }
181        crate::ValueDatum::Instruction(i) => i.parent,
182    };
183    let block2 = match &context.values[inst2.0].value {
184        crate::ValueDatum::Argument(arg) => arg.block,
185        crate::ValueDatum::Constant(_) => {
186            panic!("Shouldn't be querying dominance info for constants")
187        }
188        crate::ValueDatum::Instruction(i) => i.parent,
189    };
190
191    if block1 == block2 {
192        let inst1_idx = block1
193            .instruction_iter(context)
194            .position(|inst| inst == inst1)
195            // Not found indicates a block argument
196            .unwrap_or(0);
197        let inst2_idx = block1
198            .instruction_iter(context)
199            .position(|inst| inst == inst2)
200            // Not found indicates a block argument
201            .unwrap_or(0);
202        inst1_idx < inst2_idx
203    } else {
204        dom_tree.dominates(block1, block2)
205    }
206}
207
208pub fn cse(
209    context: &mut Context,
210    analyses: &AnalysisResults,
211    function: Function,
212) -> Result<bool, IrError> {
213    let mut vntable = VNTable::default();
214
215    // Function arg values map to themselves.
216    for arg in function.args_iter(context) {
217        vntable.value_map.insert(arg.1, ValueNumber::Number(arg.1));
218    }
219
220    // Map all other arg values map to Top.
221    for block in function.block_iter(context).skip(1) {
222        for arg in block.arg_iter(context) {
223            vntable.value_map.insert(*arg, ValueNumber::Top);
224        }
225    }
226
227    // Initialize all instructions and constants. Constants need special treatment.
228    // They don't have PartialEq implemented. So we need to value number them manually.
229    // This map maps the hash of a constant value to all possible collisions of it.
230    let mut const_map = FxHashMap::<u64, Vec<Value>>::default();
231    for (_, inst) in function.instruction_iter(context) {
232        vntable.value_map.insert(inst, ValueNumber::Top);
233        for (const_opd_val, const_opd_const) in inst
234            .get_instruction(context)
235            .unwrap()
236            .op
237            .get_operands()
238            .iter()
239            .filter_map(|opd| opd.get_constant(context).map(|copd| (opd, copd)))
240        {
241            let mut state = FxHasher::default();
242            const_opd_const.hash(&mut state);
243            let hash = state.finish();
244            if let Some(existing_const) = const_map.get(&hash).and_then(|consts| {
245                consts.iter().find(|val| {
246                    let c = val
247                        .get_constant(context)
248                        .expect("const_map can only contain consts");
249                    const_opd_const == c
250                })
251            }) {
252                vntable
253                    .value_map
254                    .insert(*const_opd_val, ValueNumber::Number(*existing_const));
255            } else {
256                const_map
257                    .entry(hash)
258                    .and_modify(|consts| consts.push(*const_opd_val))
259                    .or_insert_with(|| vec![*const_opd_val]);
260                vntable
261                    .value_map
262                    .insert(*const_opd_val, ValueNumber::Number(*const_opd_val));
263            }
264        }
265    }
266
267    // We need to iterate over the blocks in RPO.
268    let post_order: &PostOrder = analyses.get_analysis_result(function);
269
270    // RPO based value number (Sec 4.2).
271    let mut changed = true;
272    while changed {
273        changed = false;
274        // For each block in RPO:
275        for (block_idx, block) in post_order.po_to_block.iter().rev().enumerate() {
276            // Process PHIs and then the other instructions.
277            if block_idx != 0 {
278                // Entry block arguments are not PHIs.
279                for (phi, expr_opt) in block
280                    .arg_iter(context)
281                    .map(|arg| (*arg, Some(phi_to_expr(context, &vntable, *arg))))
282                    .collect_vec()
283                {
284                    let expr = expr_opt.expect("PHIs must always translate to a valid Expr");
285                    // We first try to see if PHIs can be simplified into a single value.
286                    let vn = {
287                        let Expr::Phi(ref phi_args) = expr else {
288                            panic!("Expr must be a PHI")
289                        };
290                        phi_args
291                            .iter()
292                            .map(|vn| Some(*vn))
293                            .reduce(|vn1, vn2| {
294                                // Here `None` indicates Bottom of the lattice.
295                                if let (Some(vn1), Some(vn2)) = (vn1, vn2) {
296                                    match (vn1, vn2) {
297                                        (ValueNumber::Top, ValueNumber::Top) => {
298                                            Some(ValueNumber::Top)
299                                        }
300                                        (ValueNumber::Top, ValueNumber::Number(vn))
301                                        | (ValueNumber::Number(vn), ValueNumber::Top) => {
302                                            Some(ValueNumber::Number(vn))
303                                        }
304                                        (ValueNumber::Number(vn1), ValueNumber::Number(vn2)) => {
305                                            (vn1 == vn2).then_some(ValueNumber::Number(vn1))
306                                        }
307                                    }
308                                } else {
309                                    None
310                                }
311                            })
312                            .flatten()
313                            // The PHI couldn't be simplified to a single ValueNumber.
314                            .unwrap_or(ValueNumber::Number(phi))
315                    };
316
317                    match vntable.value_map.entry(phi) {
318                        hash_map::Entry::Occupied(occ) if *occ.get() == vn => {}
319                        _ => {
320                            changed = true;
321                            vntable.value_map.insert(phi, vn);
322                        }
323                    }
324                }
325            }
326
327            for (inst, expr_opt) in block
328                .instruction_iter(context)
329                .map(|instr| (instr, instr_to_expr(context, &vntable, instr)))
330                .collect_vec()
331            {
332                // lookup(expr, x)
333                let vn = if let Some(expr) = expr_opt {
334                    match vntable.expr_map.entry(expr) {
335                        hash_map::Entry::Occupied(occ) => *occ.get(),
336                        hash_map::Entry::Vacant(vac) => *(vac.insert(ValueNumber::Number(inst))),
337                    }
338                } else {
339                    // Instructions that always map to their own value number
340                    // (i.e., they can never be equal to some other instruction).
341                    ValueNumber::Number(inst)
342                };
343                match vntable.value_map.entry(inst) {
344                    hash_map::Entry::Occupied(occ) if *occ.get() == vn => {}
345                    _ => {
346                        changed = true;
347                        vntable.value_map.insert(inst, vn);
348                    }
349                }
350            }
351        }
352        vntable.expr_map.clear();
353    }
354
355    // create a partition of congruent (equal) values.
356    let mut partition = FxHashMap::<ValueNumber, FxHashSet<Value>>::default();
357    vntable.value_map.iter().for_each(|(v, vn)| {
358        // If v is a constant or its value number is itself, don't add to the partition.
359        // The latter condition is so that we have only > 1 sized partitions.
360        if v.is_constant(context)
361            || matches!(vn, ValueNumber::Top)
362            || matches!(vn, ValueNumber::Number(v2) if (v == v2 || v2.is_constant(context)))
363        {
364            return;
365        }
366        partition
367            .entry(*vn)
368            .and_modify(|part| {
369                part.insert(*v);
370            })
371            .or_insert(vec![*v].into_iter().collect());
372    });
373
374    // For convenience, now add back back `v` into `partition[VN[v]]` if it isn't already there.
375    partition.iter_mut().for_each(|(vn, v_part)| {
376        let ValueNumber::Number(v) = vn else {
377            panic!("We cannot have Top at this point");
378        };
379        v_part.insert(*v);
380        assert!(
381            v_part.len() > 1,
382            "We've only created partitions with size greater than 1"
383        );
384    });
385
386    // There are two ways to replace congruent values (see the paper cited, Sec 5).
387    // 1. Dominator based. If v1 and v2 are equal, v1 dominates v2, we just remove v2
388    // and replace its uses with v1. Simple, and what we're going to do.
389    // 2. AVAIL based. More powerful, but also requires a data-flow-analysis for AVAIL
390    // and later on, mem2reg again since replacements will need breaking SSA.
391    let dom_tree: &DomTree = analyses.get_analysis_result(function);
392    let mut replace_map = FxHashMap::<Value, Value>::default();
393    let mut modified = false;
394    // Check every set in the partition.
395    partition.iter().for_each(|(_leader, vals)| {
396        // Iterate over every pair of values, checking if one can replace the other.
397        for v_pair in vals.iter().combinations(2) {
398            let (v1, v2) = (*v_pair[0], *v_pair[1]);
399            if dominates(context, dom_tree, v1, v2) {
400                modified = true;
401                replace_map.insert(v2, v1);
402            } else if dominates(context, dom_tree, v2, v1) {
403                modified = true;
404                replace_map.insert(v1, v2);
405            }
406        }
407    });
408
409    function.replace_values(context, &replace_map, None);
410
411    Ok(modified)
412}