use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{
parse_macro_input, FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Type, TypePath,
TypeReference,
};
#[derive(Clone, Copy)]
enum SvFlavor {
Iv,
Uv,
Nv,
}
#[derive(Clone, Copy)]
enum ArgKind {
In(SvFlavor),
Out(SvFlavor),
InCStr,
InStr,
InRawSv,
InSv,
PerlContext,
InAvRef,
InHvRef,
}
#[derive(Clone)]
enum RetKind {
Scalar(SvFlavor),
Bool,
Unit,
String_,
VecScalar(SvFlavor),
VecString,
ResultErrString(Box<RetKind>),
RawSv,
OptionRawSv,
Sv,
OptionSv,
}
struct ArgSpec {
name: Ident,
kind: ArgKind,
}
pub fn xs_sub(_attr: TokenStream, item: TokenStream) -> TokenStream {
let func = parse_macro_input!(item as ItemFn);
let mut arg_specs: Vec<ArgSpec> = Vec::new();
for arg in &func.sig.inputs {
match arg {
FnArg::Receiver(r) => {
return error(r, "`#[xs_sub]` does not support `self` receiver");
}
FnArg::Typed(PatType { pat, ty, .. }) => {
let name = match pat.as_ref() {
Pat::Ident(p) => p.ident.clone(),
other => return error(other, "`#[xs_sub]` argument must be a plain identifier"),
};
let kind = match classify_arg_type(ty) {
Some(k) => k,
None => {
return error(
ty,
"`#[xs_sub]` argument must be `IV` / `UV` / `NV` \
/ `&mut IV|UV|NV` / `&CStr` / `&str` / `*mut SV` \
/ `Sv` / `&Perl`",
);
}
};
if matches!(kind, ArgKind::PerlContext) && !arg_specs.is_empty() {
return error(
ty,
"`&Perl` (interpreter context) must be the first \
parameter of an `#[xs_sub]`",
);
}
arg_specs.push(ArgSpec { name, kind });
}
}
}
let ret_kind = match &func.sig.output {
ReturnType::Default => RetKind::Unit,
ReturnType::Type(_, ty) => match classify_ret_type(ty) {
Some(k) => k,
None => {
return error(
ty,
"`#[xs_sub]` return must be `()` / `bool` / `IV` / `UV` / `NV`",
);
}
},
};
let fn_name = &func.sig.ident;
let ret_ty_for_user = match &func.sig.output {
ReturnType::Default => quote! { () },
ReturnType::Type(_, ty) => quote! { #ty },
};
let body_fn_name = quote::format_ident!("__xs_body_{}", fn_name);
let mut body_fn_item = func.clone();
body_fn_item.sig.ident = body_fn_name.clone();
body_fn_item.vis = syn::Visibility::Inherited;
let body_fn = quote! {
#[allow(non_snake_case)]
#[inline]
#body_fn_item
};
let user_arg_call: Vec<TokenStream2> = arg_specs
.iter()
.map(|s| {
let n = &s.name;
match s.kind {
ArgKind::In(_)
| ArgKind::InCStr
| ArgKind::InStr
| ArgKind::InRawSv
| ArgKind::InSv => quote! { #n },
ArgKind::Out(_) => quote! { &mut #n },
ArgKind::PerlContext => quote! { __perl_ref },
ArgKind::InAvRef | ArgKind::InHvRef => quote! { &#n },
}
})
.collect();
let arg_count = arg_specs
.iter()
.filter(|s| !matches!(s.kind, ArgKind::PerlContext))
.count();
let usage_str = arg_specs
.iter()
.filter(|a| !matches!(a.kind, ArgKind::PerlContext))
.map(|a| a.name.to_string())
.collect::<Vec<_>>()
.join(", ");
let needs_perl_ref = arg_specs
.iter()
.any(|s| matches!(s.kind, ArgKind::PerlContext));
let perl_ref_setup: TokenStream2 = if needs_perl_ref {
quote! {
let __perl_ctx_storage = ::core::mem::ManuallyDrop::new(unsafe {
::libperl_rs::Perl::from_raw_unchecked(my_perl)
});
let __perl_ref: &::libperl_rs::Perl = &*__perl_ctx_storage;
}
} else {
quote! {}
};
let usage_cstring = std::ffi::CString::new(usage_str)
.expect("usage string contains interior nul");
let usage_lit = syn::LitCStr::new(usage_cstring.as_c_str(), proc_macro2::Span::call_site());
let threaded = cfg!(perl_useithreads);
let myperl_arg_prefix: TokenStream2 = if threaded { quote! { my_perl, } } else { quote! {} };
let myperl_arg_only: TokenStream2 = if threaded { quote! { my_perl } } else { quote! {} };
let trampoline_params = if threaded {
quote! {
my_perl: *mut ::libperl_rs::PerlInterpreter,
cv: *mut ::libperl_rs::CV,
}
} else {
quote! { cv: *mut ::libperl_rs::CV, }
};
let null_check = if threaded {
quote! { if my_perl.is_null() { return; } }
} else {
quote! {
#[allow(unused_variables)]
let my_perl: *mut ::libperl_rs::PerlInterpreter = ::core::ptr::null_mut();
}
};
let pop_mark = if threaded {
quote! {
unsafe {
(*my_perl).Imarkstack_ptr = (*my_perl).Imarkstack_ptr.sub(1);
}
}
} else {
quote! {
unsafe {
::libperl_rs::PL_markstack_ptr = ::libperl_rs::PL_markstack_ptr.sub(1);
}
}
};
let sp_writer = if threaded {
quote! {
let __set_sp_for_n = move |n: usize| unsafe {
(*my_perl).Istack_sp =
::libperl_rs::PL_stack_base!(my_perl).add(__ax + n - 1);
};
}
} else {
quote! {
let __set_sp_for_n = move |n: usize| unsafe {
::libperl_rs::PL_stack_sp =
::libperl_rs::PL_stack_base!(my_perl).add(__ax + n - 1);
};
}
};
let mut __stack_idx: usize = 0;
let arg_extractions: Vec<TokenStream2> = arg_specs
.iter()
.map(|spec| {
if matches!(spec.kind, ArgKind::PerlContext) {
return quote! {};
}
let name = &spec.name;
let svp_ident = quote::format_ident!("__svp_{}", name);
let i_lit = syn::Index::from(__stack_idx);
__stack_idx += 1;
let svp_capture = quote! {
let #svp_ident: *mut *mut ::libperl_rs::SV = unsafe {
::libperl_rs::PL_stack_base!(my_perl).add(__ax + #i_lit)
};
};
match spec.kind {
ArgKind::In(flavor) | ArgKind::Out(flavor) => {
let is_mut = matches!(spec.kind, ArgKind::Out(_));
let (rust_ty, reader) = sv_flavor_input(flavor);
let mut_kw = if is_mut { quote! { mut } } else { quote! {} };
quote! {
#svp_capture
let #mut_kw #name: #rust_ty = unsafe {
::libperl_rs::#reader(#myperl_arg_prefix *#svp_ident)
};
}
}
ArgKind::InRawSv => {
quote! {
#svp_capture
let #name: *mut ::libperl_rs::SV = unsafe { *#svp_ident };
}
}
ArgKind::InSv => {
quote! {
#svp_capture
let #name: ::libperl_rs::Sv = unsafe {
::libperl_rs::Sv::from_raw_unchecked(*#svp_ident)
};
}
}
ArgKind::InAvRef | ArgKind::InHvRef => {
let (rust_ty, expected_svtype, type_str, ctor) = match spec.kind {
ArgKind::InAvRef => (
quote! { ::libperl_rs::Av },
quote! { ::libperl_rs::svtype::SVt_PVAV },
"ARRAY",
quote! { ::libperl_rs::Av::from_raw_unchecked },
),
ArgKind::InHvRef => (
quote! { ::libperl_rs::Hv },
quote! { ::libperl_rs::svtype::SVt_PVHV },
"HASH",
quote! { ::libperl_rs::Hv::from_raw_unchecked },
),
_ => unreachable!(),
};
let err_msg = format!("argument `{}` must be a {} reference", name, type_str);
let err_lit = syn::LitCStr::new(
std::ffi::CString::new(err_msg).unwrap().as_c_str(),
proc_macro2::Span::call_site(),
);
quote! {
#svp_capture
let #name: #rust_ty = unsafe {
let __sv: *mut ::libperl_rs::SV = *#svp_ident;
if ::libperl_rs::SvROK(__sv) == 0 {
::libperl_rs::Perl_croak(
#myperl_arg_prefix
#err_lit.as_ptr(),
);
}
let __target: *mut ::libperl_rs::SV = ::libperl_rs::SvRV(__sv);
if ::libperl_rs::SvTYPE(__target) != #expected_svtype {
::libperl_rs::Perl_croak(
#myperl_arg_prefix
#err_lit.as_ptr(),
);
}
#ctor(__target as _)
};
}
}
ArgKind::PerlContext => unreachable!("handled by early return above"),
ArgKind::InCStr | ArgKind::InStr => {
let pv_ident = quote::format_ident!("__pv_{}", name);
let len_ident = quote::format_ident!("__pvlen_{}", name);
let cstr_ident = quote::format_ident!("__cstr_{}", name);
let extract_pv = quote! {
#svp_capture
let mut #len_ident: ::libperl_rs::STRLEN = 0;
let #pv_ident: *const ::core::ffi::c_char = unsafe {
::libperl_rs::Perl_sv_2pv_flags(
#myperl_arg_prefix
*#svp_ident,
&mut #len_ident,
::libperl_rs::SV_GMAGIC as _,
)
};
let #cstr_ident: &::core::ffi::CStr =
unsafe { ::core::ffi::CStr::from_ptr(#pv_ident) };
};
if matches!(spec.kind, ArgKind::InCStr) {
quote! {
#extract_pv
let #name: &::core::ffi::CStr = #cstr_ident;
}
} else {
let usage_err_lit = syn::LitCStr::new(
std::ffi::CString::new(format!(
"argument `{}` is not valid UTF-8",
name
))
.unwrap()
.as_c_str(),
proc_macro2::Span::call_site(),
);
quote! {
#extract_pv
let #name: &str = match #cstr_ident.to_str() {
::core::result::Result::Ok(s) => s,
::core::result::Result::Err(_) => {
unsafe {
::libperl_rs::Perl_croak(
#myperl_arg_prefix
#usage_err_lit.as_ptr(),
);
}
}
};
}
}
}
}
})
.collect();
let out_writebacks: Vec<TokenStream2> = arg_specs
.iter()
.filter_map(|spec| {
let ArgKind::Out(flavor) = spec.kind else { return None; };
let name = &spec.name;
let svp_ident = quote::format_ident!("__svp_{}", name);
let setter = sv_flavor_setter(flavor);
Some(quote! {
unsafe {
::libperl_rs::#setter(#myperl_arg_prefix *#svp_ident, #name);
}
})
})
.collect();
let (unwrap_code, push_kind) = match ret_kind {
RetKind::ResultErrString(inner) => (
quote! {
let __ret = match __raw {
::core::result::Result::Ok(v) => v,
::core::result::Result::Err(__e) => {
let __msg = ::std::ffi::CString::new(__e)
.unwrap_or_else(|_| ::std::ffi::CString::new(
"xs_sub: error message contained interior NUL",
).unwrap());
unsafe {
::libperl_rs::Perl_croak(
#myperl_arg_prefix
c"%s\n".as_ptr(),
__msg.as_ptr(),
);
}
}
};
},
*inner,
),
other => (quote! { let __ret = __raw; }, other),
};
let push_code: TokenStream2 = match push_kind {
RetKind::Scalar(flavor) => {
let setter = sv_flavor_setter(flavor);
quote! {
let __targ = unsafe { ::libperl_rs::Perl_sv_newmortal(#myperl_arg_only) };
unsafe { ::libperl_rs::#setter(#myperl_arg_prefix __targ, __ret); }
unsafe {
*::libperl_rs::PL_stack_base!(my_perl).add(__ax + 0) = __targ;
}
__set_sp_for_n(1);
}
}
RetKind::Bool => quote! {
let __targ = unsafe { ::libperl_rs::Perl_sv_newmortal(#myperl_arg_only) };
unsafe {
::libperl_rs::Perl_sv_setiv(
#myperl_arg_prefix __targ,
__ret as ::libperl_rs::IV,
);
}
unsafe {
*::libperl_rs::PL_stack_base!(my_perl).add(__ax + 0) = __targ;
}
__set_sp_for_n(1);
},
RetKind::String_ => quote! {
let __targ = unsafe { ::libperl_rs::Perl_sv_newmortal(#myperl_arg_only) };
let __bytes: &[u8] = __ret.as_bytes();
unsafe {
::libperl_rs::Perl_sv_setpvn(
#myperl_arg_prefix
__targ,
__bytes.as_ptr() as *const ::core::ffi::c_char,
__bytes.len() as _,
);
let __cur_flags: i64 = (*__targ).sv_flags as i64;
let __new_flags: i64 = __cur_flags | (::libperl_rs::SVf_UTF8 as i64);
(*__targ).sv_flags = __new_flags as _;
}
unsafe {
*::libperl_rs::PL_stack_base!(my_perl).add(__ax + 0) = __targ;
}
__set_sp_for_n(1);
},
RetKind::Unit => quote! {
let _ = __ret;
__set_sp_for_n(0);
},
RetKind::VecScalar(flavor) => {
let setter = sv_flavor_setter(flavor);
quote! {
let __n: usize = __ret.len();
unsafe {
let _ = ::libperl_rs::Perl_stack_grow(
#myperl_arg_prefix
::libperl_rs::PL_stack_sp!(my_perl),
::libperl_rs::PL_stack_sp!(my_perl),
__n as _,
);
}
for (__i, __val) in __ret.iter().enumerate() {
unsafe {
let __targ = ::libperl_rs::Perl_sv_newmortal(#myperl_arg_only);
::libperl_rs::#setter(#myperl_arg_prefix __targ, *__val);
*::libperl_rs::PL_stack_base!(my_perl).add(__ax + __i) = __targ;
}
}
__set_sp_for_n(__n);
}
}
RetKind::VecString => quote! {
let __n: usize = __ret.len();
unsafe {
let _ = ::libperl_rs::Perl_stack_grow(
#myperl_arg_prefix
::libperl_rs::PL_stack_sp!(my_perl),
::libperl_rs::PL_stack_sp!(my_perl),
__n as _,
);
}
for (__i, __val) in __ret.iter().enumerate() {
unsafe {
let __targ = ::libperl_rs::Perl_sv_newmortal(#myperl_arg_only);
let __bytes: &[u8] = __val.as_bytes();
::libperl_rs::Perl_sv_setpvn(
#myperl_arg_prefix
__targ,
__bytes.as_ptr() as *const ::core::ffi::c_char,
__bytes.len() as _,
);
let __cur_flags: i64 = (*__targ).sv_flags as i64;
(*__targ).sv_flags =
(__cur_flags | (::libperl_rs::SVf_UTF8 as i64)) as _;
*::libperl_rs::PL_stack_base!(my_perl).add(__ax + __i) = __targ;
}
}
__set_sp_for_n(__n);
},
RetKind::ResultErrString(_) => unreachable!("Result is unwrapped before push"),
RetKind::RawSv => quote! {
let __sv: *mut ::libperl_rs::SV = __ret;
unsafe {
let __mortal = ::libperl_rs::Perl_sv_2mortal(
#myperl_arg_prefix
::libperl_rs::sv_refcnt_inc(__sv),
);
*::libperl_rs::PL_stack_base!(my_perl).add(__ax + 0) = __mortal;
}
__set_sp_for_n(1);
},
RetKind::Sv => quote! {
let __sv: *mut ::libperl_rs::SV = __ret.as_ptr();
unsafe {
let __mortal = ::libperl_rs::Perl_sv_2mortal(
#myperl_arg_prefix
::libperl_rs::sv_refcnt_inc(__sv),
);
*::libperl_rs::PL_stack_base!(my_perl).add(__ax + 0) = __mortal;
}
__set_sp_for_n(1);
},
RetKind::OptionSv => quote! {
let __pushed: *mut ::libperl_rs::SV = match __ret {
::core::option::Option::Some(__sv_wrap) => unsafe {
::libperl_rs::Perl_sv_2mortal(
#myperl_arg_prefix
::libperl_rs::sv_refcnt_inc(__sv_wrap.as_ptr()),
)
},
::core::option::Option::None => ::libperl_rs::sv_undef_ptr(my_perl),
};
unsafe {
*::libperl_rs::PL_stack_base!(my_perl).add(__ax + 0) = __pushed;
}
__set_sp_for_n(1);
},
RetKind::OptionRawSv => {
quote! {
let __pushed: *mut ::libperl_rs::SV = match __ret {
::core::option::Option::Some(__sv) => unsafe {
::libperl_rs::Perl_sv_2mortal(
#myperl_arg_prefix
::libperl_rs::sv_refcnt_inc(__sv),
)
},
::core::option::Option::None => ::libperl_rs::sv_undef_ptr(my_perl),
};
unsafe {
*::libperl_rs::PL_stack_base!(my_perl).add(__ax + 0) = __pushed;
}
__set_sp_for_n(1);
}
}
};
let return_push: TokenStream2 = quote! {
#unwrap_code
#push_code
};
let trampoline = quote! {
#body_fn
#[allow(unused_variables, unreachable_code)]
pub extern "C" fn #fn_name( #trampoline_params ) {
#null_check
let __sp: *mut *mut ::libperl_rs::SV = ::libperl_rs::PL_stack_sp!(my_perl);
let __mark_ax = unsafe {
*::libperl_rs::PL_markstack_ptr!(my_perl)
};
#pop_mark
let __mark = unsafe {
::libperl_rs::PL_stack_base!(my_perl).add(__mark_ax as usize)
};
let __ax: usize = (__mark_ax as usize).wrapping_add(1);
let __items = unsafe { __sp.offset_from(__mark) };
if __items != #arg_count as isize {
unsafe {
::libperl_rs::Perl_croak_xs_usage(cv, #usage_lit.as_ptr());
}
return;
}
#( #arg_extractions )*
#perl_ref_setup
let __raw: #ret_ty_for_user = #body_fn_name( #( #user_arg_call ),* );
#( #out_writebacks )*
#sp_writer
#return_push
}
};
trampoline.into()
}
fn classify_arg_type(ty: &Type) -> Option<ArgKind> {
if let Type::Reference(TypeReference { mutability, elem, .. }) = ty {
if mutability.is_some() {
return classify_scalar(elem).map(ArgKind::Out);
}
if let Type::Path(TypePath { path, .. }) = elem.as_ref() {
let last = path.segments.last()?;
return match last.ident.to_string().as_str() {
"str" => Some(ArgKind::InStr),
"CStr" => Some(ArgKind::InCStr),
"Perl" => Some(ArgKind::PerlContext),
"Av" if last.arguments.is_none() => Some(ArgKind::InAvRef),
"Hv" if last.arguments.is_none() => Some(ArgKind::InHvRef),
_ => None,
};
}
return None;
}
if is_raw_sv_ptr(ty) {
return Some(ArgKind::InRawSv);
}
if is_sv_newtype(ty) {
return Some(ArgKind::InSv);
}
classify_scalar(ty).map(ArgKind::In)
}
fn classify_ret_type(ty: &Type) -> Option<RetKind> {
if let Type::Tuple(t) = ty {
if t.elems.is_empty() {
return Some(RetKind::Unit);
}
}
if is_raw_sv_ptr(ty) {
return Some(RetKind::RawSv);
}
if is_sv_newtype(ty) {
return Some(RetKind::Sv);
}
if is_rv_wrapper(ty) {
return Some(RetKind::Sv);
}
if let Type::Path(TypePath { path, .. }) = ty {
let last = path.segments.last()?;
match last.ident.to_string().as_str() {
"bool" => return Some(RetKind::Bool),
"String" => return Some(RetKind::String_),
"Vec" => return classify_vec_inner(&last.arguments),
"Result" => return classify_result_inner(&last.arguments),
"Option" => return classify_option_inner(&last.arguments),
_ => {}
}
}
classify_scalar(ty).map(RetKind::Scalar)
}
fn classify_option_inner(args: &syn::PathArguments) -> Option<RetKind> {
let inner = generic_arg_n(args, 0)?;
if is_raw_sv_ptr(inner) {
return Some(RetKind::OptionRawSv);
}
if is_sv_newtype(inner) || is_rv_wrapper(inner) {
return Some(RetKind::OptionSv);
}
None
}
fn is_raw_sv_ptr(ty: &Type) -> bool {
let Type::Ptr(p) = ty else { return false };
if p.const_token.is_some() || p.mutability.is_none() {
return false;
}
let Type::Path(TypePath { path, .. }) = p.elem.as_ref() else {
return false;
};
path.segments.last().is_some_and(|s| s.ident == "SV")
}
fn is_sv_newtype(ty: &Type) -> bool {
let Type::Path(TypePath { path, .. }) = ty else { return false };
path.segments.last().is_some_and(|s| s.ident == "Sv" && s.arguments.is_none())
}
fn is_rv_wrapper(ty: &Type) -> bool {
let Type::Path(TypePath { path, .. }) = ty else { return false };
let Some(last) = path.segments.last() else { return false };
if last.ident != "Rv" {
return false;
}
matches!(last.arguments, syn::PathArguments::AngleBracketed(_))
}
fn classify_vec_inner(args: &syn::PathArguments) -> Option<RetKind> {
let inner = generic_arg_n(args, 0)?;
if let Type::Path(TypePath { path, .. }) = inner {
if path.segments.last().is_some_and(|s| s.ident == "String") {
return Some(RetKind::VecString);
}
}
classify_scalar(inner).map(RetKind::VecScalar)
}
fn classify_result_inner(args: &syn::PathArguments) -> Option<RetKind> {
let ok_ty = generic_arg_n(args, 0)?;
let err_ty = generic_arg_n(args, 1)?;
let err_is_string = matches!(
err_ty,
Type::Path(TypePath { path, .. })
if path.segments.last().is_some_and(|s| s.ident == "String"),
);
if !err_is_string {
return None;
}
let inner = classify_ret_type(ok_ty)?;
Some(RetKind::ResultErrString(Box::new(inner)))
}
fn generic_arg_n(args: &syn::PathArguments, n: usize) -> Option<&Type> {
let syn::PathArguments::AngleBracketed(args) = args else { return None };
let arg = args.args.iter().nth(n)?;
let syn::GenericArgument::Type(ty) = arg else { return None };
Some(ty)
}
fn classify_scalar(ty: &Type) -> Option<SvFlavor> {
if let Type::Path(TypePath { path, .. }) = ty {
let last = path.segments.last()?;
return match last.ident.to_string().as_str() {
"IV" => Some(SvFlavor::Iv),
"UV" => Some(SvFlavor::Uv),
"NV" => Some(SvFlavor::Nv),
_ => None,
};
}
None
}
fn sv_flavor_input(flavor: SvFlavor) -> (TokenStream2, Ident) {
match flavor {
SvFlavor::Iv => (
quote! { ::libperl_rs::IV },
Ident::new("SvIV", proc_macro2::Span::call_site()),
),
SvFlavor::Uv => (
quote! { ::libperl_rs::UV },
Ident::new("SvUV", proc_macro2::Span::call_site()),
),
SvFlavor::Nv => (
quote! { ::libperl_rs::NV },
Ident::new("SvNV", proc_macro2::Span::call_site()),
),
}
}
fn sv_flavor_setter(flavor: SvFlavor) -> Ident {
let n = match flavor {
SvFlavor::Iv => "Perl_sv_setiv",
SvFlavor::Uv => "Perl_sv_setuv",
SvFlavor::Nv => "Perl_sv_setnv",
};
Ident::new(n, proc_macro2::Span::call_site())
}
fn error<T: quote::ToTokens>(spanned: T, msg: &str) -> TokenStream {
syn::Error::new_spanned(spanned, msg).to_compile_error().into()
}