use super::{TypeBuilder, WithSelf};
use crate::{
error::Result,
metadata::MetadataStorage,
utils::{get_integer_layout, ProgramRegistryExt},
};
use cairo_lang_sierra::{
extensions::{
core::{CoreLibfunc, CoreType},
enm::EnumConcreteType,
},
ids::ConcreteTypeId,
program_registry::ProgramRegistry,
};
use melior::{
dialect::llvm,
ir::{r#type::IntegerType, Module, Type},
Context,
};
use std::alloc::Layout;
pub type TypeLayout<'ctx> = (Type<'ctx>, Layout);
pub fn build<'ctx>(
context: &'ctx Context,
module: &Module<'ctx>,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
metadata: &mut MetadataStorage,
info: WithSelf<EnumConcreteType>,
) -> Result<Type<'ctx>> {
let tag_bits = info.variants.len().next_power_of_two().trailing_zeros();
let tag_layout = get_integer_layout(tag_bits);
let layout = info.variants.iter().try_fold(tag_layout, |acc, id| {
let layout = tag_layout
.extend(registry.get_type(id)?.layout(registry)?)?
.0;
Result::Ok(Layout::from_size_align(
acc.size().max(layout.size()),
acc.align().max(layout.align()),
)?)
})?;
let i8_ty = IntegerType::new(context, 8).into();
Ok(match info.variants.len() {
0 => llvm::r#type::array(IntegerType::new(context, 8).into(), 0),
1 => registry.build_type(context, module, metadata, &info.variants[0])?,
_ if 'block: {
for type_id in &info.variants {
if !registry.get_type(type_id)?.is_zst(registry)? {
break 'block false;
}
}
true
} =>
{
llvm::r#type::r#struct(
context,
&[
IntegerType::new(context, tag_bits).into(),
llvm::r#type::array(i8_ty, 0),
],
false,
)
}
_ => llvm::r#type::r#struct(
context,
&[
IntegerType::new(context, (8 * layout.align()) as u32).into(),
llvm::r#type::array(i8_ty, (layout.size() - layout.align()) as u32),
],
false,
),
})
}
pub fn get_layout_for_variants(
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
variants: &[ConcreteTypeId],
) -> Result<(Layout, Layout, Vec<Layout>)> {
let tag_bits = variants.len().next_power_of_two().trailing_zeros();
let tag_layout = get_integer_layout(tag_bits);
let mut layout = tag_layout;
let mut output = Vec::with_capacity(variants.len());
for variant in variants {
let concrete_payload_ty = registry.get_type(variant)?;
let payload_layout = concrete_payload_ty.layout(registry)?;
let full_layout = tag_layout.extend(payload_layout)?.0;
layout = Layout::from_size_align(
layout.size().max(full_layout.size()),
layout.align().max(full_layout.align()),
)?;
output.push(payload_layout);
}
Ok((layout, tag_layout, output))
}
pub fn get_type_for_variants<'ctx>(
context: &'ctx Context,
module: &Module<'ctx>,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
metadata: &mut MetadataStorage,
variants: &[ConcreteTypeId],
) -> Result<(Layout, TypeLayout<'ctx>, Vec<TypeLayout<'ctx>>)> {
let tag_bits = variants.len().next_power_of_two().trailing_zeros();
let tag_layout = get_integer_layout(tag_bits);
let tag_ty: Type = IntegerType::new(context, tag_bits).into();
let mut layout = tag_layout;
let mut output = Vec::with_capacity(variants.len());
for variant in variants {
let (payload_ty, payload_layout) =
registry.build_type_with_layout(context, module, metadata, variant)?;
let full_layout = tag_layout.extend(payload_layout)?.0;
layout = Layout::from_size_align(
layout.size().max(full_layout.size()),
layout.align().max(full_layout.align()),
)?;
output.push((payload_ty, payload_layout));
}
Ok((layout, (tag_ty, tag_layout), output))
}
#[cfg(test)]
mod test {
use crate::{metadata::MetadataStorage, types::TypeBuilder, utils::testing::load_program};
use cairo_lang_sierra::{
extensions::core::{CoreLibfunc, CoreType},
program_registry::ProgramRegistry,
};
use melior::{
ir::{r#type::IntegerType, Location, Module},
Context,
};
#[test]
fn enum_type_single_variant_no_i0() {
let program =
load_program("test_data_artifacts/programs/types/enum_type_single_variant_no_i0");
let context = Context::new();
let registry = ProgramRegistry::<CoreType, CoreLibfunc>::new(&program).unwrap();
let module = Module::new(Location::unknown(&context));
let mut metadata = MetadataStorage::new();
let i0_ty = IntegerType::new(&context, 0).into();
program
.type_declarations
.iter()
.map(|ty| (&ty.id, registry.get_type(&ty.id).unwrap()))
.map(|(id, ty)| {
ty.build(&context, &module, ®istry, &mut metadata, id)
.unwrap()
})
.any(|width| width == i0_ty);
}
}