use super::{TypeBuilder, WithSelf};
use crate::{
error::Result,
metadata::{
drop_overrides::DropOverridesMeta, dup_overrides::DupOverridesMeta,
realloc_bindings::ReallocBindingsMeta, MetadataStorage,
},
utils::ProgramRegistryExt,
};
use cairo_lang_sierra::{
extensions::{
core::{CoreLibfunc, CoreType},
types::InfoAndTypeConcreteType,
},
program_registry::ProgramRegistry,
};
use melior::{
dialect::{arith, llvm, ods},
ir::{attribute::IntegerAttribute, r#type::IntegerType, Block, Location, Module, Type},
Context,
};
use melior::{
dialect::{arith::CmpiPredicate, func, scf},
ir::BlockLike,
};
use melior::{
helpers::{ArithBlockExt, BuiltinBlockExt, GepIndex, LlvmBlockExt},
ir::Region,
};
pub fn build<'ctx>(
context: &'ctx Context,
module: &Module<'ctx>,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
metadata: &mut MetadataStorage,
info: WithSelf<InfoAndTypeConcreteType>,
) -> Result<Type<'ctx>> {
DupOverridesMeta::register_with(
context,
module,
registry,
metadata,
info.self_ty(),
|metadata| {
Ok(Some(build_dup(context, module, registry, metadata, &info)?))
},
)?;
DropOverridesMeta::register_with(
context,
module,
registry,
metadata,
info.self_ty(),
|metadata| {
Ok(Some(build_drop(
context, module, registry, metadata, &info,
)?))
},
)?;
let ptr_ty = llvm::r#type::pointer(context, 0);
let len_ty = IntegerType::new(context, 32).into();
Ok(llvm::r#type::r#struct(
context,
&[ptr_ty, len_ty, len_ty, len_ty],
false,
))
}
pub fn build_dup<'ctx>(
context: &'ctx Context,
module: &Module<'ctx>,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
metadata: &mut MetadataStorage,
info: &WithSelf<InfoAndTypeConcreteType>,
) -> Result<Region<'ctx>> {
let location = Location::unknown(context);
let value_ty = registry.build_type(context, module, metadata, info.self_ty())?;
let region = Region::new();
let entry = region.append_block(Block::new(&[(value_ty, location)]));
let metadata_ptr = entry.extract_value(
context,
location,
entry.argument(0)?.into(),
llvm::r#type::pointer(context, 0),
0,
)?;
let null_ptr =
entry.append_op_result(llvm::zero(llvm::r#type::pointer(context, 0), location))?;
let is_empty = entry.append_op_result(
ods::llvm::icmp(
context,
IntegerType::new(context, 1).into(),
metadata_ptr,
null_ptr,
IntegerAttribute::new(IntegerType::new(context, 64).into(), 0).into(),
location,
)
.into(),
)?;
entry.append_operation(scf::r#if(
is_empty,
&[],
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));
block.append_operation(scf::r#yield(&[], location));
region
},
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));
let refcount_ptr = block.gep(
context,
location,
metadata_ptr,
&[GepIndex::Const(0), GepIndex::Const(0)],
get_metadata_llvm_type(context),
)?;
let ref_count = block.load(
context,
location,
refcount_ptr,
IntegerType::new(context, 32).into(),
)?;
let k1 = block.const_int(context, location, 1, 32)?;
let ref_count = block.append_op_result(arith::addi(ref_count, k1, location))?;
block.store(context, location, refcount_ptr, ref_count)?;
block.append_operation(scf::r#yield(&[], location));
region
},
location,
));
entry.append_operation(func::r#return(
&[entry.argument(0)?.into(), entry.argument(0)?.into()],
location,
));
Ok(region)
}
pub fn build_drop<'ctx>(
context: &'ctx Context,
module: &Module<'ctx>,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
metadata: &mut MetadataStorage,
info: &WithSelf<InfoAndTypeConcreteType>,
) -> Result<Region<'ctx>> {
let location = Location::unknown(context);
if metadata.get::<ReallocBindingsMeta>().is_none() {
metadata.insert(ReallocBindingsMeta::new(context, module));
}
let value_ty = registry.build_type(context, module, metadata, info.self_ty())?;
let elem_ty = registry.get_type(&info.ty)?;
let elem_stride = elem_ty.layout(registry)?.pad_to_align().size();
let elem_ty = elem_ty.build(context, module, registry, metadata, &info.ty)?;
let region = Region::new();
let entry = region.append_block(Block::new(&[(value_ty, location)]));
let metadata_ptr = entry.extract_value(
context,
location,
entry.argument(0)?.into(),
llvm::r#type::pointer(context, 0),
0,
)?;
let null_ptr =
entry.append_op_result(llvm::zero(llvm::r#type::pointer(context, 0), location))?;
let is_null = entry.append_op_result(
ods::llvm::icmp(
context,
IntegerType::new(context, 1).into(),
metadata_ptr,
null_ptr,
IntegerAttribute::new(IntegerType::new(context, 64).into(), 0).into(),
location,
)
.into(),
)?;
entry.append_operation(scf::r#if(
is_null,
&[],
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));
block.append_operation(scf::r#yield(&[], location));
region
},
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));
let refcount_ptr = block.gep(
context,
location,
metadata_ptr,
&[GepIndex::Const(0), GepIndex::Const(0)],
get_metadata_llvm_type(context),
)?;
let ref_count = block.load(
context,
location,
refcount_ptr,
IntegerType::new(context, 32).into(),
)?;
let k1 = block.const_int(context, location, 1, 32)?;
let is_shared = block.append_op_result(arith::cmpi(
context,
CmpiPredicate::Ne,
ref_count,
k1,
location,
))?;
block.append_operation(scf::r#if(
is_shared,
&[],
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));
let ref_count = block.append_op_result(arith::subi(ref_count, k1, location))?;
block.store(context, location, refcount_ptr, ref_count)?;
block.append_operation(scf::r#yield(&[], location));
region
},
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));
if DropOverridesMeta::is_overriden(metadata, &info.ty) {
let k0 = block.const_int(context, location, 0, 64)?;
let elem_stride = block.const_int(context, location, elem_stride, 64)?;
let max_len_ptr = block.gep(
context,
location,
metadata_ptr,
&[GepIndex::Const(0), GepIndex::Const(1)],
get_metadata_llvm_type(context),
)?;
let max_len = block.load(
context,
location,
max_len_ptr,
IntegerType::new(context, 32).into(),
)?;
let max_len =
block.extui(max_len, IntegerType::new(context, 64).into(), location)?;
let offset_end = block.muli(max_len, elem_stride, location)?;
let data_ptr_ptr = block.gep(
context,
location,
metadata_ptr,
&[GepIndex::Const(0), GepIndex::Const(2)],
get_metadata_llvm_type(context),
)?;
let data_ptr = block.load(
context,
location,
data_ptr_ptr,
llvm::r#type::pointer(context, 0),
)?;
block.append_operation(scf::r#for(
k0,
offset_end,
elem_stride,
{
let region = Region::new();
let block = region.append_block(Block::new(&[(
IntegerType::new(context, 64).into(),
location,
)]));
let elem_offset = block.argument(0)?.into();
let elem_ptr = block.gep(
context,
location,
data_ptr,
&[GepIndex::Value(elem_offset)],
IntegerType::new(context, 8).into(),
)?;
let elem_val = block.load(context, location, elem_ptr, elem_ty)?;
DropOverridesMeta::invoke_override(
context, registry, module, &block, &block, location, metadata,
&info.ty, elem_val,
)?;
block.append_operation(scf::r#yield(&[], location));
region
},
location,
));
}
let data_ptr_ptr = block.gep(
context,
location,
metadata_ptr,
&[GepIndex::Const(0), GepIndex::Const(2)],
get_metadata_llvm_type(context),
)?;
let data_ptr = block.load(
context,
location,
data_ptr_ptr,
llvm::r#type::pointer(context, 0),
)?;
block.append_operation(ReallocBindingsMeta::free(context, data_ptr, location)?);
block.append_operation(ReallocBindingsMeta::free(
context,
metadata_ptr,
location,
)?);
block.append_operation(scf::r#yield(&[], location));
region
},
location,
));
block.append_operation(scf::r#yield(&[], location));
region
},
location,
));
entry.append_operation(func::r#return(&[], location));
Ok(region)
}
#[repr(C)]
pub struct ArrayMetadata {
pub refcount: u32,
pub max_len: u32,
pub data_ptr: *mut u8,
}
pub fn calc_metadata_size() -> usize {
std::mem::size_of::<ArrayMetadata>()
}
pub fn get_metadata_llvm_type(context: &Context) -> Type<'_> {
llvm::r#type::r#struct(
context,
&[
IntegerType::new(context, 32).into(), IntegerType::new(context, 32).into(), llvm::r#type::pointer(context, 0), ],
false,
)
}
#[cfg(test)]
mod test {
use crate::{
utils::testing::{get_compiled_program, run_program},
values::Value,
};
use pretty_assertions_sorted::assert_eq;
#[test]
fn test_array_snapshot_deep_clone() {
let program = get_compiled_program("test_data_artifacts/programs/types/nested_arrays");
let result = run_program(&program, "run_test", &[]).return_value;
assert_eq!(
result,
Value::Array(vec![
Value::Array(vec![
Value::Felt252(1.into()),
Value::Felt252(2.into()),
Value::Felt252(3.into()),
]),
Value::Array(vec![
Value::Felt252(4.into()),
Value::Felt252(5.into()),
Value::Felt252(6.into()),
]),
]),
);
}
}