use proc_macro::TokenStream;
use quote::quote;
use syn::{ItemFn, ReturnType, Type, parse_quote};
#[proc_macro_attribute]
pub fn ffrt(_args: TokenStream, input: TokenStream) -> TokenStream {
convert(input.into())
.unwrap_or_else(|err| err.into_compile_error())
.into()
}
fn convert(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, syn::Error> {
let func = syn::parse2::<ItemFn>(input)?;
if func.sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
func,
"ffrt macro only supports async functions",
));
}
let func_name = &func.sig.ident;
let func_vis = &func.vis;
let func_attrs = &func.attrs;
let func_inputs = &func.sig.inputs;
let func_body = &func.block;
let func_output = &func.sig.output;
let mut param_names = Vec::new();
for input in func_inputs.iter() {
if let syn::FnArg::Typed(pat_type) = input {
param_names.push(&pat_type.pat);
}
}
if let ReturnType::Type(_, ty) = func_output {
if let Type::Path(type_path) = &**ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Result" && !is_napi_ohos_path(&type_path.path) {
return Err(syn::Error::new_spanned(
ty,
"ffrt macro requires napi_ohos::Result, not std::result::Result or other Result types",
));
}
}
}
}
let inner_return_type = match func_output {
ReturnType::Default => {
parse_quote!(())
}
ReturnType::Type(_, ty) => {
if is_result_type(ty) {
extract_result_inner_type(ty).unwrap_or_else(|| parse_quote!(()))
} else {
(**ty).clone()
}
}
};
let returns_result = match func_output {
ReturnType::Default => false,
ReturnType::Type(_, ty) => is_result_type(ty),
};
let async_body = if returns_result {
quote! {
#func_body
}
} else {
quote! {
{
Ok(#func_body)
}
}
};
Ok(quote! {
#(#func_attrs)*
#[napi_derive_ohos::napi]
#func_vis fn #func_name<'env>(
env: &'env napi_ohos::Env,
#func_inputs
) -> napi_ohos::Result<napi_ohos::bindgen_prelude::PromiseRaw<'env, #inner_return_type>> {
use ohos_ext::SpawnLocalExt;
env.spawn_local(async move #async_body)
}
})
}
fn is_result_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Result" {
return is_napi_ohos_path(&type_path.path);
}
}
}
false
}
fn is_napi_ohos_path(path: &syn::Path) -> bool {
let path_str = path
.segments
.iter()
.map(|s| s.ident.to_string())
.collect::<Vec<_>>()
.join("::");
path_str == "napi_ohos::Result"
|| (path.segments.len() == 1 && path.segments[0].ident == "Result")
}
fn extract_result_inner_type(ty: &Type) -> Option<Type> {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Result" && is_napi_ohos_path(&type_path.path) {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
return Some(inner_ty.clone());
}
}
}
}
}
None
}