use super::{TypeBuilder, WithSelf};
use crate::{
error::Result,
metadata::{
drop_overrides::DropOverridesMeta, dup_overrides::DupOverridesMeta, MetadataStorage,
},
native_panic,
utils::{get_integer_layout, ProgramRegistryExt},
};
use cairo_lang_sierra::{
extensions::{
core::{CoreLibfunc, CoreType},
enm::EnumConcreteType,
},
ids::ConcreteTypeId,
program_registry::ProgramRegistry,
};
use melior::{
dialect::{cf, func, llvm},
helpers::{BuiltinBlockExt, LlvmBlockExt},
ir::{r#type::IntegerType, Block, BlockLike, Location, Module, Region, Type, Value},
Context,
};
use std::{
alloc::Layout,
collections::{hash_map::Entry, HashMap},
};
pub type TypeLayout<'ctx> = (Type<'ctx>, Layout);
pub fn build<'ctx>(
context: &'ctx Context,
module: &Module<'ctx>,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
metadata: &mut MetadataStorage,
info: WithSelf<EnumConcreteType>,
) -> Result<Type<'ctx>> {
DupOverridesMeta::register_with(
context,
module,
registry,
metadata,
info.self_ty(),
|metadata| {
let mut needs_override = false;
for variant in &info.variants {
registry.build_type(context, module, metadata, variant)?;
if DupOverridesMeta::is_overriden(metadata, variant) {
needs_override = true;
break;
}
}
needs_override
.then(|| build_dup(context, module, registry, metadata, &info))
.transpose()
},
)?;
DropOverridesMeta::register_with(
context,
module,
registry,
metadata,
info.self_ty(),
|metadata| {
let mut needs_override = false;
for variant in &info.variants {
registry.build_type(context, module, metadata, variant)?;
if DropOverridesMeta::is_overriden(metadata, variant) {
needs_override = true;
break;
}
}
needs_override
.then(|| build_drop(context, module, registry, metadata, &info))
.transpose()
},
)?;
let tag_bits = info.variants.len().next_power_of_two().trailing_zeros();
let tag_layout = get_integer_layout(tag_bits);
let layout = info.variants.iter().try_fold(tag_layout, |acc, id| {
let layout = tag_layout
.extend(registry.get_type(id)?.layout(registry)?)?
.0;
Result::Ok(Layout::from_size_align(
acc.size().max(layout.size()),
acc.align().max(layout.align()),
)?)
})?;
let i8_ty = IntegerType::new(context, 8).into();
Ok(match info.variants.len() {
0 => llvm::r#type::array(IntegerType::new(context, 8).into(), 0),
1 => registry.build_type(context, module, metadata, &info.variants[0])?,
_ if 'block: {
for type_id in &info.variants {
if !registry.get_type(type_id)?.is_zst(registry)? {
break 'block false;
}
}
true
} =>
{
llvm::r#type::r#struct(
context,
&[
IntegerType::new(context, tag_bits).into(),
llvm::r#type::array(i8_ty, 0),
],
false,
)
}
_ => llvm::r#type::r#struct(
context,
&[
IntegerType::new(context, (8 * layout.align()) as u32).into(),
llvm::r#type::array(i8_ty, (layout.size() - layout.align()) as u32),
],
false,
),
})
}
fn build_dup<'ctx>(
context: &'ctx Context,
module: &Module<'ctx>,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
metadata: &mut MetadataStorage,
info: &WithSelf<EnumConcreteType>,
) -> Result<Region<'ctx>> {
let location = Location::unknown(context);
let self_ty = registry.build_type(context, module, metadata, info.self_ty())?;
let region = Region::new();
let entry = region.append_block(Block::new(&[(self_ty, location)]));
let (layout, (tag_ty, _), variant_tys) = crate::types::r#enum::get_type_for_variants(
context,
module,
registry,
metadata,
&info.variants,
)?;
match variant_tys.len() {
0 => native_panic!("attempt to clone a zero-variant enum"),
1 => {
let values = DupOverridesMeta::invoke_override(
context,
registry,
module,
&entry,
&entry,
location,
metadata,
&info.variants[0],
entry.arg(0)?,
)?;
entry.append_operation(func::r#return(&[values.0, values.1], location));
}
_ => {
let ptr = entry.alloca1(context, location, self_ty, layout.align())?;
entry.store(context, location, ptr, entry.arg(0)?)?;
let mut variant_blocks = HashMap::new();
for (variant_id, variant_ty) in info
.variants
.iter()
.zip(variant_tys.iter().map(|(x, _)| *x))
{
if let Entry::Vacant(entry) = variant_blocks.entry(variant_id.id) {
let block = entry.insert(region.append_block(Block::new(&[])));
let container = block.load(
context,
location,
ptr,
llvm::r#type::r#struct(context, &[tag_ty, variant_ty], false),
)?;
let value = block.extract_value(context, location, container, variant_ty, 1)?;
let values = DupOverridesMeta::invoke_override(
context, registry, module, block, block, location, metadata, variant_id,
value,
)?;
let value = block.insert_value(context, location, container, values.0, 1)?;
block.store(context, location, ptr, value)?;
let value0 = block.load(context, location, ptr, self_ty)?;
let value = block.insert_value(context, location, container, values.1, 1)?;
block.store(context, location, ptr, value)?;
let value1 = block.load(context, location, ptr, self_ty)?;
block.append_operation(func::r#return(&[value0, value1], location));
}
}
let default_block = region.append_block(Block::new(&[]));
let tag_value = entry.load(context, location, ptr, tag_ty)?;
entry.append_operation(cf::switch(
context,
&(0..info.variants.len() as _).collect::<Vec<_>>(),
tag_value,
tag_ty,
(&default_block, &[]),
&info
.variants
.iter()
.map(|id| (&*variant_blocks[&id.id], &[] as &[Value]))
.collect::<Vec<_>>(),
location,
)?);
default_block.append_operation(llvm::unreachable(location));
}
}
Ok(region)
}
fn build_drop<'ctx>(
context: &'ctx Context,
module: &Module<'ctx>,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
metadata: &mut MetadataStorage,
info: &WithSelf<EnumConcreteType>,
) -> Result<Region<'ctx>> {
let location = Location::unknown(context);
let self_ty = registry.build_type(context, module, metadata, info.self_ty())?;
let region = Region::new();
let entry = region.append_block(Block::new(&[(self_ty, location)]));
let (layout, (tag_ty, _), variant_tys) = crate::types::r#enum::get_type_for_variants(
context,
module,
registry,
metadata,
&info.variants,
)?;
match variant_tys.len() {
0 => native_panic!("attempt to drop a zero-variant enum"),
1 => {
DropOverridesMeta::invoke_override(
context,
registry,
module,
&entry,
&entry,
location,
metadata,
&info.variants[0],
entry.arg(0)?,
)?;
entry.append_operation(func::r#return(&[], location));
}
_ => {
let ptr = entry.alloca1(context, location, self_ty, layout.align())?;
entry.store(context, location, ptr, entry.arg(0)?)?;
let mut variant_blocks = HashMap::new();
for (variant_id, variant_ty) in info
.variants
.iter()
.zip(variant_tys.iter().map(|(x, _)| *x))
{
if let Entry::Vacant(entry) = variant_blocks.entry(variant_id.id) {
let block = entry.insert(region.append_block(Block::new(&[])));
let container = block.load(
context,
location,
ptr,
llvm::r#type::r#struct(context, &[tag_ty, variant_ty], false),
)?;
let value = block.extract_value(context, location, container, variant_ty, 1)?;
DropOverridesMeta::invoke_override(
context, registry, module, block, block, location, metadata, variant_id,
value,
)?;
block.append_operation(func::r#return(&[], location));
}
}
let default_block = region.append_block(Block::new(&[]));
let tag_value = entry.load(context, location, ptr, tag_ty)?;
entry.append_operation(cf::switch(
context,
&(0..info.variants.len() as _).collect::<Vec<_>>(),
tag_value,
tag_ty,
(&default_block, &[]),
&info
.variants
.iter()
.map(|id| (&*variant_blocks[&id.id], &[] as &[Value]))
.collect::<Vec<_>>(),
location,
)?);
default_block.append_operation(llvm::unreachable(location));
}
}
Ok(region)
}
pub fn get_layout_for_variants(
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
variants: &[ConcreteTypeId],
) -> Result<(Layout, Layout, Vec<Layout>)> {
let tag_bits = variants.len().next_power_of_two().trailing_zeros();
let tag_layout = get_integer_layout(tag_bits);
let mut layout = tag_layout;
let mut output = Vec::with_capacity(variants.len());
for variant in variants {
let concrete_payload_ty = registry.get_type(variant)?;
let payload_layout = concrete_payload_ty.layout(registry)?;
let full_layout = tag_layout.extend(payload_layout)?.0;
layout = Layout::from_size_align(
layout.size().max(full_layout.size()),
layout.align().max(full_layout.align()),
)?;
output.push(payload_layout);
}
Ok((layout, tag_layout, output))
}
pub fn get_type_for_variants<'ctx>(
context: &'ctx Context,
module: &Module<'ctx>,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
metadata: &mut MetadataStorage,
variants: &[ConcreteTypeId],
) -> Result<(Layout, TypeLayout<'ctx>, Vec<TypeLayout<'ctx>>)> {
let tag_bits = variants.len().next_power_of_two().trailing_zeros();
let tag_layout = get_integer_layout(tag_bits);
let tag_ty: Type = IntegerType::new(context, tag_bits).into();
let mut layout = tag_layout;
let mut output = Vec::with_capacity(variants.len());
for variant in variants {
let (payload_ty, payload_layout) =
registry.build_type_with_layout(context, module, metadata, variant)?;
let full_layout = tag_layout.extend(payload_layout)?.0;
layout = Layout::from_size_align(
layout.size().max(full_layout.size()),
layout.align().max(full_layout.align()),
)?;
output.push((payload_ty, payload_layout));
}
Ok((layout, (tag_ty, tag_layout), output))
}
#[cfg(test)]
mod test {
use crate::{metadata::MetadataStorage, types::TypeBuilder, utils::testing::load_program};
use cairo_lang_sierra::{
extensions::core::{CoreLibfunc, CoreType},
program_registry::ProgramRegistry,
};
use melior::{
ir::{r#type::IntegerType, Location, Module},
Context,
};
#[test]
fn enum_type_single_variant_no_i0() {
let program =
load_program("test_data_artifacts/programs/types/enum_type_single_variant_no_i0");
let context = Context::new();
let registry = ProgramRegistry::<CoreType, CoreLibfunc>::new(&program).unwrap();
let module = Module::new(Location::unknown(&context));
let mut metadata = MetadataStorage::new();
let i0_ty = IntegerType::new(&context, 0).into();
program
.type_declarations
.iter()
.map(|ty| (&ty.id, registry.get_type(&ty.id).unwrap()))
.map(|(id, ty)| {
ty.build(&context, &module, ®istry, &mut metadata, id)
.unwrap()
})
.any(|width| width == i0_ty);
}
}