ohos_ext_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{ItemFn, ReturnType, Type, parse_quote};
4
5#[proc_macro_attribute]
6pub fn ffrt(_args: TokenStream, input: TokenStream) -> TokenStream {
7    convert(input.into())
8        .unwrap_or_else(|err| err.into_compile_error())
9        .into()
10}
11
12fn convert(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, syn::Error> {
13    let func = syn::parse2::<ItemFn>(input)?;
14
15    if func.sig.asyncness.is_none() {
16        return Err(syn::Error::new_spanned(
17            func,
18            "ffrt macro only supports async functions",
19        ));
20    }
21
22    // Extract function metadata
23    let func_name = &func.sig.ident;
24    let func_vis = &func.vis;
25    let func_attrs = &func.attrs;
26    let func_inputs = &func.sig.inputs;
27    let func_body = &func.block;
28    let func_output = &func.sig.output;
29
30    // Collect parameter names for the async block
31    let mut param_names = Vec::new();
32    for input in func_inputs.iter() {
33        if let syn::FnArg::Typed(pat_type) = input {
34            param_names.push(&pat_type.pat);
35        }
36    }
37
38    // Check for incorrect Result usage
39    if let ReturnType::Type(_, ty) = func_output {
40        if let Type::Path(type_path) = &**ty {
41            if let Some(segment) = type_path.path.segments.last() {
42                if segment.ident == "Result" && !is_napi_ohos_path(&type_path.path) {
43                    return Err(syn::Error::new_spanned(
44                        ty,
45                        "ffrt macro requires napi_ohos::Result, not std::result::Result or other Result types",
46                    ));
47                }
48            }
49        }
50    }
51
52    // Determine the inner return type (what the async function returns)
53    let inner_return_type = match func_output {
54        ReturnType::Default => {
55            // If no return type, default to ()
56            parse_quote!(())
57        }
58        ReturnType::Type(_, ty) => {
59            // Check if it's already a Result type
60            if is_result_type(ty) {
61                // Extract T from Result<T>
62                extract_result_inner_type(ty).unwrap_or_else(|| parse_quote!(()))
63            } else {
64                // Not a Result, use as-is
65                (**ty).clone()
66            }
67        }
68    };
69
70    // Determine if original function returns Result
71    let returns_result = match func_output {
72        ReturnType::Default => false,
73        ReturnType::Type(_, ty) => is_result_type(ty),
74    };
75
76    // Build the async block body
77    let async_body = if returns_result {
78        // Already returns Result, use as-is
79        quote! {
80            #func_body
81        }
82    } else {
83        // Wrap the entire function body result in Ok()
84        // For fn() -> T, this becomes Ok({ body })
85        // For fn() -> (), this becomes Ok({ body })
86        quote! {
87            {
88                Ok(#func_body)
89            }
90        }
91    };
92
93    // Generate the wrapper function
94    Ok(quote! {
95        #(#func_attrs)*
96        #[napi_derive_ohos::napi]
97        #func_vis fn #func_name<'env>(
98            env: &'env napi_ohos::Env,
99            #func_inputs
100        ) -> napi_ohos::Result<napi_ohos::bindgen_prelude::PromiseRaw<'env, #inner_return_type>> {
101            use ohos_ext::SpawnLocalExt;
102
103            env.spawn_local(async move #async_body)
104        }
105    })
106}
107
108fn is_result_type(ty: &Type) -> bool {
109    if let Type::Path(type_path) = ty {
110        // Check if the last segment is "Result"
111        if let Some(segment) = type_path.path.segments.last() {
112            if segment.ident == "Result" {
113                // Check if it's from napi_ohos
114                return is_napi_ohos_path(&type_path.path);
115            }
116        }
117    }
118    false
119}
120
121fn is_napi_ohos_path(path: &syn::Path) -> bool {
122    // Accept the following patterns:
123    // - napi_ohos::Result
124    // - Result (if imported from napi_ohos)
125
126    let path_str = path
127        .segments
128        .iter()
129        .map(|s| s.ident.to_string())
130        .collect::<Vec<_>>()
131        .join("::");
132
133    // Check if it's explicitly napi_ohos::Result or just Result (assumed to be imported)
134    path_str == "napi_ohos::Result"
135        || (path.segments.len() == 1 && path.segments[0].ident == "Result")
136}
137
138fn extract_result_inner_type(ty: &Type) -> Option<Type> {
139    if let Type::Path(type_path) = ty {
140        if let Some(segment) = type_path.path.segments.last() {
141            if segment.ident == "Result" && is_napi_ohos_path(&type_path.path) {
142                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
143                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
144                        return Some(inner_ty.clone());
145                    }
146                }
147            }
148        }
149    }
150    None
151}