1mod layout;
2
3use std::{iter, slice};
4
5use crate::types::{HugrSumType, TypingSession};
6
7use anyhow::{Result, anyhow, bail, ensure};
8use delegate::delegate;
9use hugr_core::types::TypeRow;
10use inkwell::{
11 builder::Builder,
12 context::Context,
13 types::{AnyType, AsTypeRef, BasicType, BasicTypeEnum, IntType, StructType},
14 values::{AnyValue, AsValueRef, BasicValue, BasicValueEnum, IntValue, StructValue},
15};
16use itertools::{Itertools as _, izip};
17
18pub fn elidable_type<'c>(ty: impl BasicType<'c>) -> bool {
25 let ty = ty.as_basic_type_enum();
26 match ty {
27 BasicTypeEnum::ArrayType(array_type) => array_type.is_empty(),
28 BasicTypeEnum::StructType(struct_type) => struct_type.count_fields() == 0,
29 _ => false,
30 }
31}
32
33fn get_variant_typerow(sum_type: &HugrSumType, tag: u32) -> Result<TypeRow> {
34 sum_type
35 .get_variant(tag as usize)
36 .ok_or(anyhow!("Bad variant index {tag} in {sum_type}"))
37 .and_then(|tr| Ok(TypeRow::try_from(tr.clone())?))
38}
39
40fn basic_type_undef<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> {
42 let t = t.as_basic_type_enum();
43 match t {
44 BasicTypeEnum::ArrayType(t) => t.get_undef().as_basic_value_enum(),
45 BasicTypeEnum::FloatType(t) => t.get_undef().as_basic_value_enum(),
46 BasicTypeEnum::IntType(t) => t.get_undef().as_basic_value_enum(),
47 BasicTypeEnum::PointerType(t) => t.get_undef().as_basic_value_enum(),
48 BasicTypeEnum::StructType(t) => t.get_undef().as_basic_value_enum(),
49 BasicTypeEnum::VectorType(t) => t.get_undef().as_basic_value_enum(),
50 BasicTypeEnum::ScalableVectorType(t) => t.get_undef().as_basic_value_enum(),
51 }
52}
53
54fn basic_type_poison<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> {
56 let t = t.as_basic_type_enum();
57 match t {
58 BasicTypeEnum::ArrayType(t) => t.get_poison().as_basic_value_enum(),
59 BasicTypeEnum::FloatType(t) => t.get_poison().as_basic_value_enum(),
60 BasicTypeEnum::IntType(t) => t.get_poison().as_basic_value_enum(),
61 BasicTypeEnum::PointerType(t) => t.get_poison().as_basic_value_enum(),
62 BasicTypeEnum::StructType(t) => t.get_poison().as_basic_value_enum(),
63 BasicTypeEnum::VectorType(t) => t.get_poison().as_basic_value_enum(),
64 BasicTypeEnum::ScalableVectorType(t) => t.get_poison().as_basic_value_enum(),
65 }
66}
67
68#[derive(Debug, Clone, derive_more::Display)]
69pub struct LLVMSumType<'c>(LLVMSumTypeEnum<'c>);
93
94impl<'c> LLVMSumType<'c> {
95 delegate! {
96 to self.0 {
97 #[must_use] pub fn value_type(&self) -> BasicTypeEnum<'c>;
99 #[must_use] pub fn tag_type(&self) -> IntType<'c>;
101 #[must_use] pub fn num_variants(&self) -> usize;
103 #[must_use] pub fn num_fields_for_variant(&self, tag: usize) -> usize;
106 #[must_use] pub fn fields_for_variant(&self, tag: usize) -> &[BasicTypeEnum<'c>];
109 }
110 }
111
112 pub fn try_from_hugr_type(
118 session: &TypingSession<'c, '_>,
119 sum_type: HugrSumType,
120 ) -> Result<Self> {
121 let variants = (0..sum_type.num_variants())
122 .map(|i| {
123 let tr = get_variant_typerow(&sum_type, i as u32)?;
124 tr.iter()
125 .map(|t| session.llvm_type(t))
126 .collect::<Result<Vec<_>>>()
127 })
128 .collect::<Result<Vec<_>>>()?;
129 Self::try_new(session.iw_context(), variants)
130 }
131
132 pub fn try_new(
137 context: &'c Context,
138 variant_types: impl Into<Vec<Vec<BasicTypeEnum<'c>>>>,
139 ) -> Result<Self> {
140 Ok(Self(LLVMSumTypeEnum::try_new(
141 context,
142 variant_types.into(),
143 )?))
144 }
145
146 #[must_use]
148 pub fn get_undef(&self) -> impl BasicValue<'c> + use<'c> {
149 basic_type_undef(self.0.value_type())
150 }
151
152 #[must_use]
154 pub fn get_poison(&self) -> impl BasicValue<'c> + use<'c> {
155 basic_type_poison(self.0.value_type())
156 }
157
158 pub fn build_tag(
161 &self,
162 builder: &Builder<'c>,
163 tag: usize,
164 vs: Vec<BasicValueEnum<'c>>,
165 ) -> Result<LLVMSumValue<'c>> {
166 self.value(self.0.build_tag(builder, tag, vs)?)
167 }
168
169 pub fn value(&self, value: impl BasicValue<'c>) -> Result<LLVMSumValue<'c>> {
173 LLVMSumValue::try_new(value, self.clone())
174 }
175}
176
177#[derive(Debug, Clone)]
181enum LLVMSumTypeEnum<'c> {
182 Void { tag_type: IntType<'c> },
186 Unit {
192 field_types: Vec<BasicTypeEnum<'c>>,
195 tag_type: IntType<'c>,
199 value_type: StructType<'c>,
201 },
202 NoFields {
206 variant_types: Vec<Vec<BasicTypeEnum<'c>>>,
210 value_type: IntType<'c>,
213 },
214 SingleVariantSingleField {
219 field_types: Vec<BasicTypeEnum<'c>>,
222 field_index: usize,
224 tag_type: IntType<'c>,
228 },
229 SingleVariantMultiField {
234 field_types: Vec<BasicTypeEnum<'c>>,
237 field_indices: Vec<Option<usize>>,
239 tag_type: IntType<'c>,
243 value_type: StructType<'c>,
246 },
247 MultiVariant {
253 variant_types: Vec<Vec<BasicTypeEnum<'c>>>,
256 field_indices: Vec<Vec<Option<usize>>>,
258 value_type: StructType<'c>,
262 },
263}
264
265fn tag_width_for_num_variants(num_variants: usize) -> u32 {
267 debug_assert!(num_variants >= 1);
268 if num_variants == 1 {
269 return 1;
270 }
271 (num_variants - 1).ilog2() + 1
272}
273
274impl<'c> LLVMSumTypeEnum<'c> {
275 pub fn try_new(
278 context: &'c Context,
279 variant_types: Vec<Vec<BasicTypeEnum<'c>>>,
280 ) -> Result<Self> {
281 let result = match variant_types.len() {
282 0 => Self::Void {
283 tag_type: context.bool_type(),
284 },
285 1 => {
286 let variant_types = variant_types.into_iter().exactly_one().unwrap();
287 let (fields, field_indices) =
288 layout::layout_variants(slice::from_ref(&variant_types));
289 let field_indices = field_indices.into_iter().exactly_one().unwrap();
290 match fields.len() {
291 0 => Self::Unit {
292 field_types: variant_types,
293 tag_type: context.bool_type(),
294 value_type: context.struct_type(&[], false),
295 },
296 1 => {
297 let field_index = field_indices
298 .into_iter()
299 .enumerate()
300 .filter_map(|(i, f_i)| f_i.is_some().then_some(i))
301 .exactly_one()
302 .unwrap();
303 Self::SingleVariantSingleField {
304 field_types: variant_types,
305 field_index,
306 tag_type: context.bool_type(),
307 }
308 }
309 _num_fields => Self::SingleVariantMultiField {
310 field_types: variant_types,
311 field_indices,
312 tag_type: context.bool_type(),
313 value_type: context.struct_type(&fields, false),
314 },
315 }
316 }
317 num_variants => {
318 let (mut fields, field_indices) = layout::layout_variants(&variant_types);
319 let tag_type =
320 context.custom_width_int_type(tag_width_for_num_variants(num_variants));
321 if fields.is_empty() {
322 Self::NoFields {
323 variant_types,
324 value_type: tag_type,
325 }
326 } else {
327 fields.insert(0, tag_type.into());
329 let value_type = context.struct_type(&fields, false);
330 Self::MultiVariant {
331 variant_types,
332 field_indices,
333 value_type,
334 }
335 }
336 }
337 };
338 Ok(result)
339 }
340
341 pub fn build_tag(
349 &self,
350 builder: &Builder<'c>,
351 tag: usize,
352 vs: Vec<BasicValueEnum<'c>>,
353 ) -> Result<BasicValueEnum<'c>> {
354 ensure!(tag < self.num_variants());
355 ensure!(vs.len() == self.num_fields_for_variant(tag));
356 ensure!(iter::zip(&vs, self.fields_for_variant(tag)).all(|(x, y)| &x.get_type() == y));
357 let value = match self {
358 Self::Void { .. } => bail!("Can't tag an empty sum"),
359 Self::Unit { value_type, .. } => value_type.get_undef().as_basic_value_enum(),
360 Self::NoFields { value_type, .. } => value_type
361 .const_int(tag as u64, false)
362 .as_basic_value_enum(),
363 Self::SingleVariantSingleField { field_index, .. } => vs[*field_index],
364 Self::SingleVariantMultiField {
365 value_type,
366 field_indices,
367 ..
368 } => {
369 let mut value = value_type.get_poison();
370 for (mb_i, v) in itertools::zip_eq(field_indices, vs) {
371 if let Some(i) = mb_i {
372 value = builder
373 .build_insert_value(value, v, *i as u32, "")?
374 .into_struct_value();
375 }
376 }
377 value.as_basic_value_enum()
378 }
379 Self::MultiVariant {
380 field_indices,
381 variant_types,
382 value_type,
383 } => {
384 let variant_field_types = &variant_types[tag];
385 let variant_field_indices = &field_indices[tag];
386 let mut value = builder
387 .build_insert_value(
388 value_type.get_poison(),
389 self.tag_type().const_int(tag as u64, false),
390 0,
391 "",
392 )?
393 .into_struct_value();
394 for (t, mb_i, v) in izip!(variant_field_types, variant_field_indices, vs) {
395 ensure!(&v.get_type() == t);
396 if let Some(i) = mb_i {
397 value = builder
398 .build_insert_value(value, v, *i as u32 + 1, "")?
399 .into_struct_value();
400 }
401 }
402 value.as_basic_value_enum()
403 }
404 };
405 debug_assert_eq!(value.get_type(), self.value_type());
406 Ok(value)
407 }
408
409 pub fn tag_type(&self) -> IntType<'c> {
411 match self {
412 Self::Void { tag_type, .. } => *tag_type,
413 Self::Unit { tag_type, .. } => *tag_type,
414 Self::NoFields { value_type, .. } => *value_type,
415 Self::SingleVariantSingleField { tag_type, .. } => *tag_type,
416 Self::SingleVariantMultiField { tag_type, .. } => *tag_type,
417 Self::MultiVariant { value_type, .. } => value_type
418 .get_field_type_at_index(0)
419 .unwrap()
420 .into_int_type(),
421 }
422 }
423
424 pub fn value_type(&self) -> BasicTypeEnum<'c> {
426 match self {
427 Self::Void { tag_type, .. } => (*tag_type).into(),
428 Self::Unit { value_type, .. } => (*value_type).into(),
429 Self::NoFields { value_type, .. } => (*value_type).into(),
430 Self::SingleVariantSingleField {
431 field_index,
432 field_types: variant_types,
433 ..
434 } => variant_types[*field_index],
435 Self::SingleVariantMultiField { value_type, .. }
436 | Self::MultiVariant { value_type, .. } => (*value_type).into(),
437 }
438 }
439
440 pub fn num_variants(&self) -> usize {
442 match self {
443 Self::Void { .. } => 0,
444 Self::Unit { .. }
445 | Self::SingleVariantSingleField { .. }
446 | Self::SingleVariantMultiField { .. } => 1,
447 Self::NoFields { variant_types, .. } | Self::MultiVariant { variant_types, .. } => {
448 variant_types.len()
449 }
450 }
451 }
452
453 pub(self) fn num_fields_for_variant(&self, tag: usize) -> usize {
456 self.fields_for_variant(tag).len()
457 }
458
459 pub(self) fn fields_for_variant(&self, tag: usize) -> &[BasicTypeEnum<'c>] {
462 assert!(tag < self.num_variants());
463 match self {
464 Self::Void { .. } => unreachable!("Void has no valid tag"),
465 Self::SingleVariantSingleField { field_types, .. }
466 | Self::SingleVariantMultiField { field_types, .. }
467 | Self::Unit { field_types, .. } => &field_types[..],
468 Self::MultiVariant { variant_types, .. } | Self::NoFields { variant_types, .. } => {
469 &variant_types[tag]
470 }
471 }
472 }
473}
474
475impl<'c> From<LLVMSumTypeEnum<'c>> for BasicTypeEnum<'c> {
476 fn from(value: LLVMSumTypeEnum<'c>) -> Self {
477 value.value_type()
478 }
479}
480
481impl std::fmt::Display for LLVMSumTypeEnum<'_> {
482 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483 self.value_type().fmt(f)
484 }
485}
486
487unsafe impl AsTypeRef for LLVMSumType<'_> {
488 fn as_type_ref(&self) -> inkwell::llvm_sys::prelude::LLVMTypeRef {
489 BasicTypeEnum::from(self.0.clone()).as_type_ref()
490 }
491}
492
493unsafe impl<'c> AnyType<'c> for LLVMSumType<'c> {}
494
495unsafe impl<'c> BasicType<'c> for LLVMSumType<'c> {}
496
497#[derive(Debug)]
500pub struct LLVMSumValue<'c>(BasicValueEnum<'c>, LLVMSumType<'c>);
501
502impl<'c> From<LLVMSumValue<'c>> for BasicValueEnum<'c> {
503 fn from(value: LLVMSumValue<'c>) -> Self {
504 value.0.as_basic_value_enum()
505 }
506}
507
508unsafe impl AsValueRef for LLVMSumValue<'_> {
509 fn as_value_ref(&self) -> inkwell::llvm_sys::prelude::LLVMValueRef {
510 self.0.as_value_ref()
511 }
512}
513
514unsafe impl<'c> AnyValue<'c> for LLVMSumValue<'c> {}
515
516unsafe impl<'c> BasicValue<'c> for LLVMSumValue<'c> {}
517
518impl<'c> LLVMSumValue<'c> {
519 pub fn try_new(value: impl BasicValue<'c>, sum_type: LLVMSumType<'c>) -> Result<Self> {
520 let value = value.as_basic_value_enum();
521 ensure!(
522 !matches!(sum_type.0, LLVMSumTypeEnum::Void { .. }),
523 "Cannot construct LLVMSumValue of a Void sum"
524 );
525 ensure!(
526 value.get_type() == sum_type.value_type(),
527 "Cannot construct LLVMSumValue of type {sum_type} from value of type {}",
528 value.get_type()
529 );
530 Ok(Self(value, sum_type))
531 }
532
533 #[must_use]
534 pub fn get_type(&self) -> LLVMSumType<'c> {
535 self.1.clone()
536 }
537
538 pub fn build_get_tag(&self, builder: &Builder<'c>) -> Result<IntValue<'c>> {
542 let result = match self.get_type().0 {
543 LLVMSumTypeEnum::Void { .. } => bail!("Cannot get tag of void sum"),
544 LLVMSumTypeEnum::Unit { tag_type, .. }
545 | LLVMSumTypeEnum::SingleVariantSingleField { tag_type, .. }
546 | LLVMSumTypeEnum::SingleVariantMultiField { tag_type, .. } => {
547 anyhow::Ok(tag_type.const_int(0, false))
548 }
549 LLVMSumTypeEnum::NoFields { .. } => Ok(self.0.into_int_value()),
550 LLVMSumTypeEnum::MultiVariant { .. } => {
551 let value: StructValue = self.0.into_struct_value();
552 Ok(builder.build_extract_value(value, 0, "")?.into_int_value())
553 }
554 }?;
555 debug_assert_eq!(result.get_type(), self.tag_type());
556 Ok(result)
557 }
558
559 pub fn build_untag(
564 &self,
565 builder: &Builder<'c>,
566 tag: usize,
567 ) -> Result<Vec<BasicValueEnum<'c>>> {
568 ensure!(tag < self.num_variants(), "Bad tag {tag} in {}", self.1);
569 let results =
570 match self.get_type().0 {
571 LLVMSumTypeEnum::Void { .. } => bail!("Cannot untag void sum"),
572 LLVMSumTypeEnum::Unit {
573 field_types: variant_types,
574 ..
575 } => anyhow::Ok(
576 variant_types
577 .into_iter()
578 .map(basic_type_undef)
579 .collect_vec(),
580 ),
581 LLVMSumTypeEnum::NoFields { variant_types, .. } => Ok(variant_types[tag]
582 .iter()
583 .copied()
584 .map(basic_type_undef)
585 .collect()),
586 LLVMSumTypeEnum::SingleVariantSingleField {
587 field_types: variant_types,
588 field_index,
589 ..
590 } => Ok(variant_types
591 .iter()
592 .enumerate()
593 .map(|(i, t)| {
594 if i == field_index {
595 self.0
596 } else {
597 basic_type_undef(*t)
598 }
599 })
600 .collect()),
601 LLVMSumTypeEnum::SingleVariantMultiField {
602 field_types: variant_types,
603 field_indices,
604 ..
605 } => itertools::zip_eq(variant_types, field_indices)
606 .map(|(t, mb_i)| {
607 if let Some(i) = mb_i {
608 Ok(builder.build_extract_value(
609 self.0.into_struct_value(),
610 i as u32,
611 "",
612 )?)
613 } else {
614 Ok(basic_type_undef(t))
615 }
616 })
617 .collect(),
618 LLVMSumTypeEnum::MultiVariant {
619 variant_types,
620 field_indices,
621 ..
622 } => {
623 let value = self.0.into_struct_value();
624 itertools::zip_eq(&variant_types[tag], &field_indices[tag])
625 .map(|(ty, mb_i)| {
626 if let Some(i) = mb_i {
627 Ok(builder.build_extract_value(value, *i as u32 + 1, "")?)
628 } else {
629 Ok(basic_type_undef(*ty))
630 }
631 })
632 .collect()
633 }
634 }?;
635 #[cfg(debug_assertions)]
636 {
637 let result_types = results
638 .iter()
639 .map(inkwell::values::BasicValueEnum::get_type)
640 .collect_vec();
641 assert_eq!(&result_types, self.get_type().fields_for_variant(tag));
642 }
643 Ok(results)
644 }
645
646 pub fn build_destructure(
647 &self,
648 builder: &Builder<'c>,
649 mut handler: impl FnMut(&Builder<'c>, usize, Vec<BasicValueEnum<'c>>) -> Result<()>,
650 ) -> Result<()> {
651 let orig_bb = builder
652 .get_insert_block()
653 .ok_or(anyhow!("No current insertion point"))?;
654 let context = orig_bb.get_context();
655 let mut last_bb = orig_bb;
656 let tag_ty = self.tag_type();
657
658 let mut cases = vec![];
659
660 for var_i in 0..self.1.num_variants() {
661 let bb = context.insert_basic_block_after(last_bb, "");
662 last_bb = bb;
663 cases.push((tag_ty.const_int(var_i as u64, false), bb));
664
665 builder.position_at_end(bb);
666 let inputs = self.build_untag(builder, var_i)?;
667 handler(builder, var_i, inputs)?;
668 }
669
670 builder.position_at_end(orig_bb);
671 let tag = self.build_get_tag(builder)?;
672 builder.build_switch(tag, cases[0].1, &cases[1..])?;
673
674 Ok(())
675 }
676
677 delegate! {
678 to self.1 {
679 #[must_use] pub fn tag_type(&self) -> IntType<'c>;
681 #[must_use] pub fn num_variants(&self) -> usize;
683 }
684 }
685}
686
687#[cfg(test)]
688mod test {
689 use hugr_core::extension::prelude::{bool_t, usize_t};
690 use insta::assert_snapshot;
691 use rstest::{Context, rstest};
692
693 use crate::{
694 test::{TestContext, llvm_ctx},
695 types::HugrType,
696 };
697
698 use super::*;
699
700 #[rstest]
701 #[case(1, 1)]
702 #[case(2, 1)]
703 #[case(3, 2)]
704 #[case(4, 2)]
705 #[case(5, 3)]
706 #[case(8, 3)]
707 #[case(9, 4)]
708 fn tag_width(#[case] num_variants: usize, #[case] expected: u32) {
709 assert_eq!(tag_width_for_num_variants(num_variants), expected);
710 }
711
712 #[rstest]
713 fn sum_types(mut llvm_ctx: TestContext) {
714 llvm_ctx.add_extensions(
715 super::super::custom::CodegenExtsBuilder::add_default_prelude_extensions,
716 );
717 let ts = llvm_ctx.get_typing_session();
718 let iwc = ts.iw_context();
719 let empty_struct = iwc.struct_type(&[], false).as_basic_type_enum();
720 let i1 = iwc.bool_type().as_basic_type_enum();
721 let i2 = iwc.custom_width_int_type(2).as_basic_type_enum();
722 let i64 = iwc.i64_type().as_basic_type_enum();
723
724 {
725 let hugr_type = HugrType::new_unit_sum(0);
727 assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i1);
728 }
729
730 {
731 let hugr_type = HugrType::UNIT;
733 assert_eq!(ts.llvm_type(&hugr_type).unwrap(), empty_struct.clone());
734 }
735
736 {
737 let hugr_type = HugrType::new_tuple(vec![HugrType::UNIT, HugrType::UNIT]);
739 assert_eq!(ts.llvm_type(&hugr_type).unwrap(), empty_struct.clone());
740 }
741
742 {
743 let hugr_type = bool_t();
745 assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i1);
746 }
747
748 {
749 let hugr_type = HugrType::new_sum(vec![vec![HugrType::UNIT]; 3]);
751 assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i2);
752 }
753
754 {
755 let hugr_type = HugrType::new_tuple(vec![usize_t()]);
757 assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i64);
758 }
759
760 {
761 let hugr_type = HugrType::new_tuple(vec![HugrType::UNIT, usize_t()]);
763 assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i64);
764 }
765
766 {
767 let hugr_type = HugrType::new_tuple(vec![usize_t(), bool_t(), HugrType::UNIT]);
769 let llvm_type = iwc.struct_type(&[i64, i1], false).into();
770 assert_eq!(ts.llvm_type(&hugr_type).unwrap(), llvm_type);
771 }
772
773 {
774 let hugr_type1 =
776 HugrType::new_sum([vec![bool_t(), HugrType::UNIT, usize_t()], vec![usize_t()]]);
777 let hugr_type2 = HugrType::new_sum([vec![usize_t(), bool_t()], vec![usize_t()]]);
778 let llvm_type = iwc.struct_type(&[i1, i64, i1], false).into();
779 assert_eq!(ts.llvm_type(&hugr_type1).unwrap(), llvm_type);
780 assert_eq!(ts.llvm_type(&hugr_type2).unwrap(), llvm_type);
781 }
782 }
783
784 #[rstest]
785 #[case::unit(HugrSumType::new_unary(1), 0)]
786 #[case::unit_elided_fields(HugrSumType::new([HugrType::UNIT]), 0)]
787 #[case::nofields(HugrSumType::new_unary(4), 2)]
788 #[case::nofields_elided_fields(HugrSumType::new([vec![HugrType::UNIT], vec![]]), 0)]
789 #[case::one_variant_one_field(HugrSumType::new([bool_t()]), 0)]
790 #[case::one_variant_one_field_elided_fields(HugrSumType::new([vec![HugrType::UNIT,bool_t()]]), 0)]
791 #[case::one_variant_two_fields(HugrSumType::new([vec![bool_t(),bool_t()]]), 0)]
792 #[case::one_variant_two_fields_elided_fields(HugrSumType::new([vec![bool_t(),HugrType::UNIT,bool_t()]]), 0)]
793 #[case::two_variant_one_field(HugrSumType::new([vec![bool_t()],vec![]]), 1)]
794 #[case::two_variant_one_field_elided_fields(HugrSumType::new([vec![bool_t()],vec![HugrType::UNIT]]), 1)]
795 fn build_untag_tag(
796 #[context] rstest_ctx: Context,
797 llvm_ctx: TestContext,
798 #[case] sum: HugrSumType,
799 #[case] tag: usize,
800 ) {
801 let module = {
802 let ts = llvm_ctx.get_typing_session();
803 let iwc = llvm_ctx.iw_context();
804 let module = iwc.create_module("");
805 let llvm_ty = ts.llvm_sum_type(sum.clone()).unwrap();
806 let func_ty = llvm_ty.fn_type(&[llvm_ty.as_basic_type_enum().into()], false);
807 let func = module.add_function("untag_tag", func_ty, None);
808 let bb = iwc.append_basic_block(func, "");
809 let builder = iwc.create_builder();
810 builder.position_at_end(bb);
811 let value = llvm_ty.value(func.get_nth_param(0).unwrap()).unwrap();
812 let _tag = value.build_get_tag(&builder).unwrap();
813 let fields = value.build_untag(&builder, tag).unwrap();
814 let new_value = llvm_ty.build_tag(&builder, tag, fields).unwrap();
815 let _ = builder.build_return(Some(&new_value));
816 module.verify().unwrap();
817 module
818 };
819
820 let mut insta_settings = insta::Settings::clone_current();
821 insta_settings.set_snapshot_suffix(rstest_ctx.description.unwrap());
822 insta_settings.bind(|| {
823 assert_snapshot!(module.to_string());
824 });
825 }
826}