Skip to main content

appentry_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4
5#[proc_macro_attribute]
6pub fn appentry(args: TokenStream, input: TokenStream) -> TokenStream {
7    let input_fn = syn::parse_macro_input!(input as syn::ItemFn);
8    let fn_sig = &input_fn.sig;
9
10    // Extract function name
11    let fn_name = fn_sig.ident.to_string();
12    let fn_ident = fn_sig.ident.clone();
13
14    // Parse the arguments to extract aliases
15    let alias_exprs: Vec<String> = if !args.is_empty() {
16        let parsed_args = syn::parse_macro_input!(args as syn::ExprArray);
17        parsed_args
18            .elems
19            .iter()
20            .filter_map(|expr| {
21                if let syn::Expr::Lit(syn::ExprLit {
22                    lit: syn::Lit::Str(lit_str),
23                    ..
24                }) = expr
25                {
26                    Some(lit_str.value())
27                } else {
28                    None
29                }
30            })
31            .collect()
32    } else {
33        vec![fn_name.to_lowercase()] // Default alias is lowercase function name
34    };
35
36    // Extract argument information
37    let mut arg_names = Vec::new();
38    let mut arg_types = Vec::new();
39    let mut arg_descs = Vec::new();
40
41    // Extract doc comments from the function to find parameter descriptions
42    let doc_comments = extract_doc_comments(&input_fn.attrs);
43
44    // Extract function description from doc comments (first non-whitespace line)
45    let func_desc = extract_func_description(&doc_comments);
46
47    for arg in &fn_sig.inputs {
48        if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = arg {
49            let arg_name = match pat.as_ref() {
50                syn::Pat::Ident(ident) => ident.ident.to_string(),
51                _ => "_".to_string(),
52            };
53            let arg_type_str = quote! { #ty }.to_string();
54
55            // Look for description of this parameter in doc comments
56            let arg_desc = extract_param_description(&doc_comments, &arg_name);
57
58            arg_names.push(arg_name);
59            arg_types.push(arg_type_str);
60            arg_descs.push(arg_desc);
61        }
62    }
63
64    // Convert to static arrays
65    let arg_count = arg_names.len();
66    let arg_names_literals: Vec<syn::LitStr> = arg_names
67        .iter()
68        .map(|name| syn::LitStr::new(name, Span::call_site()))
69        .collect();
70    let arg_types_literals: Vec<syn::LitStr> = arg_types
71        .iter()
72        .map(|ty| syn::LitStr::new(ty, Span::call_site()))
73        .collect();
74    let arg_descs_literals: Vec<proc_macro2::TokenStream> = arg_descs
75        .iter()
76        .map(|desc_opt| match desc_opt {
77            Some(desc) => {
78                let lit_str = syn::LitStr::new(desc, Span::call_site());
79                quote! { Some(#lit_str) }
80            }
81            None => quote! { None },
82        })
83        .collect();
84
85    // Generate a wrapper function that handles arguments
86    let wrapper_fn_name = syn::Ident::new(&format!("appentry_{}", fn_name), Span::call_site());
87
88    // Process the original function's parameters to generate appropriate argument extraction
89    let mut inputs_with_names = Vec::new();
90    for (_i, input) in fn_sig.inputs.iter().enumerate() {
91        if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = input {
92            if let syn::Pat::Ident(ident) = pat.as_ref() {
93                let arg_name = &ident.ident;
94                let arg_name_str = arg_name.to_string();
95                let short_arg = format!("-{}", arg_name_str.chars().next().unwrap_or('_'));
96                let long_arg = format!("--{}", arg_name_str);
97
98                // Check if the type is bool to handle differently
99                let is_bool = if let syn::Type::Path(type_path) = ty.as_ref() {
100                    type_path
101                        .path
102                        .segments
103                        .last()
104                        .map_or(false, |seg| seg.ident == "bool")
105                } else {
106                    false
107                };
108
109                inputs_with_names.push((
110                    arg_name.clone(),
111                    ty.clone(),
112                    short_arg,
113                    long_arg,
114                    is_bool,
115                ));
116            }
117        }
118    }
119
120    let param_processing: Vec<proc_macro2::TokenStream> = inputs_with_names
121        .iter()
122        .map(|(arg_ident, _, short_arg, long_arg, is_bool)| {
123            let short_arg_lit = syn::LitStr::new(short_arg, Span::call_site());
124            let long_arg_lit = syn::LitStr::new(long_arg, Span::call_site());
125            if *is_bool {
126                // For boolean arguments, we can just check if the flag exists
127                quote! {
128                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
129                }
130            } else {
131                quote! {
132                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
133                }
134            }
135        })
136        .collect();
137
138    let arg_refs: Vec<syn::Ident> = inputs_with_names
139        .iter()
140        .map(|(name, _, _, _, _)| name.clone())
141        .collect();
142
143    // Check if the return type is Result<_, _>
144    let has_result_return = if let syn::ReturnType::Type(_, ty) = &fn_sig.output {
145        if let syn::Type::Path(type_path) = ty.as_ref() {
146            type_path
147                .path
148                .segments
149                .last()
150                .map_or(false, |segment| segment.ident == "Result")
151        } else {
152            false
153        }
154    } else {
155        false
156    };
157
158    // Generate the inventory submission code
159    let fn_name_literal = syn::LitStr::new(&fn_name, Span::call_site());
160    let original_function = &input_fn;
161
162    let call_with_result_handling = if has_result_return {
163        quote! {
164            #fn_ident(#(#arg_refs),*)?;
165        }
166    } else {
167        quote! {
168            #fn_ident(#(#arg_refs),*);
169        }
170    };
171
172    // Create static array literals for aliases
173    let alias_literals: Vec<syn::LitStr> = alias_exprs
174        .iter()
175        .map(|alias| syn::LitStr::new(alias, Span::call_site()))
176        .collect();
177    let alias_count = alias_literals.len();
178
179    let expanded = quote! {
180        // The original function
181        #original_function
182
183        // Create a wrapper function to handle arguments
184        fn #wrapper_fn_name(args: &mut std::collections::HashMap<String, Option<String>>) -> anyhow::Result<()> {
185            #(#param_processing)*
186            #call_with_result_handling
187            Ok(())
188        }
189
190        // Submit the function info to inventory directly
191        ::inventory::submit! {
192            {
193                const ALIASES: [&str; #alias_count] = [#(#alias_literals),*];
194                const ARGS: [::appentry::ArgInfo; #arg_count] = [
195                    #(
196                        ::appentry::ArgInfo::new_with_desc(
197                            #arg_names_literals,
198                            #arg_types_literals,
199                            #arg_descs_literals
200                        ),
201                    )*
202                ];
203                ::appentry::FunctionInfo::new_with_desc(
204                    #fn_name_literal,
205                    &ALIASES,
206                    #func_desc,
207                    &ARGS,
208                    #wrapper_fn_name as fn(&mut std::collections::HashMap<String, Option<String>>) -> anyhow::Result<()>
209                )
210            }
211        }
212    };
213
214    //panic!("{}", expanded.to_string());
215    expanded.into()
216}
217
218fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
219    attrs
220        .iter()
221        .filter(|attr| attr.path().is_ident("doc"))
222        .filter_map(|attr| {
223            if let syn::Meta::NameValue(syn::MetaNameValue {
224                value:
225                    syn::Expr::Lit(syn::ExprLit {
226                        lit: syn::Lit::Str(lit_str),
227                        ..
228                    }),
229                ..
230            }) = attr.meta.clone()
231            {
232                Some(lit_str.value())
233            } else {
234                None
235            }
236        })
237        .collect::<Vec<_>>()
238        .join("\n")
239}
240
241fn extract_param_description(doc_comments: &str, param_name: &str) -> Option<String> {
242    // Look for parameter descriptions in common Rust doc comment formats
243    let lines: Vec<&str> = doc_comments.lines().collect();
244
245    // First, look for patterns like "* `param_name` - description" (Rust standard format)
246    for line in &lines {
247        let trimmed = line.trim();
248        if trimmed.contains(param_name) && (trimmed.contains('`') && trimmed.contains('-')) {
249            // Look for format like "* `name` - The name of the person to greet"
250            if let Some(start_pos) = trimmed.find(&format!("`{}`", param_name)) {
251                // Find the dash after the parameter name
252                if let Some(dash_pos) = trimmed[start_pos..].find(" - ") {
253                    let full_dash_pos = start_pos + dash_pos + 3; // +3 for " - "
254                    let desc = trimmed[full_dash_pos..].trim();
255                    if !desc.is_empty() {
256                        return Some(desc.to_string());
257                    }
258                }
259            }
260        }
261    }
262
263    // Look for patterns like "param_name: description"
264    for line in &lines {
265        let trimmed = line.trim();
266        if trimmed.starts_with(param_name) && trimmed.contains(':') {
267            let colon_pos = trimmed.find(':').unwrap();
268            let desc = trimmed[colon_pos + 1..].trim();
269            if !desc.is_empty() {
270                return Some(desc.to_string());
271            }
272        }
273    }
274
275    // Look for patterns like "# Arguments" section
276    let mut in_arguments_section = false;
277    for line in &lines {
278        let trimmed = line.trim();
279
280        if trimmed.to_lowercase().contains("arguments") && trimmed.starts_with('#') {
281            in_arguments_section = true;
282            continue;
283        }
284
285        if trimmed.starts_with('#') && !trimmed.to_lowercase().contains("arguments") {
286            // We've moved to a different section
287            in_arguments_section = false;
288        }
289
290        if in_arguments_section {
291            if trimmed.contains(param_name) && (trimmed.contains('`') && trimmed.contains('-')) {
292                if let Some(start_pos) = trimmed.find(&format!("`{}`", param_name)) {
293                    if let Some(dash_pos) = trimmed[start_pos..].find(" - ") {
294                        let full_dash_pos = start_pos + dash_pos + 3;
295                        let desc = trimmed[full_dash_pos..].trim();
296                        if !desc.is_empty() {
297                            return Some(desc.to_string());
298                        }
299                    }
300                }
301            }
302        }
303    }
304
305    None
306}
307
308fn extract_func_description(doc_comments: &str) -> Option<proc_macro2::TokenStream> {
309    // Split the doc comments into lines and find the first meaningful description
310    // Skip any empty lines or lines that are just whitespace
311    let lines: Vec<&str> = doc_comments.lines().collect();
312
313    for line in lines {
314        let trimmed = line.trim();
315        // Skip empty lines or lines that start with '#' (headers) or '*' (list items)
316        if !trimmed.is_empty() && !trimmed.starts_with('#') && !trimmed.starts_with('*') {
317            // This is likely the function description
318            // Avoid returning argument descriptions
319            if !trimmed.to_lowercase().contains("arguments")
320                && !trimmed.to_lowercase().contains("params")
321                && !trimmed.to_lowercase().contains(":")
322            {
323                let lit_str = syn::LitStr::new(trimmed, Span::call_site());
324                return Some(quote! { Some(#lit_str) });
325            }
326        }
327    }
328    Some(quote! { None })
329}