use std::cell::UnsafeCell;
use std::collections::HashMap;
use std::sync::Arc;
use relon_eval_api::{NativeArgs, NativeFnCaps, RelonFunction, RuntimeError, Value};
use relon_parser::TokenRange;
#[repr(C)]
pub struct ArenaState {
pub arena_base: UnsafeCell<usize>,
pub arena_len: UnsafeCell<u32>,
pub tail_cursor: UnsafeCell<u32>,
pub scratch_cursor: UnsafeCell<u32>,
pub scratch_base: UnsafeCell<u32>,
pub trap_code: UnsafeCell<u64>,
pub host_fns: UnsafeCell<usize>,
pub step_budget: UnsafeCell<i64>,
}
pub const ARENA_STATE_OFFSET_BASE: u32 = 0;
pub const ARENA_STATE_OFFSET_LEN: u32 = std::mem::size_of::<usize>() as u32;
pub const ARENA_STATE_OFFSET_TAIL_CURSOR: u32 = ARENA_STATE_OFFSET_LEN + 4;
pub const ARENA_STATE_OFFSET_SCRATCH_CURSOR: u32 = ARENA_STATE_OFFSET_TAIL_CURSOR + 4;
pub const ARENA_STATE_OFFSET_SCRATCH_BASE: u32 = ARENA_STATE_OFFSET_SCRATCH_CURSOR + 4;
pub const ARENA_STATE_OFFSET_TRAP_CODE: u32 = 24;
#[allow(dead_code)]
pub const ARENA_STATE_OFFSET_HOST_FNS: u32 = ARENA_STATE_OFFSET_TRAP_CODE + 8;
pub const ARENA_STATE_OFFSET_STEP_BUDGET: u32 =
ARENA_STATE_OFFSET_HOST_FNS + std::mem::size_of::<usize>() as u32;
#[repr(u64)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NativeTrap {
DivisionByZero = 1,
BoundsViolation = 2,
ResourceExhausted = 4,
CapabilityDenied = 3,
NumericOverflow = 6,
HostFnMissing = 5,
HostFnError = 7,
NoMatch = 8,
}
impl NativeTrap {
pub fn runtime_error_from_code(code: u64) -> RuntimeError {
match code {
1 => RuntimeError::DivisionByZero(TokenRange::default()),
2 => RuntimeError::IndexOutOfBounds {
range: TokenRange::default(),
},
3 => RuntimeError::CapabilityDenied {
cap_bit: None,
reason: "llvm-aot: host-fn call denied by capability gate".to_string(),
range: TokenRange::default(),
},
6 => RuntimeError::NumericOverflow(TokenRange::default()),
8 => RuntimeError::TypeMismatch {
expected: "a matching arm".to_string(),
found: "no matching arm".to_string(),
range: TokenRange::default(),
},
4 => RuntimeError::StepLimitExceeded {
limit: None,
range: TokenRange::default(),
},
_ => RuntimeError::Unsupported {
reason: "llvm-aot: native-fn dispatch failed (host fn missing / errored / \
returned a non-scalar value)"
.to_string(),
},
}
}
}
impl ArenaState {
pub fn new(arena: &mut [u8], scratch_base: u32) -> Self {
Self {
arena_base: UnsafeCell::new(arena.as_mut_ptr() as usize),
arena_len: UnsafeCell::new(arena.len() as u32),
tail_cursor: UnsafeCell::new(0),
scratch_cursor: UnsafeCell::new(0),
scratch_base: UnsafeCell::new(scratch_base),
trap_code: UnsafeCell::new(0),
host_fns: UnsafeCell::new(0),
step_budget: UnsafeCell::new(0),
}
}
pub fn set_step_budget(&self, budget: i64) {
unsafe {
*self.step_budget.get() = budget;
}
}
pub unsafe fn install_host_fns(&self, registry: *const HostFnRegistry) {
unsafe {
*self.host_fns.get() = registry as usize;
}
}
pub fn trap_code(&self) -> u64 {
unsafe { *self.trap_code.get() }
}
#[allow(dead_code)]
pub fn tail_cursor(&self) -> u32 {
unsafe { *self.tail_cursor.get() }
}
}
#[derive(Default, Clone)]
pub struct HostFnRegistry {
host_fns: HashMap<u32, Arc<dyn RelonFunction>>,
}
impl std::fmt::Debug for HostFnRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HostFnRegistry")
.field("host_fn_count", &self.host_fns.len())
.finish()
}
}
impl HostFnRegistry {
pub fn new() -> Self {
Self {
host_fns: HashMap::new(),
}
}
pub fn register(&mut self, import_idx: u32, func: Arc<dyn RelonFunction>) {
self.host_fns.insert(import_idx, func);
}
pub fn resolve(&self, import_idx: u32) -> Option<&Arc<dyn RelonFunction>> {
self.host_fns.get(&import_idx)
}
pub fn len(&self) -> usize {
self.host_fns.len()
}
pub fn is_empty(&self) -> bool {
self.host_fns.is_empty()
}
}
struct LlvmNativeFnCaps;
impl NativeFnCaps for LlvmNativeFnCaps {
fn call_relon(
&self,
_func: &Value,
_args: Vec<Value>,
_range: TokenRange,
) -> Result<Value, RuntimeError> {
Err(RuntimeError::Unsupported {
reason: "llvm-aot host fn: call_relon callback unsupported".to_string(),
})
}
}
fn llvm_native_caps() -> Arc<dyn NativeFnCaps> {
static CAPS: std::sync::OnceLock<Arc<dyn NativeFnCaps>> = std::sync::OnceLock::new();
Arc::clone(CAPS.get_or_init(|| Arc::new(LlvmNativeFnCaps) as Arc<dyn NativeFnCaps>))
}
pub const RELON_LLVM_CALL_NATIVE_SYMBOL: &str = "relon_llvm_call_native";
pub unsafe extern "C" fn relon_llvm_call_native(
state: *const ArenaState,
import_idx: u32,
args_ptr: *const i64,
arg_count: u32,
) -> i64 {
let state = unsafe { &*state };
let registry_ptr = unsafe { *state.host_fns.get() } as *const HostFnRegistry;
let record_trap = |code: NativeTrap| {
unsafe {
*state.trap_code.get() = code as u64;
}
};
if registry_ptr.is_null() {
record_trap(NativeTrap::HostFnMissing);
return 0;
}
let registry = unsafe { &*registry_ptr };
let Some(func) = registry.resolve(import_idx).cloned() else {
record_trap(NativeTrap::HostFnMissing);
return 0;
};
let args_slice = if arg_count == 0 {
&[][..]
} else {
unsafe { std::slice::from_raw_parts(args_ptr, arg_count as usize) }
};
let packed: Vec<Value> = args_slice.iter().map(|&x| Value::Int(x)).collect();
let native_args = NativeArgs::from_positional(packed, llvm_native_caps());
match func.call(native_args, TokenRange::default()) {
Ok(Value::Int(v)) => v,
Ok(Value::Bool(b)) => i64::from(b),
Ok(v) if v.is_option_none() => 0,
Ok(_) | Err(_) => {
record_trap(NativeTrap::HostFnError);
0
}
}
}
#[inline]
pub fn relon_llvm_call_native_addr() -> usize {
relon_llvm_call_native as *const () as usize
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn arena_state_offsets_match_repr_c_layout() {
let mut buf = [0u8; 16];
let state = ArenaState::new(&mut buf, 16);
let base = &state as *const _ as usize;
assert_eq!(
(state.arena_base.get() as usize) - base,
ARENA_STATE_OFFSET_BASE as usize
);
assert_eq!(
(state.arena_len.get() as usize) - base,
ARENA_STATE_OFFSET_LEN as usize
);
assert_eq!(
(state.tail_cursor.get() as usize) - base,
ARENA_STATE_OFFSET_TAIL_CURSOR as usize
);
assert_eq!(
(state.scratch_cursor.get() as usize) - base,
ARENA_STATE_OFFSET_SCRATCH_CURSOR as usize
);
assert_eq!(
(state.scratch_base.get() as usize) - base,
ARENA_STATE_OFFSET_SCRATCH_BASE as usize
);
assert_eq!(
(state.trap_code.get() as usize) - base,
ARENA_STATE_OFFSET_TRAP_CODE as usize
);
assert_eq!(
(state.host_fns.get() as usize) - base,
ARENA_STATE_OFFSET_HOST_FNS as usize
);
assert_eq!(
(state.step_budget.get() as usize) - base,
ARENA_STATE_OFFSET_STEP_BUDGET as usize
);
}
struct AddOne;
impl RelonFunction for AddOne {
fn call(&self, args: NativeArgs, _r: TokenRange) -> Result<Value, RuntimeError> {
match args.positional.first() {
Some(Value::Int(x)) => Ok(Value::Int(x + 1)),
_ => Err(RuntimeError::Unsupported {
reason: "AddOne expects Int".into(),
}),
}
}
}
#[test]
fn call_native_helper_dispatches_registered_fn() {
let mut reg = HostFnRegistry::new();
reg.register(0, Arc::new(AddOne));
let mut buf = [0u8; 16];
let state = ArenaState::new(&mut buf, 16);
unsafe { state.install_host_fns(® as *const _) };
let args = [41i64];
let r = unsafe { relon_llvm_call_native(&state as *const _, 0, args.as_ptr(), 1) };
assert_eq!(r, 42);
assert_eq!(state.trap_code(), 0);
}
#[test]
fn call_native_helper_traps_when_unregistered() {
let reg = HostFnRegistry::new();
let mut buf = [0u8; 16];
let state = ArenaState::new(&mut buf, 16);
unsafe { state.install_host_fns(® as *const _) };
let r = unsafe { relon_llvm_call_native(&state as *const _, 7, std::ptr::null(), 0) };
assert_eq!(r, 0);
assert_eq!(state.trap_code(), NativeTrap::HostFnMissing as u64);
}
#[test]
fn call_native_helper_traps_when_no_registry() {
let mut buf = [0u8; 16];
let state = ArenaState::new(&mut buf, 16);
let r = unsafe { relon_llvm_call_native(&state as *const _, 0, std::ptr::null(), 0) };
assert_eq!(r, 0);
assert_eq!(state.trap_code(), NativeTrap::HostFnMissing as u64);
}
#[test]
fn native_trap_bounds_code_lifts_to_index_out_of_bounds() {
assert!(matches!(
NativeTrap::runtime_error_from_code(NativeTrap::DivisionByZero as u64),
RuntimeError::DivisionByZero(_)
));
assert!(matches!(
NativeTrap::runtime_error_from_code(NativeTrap::BoundsViolation as u64),
RuntimeError::IndexOutOfBounds { .. }
));
assert!(matches!(
NativeTrap::runtime_error_from_code(NativeTrap::ResourceExhausted as u64),
RuntimeError::StepLimitExceeded { .. }
));
}
}