1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{ItemFn, ReturnType, Type, parse_quote};
4
5#[proc_macro_attribute]
6pub fn ffrt(_args: TokenStream, input: TokenStream) -> TokenStream {
7 convert(input.into())
8 .unwrap_or_else(|err| err.into_compile_error())
9 .into()
10}
11
12fn convert(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, syn::Error> {
13 let func = syn::parse2::<ItemFn>(input)?;
14
15 if func.sig.asyncness.is_none() {
16 return Err(syn::Error::new_spanned(
17 func,
18 "ffrt macro only supports async functions",
19 ));
20 }
21
22 let func_name = &func.sig.ident;
24 let func_vis = &func.vis;
25 let func_attrs = &func.attrs;
26 let func_inputs = &func.sig.inputs;
27 let func_body = &func.block;
28 let func_output = &func.sig.output;
29
30 let mut param_names = Vec::new();
32 for input in func_inputs.iter() {
33 if let syn::FnArg::Typed(pat_type) = input {
34 param_names.push(&pat_type.pat);
35 }
36 }
37
38 if let ReturnType::Type(_, ty) = func_output {
40 if let Type::Path(type_path) = &**ty {
41 if let Some(segment) = type_path.path.segments.last() {
42 if segment.ident == "Result" && !is_napi_ohos_path(&type_path.path) {
43 return Err(syn::Error::new_spanned(
44 ty,
45 "ffrt macro requires napi_ohos::Result, not std::result::Result or other Result types",
46 ));
47 }
48 }
49 }
50 }
51
52 let inner_return_type = match func_output {
54 ReturnType::Default => {
55 parse_quote!(())
57 }
58 ReturnType::Type(_, ty) => {
59 if is_result_type(ty) {
61 extract_result_inner_type(ty).unwrap_or_else(|| parse_quote!(()))
63 } else {
64 (**ty).clone()
66 }
67 }
68 };
69
70 let returns_result = match func_output {
72 ReturnType::Default => false,
73 ReturnType::Type(_, ty) => is_result_type(ty),
74 };
75
76 let async_body = if returns_result {
78 quote! {
80 #func_body
81 }
82 } else {
83 quote! {
87 {
88 Ok(#func_body)
89 }
90 }
91 };
92
93 Ok(quote! {
95 #(#func_attrs)*
96 #[napi_derive_ohos::napi]
97 #func_vis fn #func_name<'env>(
98 env: &'env napi_ohos::Env,
99 #func_inputs
100 ) -> napi_ohos::Result<napi_ohos::bindgen_prelude::PromiseRaw<'env, #inner_return_type>> {
101 use ohos_ext::SpawnLocalExt;
102
103 env.spawn_local(async move #async_body)
104 }
105 })
106}
107
108fn is_result_type(ty: &Type) -> bool {
109 if let Type::Path(type_path) = ty {
110 if let Some(segment) = type_path.path.segments.last() {
112 if segment.ident == "Result" {
113 return is_napi_ohos_path(&type_path.path);
115 }
116 }
117 }
118 false
119}
120
121fn is_napi_ohos_path(path: &syn::Path) -> bool {
122 let path_str = path
127 .segments
128 .iter()
129 .map(|s| s.ident.to_string())
130 .collect::<Vec<_>>()
131 .join("::");
132
133 path_str == "napi_ohos::Result"
135 || (path.segments.len() == 1 && path.segments[0].ident == "Result")
136}
137
138fn extract_result_inner_type(ty: &Type) -> Option<Type> {
139 if let Type::Path(type_path) = ty {
140 if let Some(segment) = type_path.path.segments.last() {
141 if segment.ident == "Result" && is_napi_ohos_path(&type_path.path) {
142 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
143 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
144 return Some(inner_ty.clone());
145 }
146 }
147 }
148 }
149 }
150 None
151}