hugr_llvm/
sum.rs

1mod layout;
2
3use std::{iter, slice};
4
5use crate::types::{HugrSumType, TypingSession};
6
7use anyhow::{anyhow, bail, ensure, Result};
8use delegate::delegate;
9use hugr_core::types::TypeRow;
10use inkwell::{
11    builder::Builder,
12    context::Context,
13    types::{AnyType, AsTypeRef, BasicType, BasicTypeEnum, IntType, StructType},
14    values::{AnyValue, AsValueRef, BasicValue, BasicValueEnum, IntValue, StructValue},
15};
16use itertools::{izip, Itertools as _};
17
18/// An elidable type is one that holds no information, for example `{}`, the
19/// empty struct.
20///
21/// Currently the following types are elidable:
22///   * Empty structs, which may be packed, unpacked, named, or unnamed
23///   * Empty arrays of any type.
24pub fn elidable_type<'c>(ty: impl BasicType<'c>) -> bool {
25    let ty = ty.as_basic_type_enum();
26    match ty {
27        BasicTypeEnum::ArrayType(array_type) => array_type.is_empty(),
28        BasicTypeEnum::StructType(struct_type) => struct_type.count_fields() == 0,
29        _ => false,
30    }
31}
32
33fn get_variant_typerow(sum_type: &HugrSumType, tag: u32) -> Result<TypeRow> {
34    sum_type
35        .get_variant(tag as usize)
36        .ok_or(anyhow!("Bad variant index {tag} in {sum_type}"))
37        .and_then(|tr| Ok(TypeRow::try_from(tr.clone())?))
38}
39
40/// Returns an `undef` value for any [BasicType].
41fn basic_type_undef<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> {
42    let t = t.as_basic_type_enum();
43    match t {
44        BasicTypeEnum::ArrayType(t) => t.get_undef().as_basic_value_enum(),
45        BasicTypeEnum::FloatType(t) => t.get_undef().as_basic_value_enum(),
46        BasicTypeEnum::IntType(t) => t.get_undef().as_basic_value_enum(),
47        BasicTypeEnum::PointerType(t) => t.get_undef().as_basic_value_enum(),
48        BasicTypeEnum::StructType(t) => t.get_undef().as_basic_value_enum(),
49        BasicTypeEnum::VectorType(t) => t.get_undef().as_basic_value_enum(),
50    }
51}
52
53/// Returns an `poison` value for any [BasicType].
54fn basic_type_poison<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> {
55    let t = t.as_basic_type_enum();
56    match t {
57        BasicTypeEnum::ArrayType(t) => t.get_poison().as_basic_value_enum(),
58        BasicTypeEnum::FloatType(t) => t.get_poison().as_basic_value_enum(),
59        BasicTypeEnum::IntType(t) => t.get_poison().as_basic_value_enum(),
60        BasicTypeEnum::PointerType(t) => t.get_poison().as_basic_value_enum(),
61        BasicTypeEnum::StructType(t) => t.get_poison().as_basic_value_enum(),
62        BasicTypeEnum::VectorType(t) => t.get_poison().as_basic_value_enum(),
63    }
64}
65
66#[derive(Debug, Clone, derive_more::Display)]
67/// The opaque representation of a [HugrSumType].
68///
69/// Provides an `impl`s of `BasicType`, allowing interoperation with other
70/// inkwell tools.
71///
72/// To obtain an [LLVMSumType] corresponding to a [HugrSumType] use
73/// [LLVMSumType::try_new] or [LLVMSumType::try_from_hugr_type].
74///
75/// Any such [LLVMSumType] has a fixed underlying LLVM type, which can be
76/// obtained by [BasicType::as_basic_type_enum] or [LLVMSumType::value_type].
77/// Note this type is unspecified, and we go to some effort to ensure that it is
78/// minimal and efficient. Users should not expect this type to remain the same
79/// across versions.
80///
81/// Unit types such as empty structs(`{}`) are elided from the LLVM type where
82/// possible. See [elidable_type] for the specification of which types are
83/// elided.
84///
85/// Each [LLVMSumType] has an associated [IntType] tag type, which can be
86/// obtained via [LLVMSumType::tag_type].
87///
88/// The value type [LLVMSumValue] represents values of this type. To obtain an
89/// [LLVMSumValue] use [LLVMSumType::build_tag] or [LLVMSumType::value].
90pub struct LLVMSumType<'c>(LLVMSumTypeEnum<'c>);
91
92impl<'c> LLVMSumType<'c> {
93    delegate! {
94        to self.0 {
95            /// The underlying LLVM type.
96            pub fn value_type(&self) -> BasicTypeEnum<'c>;
97            /// The type of the value that would be returned by [LLVMSumValue::build_get_tag].
98            pub fn tag_type(&self) -> IntType<'c>;
99            /// The number of variants in the represented [HugrSumType].
100            pub fn num_variants(&self) -> usize;
101            /// The number of fields in the `tag`th variant of the represented [HugrSumType].
102            /// Panics if `tag` is out of bounds.
103            pub fn num_fields_for_variant(&self, tag: usize) -> usize;
104            /// The LLVM types representing the fields in the `tag` variant of the represented [HugrSumType].
105            /// Panics if `tag` is out of bounds.
106            pub fn fields_for_variant(&self, tag: usize) -> &[BasicTypeEnum<'c>];
107        }
108    }
109
110    /// Constructs a new [LLVMSumType] from a [HugrSumType], using `session` to
111    /// determine the types of the fields.
112    ///
113    /// Returns an error if the type of any field cannot be converted by
114    /// `session`, or if `sum_type` has no variants.
115    pub fn try_from_hugr_type(
116        session: &TypingSession<'c, '_>,
117        sum_type: HugrSumType,
118    ) -> Result<Self> {
119        let variants = (0..sum_type.num_variants())
120            .map(|i| {
121                let tr = get_variant_typerow(&sum_type, i as u32)?;
122                tr.iter()
123                    .map(|t| session.llvm_type(t))
124                    .collect::<Result<Vec<_>>>()
125            })
126            .collect::<Result<Vec<_>>>()?;
127        Self::try_new(session.iw_context(), variants)
128    }
129
130    /// Constructs a new [LLVMSumType] from a `Vec` of variants.
131    /// Each variant is a `Vec` of LLVM types each corresponding to a field in the sum.
132    ///
133    /// Returns an error if `variant_types` is empty;
134    pub fn try_new(
135        context: &'c Context,
136        variant_types: impl Into<Vec<Vec<BasicTypeEnum<'c>>>>,
137    ) -> Result<Self> {
138        Ok(Self(LLVMSumTypeEnum::try_new(
139            context,
140            variant_types.into(),
141        )?))
142    }
143
144    /// Returns an constant `undef` value of the underlying LLVM type.
145    pub fn get_undef(&self) -> impl BasicValue<'c> {
146        basic_type_undef(self.0.value_type())
147    }
148
149    /// Returns an constant `poison` value of the underlying LLVM type.
150    pub fn get_poison(&self) -> impl BasicValue<'c> {
151        basic_type_poison(self.0.value_type())
152    }
153
154    /// Emits instructions to construct an [LLVMSumValue] of this type. The
155    /// value will represent the `tag`th variant.
156    pub fn build_tag(
157        &self,
158        builder: &Builder<'c>,
159        tag: usize,
160        vs: Vec<BasicValueEnum<'c>>,
161    ) -> Result<LLVMSumValue<'c>> {
162        self.value(self.0.build_tag(builder, tag, vs)?)
163    }
164
165    /// Returns an [LLVMSumValue] of this type.
166    ///
167    /// Returns an error if `value.get_type() != self.value_type()`.
168    pub fn value(&self, value: impl BasicValue<'c>) -> Result<LLVMSumValue<'c>> {
169        LLVMSumValue::try_new(value, self.clone())
170    }
171}
172
173/// The internal representation of a [HugrSumType].
174///
175/// This type is not public, so that it can be changed without breaking users.
176#[derive(Debug, Clone)]
177enum LLVMSumTypeEnum<'c> {
178    /// A Sum type with no variants. It's representation is unspecified.
179    ///
180    /// Values of this type can only be constructed by [get_poison].
181    Void { tag_type: IntType<'c> },
182    /// A Sum type with a single variant and all-elidable fields.
183    /// Represented by `{}`
184    /// Values of this type contain no information, so they never need to be
185    /// stored. One can always use `undef` to materialize a value of this type.
186    /// Represented by an empty struct.
187    Unit {
188        /// The LLVM types of the fields. One entry for each field in the Hugr
189        /// variant. Each field must be elidable.
190        field_types: Vec<BasicTypeEnum<'c>>,
191        /// The LLVM type of the tag. Always `i1` for now.
192        /// We store it here so because otherwise we would need a &[Context] to
193        /// construct it.
194        tag_type: IntType<'c>,
195        /// The underlying LLVM type. Always `{}` for now.
196        value_type: StructType<'c>,
197    },
198    /// A Sum type with more than one variant and all elidable fields.
199    /// Values of this type contain information only in their tag.
200    /// Represented by the value of their tag.
201    NoFields {
202        /// The LLVM types of the fields. One entry for each variant, with that
203        /// entry containing one entry per Hugr field in the variant. Each field
204        /// must be elidable.
205        variant_types: Vec<Vec<BasicTypeEnum<'c>>>,
206        /// The underlying LLVM type. For now it is the smallest integer type
207        /// large enough to index the variants.
208        value_type: IntType<'c>,
209    },
210    /// A Sum type with a single variant and exactly one non-elidable field.
211    /// Values of this type contain information only in the value of their
212    /// non-elidable field.
213    /// Represented by the value of their non-elidable field.
214    SingleVariantSingleField {
215        /// The LLVM types of the fields. One entry for each Hugr field in the single
216        /// variant.
217        field_types: Vec<BasicTypeEnum<'c>>,
218        /// The index into variant_types of the non-elidable field.
219        field_index: usize,
220        /// The LLVM type of the tag. Always `i1` for now.
221        /// We store it here so because otherwise we would need a &[Context] to
222        /// construct it.
223        tag_type: IntType<'c>,
224    },
225    /// A Sum type with a single variant and more than one non-elidable field.
226    /// Values of this type contain information in the values of their
227    /// non-elidable fields.
228    /// Represented by a struct containing each non-elidable field.
229    SingleVariantMultiField {
230        /// The LLVM types of the fields. One entry for each Hugr field in the
231        /// single variant.
232        field_types: Vec<BasicTypeEnum<'c>>,
233        /// For each field, an index into the fields of `value_type`
234        field_indices: Vec<Option<usize>>,
235        /// The LLVM type of the tag. Always `i1` for now.
236        /// We store it here so because otherwise we would need a &[Context] to
237        /// construct it.
238        tag_type: IntType<'c>,
239        /// The underlying LLVM type. Has one field for each non-elidable field
240        /// in the single variant.
241        value_type: StructType<'c>,
242    },
243    /// A Sum type with multiple variants and at least one non-elidable field.
244    /// Values of this type contain information in their tag and in the values
245    /// of their non-elidable fields.
246    /// Represented by a struct containing a tag and fields enough to store the
247    /// non-elidable fields of any one variant.
248    MultiVariant {
249        /// The LLVM types of the fields. One entry for each variant, with that
250        /// entry containing one entry per Hugr field in the variant.
251        variant_types: Vec<Vec<BasicTypeEnum<'c>>>,
252        /// For each field in each variant, an index into the fields of `value_type`.
253        field_indices: Vec<Vec<Option<usize>>>,
254        /// The underlying LLVM type. The first field is of `tag_type`. The
255        /// remaining fields are minimal such that any one variant can be
256        /// injectively mapped into those fields.
257        value_type: StructType<'c>,
258    },
259}
260
261/// Returns the smallest width for an integer type to be able to represent values smaller than `num_variants
262fn tag_width_for_num_variants(num_variants: usize) -> u32 {
263    debug_assert!(num_variants >= 1);
264    if num_variants == 1 {
265        return 1;
266    }
267    (num_variants - 1).ilog2() + 1
268}
269
270impl<'c> LLVMSumTypeEnum<'c> {
271    /// Constructs a new [LLVMSumTypeEnum] from a `Vec` of variants.
272    /// Each variant is a `Vec` of LLVM types each corresponding to a field in the sum.
273    pub fn try_new(
274        context: &'c Context,
275        variant_types: Vec<Vec<BasicTypeEnum<'c>>>,
276    ) -> Result<Self> {
277        let result = match variant_types.len() {
278            0 => Self::Void {
279                tag_type: context.bool_type(),
280            },
281            1 => {
282                let variant_types = variant_types.into_iter().exactly_one().unwrap();
283                let (fields, field_indices) =
284                    layout::layout_variants(slice::from_ref(&variant_types));
285                let field_indices = field_indices.into_iter().exactly_one().unwrap();
286                match fields.len() {
287                    0 => Self::Unit {
288                        field_types: variant_types,
289                        tag_type: context.bool_type(),
290                        value_type: context.struct_type(&[], false),
291                    },
292                    1 => {
293                        let field_index = field_indices
294                            .into_iter()
295                            .enumerate()
296                            .filter_map(|(i, f_i)| f_i.is_some().then_some(i))
297                            .exactly_one()
298                            .unwrap();
299                        Self::SingleVariantSingleField {
300                            field_types: variant_types,
301                            field_index,
302                            tag_type: context.bool_type(),
303                        }
304                    }
305                    _num_fields => Self::SingleVariantMultiField {
306                        field_types: variant_types,
307                        field_indices,
308                        tag_type: context.bool_type(),
309                        value_type: context.struct_type(&fields, false),
310                    },
311                }
312            }
313            num_variants => {
314                let (mut fields, field_indices) = layout::layout_variants(&variant_types);
315                let tag_type =
316                    context.custom_width_int_type(tag_width_for_num_variants(num_variants));
317                if fields.is_empty() {
318                    Self::NoFields {
319                        variant_types,
320                        value_type: tag_type,
321                    }
322                } else {
323                    // prefix the tag fields
324                    fields.insert(0, tag_type.into());
325                    let value_type = context.struct_type(&fields, false);
326                    Self::MultiVariant {
327                        variant_types,
328                        field_indices,
329                        value_type,
330                    }
331                }
332            }
333        };
334        Ok(result)
335    }
336
337    /// Emit instructions to build a value of type `LLVMSumType`, being of variant `tag`.
338    ///
339    /// Returns an error if:
340    ///   * `tag` is out of bounds
341    ///   * `vs` does not have a length equal to the length of the `tag`th
342    ///     variant of the represented Hugr type.
343    ///   * Any entry of `vs` does not have the expected type.
344    pub fn build_tag(
345        &self,
346        builder: &Builder<'c>,
347        tag: usize,
348        vs: Vec<BasicValueEnum<'c>>,
349    ) -> Result<BasicValueEnum<'c>> {
350        ensure!(tag < self.num_variants());
351        ensure!(vs.len() == self.num_fields_for_variant(tag));
352        ensure!(iter::zip(&vs, self.fields_for_variant(tag)).all(|(x, y)| &x.get_type() == y));
353        let value = match self {
354            Self::Void { .. } => bail!("Can't tag an empty sum"),
355            Self::Unit { value_type, .. } => value_type.get_undef().as_basic_value_enum(),
356            Self::NoFields { value_type, .. } => value_type
357                .const_int(tag as u64, false)
358                .as_basic_value_enum(),
359            Self::SingleVariantSingleField { field_index, .. } => vs[*field_index],
360            Self::SingleVariantMultiField {
361                value_type,
362                field_indices,
363                ..
364            } => {
365                let mut value = value_type.get_poison();
366                for (mb_i, v) in itertools::zip_eq(field_indices, vs) {
367                    if let Some(i) = mb_i {
368                        value = builder
369                            .build_insert_value(value, v, *i as u32, "")?
370                            .into_struct_value();
371                    }
372                }
373                value.as_basic_value_enum()
374            }
375            Self::MultiVariant {
376                field_indices,
377                variant_types,
378                value_type,
379            } => {
380                let variant_field_types = &variant_types[tag];
381                let variant_field_indices = &field_indices[tag];
382                let mut value = builder
383                    .build_insert_value(
384                        value_type.get_poison(),
385                        self.tag_type().const_int(tag as u64, false),
386                        0,
387                        "",
388                    )?
389                    .into_struct_value();
390                for (t, mb_i, v) in izip!(variant_field_types, variant_field_indices, vs) {
391                    ensure!(&v.get_type() == t);
392                    if let Some(i) = mb_i {
393                        value = builder
394                            .build_insert_value(value, v, *i as u32 + 1, "")?
395                            .into_struct_value();
396                    }
397                }
398                value.as_basic_value_enum()
399            }
400        };
401        debug_assert_eq!(value.get_type(), self.value_type());
402        Ok(value)
403    }
404
405    /// Get the type of the value that would be returned by `build_get_tag`.
406    pub fn tag_type(&self) -> IntType<'c> {
407        match self {
408            Self::Void { tag_type, .. } => *tag_type,
409            Self::Unit { tag_type, .. } => *tag_type,
410            Self::NoFields { value_type, .. } => *value_type,
411            Self::SingleVariantSingleField { tag_type, .. } => *tag_type,
412            Self::SingleVariantMultiField { tag_type, .. } => *tag_type,
413            Self::MultiVariant { value_type, .. } => value_type
414                .get_field_type_at_index(0)
415                .unwrap()
416                .into_int_type(),
417        }
418    }
419
420    /// The underlying LLVM type.
421    pub fn value_type(&self) -> BasicTypeEnum<'c> {
422        match self {
423            Self::Void { tag_type, .. } => (*tag_type).into(),
424            Self::Unit { value_type, .. } => (*value_type).into(),
425            Self::NoFields { value_type, .. } => (*value_type).into(),
426            Self::SingleVariantSingleField {
427                field_index,
428                field_types: variant_types,
429                ..
430            } => variant_types[*field_index],
431            Self::SingleVariantMultiField { value_type, .. }
432            | Self::MultiVariant { value_type, .. } => (*value_type).into(),
433        }
434    }
435
436    /// The number of variants in the represented [HugrSumType].
437    pub fn num_variants(&self) -> usize {
438        match self {
439            Self::Void { .. } => 0,
440            Self::Unit { .. }
441            | Self::SingleVariantSingleField { .. }
442            | Self::SingleVariantMultiField { .. } => 1,
443            Self::NoFields { variant_types, .. } | Self::MultiVariant { variant_types, .. } => {
444                variant_types.len()
445            }
446        }
447    }
448
449    /// The number of fields in the `tag`th variant of the represented [HugrSumType].
450    /// Panics if `tag` is out of bounds.
451    pub(self) fn num_fields_for_variant(&self, tag: usize) -> usize {
452        self.fields_for_variant(tag).len()
453    }
454
455    /// The LLVM types representing the fields in the `tag` variant of the
456    /// represented [HugrSumType].  Panics if `tag` is out of bounds.
457    pub(self) fn fields_for_variant(&self, tag: usize) -> &[BasicTypeEnum<'c>] {
458        assert!(tag < self.num_variants());
459        match self {
460            Self::Void { .. } => unreachable!("Void has no valid tag"),
461            Self::SingleVariantSingleField { field_types, .. }
462            | Self::SingleVariantMultiField { field_types, .. }
463            | Self::Unit { field_types, .. } => &field_types[..],
464            Self::MultiVariant { variant_types, .. } | Self::NoFields { variant_types, .. } => {
465                &variant_types[tag]
466            }
467        }
468    }
469}
470
471impl<'c> From<LLVMSumTypeEnum<'c>> for BasicTypeEnum<'c> {
472    fn from(value: LLVMSumTypeEnum<'c>) -> Self {
473        value.value_type()
474    }
475}
476
477impl std::fmt::Display for LLVMSumTypeEnum<'_> {
478    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
479        self.value_type().fmt(f)
480    }
481}
482
483unsafe impl AsTypeRef for LLVMSumType<'_> {
484    fn as_type_ref(&self) -> inkwell::llvm_sys::prelude::LLVMTypeRef {
485        BasicTypeEnum::from(self.0.clone()).as_type_ref()
486    }
487}
488
489unsafe impl<'c> AnyType<'c> for LLVMSumType<'c> {}
490
491unsafe impl<'c> BasicType<'c> for LLVMSumType<'c> {}
492
493/// A Value equivalent of [LLVMSumType]. Represents a [HugrSumType] Value on the
494/// wire, offering functions for inspecting and deconstructing such Values.
495#[derive(Debug)]
496pub struct LLVMSumValue<'c>(BasicValueEnum<'c>, LLVMSumType<'c>);
497
498impl<'c> From<LLVMSumValue<'c>> for BasicValueEnum<'c> {
499    fn from(value: LLVMSumValue<'c>) -> Self {
500        value.0.as_basic_value_enum()
501    }
502}
503
504unsafe impl AsValueRef for LLVMSumValue<'_> {
505    fn as_value_ref(&self) -> inkwell::llvm_sys::prelude::LLVMValueRef {
506        self.0.as_value_ref()
507    }
508}
509
510unsafe impl<'c> AnyValue<'c> for LLVMSumValue<'c> {}
511
512unsafe impl<'c> BasicValue<'c> for LLVMSumValue<'c> {}
513
514impl<'c> LLVMSumValue<'c> {
515    pub fn try_new(value: impl BasicValue<'c>, sum_type: LLVMSumType<'c>) -> Result<Self> {
516        let value = value.as_basic_value_enum();
517        ensure!(
518            !matches!(sum_type.0, LLVMSumTypeEnum::Void { .. }),
519            "Cannot construct LLVMSumValue of a Void sum"
520        );
521        ensure!(
522            value.get_type() == sum_type.value_type(),
523            "Cannot construct LLVMSumValue of type {sum_type} from value of type {}",
524            value.get_type()
525        );
526        Ok(Self(value, sum_type))
527    }
528
529    pub fn get_type(&self) -> LLVMSumType<'c> {
530        self.1.clone()
531    }
532
533    /// Emit instructions to read the tag of a value of type `LLVMSumType`.
534    ///
535    /// The type of the value is that returned by [LLVMSumType::tag_type].
536    pub fn build_get_tag(&self, builder: &Builder<'c>) -> Result<IntValue<'c>> {
537        let result = match self.get_type().0 {
538            LLVMSumTypeEnum::Void { .. } => bail!("Cannot get tag of void sum"),
539            LLVMSumTypeEnum::Unit { tag_type, .. }
540            | LLVMSumTypeEnum::SingleVariantSingleField { tag_type, .. }
541            | LLVMSumTypeEnum::SingleVariantMultiField { tag_type, .. } => {
542                anyhow::Ok(tag_type.const_int(0, false))
543            }
544            LLVMSumTypeEnum::NoFields { .. } => Ok(self.0.into_int_value()),
545            LLVMSumTypeEnum::MultiVariant { .. } => {
546                let value: StructValue = self.0.into_struct_value();
547                Ok(builder.build_extract_value(value, 0, "")?.into_int_value())
548            }
549        }?;
550        debug_assert_eq!(result.get_type(), self.tag_type());
551        Ok(result)
552    }
553
554    /// Emit instructions to read the inner values of a value of type
555    /// `LLVMSumType`, on the assumption that it's tag is `tag`.
556    ///
557    /// If it's tag is not `tag`, the returned values are unspecified.
558    pub fn build_untag(
559        &self,
560        builder: &Builder<'c>,
561        tag: usize,
562    ) -> Result<Vec<BasicValueEnum<'c>>> {
563        ensure!(tag < self.num_variants(), "Bad tag {tag} in {}", self.1);
564        let results =
565            match self.get_type().0 {
566                LLVMSumTypeEnum::Void { .. } => bail!("Cannot untag void sum"),
567                LLVMSumTypeEnum::Unit {
568                    field_types: variant_types,
569                    ..
570                } => anyhow::Ok(
571                    variant_types
572                        .into_iter()
573                        .map(basic_type_undef)
574                        .collect_vec(),
575                ),
576                LLVMSumTypeEnum::NoFields { variant_types, .. } => Ok(variant_types[tag]
577                    .iter()
578                    .copied()
579                    .map(basic_type_undef)
580                    .collect()),
581                LLVMSumTypeEnum::SingleVariantSingleField {
582                    field_types: variant_types,
583                    field_index,
584                    ..
585                } => Ok(variant_types
586                    .iter()
587                    .enumerate()
588                    .map(|(i, t)| {
589                        if i == field_index {
590                            self.0
591                        } else {
592                            basic_type_undef(*t)
593                        }
594                    })
595                    .collect()),
596                LLVMSumTypeEnum::SingleVariantMultiField {
597                    field_types: variant_types,
598                    field_indices,
599                    ..
600                } => itertools::zip_eq(variant_types, field_indices)
601                    .map(|(t, mb_i)| {
602                        if let Some(i) = mb_i {
603                            Ok(builder.build_extract_value(
604                                self.0.into_struct_value(),
605                                i as u32,
606                                "",
607                            )?)
608                        } else {
609                            Ok(basic_type_undef(t))
610                        }
611                    })
612                    .collect(),
613                LLVMSumTypeEnum::MultiVariant {
614                    variant_types,
615                    field_indices,
616                    ..
617                } => {
618                    let value = self.0.into_struct_value();
619                    itertools::zip_eq(&variant_types[tag], &field_indices[tag])
620                        .map(|(ty, mb_i)| {
621                            if let Some(i) = mb_i {
622                                Ok(builder.build_extract_value(value, *i as u32 + 1, "")?)
623                            } else {
624                                Ok(basic_type_undef(*ty))
625                            }
626                        })
627                        .collect()
628                }
629            }?;
630        #[cfg(debug_assertions)]
631        {
632            let result_types = results.iter().map(|x| x.get_type()).collect_vec();
633            assert_eq!(&result_types, self.get_type().fields_for_variant(tag));
634        }
635        Ok(results)
636    }
637
638    pub fn build_destructure(
639        &self,
640        builder: &Builder<'c>,
641        mut handler: impl FnMut(&Builder<'c>, usize, Vec<BasicValueEnum<'c>>) -> Result<()>,
642    ) -> Result<()> {
643        let orig_bb = builder
644            .get_insert_block()
645            .ok_or(anyhow!("No current insertion point"))?;
646        let context = orig_bb.get_context();
647        let mut last_bb = orig_bb;
648        let tag_ty = self.tag_type();
649
650        let mut cases = vec![];
651
652        for var_i in 0..self.1.num_variants() {
653            let bb = context.insert_basic_block_after(last_bb, "");
654            last_bb = bb;
655            cases.push((tag_ty.const_int(var_i as u64, false), bb));
656
657            builder.position_at_end(bb);
658            let inputs = self.build_untag(builder, var_i)?;
659            handler(builder, var_i, inputs)?;
660        }
661
662        builder.position_at_end(orig_bb);
663        let tag = self.build_get_tag(builder)?;
664        builder.build_switch(tag, cases[0].1, &cases[1..])?;
665
666        Ok(())
667    }
668
669    delegate! {
670        to self.1 {
671            /// Get the type of the value that would be returned by `build_get_tag`.
672            pub fn tag_type(&self) -> IntType<'c>;
673            /// The number of variants in the represented [HugrSumType].
674            pub fn num_variants(&self) -> usize;
675        }
676    }
677}
678
679#[cfg(test)]
680mod test {
681    use hugr_core::extension::prelude::{bool_t, usize_t};
682    use insta::assert_snapshot;
683    use rstest::{rstest, Context};
684
685    use crate::{
686        test::{llvm_ctx, TestContext},
687        types::HugrType,
688    };
689
690    use super::*;
691
692    #[rstest]
693    #[case(1, 1)]
694    #[case(2, 1)]
695    #[case(3, 2)]
696    #[case(4, 2)]
697    #[case(5, 3)]
698    #[case(8, 3)]
699    #[case(9, 4)]
700    fn tag_width(#[case] num_variants: usize, #[case] expected: u32) {
701        assert_eq!(tag_width_for_num_variants(num_variants), expected);
702    }
703
704    #[rstest]
705    fn sum_types(mut llvm_ctx: TestContext) {
706        llvm_ctx.add_extensions(|cem| cem.add_default_prelude_extensions());
707        let ts = llvm_ctx.get_typing_session();
708        let iwc = ts.iw_context();
709        let empty_struct = iwc.struct_type(&[], false).as_basic_type_enum();
710        let i1 = iwc.bool_type().as_basic_type_enum();
711        let i2 = iwc.custom_width_int_type(2).as_basic_type_enum();
712        let i64 = iwc.i64_type().as_basic_type_enum();
713
714        {
715            // no-variants -> i1
716            let hugr_type = HugrType::new_unit_sum(0);
717            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i1);
718        }
719
720        {
721            // one-variant-no-fields -> empty_struct
722            let hugr_type = HugrType::UNIT;
723            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), empty_struct.clone());
724        }
725
726        {
727            // one-variant-elidable-fields -> empty_struct
728            let hugr_type = HugrType::new_tuple(vec![HugrType::UNIT, HugrType::UNIT]);
729            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), empty_struct.clone());
730        }
731
732        {
733            // multi-variant-no-fields -> bare tag
734            let hugr_type = bool_t();
735            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i1);
736        }
737
738        {
739            // multi-variant-elidable-fields -> bare tag
740            let hugr_type = HugrType::new_sum(vec![vec![HugrType::UNIT]; 3]);
741            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i2);
742        }
743
744        {
745            // one-variant-one-field -> bare field
746            let hugr_type = HugrType::new_tuple(vec![usize_t()]);
747            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i64);
748        }
749
750        {
751            // one-variant-one-non-elidable-field -> bare field
752            let hugr_type = HugrType::new_tuple(vec![HugrType::UNIT, usize_t()]);
753            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i64);
754        }
755
756        {
757            // one-variant-multi-field -> struct-of-fields
758            let hugr_type = HugrType::new_tuple(vec![usize_t(), bool_t(), HugrType::UNIT]);
759            let llvm_type = iwc.struct_type(&[i64, i1], false).into();
760            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), llvm_type);
761        }
762
763        {
764            // multi-variant-multi-field -> struct-of-fields-with-tag
765            let hugr_type1 =
766                HugrType::new_sum([vec![bool_t(), HugrType::UNIT, usize_t()], vec![usize_t()]]);
767            let hugr_type2 = HugrType::new_sum([vec![usize_t(), bool_t()], vec![usize_t()]]);
768            let llvm_type = iwc.struct_type(&[i1, i64, i1], false).into();
769            assert_eq!(ts.llvm_type(&hugr_type1).unwrap(), llvm_type);
770            assert_eq!(ts.llvm_type(&hugr_type2).unwrap(), llvm_type);
771        }
772    }
773
774    #[rstest]
775    #[case::unit(HugrSumType::new_unary(1), 0)]
776    #[case::unit_elided_fields(HugrSumType::new([HugrType::UNIT]), 0)]
777    #[case::nofields(HugrSumType::new_unary(4), 2)]
778    #[case::nofields_elided_fields(HugrSumType::new([vec![HugrType::UNIT], vec![]]), 0)]
779    #[case::one_variant_one_field(HugrSumType::new([bool_t()]), 0)]
780    #[case::one_variant_one_field_elided_fields(HugrSumType::new([vec![HugrType::UNIT,bool_t()]]), 0)]
781    #[case::one_variant_two_fields(HugrSumType::new([vec![bool_t(),bool_t()]]), 0)]
782    #[case::one_variant_two_fields_elided_fields(HugrSumType::new([vec![bool_t(),HugrType::UNIT,bool_t()]]), 0)]
783    #[case::two_variant_one_field(HugrSumType::new([vec![bool_t()],vec![]]), 1)]
784    #[case::two_variant_one_field_elided_fields(HugrSumType::new([vec![bool_t()],vec![HugrType::UNIT]]), 1)]
785    fn build_untag_tag(
786        #[context] rstest_ctx: Context,
787        llvm_ctx: TestContext,
788        #[case] sum: HugrSumType,
789        #[case] tag: usize,
790    ) {
791        let module = {
792            let ts = llvm_ctx.get_typing_session();
793            let iwc = llvm_ctx.iw_context();
794            let module = iwc.create_module("");
795            let llvm_ty = ts.llvm_sum_type(sum.clone()).unwrap();
796            let func_ty = llvm_ty.fn_type(&[llvm_ty.as_basic_type_enum().into()], false);
797            let func = module.add_function("untag_tag", func_ty, None);
798            let bb = iwc.append_basic_block(func, "");
799            let builder = iwc.create_builder();
800            builder.position_at_end(bb);
801            let value = llvm_ty.value(func.get_nth_param(0).unwrap()).unwrap();
802            let _tag = value.build_get_tag(&builder).unwrap();
803            let fields = value.build_untag(&builder, tag).unwrap();
804            let new_value = llvm_ty.build_tag(&builder, tag, fields).unwrap();
805            let _ = builder.build_return(Some(&new_value));
806            module.verify().unwrap();
807            module
808        };
809
810        let mut insta_settings = insta::Settings::clone_current();
811        insta_settings.set_snapshot_suffix(rstest_ctx.description.unwrap());
812        insta_settings.bind(|| {
813            assert_snapshot!(module.to_string());
814        });
815    }
816}