use crate::{
error::{panic::ToNativeAssertError, Error as CoreLibfuncBuilderError, Result},
metadata::MetadataStorage,
native_panic,
types::TypeBuilder,
utils::ProgramRegistryExt,
};
use bumpalo::Bump;
use cairo_lang_sierra::{
extensions::{
core::{CoreConcreteLibfunc, CoreLibfunc, CoreType, CoreTypeConcrete},
int::{
signed::{Sint16Traits, Sint32Traits, Sint64Traits, Sint8Traits},
unsigned::{Uint16Traits, Uint32Traits, Uint64Traits, Uint8Traits},
},
lib_func::{BranchSignature, ParamSignature},
starknet::StarknetTypeConcrete,
ConcreteLibfunc,
},
ids::FunctionId,
program_registry::ProgramRegistry,
};
use itertools::Itertools;
use melior::{
dialect::{arith, cf, llvm, ods},
helpers::{ArithBlockExt, BuiltinBlockExt, LlvmBlockExt},
ir::{
attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute},
operation::OperationBuilder,
r#type::IntegerType,
Attribute, Block, BlockLike, BlockRef, Location, Module, Region, Value,
},
Context,
};
use num_bigint::BigInt;
use std::{
cell::Cell,
error::Error,
ops::Deref,
sync::atomic::{AtomicBool, Ordering},
};
mod array;
mod r#bool;
mod bounded_int;
mod r#box;
mod bytes31;
mod cast;
mod circuit;
mod r#const;
mod coupon;
mod debug;
mod drop;
mod dup;
mod ec;
mod r#enum;
mod felt252;
mod felt252_dict;
mod felt252_dict_entry;
mod function_call;
mod gas;
mod int;
mod int_range;
mod mem;
mod nullable;
mod pedersen;
mod poseidon;
mod starknet;
mod r#struct;
mod uint256;
mod uint512;
pub trait LibfuncBuilder {
type Error: Error;
fn build<'ctx, 'this>(
&self,
context: &'ctx Context,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
entry: &'this Block<'ctx>,
location: Location<'ctx>,
helper: &LibfuncHelper<'ctx, 'this>,
metadata: &mut MetadataStorage,
) -> Result<()>;
fn is_function_call(&self) -> Option<&FunctionId>;
}
impl LibfuncBuilder for CoreConcreteLibfunc {
type Error = CoreLibfuncBuilderError;
fn build<'ctx, 'this>(
&self,
context: &'ctx Context,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
entry: &'this Block<'ctx>,
location: Location<'ctx>,
helper: &LibfuncHelper<'ctx, 'this>,
metadata: &mut MetadataStorage,
) -> Result<()> {
match self {
Self::ApTracking(_) | Self::BranchAlign(_) | Self::UnconditionalJump(_) => {
build_noop::<0, false>(
context,
registry,
entry,
location,
helper,
metadata,
self.param_signatures(),
)
}
Self::Array(selector) => self::array::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Bool(selector) => self::r#bool::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::BoundedInt(info) => {
self::bounded_int::build(context, registry, entry, location, helper, metadata, info)
}
Self::Box(selector) => self::r#box::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Bytes31(selector) => self::bytes31::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Cast(selector) => self::cast::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Circuit(info) => {
self::circuit::build(context, registry, entry, location, helper, metadata, info)
}
Self::Const(selector) => self::r#const::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Coupon(selector) => self::coupon::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::CouponCall(info) => self::function_call::build(
context, registry, entry, location, helper, metadata, info,
),
Self::Debug(selector) => self::debug::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Trace(_) => native_panic!("Implement trace libfunc"),
Self::Drop(info) => {
self::drop::build(context, registry, entry, location, helper, metadata, info)
}
Self::Dup(info) | Self::SnapshotTake(info) => {
self::dup::build(context, registry, entry, location, helper, metadata, info)
}
Self::Ec(selector) => self::ec::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Enum(selector) => self::r#enum::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Felt252(selector) => self::felt252::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Felt252Dict(selector) => self::felt252_dict::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Felt252SquashedDict(_) => {
native_panic!("Implement felt252_squashed_dict libfunc")
}
Self::Felt252DictEntry(selector) => self::felt252_dict_entry::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::FunctionCall(info) => self::function_call::build(
context, registry, entry, location, helper, metadata, info,
),
Self::Gas(selector) => self::gas::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::IntRange(selector) => self::int_range::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Blake(_) => native_panic!("Implement blake libfunc"),
Self::Mem(selector) => self::mem::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Nullable(selector) => self::nullable::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Pedersen(selector) => self::pedersen::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Poseidon(selector) => self::poseidon::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Sint8(selector) => self::int::build_signed::<Sint8Traits>(
context, registry, entry, location, helper, metadata, selector,
),
Self::Sint16(selector) => self::int::build_signed::<Sint16Traits>(
context, registry, entry, location, helper, metadata, selector,
),
Self::Sint32(selector) => self::int::build_signed::<Sint32Traits>(
context, registry, entry, location, helper, metadata, selector,
),
Self::Sint64(selector) => self::int::build_signed::<Sint64Traits>(
context, registry, entry, location, helper, metadata, selector,
),
Self::Sint128(selector) => self::int::build_i128(
context, registry, entry, location, helper, metadata, selector,
),
Self::Starknet(selector) => self::starknet::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Struct(selector) => self::r#struct::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Uint8(selector) => self::int::build_unsigned::<Uint8Traits>(
context, registry, entry, location, helper, metadata, selector,
),
Self::Uint16(selector) => self::int::build_unsigned::<Uint16Traits>(
context, registry, entry, location, helper, metadata, selector,
),
Self::Uint32(selector) => self::int::build_unsigned::<Uint32Traits>(
context, registry, entry, location, helper, metadata, selector,
),
Self::Uint64(selector) => self::int::build_unsigned::<Uint64Traits>(
context, registry, entry, location, helper, metadata, selector,
),
Self::Uint128(selector) => self::int::build_u128(
context, registry, entry, location, helper, metadata, selector,
),
Self::Uint256(selector) => self::uint256::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::Uint512(selector) => self::uint512::build(
context, registry, entry, location, helper, metadata, selector,
),
Self::UnwrapNonZero(info) => build_noop::<1, false>(
context,
registry,
entry,
location,
helper,
metadata,
&info.signature.param_signatures,
),
Self::QM31(_) => native_panic!("Implement QM31 libfunc"),
Self::UnsafePanic(_) => native_panic!("Implement unsafe_panic libfunc"),
Self::GasReserve(_) => native_panic!("Implement gas_reserve libfunc"),
Self::DummyFunctionCall(_) => native_panic!("Implement dummy_function_call libfunc"),
}
}
fn is_function_call(&self) -> Option<&FunctionId> {
match self {
CoreConcreteLibfunc::FunctionCall(info) => Some(&info.function.id),
CoreConcreteLibfunc::CouponCall(info) => Some(&info.function.id),
_ => None,
}
}
}
pub struct LibfuncHelper<'ctx, 'this>
where
'this: 'ctx,
{
pub module: &'this Module<'ctx>,
pub init_block: &'this BlockRef<'ctx, 'this>,
pub region: &'this Region<'ctx>,
pub blocks_arena: &'this Bump,
pub last_block: Cell<&'this BlockRef<'ctx, 'this>>,
pub branches: Vec<(&'this Block<'ctx>, Vec<BranchArg<'ctx, 'this>>)>,
pub results: Vec<Vec<Cell<Option<Value<'ctx, 'this>>>>>,
#[cfg(feature = "with-libfunc-profiling")]
pub profiler: Option<(
crate::metadata::profiler::ProfilerMeta,
cairo_lang_sierra::program::StatementIdx,
(Value<'ctx, 'this>, Value<'ctx, 'this>),
)>,
}
impl<'ctx, 'this> LibfuncHelper<'ctx, 'this>
where
'this: 'ctx,
{
#[doc(hidden)]
pub(crate) fn results(self) -> Result<Vec<Vec<Value<'ctx, 'this>>>> {
self.results
.into_iter()
.enumerate()
.map(|(branch_idx, x)| {
x.into_iter()
.enumerate()
.map(|(arg_idx, x)| {
x.into_inner().to_native_assert_error(&format!(
"Argument #{arg_idx} of branch {branch_idx} doesn't have a value."
))
})
.collect()
})
.collect()
}
pub fn init_block(&self) -> &'this Block<'ctx> {
self.init_block
}
pub fn append_block(&self, block: Block<'ctx>) -> &'this Block<'ctx> {
let block = self
.region
.insert_block_after(*self.last_block.get(), block);
let block_ref: &'this mut BlockRef<'ctx, 'this> = self.blocks_arena.alloc(block);
self.last_block.set(block_ref);
block_ref
}
fn br(
&self,
block: &'this Block<'ctx>,
branch: usize,
results: &[Value<'ctx, 'this>],
location: Location<'ctx>,
) -> Result<()> {
let (successor, operands) = &self.branches[branch];
for (dst, src) in self.results[branch].iter().zip(results) {
dst.replace(Some(*src));
}
let destination_operands = operands
.iter()
.copied()
.map(|op| match op {
BranchArg::External(x) => x,
BranchArg::Returned(i) => results[i],
})
.collect::<Vec<_>>();
#[cfg(feature = "with-libfunc-profiling")]
self.push_profiler_frame(
unsafe { self.context().to_ref() },
self.module,
block,
location,
)?;
block.append_operation(cf::br(successor, &destination_operands, location));
Ok(())
}
fn cond_br(
&self,
context: &'ctx Context,
block: &'this Block<'ctx>,
condition: Value<'ctx, 'this>,
branches: [usize; 2],
results: [&[Value<'ctx, 'this>]; 2],
location: Location<'ctx>,
) -> Result<()> {
let (block_true, args_true) = {
let (successor, operands) = &self.branches[branches[0]];
for (dst, src) in self.results[branches[0]].iter().zip(results[0]) {
dst.replace(Some(*src));
}
let destination_operands = operands
.iter()
.copied()
.map(|op| match op {
BranchArg::External(x) => x,
BranchArg::Returned(i) => results[0][i],
})
.collect::<Vec<_>>();
(*successor, destination_operands)
};
let (block_false, args_false) = {
let (successor, operands) = &self.branches[branches[1]];
for (dst, src) in self.results[branches[1]].iter().zip(results[1]) {
dst.replace(Some(*src));
}
let destination_operands = operands
.iter()
.copied()
.map(|op| match op {
BranchArg::External(x) => x,
BranchArg::Returned(i) => results[1][i],
})
.collect::<Vec<_>>();
(*successor, destination_operands)
};
#[cfg(feature = "with-libfunc-profiling")]
self.push_profiler_frame(context, self.module, block, location)?;
block.append_operation(cf::cond_br(
context,
condition,
block_true,
block_false,
&args_true,
&args_false,
location,
));
Ok(())
}
#[cfg(feature = "with-libfunc-profiling")]
fn push_profiler_frame(
&self,
context: &'ctx Context,
module: &'this Module,
block: &'this Block<'ctx>,
location: Location<'ctx>,
) -> Result<()> {
if let Some((profiler_meta, statement_idx, t0)) = self.profiler.as_ref() {
let t0 = *t0;
let t1 = profiler_meta.measure_timestamp(context, block, location)?;
profiler_meta.push_frame(context, module, block, statement_idx.0, t0, t1, location)?;
}
Ok(())
}
}
impl<'ctx> Deref for LibfuncHelper<'ctx, '_> {
type Target = Module<'ctx>;
fn deref(&self) -> &Self::Target {
self.module
}
}
#[derive(Clone, Copy, Debug)]
pub enum BranchArg<'ctx, 'this> {
External(Value<'ctx, 'this>),
Returned(usize),
}
fn increment_builtin_counter<'ctx: 'a, 'a>(
context: &'ctx Context,
block: &'ctx Block<'ctx>,
location: Location<'ctx>,
value: Value<'ctx, '_>,
) -> crate::error::Result<Value<'ctx, 'a>> {
increment_builtin_counter_by(context, block, location, value, 1)
}
fn increment_builtin_counter_by<'ctx: 'a, 'a>(
context: &'ctx Context,
block: &'ctx Block<'ctx>,
location: Location<'ctx>,
value: Value<'ctx, '_>,
amount: impl Into<BigInt>,
) -> crate::error::Result<Value<'ctx, 'a>> {
Ok(block.append_op_result(arith::addi(
value,
block.const_int(context, location, amount.into(), 64)?,
location,
))?)
}
fn increment_builtin_counter_conditionally_by<'ctx: 'a, 'a>(
context: &'ctx Context,
block: &'ctx Block<'ctx>,
location: Location<'ctx>,
value_to_inc: Value<'ctx, '_>,
true_amount: impl Into<BigInt>,
false_amount: impl Into<BigInt>,
condition: Value<'ctx, '_>,
) -> crate::error::Result<Value<'ctx, 'a>> {
let true_amount_value = block.const_int(context, location, true_amount.into(), 64)?;
let false_amount_value = block.const_int(context, location, false_amount.into(), 64)?;
let true_incremented =
block.append_op_result(arith::addi(value_to_inc, true_amount_value, location))?;
let false_incremented =
block.append_op_result(arith::addi(value_to_inc, false_amount_value, location))?;
Ok(block.append_op_result(arith::select(
condition,
true_incremented,
false_incremented,
location,
))?)
}
fn build_noop<'ctx, 'this, const N: usize, const PROCESS_BUILTINS: bool>(
context: &'ctx Context,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
entry: &'this Block<'ctx>,
location: Location<'ctx>,
helper: &LibfuncHelper<'ctx, 'this>,
_metadata: &mut MetadataStorage,
param_signatures: &[ParamSignature],
) -> Result<()> {
let mut params = Vec::with_capacity(N);
#[allow(clippy::needless_range_loop)]
for i in 0..N {
let param_ty = registry.get_type(¶m_signatures[i].ty)?;
let mut param_val = entry.argument(i)?.into();
if PROCESS_BUILTINS
&& param_ty.is_builtin()
&& !matches!(
param_ty,
CoreTypeConcrete::BuiltinCosts(_)
| CoreTypeConcrete::Coupon(_)
| CoreTypeConcrete::GasBuiltin(_)
| CoreTypeConcrete::Starknet(StarknetTypeConcrete::System(_))
)
{
param_val = increment_builtin_counter(context, entry, location, param_val)?;
}
params.push(param_val);
}
helper.br(entry, 0, ¶ms, location)
}
#[allow(dead_code)]
pub fn build_mock_libfunc<'ctx, 'this>(
context: &'ctx Context,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
entry: &'this Block<'ctx>,
location: Location<'ctx>,
helper: &LibfuncHelper<'ctx, 'this>,
metadata: &mut MetadataStorage,
branch_signatures: &[BranchSignature],
) -> Result<()> {
let mut args = Vec::new();
for arg_idx in 0..entry.argument_count() {
args.push(entry.arg(arg_idx)?);
}
let flag_type = IntegerType::new(context, 8).into();
let ptr_type = llvm::r#type::pointer(context, 0);
let result_type = llvm::r#type::r#struct(context, &[flag_type, ptr_type], false);
let result_ptr = build_mock_runtime_call(context, helper, entry, &args, location)?;
let result = entry.load(context, location, result_ptr, result_type)?;
let flag = entry.extract_value(context, location, result, flag_type, 0)?;
let payload_ptr = entry.extract_value(context, location, result, ptr_type, 1)?;
let branches_idxs = (0..branch_signatures.len()).collect_vec();
let default_block = {
let block = helper.append_block(Block::new(&[]));
block.append_operation(llvm::unreachable(location));
block
};
let mut destinations = Vec::new();
for &branch_idx in &branches_idxs {
let block = helper.append_block(Block::new(&[]));
let mut branch_types = Vec::new();
for branch_var in &branch_signatures[branch_idx].vars {
let branch_var_type = registry.build_type(context, helper, metadata, &branch_var.ty)?;
branch_types.push(branch_var_type);
}
let branch_type = llvm::r#type::r#struct(
context,
&(0..branch_types.len()).map(|_| ptr_type).collect_vec(),
false,
);
let branch_result = block.load(context, location, payload_ptr, branch_type)?;
let mut branch_results = Vec::new();
for (var_idx, var_type) in branch_types.iter().enumerate() {
let var_ptr =
block.extract_value(context, location, branch_result, ptr_type, var_idx)?;
let var = block.load(context, location, var_ptr, *var_type)?;
branch_results.push(var);
}
helper.br(block, branch_idx, &branch_results, location)?;
let operands: &[Value] = &[];
destinations.push((block, operands));
}
entry.append_operation(cf::switch(
context,
&branches_idxs.iter().map(|&x| x as i64).collect_vec(),
flag,
flag_type,
(default_block, &[]),
&destinations[..],
location,
)?);
Ok(())
}
#[allow(dead_code)]
pub fn build_mock_runtime_call<'c, 'a>(
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
args: &[Value<'c, 'a>],
location: Location<'c>,
) -> Result<Value<'c, 'a>> {
let ptr_type = llvm::r#type::pointer(context, 0);
static MOCK_RUNTIME_SYMBOL_DECLARED: AtomicBool = AtomicBool::new(false);
if !MOCK_RUNTIME_SYMBOL_DECLARED.swap(true, Ordering::Relaxed) {
module.body().append_operation(
ods::llvm::mlir_global(
context,
Region::new(),
TypeAttribute::new(ptr_type),
StringAttribute::new(context, "cairo_native__mock"),
Attribute::parse(context, "#llvm.linkage<weak>")
.ok_or(CoreLibfuncBuilderError::ParseAttributeError)?,
location,
)
.into(),
);
}
let function_ptr_ptr = block.append_op_result(
ods::llvm::mlir_addressof(
context,
ptr_type,
FlatSymbolRefAttribute::new(context, "cairo_native__mock"),
location,
)
.into(),
)?;
let function_ptr = block.load(context, location, function_ptr_ptr, ptr_type)?;
let result = block.append_op_result(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function_ptr])
.add_operands(args)
.add_results(&[llvm::r#type::pointer(context, 0)])
.build()?,
)?;
Ok(result)
}