use crate::{
error::Error,
execution_result::{ContractExecutionResult, ExecutionResult},
metadata::{gas::GasMetadata, runtime_bindings::setup_runtime},
module::NativeModule,
starknet::{DummySyscallHandler, StarknetSyscallHandler},
utils::generate_function_name,
values::Value,
OptLevel,
};
use cairo_lang_sierra::{
extensions::core::{CoreLibfunc, CoreType},
ids::FunctionId,
program::FunctionSignature,
program_registry::ProgramRegistry,
};
use educe::Educe;
use libc::c_void;
use libloading::Library;
use starknet_types_core::felt::Felt;
use std::io;
use tempfile::NamedTempFile;
#[derive(Educe)]
#[educe(Debug)]
pub struct AotNativeExecutor {
#[educe(Debug(ignore))]
library: Library,
#[educe(Debug(ignore))]
registry: ProgramRegistry<CoreType, CoreLibfunc>,
gas_metadata: GasMetadata,
}
unsafe impl Send for AotNativeExecutor {}
unsafe impl Sync for AotNativeExecutor {}
impl AotNativeExecutor {
pub fn new(
library: Library,
registry: ProgramRegistry<CoreType, CoreLibfunc>,
gas_metadata: GasMetadata,
) -> Self {
let executor = Self {
library,
registry,
gas_metadata,
};
setup_runtime(|name| executor.find_symbol_ptr(name));
#[cfg(feature = "with-debug-utils")]
crate::metadata::debug_utils::setup_runtime(|name| executor.find_symbol_ptr(name));
#[cfg(feature = "with-trace-dump")]
crate::metadata::trace_dump::setup_runtime(|name| executor.find_symbol_ptr(name));
#[cfg(feature = "with-libfunc-profiling")]
crate::metadata::profiler::setup_runtime(|name| executor.find_symbol_ptr(name));
executor
}
pub fn from_native_module(module: NativeModule, opt_level: OptLevel) -> Result<Self, Error> {
let NativeModule {
module,
registry,
mut metadata,
} = module;
let library_path = NamedTempFile::new()?
.into_temp_path()
.keep()
.map_err(io::Error::from)?;
let object_data = crate::module_to_object(&module, opt_level, None)?;
crate::object_to_shared_lib(&object_data, &library_path, None)?;
Ok(Self::new(
unsafe { Library::new(&library_path)? },
registry,
metadata.remove().ok_or(Error::MissingMetadata)?,
))
}
pub fn invoke_dynamic(
&self,
function_id: &FunctionId,
args: &[Value],
gas: Option<u64>,
) -> Result<ExecutionResult, Error> {
let available_gas = self
.gas_metadata
.get_initial_available_gas(function_id, gas)?;
super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id)?,
self.extract_signature(function_id)?,
args,
available_gas,
Option::<DummySyscallHandler>::None,
)
}
pub fn invoke_dynamic_with_syscall_handler(
&self,
function_id: &FunctionId,
args: &[Value],
gas: Option<u64>,
syscall_handler: impl StarknetSyscallHandler,
) -> Result<ExecutionResult, Error> {
let available_gas = self
.gas_metadata
.get_initial_available_gas(function_id, gas)?;
super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id)?,
self.extract_signature(function_id)?,
args,
available_gas,
Some(syscall_handler),
)
}
pub fn invoke_contract_dynamic(
&self,
function_id: &FunctionId,
args: &[Felt],
gas: Option<u64>,
syscall_handler: impl StarknetSyscallHandler,
) -> Result<ContractExecutionResult, Error> {
let available_gas = self
.gas_metadata
.get_initial_available_gas(function_id, gas)?;
ContractExecutionResult::from_execution_result(super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id)?,
self.extract_signature(function_id)?,
&[Value::Struct {
fields: vec![Value::Array(
args.iter().cloned().map(Value::Felt252).collect(),
)],
debug_name: None,
}],
available_gas,
Some(syscall_handler),
)?)
}
pub fn find_function_ptr(&self, function_id: &FunctionId) -> Result<*mut c_void, Error> {
let function_name = generate_function_name(function_id, false);
let function_name = format!("_mlir_ciface_{function_name}");
unsafe {
Ok(self
.library
.get::<extern "C" fn()>(function_name.as_bytes())?
.into_raw()
.into_raw())
}
}
pub fn find_symbol_ptr(&self, name: &str) -> Option<*mut c_void> {
unsafe {
self.library
.get::<*mut ()>(name.as_bytes())
.ok()
.map(|x| x.into_raw().into_raw())
}
}
fn extract_signature(&self, function_id: &FunctionId) -> Result<&FunctionSignature, Error> {
Ok(&self.registry.get_function(function_id)?.signature)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
context::NativeContext, include_contract, starknet_stub::StubSyscallHandler,
utils::testing::load_program,
};
use cairo_lang_sierra::program::Program;
use rstest::*;
#[fixture]
fn program() -> Program {
load_program("test_data_artifacts/programs/executor_aot")
}
#[fixture]
fn starknet_program() -> Program {
include_contract!("test_data_artifacts/contracts/simple_storage_42.contract.json")
.extract_sierra_program(true)
.unwrap()
.program
}
#[rstest]
#[case(OptLevel::None)]
#[case(OptLevel::Default)]
#[case(OptLevel::Aggressive)]
fn test_invoke_dynamic(program: Program, #[case] optlevel: OptLevel) {
let native_context = NativeContext::new();
let module = native_context
.compile(&program, false, Some(Default::default()), None)
.expect("failed to compile context");
let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap();
let entrypoint_function_id = &program.funcs.first().expect("should have a function").id;
let result = executor
.invoke_dynamic(entrypoint_function_id, &[], Some(u64::MAX))
.unwrap();
assert_eq!(result.return_value, Value::Felt252(Felt::from(42)));
}
#[rstest]
#[case(OptLevel::None)]
#[case(OptLevel::Default)]
#[case(OptLevel::Aggressive)]
fn test_invoke_dynamic_with_syscall_handler(program: Program, #[case] optlevel: OptLevel) {
let native_context = NativeContext::new();
let module = native_context
.compile(&program, false, Some(Default::default()), None)
.expect("failed to compile context");
let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap();
let entrypoint_function_id = &program.funcs.get(1).expect("should have a function").id;
let syscall_handler = &mut StubSyscallHandler::default();
let expected_value = Felt::from(123);
syscall_handler.block_hash.insert(1, expected_value);
let result = executor
.invoke_dynamic_with_syscall_handler(
entrypoint_function_id,
&[],
Some(u64::MAX),
syscall_handler,
)
.unwrap();
let expected_value = Value::Enum {
tag: 0,
value: Value::Struct {
fields: vec![Value::Felt252(expected_value)],
debug_name: Some("Tuple<felt252>".into()),
}
.into(),
debug_name: Some("core::panics::PanicResult::<(core::felt252,)>".into()),
};
assert_eq!(result.return_value, expected_value);
}
#[rstest]
#[case(OptLevel::None)]
#[case(OptLevel::Default)]
#[case(OptLevel::Aggressive)]
fn test_invoke_contract_dynamic(starknet_program: Program, #[case] optlevel: OptLevel) {
let native_context = NativeContext::new();
let module = native_context
.compile(&starknet_program, false, Some(Default::default()), None)
.expect("failed to compile context");
let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap();
let entrypoint_function_id = &starknet_program
.funcs
.iter()
.find(|f| {
f.id.debug_name
.as_ref()
.map(|name| name.contains("__wrapper__ISimpleStorageImpl__get"))
.unwrap_or_default()
})
.expect("should have a function")
.id;
let result = executor
.invoke_contract_dynamic(
entrypoint_function_id,
&[],
Some(u64::MAX),
&mut StubSyscallHandler::default(),
)
.unwrap();
assert_eq!(result.return_values, vec![Felt::from(42)]);
}
}