Skip to main content

appentry_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4
5#[proc_macro_attribute]
6pub fn appentry(args: TokenStream, input: TokenStream) -> TokenStream {
7    let input_fn = syn::parse_macro_input!(input as syn::ItemFn);
8    let fn_sig = &input_fn.sig;
9
10    // Parse macro arguments
11    let args_str = args.to_string();
12    let is_default = args_str.contains("default");
13
14    // Extract function name
15    let fn_name = fn_sig.ident.to_string();
16    let fn_ident = fn_sig.ident.clone();
17
18    // Extract argument information
19    let mut arg_names = Vec::new();
20    let mut arg_types = Vec::new();
21    let mut arg_descs = Vec::new();
22
23    // Extract doc comments from the function to find parameter descriptions
24    let doc_comments = extract_doc_comments(&input_fn.attrs);
25
26    // Extract function description from doc comments (first non-whitespace line)
27    let func_desc = extract_func_description(&doc_comments);
28
29    for arg in &fn_sig.inputs {
30        if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = arg {
31            let arg_name = match pat.as_ref() {
32                syn::Pat::Ident(ident) => ident.ident.to_string(),
33                _ => "_".to_string(),
34            };
35            let arg_type_str = quote! { #ty }.to_string();
36
37            // Look for description of this parameter in doc comments
38            let arg_desc = extract_param_description(&doc_comments, &arg_name);
39
40            arg_names.push(arg_name);
41            arg_types.push(arg_type_str);
42            arg_descs.push(arg_desc);
43        }
44    }
45
46    // Convert to static arrays
47    let arg_count = arg_names.len();
48    let arg_names_literals: Vec<syn::LitStr> = arg_names
49        .iter()
50        .map(|name| syn::LitStr::new(name, Span::call_site()))
51        .collect();
52    let arg_types_literals: Vec<syn::LitStr> = arg_types
53        .iter()
54        .map(|ty| syn::LitStr::new(ty, Span::call_site()))
55        .collect();
56    let arg_descs_literals: Vec<proc_macro2::TokenStream> = arg_descs
57        .iter()
58        .map(|desc_opt| match desc_opt {
59            Some(desc) => {
60                let lit_str = syn::LitStr::new(desc, Span::call_site());
61                quote! { Some(#lit_str) }
62            }
63            None => quote! { None },
64        })
65        .collect();
66
67    // Generate a wrapper function that handles arguments
68    let wrapper_fn_name = syn::Ident::new(&format!("appentry_{}", fn_name), Span::call_site());
69
70    // Process the original function's parameters to generate appropriate argument extraction
71    let mut inputs_with_names = Vec::new();
72    for (_i, input) in fn_sig.inputs.iter().enumerate() {
73        if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = input {
74            if let syn::Pat::Ident(ident) = pat.as_ref() {
75                let arg_name = &ident.ident;
76                let arg_name_str = arg_name.to_string();
77                let short_arg = format!("-{}", arg_name_str.chars().next().unwrap_or('_'));
78                let long_arg = format!("--{}", arg_name_str);
79
80                // Check if the type is bool to handle differently
81                let is_bool = if let syn::Type::Path(type_path) = ty.as_ref() {
82                    type_path
83                        .path
84                        .segments
85                        .last()
86                        .map_or(false, |seg| seg.ident == "bool")
87                } else {
88                    false
89                };
90
91                inputs_with_names.push((
92                    arg_name.clone(),
93                    ty.clone(),
94                    short_arg,
95                    long_arg,
96                    is_bool,
97                ));
98            }
99        }
100    }
101
102    let param_processing: Vec<proc_macro2::TokenStream> = inputs_with_names
103        .iter()
104        .map(|(arg_ident, _, short_arg, long_arg, is_bool)| {
105            let short_arg_lit = syn::LitStr::new(short_arg, Span::call_site());
106            let long_arg_lit = syn::LitStr::new(long_arg, Span::call_site());
107            if *is_bool {
108                // For boolean arguments, we can just check if the flag exists
109                quote! {
110                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
111                }
112            } else {
113                quote! {
114                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
115                }
116            }
117        })
118        .collect();
119
120    // Generate parameter extraction for async context (to avoid lifetime issues)
121    let async_param_processing: Vec<proc_macro2::TokenStream> = inputs_with_names
122        .iter()
123        .map(|(arg_ident, _ty, short_arg, long_arg, is_bool)| {
124            let short_arg_lit = syn::LitStr::new(short_arg, Span::call_site());
125            let long_arg_lit = syn::LitStr::new(long_arg, Span::call_site());
126            // For async context, we'll call the same function but in async context
127            if *is_bool {
128                quote! {
129                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
130                }
131            } else {
132                quote! {
133                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
134                }
135            }
136        })
137        .collect();
138
139    let arg_refs: Vec<syn::Ident> = inputs_with_names
140        .iter()
141        .map(|(name, _, _, _, _)| name.clone())
142        .collect();
143
144    // Check if the return type is Result<_, _>
145    let has_result_return = if let syn::ReturnType::Type(_, ty) = &fn_sig.output {
146        if let syn::Type::Path(type_path) = ty.as_ref() {
147            type_path
148                .path
149                .segments
150                .last()
151                .map_or(false, |segment| segment.ident == "Result")
152        } else {
153            false
154        }
155    } else {
156        false
157    };
158
159    // Check if the original function is async
160    let is_async = fn_sig.asyncness.is_some();
161
162    // Generate the inventory submission code
163    let fn_name_literal = syn::LitStr::new(&fn_name, Span::call_site());
164    let original_function = &input_fn;
165
166    // Define the wrapper function based on whether the original function is async
167    let wrapper_function_definition = if is_async {
168        let call_with_result_handling = match has_result_return {
169            true => quote! { #fn_ident(#(#arg_refs),*).await?; },
170            false => quote! { #fn_ident(#(#arg_refs),*).await; },
171        };
172        let async_wrapper = quote! {
173            fn #wrapper_fn_name(args: &mut std::collections::HashMap<String, Option<String>>) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>>>> {
174                #(#async_param_processing)* // Process parameters synchronously
175                Box::pin(async move {
176                    #call_with_result_handling
177                    Ok(())
178                })
179            }
180        };
181        async_wrapper
182    } else {
183        let call_with_result_handling = match has_result_return {
184            true => quote! { #fn_ident(#(#arg_refs),*)?; },
185            false => quote! { #fn_ident(#(#arg_refs),*); },
186        };
187        quote! {
188            fn #wrapper_fn_name(args: &mut std::collections::HashMap<String, Option<String>>) -> anyhow::Result<()> {
189                #(#param_processing)*
190                #call_with_result_handling
191                Ok(())
192            }
193        }
194    };
195
196    // Define the method type based on whether the original function is async
197    let method_type = if is_async {
198        quote! {
199            ::appentry::AppEntryMethod::Async(#wrapper_fn_name)
200        }
201    } else {
202        quote! {
203            ::appentry::AppEntryMethod::Sync(#wrapper_fn_name)
204        }
205    };
206
207    let expanded = quote! {
208        // The original function
209        #original_function
210
211        // Create a wrapper function to handle arguments
212        #wrapper_function_definition
213
214        // Submit the function info to inventory directly
215        ::inventory::submit! {
216            {
217                const ARGS: [::appentry::ArgInfo; #arg_count] = [
218                    #(
219                        ::appentry::ArgInfo::new_with_desc(
220                            #arg_names_literals,
221                            #arg_types_literals,
222                            #arg_descs_literals
223                        ),
224                    )*
225                ];
226                ::appentry::FunctionInfo::new(
227                    #fn_name_literal,
228                    #is_default,
229                    #func_desc,
230                    &ARGS,
231                    #method_type
232                )
233            }
234        }
235    };
236
237    //panic!("{}", expanded.to_string());
238    expanded.into()
239}
240
241fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
242    attrs
243        .iter()
244        .filter(|attr| attr.path().is_ident("doc"))
245        .filter_map(|attr| {
246            if let syn::Meta::NameValue(syn::MetaNameValue {
247                value:
248                    syn::Expr::Lit(syn::ExprLit {
249                        lit: syn::Lit::Str(lit_str),
250                        ..
251                    }),
252                ..
253            }) = attr.meta.clone()
254            {
255                Some(lit_str.value())
256            } else {
257                None
258            }
259        })
260        .collect::<Vec<_>>()
261        .join("\n")
262}
263
264fn extract_param_description(doc_comments: &str, param_name: &str) -> Option<String> {
265    // Look for parameter descriptions in common Rust doc comment formats
266    let lines: Vec<&str> = doc_comments.lines().collect();
267
268    // First, look for patterns like "* `param_name` - description" (Rust standard format)
269    for line in &lines {
270        let trimmed = line.trim();
271        if trimmed.contains(param_name) && (trimmed.contains('`') && trimmed.contains('-')) {
272            // Look for format like "* `name` - The name of the person to greet"
273            if let Some(start_pos) = trimmed.find(&format!("`{}`", param_name)) {
274                // Find the dash after the parameter name
275                if let Some(dash_pos) = trimmed[start_pos..].find(" - ") {
276                    let full_dash_pos = start_pos + dash_pos + 3; // +3 for " - "
277                    let desc = trimmed[full_dash_pos..].trim();
278                    if !desc.is_empty() {
279                        return Some(desc.to_string());
280                    }
281                }
282            }
283        }
284    }
285
286    // Look for patterns like "param_name: description"
287    for line in &lines {
288        let trimmed = line.trim();
289        if trimmed.starts_with(param_name) && trimmed.contains(':') {
290            let colon_pos = trimmed.find(':').unwrap();
291            let desc = trimmed[colon_pos + 1..].trim();
292            if !desc.is_empty() {
293                return Some(desc.to_string());
294            }
295        }
296    }
297
298    // Look for patterns like "# Arguments" section
299    let mut in_arguments_section = false;
300    for line in &lines {
301        let trimmed = line.trim();
302
303        if trimmed.to_lowercase().contains("arguments") && trimmed.starts_with('#') {
304            in_arguments_section = true;
305            continue;
306        }
307
308        if trimmed.starts_with('#') && !trimmed.to_lowercase().contains("arguments") {
309            // We've moved to a different section
310            in_arguments_section = false;
311        }
312
313        if in_arguments_section {
314            if trimmed.contains(param_name) && (trimmed.contains('`') && trimmed.contains('-')) {
315                if let Some(start_pos) = trimmed.find(&format!("`{}`", param_name)) {
316                    if let Some(dash_pos) = trimmed[start_pos..].find(" - ") {
317                        let full_dash_pos = start_pos + dash_pos + 3;
318                        let desc = trimmed[full_dash_pos..].trim();
319                        if !desc.is_empty() {
320                            return Some(desc.to_string());
321                        }
322                    }
323                }
324            }
325        }
326    }
327
328    None
329}
330
331fn extract_func_description(doc_comments: &str) -> Option<proc_macro2::TokenStream> {
332    // Split the doc comments into lines and find the first meaningful description
333    // Skip any empty lines or lines that are just whitespace
334    let lines: Vec<&str> = doc_comments.lines().collect();
335
336    for line in lines {
337        let trimmed = line.trim();
338        // Skip empty lines or lines that start with '#' (headers) or '*' (list items)
339        if !trimmed.is_empty() && !trimmed.starts_with('#') && !trimmed.starts_with('*') {
340            // This is likely the function description
341            // Avoid returning argument descriptions
342            if !trimmed.to_lowercase().contains("arguments")
343                && !trimmed.to_lowercase().contains("params")
344                && !trimmed.to_lowercase().contains(":")
345            {
346                let lit_str = syn::LitStr::new(trimmed, Span::call_site());
347                return Some(quote! { Some(#lit_str) });
348            }
349        }
350    }
351    Some(quote! { None })
352}