use crate::{
error::{Error, Result},
libfuncs::LibfuncHelper,
utils::get_integer_layout,
};
use cairo_lang_sierra::extensions::qm31::QM31BinaryOperator;
use itertools::Itertools;
use melior::{
dialect::{
arith::{self, CmpiPredicate},
cf, llvm, ods,
},
helpers::{ArithBlockExt, BuiltinBlockExt, LlvmBlockExt},
ir::{
attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute},
operation::OperationBuilder,
r#type::IntegerType,
Attribute, Block, BlockLike, Identifier, Location, Module, OperationRef, Region, Type,
Value,
},
Context,
};
use std::{
alloc::Layout,
collections::HashSet,
ffi::{c_int, c_void},
};
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
enum RuntimeBinding {
Pedersen,
HadesPermutation,
EcStateTryFinalizeNz,
EcStateAddMul,
EcStateInit,
EcStateAdd,
EcPointTryNewNz,
EcPointFromXNz,
DictNew,
DictGet,
DictSquash,
GetCostsBuiltin,
BlakeCompress,
DebugPrint,
U31ExtendedEuclideanAlgorithm,
U252ExtendedEuclideanAlgorithm,
U384ExtendedEuclideanAlgorithm,
CircuitArithOperation,
DictIntoEntries,
QM31Add,
QM31Sub,
QM31Mul,
QM31Div,
ArenaAlloc,
#[cfg(feature = "with-cheatcode")]
VtableCheatcode,
}
impl RuntimeBinding {
const fn symbol(self) -> &'static str {
match self {
RuntimeBinding::DebugPrint => "cairo_native__libfunc__debug__print",
RuntimeBinding::Pedersen => "cairo_native__libfunc__pedersen",
RuntimeBinding::HadesPermutation => "cairo_native__libfunc__hades_permutation",
RuntimeBinding::EcStateTryFinalizeNz => {
"cairo_native__libfunc__ec__ec_state_try_finalize_nz"
}
RuntimeBinding::EcStateAddMul => "cairo_native__libfunc__ec__ec_state_add_mul",
RuntimeBinding::EcStateInit => "cairo_native__libfunc__ec__ec_state_init",
RuntimeBinding::EcStateAdd => "cairo_native__libfunc__ec__ec_state_add",
RuntimeBinding::EcPointTryNewNz => "cairo_native__libfunc__ec__ec_point_try_new_nz",
RuntimeBinding::EcPointFromXNz => "cairo_native__libfunc__ec__ec_point_from_x_nz",
RuntimeBinding::DictNew => "cairo_native__dict_new",
RuntimeBinding::DictGet => "cairo_native__dict_get",
RuntimeBinding::DictSquash => "cairo_native__dict_squash",
RuntimeBinding::GetCostsBuiltin => "cairo_native__get_costs_builtin",
RuntimeBinding::BlakeCompress => "cairo_native__libfunc__blake_compress",
RuntimeBinding::U31ExtendedEuclideanAlgorithm => {
"cairo_native__u31_extended_euclidean_algorithm"
}
RuntimeBinding::U252ExtendedEuclideanAlgorithm => {
"cairo_native__u252_extended_euclidean_algorithm"
}
RuntimeBinding::U384ExtendedEuclideanAlgorithm => {
"cairo_native__u384_extended_euclidean_algorithm"
}
RuntimeBinding::CircuitArithOperation => "cairo_native__circuit_arith_operation",
RuntimeBinding::DictIntoEntries => "cairo_native__dict_into_entries",
RuntimeBinding::QM31Add => "cairo_native__libfunc__qm31__qm31_add",
RuntimeBinding::QM31Sub => "cairo_native__libfunc__qm31__qm31_sub",
RuntimeBinding::QM31Mul => "cairo_native__libfunc__qm31__qm31_mul",
RuntimeBinding::QM31Div => "cairo_native__libfunc__qm31__qm31_div",
RuntimeBinding::ArenaAlloc => "cairo_native__arena_alloc",
#[cfg(feature = "with-cheatcode")]
RuntimeBinding::VtableCheatcode => "cairo_native__vtable_cheatcode",
}
}
const fn function_ptr(self) -> Option<*const ()> {
let function_ptr = match self {
RuntimeBinding::DebugPrint => {
crate::runtime::cairo_native__libfunc__debug__print as *const ()
}
RuntimeBinding::Pedersen => {
crate::runtime::cairo_native__libfunc__pedersen as *const ()
}
RuntimeBinding::HadesPermutation => {
crate::runtime::cairo_native__libfunc__hades_permutation as *const ()
}
RuntimeBinding::EcStateTryFinalizeNz => {
crate::runtime::cairo_native__libfunc__ec__ec_state_try_finalize_nz as *const ()
}
RuntimeBinding::EcStateAddMul => {
crate::runtime::cairo_native__libfunc__ec__ec_state_add_mul as *const ()
}
RuntimeBinding::EcStateInit => {
crate::runtime::cairo_native__libfunc__ec__ec_state_init as *const ()
}
RuntimeBinding::EcStateAdd => {
crate::runtime::cairo_native__libfunc__ec__ec_state_add as *const ()
}
RuntimeBinding::EcPointTryNewNz => {
crate::runtime::cairo_native__libfunc__ec__ec_point_try_new_nz as *const ()
}
RuntimeBinding::EcPointFromXNz => {
crate::runtime::cairo_native__libfunc__ec__ec_point_from_x_nz as *const ()
}
RuntimeBinding::DictNew => crate::runtime::cairo_native__dict_new as *const (),
RuntimeBinding::DictGet => crate::runtime::cairo_native__dict_get as *const (),
RuntimeBinding::DictSquash => crate::runtime::cairo_native__dict_squash as *const (),
RuntimeBinding::GetCostsBuiltin => {
crate::runtime::cairo_native__get_costs_builtin as *const ()
}
RuntimeBinding::DictIntoEntries => {
crate::runtime::cairo_native__dict_into_entries as *const ()
}
RuntimeBinding::QM31Add => {
crate::runtime::cairo_native__libfunc__qm31__qm31_add as *const ()
}
RuntimeBinding::QM31Sub => {
crate::runtime::cairo_native__libfunc__qm31__qm31_sub as *const ()
}
RuntimeBinding::QM31Mul => {
crate::runtime::cairo_native__libfunc__qm31__qm31_mul as *const ()
}
RuntimeBinding::QM31Div => {
crate::runtime::cairo_native__libfunc__qm31__qm31_div as *const ()
}
RuntimeBinding::BlakeCompress => {
crate::runtime::cairo_native__libfunc__blake_compress as *const ()
}
RuntimeBinding::U31ExtendedEuclideanAlgorithm
| RuntimeBinding::U252ExtendedEuclideanAlgorithm
| RuntimeBinding::U384ExtendedEuclideanAlgorithm => return None,
RuntimeBinding::CircuitArithOperation => return None,
RuntimeBinding::ArenaAlloc => crate::runtime::cairo_native__arena_alloc as *const (),
#[cfg(feature = "with-cheatcode")]
RuntimeBinding::VtableCheatcode => {
crate::starknet::cairo_native__vtable_cheatcode as *const ()
}
};
Some(function_ptr)
}
}
#[repr(u8)]
#[derive(Clone, Copy)]
pub enum CircuitArithOperationType {
Add,
Sub,
Mul,
}
#[derive(Debug, Default)]
pub struct RuntimeBindingsMeta {
active_map: HashSet<RuntimeBinding>,
}
impl RuntimeBindingsMeta {
fn build_function<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
location: Location<'c>,
binding: RuntimeBinding,
) -> Result<Value<'c, 'a>> {
if self.active_map.insert(binding) {
module.body().append_operation(
ods::llvm::mlir_global(
context,
Region::new(),
TypeAttribute::new(llvm::r#type::pointer(context, 0)),
StringAttribute::new(context, binding.symbol()),
Attribute::parse(context, "#llvm.linkage<weak>")
.ok_or(Error::ParseAttributeError)?,
location,
)
.into(),
);
}
let global_address = block.append_op_result(
ods::llvm::mlir_addressof(
context,
llvm::r#type::pointer(context, 0),
FlatSymbolRefAttribute::new(context, binding.symbol()),
location,
)
.into(),
)?;
Ok(block.load(
context,
location,
global_address,
llvm::r#type::pointer(context, 0),
)?)
}
pub fn u31_extended_euclidean_algorithm<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
location: Location<'c>,
a: Value<'c, '_>,
b: Value<'c, '_>,
) -> Result<Value<'c, 'a>>
where
'c: 'a,
{
let integer_type = IntegerType::new(context, 31).into();
let func_symbol = RuntimeBinding::U31ExtendedEuclideanAlgorithm.symbol();
if self
.active_map
.insert(RuntimeBinding::U31ExtendedEuclideanAlgorithm)
{
build_egcd_function(module, context, location, func_symbol, integer_type)?;
}
let return_type = llvm::r#type::r#struct(context, &[integer_type, integer_type], false);
Ok(block
.append_operation(
OperationBuilder::new("llvm.call", location)
.add_attributes(&[(
Identifier::new(context, "callee"),
FlatSymbolRefAttribute::new(context, func_symbol).into(),
)])
.add_operands(&[a, b])
.add_results(&[return_type])
.build()?,
)
.result(0)?
.into())
}
pub fn u252_extended_euclidean_algorithm<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
location: Location<'c>,
a: Value<'c, '_>,
b: Value<'c, '_>,
) -> Result<Value<'c, 'a>>
where
'c: 'a,
{
let integer_type = IntegerType::new(context, 252).into();
let func_symbol = RuntimeBinding::U252ExtendedEuclideanAlgorithm.symbol();
if self
.active_map
.insert(RuntimeBinding::U252ExtendedEuclideanAlgorithm)
{
build_egcd_function(module, context, location, func_symbol, integer_type)?;
}
let return_type = llvm::r#type::r#struct(context, &[integer_type, integer_type], false);
Ok(block
.append_operation(
OperationBuilder::new("llvm.call", location)
.add_attributes(&[(
Identifier::new(context, "callee"),
FlatSymbolRefAttribute::new(context, func_symbol).into(),
)])
.add_operands(&[a, b])
.add_results(&[return_type])
.build()?,
)
.result(0)?
.into())
}
pub fn u384_extended_euclidean_algorithm<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
location: Location<'c>,
a: Value<'c, '_>,
b: Value<'c, '_>,
) -> Result<Value<'c, 'a>>
where
'c: 'a,
{
let integer_type = IntegerType::new(context, 384).into();
let func_symbol = RuntimeBinding::U384ExtendedEuclideanAlgorithm.symbol();
if self
.active_map
.insert(RuntimeBinding::U384ExtendedEuclideanAlgorithm)
{
build_egcd_function(module, context, location, func_symbol, integer_type)?;
}
let return_type = llvm::r#type::r#struct(context, &[integer_type, integer_type], false);
Ok(block
.append_operation(
OperationBuilder::new("llvm.call", location)
.add_attributes(&[(
Identifier::new(context, "callee"),
FlatSymbolRefAttribute::new(context, func_symbol).into(),
)])
.add_operands(&[a, b])
.add_results(&[return_type])
.build()?,
)
.result(0)?
.into())
}
#[allow(clippy::too_many_arguments)]
pub fn circuit_arith_operation<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
location: Location<'c>,
op_type: CircuitArithOperationType,
lhs_value: Value<'c, '_>,
rhs_value: Value<'c, '_>,
circuit_modulus: Value<'c, '_>,
) -> Result<Value<'c, 'a>>
where
'c: 'a,
{
let func_symbol = RuntimeBinding::CircuitArithOperation.symbol();
if self
.active_map
.insert(RuntimeBinding::CircuitArithOperation)
{
build_circuit_arith_operation(context, module, location, func_symbol)?;
}
let op_tag = block.const_int(context, location, op_type as u8, 2)?;
let return_type = IntegerType::new(context, 384).into();
Ok(block.append_op_result(
OperationBuilder::new("llvm.call", location)
.add_attributes(&[(
Identifier::new(context, "callee"),
FlatSymbolRefAttribute::new(context, func_symbol).into(),
)])
.add_operands(&[op_tag, lhs_value, rhs_value, circuit_modulus])
.add_results(&[return_type])
.build()?,
)?)
}
#[allow(clippy::too_many_arguments)]
pub fn libfunc_debug_print<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
target_fd: Value<'c, '_>,
values_ptr: Value<'c, '_>,
values_len: Value<'c, '_>,
location: Location<'c>,
) -> Result<Value<'c, 'a>>
where
'c: 'a,
{
let function =
self.build_function(context, module, block, location, RuntimeBinding::DebugPrint)?;
Ok(block
.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[target_fd, values_ptr, values_len])
.add_results(&[IntegerType::new(context, 32).into()])
.build()?,
)
.result(0)?
.into())
}
#[allow(clippy::too_many_arguments)]
pub fn libfunc_pedersen<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
dst_ptr: Value<'c, '_>,
lhs_ptr: Value<'c, '_>,
rhs_ptr: Value<'c, '_>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function =
self.build_function(context, module, block, location, RuntimeBinding::Pedersen)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[dst_ptr, lhs_ptr, rhs_ptr])
.build()?,
))
}
#[allow(clippy::too_many_arguments)]
pub fn libfunc_hades_permutation<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
op0_ptr: Value<'c, '_>,
op1_ptr: Value<'c, '_>,
op2_ptr: Value<'c, '_>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function = self.build_function(
context,
module,
block,
location,
RuntimeBinding::HadesPermutation,
)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[op0_ptr, op1_ptr, op2_ptr])
.build()?,
))
}
#[allow(clippy::too_many_arguments)]
pub fn libfunc_blake_compress<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
out_state: Value<'c, 'a>,
state: Value<'c, 'a>,
message: Value<'c, 'a>,
count_bytes: Value<'c, 'a>,
finalize: Value<'c, 'a>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function = self.build_function(
context,
module,
block,
location,
RuntimeBinding::BlakeCompress,
)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[out_state, state, message, count_bytes, finalize])
.build()?,
))
}
pub fn libfunc_ec_point_from_x_nz<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
point_ptr: Value<'c, '_>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function = self.build_function(
context,
module,
block,
location,
RuntimeBinding::EcPointFromXNz,
)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[point_ptr])
.add_results(&[IntegerType::new(context, 1).into()])
.build()?,
))
}
pub fn libfunc_ec_point_try_new_nz<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
point_ptr: Value<'c, '_>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function = self.build_function(
context,
module,
block,
location,
RuntimeBinding::EcPointTryNewNz,
)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[point_ptr])
.add_results(&[IntegerType::new(context, 1).into()])
.build()?,
))
}
pub fn libfunc_ec_state_init<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
state_ptr: Value<'c, '_>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function = self.build_function(
context,
module,
block,
location,
RuntimeBinding::EcStateInit,
)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[state_ptr])
.build()?,
))
}
pub fn libfunc_ec_state_add<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
state_ptr: Value<'c, '_>,
point_ptr: Value<'c, '_>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function =
self.build_function(context, module, block, location, RuntimeBinding::EcStateAdd)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[state_ptr, point_ptr])
.build()?,
))
}
#[allow(clippy::too_many_arguments)]
pub fn libfunc_ec_state_add_mul<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
state_ptr: Value<'c, '_>,
scalar_ptr: Value<'c, '_>,
point_ptr: Value<'c, '_>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function = self.build_function(
context,
module,
block,
location,
RuntimeBinding::EcStateAddMul,
)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[state_ptr, scalar_ptr, point_ptr])
.build()?,
))
}
pub fn libfunc_ec_state_try_finalize_nz<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
point_ptr: Value<'c, '_>,
state_ptr: Value<'c, '_>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function = self.build_function(
context,
module,
block,
location,
RuntimeBinding::EcStateTryFinalizeNz,
)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[point_ptr, state_ptr])
.add_results(&[IntegerType::new(context, 1).into()])
.build()?,
))
}
#[allow(clippy::too_many_arguments)]
pub fn libfunc_qm31_bin_op<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
lhs_ptr: Value<'c, '_>,
rhs_ptr: Value<'c, '_>,
op: QM31BinaryOperator,
location: Location<'c>,
) -> Result<Value<'c, 'a>>
where
'c: 'a,
{
let qm31_ty = llvm::r#type::array(IntegerType::new(context, 31).into(), 4);
let res_ptr = block.alloca1(context, location, qm31_ty, get_integer_layout(31).align())?;
let function = match op {
QM31BinaryOperator::Add => {
self.build_function(context, module, block, location, RuntimeBinding::QM31Add)?
}
QM31BinaryOperator::Sub => {
self.build_function(context, module, block, location, RuntimeBinding::QM31Sub)?
}
QM31BinaryOperator::Mul => {
self.build_function(context, module, block, location, RuntimeBinding::QM31Mul)?
}
QM31BinaryOperator::Div => {
self.build_function(context, module, block, location, RuntimeBinding::QM31Div)?
}
};
block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[lhs_ptr, rhs_ptr, res_ptr])
.build()?,
);
Ok(block.load(context, location, res_ptr, qm31_ty)?)
}
pub fn arena_alloc<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
location: Location<'c>,
size: Value<'c, 'a>,
align: Value<'c, 'a>,
) -> Result<Value<'c, 'a>>
where
'c: 'a,
{
let function =
self.build_function(context, module, block, location, RuntimeBinding::ArenaAlloc)?;
Ok(block.append_op_result(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[size, align])
.add_results(&[llvm::r#type::pointer(context, 0)])
.build()?,
)?)
}
#[allow(clippy::too_many_arguments)]
pub fn dict_new<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
location: Location<'c>,
layout: Layout,
) -> Result<Value<'c, 'a>>
where
'c: 'a,
{
let function =
self.build_function(context, module, block, location, RuntimeBinding::DictNew)?;
let i64_ty = IntegerType::new(context, 64).into();
let size = block.const_int_from_type(context, location, layout.size(), i64_ty)?;
let align = block.const_int_from_type(context, location, layout.align(), i64_ty)?;
Ok(block.append_op_result(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[size, align])
.add_results(&[llvm::r#type::pointer(context, 0)])
.build()?,
)?)
}
#[allow(clippy::too_many_arguments)]
pub fn dict_get<'c, 'a>(
&mut self,
context: &'c Context,
helper: &LibfuncHelper<'c, 'a>,
block: &'a Block<'c>,
dict_ptr: Value<'c, 'a>, key_ptr: Value<'c, 'a>, location: Location<'c>,
) -> Result<(Value<'c, 'a>, Value<'c, 'a>)>
where
'c: 'a,
{
let function =
self.build_function(context, helper, block, location, RuntimeBinding::DictGet)?;
let value_ptr = helper.init_block().alloca1(
context,
location,
llvm::r#type::pointer(context, 0),
align_of::<*mut ()>(),
)?;
let is_present = block.append_op_result(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[dict_ptr, key_ptr, value_ptr])
.add_results(&[IntegerType::new(context, c_int::BITS).into()])
.build()?,
)?;
let value_ptr = block.load(
context,
location,
value_ptr,
llvm::r#type::pointer(context, 0),
)?;
Ok((is_present, value_ptr))
}
#[allow(clippy::too_many_arguments)]
pub fn dict_squash<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
dict_ptr: Value<'c, 'a>, range_check_ptr: Value<'c, 'a>, gas_ptr: Value<'c, 'a>, location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function =
self.build_function(context, module, block, location, RuntimeBinding::DictSquash)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[dict_ptr, range_check_ptr, gas_ptr])
.add_results(&[IntegerType::new(context, 64).into()])
.build()?,
))
}
#[allow(clippy::too_many_arguments)]
pub fn dict_into_entries<'c, 'a>(
&mut self,
context: &'c Context,
helper: &LibfuncHelper<'c, 'a>,
block: &'a Block<'c>,
dict_ptr: Value<'c, 'a>,
array_ptr: Value<'c, 'a>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function = self.build_function(
context,
helper,
block,
location,
RuntimeBinding::DictIntoEntries,
)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[dict_ptr, array_ptr])
.build()?,
))
}
#[allow(clippy::too_many_arguments)]
pub fn get_costs_builtin<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function = self.build_function(
context,
module,
block,
location,
RuntimeBinding::GetCostsBuiltin,
)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_results(&[llvm::r#type::pointer(context, 0)])
.build()?,
))
}
#[allow(clippy::too_many_arguments)]
#[cfg(feature = "with-cheatcode")]
pub fn vtable_cheatcode<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
location: Location<'c>,
result_ptr: Value<'c, 'a>,
selector_ptr: Value<'c, 'a>,
args: Value<'c, 'a>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
let function = self.build_function(
context,
module,
block,
location,
RuntimeBinding::VtableCheatcode,
)?;
Ok(block.append_operation(
OperationBuilder::new("llvm.call", location)
.add_operands(&[function])
.add_operands(&[result_ptr, selector_ptr, args])
.build()?,
))
}
}
pub fn setup_runtime(find_symbol_ptr: impl Fn(&str) -> Option<*mut c_void>) {
for binding in [
RuntimeBinding::DebugPrint,
RuntimeBinding::Pedersen,
RuntimeBinding::HadesPermutation,
RuntimeBinding::EcStateTryFinalizeNz,
RuntimeBinding::EcStateAddMul,
RuntimeBinding::EcStateInit,
RuntimeBinding::EcStateAdd,
RuntimeBinding::EcPointTryNewNz,
RuntimeBinding::EcPointFromXNz,
RuntimeBinding::DictNew,
RuntimeBinding::DictGet,
RuntimeBinding::DictSquash,
RuntimeBinding::GetCostsBuiltin,
RuntimeBinding::BlakeCompress,
RuntimeBinding::DebugPrint,
RuntimeBinding::DictIntoEntries,
RuntimeBinding::QM31Add,
RuntimeBinding::QM31Sub,
RuntimeBinding::QM31Mul,
RuntimeBinding::QM31Div,
RuntimeBinding::ArenaAlloc,
#[cfg(feature = "with-cheatcode")]
RuntimeBinding::VtableCheatcode,
] {
if let Some(global) = find_symbol_ptr(binding.symbol()) {
let global = global.cast::<*const ()>();
unsafe {
if let Some(function_ptr) = binding.function_ptr() {
*global = function_ptr;
};
}
}
}
}
fn build_egcd_function<'ctx>(
module: &Module,
context: &'ctx Context,
location: Location<'ctx>,
func_symbol: &str,
integer_type: Type,
) -> Result<()> {
let region = Region::new();
let entry_block = region.append_block(Block::new(&[
(integer_type, location), (integer_type, location), ]));
let loop_block = region.append_block(Block::new(&[
(integer_type, location), (integer_type, location), (integer_type, location), (integer_type, location), ]));
let end_block = region.append_block(Block::new(&[
(integer_type, location), (integer_type, location), ]));
let modulus = entry_block.arg(1)?;
entry_block.append_operation(cf::br(
&loop_block,
&[
modulus, entry_block.arg(0)?,
entry_block.const_int_from_type(context, location, 0, integer_type)?,
entry_block.const_int_from_type(context, location, 1, integer_type)?,
],
location,
));
{
let old_r = loop_block.arg(0)?;
let new_r = loop_block.arg(1)?;
let old_s = loop_block.arg(2)?;
let new_s = loop_block.arg(3)?;
let quotient = loop_block.append_op_result(arith::divui(old_r, new_r, location))?;
let quotient_by_new_r = loop_block.muli(quotient, new_r, location)?;
let quotient_by_new_s = loop_block.muli(quotient, new_s, location)?;
let next_new_r =
loop_block.append_op_result(arith::subi(old_r, quotient_by_new_r, location))?;
let next_new_s =
loop_block.append_op_result(arith::subi(old_s, quotient_by_new_s, location))?;
let zero = loop_block.const_int_from_type(context, location, 0, integer_type)?;
let next_new_r_is_zero =
loop_block.cmpi(context, CmpiPredicate::Eq, next_new_r, zero, location)?;
loop_block.append_operation(cf::cond_br(
context,
next_new_r_is_zero,
&end_block,
&loop_block,
&[new_r, new_s],
&[new_r, next_new_r, new_s, next_new_s],
location,
));
}
{
let gcd = end_block.arg(0)?;
let beuzout_coeff = end_block.arg(1)?;
let zero = end_block.const_int_from_type(context, location, 0, integer_type)?;
let is_negative = end_block
.append_operation(arith::cmpi(
context,
CmpiPredicate::Slt,
beuzout_coeff,
zero,
location,
))
.result(0)?
.into();
let wrapped_beuzout_coeff = end_block.addi(beuzout_coeff, modulus, location)?;
let beuzout_coeff = end_block.append_op_result(arith::select(
is_negative,
wrapped_beuzout_coeff,
beuzout_coeff,
location,
))?;
let results = end_block.append_op_result(llvm::undef(
llvm::r#type::r#struct(context, &[integer_type, integer_type], false),
location,
))?;
let results = end_block.insert_values(context, location, results, &[gcd, beuzout_coeff])?;
end_block.append_operation(llvm::r#return(Some(results), location));
}
let func_name = StringAttribute::new(context, func_symbol);
module.body().append_operation(llvm::func(
context,
func_name,
TypeAttribute::new(llvm::r#type::function(
llvm::r#type::r#struct(context, &[integer_type, integer_type], false),
&[integer_type, integer_type],
false,
)),
region,
&[(
Identifier::new(context, "no_inline"), Attribute::unit(context),
)],
location,
));
Ok(())
}
fn build_circuit_arith_operation<'ctx>(
context: &'ctx Context,
module: &Module,
location: Location<'ctx>,
func_symbol: &str,
) -> Result<()> {
let func_name = StringAttribute::new(context, func_symbol);
let u2_ty = IntegerType::new(context, 2).into();
let u384_ty: Type = IntegerType::new(context, 384).into();
let u385_ty: Type = IntegerType::new(context, 385).into();
let u768_ty = IntegerType::new(context, 768).into();
let region = Region::new();
let entry_block = region.append_block(Block::new(&[
(u2_ty, location),
(u384_ty, location),
(u384_ty, location),
(u384_ty, location),
]));
let op_tag = entry_block.arg(0)?;
let lhs = entry_block.arg(1)?;
let rhs = entry_block.arg(2)?;
let modulus = entry_block.arg(3)?;
let ops = [
CircuitArithOperationType::Add,
CircuitArithOperationType::Sub,
CircuitArithOperationType::Mul,
];
let op_blocks = ops
.into_iter()
.map(|op| (op, Block::new(&[])))
.collect_vec();
let default_block = region.append_block(Block::new(&[]));
let cases_values = ops.iter().map(|&op| op as i64).collect_vec();
{
default_block.append_operation(llvm::unreachable(location));
}
for (tag, block) in op_blocks.iter() {
let result = match tag {
CircuitArithOperationType::Add => {
let lhs = block.extui(lhs, u385_ty, location)?;
let rhs = block.extui(rhs, u385_ty, location)?;
let modulus = block.extui(modulus, u385_ty, location)?;
let result = block.addi(lhs, rhs, location)?;
block.append_op_result(arith::remui(result, modulus, location))?
}
CircuitArithOperationType::Sub => {
let lhs = block.extui(lhs, u385_ty, location)?;
let rhs = block.extui(rhs, u385_ty, location)?;
let modulus = block.extui(modulus, u385_ty, location)?;
let partial_result = block.addi(lhs, modulus, location)?;
let result = block.subi(partial_result, rhs, location)?;
block.append_op_result(arith::remui(result, modulus, location))?
}
CircuitArithOperationType::Mul => {
let lhs = block.extui(lhs, u768_ty, location)?;
let rhs = block.extui(rhs, u768_ty, location)?;
let modulus = block.extui(modulus, u768_ty, location)?;
let result = block.muli(lhs, rhs, location)?;
block.append_op_result(arith::remui(result, modulus, location))?
}
};
let result = block.trunci(result, u384_ty, location)?;
block.append_operation(llvm::r#return(Some(result), location));
}
entry_block.append_operation(cf::switch(
context,
&cases_values,
op_tag,
u2_ty,
(&default_block, &[]),
&op_blocks
.iter()
.map(|(_, block)| (block, [].as_slice()))
.collect::<Vec<_>>(),
location,
)?);
for (_, block) in op_blocks.into_iter() {
region.append_block(block);
}
module.body().append_operation(llvm::func(
context,
func_name,
TypeAttribute::new(llvm::r#type::function(
u384_ty,
&[u2_ty, u384_ty, u384_ty, u384_ty],
false,
)),
region,
&[(
Identifier::new(context, "no_inline"),
Attribute::unit(context),
)],
location,
));
Ok(())
}