hugr_llvm/
sum.rs

1mod layout;
2
3use std::{iter, slice};
4
5use crate::types::{HugrSumType, TypingSession};
6
7use anyhow::{Result, anyhow, bail, ensure};
8use delegate::delegate;
9use hugr_core::types::TypeRow;
10use inkwell::{
11    builder::Builder,
12    context::Context,
13    types::{AnyType, AsTypeRef, BasicType, BasicTypeEnum, IntType, StructType},
14    values::{AnyValue, AsValueRef, BasicValue, BasicValueEnum, IntValue, StructValue},
15};
16use itertools::{Itertools as _, izip};
17
18/// An elidable type is one that holds no information, for example `{}`, the
19/// empty struct.
20///
21/// Currently the following types are elidable:
22///   * Empty structs, which may be packed, unpacked, named, or unnamed
23///   * Empty arrays of any type.
24pub fn elidable_type<'c>(ty: impl BasicType<'c>) -> bool {
25    let ty = ty.as_basic_type_enum();
26    match ty {
27        BasicTypeEnum::ArrayType(array_type) => array_type.is_empty(),
28        BasicTypeEnum::StructType(struct_type) => struct_type.count_fields() == 0,
29        _ => false,
30    }
31}
32
33fn get_variant_typerow(sum_type: &HugrSumType, tag: u32) -> Result<TypeRow> {
34    sum_type
35        .get_variant(tag as usize)
36        .ok_or(anyhow!("Bad variant index {tag} in {sum_type}"))
37        .and_then(|tr| Ok(TypeRow::try_from(tr.clone())?))
38}
39
40/// Returns an `undef` value for any [`BasicType`].
41fn basic_type_undef<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> {
42    let t = t.as_basic_type_enum();
43    match t {
44        BasicTypeEnum::ArrayType(t) => t.get_undef().as_basic_value_enum(),
45        BasicTypeEnum::FloatType(t) => t.get_undef().as_basic_value_enum(),
46        BasicTypeEnum::IntType(t) => t.get_undef().as_basic_value_enum(),
47        BasicTypeEnum::PointerType(t) => t.get_undef().as_basic_value_enum(),
48        BasicTypeEnum::StructType(t) => t.get_undef().as_basic_value_enum(),
49        BasicTypeEnum::VectorType(t) => t.get_undef().as_basic_value_enum(),
50        BasicTypeEnum::ScalableVectorType(t) => t.get_undef().as_basic_value_enum(),
51    }
52}
53
54/// Returns an `poison` value for any [`BasicType`].
55fn basic_type_poison<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> {
56    let t = t.as_basic_type_enum();
57    match t {
58        BasicTypeEnum::ArrayType(t) => t.get_poison().as_basic_value_enum(),
59        BasicTypeEnum::FloatType(t) => t.get_poison().as_basic_value_enum(),
60        BasicTypeEnum::IntType(t) => t.get_poison().as_basic_value_enum(),
61        BasicTypeEnum::PointerType(t) => t.get_poison().as_basic_value_enum(),
62        BasicTypeEnum::StructType(t) => t.get_poison().as_basic_value_enum(),
63        BasicTypeEnum::VectorType(t) => t.get_poison().as_basic_value_enum(),
64        BasicTypeEnum::ScalableVectorType(t) => t.get_poison().as_basic_value_enum(),
65    }
66}
67
68#[derive(Debug, Clone, derive_more::Display)]
69/// The opaque representation of a [`HugrSumType`].
70///
71/// Provides an `impl`s of `BasicType`, allowing interoperation with other
72/// inkwell tools.
73///
74/// To obtain an [`LLVMSumType`] corresponding to a [`HugrSumType`] use
75/// [`LLVMSumType::try_new`] or [`LLVMSumType::try_from_hugr_type`].
76///
77/// Any such [`LLVMSumType`] has a fixed underlying LLVM type, which can be
78/// obtained by [`BasicType::as_basic_type_enum`] or [`LLVMSumType::value_type`].
79/// Note this type is unspecified, and we go to some effort to ensure that it is
80/// minimal and efficient. Users should not expect this type to remain the same
81/// across versions.
82///
83/// Unit types such as empty structs(`{}`) are elided from the LLVM type where
84/// possible. See [`elidable_type`] for the specification of which types are
85/// elided.
86///
87/// Each [`LLVMSumType`] has an associated [`IntType`] tag type, which can be
88/// obtained via [`LLVMSumType::tag_type`].
89///
90/// The value type [`LLVMSumValue`] represents values of this type. To obtain an
91/// [`LLVMSumValue`] use [`LLVMSumType::build_tag`] or [`LLVMSumType::value`].
92pub struct LLVMSumType<'c>(LLVMSumTypeEnum<'c>);
93
94impl<'c> LLVMSumType<'c> {
95    delegate! {
96        to self.0 {
97            /// The underlying LLVM type.
98            #[must_use] pub fn value_type(&self) -> BasicTypeEnum<'c>;
99            /// The type of the value that would be returned by [LLVMSumValue::build_get_tag].
100            #[must_use] pub fn tag_type(&self) -> IntType<'c>;
101            /// The number of variants in the represented [HugrSumType].
102            #[must_use] pub fn num_variants(&self) -> usize;
103            /// The number of fields in the `tag`th variant of the represented [HugrSumType].
104            /// Panics if `tag` is out of bounds.
105            #[must_use] pub fn num_fields_for_variant(&self, tag: usize) -> usize;
106            /// The LLVM types representing the fields in the `tag` variant of the represented [HugrSumType].
107            /// Panics if `tag` is out of bounds.
108            #[must_use] pub fn fields_for_variant(&self, tag: usize) -> &[BasicTypeEnum<'c>];
109        }
110    }
111
112    /// Constructs a new [`LLVMSumType`] from a [`HugrSumType`], using `session` to
113    /// determine the types of the fields.
114    ///
115    /// Returns an error if the type of any field cannot be converted by
116    /// `session`, or if `sum_type` has no variants.
117    pub fn try_from_hugr_type(
118        session: &TypingSession<'c, '_>,
119        sum_type: HugrSumType,
120    ) -> Result<Self> {
121        let variants = (0..sum_type.num_variants())
122            .map(|i| {
123                let tr = get_variant_typerow(&sum_type, i as u32)?;
124                tr.iter()
125                    .map(|t| session.llvm_type(t))
126                    .collect::<Result<Vec<_>>>()
127            })
128            .collect::<Result<Vec<_>>>()?;
129        Self::try_new(session.iw_context(), variants)
130    }
131
132    /// Constructs a new [`LLVMSumType`] from a `Vec` of variants.
133    /// Each variant is a `Vec` of LLVM types each corresponding to a field in the sum.
134    ///
135    /// Returns an error if `variant_types` is empty;
136    pub fn try_new(
137        context: &'c Context,
138        variant_types: impl Into<Vec<Vec<BasicTypeEnum<'c>>>>,
139    ) -> Result<Self> {
140        Ok(Self(LLVMSumTypeEnum::try_new(
141            context,
142            variant_types.into(),
143        )?))
144    }
145
146    /// Returns an constant `undef` value of the underlying LLVM type.
147    #[must_use]
148    pub fn get_undef(&self) -> impl BasicValue<'c> + use<'c> {
149        basic_type_undef(self.0.value_type())
150    }
151
152    /// Returns an constant `poison` value of the underlying LLVM type.
153    #[must_use]
154    pub fn get_poison(&self) -> impl BasicValue<'c> + use<'c> {
155        basic_type_poison(self.0.value_type())
156    }
157
158    /// Emits instructions to construct an [`LLVMSumValue`] of this type. The
159    /// value will represent the `tag`th variant.
160    pub fn build_tag(
161        &self,
162        builder: &Builder<'c>,
163        tag: usize,
164        vs: Vec<BasicValueEnum<'c>>,
165    ) -> Result<LLVMSumValue<'c>> {
166        self.value(self.0.build_tag(builder, tag, vs)?)
167    }
168
169    /// Returns an [`LLVMSumValue`] of this type.
170    ///
171    /// Returns an error if `value.get_type() != self.value_type()`.
172    pub fn value(&self, value: impl BasicValue<'c>) -> Result<LLVMSumValue<'c>> {
173        LLVMSumValue::try_new(value, self.clone())
174    }
175}
176
177/// The internal representation of a [`HugrSumType`].
178///
179/// This type is not public, so that it can be changed without breaking users.
180#[derive(Debug, Clone)]
181enum LLVMSumTypeEnum<'c> {
182    /// A Sum type with no variants. It's representation is unspecified.
183    ///
184    /// Values of this type can only be constructed by [`get_poison`].
185    Void { tag_type: IntType<'c> },
186    /// A Sum type with a single variant and all-elidable fields.
187    /// Represented by `{}`
188    /// Values of this type contain no information, so they never need to be
189    /// stored. One can always use `undef` to materialize a value of this type.
190    /// Represented by an empty struct.
191    Unit {
192        /// The LLVM types of the fields. One entry for each field in the Hugr
193        /// variant. Each field must be elidable.
194        field_types: Vec<BasicTypeEnum<'c>>,
195        /// The LLVM type of the tag. Always `i1` for now.
196        /// We store it here so because otherwise we would need a &[Context] to
197        /// construct it.
198        tag_type: IntType<'c>,
199        /// The underlying LLVM type. Always `{}` for now.
200        value_type: StructType<'c>,
201    },
202    /// A Sum type with more than one variant and all elidable fields.
203    /// Values of this type contain information only in their tag.
204    /// Represented by the value of their tag.
205    NoFields {
206        /// The LLVM types of the fields. One entry for each variant, with that
207        /// entry containing one entry per Hugr field in the variant. Each field
208        /// must be elidable.
209        variant_types: Vec<Vec<BasicTypeEnum<'c>>>,
210        /// The underlying LLVM type. For now it is the smallest integer type
211        /// large enough to index the variants.
212        value_type: IntType<'c>,
213    },
214    /// A Sum type with a single variant and exactly one non-elidable field.
215    /// Values of this type contain information only in the value of their
216    /// non-elidable field.
217    /// Represented by the value of their non-elidable field.
218    SingleVariantSingleField {
219        /// The LLVM types of the fields. One entry for each Hugr field in the single
220        /// variant.
221        field_types: Vec<BasicTypeEnum<'c>>,
222        /// The index into `variant_types` of the non-elidable field.
223        field_index: usize,
224        /// The LLVM type of the tag. Always `i1` for now.
225        /// We store it here so because otherwise we would need a &[Context] to
226        /// construct it.
227        tag_type: IntType<'c>,
228    },
229    /// A Sum type with a single variant and more than one non-elidable field.
230    /// Values of this type contain information in the values of their
231    /// non-elidable fields.
232    /// Represented by a struct containing each non-elidable field.
233    SingleVariantMultiField {
234        /// The LLVM types of the fields. One entry for each Hugr field in the
235        /// single variant.
236        field_types: Vec<BasicTypeEnum<'c>>,
237        /// For each field, an index into the fields of `value_type`
238        field_indices: Vec<Option<usize>>,
239        /// The LLVM type of the tag. Always `i1` for now.
240        /// We store it here so because otherwise we would need a &[Context] to
241        /// construct it.
242        tag_type: IntType<'c>,
243        /// The underlying LLVM type. Has one field for each non-elidable field
244        /// in the single variant.
245        value_type: StructType<'c>,
246    },
247    /// A Sum type with multiple variants and at least one non-elidable field.
248    /// Values of this type contain information in their tag and in the values
249    /// of their non-elidable fields.
250    /// Represented by a struct containing a tag and fields enough to store the
251    /// non-elidable fields of any one variant.
252    MultiVariant {
253        /// The LLVM types of the fields. One entry for each variant, with that
254        /// entry containing one entry per Hugr field in the variant.
255        variant_types: Vec<Vec<BasicTypeEnum<'c>>>,
256        /// For each field in each variant, an index into the fields of `value_type`.
257        field_indices: Vec<Vec<Option<usize>>>,
258        /// The underlying LLVM type. The first field is of `tag_type`. The
259        /// remaining fields are minimal such that any one variant can be
260        /// injectively mapped into those fields.
261        value_type: StructType<'c>,
262    },
263}
264
265/// Returns the smallest width for an integer type to be able to represent values smaller than `num_variants
266fn tag_width_for_num_variants(num_variants: usize) -> u32 {
267    debug_assert!(num_variants >= 1);
268    if num_variants == 1 {
269        return 1;
270    }
271    (num_variants - 1).ilog2() + 1
272}
273
274impl<'c> LLVMSumTypeEnum<'c> {
275    /// Constructs a new [`LLVMSumTypeEnum`] from a `Vec` of variants.
276    /// Each variant is a `Vec` of LLVM types each corresponding to a field in the sum.
277    pub fn try_new(
278        context: &'c Context,
279        variant_types: Vec<Vec<BasicTypeEnum<'c>>>,
280    ) -> Result<Self> {
281        let result = match variant_types.len() {
282            0 => Self::Void {
283                tag_type: context.bool_type(),
284            },
285            1 => {
286                let variant_types = variant_types.into_iter().exactly_one().unwrap();
287                let (fields, field_indices) =
288                    layout::layout_variants(slice::from_ref(&variant_types));
289                let field_indices = field_indices.into_iter().exactly_one().unwrap();
290                match fields.len() {
291                    0 => Self::Unit {
292                        field_types: variant_types,
293                        tag_type: context.bool_type(),
294                        value_type: context.struct_type(&[], false),
295                    },
296                    1 => {
297                        let field_index = field_indices
298                            .into_iter()
299                            .enumerate()
300                            .filter_map(|(i, f_i)| f_i.is_some().then_some(i))
301                            .exactly_one()
302                            .unwrap();
303                        Self::SingleVariantSingleField {
304                            field_types: variant_types,
305                            field_index,
306                            tag_type: context.bool_type(),
307                        }
308                    }
309                    _num_fields => Self::SingleVariantMultiField {
310                        field_types: variant_types,
311                        field_indices,
312                        tag_type: context.bool_type(),
313                        value_type: context.struct_type(&fields, false),
314                    },
315                }
316            }
317            num_variants => {
318                let (mut fields, field_indices) = layout::layout_variants(&variant_types);
319                let tag_type =
320                    context.custom_width_int_type(tag_width_for_num_variants(num_variants));
321                if fields.is_empty() {
322                    Self::NoFields {
323                        variant_types,
324                        value_type: tag_type,
325                    }
326                } else {
327                    // prefix the tag fields
328                    fields.insert(0, tag_type.into());
329                    let value_type = context.struct_type(&fields, false);
330                    Self::MultiVariant {
331                        variant_types,
332                        field_indices,
333                        value_type,
334                    }
335                }
336            }
337        };
338        Ok(result)
339    }
340
341    /// Emit instructions to build a value of type `LLVMSumType`, being of variant `tag`.
342    ///
343    /// Returns an error if:
344    ///   * `tag` is out of bounds
345    ///   * `vs` does not have a length equal to the length of the `tag`th
346    ///     variant of the represented Hugr type.
347    ///   * Any entry of `vs` does not have the expected type.
348    pub fn build_tag(
349        &self,
350        builder: &Builder<'c>,
351        tag: usize,
352        vs: Vec<BasicValueEnum<'c>>,
353    ) -> Result<BasicValueEnum<'c>> {
354        ensure!(tag < self.num_variants());
355        ensure!(vs.len() == self.num_fields_for_variant(tag));
356        ensure!(iter::zip(&vs, self.fields_for_variant(tag)).all(|(x, y)| &x.get_type() == y));
357        let value = match self {
358            Self::Void { .. } => bail!("Can't tag an empty sum"),
359            Self::Unit { value_type, .. } => value_type.get_undef().as_basic_value_enum(),
360            Self::NoFields { value_type, .. } => value_type
361                .const_int(tag as u64, false)
362                .as_basic_value_enum(),
363            Self::SingleVariantSingleField { field_index, .. } => vs[*field_index],
364            Self::SingleVariantMultiField {
365                value_type,
366                field_indices,
367                ..
368            } => {
369                let mut value = value_type.get_poison();
370                for (mb_i, v) in itertools::zip_eq(field_indices, vs) {
371                    if let Some(i) = mb_i {
372                        value = builder
373                            .build_insert_value(value, v, *i as u32, "")?
374                            .into_struct_value();
375                    }
376                }
377                value.as_basic_value_enum()
378            }
379            Self::MultiVariant {
380                field_indices,
381                variant_types,
382                value_type,
383            } => {
384                let variant_field_types = &variant_types[tag];
385                let variant_field_indices = &field_indices[tag];
386                let mut value = builder
387                    .build_insert_value(
388                        value_type.get_poison(),
389                        self.tag_type().const_int(tag as u64, false),
390                        0,
391                        "",
392                    )?
393                    .into_struct_value();
394                for (t, mb_i, v) in izip!(variant_field_types, variant_field_indices, vs) {
395                    ensure!(&v.get_type() == t);
396                    if let Some(i) = mb_i {
397                        value = builder
398                            .build_insert_value(value, v, *i as u32 + 1, "")?
399                            .into_struct_value();
400                    }
401                }
402                value.as_basic_value_enum()
403            }
404        };
405        debug_assert_eq!(value.get_type(), self.value_type());
406        Ok(value)
407    }
408
409    /// Get the type of the value that would be returned by `build_get_tag`.
410    pub fn tag_type(&self) -> IntType<'c> {
411        match self {
412            Self::Void { tag_type, .. } => *tag_type,
413            Self::Unit { tag_type, .. } => *tag_type,
414            Self::NoFields { value_type, .. } => *value_type,
415            Self::SingleVariantSingleField { tag_type, .. } => *tag_type,
416            Self::SingleVariantMultiField { tag_type, .. } => *tag_type,
417            Self::MultiVariant { value_type, .. } => value_type
418                .get_field_type_at_index(0)
419                .unwrap()
420                .into_int_type(),
421        }
422    }
423
424    /// The underlying LLVM type.
425    pub fn value_type(&self) -> BasicTypeEnum<'c> {
426        match self {
427            Self::Void { tag_type, .. } => (*tag_type).into(),
428            Self::Unit { value_type, .. } => (*value_type).into(),
429            Self::NoFields { value_type, .. } => (*value_type).into(),
430            Self::SingleVariantSingleField {
431                field_index,
432                field_types: variant_types,
433                ..
434            } => variant_types[*field_index],
435            Self::SingleVariantMultiField { value_type, .. }
436            | Self::MultiVariant { value_type, .. } => (*value_type).into(),
437        }
438    }
439
440    /// The number of variants in the represented [`HugrSumType`].
441    pub fn num_variants(&self) -> usize {
442        match self {
443            Self::Void { .. } => 0,
444            Self::Unit { .. }
445            | Self::SingleVariantSingleField { .. }
446            | Self::SingleVariantMultiField { .. } => 1,
447            Self::NoFields { variant_types, .. } | Self::MultiVariant { variant_types, .. } => {
448                variant_types.len()
449            }
450        }
451    }
452
453    /// The number of fields in the `tag`th variant of the represented [`HugrSumType`].
454    /// Panics if `tag` is out of bounds.
455    pub(self) fn num_fields_for_variant(&self, tag: usize) -> usize {
456        self.fields_for_variant(tag).len()
457    }
458
459    /// The LLVM types representing the fields in the `tag` variant of the
460    /// represented [`HugrSumType`].  Panics if `tag` is out of bounds.
461    pub(self) fn fields_for_variant(&self, tag: usize) -> &[BasicTypeEnum<'c>] {
462        assert!(tag < self.num_variants());
463        match self {
464            Self::Void { .. } => unreachable!("Void has no valid tag"),
465            Self::SingleVariantSingleField { field_types, .. }
466            | Self::SingleVariantMultiField { field_types, .. }
467            | Self::Unit { field_types, .. } => &field_types[..],
468            Self::MultiVariant { variant_types, .. } | Self::NoFields { variant_types, .. } => {
469                &variant_types[tag]
470            }
471        }
472    }
473}
474
475impl<'c> From<LLVMSumTypeEnum<'c>> for BasicTypeEnum<'c> {
476    fn from(value: LLVMSumTypeEnum<'c>) -> Self {
477        value.value_type()
478    }
479}
480
481impl std::fmt::Display for LLVMSumTypeEnum<'_> {
482    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483        self.value_type().fmt(f)
484    }
485}
486
487unsafe impl AsTypeRef for LLVMSumType<'_> {
488    fn as_type_ref(&self) -> inkwell::llvm_sys::prelude::LLVMTypeRef {
489        BasicTypeEnum::from(self.0.clone()).as_type_ref()
490    }
491}
492
493unsafe impl<'c> AnyType<'c> for LLVMSumType<'c> {}
494
495unsafe impl<'c> BasicType<'c> for LLVMSumType<'c> {}
496
497/// A Value equivalent of [`LLVMSumType`]. Represents a [`HugrSumType`] Value on the
498/// wire, offering functions for inspecting and deconstructing such Values.
499#[derive(Debug)]
500pub struct LLVMSumValue<'c>(BasicValueEnum<'c>, LLVMSumType<'c>);
501
502impl<'c> From<LLVMSumValue<'c>> for BasicValueEnum<'c> {
503    fn from(value: LLVMSumValue<'c>) -> Self {
504        value.0.as_basic_value_enum()
505    }
506}
507
508unsafe impl AsValueRef for LLVMSumValue<'_> {
509    fn as_value_ref(&self) -> inkwell::llvm_sys::prelude::LLVMValueRef {
510        self.0.as_value_ref()
511    }
512}
513
514unsafe impl<'c> AnyValue<'c> for LLVMSumValue<'c> {}
515
516unsafe impl<'c> BasicValue<'c> for LLVMSumValue<'c> {}
517
518impl<'c> LLVMSumValue<'c> {
519    pub fn try_new(value: impl BasicValue<'c>, sum_type: LLVMSumType<'c>) -> Result<Self> {
520        let value = value.as_basic_value_enum();
521        ensure!(
522            !matches!(sum_type.0, LLVMSumTypeEnum::Void { .. }),
523            "Cannot construct LLVMSumValue of a Void sum"
524        );
525        ensure!(
526            value.get_type() == sum_type.value_type(),
527            "Cannot construct LLVMSumValue of type {sum_type} from value of type {}",
528            value.get_type()
529        );
530        Ok(Self(value, sum_type))
531    }
532
533    #[must_use]
534    pub fn get_type(&self) -> LLVMSumType<'c> {
535        self.1.clone()
536    }
537
538    /// Emit instructions to read the tag of a value of type `LLVMSumType`.
539    ///
540    /// The type of the value is that returned by [`LLVMSumType::tag_type`].
541    pub fn build_get_tag(&self, builder: &Builder<'c>) -> Result<IntValue<'c>> {
542        let result = match self.get_type().0 {
543            LLVMSumTypeEnum::Void { .. } => bail!("Cannot get tag of void sum"),
544            LLVMSumTypeEnum::Unit { tag_type, .. }
545            | LLVMSumTypeEnum::SingleVariantSingleField { tag_type, .. }
546            | LLVMSumTypeEnum::SingleVariantMultiField { tag_type, .. } => {
547                anyhow::Ok(tag_type.const_int(0, false))
548            }
549            LLVMSumTypeEnum::NoFields { .. } => Ok(self.0.into_int_value()),
550            LLVMSumTypeEnum::MultiVariant { .. } => {
551                let value: StructValue = self.0.into_struct_value();
552                Ok(builder.build_extract_value(value, 0, "")?.into_int_value())
553            }
554        }?;
555        debug_assert_eq!(result.get_type(), self.tag_type());
556        Ok(result)
557    }
558
559    /// Emit instructions to read the inner values of a value of type
560    /// `LLVMSumType`, on the assumption that it's tag is `tag`.
561    ///
562    /// If it's tag is not `tag`, the returned values are unspecified.
563    pub fn build_untag(
564        &self,
565        builder: &Builder<'c>,
566        tag: usize,
567    ) -> Result<Vec<BasicValueEnum<'c>>> {
568        ensure!(tag < self.num_variants(), "Bad tag {tag} in {}", self.1);
569        let results =
570            match self.get_type().0 {
571                LLVMSumTypeEnum::Void { .. } => bail!("Cannot untag void sum"),
572                LLVMSumTypeEnum::Unit {
573                    field_types: variant_types,
574                    ..
575                } => anyhow::Ok(
576                    variant_types
577                        .into_iter()
578                        .map(basic_type_undef)
579                        .collect_vec(),
580                ),
581                LLVMSumTypeEnum::NoFields { variant_types, .. } => Ok(variant_types[tag]
582                    .iter()
583                    .copied()
584                    .map(basic_type_undef)
585                    .collect()),
586                LLVMSumTypeEnum::SingleVariantSingleField {
587                    field_types: variant_types,
588                    field_index,
589                    ..
590                } => Ok(variant_types
591                    .iter()
592                    .enumerate()
593                    .map(|(i, t)| {
594                        if i == field_index {
595                            self.0
596                        } else {
597                            basic_type_undef(*t)
598                        }
599                    })
600                    .collect()),
601                LLVMSumTypeEnum::SingleVariantMultiField {
602                    field_types: variant_types,
603                    field_indices,
604                    ..
605                } => itertools::zip_eq(variant_types, field_indices)
606                    .map(|(t, mb_i)| {
607                        if let Some(i) = mb_i {
608                            Ok(builder.build_extract_value(
609                                self.0.into_struct_value(),
610                                i as u32,
611                                "",
612                            )?)
613                        } else {
614                            Ok(basic_type_undef(t))
615                        }
616                    })
617                    .collect(),
618                LLVMSumTypeEnum::MultiVariant {
619                    variant_types,
620                    field_indices,
621                    ..
622                } => {
623                    let value = self.0.into_struct_value();
624                    itertools::zip_eq(&variant_types[tag], &field_indices[tag])
625                        .map(|(ty, mb_i)| {
626                            if let Some(i) = mb_i {
627                                Ok(builder.build_extract_value(value, *i as u32 + 1, "")?)
628                            } else {
629                                Ok(basic_type_undef(*ty))
630                            }
631                        })
632                        .collect()
633                }
634            }?;
635        #[cfg(debug_assertions)]
636        {
637            let result_types = results
638                .iter()
639                .map(inkwell::values::BasicValueEnum::get_type)
640                .collect_vec();
641            assert_eq!(&result_types, self.get_type().fields_for_variant(tag));
642        }
643        Ok(results)
644    }
645
646    pub fn build_destructure(
647        &self,
648        builder: &Builder<'c>,
649        mut handler: impl FnMut(&Builder<'c>, usize, Vec<BasicValueEnum<'c>>) -> Result<()>,
650    ) -> Result<()> {
651        let orig_bb = builder
652            .get_insert_block()
653            .ok_or(anyhow!("No current insertion point"))?;
654        let context = orig_bb.get_context();
655        let mut last_bb = orig_bb;
656        let tag_ty = self.tag_type();
657
658        let mut cases = vec![];
659
660        for var_i in 0..self.1.num_variants() {
661            let bb = context.insert_basic_block_after(last_bb, "");
662            last_bb = bb;
663            cases.push((tag_ty.const_int(var_i as u64, false), bb));
664
665            builder.position_at_end(bb);
666            let inputs = self.build_untag(builder, var_i)?;
667            handler(builder, var_i, inputs)?;
668        }
669
670        builder.position_at_end(orig_bb);
671        let tag = self.build_get_tag(builder)?;
672        builder.build_switch(tag, cases[0].1, &cases[1..])?;
673
674        Ok(())
675    }
676
677    delegate! {
678        to self.1 {
679            /// Get the type of the value that would be returned by `build_get_tag`.
680            #[must_use] pub fn tag_type(&self) -> IntType<'c>;
681            /// The number of variants in the represented [HugrSumType].
682            #[must_use] pub fn num_variants(&self) -> usize;
683        }
684    }
685}
686
687#[cfg(test)]
688mod test {
689    use hugr_core::extension::prelude::{bool_t, usize_t};
690    use insta::assert_snapshot;
691    use rstest::{Context, rstest};
692
693    use crate::{
694        test::{TestContext, llvm_ctx},
695        types::HugrType,
696    };
697
698    use super::*;
699
700    #[rstest]
701    #[case(1, 1)]
702    #[case(2, 1)]
703    #[case(3, 2)]
704    #[case(4, 2)]
705    #[case(5, 3)]
706    #[case(8, 3)]
707    #[case(9, 4)]
708    fn tag_width(#[case] num_variants: usize, #[case] expected: u32) {
709        assert_eq!(tag_width_for_num_variants(num_variants), expected);
710    }
711
712    #[rstest]
713    fn sum_types(mut llvm_ctx: TestContext) {
714        llvm_ctx.add_extensions(
715            super::super::custom::CodegenExtsBuilder::add_default_prelude_extensions,
716        );
717        let ts = llvm_ctx.get_typing_session();
718        let iwc = ts.iw_context();
719        let empty_struct = iwc.struct_type(&[], false).as_basic_type_enum();
720        let i1 = iwc.bool_type().as_basic_type_enum();
721        let i2 = iwc.custom_width_int_type(2).as_basic_type_enum();
722        let i64 = iwc.i64_type().as_basic_type_enum();
723
724        {
725            // no-variants -> i1
726            let hugr_type = HugrType::new_unit_sum(0);
727            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i1);
728        }
729
730        {
731            // one-variant-no-fields -> empty_struct
732            let hugr_type = HugrType::UNIT;
733            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), empty_struct.clone());
734        }
735
736        {
737            // one-variant-elidable-fields -> empty_struct
738            let hugr_type = HugrType::new_tuple(vec![HugrType::UNIT, HugrType::UNIT]);
739            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), empty_struct.clone());
740        }
741
742        {
743            // multi-variant-no-fields -> bare tag
744            let hugr_type = bool_t();
745            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i1);
746        }
747
748        {
749            // multi-variant-elidable-fields -> bare tag
750            let hugr_type = HugrType::new_sum(vec![vec![HugrType::UNIT]; 3]);
751            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i2);
752        }
753
754        {
755            // one-variant-one-field -> bare field
756            let hugr_type = HugrType::new_tuple(vec![usize_t()]);
757            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i64);
758        }
759
760        {
761            // one-variant-one-non-elidable-field -> bare field
762            let hugr_type = HugrType::new_tuple(vec![HugrType::UNIT, usize_t()]);
763            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i64);
764        }
765
766        {
767            // one-variant-multi-field -> struct-of-fields
768            let hugr_type = HugrType::new_tuple(vec![usize_t(), bool_t(), HugrType::UNIT]);
769            let llvm_type = iwc.struct_type(&[i64, i1], false).into();
770            assert_eq!(ts.llvm_type(&hugr_type).unwrap(), llvm_type);
771        }
772
773        {
774            // multi-variant-multi-field -> struct-of-fields-with-tag
775            let hugr_type1 =
776                HugrType::new_sum([vec![bool_t(), HugrType::UNIT, usize_t()], vec![usize_t()]]);
777            let hugr_type2 = HugrType::new_sum([vec![usize_t(), bool_t()], vec![usize_t()]]);
778            let llvm_type = iwc.struct_type(&[i1, i64, i1], false).into();
779            assert_eq!(ts.llvm_type(&hugr_type1).unwrap(), llvm_type);
780            assert_eq!(ts.llvm_type(&hugr_type2).unwrap(), llvm_type);
781        }
782    }
783
784    #[rstest]
785    #[case::unit(HugrSumType::new_unary(1), 0)]
786    #[case::unit_elided_fields(HugrSumType::new([HugrType::UNIT]), 0)]
787    #[case::nofields(HugrSumType::new_unary(4), 2)]
788    #[case::nofields_elided_fields(HugrSumType::new([vec![HugrType::UNIT], vec![]]), 0)]
789    #[case::one_variant_one_field(HugrSumType::new([bool_t()]), 0)]
790    #[case::one_variant_one_field_elided_fields(HugrSumType::new([vec![HugrType::UNIT,bool_t()]]), 0)]
791    #[case::one_variant_two_fields(HugrSumType::new([vec![bool_t(),bool_t()]]), 0)]
792    #[case::one_variant_two_fields_elided_fields(HugrSumType::new([vec![bool_t(),HugrType::UNIT,bool_t()]]), 0)]
793    #[case::two_variant_one_field(HugrSumType::new([vec![bool_t()],vec![]]), 1)]
794    #[case::two_variant_one_field_elided_fields(HugrSumType::new([vec![bool_t()],vec![HugrType::UNIT]]), 1)]
795    fn build_untag_tag(
796        #[context] rstest_ctx: Context,
797        llvm_ctx: TestContext,
798        #[case] sum: HugrSumType,
799        #[case] tag: usize,
800    ) {
801        let module = {
802            let ts = llvm_ctx.get_typing_session();
803            let iwc = llvm_ctx.iw_context();
804            let module = iwc.create_module("");
805            let llvm_ty = ts.llvm_sum_type(sum.clone()).unwrap();
806            let func_ty = llvm_ty.fn_type(&[llvm_ty.as_basic_type_enum().into()], false);
807            let func = module.add_function("untag_tag", func_ty, None);
808            let bb = iwc.append_basic_block(func, "");
809            let builder = iwc.create_builder();
810            builder.position_at_end(bb);
811            let value = llvm_ty.value(func.get_nth_param(0).unwrap()).unwrap();
812            let _tag = value.build_get_tag(&builder).unwrap();
813            let fields = value.build_untag(&builder, tag).unwrap();
814            let new_value = llvm_ty.build_tag(&builder, tag, fields).unwrap();
815            let _ = builder.build_return(Some(&new_value));
816            module.verify().unwrap();
817            module
818        };
819
820        let mut insta_settings = insta::Settings::clone_current();
821        insta_settings.set_snapshot_suffix(rstest_ctx.description.unwrap());
822        insta_settings.bind(|| {
823            assert_snapshot!(module.to_string());
824        });
825    }
826}