use crate::{
error::Error,
execution_result::{ContractExecutionResult, ExecutionResult},
metadata::{
felt252_dict::Felt252DictOverrides, gas::GasMetadata, runtime_bindings::setup_runtime,
},
module::NativeModule,
starknet::{DummySyscallHandler, StarknetSyscallHandler},
utils::generate_function_name,
values::Value,
OptLevel,
};
use cairo_lang_sierra::{
extensions::core::{CoreLibfunc, CoreType},
ids::{ConcreteTypeId, FunctionId},
program::FunctionSignature,
program_registry::ProgramRegistry,
};
use educe::Educe;
use libc::c_void;
use libloading::Library;
use starknet_types_core::felt::Felt;
use std::{io, mem::transmute};
use tempfile::NamedTempFile;
#[derive(Educe)]
#[educe(Debug)]
pub struct AotNativeExecutor {
#[educe(Debug(ignore))]
library: Library,
#[educe(Debug(ignore))]
registry: ProgramRegistry<CoreType, CoreLibfunc>,
gas_metadata: GasMetadata,
dict_overrides: Felt252DictOverrides,
}
unsafe impl Send for AotNativeExecutor {}
unsafe impl Sync for AotNativeExecutor {}
impl AotNativeExecutor {
pub fn new(
library: Library,
registry: ProgramRegistry<CoreType, CoreLibfunc>,
gas_metadata: GasMetadata,
dict_overrides: Felt252DictOverrides,
) -> Self {
let executor = Self {
library,
registry,
gas_metadata,
dict_overrides,
};
setup_runtime(|name| executor.find_symbol_ptr(name));
#[cfg(feature = "with-debug-utils")]
crate::metadata::debug_utils::setup_runtime(|name| executor.find_symbol_ptr(name));
#[cfg(feature = "with-trace-dump")]
crate::metadata::trace_dump::setup_runtime(|name| executor.find_symbol_ptr(name));
#[cfg(feature = "with-libfunc-profiling")]
crate::metadata::profiler::setup_runtime(|name| executor.find_symbol_ptr(name));
executor
}
pub fn from_native_module(module: NativeModule, opt_level: OptLevel) -> Result<Self, Error> {
let NativeModule {
module,
registry,
mut metadata,
} = module;
let library_path = NamedTempFile::new()?
.into_temp_path()
.keep()
.map_err(io::Error::from)?;
let object_data = crate::module_to_object(&module, opt_level, None)?;
crate::object_to_shared_lib(&object_data, &library_path, None)?;
Ok(Self::new(
unsafe { Library::new(&library_path)? },
registry,
metadata.remove().ok_or(Error::MissingMetadata)?,
metadata.remove().unwrap_or_default(),
))
}
pub fn invoke_dynamic(
&self,
function_id: &FunctionId,
args: &[Value],
gas: Option<u64>,
) -> Result<ExecutionResult, Error> {
let available_gas = self
.gas_metadata
.get_initial_available_gas(function_id, gas)?;
super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id)?,
self.extract_signature(function_id)?,
args,
available_gas,
Option::<DummySyscallHandler>::None,
self.build_find_dict_drop_override(),
)
}
pub fn invoke_dynamic_with_syscall_handler(
&self,
function_id: &FunctionId,
args: &[Value],
gas: Option<u64>,
syscall_handler: impl StarknetSyscallHandler,
) -> Result<ExecutionResult, Error> {
let available_gas = self
.gas_metadata
.get_initial_available_gas(function_id, gas)?;
super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id)?,
self.extract_signature(function_id)?,
args,
available_gas,
Some(syscall_handler),
self.build_find_dict_drop_override(),
)
}
pub fn invoke_contract_dynamic(
&self,
function_id: &FunctionId,
args: &[Felt],
gas: Option<u64>,
syscall_handler: impl StarknetSyscallHandler,
) -> Result<ContractExecutionResult, Error> {
let available_gas = self
.gas_metadata
.get_initial_available_gas(function_id, gas)?;
ContractExecutionResult::from_execution_result(super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id)?,
self.extract_signature(function_id)?,
&[Value::Struct {
fields: vec![Value::Array(
args.iter().cloned().map(Value::Felt252).collect(),
)],
debug_name: None,
}],
available_gas,
Some(syscall_handler),
self.build_find_dict_drop_override(),
)?)
}
pub fn find_function_ptr(&self, function_id: &FunctionId) -> Result<*mut c_void, Error> {
let function_name = generate_function_name(function_id, false);
let function_name = format!("_mlir_ciface_{function_name}");
unsafe {
Ok(self
.library
.get::<extern "C" fn()>(function_name.as_bytes())?
.into_raw()
.into_raw())
}
}
pub fn find_symbol_ptr(&self, name: &str) -> Option<*mut c_void> {
unsafe {
self.library
.get::<*mut ()>(name.as_bytes())
.ok()
.map(|x| x.into_raw().into_raw())
}
}
fn extract_signature(&self, function_id: &FunctionId) -> Result<&FunctionSignature, Error> {
Ok(&self.registry.get_function(function_id)?.signature)
}
fn build_find_dict_drop_override(
&self,
) -> impl '_ + Copy + Fn(&ConcreteTypeId) -> Option<extern "C" fn(*mut c_void)> {
|type_id| {
self.dict_overrides
.get_drop_fn(type_id)
.and_then(|symbol| self.find_symbol_ptr(symbol))
.map(|ptr| unsafe { transmute(ptr as *const ()) })
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
context::NativeContext, include_contract, starknet_stub::StubSyscallHandler,
utils::testing::load_program,
};
use cairo_lang_sierra::program::Program;
use rstest::*;
#[fixture]
fn program() -> Program {
load_program("test_data_artifacts/programs/executor_aot")
}
#[fixture]
fn starknet_program() -> Program {
include_contract!("test_data_artifacts/contracts/simple_storage_42.contract.json")
.extract_sierra_program(true)
.unwrap()
.program
}
#[rstest]
#[case(OptLevel::None)]
#[case(OptLevel::Default)]
#[case(OptLevel::Aggressive)]
fn test_invoke_dynamic(program: Program, #[case] optlevel: OptLevel) {
let native_context = NativeContext::new();
let module = native_context
.compile(&program, false, Some(Default::default()), None)
.expect("failed to compile context");
let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap();
let entrypoint_function_id = &program.funcs.first().expect("should have a function").id;
let result = executor
.invoke_dynamic(entrypoint_function_id, &[], Some(u64::MAX))
.unwrap();
assert_eq!(result.return_value, Value::Felt252(Felt::from(42)));
}
#[rstest]
#[case(OptLevel::None)]
#[case(OptLevel::Default)]
#[case(OptLevel::Aggressive)]
fn test_invoke_dynamic_with_syscall_handler(program: Program, #[case] optlevel: OptLevel) {
let native_context = NativeContext::new();
let module = native_context
.compile(&program, false, Some(Default::default()), None)
.expect("failed to compile context");
let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap();
let entrypoint_function_id = &program.funcs.get(1).expect("should have a function").id;
let syscall_handler = &mut StubSyscallHandler::default();
let expected_value = Felt::from(123);
syscall_handler.block_hash.insert(1, expected_value);
let result = executor
.invoke_dynamic_with_syscall_handler(
entrypoint_function_id,
&[],
Some(u64::MAX),
syscall_handler,
)
.unwrap();
let expected_value = Value::Enum {
tag: 0,
value: Value::Struct {
fields: vec![Value::Felt252(expected_value)],
debug_name: Some("Tuple<felt252>".into()),
}
.into(),
debug_name: Some("core::panics::PanicResult::<(core::felt252,)>".into()),
};
assert_eq!(result.return_value, expected_value);
}
#[rstest]
#[case(OptLevel::None)]
#[case(OptLevel::Default)]
#[case(OptLevel::Aggressive)]
fn test_invoke_contract_dynamic(starknet_program: Program, #[case] optlevel: OptLevel) {
let native_context = NativeContext::new();
let module = native_context
.compile(&starknet_program, false, Some(Default::default()), None)
.expect("failed to compile context");
let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap();
let entrypoint_function_id = &starknet_program
.funcs
.iter()
.find(|f| {
f.id.debug_name
.as_ref()
.map(|name| name.contains("__wrapper__ISimpleStorageImpl__get"))
.unwrap_or_default()
})
.expect("should have a function")
.id;
let result = executor
.invoke_contract_dynamic(
entrypoint_function_id,
&[],
Some(u64::MAX),
&mut StubSyscallHandler::default(),
)
.unwrap();
assert_eq!(result.return_values, vec![Felt::from(42)]);
}
}