cranelift_frontend/
switch.rs

1use super::HashMap;
2use crate::frontend::FunctionBuilder;
3use alloc::vec::Vec;
4use cranelift_codegen::ir::condcodes::IntCC;
5use cranelift_codegen::ir::*;
6
7type EntryIndex = u128;
8
9/// Unlike with `br_table`, `Switch` cases may be sparse or non-0-based.
10/// They emit efficient code using branches, jump tables, or a combination of both.
11///
12/// # Example
13///
14/// ```rust
15/// # use cranelift_codegen::ir::types::*;
16/// # use cranelift_codegen::ir::{UserFuncName, Function, Signature, InstBuilder};
17/// # use cranelift_codegen::isa::CallConv;
18/// # use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Switch};
19/// #
20/// # let mut sig = Signature::new(CallConv::SystemV);
21/// # let mut fn_builder_ctx = FunctionBuilderContext::new();
22/// # let mut func = Function::with_name_signature(UserFuncName::user(0, 0), sig);
23/// # let mut builder = FunctionBuilder::new(&mut func, &mut fn_builder_ctx);
24/// #
25/// # let entry = builder.create_block();
26/// # builder.switch_to_block(entry);
27/// #
28/// let block0 = builder.create_block();
29/// let block1 = builder.create_block();
30/// let block2 = builder.create_block();
31/// let fallback = builder.create_block();
32///
33/// let val = builder.ins().iconst(I32, 1);
34///
35/// let mut switch = Switch::new();
36/// switch.set_entry(0, block0);
37/// switch.set_entry(1, block1);
38/// switch.set_entry(7, block2);
39/// switch.emit(&mut builder, val, fallback);
40/// ```
41#[derive(Debug, Default)]
42pub struct Switch {
43    cases: HashMap<EntryIndex, Block>,
44}
45
46impl Switch {
47    /// Create a new empty switch
48    pub fn new() -> Self {
49        Self {
50            cases: HashMap::new(),
51        }
52    }
53
54    /// Set a switch entry
55    pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
56        let prev = self.cases.insert(index, block);
57        assert!(
58            prev.is_none(),
59            "Tried to set the same entry {} twice",
60            index
61        );
62    }
63
64    /// Get a reference to all existing entries
65    pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
66        &self.cases
67    }
68
69    /// Turn the `cases` `HashMap` into a list of `ContiguousCaseRange`s.
70    ///
71    /// # Postconditions
72    ///
73    /// * Every entry will be represented.
74    /// * The `ContiguousCaseRange`s will not overlap.
75    /// * Between two `ContiguousCaseRange`s there will be at least one entry index.
76    /// * No `ContiguousCaseRange`s will be empty.
77    fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
78        log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);
79        let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
80        cases.sort_by_key(|&(index, _)| index);
81
82        let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
83        let mut last_index = None;
84        for (index, block) in cases {
85            match last_index {
86                None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
87                Some(last_index) => {
88                    if index > last_index + 1 {
89                        contiguous_case_ranges.push(ContiguousCaseRange::new(index));
90                    }
91                }
92            }
93            contiguous_case_ranges
94                .last_mut()
95                .unwrap()
96                .blocks
97                .push(block);
98            last_index = Some(index);
99        }
100
101        log::trace!(
102            "build_contiguous_case_ranges after: {:#?}",
103            contiguous_case_ranges
104        );
105
106        contiguous_case_ranges
107    }
108
109    /// Binary search for the right `ContiguousCaseRange`.
110    fn build_search_tree<'a>(
111        bx: &mut FunctionBuilder,
112        val: Value,
113        otherwise: Block,
114        contiguous_case_ranges: &'a [ContiguousCaseRange],
115    ) {
116        // If no switch cases were added to begin with, we can just emit `jump otherwise`.
117        if contiguous_case_ranges.is_empty() {
118            bx.ins().jump(otherwise, &[]);
119            return;
120        }
121
122        // Avoid allocation in the common case
123        if contiguous_case_ranges.len() <= 3 {
124            Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
125            return;
126        }
127
128        let mut stack = Vec::new();
129        stack.push((None, contiguous_case_ranges));
130
131        while let Some((block, contiguous_case_ranges)) = stack.pop() {
132            if let Some(block) = block {
133                bx.switch_to_block(block);
134            }
135
136            if contiguous_case_ranges.len() <= 3 {
137                Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
138            } else {
139                let split_point = contiguous_case_ranges.len() / 2;
140                let (left, right) = contiguous_case_ranges.split_at(split_point);
141
142                let left_block = bx.create_block();
143                let right_block = bx.create_block();
144
145                let first_index = right[0].first_index;
146                let should_take_right_side =
147                    icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
148                bx.ins()
149                    .brif(should_take_right_side, right_block, &[], left_block, &[]);
150
151                bx.seal_block(left_block);
152                bx.seal_block(right_block);
153
154                stack.push((Some(left_block), left));
155                stack.push((Some(right_block), right));
156            }
157        }
158    }
159
160    /// Linear search for the right `ContiguousCaseRange`.
161    fn build_search_branches<'a>(
162        bx: &mut FunctionBuilder,
163        val: Value,
164        otherwise: Block,
165        contiguous_case_ranges: &'a [ContiguousCaseRange],
166    ) {
167        for (ix, range) in contiguous_case_ranges.iter().enumerate().rev() {
168            let alternate = if ix == 0 {
169                otherwise
170            } else {
171                bx.create_block()
172            };
173
174            if range.first_index == 0 {
175                assert_eq!(alternate, otherwise);
176
177                if let Some(block) = range.single_block() {
178                    bx.ins().brif(val, otherwise, &[], block, &[]);
179                } else {
180                    Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
181                }
182            } else {
183                if let Some(block) = range.single_block() {
184                    let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
185                    bx.ins().brif(is_good_val, block, &[], alternate, &[]);
186                } else {
187                    let is_good_val = icmp_imm_u128(
188                        bx,
189                        IntCC::UnsignedGreaterThanOrEqual,
190                        val,
191                        range.first_index,
192                    );
193                    let jt_block = bx.create_block();
194                    bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
195                    bx.seal_block(jt_block);
196                    bx.switch_to_block(jt_block);
197                    Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
198                }
199            }
200
201            if alternate != otherwise {
202                bx.seal_block(alternate);
203                bx.switch_to_block(alternate);
204            }
205        }
206    }
207
208    fn build_jump_table(
209        bx: &mut FunctionBuilder,
210        val: Value,
211        otherwise: Block,
212        first_index: EntryIndex,
213        blocks: &[Block],
214    ) {
215        // There are currently no 128bit systems supported by rustc, but once we do ensure that
216        // we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
217        assert!(
218            u32::try_from(blocks.len()).is_ok(),
219            "Jump tables bigger than 2^32-1 are not yet supported"
220        );
221
222        let jt_data = JumpTableData::new(
223            bx.func.dfg.block_call(otherwise, &[]),
224            &blocks
225                .iter()
226                .map(|block| bx.func.dfg.block_call(*block, &[]))
227                .collect::<Vec<_>>(),
228        );
229        let jump_table = bx.create_jump_table(jt_data);
230
231        let discr = if first_index == 0 {
232            val
233        } else {
234            if let Ok(first_index) = u64::try_from(first_index) {
235                bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
236            } else {
237                let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
238                let lsb = bx.ins().iconst(types::I64, lsb as i64);
239                let msb = bx.ins().iconst(types::I64, msb as i64);
240                let index = bx.ins().iconcat(lsb, msb);
241                bx.ins().isub(val, index)
242            }
243        };
244
245        let discr = match bx.func.dfg.value_type(discr).bits() {
246            bits if bits > 32 => {
247                // Check for overflow of cast to u32. This is the max supported jump table entries.
248                let new_block = bx.create_block();
249                let bigger_than_u32 =
250                    bx.ins()
251                        .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
252                bx.ins()
253                    .brif(bigger_than_u32, otherwise, &[], new_block, &[]);
254                bx.seal_block(new_block);
255                bx.switch_to_block(new_block);
256
257                // Cast to i32, as br_table is not implemented for i64/i128
258                bx.ins().ireduce(types::I32, discr)
259            }
260            bits if bits < 32 => bx.ins().uextend(types::I32, discr),
261            _ => discr,
262        };
263
264        bx.ins().br_table(discr, jump_table);
265    }
266
267    /// Build the switch
268    ///
269    /// # Arguments
270    ///
271    /// * The function builder to emit to
272    /// * The value to switch on
273    /// * The default block
274    pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
275        // Validate that the type of `val` is sufficiently wide to address all cases.
276        let max = self.cases.keys().max().copied().unwrap_or(0);
277        let val_ty = bx.func.dfg.value_type(val);
278        let val_ty_max = val_ty.bounds(false).1;
279        if max > val_ty_max {
280            panic!(
281                "The index type {} does not fit the maximum switch entry of {}",
282                val_ty, max
283            );
284        }
285
286        let contiguous_case_ranges = self.collect_contiguous_case_ranges();
287        Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
288    }
289}
290
291fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
292    if bx.func.dfg.value_type(x) != types::I128 {
293        assert!(u64::try_from(y).is_ok());
294        bx.ins().icmp_imm(cond, x, y as i64)
295    } else if let Ok(index) = i64::try_from(y) {
296        bx.ins().icmp_imm(cond, x, index)
297    } else {
298        let (lsb, msb) = (y as u64, (y >> 64) as u64);
299        let lsb = bx.ins().iconst(types::I64, lsb as i64);
300        let msb = bx.ins().iconst(types::I64, msb as i64);
301        let index = bx.ins().iconcat(lsb, msb);
302        bx.ins().icmp(cond, x, index)
303    }
304}
305
306/// This represents a contiguous range of cases to switch on.
307///
308/// For example 10 => block1, 11 => block2, 12 => block7 will be represented as:
309///
310/// ```plain
311/// ContiguousCaseRange {
312///     first_index: 10,
313///     blocks: vec![Block::from_u32(1), Block::from_u32(2), Block::from_u32(7)]
314/// }
315/// ```
316#[derive(Debug)]
317struct ContiguousCaseRange {
318    /// The entry index of the first case. Eg. 10 when the entry indexes are 10, 11, 12 and 13.
319    first_index: EntryIndex,
320
321    /// The blocks to jump to sorted in ascending order of entry index.
322    blocks: Vec<Block>,
323}
324
325impl ContiguousCaseRange {
326    fn new(first_index: EntryIndex) -> Self {
327        Self {
328            first_index,
329            blocks: Vec::new(),
330        }
331    }
332
333    /// Returns `Some` block when there is only a single block in this range.
334    fn single_block(&self) -> Option<Block> {
335        if self.blocks.len() == 1 {
336            Some(self.blocks[0])
337        } else {
338            None
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use crate::frontend::FunctionBuilderContext;
347    use alloc::string::ToString;
348
349    macro_rules! setup {
350        ($default:expr, [$($index:expr,)*]) => {{
351            let mut func = Function::new();
352            let mut func_ctx = FunctionBuilderContext::new();
353            {
354                let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
355                let block = bx.create_block();
356                bx.switch_to_block(block);
357                let val = bx.ins().iconst(types::I8, 0);
358                #[allow(unused_mut)]
359                let mut switch = Switch::new();
360                $(
361                    let block = bx.create_block();
362                    switch.set_entry($index, block);
363                )*
364                switch.emit(&mut bx, val, Block::with_number($default).unwrap());
365            }
366            func
367                .to_string()
368                .trim_start_matches("function u0:0() fast {\n")
369                .trim_end_matches("\n}\n")
370                .to_string()
371        }};
372    }
373
374    #[test]
375    fn switch_empty() {
376        let func = setup!(42, []);
377        assert_eq_output!(
378            func,
379            "block0:
380    v0 = iconst.i8 0
381    jump block42"
382        );
383    }
384
385    #[test]
386    fn switch_zero() {
387        let func = setup!(0, [0,]);
388        assert_eq_output!(
389            func,
390            "block0:
391    v0 = iconst.i8 0
392    brif v0, block0, block1  ; v0 = 0"
393        );
394    }
395
396    #[test]
397    fn switch_single() {
398        let func = setup!(0, [1,]);
399        assert_eq_output!(
400            func,
401            "block0:
402    v0 = iconst.i8 0
403    v1 = icmp_imm eq v0, 1  ; v0 = 0
404    brif v1, block1, block0"
405        );
406    }
407
408    #[test]
409    fn switch_bool() {
410        let func = setup!(0, [0, 1,]);
411        assert_eq_output!(
412            func,
413            "block0:
414    v0 = iconst.i8 0
415    v1 = uextend.i32 v0  ; v0 = 0
416    br_table v1, block0, [block1, block2]"
417        );
418    }
419
420    #[test]
421    fn switch_two_gap() {
422        let func = setup!(0, [0, 2,]);
423        assert_eq_output!(
424            func,
425            "block0:
426    v0 = iconst.i8 0
427    v1 = icmp_imm eq v0, 2  ; v0 = 0
428    brif v1, block2, block3
429
430block3:
431    brif.i8 v0, block0, block1  ; v0 = 0"
432        );
433    }
434
435    #[test]
436    fn switch_many() {
437        let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
438        assert_eq_output!(
439            func,
440            "block0:
441    v0 = iconst.i8 0
442    v1 = icmp_imm uge v0, 7  ; v0 = 0
443    brif v1, block9, block8
444
445block9:
446    v2 = icmp_imm.i8 uge v0, 10  ; v0 = 0
447    brif v2, block11, block10
448
449block11:
450    v3 = iadd_imm.i8 v0, -10  ; v0 = 0
451    v4 = uextend.i32 v3
452    br_table v4, block0, [block5, block6, block7]
453
454block10:
455    v5 = icmp_imm.i8 eq v0, 7  ; v0 = 0
456    brif v5, block4, block0
457
458block8:
459    v6 = icmp_imm.i8 eq v0, 5  ; v0 = 0
460    brif v6, block3, block12
461
462block12:
463    v7 = uextend.i32 v0  ; v0 = 0
464    br_table v7, block0, [block1, block2]"
465        );
466    }
467
468    #[test]
469    fn switch_min_index_value() {
470        let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
471        assert_eq_output!(
472            func,
473            "block0:
474    v0 = iconst.i8 0
475    v1 = icmp_imm eq v0, -128  ; v0 = 0
476    brif v1, block1, block3
477
478block3:
479    v2 = icmp_imm.i8 eq v0, 1  ; v0 = 0
480    brif v2, block2, block0"
481        );
482    }
483
484    #[test]
485    fn switch_max_index_value() {
486        let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
487        assert_eq_output!(
488            func,
489            "block0:
490    v0 = iconst.i8 0
491    v1 = icmp_imm eq v0, 127  ; v0 = 0
492    brif v1, block1, block3
493
494block3:
495    v2 = icmp_imm.i8 eq v0, 1  ; v0 = 0
496    brif v2, block2, block0"
497        )
498    }
499
500    #[test]
501    fn switch_optimal_codegen() {
502        let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
503        assert_eq_output!(
504            func,
505            "block0:
506    v0 = iconst.i8 0
507    v1 = icmp_imm eq v0, -1  ; v0 = 0
508    brif v1, block1, block4
509
510block4:
511    v2 = uextend.i32 v0  ; v0 = 0
512    br_table v2, block0, [block2, block3]"
513        );
514    }
515
516    #[test]
517    #[should_panic(
518        expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"
519    )]
520    fn switch_rejects_small_inputs() {
521        // This is a regression test for a bug that we found where we would emit a cmp
522        // with a type that was not able to fully represent a large index.
523        //
524        // See: https://github.com/bytecodealliance/wasmtime/pull/4502#issuecomment-1191961677
525        setup!(1, [0x4100_0000_00bf_d470,]);
526    }
527
528    #[test]
529    fn switch_seal_generated_blocks() {
530        let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];
531
532        for case in cases {
533            for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {
534                eprintln!("Testing {:?} with keys: {:?}", typ, case);
535                do_case(case, *typ);
536            }
537        }
538
539        fn do_case(keys: &[u128], typ: Type) {
540            let mut func = Function::new();
541            let mut builder_ctx = FunctionBuilderContext::new();
542            let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
543
544            let root_block = builder.create_block();
545            let default_block = builder.create_block();
546            let mut switch = Switch::new();
547
548            let case_blocks = keys
549                .iter()
550                .map(|key| {
551                    let block = builder.create_block();
552                    switch.set_entry(*key, block);
553                    block
554                })
555                .collect::<Vec<_>>();
556
557            builder.seal_block(root_block);
558            builder.switch_to_block(root_block);
559
560            let val = builder.ins().iconst(typ, 1);
561            switch.emit(&mut builder, val, default_block);
562
563            for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
564                builder.seal_block(block);
565                builder.switch_to_block(block);
566                builder.ins().return_(&[]);
567            }
568
569            builder.finalize(); // Will panic if some blocks are not sealed
570        }
571    }
572
573    #[test]
574    fn switch_64bit() {
575        let mut func = Function::new();
576        let mut func_ctx = FunctionBuilderContext::new();
577        {
578            let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
579            let block0 = bx.create_block();
580            bx.switch_to_block(block0);
581            let val = bx.ins().iconst(types::I64, 0);
582            let mut switch = Switch::new();
583            let block1 = bx.create_block();
584            switch.set_entry(1, block1);
585            let block2 = bx.create_block();
586            switch.set_entry(0, block2);
587            let block3 = bx.create_block();
588            switch.emit(&mut bx, val, block3);
589        }
590        let func = func
591            .to_string()
592            .trim_start_matches("function u0:0() fast {\n")
593            .trim_end_matches("\n}\n")
594            .to_string();
595        assert_eq_output!(
596            func,
597            "block0:
598    v0 = iconst.i64 0
599    v1 = icmp_imm ugt v0, 0xffff_ffff  ; v0 = 0
600    brif v1, block3, block4
601
602block4:
603    v2 = ireduce.i32 v0  ; v0 = 0
604    br_table v2, block3, [block2, block1]"
605        );
606    }
607
608    #[test]
609    fn switch_128bit() {
610        let mut func = Function::new();
611        let mut func_ctx = FunctionBuilderContext::new();
612        {
613            let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
614            let block0 = bx.create_block();
615            bx.switch_to_block(block0);
616            let val = bx.ins().iconst(types::I64, 0);
617            let val = bx.ins().uextend(types::I128, val);
618            let mut switch = Switch::new();
619            let block1 = bx.create_block();
620            switch.set_entry(1, block1);
621            let block2 = bx.create_block();
622            switch.set_entry(0, block2);
623            let block3 = bx.create_block();
624            switch.emit(&mut bx, val, block3);
625        }
626        let func = func
627            .to_string()
628            .trim_start_matches("function u0:0() fast {\n")
629            .trim_end_matches("\n}\n")
630            .to_string();
631        assert_eq_output!(
632            func,
633            "block0:
634    v0 = iconst.i64 0
635    v1 = uextend.i128 v0  ; v0 = 0
636    v2 = icmp_imm ugt v1, 0xffff_ffff
637    brif v2, block3, block4
638
639block4:
640    v3 = ireduce.i32 v1
641    br_table v3, block3, [block2, block1]"
642        );
643    }
644
645    #[test]
646    fn switch_128bit_max_u64() {
647        let mut func = Function::new();
648        let mut func_ctx = FunctionBuilderContext::new();
649        {
650            let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
651            let block0 = bx.create_block();
652            bx.switch_to_block(block0);
653            let val = bx.ins().iconst(types::I64, 0);
654            let val = bx.ins().uextend(types::I128, val);
655            let mut switch = Switch::new();
656            let block1 = bx.create_block();
657            switch.set_entry(u64::MAX.into(), block1);
658            let block2 = bx.create_block();
659            switch.set_entry(0, block2);
660            let block3 = bx.create_block();
661            switch.emit(&mut bx, val, block3);
662        }
663        let func = func
664            .to_string()
665            .trim_start_matches("function u0:0() fast {\n")
666            .trim_end_matches("\n}\n")
667            .to_string();
668        assert_eq_output!(
669            func,
670            "block0:
671    v0 = iconst.i64 0
672    v1 = uextend.i128 v0  ; v0 = 0
673    v2 = iconst.i64 -1
674    v3 = iconst.i64 0
675    v4 = iconcat v2, v3  ; v2 = -1, v3 = 0
676    v5 = icmp eq v1, v4
677    brif v5, block1, block4
678
679block4:
680    brif.i128 v1, block3, block2"
681        );
682    }
683}