Skip to main content

appentry_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4
5#[proc_macro_attribute]
6pub fn appentry(_args: TokenStream, input: TokenStream) -> TokenStream {
7    let input_fn = syn::parse_macro_input!(input as syn::ItemFn);
8    let fn_sig = &input_fn.sig;
9
10    // Extract function name
11    let fn_name = fn_sig.ident.to_string();
12    let fn_ident = fn_sig.ident.clone();
13
14    // Extract argument information
15    let mut arg_names = Vec::new();
16    let mut arg_types = Vec::new();
17    let mut arg_descs = Vec::new();
18
19    // Extract doc comments from the function to find parameter descriptions
20    let doc_comments = extract_doc_comments(&input_fn.attrs);
21
22    // Extract function description from doc comments (first non-whitespace line)
23    let func_desc = extract_func_description(&doc_comments);
24
25    for arg in &fn_sig.inputs {
26        if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = arg {
27            let arg_name = match pat.as_ref() {
28                syn::Pat::Ident(ident) => ident.ident.to_string(),
29                _ => "_".to_string(),
30            };
31            let arg_type_str = quote! { #ty }.to_string();
32
33            // Look for description of this parameter in doc comments
34            let arg_desc = extract_param_description(&doc_comments, &arg_name);
35
36            arg_names.push(arg_name);
37            arg_types.push(arg_type_str);
38            arg_descs.push(arg_desc);
39        }
40    }
41
42    // Convert to static arrays
43    let arg_count = arg_names.len();
44    let arg_names_literals: Vec<syn::LitStr> = arg_names
45        .iter()
46        .map(|name| syn::LitStr::new(name, Span::call_site()))
47        .collect();
48    let arg_types_literals: Vec<syn::LitStr> = arg_types
49        .iter()
50        .map(|ty| syn::LitStr::new(ty, Span::call_site()))
51        .collect();
52    let arg_descs_literals: Vec<proc_macro2::TokenStream> = arg_descs
53        .iter()
54        .map(|desc_opt| match desc_opt {
55            Some(desc) => {
56                let lit_str = syn::LitStr::new(desc, Span::call_site());
57                quote! { Some(#lit_str) }
58            }
59            None => quote! { None },
60        })
61        .collect();
62
63    // Generate a wrapper function that handles arguments
64    let wrapper_fn_name = syn::Ident::new(&format!("appentry_{}", fn_name), Span::call_site());
65
66    // Process the original function's parameters to generate appropriate argument extraction
67    let mut inputs_with_names = Vec::new();
68    for (_i, input) in fn_sig.inputs.iter().enumerate() {
69        if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = input {
70            if let syn::Pat::Ident(ident) = pat.as_ref() {
71                let arg_name = &ident.ident;
72                let arg_name_str = arg_name.to_string();
73                let short_arg = format!("-{}", arg_name_str.chars().next().unwrap_or('_'));
74                let long_arg = format!("--{}", arg_name_str);
75
76                // Check if the type is bool to handle differently
77                let is_bool = if let syn::Type::Path(type_path) = ty.as_ref() {
78                    type_path
79                        .path
80                        .segments
81                        .last()
82                        .map_or(false, |seg| seg.ident == "bool")
83                } else {
84                    false
85                };
86
87                inputs_with_names.push((
88                    arg_name.clone(),
89                    ty.clone(),
90                    short_arg,
91                    long_arg,
92                    is_bool,
93                ));
94            }
95        }
96    }
97
98    let param_processing: Vec<proc_macro2::TokenStream> = inputs_with_names
99        .iter()
100        .map(|(arg_ident, _, short_arg, long_arg, is_bool)| {
101            let short_arg_lit = syn::LitStr::new(short_arg, Span::call_site());
102            let long_arg_lit = syn::LitStr::new(long_arg, Span::call_site());
103            if *is_bool {
104                // For boolean arguments, we can just check if the flag exists
105                quote! {
106                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
107                }
108            } else {
109                quote! {
110                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
111                }
112            }
113        })
114        .collect();
115
116    // Generate parameter extraction for async context (to avoid lifetime issues)
117    let async_param_processing: Vec<proc_macro2::TokenStream> = inputs_with_names
118        .iter()
119        .map(|(arg_ident, _ty, short_arg, long_arg, is_bool)| {
120            let short_arg_lit = syn::LitStr::new(short_arg, Span::call_site());
121            let long_arg_lit = syn::LitStr::new(long_arg, Span::call_site());
122            // For async context, we'll call the same function but in async context
123            if *is_bool {
124                quote! {
125                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
126                }
127            } else {
128                quote! {
129                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
130                }
131            }
132        })
133        .collect();
134
135    let arg_refs: Vec<syn::Ident> = inputs_with_names
136        .iter()
137        .map(|(name, _, _, _, _)| name.clone())
138        .collect();
139
140    // Check if the return type is Result<_, _>
141    let has_result_return = if let syn::ReturnType::Type(_, ty) = &fn_sig.output {
142        if let syn::Type::Path(type_path) = ty.as_ref() {
143            type_path
144                .path
145                .segments
146                .last()
147                .map_or(false, |segment| segment.ident == "Result")
148        } else {
149            false
150        }
151    } else {
152        false
153    };
154
155    // Check if the original function is async
156    let is_async = fn_sig.asyncness.is_some();
157
158    // Generate the inventory submission code
159    let fn_name_literal = syn::LitStr::new(&fn_name, Span::call_site());
160    let original_function = &input_fn;
161
162    // Define the wrapper function based on whether the original function is async
163    let wrapper_function_definition = if is_async {
164        let call_with_result_handling = match has_result_return {
165            true => quote! { #fn_ident(#(#arg_refs),*).await?; },
166            false => quote! { #fn_ident(#(#arg_refs),*).await; },
167        };
168        let async_wrapper = quote! {
169            fn #wrapper_fn_name(args: &mut std::collections::HashMap<String, Option<String>>) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>>>> {
170                #(#async_param_processing)* // Process parameters synchronously
171                Box::pin(async move {
172                    #call_with_result_handling
173                    Ok(())
174                })
175            }
176        };
177        async_wrapper
178    } else {
179        let call_with_result_handling = match has_result_return {
180            true => quote! { #fn_ident(#(#arg_refs),*)?; },
181            false => quote! { #fn_ident(#(#arg_refs),*); },
182        };
183        quote! {
184            fn #wrapper_fn_name(args: &mut std::collections::HashMap<String, Option<String>>) -> anyhow::Result<()> {
185                #(#param_processing)*
186                #call_with_result_handling
187                Ok(())
188            }
189        }
190    };
191
192    // Define the method type based on whether the original function is async
193    let method_type = if is_async {
194        quote! {
195            ::appentry::AppEntryMethod::Async(#wrapper_fn_name)
196        }
197    } else {
198        quote! {
199            ::appentry::AppEntryMethod::Sync(#wrapper_fn_name)
200        }
201    };
202
203    let expanded = quote! {
204        // The original function
205        #original_function
206
207        // Create a wrapper function to handle arguments
208        #wrapper_function_definition
209
210        // Submit the function info to inventory directly
211        ::inventory::submit! {
212            {
213                const ARGS: [::appentry::ArgInfo; #arg_count] = [
214                    #(
215                        ::appentry::ArgInfo::new_with_desc(
216                            #arg_names_literals,
217                            #arg_types_literals,
218                            #arg_descs_literals
219                        ),
220                    )*
221                ];
222                ::appentry::FunctionInfo::new(
223                    #fn_name_literal,
224                    #func_desc,
225                    &ARGS,
226                    #method_type
227                )
228            }
229        }
230    };
231
232    //panic!("{}", expanded.to_string());
233    expanded.into()
234}
235
236fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
237    attrs
238        .iter()
239        .filter(|attr| attr.path().is_ident("doc"))
240        .filter_map(|attr| {
241            if let syn::Meta::NameValue(syn::MetaNameValue {
242                value:
243                    syn::Expr::Lit(syn::ExprLit {
244                        lit: syn::Lit::Str(lit_str),
245                        ..
246                    }),
247                ..
248            }) = attr.meta.clone()
249            {
250                Some(lit_str.value())
251            } else {
252                None
253            }
254        })
255        .collect::<Vec<_>>()
256        .join("\n")
257}
258
259fn extract_param_description(doc_comments: &str, param_name: &str) -> Option<String> {
260    // Look for parameter descriptions in common Rust doc comment formats
261    let lines: Vec<&str> = doc_comments.lines().collect();
262
263    // First, look for patterns like "* `param_name` - description" (Rust standard format)
264    for line in &lines {
265        let trimmed = line.trim();
266        if trimmed.contains(param_name) && (trimmed.contains('`') && trimmed.contains('-')) {
267            // Look for format like "* `name` - The name of the person to greet"
268            if let Some(start_pos) = trimmed.find(&format!("`{}`", param_name)) {
269                // Find the dash after the parameter name
270                if let Some(dash_pos) = trimmed[start_pos..].find(" - ") {
271                    let full_dash_pos = start_pos + dash_pos + 3; // +3 for " - "
272                    let desc = trimmed[full_dash_pos..].trim();
273                    if !desc.is_empty() {
274                        return Some(desc.to_string());
275                    }
276                }
277            }
278        }
279    }
280
281    // Look for patterns like "param_name: description"
282    for line in &lines {
283        let trimmed = line.trim();
284        if trimmed.starts_with(param_name) && trimmed.contains(':') {
285            let colon_pos = trimmed.find(':').unwrap();
286            let desc = trimmed[colon_pos + 1..].trim();
287            if !desc.is_empty() {
288                return Some(desc.to_string());
289            }
290        }
291    }
292
293    // Look for patterns like "# Arguments" section
294    let mut in_arguments_section = false;
295    for line in &lines {
296        let trimmed = line.trim();
297
298        if trimmed.to_lowercase().contains("arguments") && trimmed.starts_with('#') {
299            in_arguments_section = true;
300            continue;
301        }
302
303        if trimmed.starts_with('#') && !trimmed.to_lowercase().contains("arguments") {
304            // We've moved to a different section
305            in_arguments_section = false;
306        }
307
308        if in_arguments_section {
309            if trimmed.contains(param_name) && (trimmed.contains('`') && trimmed.contains('-')) {
310                if let Some(start_pos) = trimmed.find(&format!("`{}`", param_name)) {
311                    if let Some(dash_pos) = trimmed[start_pos..].find(" - ") {
312                        let full_dash_pos = start_pos + dash_pos + 3;
313                        let desc = trimmed[full_dash_pos..].trim();
314                        if !desc.is_empty() {
315                            return Some(desc.to_string());
316                        }
317                    }
318                }
319            }
320        }
321    }
322
323    None
324}
325
326fn extract_func_description(doc_comments: &str) -> Option<proc_macro2::TokenStream> {
327    // Split the doc comments into lines and find the first meaningful description
328    // Skip any empty lines or lines that are just whitespace
329    let lines: Vec<&str> = doc_comments.lines().collect();
330
331    for line in lines {
332        let trimmed = line.trim();
333        // Skip empty lines or lines that start with '#' (headers) or '*' (list items)
334        if !trimmed.is_empty() && !trimmed.starts_with('#') && !trimmed.starts_with('*') {
335            // This is likely the function description
336            // Avoid returning argument descriptions
337            if !trimmed.to_lowercase().contains("arguments")
338                && !trimmed.to_lowercase().contains("params")
339                && !trimmed.to_lowercase().contains(":")
340            {
341                let lit_str = syn::LitStr::new(trimmed, Span::call_site());
342                return Some(quote! { Some(#lit_str) });
343            }
344        }
345    }
346    Some(quote! { None })
347}