1use super::HashMap;
2use crate::frontend::FunctionBuilder;
3use alloc::vec::Vec;
4use cranelift_codegen::ir::condcodes::IntCC;
5use cranelift_codegen::ir::*;
6
7type EntryIndex = u128;
8
9#[derive(Debug, Default)]
42pub struct Switch {
43 cases: HashMap<EntryIndex, Block>,
44}
45
46impl Switch {
47 pub fn new() -> Self {
49 Self {
50 cases: HashMap::new(),
51 }
52 }
53
54 pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
56 let prev = self.cases.insert(index, block);
57 assert!(prev.is_none(), "Tried to set the same entry {index} twice");
58 }
59
60 pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
62 &self.cases
63 }
64
65 fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
74 log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);
75 let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
76 cases.sort_by_key(|&(index, _)| index);
77
78 let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
79 let mut last_index = None;
80 for (index, block) in cases {
81 match last_index {
82 None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
83 Some(last_index) => {
84 if index > last_index + 1 {
85 contiguous_case_ranges.push(ContiguousCaseRange::new(index));
86 }
87 }
88 }
89 contiguous_case_ranges
90 .last_mut()
91 .unwrap()
92 .blocks
93 .push(block);
94 last_index = Some(index);
95 }
96
97 log::trace!("build_contiguous_case_ranges after: {contiguous_case_ranges:#?}");
98
99 contiguous_case_ranges
100 }
101
102 fn build_search_tree<'a>(
104 bx: &mut FunctionBuilder,
105 val: Value,
106 otherwise: Block,
107 contiguous_case_ranges: &'a [ContiguousCaseRange],
108 ) {
109 if contiguous_case_ranges.is_empty() {
111 bx.ins().jump(otherwise, &[]);
112 return;
113 }
114
115 if contiguous_case_ranges.len() <= 3 {
117 Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
118 return;
119 }
120
121 let mut stack = Vec::new();
122 stack.push((None, contiguous_case_ranges));
123
124 while let Some((block, contiguous_case_ranges)) = stack.pop() {
125 if let Some(block) = block {
126 bx.switch_to_block(block);
127 }
128
129 if contiguous_case_ranges.len() <= 3 {
130 Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
131 } else {
132 let split_point = contiguous_case_ranges.len() / 2;
133 let (left, right) = contiguous_case_ranges.split_at(split_point);
134
135 let left_block = bx.create_block();
136 let right_block = bx.create_block();
137
138 let first_index = right[0].first_index;
139 let should_take_right_side =
140 icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
141 bx.ins()
142 .brif(should_take_right_side, right_block, &[], left_block, &[]);
143
144 bx.seal_block(left_block);
145 bx.seal_block(right_block);
146
147 stack.push((Some(left_block), left));
148 stack.push((Some(right_block), right));
149 }
150 }
151 }
152
153 fn build_search_branches<'a>(
155 bx: &mut FunctionBuilder,
156 val: Value,
157 otherwise: Block,
158 contiguous_case_ranges: &'a [ContiguousCaseRange],
159 ) {
160 for (ix, range) in contiguous_case_ranges.iter().enumerate().rev() {
161 let alternate = if ix == 0 {
162 otherwise
163 } else {
164 bx.create_block()
165 };
166
167 if range.first_index == 0 {
168 assert_eq!(alternate, otherwise);
169
170 if let Some(block) = range.single_block() {
171 bx.ins().brif(val, otherwise, &[], block, &[]);
172 } else {
173 Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
174 }
175 } else {
176 if let Some(block) = range.single_block() {
177 let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
178 bx.ins().brif(is_good_val, block, &[], alternate, &[]);
179 } else {
180 let is_good_val = icmp_imm_u128(
181 bx,
182 IntCC::UnsignedGreaterThanOrEqual,
183 val,
184 range.first_index,
185 );
186 let jt_block = bx.create_block();
187 bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
188 bx.seal_block(jt_block);
189 bx.switch_to_block(jt_block);
190 Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
191 }
192 }
193
194 if alternate != otherwise {
195 bx.seal_block(alternate);
196 bx.switch_to_block(alternate);
197 }
198 }
199 }
200
201 fn build_jump_table(
202 bx: &mut FunctionBuilder,
203 val: Value,
204 otherwise: Block,
205 first_index: EntryIndex,
206 blocks: &[Block],
207 ) {
208 assert!(
211 u32::try_from(blocks.len()).is_ok(),
212 "Jump tables bigger than 2^32-1 are not yet supported"
213 );
214
215 let jt_data = JumpTableData::new(
216 bx.func.dfg.block_call(otherwise, &[]),
217 &blocks
218 .iter()
219 .map(|block| bx.func.dfg.block_call(*block, &[]))
220 .collect::<Vec<_>>(),
221 );
222 let jump_table = bx.create_jump_table(jt_data);
223
224 let discr = if first_index == 0 {
225 val
226 } else {
227 if let Ok(first_index) = u64::try_from(first_index) {
228 bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
229 } else {
230 let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
231 let lsb = bx.ins().iconst(types::I64, lsb as i64);
232 let msb = bx.ins().iconst(types::I64, msb as i64);
233 let index = bx.ins().iconcat(lsb, msb);
234 bx.ins().isub(val, index)
235 }
236 };
237
238 let discr = match bx.func.dfg.value_type(discr).bits() {
239 bits if bits > 32 => {
240 let new_block = bx.create_block();
242 let bigger_than_u32 =
243 bx.ins()
244 .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
245 bx.ins()
246 .brif(bigger_than_u32, otherwise, &[], new_block, &[]);
247 bx.seal_block(new_block);
248 bx.switch_to_block(new_block);
249
250 bx.ins().ireduce(types::I32, discr)
252 }
253 bits if bits < 32 => bx.ins().uextend(types::I32, discr),
254 _ => discr,
255 };
256
257 bx.ins().br_table(discr, jump_table);
258 }
259
260 pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
268 let max = self.cases.keys().max().copied().unwrap_or(0);
270 let val_ty = bx.func.dfg.value_type(val);
271 let val_ty_max = val_ty.bounds(false).1;
272 if max > val_ty_max {
273 panic!("The index type {val_ty} does not fit the maximum switch entry of {max}");
274 }
275
276 let contiguous_case_ranges = self.collect_contiguous_case_ranges();
277 Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
278 }
279}
280
281fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
282 if bx.func.dfg.value_type(x) != types::I128 {
283 assert!(u64::try_from(y).is_ok());
284 bx.ins().icmp_imm(cond, x, y as i64)
285 } else if let Ok(index) = i64::try_from(y) {
286 bx.ins().icmp_imm(cond, x, index)
287 } else {
288 let (lsb, msb) = (y as u64, (y >> 64) as u64);
289 let lsb = bx.ins().iconst(types::I64, lsb as i64);
290 let msb = bx.ins().iconst(types::I64, msb as i64);
291 let index = bx.ins().iconcat(lsb, msb);
292 bx.ins().icmp(cond, x, index)
293 }
294}
295
296#[derive(Debug)]
307struct ContiguousCaseRange {
308 first_index: EntryIndex,
310
311 blocks: Vec<Block>,
313}
314
315impl ContiguousCaseRange {
316 fn new(first_index: EntryIndex) -> Self {
317 Self {
318 first_index,
319 blocks: Vec::new(),
320 }
321 }
322
323 fn single_block(&self) -> Option<Block> {
325 if self.blocks.len() == 1 {
326 Some(self.blocks[0])
327 } else {
328 None
329 }
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::frontend::FunctionBuilderContext;
337 use alloc::string::ToString;
338
339 macro_rules! setup {
340 ($default:expr, [$($index:expr,)*]) => {{
341 let mut func = Function::new();
342 let mut func_ctx = FunctionBuilderContext::new();
343 {
344 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
345 let block = bx.create_block();
346 bx.switch_to_block(block);
347 let val = bx.ins().iconst(types::I8, 0);
348 let mut switch = Switch::new();
349 let _ = &mut switch;
350 $(
351 let block = bx.create_block();
352 switch.set_entry($index, block);
353 )*
354 switch.emit(&mut bx, val, Block::with_number($default).unwrap());
355 }
356 func
357 .to_string()
358 .trim_start_matches("function u0:0() fast {\n")
359 .trim_end_matches("\n}\n")
360 .to_string()
361 }};
362 }
363
364 #[test]
365 fn switch_empty() {
366 let func = setup!(42, []);
367 assert_eq_output!(
368 func,
369 "block0:
370 v0 = iconst.i8 0
371 jump block42"
372 );
373 }
374
375 #[test]
376 fn switch_zero() {
377 let func = setup!(0, [0,]);
378 assert_eq_output!(
379 func,
380 "block0:
381 v0 = iconst.i8 0
382 brif v0, block0, block1 ; v0 = 0"
383 );
384 }
385
386 #[test]
387 fn switch_single() {
388 let func = setup!(0, [1,]);
389 assert_eq_output!(
390 func,
391 "block0:
392 v0 = iconst.i8 0
393 v1 = icmp_imm eq v0, 1 ; v0 = 0
394 brif v1, block1, block0"
395 );
396 }
397
398 #[test]
399 fn switch_bool() {
400 let func = setup!(0, [0, 1,]);
401 assert_eq_output!(
402 func,
403 "block0:
404 v0 = iconst.i8 0
405 v1 = uextend.i32 v0 ; v0 = 0
406 br_table v1, block0, [block1, block2]"
407 );
408 }
409
410 #[test]
411 fn switch_two_gap() {
412 let func = setup!(0, [0, 2,]);
413 assert_eq_output!(
414 func,
415 "block0:
416 v0 = iconst.i8 0
417 v1 = icmp_imm eq v0, 2 ; v0 = 0
418 brif v1, block2, block3
419
420block3:
421 brif.i8 v0, block0, block1 ; v0 = 0"
422 );
423 }
424
425 #[test]
426 fn switch_many() {
427 let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
428 assert_eq_output!(
429 func,
430 "block0:
431 v0 = iconst.i8 0
432 v1 = icmp_imm uge v0, 7 ; v0 = 0
433 brif v1, block9, block8
434
435block9:
436 v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0
437 brif v2, block11, block10
438
439block11:
440 v3 = iadd_imm.i8 v0, -10 ; v0 = 0
441 v4 = uextend.i32 v3
442 br_table v4, block0, [block5, block6, block7]
443
444block10:
445 v5 = icmp_imm.i8 eq v0, 7 ; v0 = 0
446 brif v5, block4, block0
447
448block8:
449 v6 = icmp_imm.i8 eq v0, 5 ; v0 = 0
450 brif v6, block3, block12
451
452block12:
453 v7 = uextend.i32 v0 ; v0 = 0
454 br_table v7, block0, [block1, block2]"
455 );
456 }
457
458 #[test]
459 fn switch_min_index_value() {
460 let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
461 assert_eq_output!(
462 func,
463 "block0:
464 v0 = iconst.i8 0
465 v1 = icmp_imm eq v0, -128 ; v0 = 0
466 brif v1, block1, block3
467
468block3:
469 v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
470 brif v2, block2, block0"
471 );
472 }
473
474 #[test]
475 fn switch_max_index_value() {
476 let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
477 assert_eq_output!(
478 func,
479 "block0:
480 v0 = iconst.i8 0
481 v1 = icmp_imm eq v0, 127 ; v0 = 0
482 brif v1, block1, block3
483
484block3:
485 v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
486 brif v2, block2, block0"
487 )
488 }
489
490 #[test]
491 fn switch_optimal_codegen() {
492 let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
493 assert_eq_output!(
494 func,
495 "block0:
496 v0 = iconst.i8 0
497 v1 = icmp_imm eq v0, -1 ; v0 = 0
498 brif v1, block1, block4
499
500block4:
501 v2 = uextend.i32 v0 ; v0 = 0
502 br_table v2, block0, [block2, block3]"
503 );
504 }
505
506 #[test]
507 #[should_panic(
508 expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"
509 )]
510 fn switch_rejects_small_inputs() {
511 setup!(1, [0x4100_0000_00bf_d470,]);
516 }
517
518 #[test]
519 fn switch_seal_generated_blocks() {
520 let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];
521
522 for case in cases {
523 for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {
524 eprintln!("Testing {typ:?} with keys: {case:?}");
525 do_case(case, *typ);
526 }
527 }
528
529 fn do_case(keys: &[u128], typ: Type) {
530 let mut func = Function::new();
531 let mut builder_ctx = FunctionBuilderContext::new();
532 let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
533
534 let root_block = builder.create_block();
535 let default_block = builder.create_block();
536 let mut switch = Switch::new();
537
538 let case_blocks = keys
539 .iter()
540 .map(|key| {
541 let block = builder.create_block();
542 switch.set_entry(*key, block);
543 block
544 })
545 .collect::<Vec<_>>();
546
547 builder.seal_block(root_block);
548 builder.switch_to_block(root_block);
549
550 let val = builder.ins().iconst(typ, 1);
551 switch.emit(&mut builder, val, default_block);
552
553 for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
554 builder.seal_block(block);
555 builder.switch_to_block(block);
556 builder.ins().return_(&[]);
557 }
558
559 builder.finalize(); }
561 }
562
563 #[test]
564 fn switch_64bit() {
565 let mut func = Function::new();
566 let mut func_ctx = FunctionBuilderContext::new();
567 {
568 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
569 let block0 = bx.create_block();
570 bx.switch_to_block(block0);
571 let val = bx.ins().iconst(types::I64, 0);
572 let mut switch = Switch::new();
573 let block1 = bx.create_block();
574 switch.set_entry(1, block1);
575 let block2 = bx.create_block();
576 switch.set_entry(0, block2);
577 let block3 = bx.create_block();
578 switch.emit(&mut bx, val, block3);
579 }
580 let func = func
581 .to_string()
582 .trim_start_matches("function u0:0() fast {\n")
583 .trim_end_matches("\n}\n")
584 .to_string();
585 assert_eq_output!(
586 func,
587 "block0:
588 v0 = iconst.i64 0
589 v1 = icmp_imm ugt v0, 0xffff_ffff ; v0 = 0
590 brif v1, block3, block4
591
592block4:
593 v2 = ireduce.i32 v0 ; v0 = 0
594 br_table v2, block3, [block2, block1]"
595 );
596 }
597
598 #[test]
599 fn switch_128bit() {
600 let mut func = Function::new();
601 let mut func_ctx = FunctionBuilderContext::new();
602 {
603 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
604 let block0 = bx.create_block();
605 bx.switch_to_block(block0);
606 let val = bx.ins().iconst(types::I64, 0);
607 let val = bx.ins().uextend(types::I128, val);
608 let mut switch = Switch::new();
609 let block1 = bx.create_block();
610 switch.set_entry(1, block1);
611 let block2 = bx.create_block();
612 switch.set_entry(0, block2);
613 let block3 = bx.create_block();
614 switch.emit(&mut bx, val, block3);
615 }
616 let func = func
617 .to_string()
618 .trim_start_matches("function u0:0() fast {\n")
619 .trim_end_matches("\n}\n")
620 .to_string();
621 assert_eq_output!(
622 func,
623 "block0:
624 v0 = iconst.i64 0
625 v1 = uextend.i128 v0 ; v0 = 0
626 v2 = icmp_imm ugt v1, 0xffff_ffff
627 brif v2, block3, block4
628
629block4:
630 v3 = ireduce.i32 v1
631 br_table v3, block3, [block2, block1]"
632 );
633 }
634
635 #[test]
636 fn switch_128bit_max_u64() {
637 let mut func = Function::new();
638 let mut func_ctx = FunctionBuilderContext::new();
639 {
640 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
641 let block0 = bx.create_block();
642 bx.switch_to_block(block0);
643 let val = bx.ins().iconst(types::I64, 0);
644 let val = bx.ins().uextend(types::I128, val);
645 let mut switch = Switch::new();
646 let block1 = bx.create_block();
647 switch.set_entry(u64::MAX.into(), block1);
648 let block2 = bx.create_block();
649 switch.set_entry(0, block2);
650 let block3 = bx.create_block();
651 switch.emit(&mut bx, val, block3);
652 }
653 let func = func
654 .to_string()
655 .trim_start_matches("function u0:0() fast {\n")
656 .trim_end_matches("\n}\n")
657 .to_string();
658 assert_eq_output!(
659 func,
660 "block0:
661 v0 = iconst.i64 0
662 v1 = uextend.i128 v0 ; v0 = 0
663 v2 = iconst.i64 -1
664 v3 = iconst.i64 0
665 v4 = iconcat v2, v3 ; v2 = -1, v3 = 0
666 v5 = icmp eq v1, v4
667 brif v5, block1, block4
668
669block4:
670 brif.i128 v1, block3, block2"
671 );
672 }
673}