use proc_macro2::TokenStream;
use quote::{quote, format_ident};
use syn::{Expr, ExprCall, ExprMethodCall};
use std::collections::HashMap;
use super::CompileError;
#[derive(Debug, Clone, PartialEq)]
pub enum CallKind {
Function,
Method,
AssociatedFn,
Closure,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct NativeCallInfo {
pub index: usize,
pub call_expr: Expr,
pub arg_count: usize,
pub call_kind: CallKind,
pub call_key: String,
}
pub struct NativeCallCollector {
calls: Vec<NativeCallInfo>,
call_map: HashMap<String, usize>,
}
impl NativeCallCollector {
pub fn new() -> Self {
Self {
calls: Vec::new(),
call_map: HashMap::new(),
}
}
pub fn has_calls(&self) -> bool {
!self.calls.is_empty()
}
#[allow(dead_code)]
pub fn len(&self) -> usize {
self.calls.len()
}
pub fn register_call(&mut self, call: &ExprCall) -> Result<usize, CompileError> {
let call_key = self.call_to_key_expr(&Expr::Call(call.clone()));
if let Some(&index) = self.call_map.get(&call_key) {
return Ok(index);
}
let call_kind = self.determine_call_kind(&call.func);
let index = self.calls.len();
let info = NativeCallInfo {
index,
call_expr: Expr::Call(call.clone()),
arg_count: call.args.len(),
call_kind,
call_key: call_key.clone(),
};
self.calls.push(info);
self.call_map.insert(call_key, index);
Ok(index)
}
pub fn register_method(&mut self, method: &ExprMethodCall) -> Result<usize, CompileError> {
let call_key = self.method_to_key(method);
if let Some(&index) = self.call_map.get(&call_key) {
return Ok(index);
}
let index = self.calls.len();
let info = NativeCallInfo {
index,
call_expr: Expr::MethodCall(method.clone()),
arg_count: method.args.len() + 1, call_kind: CallKind::Method,
call_key: call_key.clone(),
};
self.calls.push(info);
self.call_map.insert(call_key, index);
Ok(index)
}
pub fn generate_wrappers(&self) -> TokenStream {
let wrappers: Vec<TokenStream> = self.calls.iter().map(|info| {
let wrapper_name = format_ident!("__native_{}", info.index);
match &info.call_expr {
Expr::Call(call) => self.generate_call_wrapper(&wrapper_name, call, info),
Expr::MethodCall(method) => self.generate_method_wrapper(&wrapper_name, method, info),
_ => quote! {},
}
}).collect();
quote! {
#(#wrappers)*
}
}
pub fn generate_table(&self) -> TokenStream {
if self.calls.is_empty() {
return quote! {
let __native_table: &[fn(&[u64]) -> u64] = &[];
};
}
let wrapper_refs: Vec<TokenStream> = self.calls.iter().map(|info| {
let wrapper_name = format_ident!("__native_{}", info.index);
quote! { #wrapper_name }
}).collect();
let count = self.calls.len();
quote! {
let __native_table: [fn(&[u64]) -> u64; #count] = [
#(#wrapper_refs),*
];
}
}
fn generate_call_wrapper(&self, wrapper_name: &syn::Ident, call: &ExprCall, _info: &NativeCallInfo) -> TokenStream {
let func = &call.func;
let arg_count = call.args.len();
let arg_extracts: Vec<TokenStream> = (0..arg_count).map(|i| {
quote! { args[#i] }
}).collect();
let call_expr = if arg_extracts.is_empty() {
quote! { #func() }
} else {
quote! { #func(#(#arg_extracts as _),*) }
};
quote! {
#[inline(never)]
fn #wrapper_name(args: &[u64]) -> u64 {
let __result = #call_expr;
__to_u64(__result)
}
}
}
fn generate_method_wrapper(&self, wrapper_name: &syn::Ident, method: &ExprMethodCall, _info: &NativeCallInfo) -> TokenStream {
let method_name = &method.method;
let arg_count = method.args.len();
let receiver_extract = quote! { args[0] };
let arg_extracts: Vec<TokenStream> = (0..arg_count).map(|i| {
let idx = i + 1; quote! { args[#idx] }
}).collect();
let call_expr = if arg_extracts.is_empty() {
quote! {
let __receiver_ptr = #receiver_extract as *mut ();
unsafe { (*(__receiver_ptr as *mut _)).#method_name() }
}
} else {
quote! {
let __receiver_ptr = #receiver_extract as *mut ();
unsafe { (*(__receiver_ptr as *mut _)).#method_name(#(#arg_extracts as _),*) }
}
};
quote! {
#[inline(never)]
fn #wrapper_name(args: &[u64]) -> u64 {
let __result = #call_expr;
__to_u64(__result)
}
}
}
pub fn generate_conversion_helper() -> TokenStream {
quote! {
#[inline(always)]
fn __to_u64<T>(value: T) -> u64 {
let size = std::mem::size_of::<T>();
if size == 0 {
return 0;
}
if size <= 8 {
let mut result = 0u64;
unsafe {
std::ptr::copy_nonoverlapping(
&value as *const T as *const u8,
&mut result as *mut u64 as *mut u8,
size,
);
}
std::mem::forget(value);
result
} else {
let boxed = Box::new(value);
Box::into_raw(boxed) as u64
}
}
#[inline(always)]
fn __from_u64_bool(value: u64) -> bool {
value != 0
}
}
}
fn determine_call_kind(&self, func: &Expr) -> CallKind {
match func {
Expr::Path(path) => {
if path.path.segments.len() > 1 {
CallKind::AssociatedFn
} else {
CallKind::Function
}
}
_ => CallKind::Closure,
}
}
fn call_to_key_expr(&self, expr: &Expr) -> String {
match expr {
Expr::Call(call) => {
let func_key = self.expr_to_key(&call.func);
let args_key: Vec<String> = call.args.iter()
.map(|arg| self.expr_to_key(arg))
.collect();
format!("call:{}({})", func_key, args_key.join(","))
}
_ => format!("{:?}", expr),
}
}
fn method_to_key(&self, method: &ExprMethodCall) -> String {
let receiver_key = self.expr_to_key(&method.receiver);
let method_name = method.method.to_string();
let args_key: Vec<String> = method.args.iter()
.map(|arg| self.expr_to_key(arg))
.collect();
format!("method:{}.{}({})", receiver_key, method_name, args_key.join(","))
}
fn expr_to_key(&self, expr: &Expr) -> String {
match expr {
Expr::Path(path) => {
path.path.segments.iter()
.map(|s| s.ident.to_string())
.collect::<Vec<_>>()
.join("::")
}
Expr::Lit(lit) => format!("{:?}", lit.lit),
_ => format!("{:p}", expr), }
}
}
impl Default for NativeCallCollector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_collector_new() {
let collector = NativeCallCollector::new();
assert_eq!(collector.len(), 0);
assert!(!collector.has_calls());
}
}