1use super::HashMap;
2use crate::frontend::FunctionBuilder;
3use alloc::vec::Vec;
4use cranelift_codegen::ir::condcodes::IntCC;
5use cranelift_codegen::ir::*;
6
7type EntryIndex = u128;
8
9#[derive(Debug, Default)]
42pub struct Switch {
43 cases: HashMap<EntryIndex, Block>,
44}
45
46impl Switch {
47 pub fn new() -> Self {
49 Self {
50 cases: HashMap::new(),
51 }
52 }
53
54 pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
56 let prev = self.cases.insert(index, block);
57 assert!(
58 prev.is_none(),
59 "Tried to set the same entry {} twice",
60 index
61 );
62 }
63
64 pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
66 &self.cases
67 }
68
69 fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
78 log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);
79 let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
80 cases.sort_by_key(|&(index, _)| index);
81
82 let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
83 let mut last_index = None;
84 for (index, block) in cases {
85 match last_index {
86 None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
87 Some(last_index) => {
88 if index > last_index + 1 {
89 contiguous_case_ranges.push(ContiguousCaseRange::new(index));
90 }
91 }
92 }
93 contiguous_case_ranges
94 .last_mut()
95 .unwrap()
96 .blocks
97 .push(block);
98 last_index = Some(index);
99 }
100
101 log::trace!(
102 "build_contiguous_case_ranges after: {:#?}",
103 contiguous_case_ranges
104 );
105
106 contiguous_case_ranges
107 }
108
109 fn build_search_tree<'a>(
111 bx: &mut FunctionBuilder,
112 val: Value,
113 otherwise: Block,
114 contiguous_case_ranges: &'a [ContiguousCaseRange],
115 ) {
116 if contiguous_case_ranges.is_empty() {
118 bx.ins().jump(otherwise, &[]);
119 return;
120 }
121
122 if contiguous_case_ranges.len() <= 3 {
124 Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
125 return;
126 }
127
128 let mut stack = Vec::new();
129 stack.push((None, contiguous_case_ranges));
130
131 while let Some((block, contiguous_case_ranges)) = stack.pop() {
132 if let Some(block) = block {
133 bx.switch_to_block(block);
134 }
135
136 if contiguous_case_ranges.len() <= 3 {
137 Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
138 } else {
139 let split_point = contiguous_case_ranges.len() / 2;
140 let (left, right) = contiguous_case_ranges.split_at(split_point);
141
142 let left_block = bx.create_block();
143 let right_block = bx.create_block();
144
145 let first_index = right[0].first_index;
146 let should_take_right_side =
147 icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
148 bx.ins()
149 .brif(should_take_right_side, right_block, &[], left_block, &[]);
150
151 bx.seal_block(left_block);
152 bx.seal_block(right_block);
153
154 stack.push((Some(left_block), left));
155 stack.push((Some(right_block), right));
156 }
157 }
158 }
159
160 fn build_search_branches<'a>(
162 bx: &mut FunctionBuilder,
163 val: Value,
164 otherwise: Block,
165 contiguous_case_ranges: &'a [ContiguousCaseRange],
166 ) {
167 for (ix, range) in contiguous_case_ranges.iter().enumerate().rev() {
168 let alternate = if ix == 0 {
169 otherwise
170 } else {
171 bx.create_block()
172 };
173
174 if range.first_index == 0 {
175 assert_eq!(alternate, otherwise);
176
177 if let Some(block) = range.single_block() {
178 bx.ins().brif(val, otherwise, &[], block, &[]);
179 } else {
180 Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
181 }
182 } else {
183 if let Some(block) = range.single_block() {
184 let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
185 bx.ins().brif(is_good_val, block, &[], alternate, &[]);
186 } else {
187 let is_good_val = icmp_imm_u128(
188 bx,
189 IntCC::UnsignedGreaterThanOrEqual,
190 val,
191 range.first_index,
192 );
193 let jt_block = bx.create_block();
194 bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
195 bx.seal_block(jt_block);
196 bx.switch_to_block(jt_block);
197 Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
198 }
199 }
200
201 if alternate != otherwise {
202 bx.seal_block(alternate);
203 bx.switch_to_block(alternate);
204 }
205 }
206 }
207
208 fn build_jump_table(
209 bx: &mut FunctionBuilder,
210 val: Value,
211 otherwise: Block,
212 first_index: EntryIndex,
213 blocks: &[Block],
214 ) {
215 assert!(
218 u32::try_from(blocks.len()).is_ok(),
219 "Jump tables bigger than 2^32-1 are not yet supported"
220 );
221
222 let jt_data = JumpTableData::new(
223 bx.func.dfg.block_call(otherwise, &[]),
224 &blocks
225 .iter()
226 .map(|block| bx.func.dfg.block_call(*block, &[]))
227 .collect::<Vec<_>>(),
228 );
229 let jump_table = bx.create_jump_table(jt_data);
230
231 let discr = if first_index == 0 {
232 val
233 } else {
234 if let Ok(first_index) = u64::try_from(first_index) {
235 bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
236 } else {
237 let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
238 let lsb = bx.ins().iconst(types::I64, lsb as i64);
239 let msb = bx.ins().iconst(types::I64, msb as i64);
240 let index = bx.ins().iconcat(lsb, msb);
241 bx.ins().isub(val, index)
242 }
243 };
244
245 let discr = match bx.func.dfg.value_type(discr).bits() {
246 bits if bits > 32 => {
247 let new_block = bx.create_block();
249 let bigger_than_u32 =
250 bx.ins()
251 .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
252 bx.ins()
253 .brif(bigger_than_u32, otherwise, &[], new_block, &[]);
254 bx.seal_block(new_block);
255 bx.switch_to_block(new_block);
256
257 bx.ins().ireduce(types::I32, discr)
259 }
260 bits if bits < 32 => bx.ins().uextend(types::I32, discr),
261 _ => discr,
262 };
263
264 bx.ins().br_table(discr, jump_table);
265 }
266
267 pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
275 let max = self.cases.keys().max().copied().unwrap_or(0);
277 let val_ty = bx.func.dfg.value_type(val);
278 let val_ty_max = val_ty.bounds(false).1;
279 if max > val_ty_max {
280 panic!(
281 "The index type {} does not fit the maximum switch entry of {}",
282 val_ty, max
283 );
284 }
285
286 let contiguous_case_ranges = self.collect_contiguous_case_ranges();
287 Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
288 }
289}
290
291fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
292 if bx.func.dfg.value_type(x) != types::I128 {
293 assert!(u64::try_from(y).is_ok());
294 bx.ins().icmp_imm(cond, x, y as i64)
295 } else if let Ok(index) = i64::try_from(y) {
296 bx.ins().icmp_imm(cond, x, index)
297 } else {
298 let (lsb, msb) = (y as u64, (y >> 64) as u64);
299 let lsb = bx.ins().iconst(types::I64, lsb as i64);
300 let msb = bx.ins().iconst(types::I64, msb as i64);
301 let index = bx.ins().iconcat(lsb, msb);
302 bx.ins().icmp(cond, x, index)
303 }
304}
305
306#[derive(Debug)]
317struct ContiguousCaseRange {
318 first_index: EntryIndex,
320
321 blocks: Vec<Block>,
323}
324
325impl ContiguousCaseRange {
326 fn new(first_index: EntryIndex) -> Self {
327 Self {
328 first_index,
329 blocks: Vec::new(),
330 }
331 }
332
333 fn single_block(&self) -> Option<Block> {
335 if self.blocks.len() == 1 {
336 Some(self.blocks[0])
337 } else {
338 None
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use crate::frontend::FunctionBuilderContext;
347 use alloc::string::ToString;
348
349 macro_rules! setup {
350 ($default:expr, [$($index:expr,)*]) => {{
351 let mut func = Function::new();
352 let mut func_ctx = FunctionBuilderContext::new();
353 {
354 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
355 let block = bx.create_block();
356 bx.switch_to_block(block);
357 let val = bx.ins().iconst(types::I8, 0);
358 #[allow(unused_mut)]
359 let mut switch = Switch::new();
360 $(
361 let block = bx.create_block();
362 switch.set_entry($index, block);
363 )*
364 switch.emit(&mut bx, val, Block::with_number($default).unwrap());
365 }
366 func
367 .to_string()
368 .trim_start_matches("function u0:0() fast {\n")
369 .trim_end_matches("\n}\n")
370 .to_string()
371 }};
372 }
373
374 #[test]
375 fn switch_empty() {
376 let func = setup!(42, []);
377 assert_eq_output!(
378 func,
379 "block0:
380 v0 = iconst.i8 0
381 jump block42"
382 );
383 }
384
385 #[test]
386 fn switch_zero() {
387 let func = setup!(0, [0,]);
388 assert_eq_output!(
389 func,
390 "block0:
391 v0 = iconst.i8 0
392 brif v0, block0, block1 ; v0 = 0"
393 );
394 }
395
396 #[test]
397 fn switch_single() {
398 let func = setup!(0, [1,]);
399 assert_eq_output!(
400 func,
401 "block0:
402 v0 = iconst.i8 0
403 v1 = icmp_imm eq v0, 1 ; v0 = 0
404 brif v1, block1, block0"
405 );
406 }
407
408 #[test]
409 fn switch_bool() {
410 let func = setup!(0, [0, 1,]);
411 assert_eq_output!(
412 func,
413 "block0:
414 v0 = iconst.i8 0
415 v1 = uextend.i32 v0 ; v0 = 0
416 br_table v1, block0, [block1, block2]"
417 );
418 }
419
420 #[test]
421 fn switch_two_gap() {
422 let func = setup!(0, [0, 2,]);
423 assert_eq_output!(
424 func,
425 "block0:
426 v0 = iconst.i8 0
427 v1 = icmp_imm eq v0, 2 ; v0 = 0
428 brif v1, block2, block3
429
430block3:
431 brif.i8 v0, block0, block1 ; v0 = 0"
432 );
433 }
434
435 #[test]
436 fn switch_many() {
437 let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
438 assert_eq_output!(
439 func,
440 "block0:
441 v0 = iconst.i8 0
442 v1 = icmp_imm uge v0, 7 ; v0 = 0
443 brif v1, block9, block8
444
445block9:
446 v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0
447 brif v2, block11, block10
448
449block11:
450 v3 = iadd_imm.i8 v0, -10 ; v0 = 0
451 v4 = uextend.i32 v3
452 br_table v4, block0, [block5, block6, block7]
453
454block10:
455 v5 = icmp_imm.i8 eq v0, 7 ; v0 = 0
456 brif v5, block4, block0
457
458block8:
459 v6 = icmp_imm.i8 eq v0, 5 ; v0 = 0
460 brif v6, block3, block12
461
462block12:
463 v7 = uextend.i32 v0 ; v0 = 0
464 br_table v7, block0, [block1, block2]"
465 );
466 }
467
468 #[test]
469 fn switch_min_index_value() {
470 let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
471 assert_eq_output!(
472 func,
473 "block0:
474 v0 = iconst.i8 0
475 v1 = icmp_imm eq v0, -128 ; v0 = 0
476 brif v1, block1, block3
477
478block3:
479 v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
480 brif v2, block2, block0"
481 );
482 }
483
484 #[test]
485 fn switch_max_index_value() {
486 let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
487 assert_eq_output!(
488 func,
489 "block0:
490 v0 = iconst.i8 0
491 v1 = icmp_imm eq v0, 127 ; v0 = 0
492 brif v1, block1, block3
493
494block3:
495 v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
496 brif v2, block2, block0"
497 )
498 }
499
500 #[test]
501 fn switch_optimal_codegen() {
502 let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
503 assert_eq_output!(
504 func,
505 "block0:
506 v0 = iconst.i8 0
507 v1 = icmp_imm eq v0, -1 ; v0 = 0
508 brif v1, block1, block4
509
510block4:
511 v2 = uextend.i32 v0 ; v0 = 0
512 br_table v2, block0, [block2, block3]"
513 );
514 }
515
516 #[test]
517 #[should_panic(
518 expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"
519 )]
520 fn switch_rejects_small_inputs() {
521 setup!(1, [0x4100_0000_00bf_d470,]);
526 }
527
528 #[test]
529 fn switch_seal_generated_blocks() {
530 let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];
531
532 for case in cases {
533 for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {
534 eprintln!("Testing {:?} with keys: {:?}", typ, case);
535 do_case(case, *typ);
536 }
537 }
538
539 fn do_case(keys: &[u128], typ: Type) {
540 let mut func = Function::new();
541 let mut builder_ctx = FunctionBuilderContext::new();
542 let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
543
544 let root_block = builder.create_block();
545 let default_block = builder.create_block();
546 let mut switch = Switch::new();
547
548 let case_blocks = keys
549 .iter()
550 .map(|key| {
551 let block = builder.create_block();
552 switch.set_entry(*key, block);
553 block
554 })
555 .collect::<Vec<_>>();
556
557 builder.seal_block(root_block);
558 builder.switch_to_block(root_block);
559
560 let val = builder.ins().iconst(typ, 1);
561 switch.emit(&mut builder, val, default_block);
562
563 for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
564 builder.seal_block(block);
565 builder.switch_to_block(block);
566 builder.ins().return_(&[]);
567 }
568
569 builder.finalize(); }
571 }
572
573 #[test]
574 fn switch_64bit() {
575 let mut func = Function::new();
576 let mut func_ctx = FunctionBuilderContext::new();
577 {
578 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
579 let block0 = bx.create_block();
580 bx.switch_to_block(block0);
581 let val = bx.ins().iconst(types::I64, 0);
582 let mut switch = Switch::new();
583 let block1 = bx.create_block();
584 switch.set_entry(1, block1);
585 let block2 = bx.create_block();
586 switch.set_entry(0, block2);
587 let block3 = bx.create_block();
588 switch.emit(&mut bx, val, block3);
589 }
590 let func = func
591 .to_string()
592 .trim_start_matches("function u0:0() fast {\n")
593 .trim_end_matches("\n}\n")
594 .to_string();
595 assert_eq_output!(
596 func,
597 "block0:
598 v0 = iconst.i64 0
599 v1 = icmp_imm ugt v0, 0xffff_ffff ; v0 = 0
600 brif v1, block3, block4
601
602block4:
603 v2 = ireduce.i32 v0 ; v0 = 0
604 br_table v2, block3, [block2, block1]"
605 );
606 }
607
608 #[test]
609 fn switch_128bit() {
610 let mut func = Function::new();
611 let mut func_ctx = FunctionBuilderContext::new();
612 {
613 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
614 let block0 = bx.create_block();
615 bx.switch_to_block(block0);
616 let val = bx.ins().iconst(types::I64, 0);
617 let val = bx.ins().uextend(types::I128, val);
618 let mut switch = Switch::new();
619 let block1 = bx.create_block();
620 switch.set_entry(1, block1);
621 let block2 = bx.create_block();
622 switch.set_entry(0, block2);
623 let block3 = bx.create_block();
624 switch.emit(&mut bx, val, block3);
625 }
626 let func = func
627 .to_string()
628 .trim_start_matches("function u0:0() fast {\n")
629 .trim_end_matches("\n}\n")
630 .to_string();
631 assert_eq_output!(
632 func,
633 "block0:
634 v0 = iconst.i64 0
635 v1 = uextend.i128 v0 ; v0 = 0
636 v2 = icmp_imm ugt v1, 0xffff_ffff
637 brif v2, block3, block4
638
639block4:
640 v3 = ireduce.i32 v1
641 br_table v3, block3, [block2, block1]"
642 );
643 }
644
645 #[test]
646 fn switch_128bit_max_u64() {
647 let mut func = Function::new();
648 let mut func_ctx = FunctionBuilderContext::new();
649 {
650 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
651 let block0 = bx.create_block();
652 bx.switch_to_block(block0);
653 let val = bx.ins().iconst(types::I64, 0);
654 let val = bx.ins().uextend(types::I128, val);
655 let mut switch = Switch::new();
656 let block1 = bx.create_block();
657 switch.set_entry(u64::MAX.into(), block1);
658 let block2 = bx.create_block();
659 switch.set_entry(0, block2);
660 let block3 = bx.create_block();
661 switch.emit(&mut bx, val, block3);
662 }
663 let func = func
664 .to_string()
665 .trim_start_matches("function u0:0() fast {\n")
666 .trim_end_matches("\n}\n")
667 .to_string();
668 assert_eq_output!(
669 func,
670 "block0:
671 v0 = iconst.i64 0
672 v1 = uextend.i128 v0 ; v0 = 0
673 v2 = iconst.i64 -1
674 v3 = iconst.i64 0
675 v4 = iconcat v2, v3 ; v2 = -1, v3 = 0
676 v5 = icmp eq v1, v4
677 brif v5, block1, block4
678
679block4:
680 brif.i128 v1, block3, block2"
681 );
682 }
683}