use crate::{
clone_option_mut,
debug::libfunc_to_name,
error::{panic::ToNativeAssertError, Error},
libfuncs::{BranchArg, LibfuncBuilder, LibfuncHelper},
metadata::{
gas::{GasCost, GasMetadata},
tail_recursion::TailRecursionMeta,
MetadataStorage,
},
native_assert, native_panic,
statistics::Statistics,
types::TypeBuilder,
utils::{generate_function_name, walk_ir::walk_mlir_block},
};
use bumpalo::Bump;
use cairo_lang_sierra::{
edit_state::EditState,
extensions::{
core::{CoreConcreteLibfunc, CoreLibfunc, CoreType},
ConcreteLibfunc,
},
ids::{ConcreteTypeId, VarId},
program::{Function, Invocation, Program, Statement, StatementIdx},
program_registry::ProgramRegistry,
};
use cairo_lang_sierra_to_casm::environment::gas_wallet::GasWallet;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use itertools::Itertools;
use melior::{
dialect::{
arith::CmpiPredicate,
cf, func, index,
llvm::{self, LoadStoreOptions},
memref,
},
helpers::{ArithBlockExt, BuiltinBlockExt, LlvmBlockExt},
ir::{
attribute::{
DenseI64ArrayAttribute, FlatSymbolRefAttribute, IntegerAttribute, StringAttribute,
TypeAttribute,
},
operation::OperationBuilder,
r#type::{FunctionType, IntegerType, MemRefType},
Attribute, AttributeLike, Block, BlockLike, BlockRef, Identifier, Location, Module, Region,
Type, Value,
},
Context,
};
use mlir_sys::{
mlirDisctinctAttrCreate, mlirLLVMDICompileUnitAttrGet, mlirLLVMDIFileAttrGet,
mlirLLVMDIModuleAttrGet, mlirLLVMDIModuleAttrGetScope, mlirLLVMDISubprogramAttrGet,
mlirLLVMDISubroutineTypeAttrGet, MlirLLVMDIEmissionKind_MlirLLVMDIEmissionKindFull,
MlirLLVMDINameTableKind_MlirLLVMDINameTableKindDefault,
};
use std::{
cell::Cell,
collections::{hash_map::Entry, BTreeMap, HashMap, HashSet},
ops::Deref,
};
type BlockStorage<'c, 'a> =
HashMap<StatementIdx, (Option<(BlockRef<'c, 'a>, Vec<VarId>)>, BlockRef<'c, 'a>)>;
#[allow(clippy::too_many_arguments)]
pub fn compile(
context: &Context,
module: &Module,
program: &Program,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
metadata: &mut MetadataStorage,
di_compile_unit_id: Attribute,
ignore_debug_names: bool,
stats: Option<&mut Statistics>,
) -> Result<(), Error> {
if let Ok(x) = std::env::var("NATIVE_DEBUG_DUMP") {
if x == "1" || x == "true" {
std::fs::write("program.sierra", program.to_string())?;
}
}
let num_types = program.type_declarations.len() + 1;
let n_libfuncs = program.libfunc_declarations.len() + 1;
let sierra_stmt_start_offset = num_types + n_libfuncs + 1;
for function in &program.funcs {
tracing::info!("Compiling function `{}`.", function.id);
compile_func(
context,
module,
registry,
function,
&program.statements,
metadata,
di_compile_unit_id,
sierra_stmt_start_offset,
ignore_debug_names,
clone_option_mut!(stats),
)?;
}
tracing::info!("The program was compiled successfully.");
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn compile_func(
context: &Context,
module: &Module,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
function: &Function,
statements: &[Statement],
metadata: &mut MetadataStorage,
di_compile_unit_id: Attribute,
sierra_stmt_start_offset: usize,
ignore_debug_names: bool,
stats: Option<&mut Statistics>,
) -> Result<(), Error> {
let fn_location = Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + function.entry_point.0,
0,
);
let region = Region::new();
let blocks_arena = Bump::new();
let mut arg_types = extract_types(
context,
module,
&function.signature.param_types,
registry,
metadata,
)
.collect::<Result<Vec<_>, _>>()?;
let mut return_types = extract_types(
context,
module,
&function.signature.ret_types,
registry,
metadata,
)
.collect::<Result<Vec<_>, _>>()?;
#[cfg(feature = "with-trace-dump")]
let mut var_types: HashMap<VarId, ConcreteTypeId> = HashMap::new();
for (ty, type_info) in
arg_types
.iter_mut()
.zip(function.signature.param_types.iter().filter_map(|type_id| {
let type_info = match registry.get_type(type_id) {
Ok(x) => x,
Err(e) => return Some(Err(e.into())),
};
let is_zst = match type_info.is_zst(registry) {
Ok(x) => x,
Err(e) => return Some(Err(e)),
};
if type_info.is_builtin() && is_zst {
None
} else {
Some(Ok(type_info))
}
}))
{
let type_info = type_info?;
if type_info.is_memory_allocated(registry)? {
*ty = llvm::r#type::pointer(context, 0);
}
}
let return_type_infos = function
.signature
.ret_types
.iter()
.filter_map(|type_id| {
let type_info = match registry.get_type(type_id) {
Ok(x) => x,
Err(e) => return Some(Err(e.into())),
};
let is_zst = match type_info.is_zst(registry) {
Ok(x) => x,
Err(e) => return Some(Err(e)),
};
if type_info.is_builtin() && is_zst {
None
} else {
Some(Ok((type_id, type_info)))
}
})
.collect::<Result<Vec<_>, _>>()?;
let has_return_ptr = if return_type_infos.len() > 1 {
Some(false)
} else if return_type_infos
.first()
.map(|(_, type_info)| type_info.is_memory_allocated(registry))
.transpose()?
== Some(true)
{
assert_eq!(return_types.len(), 1);
return_types.remove(0);
arg_types.insert(0, llvm::r#type::pointer(context, 0));
Some(true)
} else {
None
};
let function_name = generate_function_name(&function.id, ignore_debug_names);
let function_name_for_inner = generate_function_name(&function.id, false);
let di_subprogram = unsafe {
let file_attr = Attribute::from_raw(mlirLLVMDIFileAttrGet(
context.to_raw(),
StringAttribute::new(context, "program.sierra").to_raw(),
StringAttribute::new(context, ".").to_raw(),
));
let compile_unit = {
Attribute::from_raw(mlirLLVMDICompileUnitAttrGet(
context.to_raw(),
di_compile_unit_id.to_raw(),
0x0002, file_attr.to_raw(),
StringAttribute::new(context, "cairo-native").to_raw(),
false,
MlirLLVMDIEmissionKind_MlirLLVMDIEmissionKindFull,
MlirLLVMDINameTableKind_MlirLLVMDINameTableKindDefault,
))
};
let di_module = mlirLLVMDIModuleAttrGet(
context.to_raw(),
file_attr.to_raw(),
compile_unit.to_raw(),
StringAttribute::new(context, "LLVMDialectModule").to_raw(),
StringAttribute::new(context, "").to_raw(),
StringAttribute::new(context, "").to_raw(),
StringAttribute::new(context, "").to_raw(),
0,
false,
);
let module_scope = mlirLLVMDIModuleAttrGetScope(di_module);
Attribute::from_raw({
let id = mlirDisctinctAttrCreate(
StringAttribute::new(context, &format!("fn_{}", function.id.id)).to_raw(),
);
let ty = mlirLLVMDISubroutineTypeAttrGet(
context.to_raw(),
0x0, 0,
std::ptr::null(),
);
mlirLLVMDISubprogramAttrGet(
context.to_raw(),
id,
module_scope,
file_attr.to_raw(),
StringAttribute::new(context, &function_name).to_raw(),
StringAttribute::new(context, &function_name).to_raw(),
file_attr.to_raw(),
(sierra_stmt_start_offset + function.entry_point.0) as u32,
(sierra_stmt_start_offset + function.entry_point.0) as u32,
0x8, ty,
)
})
};
tracing::debug!("Generating function structure (region with blocks).");
let (entry_block, blocks, is_recursive) = generate_function_structure(
context,
module,
®ion,
registry,
function,
statements,
metadata,
sierra_stmt_start_offset,
)?;
tracing::debug!("Generating the function implementation.");
let pre_entry_block_args = arg_types
.iter()
.map(|ty| {
(
*ty,
Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + function.entry_point.0,
0,
),
)
})
.collect::<Vec<_>>();
let pre_entry_block =
region.insert_block_before(entry_block, Block::new(&pre_entry_block_args));
let initial_state = {
let mut values = OrderedHashMap::default();
let mut count = 0;
for param in &function.params {
let type_info = registry.get_type(¶m.ty)?;
let location = Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + function.entry_point.0,
0,
);
values.insert(
param.id.clone(),
if type_info.is_builtin() && type_info.is_zst(registry)? {
pre_entry_block
.append_operation(llvm::undef(
type_info.build(context, module, registry, metadata, ¶m.ty)?,
location,
))
.result(0)?
.into()
} else {
let value = entry_block.argument(count)?.into();
count += 1;
value
},
);
#[cfg(feature = "with-trace-dump")]
var_types.insert(param.id.clone(), param.ty.clone());
}
values
};
tracing::trace!("Implementing the entry block.");
entry_block.append_operation(cf::br(
&blocks[&function.entry_point].1,
&match &statements[function.entry_point.0] {
Statement::Invocation(x) => &x.args,
Statement::Return(x) => x,
}
.iter()
.map(|x| initial_state[x])
.collect::<Vec<_>>(),
{
Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + function.entry_point.0,
0,
)
},
));
let mut tailrec_state = Option::<(Value, BlockRef)>::None;
foreach_statement_in_function::<_, Error>(
statements,
function.entry_point,
initial_state,
|statement_idx, mut state| {
if let Some(gas_metadata) = metadata.get::<GasMetadata>() {
let gas_cost = gas_metadata.get_gas_costs_for_statement(statement_idx);
let gas_wallet = gas_metadata.get_gas_wallet(statement_idx);
metadata.remove::<GasCost>();
metadata.remove::<GasWallet>();
metadata.insert(gas_wallet);
metadata.insert(GasCost(gas_cost));
}
let (landing_block, block) = &blocks[&statement_idx];
if let Some((landing_block, _)) = landing_block {
tracing::trace!("Implementing the statement {statement_idx}'s landing block.");
state = state
.keys()
.sorted_by_key(|x| x.id)
.enumerate()
.map(|(idx, var_id)| Ok((var_id.clone(), landing_block.argument(idx)?.into())))
.collect::<Result<_, Error>>()?;
landing_block.append_operation(cf::br(
block,
&state.clone().take_vars(
match &statements[statement_idx.0] {
Statement::Invocation(x) => &x.args,
Statement::Return(x) => x,
}
.iter(),
)?,
Location::name(
context,
&format!("landing_block(stmt_idx={})", statement_idx),
fn_location,
),
));
}
Ok(match &statements[statement_idx.0] {
Statement::Invocation(invocation) => {
tracing::trace!(
"Implementing the invocation statement at {statement_idx}: {}.",
invocation.libfunc_id
);
let location = Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + statement_idx.0,
0,
);
#[cfg(feature = "with-debug-utils")]
{
if let Ok(x) = std::env::var("NATIVE_DEBUG_TRAP_AT_STMT") {
if x.eq_ignore_ascii_case(&statement_idx.0.to_string()) {
block.append_operation(
melior::dialect::ods::llvm::intr_debugtrap(context, location)
.into(),
);
}
}
}
let libfunc_name = if invocation.libfunc_id.debug_name.is_some() {
format!("{}(stmt_idx={})", invocation.libfunc_id, statement_idx)
} else {
let libf = registry.get_libfunc(&invocation.libfunc_id)?;
format!("{}(stmt_idx={})", libfunc_to_name(libf), statement_idx)
};
#[cfg(feature = "with-trace-dump")]
crate::utils::trace_dump::build_state_snapshot(
context,
registry,
module,
block,
location,
metadata,
statement_idx,
&state,
&var_types,
);
state.take_vars(invocation.args.iter())?;
let libfunc = registry.get_libfunc(&invocation.libfunc_id)?;
if is_recursive {
if let Some(target) = libfunc.is_function_call() {
if target == &function.id && state.is_empty() {
let location = Location::name(
context,
&format!("recursion_counter({})", libfunc_name),
location,
);
let op0 = pre_entry_block.insert_operation(
0,
memref::alloca(
context,
MemRefType::new(Type::index(context), &[], None, None),
&[],
&[],
None,
location,
),
);
let op1 = pre_entry_block.insert_operation_after(
op0,
index::constant(
context,
IntegerAttribute::new(Type::index(context), 0),
location,
),
);
pre_entry_block.insert_operation_after(
op1,
memref::store(
op1.result(0)?.into(),
op0.result(0)?.into(),
&[],
location,
),
);
metadata
.insert(TailRecursionMeta::new(
op0.result(0)?.into(),
&entry_block,
))
.to_native_assert_error(
"tail recursion metadata shouldn't be inserted",
)?;
}
}
}
#[allow(unused_mut)]
let mut helper = LibfuncHelper {
module,
init_block: &pre_entry_block,
region: ®ion,
blocks_arena: &blocks_arena,
last_block: Cell::new(block),
branches: generate_branching_targets(
&blocks,
statements,
statement_idx,
invocation,
&state,
),
results: invocation
.branches
.iter()
.map(|x| vec![Cell::new(None); x.results.len()])
.collect::<Vec<_>>(),
#[cfg(feature = "with-libfunc-profiling")]
profiler: match libfunc {
CoreConcreteLibfunc::FunctionCall(_) => {
None
}
_ => match metadata.remove::<crate::metadata::profiler::ProfilerMeta>()
{
Some(profiler_meta) => {
let t0 = profiler_meta
.measure_timestamp(context, block, location)?;
Some((profiler_meta, statement_idx, t0))
}
None => None,
},
},
};
libfunc.build(
context,
registry,
block,
Location::name(context, &libfunc_name, location),
&helper,
metadata,
)?;
if let Some(&mut ref mut stats) = stats {
let mut operations = 0;
walk_mlir_block(*block, *helper.last_block.get(), &mut |_| operations += 1);
let name = libfunc_to_name(libfunc).to_string();
*stats.mlir_operations_by_libfunc.entry(name).or_insert(0) += operations;
}
native_assert!(
block.terminator().is_some(),
"libfunc {} had no terminator",
libfunc_name
);
#[cfg(feature = "with-libfunc-profiling")]
if let Some((profiler_meta, _, _)) = helper.profiler.take() {
metadata.insert(profiler_meta);
}
if let Some(tailrec_meta) = metadata.remove::<TailRecursionMeta>() {
if let Some(return_block) = tailrec_meta.return_target() {
tailrec_state = Some((tailrec_meta.depth_counter(), return_block));
}
}
#[cfg(feature = "with-trace-dump")]
for (branch_signature, branch_info) in
libfunc.branch_signatures().iter().zip(&invocation.branches)
{
for (var_info, var_id) in
branch_signature.vars.iter().zip(&branch_info.results)
{
var_types.insert(var_id.clone(), var_info.ty.clone());
}
}
StatementCompileResult::Processed(
invocation
.branches
.iter()
.zip(helper.results()?)
.map(|(branch_info, result_values)| {
native_assert!(
branch_info.results.len() == result_values.len(),
"Mismatched number of returned values from branch."
);
let mut new_state = state.clone();
new_state.put_vars(
branch_info.results.iter().zip(result_values.into_iter()),
)?;
Ok(new_state)
})
.collect::<Result<_, Error>>()?,
)
}
Statement::Return(var_ids) => {
tracing::trace!("Implementing the return statement at {statement_idx}");
let location = Location::name(
context,
&format!("return(stmt_idx={})", statement_idx),
Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + statement_idx.0,
0,
),
);
#[cfg(feature = "with-trace-dump")]
if !is_recursive || tailrec_state.is_some() {
crate::utils::trace_dump::build_state_snapshot(
context,
registry,
module,
block,
location,
metadata,
statement_idx,
&state,
&var_types,
);
}
let mut values = state.take_vars(var_ids.iter())?;
let mut block = *block;
if is_recursive {
match tailrec_state {
None => {
return Ok(StatementCompileResult::Deferred);
}
Some((depth_counter, recursion_target)) => {
let location = Location::name(
context,
&format!("return(stmt_idx={}, tail_recursion)", statement_idx),
Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + statement_idx.0,
0,
),
);
let cont_block = region.insert_block_after(block, Block::new(&[]));
let depth_counter_value = block.append_op_result(memref::load(
depth_counter,
&[],
location,
))?;
let k0 = block.const_int_from_type(
context,
location,
0,
Type::index(context),
)?;
let is_zero_depth = block.append_op_result(index::cmp(
context,
CmpiPredicate::Eq,
depth_counter_value,
k0,
location,
))?;
let k1 = block.const_int_from_type(
context,
location,
1,
Type::index(context),
)?;
let depth_counter_value = block.append_op_result(index::sub(
depth_counter_value,
k1,
location,
))?;
block.append_operation(memref::store(
depth_counter_value,
depth_counter,
&[],
location,
));
let recursive_values = match has_return_ptr {
Some(true) => function
.signature
.ret_types
.iter()
.zip(&values)
.filter_map(|(type_id, value)| {
let type_info = match registry.get_type(type_id) {
Ok(x) => x,
Err(e) => return Some(Err(e.into())),
};
let is_zst = match type_info.is_zst(registry) {
Ok(x) => x,
Err(e) => return Some(Err(e)),
};
let is_memory_allocated =
match type_info.is_memory_allocated(registry) {
Ok(x) => x,
Err(e) => return Some(Err(e)),
};
if is_zst || is_memory_allocated {
None
} else {
Some(Ok(*value))
}
})
.collect::<Result<Vec<_>, _>>()?,
Some(false) => function
.signature
.ret_types
.iter()
.zip(&values)
.filter_map(|(type_id, value)| {
let type_info = match registry.get_type(type_id) {
Ok(x) => x,
Err(e) => return Some(Err(e.into())),
};
let is_zst = match type_info.is_zst(registry) {
Ok(x) => x,
Err(e) => return Some(Err(e)),
};
if is_zst {
None
} else {
Some(Ok(*value))
}
})
.collect::<Result<Vec<_>, _>>()?,
None => native_panic!("not yet implemented"),
};
block.append_operation(cf::cond_br(
context,
is_zero_depth,
&cont_block,
&recursion_target,
&[],
&recursive_values,
location,
));
block = cont_block;
}
}
}
for (idx, type_id) in function.signature.ret_types.iter().enumerate().rev() {
let type_info = registry.get_type(type_id)?;
if type_info.is_builtin() && type_info.is_zst(registry)? {
values.remove(idx);
}
}
if Some(true) == has_return_ptr {
let (_ret_type_id, ret_type_info) = return_type_infos[0];
let ret_layout = ret_type_info.layout(registry)?;
let ptr = values.remove(0);
block.append_operation(llvm::store(
context,
ptr,
pre_entry_block.arg(0)?,
location,
LoadStoreOptions::new().align(Some(IntegerAttribute::new(
IntegerType::new(context, 64).into(),
ret_layout.align() as i64,
))),
));
}
block.append_operation(llvm::r#return(
Some({
let res_ty = llvm::r#type::r#struct(context, &return_types, false);
values.iter().enumerate().try_fold(
block.append_op_result(llvm::undef(res_ty, location))?,
|acc, (idx, x)| {
block.append_op_result(llvm::insert_value(
context,
acc,
DenseI64ArrayAttribute::new(context, &[idx as i64]),
*x,
location,
))
},
)?
}),
location,
));
StatementCompileResult::Processed(Vec::new())
}
})
},
)?;
{
let mut arg_values = Vec::with_capacity(function.signature.param_types.len());
for (i, type_id_and_info) in function
.signature
.param_types
.iter()
.filter_map(|type_id| {
registry
.get_type(type_id)
.map(|type_info| {
let is_zst = match type_info.is_zst(registry) {
Ok(x) => x,
Err(e) => return Some(Err(e)),
};
if type_info.is_builtin() && is_zst {
None
} else {
Some(Ok((type_id, type_info)))
}
})
.map_err(Error::from)
.transpose()
.map(|x| match x {
Ok(Ok(x)) => Ok(x),
Ok(Err(e)) | Err(e) => Err(e),
})
})
.enumerate()
{
let (type_id, type_info) = type_id_and_info?;
let mut value = pre_entry_block
.argument((has_return_ptr == Some(true)) as usize + i)?
.into();
if type_info.is_memory_allocated(registry)? {
value = pre_entry_block
.append_operation(llvm::load(
context,
value,
type_info.build(context, module, registry, metadata, type_id)?,
fn_location,
LoadStoreOptions::new().align(Some(IntegerAttribute::new(
IntegerType::new(context, 64).into(),
type_info.layout(registry)?.align() as i64,
))),
))
.result(0)?
.into();
}
arg_values.push(value);
}
pre_entry_block.append_operation(cf::br(&entry_block, &arg_values, fn_location));
}
let inner_function_name = format!("impl${function_name_for_inner}");
module.body().append_operation(llvm::func(
context,
StringAttribute::new(context, &inner_function_name),
TypeAttribute::new(llvm::r#type::function(
llvm::r#type::r#struct(context, &return_types, false),
&arg_types,
false,
)),
region,
&[
(
Identifier::new(context, "sym_visibility"),
StringAttribute::new(context, "private").into(),
),
(
Identifier::new(context, "linkage"),
Attribute::parse(context, "#llvm.linkage<private>")
.ok_or(Error::ParseAttributeError)?,
),
(
Identifier::new(context, "CConv"),
Attribute::parse(context, "#llvm.cconv<fastcc>")
.ok_or(Error::ParseAttributeError)?,
),
],
Location::fused(
context,
&[Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + function.entry_point.0,
0,
)],
di_subprogram,
),
));
generate_entry_point_wrapper(
context,
module,
function_name.as_ref(),
&inner_function_name,
&pre_entry_block_args,
&return_types,
Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + function.entry_point.0,
0,
),
)?;
tracing::debug!("Done generating function {}.", function.id);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn generate_function_structure<'c, 'a>(
context: &'c Context,
module: &'a Module<'c>,
region: &'a Region<'c>,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
function: &Function,
statements: &[Statement],
metadata_storage: &mut MetadataStorage,
sierra_stmt_start_offset: usize,
) -> Result<(BlockRef<'c, 'a>, BlockStorage<'c, 'a>, bool), Error> {
let initial_state: OrderedHashMap<_, Type> = function
.params
.iter()
.zip(&function.signature.param_types)
.map(|(param, ty)| {
let type_info = registry.get_type(ty)?;
Ok((
param.id.clone(),
type_info.build(context, module, registry, metadata_storage, ty)?,
))
})
.collect::<Result<_, Error>>()?;
let mut blocks = BTreeMap::new();
let mut predecessors = HashMap::from([(function.entry_point, (initial_state.clone(), 0))]);
let mut num_tail_recursions = 0usize;
foreach_statement_in_function::<_, Error>(
statements,
function.entry_point,
initial_state,
|statement_idx, state| {
let block = {
if let std::collections::btree_map::Entry::Vacant(e) = blocks.entry(statement_idx.0)
{
e.insert(Block::new(&[]));
blocks
.get_mut(&statement_idx.0)
.to_native_assert_error("block should exist")?
} else {
native_panic!("statement index already present in block")
}
};
Ok(match &statements[statement_idx.0] {
Statement::Invocation(invocation) => {
tracing::trace!(
"Creating block for invocation statement at index {statement_idx}: {}",
invocation.libfunc_id
);
let (state, types) = {
let mut state = state.clone();
let types = state.take_vars(invocation.args.iter())?;
(state, types)
};
let location = Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + statement_idx.0,
0,
);
for ty in types {
block.add_argument(ty, location);
}
let libfunc = registry.get_libfunc(&invocation.libfunc_id)?;
if let CoreConcreteLibfunc::FunctionCall(info) = libfunc {
if info.function.id == function.id && state.is_empty() {
num_tail_recursions += 1;
}
}
StatementCompileResult::Processed(
invocation
.branches
.iter()
.zip(libfunc.branch_signatures())
.map(|(branch, branch_signature)| {
let mut new_state = state.clone();
new_state.put_vars(
branch.results.iter().zip(
branch_signature
.vars
.iter()
.map(|var_info| -> Result<_, Error> {
registry.get_type(&var_info.ty)?.build(
context,
module,
registry,
metadata_storage,
&var_info.ty,
)
})
.collect::<Result<Vec<_>, _>>()?,
),
)?;
let (prev_state, pred_count) = match predecessors
.entry(statement_idx.next(branch.target))
{
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => entry.insert((new_state.clone(), 0)),
};
native_assert!(
prev_state.eq_unordered(&new_state),
"Branch target states do not match."
);
*pred_count += 1;
Ok(new_state)
})
.collect::<Result<_, Error>>()?,
)
}
Statement::Return(var_ids) => {
tracing::trace!(
"Creating block for return statement at index {statement_idx}."
);
let (state, types) = {
let mut state = state.clone();
let types = state.take_vars(var_ids.iter())?;
(state, types)
};
native_assert!(
state.is_empty(),
"State must be empty after a return statement."
);
let location = Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + statement_idx.0,
0,
);
for ty in types {
block.add_argument(ty, location);
}
StatementCompileResult::Processed(Vec::new())
}
})
},
)?;
tracing::trace!("Generating function entry block.");
let entry_block = region.append_block(Block::new(&{
extract_types(
context,
module,
&function.signature.param_types,
registry,
metadata_storage,
)
.map(|ty| {
Ok((
ty?,
Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + function.entry_point.0,
0,
),
))
})
.collect::<Result<Vec<_>, Error>>()?
}));
let blocks = blocks
.into_iter()
.map(|(i, block)| {
let statement_idx = StatementIdx(i);
tracing::trace!("Inserting block for statement at index {statement_idx}.");
let libfunc_block = region.append_block(block);
let landing_block = (predecessors[&statement_idx].1 > 1).then(|| {
tracing::trace!(
"Generating a landing block for the statement at index {statement_idx}."
);
(
region.insert_block_before(
libfunc_block,
Block::new(
&predecessors[&statement_idx]
.0
.iter()
.map(|(var_id, ty)| (var_id.id, *ty))
.collect::<BTreeMap<_, _>>()
.into_values()
.map(|ty| {
(
ty,
Location::new(
context,
"program.sierra",
sierra_stmt_start_offset + statement_idx.0,
0,
),
)
})
.collect::<Vec<_>>(),
),
),
predecessors[&statement_idx]
.0
.clone()
.into_iter()
.sorted_by_key(|(k, _)| k.id)
.collect::<Vec<_>>(),
)
});
(statement_idx, (landing_block, libfunc_block))
})
.collect::<HashMap<_, _>>();
Ok((
entry_block,
blocks
.into_iter()
.map(|(k, v)| {
(
k,
(
v.0.map(|x| (x.0, x.1.into_iter().map(|x| x.0).collect::<Vec<_>>())),
v.1,
),
)
})
.collect(),
num_tail_recursions == 1,
))
}
fn extract_types<'c: 'a, 'a>(
context: &'c Context,
module: &'a Module<'c>,
type_ids: &'a [ConcreteTypeId],
registry: &'a ProgramRegistry<CoreType, CoreLibfunc>,
metadata_storage: &'a mut MetadataStorage,
) -> impl 'a + Iterator<Item = Result<Type<'c>, Error>> {
type_ids.iter().filter_map(|id| {
let type_info = match registry.get_type(id) {
Ok(x) => x,
Err(e) => return Some(Err(e.into())),
};
let is_zst = match type_info.is_zst(registry) {
Ok(x) => x,
Err(e) => return Some(Err(e)),
};
if type_info.is_builtin() && is_zst {
None
} else {
Some(type_info.build(context, module, registry, metadata_storage, id))
}
})
}
fn foreach_statement_in_function<S, E>(
statements: &[Statement],
entry_point: StatementIdx,
initial_state: S,
mut closure: impl FnMut(StatementIdx, S) -> Result<StatementCompileResult<Vec<S>>, E>,
) -> Result<(), E>
where
S: Clone,
{
let mut queue = vec![(entry_point, initial_state)];
let mut visited = HashSet::new();
while let Some((statement_idx, state)) = queue.pop() {
if !visited.insert(statement_idx) {
continue;
}
match closure(statement_idx, state.clone())? {
StatementCompileResult::Processed(branch_states) => {
let branches = match &statements[statement_idx.0] {
Statement::Invocation(x) => x.branches.as_slice(),
Statement::Return(_) => &[],
};
assert_eq!(
branches.len(),
branch_states.len(),
"Returned number of states must match the number of branches."
);
queue.extend(
branches
.iter()
.map(|branch| statement_idx.next(branch.target))
.zip(branch_states),
);
}
StatementCompileResult::Deferred => {
tracing::trace!("Statement {statement_idx}'s compilation has been deferred.");
visited.remove(&statement_idx);
queue.insert(0, (statement_idx, state));
}
}
}
Ok(())
}
fn generate_branching_targets<'ctx, 'this, 'a>(
blocks: &'this BlockStorage<'ctx, 'this>,
statements: &'this [Statement],
statement_idx: StatementIdx,
invocation: &'this Invocation,
state: &OrderedHashMap<VarId, Value<'ctx, 'this>>,
) -> Vec<(&'this Block<'ctx>, Vec<BranchArg<'ctx, 'this>>)>
where
'this: 'ctx,
{
invocation
.branches
.iter()
.map(move |branch| {
let target_idx = statement_idx.next(branch.target);
let (landing_block, block) = &blocks[&target_idx];
match landing_block {
Some((landing_block, state_vars)) => {
let target_vars = state_vars
.iter()
.map(|var_id| {
match branch.results.iter().find_position(|id| *id == var_id) {
Some((i, _)) => BranchArg::Returned(i),
None => BranchArg::External(state[var_id]),
}
})
.collect::<Vec<_>>();
(landing_block.deref(), target_vars)
}
None => {
let target_vars = match &statements[target_idx.0] {
Statement::Invocation(x) => &x.args,
Statement::Return(x) => x,
}
.iter()
.map(|var_id| {
match branch
.results
.iter()
.enumerate()
.find_map(|(i, id)| (id == var_id).then_some(i))
{
Some(i) => BranchArg::Returned(i),
None => BranchArg::External(state[var_id]),
}
})
.collect::<Vec<_>>();
(block.deref(), target_vars)
}
}
})
.collect()
}
#[allow(clippy::too_many_arguments)]
fn generate_entry_point_wrapper<'c>(
context: &'c Context,
module: &Module<'c>,
public_symbol: &str,
private_symbol: &str,
arg_types: &[(Type<'c>, Location<'c>)],
ret_types: &[Type<'c>],
location: Location<'c>,
) -> Result<(), Error> {
let region = Region::new();
let block = region.append_block(Block::new(arg_types));
let mut args = Vec::with_capacity(arg_types.len());
for i in 0..arg_types.len() {
args.push(block.argument(i)?.into());
}
let result = block.append_op_result(
OperationBuilder::new("llvm.call", location)
.add_attributes(&[
(
Identifier::new(context, "callee"),
FlatSymbolRefAttribute::new(context, private_symbol).into(),
),
(
Identifier::new(context, "CConv"),
Attribute::parse(context, "#llvm.cconv<fastcc>")
.ok_or(Error::ParseAttributeError)?,
),
])
.add_operands(&args)
.add_results(&[llvm::r#type::r#struct(context, ret_types, false)])
.build()?,
)?;
let mut returns = Vec::with_capacity(ret_types.len());
for (i, ty) in ret_types.iter().enumerate() {
returns.push(block.extract_value(context, location, result, *ty, i)?);
}
block.append_operation(func::r#return(&returns, location));
module.body().append_operation(func::func(
context,
StringAttribute::new(context, public_symbol),
TypeAttribute::new(
FunctionType::new(
context,
&arg_types.iter().map(|x| x.0).collect::<Vec<_>>(),
ret_types,
)
.into(),
),
region,
&[
(
Identifier::new(context, "sym_visibility"),
StringAttribute::new(context, "public").into(),
),
(
Identifier::new(context, "llvm.linkage"),
Attribute::parse(context, "#llvm.linkage<private>")
.ok_or(Error::ParseAttributeError)?,
),
(
Identifier::new(context, "llvm.CConv"),
Attribute::parse(context, "#llvm.cconv<fastcc>")
.ok_or(Error::ParseAttributeError)?,
),
(
Identifier::new(context, "llvm.emit_c_interface"),
Attribute::unit(context),
),
],
location,
));
Ok(())
}
#[derive(Clone, Debug)]
enum StatementCompileResult<T> {
Processed(T),
Deferred,
}