cranelift_frontend/
switch.rs

1use super::HashMap;
2use crate::frontend::FunctionBuilder;
3use alloc::vec::Vec;
4use cranelift_codegen::ir::condcodes::IntCC;
5use cranelift_codegen::ir::*;
6
7type EntryIndex = u128;
8
9/// Unlike with `br_table`, `Switch` cases may be sparse or non-0-based.
10/// They emit efficient code using branches, jump tables, or a combination of both.
11///
12/// # Example
13///
14/// ```rust
15/// # use cranelift_codegen::ir::types::*;
16/// # use cranelift_codegen::ir::{UserFuncName, Function, Signature, InstBuilder};
17/// # use cranelift_codegen::isa::CallConv;
18/// # use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Switch};
19/// #
20/// # let mut sig = Signature::new(CallConv::SystemV);
21/// # let mut fn_builder_ctx = FunctionBuilderContext::new();
22/// # let mut func = Function::with_name_signature(UserFuncName::user(0, 0), sig);
23/// # let mut builder = FunctionBuilder::new(&mut func, &mut fn_builder_ctx);
24/// #
25/// # let entry = builder.create_block();
26/// # builder.switch_to_block(entry);
27/// #
28/// let block0 = builder.create_block();
29/// let block1 = builder.create_block();
30/// let block2 = builder.create_block();
31/// let fallback = builder.create_block();
32///
33/// let val = builder.ins().iconst(I32, 1);
34///
35/// let mut switch = Switch::new();
36/// switch.set_entry(0, block0);
37/// switch.set_entry(1, block1);
38/// switch.set_entry(7, block2);
39/// switch.emit(&mut builder, val, fallback);
40/// ```
41#[derive(Debug, Default)]
42pub struct Switch {
43    cases: HashMap<EntryIndex, Block>,
44}
45
46impl Switch {
47    /// Create a new empty switch
48    pub fn new() -> Self {
49        Self {
50            cases: HashMap::new(),
51        }
52    }
53
54    /// Set a switch entry
55    pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
56        let prev = self.cases.insert(index, block);
57        assert!(prev.is_none(), "Tried to set the same entry {index} twice");
58    }
59
60    /// Get a reference to all existing entries
61    pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
62        &self.cases
63    }
64
65    /// Turn the `cases` `HashMap` into a list of `ContiguousCaseRange`s.
66    ///
67    /// # Postconditions
68    ///
69    /// * Every entry will be represented.
70    /// * The `ContiguousCaseRange`s will not overlap.
71    /// * Between two `ContiguousCaseRange`s there will be at least one entry index.
72    /// * No `ContiguousCaseRange`s will be empty.
73    fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
74        log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);
75        let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
76        cases.sort_by_key(|&(index, _)| index);
77
78        let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
79        let mut last_index = None;
80        for (index, block) in cases {
81            match last_index {
82                None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
83                Some(last_index) => {
84                    if index > last_index + 1 {
85                        contiguous_case_ranges.push(ContiguousCaseRange::new(index));
86                    }
87                }
88            }
89            contiguous_case_ranges
90                .last_mut()
91                .unwrap()
92                .blocks
93                .push(block);
94            last_index = Some(index);
95        }
96
97        log::trace!("build_contiguous_case_ranges after: {contiguous_case_ranges:#?}");
98
99        contiguous_case_ranges
100    }
101
102    /// Binary search for the right `ContiguousCaseRange`.
103    fn build_search_tree<'a>(
104        bx: &mut FunctionBuilder,
105        val: Value,
106        otherwise: Block,
107        contiguous_case_ranges: &'a [ContiguousCaseRange],
108    ) {
109        // If no switch cases were added to begin with, we can just emit `jump otherwise`.
110        if contiguous_case_ranges.is_empty() {
111            bx.ins().jump(otherwise, &[]);
112            return;
113        }
114
115        // Avoid allocation in the common case
116        if contiguous_case_ranges.len() <= 3 {
117            Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
118            return;
119        }
120
121        let mut stack = Vec::new();
122        stack.push((None, contiguous_case_ranges));
123
124        while let Some((block, contiguous_case_ranges)) = stack.pop() {
125            if let Some(block) = block {
126                bx.switch_to_block(block);
127            }
128
129            if contiguous_case_ranges.len() <= 3 {
130                Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
131            } else {
132                let split_point = contiguous_case_ranges.len() / 2;
133                let (left, right) = contiguous_case_ranges.split_at(split_point);
134
135                let left_block = bx.create_block();
136                let right_block = bx.create_block();
137
138                let first_index = right[0].first_index;
139                let should_take_right_side =
140                    icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
141                bx.ins()
142                    .brif(should_take_right_side, right_block, &[], left_block, &[]);
143
144                bx.seal_block(left_block);
145                bx.seal_block(right_block);
146
147                stack.push((Some(left_block), left));
148                stack.push((Some(right_block), right));
149            }
150        }
151    }
152
153    /// Linear search for the right `ContiguousCaseRange`.
154    fn build_search_branches<'a>(
155        bx: &mut FunctionBuilder,
156        val: Value,
157        otherwise: Block,
158        contiguous_case_ranges: &'a [ContiguousCaseRange],
159    ) {
160        for (ix, range) in contiguous_case_ranges.iter().enumerate().rev() {
161            let alternate = if ix == 0 {
162                otherwise
163            } else {
164                bx.create_block()
165            };
166
167            if range.first_index == 0 {
168                assert_eq!(alternate, otherwise);
169
170                if let Some(block) = range.single_block() {
171                    bx.ins().brif(val, otherwise, &[], block, &[]);
172                } else {
173                    Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
174                }
175            } else {
176                if let Some(block) = range.single_block() {
177                    let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
178                    bx.ins().brif(is_good_val, block, &[], alternate, &[]);
179                } else {
180                    let is_good_val = icmp_imm_u128(
181                        bx,
182                        IntCC::UnsignedGreaterThanOrEqual,
183                        val,
184                        range.first_index,
185                    );
186                    let jt_block = bx.create_block();
187                    bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
188                    bx.seal_block(jt_block);
189                    bx.switch_to_block(jt_block);
190                    Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
191                }
192            }
193
194            if alternate != otherwise {
195                bx.seal_block(alternate);
196                bx.switch_to_block(alternate);
197            }
198        }
199    }
200
201    fn build_jump_table(
202        bx: &mut FunctionBuilder,
203        val: Value,
204        otherwise: Block,
205        first_index: EntryIndex,
206        blocks: &[Block],
207    ) {
208        // There are currently no 128bit systems supported by rustc, but once we do ensure that
209        // we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
210        assert!(
211            u32::try_from(blocks.len()).is_ok(),
212            "Jump tables bigger than 2^32-1 are not yet supported"
213        );
214
215        let jt_data = JumpTableData::new(
216            bx.func.dfg.block_call(otherwise, &[]),
217            &blocks
218                .iter()
219                .map(|block| bx.func.dfg.block_call(*block, &[]))
220                .collect::<Vec<_>>(),
221        );
222        let jump_table = bx.create_jump_table(jt_data);
223
224        let discr = if first_index == 0 {
225            val
226        } else {
227            if let Ok(first_index) = u64::try_from(first_index) {
228                bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
229            } else {
230                let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
231                let lsb = bx.ins().iconst(types::I64, lsb as i64);
232                let msb = bx.ins().iconst(types::I64, msb as i64);
233                let index = bx.ins().iconcat(lsb, msb);
234                bx.ins().isub(val, index)
235            }
236        };
237
238        let discr = match bx.func.dfg.value_type(discr).bits() {
239            bits if bits > 32 => {
240                // Check for overflow of cast to u32. This is the max supported jump table entries.
241                let new_block = bx.create_block();
242                let bigger_than_u32 =
243                    bx.ins()
244                        .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
245                bx.ins()
246                    .brif(bigger_than_u32, otherwise, &[], new_block, &[]);
247                bx.seal_block(new_block);
248                bx.switch_to_block(new_block);
249
250                // Cast to i32, as br_table is not implemented for i64/i128
251                bx.ins().ireduce(types::I32, discr)
252            }
253            bits if bits < 32 => bx.ins().uextend(types::I32, discr),
254            _ => discr,
255        };
256
257        bx.ins().br_table(discr, jump_table);
258    }
259
260    /// Build the switch
261    ///
262    /// # Arguments
263    ///
264    /// * The function builder to emit to
265    /// * The value to switch on
266    /// * The default block
267    pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
268        // Validate that the type of `val` is sufficiently wide to address all cases.
269        let max = self.cases.keys().max().copied().unwrap_or(0);
270        let val_ty = bx.func.dfg.value_type(val);
271        let val_ty_max = val_ty.bounds(false).1;
272        if max > val_ty_max {
273            panic!("The index type {val_ty} does not fit the maximum switch entry of {max}");
274        }
275
276        let contiguous_case_ranges = self.collect_contiguous_case_ranges();
277        Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
278    }
279}
280
281fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
282    if bx.func.dfg.value_type(x) != types::I128 {
283        assert!(u64::try_from(y).is_ok());
284        bx.ins().icmp_imm(cond, x, y as i64)
285    } else if let Ok(index) = i64::try_from(y) {
286        bx.ins().icmp_imm(cond, x, index)
287    } else {
288        let (lsb, msb) = (y as u64, (y >> 64) as u64);
289        let lsb = bx.ins().iconst(types::I64, lsb as i64);
290        let msb = bx.ins().iconst(types::I64, msb as i64);
291        let index = bx.ins().iconcat(lsb, msb);
292        bx.ins().icmp(cond, x, index)
293    }
294}
295
296/// This represents a contiguous range of cases to switch on.
297///
298/// For example 10 => block1, 11 => block2, 12 => block7 will be represented as:
299///
300/// ```plain
301/// ContiguousCaseRange {
302///     first_index: 10,
303///     blocks: vec![Block::from_u32(1), Block::from_u32(2), Block::from_u32(7)]
304/// }
305/// ```
306#[derive(Debug)]
307struct ContiguousCaseRange {
308    /// The entry index of the first case. Eg. 10 when the entry indexes are 10, 11, 12 and 13.
309    first_index: EntryIndex,
310
311    /// The blocks to jump to sorted in ascending order of entry index.
312    blocks: Vec<Block>,
313}
314
315impl ContiguousCaseRange {
316    fn new(first_index: EntryIndex) -> Self {
317        Self {
318            first_index,
319            blocks: Vec::new(),
320        }
321    }
322
323    /// Returns `Some` block when there is only a single block in this range.
324    fn single_block(&self) -> Option<Block> {
325        if self.blocks.len() == 1 {
326            Some(self.blocks[0])
327        } else {
328            None
329        }
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use crate::frontend::FunctionBuilderContext;
337    use alloc::string::ToString;
338
339    macro_rules! setup {
340        ($default:expr, [$($index:expr,)*]) => {{
341            let mut func = Function::new();
342            let mut func_ctx = FunctionBuilderContext::new();
343            {
344                let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
345                let block = bx.create_block();
346                bx.switch_to_block(block);
347                let val = bx.ins().iconst(types::I8, 0);
348                let mut switch = Switch::new();
349                let _ = &mut switch;
350                $(
351                    let block = bx.create_block();
352                    switch.set_entry($index, block);
353                )*
354                switch.emit(&mut bx, val, Block::with_number($default).unwrap());
355            }
356            func
357                .to_string()
358                .trim_start_matches("function u0:0() fast {\n")
359                .trim_end_matches("\n}\n")
360                .to_string()
361        }};
362    }
363
364    #[test]
365    fn switch_empty() {
366        let func = setup!(42, []);
367        assert_eq_output!(
368            func,
369            "block0:
370    v0 = iconst.i8 0
371    jump block42"
372        );
373    }
374
375    #[test]
376    fn switch_zero() {
377        let func = setup!(0, [0,]);
378        assert_eq_output!(
379            func,
380            "block0:
381    v0 = iconst.i8 0
382    brif v0, block0, block1  ; v0 = 0"
383        );
384    }
385
386    #[test]
387    fn switch_single() {
388        let func = setup!(0, [1,]);
389        assert_eq_output!(
390            func,
391            "block0:
392    v0 = iconst.i8 0
393    v1 = icmp_imm eq v0, 1  ; v0 = 0
394    brif v1, block1, block0"
395        );
396    }
397
398    #[test]
399    fn switch_bool() {
400        let func = setup!(0, [0, 1,]);
401        assert_eq_output!(
402            func,
403            "block0:
404    v0 = iconst.i8 0
405    v1 = uextend.i32 v0  ; v0 = 0
406    br_table v1, block0, [block1, block2]"
407        );
408    }
409
410    #[test]
411    fn switch_two_gap() {
412        let func = setup!(0, [0, 2,]);
413        assert_eq_output!(
414            func,
415            "block0:
416    v0 = iconst.i8 0
417    v1 = icmp_imm eq v0, 2  ; v0 = 0
418    brif v1, block2, block3
419
420block3:
421    brif.i8 v0, block0, block1  ; v0 = 0"
422        );
423    }
424
425    #[test]
426    fn switch_many() {
427        let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
428        assert_eq_output!(
429            func,
430            "block0:
431    v0 = iconst.i8 0
432    v1 = icmp_imm uge v0, 7  ; v0 = 0
433    brif v1, block9, block8
434
435block9:
436    v2 = icmp_imm.i8 uge v0, 10  ; v0 = 0
437    brif v2, block11, block10
438
439block11:
440    v3 = iadd_imm.i8 v0, -10  ; v0 = 0
441    v4 = uextend.i32 v3
442    br_table v4, block0, [block5, block6, block7]
443
444block10:
445    v5 = icmp_imm.i8 eq v0, 7  ; v0 = 0
446    brif v5, block4, block0
447
448block8:
449    v6 = icmp_imm.i8 eq v0, 5  ; v0 = 0
450    brif v6, block3, block12
451
452block12:
453    v7 = uextend.i32 v0  ; v0 = 0
454    br_table v7, block0, [block1, block2]"
455        );
456    }
457
458    #[test]
459    fn switch_min_index_value() {
460        let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
461        assert_eq_output!(
462            func,
463            "block0:
464    v0 = iconst.i8 0
465    v1 = icmp_imm eq v0, -128  ; v0 = 0
466    brif v1, block1, block3
467
468block3:
469    v2 = icmp_imm.i8 eq v0, 1  ; v0 = 0
470    brif v2, block2, block0"
471        );
472    }
473
474    #[test]
475    fn switch_max_index_value() {
476        let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
477        assert_eq_output!(
478            func,
479            "block0:
480    v0 = iconst.i8 0
481    v1 = icmp_imm eq v0, 127  ; v0 = 0
482    brif v1, block1, block3
483
484block3:
485    v2 = icmp_imm.i8 eq v0, 1  ; v0 = 0
486    brif v2, block2, block0"
487        )
488    }
489
490    #[test]
491    fn switch_optimal_codegen() {
492        let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
493        assert_eq_output!(
494            func,
495            "block0:
496    v0 = iconst.i8 0
497    v1 = icmp_imm eq v0, -1  ; v0 = 0
498    brif v1, block1, block4
499
500block4:
501    v2 = uextend.i32 v0  ; v0 = 0
502    br_table v2, block0, [block2, block3]"
503        );
504    }
505
506    #[test]
507    #[should_panic(
508        expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"
509    )]
510    fn switch_rejects_small_inputs() {
511        // This is a regression test for a bug that we found where we would emit a cmp
512        // with a type that was not able to fully represent a large index.
513        //
514        // See: https://github.com/bytecodealliance/wasmtime/pull/4502#issuecomment-1191961677
515        setup!(1, [0x4100_0000_00bf_d470,]);
516    }
517
518    #[test]
519    fn switch_seal_generated_blocks() {
520        let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];
521
522        for case in cases {
523            for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {
524                eprintln!("Testing {typ:?} with keys: {case:?}");
525                do_case(case, *typ);
526            }
527        }
528
529        fn do_case(keys: &[u128], typ: Type) {
530            let mut func = Function::new();
531            let mut builder_ctx = FunctionBuilderContext::new();
532            let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
533
534            let root_block = builder.create_block();
535            let default_block = builder.create_block();
536            let mut switch = Switch::new();
537
538            let case_blocks = keys
539                .iter()
540                .map(|key| {
541                    let block = builder.create_block();
542                    switch.set_entry(*key, block);
543                    block
544                })
545                .collect::<Vec<_>>();
546
547            builder.seal_block(root_block);
548            builder.switch_to_block(root_block);
549
550            let val = builder.ins().iconst(typ, 1);
551            switch.emit(&mut builder, val, default_block);
552
553            for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
554                builder.seal_block(block);
555                builder.switch_to_block(block);
556                builder.ins().return_(&[]);
557            }
558
559            builder.finalize(); // Will panic if some blocks are not sealed
560        }
561    }
562
563    #[test]
564    fn switch_64bit() {
565        let mut func = Function::new();
566        let mut func_ctx = FunctionBuilderContext::new();
567        {
568            let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
569            let block0 = bx.create_block();
570            bx.switch_to_block(block0);
571            let val = bx.ins().iconst(types::I64, 0);
572            let mut switch = Switch::new();
573            let block1 = bx.create_block();
574            switch.set_entry(1, block1);
575            let block2 = bx.create_block();
576            switch.set_entry(0, block2);
577            let block3 = bx.create_block();
578            switch.emit(&mut bx, val, block3);
579        }
580        let func = func
581            .to_string()
582            .trim_start_matches("function u0:0() fast {\n")
583            .trim_end_matches("\n}\n")
584            .to_string();
585        assert_eq_output!(
586            func,
587            "block0:
588    v0 = iconst.i64 0
589    v1 = icmp_imm ugt v0, 0xffff_ffff  ; v0 = 0
590    brif v1, block3, block4
591
592block4:
593    v2 = ireduce.i32 v0  ; v0 = 0
594    br_table v2, block3, [block2, block1]"
595        );
596    }
597
598    #[test]
599    fn switch_128bit() {
600        let mut func = Function::new();
601        let mut func_ctx = FunctionBuilderContext::new();
602        {
603            let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
604            let block0 = bx.create_block();
605            bx.switch_to_block(block0);
606            let val = bx.ins().iconst(types::I64, 0);
607            let val = bx.ins().uextend(types::I128, val);
608            let mut switch = Switch::new();
609            let block1 = bx.create_block();
610            switch.set_entry(1, block1);
611            let block2 = bx.create_block();
612            switch.set_entry(0, block2);
613            let block3 = bx.create_block();
614            switch.emit(&mut bx, val, block3);
615        }
616        let func = func
617            .to_string()
618            .trim_start_matches("function u0:0() fast {\n")
619            .trim_end_matches("\n}\n")
620            .to_string();
621        assert_eq_output!(
622            func,
623            "block0:
624    v0 = iconst.i64 0
625    v1 = uextend.i128 v0  ; v0 = 0
626    v2 = icmp_imm ugt v1, 0xffff_ffff
627    brif v2, block3, block4
628
629block4:
630    v3 = ireduce.i32 v1
631    br_table v3, block3, [block2, block1]"
632        );
633    }
634
635    #[test]
636    fn switch_128bit_max_u64() {
637        let mut func = Function::new();
638        let mut func_ctx = FunctionBuilderContext::new();
639        {
640            let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
641            let block0 = bx.create_block();
642            bx.switch_to_block(block0);
643            let val = bx.ins().iconst(types::I64, 0);
644            let val = bx.ins().uextend(types::I128, val);
645            let mut switch = Switch::new();
646            let block1 = bx.create_block();
647            switch.set_entry(u64::MAX.into(), block1);
648            let block2 = bx.create_block();
649            switch.set_entry(0, block2);
650            let block3 = bx.create_block();
651            switch.emit(&mut bx, val, block3);
652        }
653        let func = func
654            .to_string()
655            .trim_start_matches("function u0:0() fast {\n")
656            .trim_end_matches("\n}\n")
657            .to_string();
658        assert_eq_output!(
659            func,
660            "block0:
661    v0 = iconst.i64 0
662    v1 = uextend.i128 v0  ; v0 = 0
663    v2 = iconst.i64 -1
664    v3 = iconst.i64 0
665    v4 = iconcat v2, v3  ; v2 = -1, v3 = 0
666    v5 = icmp eq v1, v4
667    brif v5, block1, block4
668
669block4:
670    brif.i128 v1, block3, block2"
671        );
672    }
673}