Skip to main content

appentry_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4
5#[proc_macro_attribute]
6pub fn appentry(_args: TokenStream, input: TokenStream) -> TokenStream {
7    let input_fn = syn::parse_macro_input!(input as syn::ItemFn);
8    let fn_sig = &input_fn.sig;
9
10    // Extract function name
11    let fn_name = fn_sig.ident.to_string();
12    let fn_ident = fn_sig.ident.clone();
13
14    // Extract argument information
15    let mut arg_names = Vec::new();
16    let mut arg_types = Vec::new();
17    let mut arg_descs = Vec::new();
18
19    // Extract doc comments from the function to find parameter descriptions
20    let doc_comments = extract_doc_comments(&input_fn.attrs);
21
22    // Extract function description from doc comments (first non-whitespace line)
23    let func_desc = extract_func_description(&doc_comments);
24
25    for arg in &fn_sig.inputs {
26        if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = arg {
27            let arg_name = match pat.as_ref() {
28                syn::Pat::Ident(ident) => ident.ident.to_string(),
29                _ => "_".to_string(),
30            };
31            let arg_type_str = quote! { #ty }.to_string();
32
33            // Look for description of this parameter in doc comments
34            let arg_desc = extract_param_description(&doc_comments, &arg_name);
35
36            arg_names.push(arg_name);
37            arg_types.push(arg_type_str);
38            arg_descs.push(arg_desc);
39        }
40    }
41
42    // Convert to static arrays
43    let arg_count = arg_names.len();
44    let arg_names_literals: Vec<syn::LitStr> = arg_names
45        .iter()
46        .map(|name| syn::LitStr::new(name, Span::call_site()))
47        .collect();
48    let arg_types_literals: Vec<syn::LitStr> = arg_types
49        .iter()
50        .map(|ty| syn::LitStr::new(ty, Span::call_site()))
51        .collect();
52    let arg_descs_literals: Vec<proc_macro2::TokenStream> = arg_descs
53        .iter()
54        .map(|desc_opt| match desc_opt {
55            Some(desc) => {
56                let lit_str = syn::LitStr::new(desc, Span::call_site());
57                quote! { Some(#lit_str) }
58            }
59            None => quote! { None },
60        })
61        .collect();
62
63    // Generate a wrapper function that handles arguments
64    let wrapper_fn_name = syn::Ident::new(&format!("appentry_{}", fn_name), Span::call_site());
65
66    // Process the original function's parameters to generate appropriate argument extraction
67    let mut inputs_with_names = Vec::new();
68    for (_i, input) in fn_sig.inputs.iter().enumerate() {
69        if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = input {
70            if let syn::Pat::Ident(ident) = pat.as_ref() {
71                let arg_name = &ident.ident;
72                let arg_name_str = arg_name.to_string();
73                let short_arg = format!("-{}", arg_name_str.chars().next().unwrap_or('_'));
74                let long_arg = format!("--{}", arg_name_str);
75
76                // Check if the type is bool to handle differently
77                let is_bool = if let syn::Type::Path(type_path) = ty.as_ref() {
78                    type_path
79                        .path
80                        .segments
81                        .last()
82                        .map_or(false, |seg| seg.ident == "bool")
83                } else {
84                    false
85                };
86
87                inputs_with_names.push((
88                    arg_name.clone(),
89                    ty.clone(),
90                    short_arg,
91                    long_arg,
92                    is_bool,
93                ));
94            }
95        }
96    }
97
98    let param_processing: Vec<proc_macro2::TokenStream> = inputs_with_names
99        .iter()
100        .map(|(arg_ident, _, short_arg, long_arg, is_bool)| {
101            let short_arg_lit = syn::LitStr::new(short_arg, Span::call_site());
102            let long_arg_lit = syn::LitStr::new(long_arg, Span::call_site());
103            if *is_bool {
104                // For boolean arguments, we can just check if the flag exists
105                quote! {
106                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
107                }
108            } else {
109                quote! {
110                    let #arg_ident = ::appentry::get_arg_from_name(args, &[#short_arg_lit, #long_arg_lit]);
111                }
112            }
113        })
114        .collect();
115
116    let arg_refs: Vec<syn::Ident> = inputs_with_names
117        .iter()
118        .map(|(name, _, _, _, _)| name.clone())
119        .collect();
120
121    // Check if the return type is Result<_, _>
122    let has_result_return = if let syn::ReturnType::Type(_, ty) = &fn_sig.output {
123        if let syn::Type::Path(type_path) = ty.as_ref() {
124            type_path
125                .path
126                .segments
127                .last()
128                .map_or(false, |segment| segment.ident == "Result")
129        } else {
130            false
131        }
132    } else {
133        false
134    };
135
136    // Generate the inventory submission code
137    let fn_name_literal = syn::LitStr::new(&fn_name, Span::call_site());
138    let original_function = &input_fn;
139
140    let call_with_result_handling = if has_result_return {
141        quote! {
142            #fn_ident(#(#arg_refs),*)?;
143        }
144    } else {
145        quote! {
146            #fn_ident(#(#arg_refs),*);
147        }
148    };
149
150    let expanded = quote! {
151        // The original function
152        #original_function
153
154        // Create a wrapper function to handle arguments
155        fn #wrapper_fn_name(args: &mut std::collections::HashMap<String, Option<String>>) -> anyhow::Result<()> {
156            #(#param_processing)*
157            #call_with_result_handling
158            Ok(())
159        }
160
161        // Submit the function info to inventory directly
162        ::inventory::submit! {
163            {
164                const ARGS: [::appentry::ArgInfo; #arg_count] = [
165                    #(
166                        ::appentry::ArgInfo::new_with_desc(
167                            #arg_names_literals,
168                            #arg_types_literals,
169                            #arg_descs_literals
170                        ),
171                    )*
172                ];
173                ::appentry::FunctionInfo::new_with_desc(
174                    #fn_name_literal,
175                    #func_desc,
176                    &ARGS,
177                    #wrapper_fn_name as fn(&mut std::collections::HashMap<String, Option<String>>) -> anyhow::Result<()>
178                )
179            }
180        }
181    };
182
183    //panic!("{}", expanded.to_string());
184    expanded.into()
185}
186
187fn extract_doc_comments(attrs: &[syn::Attribute]) -> String {
188    attrs
189        .iter()
190        .filter(|attr| attr.path().is_ident("doc"))
191        .filter_map(|attr| {
192            if let syn::Meta::NameValue(syn::MetaNameValue {
193                value:
194                    syn::Expr::Lit(syn::ExprLit {
195                        lit: syn::Lit::Str(lit_str),
196                        ..
197                    }),
198                ..
199            }) = attr.meta.clone()
200            {
201                Some(lit_str.value())
202            } else {
203                None
204            }
205        })
206        .collect::<Vec<_>>()
207        .join("\n")
208}
209
210fn extract_param_description(doc_comments: &str, param_name: &str) -> Option<String> {
211    // Look for parameter descriptions in common Rust doc comment formats
212    let lines: Vec<&str> = doc_comments.lines().collect();
213
214    // First, look for patterns like "* `param_name` - description" (Rust standard format)
215    for line in &lines {
216        let trimmed = line.trim();
217        if trimmed.contains(param_name) && (trimmed.contains('`') && trimmed.contains('-')) {
218            // Look for format like "* `name` - The name of the person to greet"
219            if let Some(start_pos) = trimmed.find(&format!("`{}`", param_name)) {
220                // Find the dash after the parameter name
221                if let Some(dash_pos) = trimmed[start_pos..].find(" - ") {
222                    let full_dash_pos = start_pos + dash_pos + 3; // +3 for " - "
223                    let desc = trimmed[full_dash_pos..].trim();
224                    if !desc.is_empty() {
225                        return Some(desc.to_string());
226                    }
227                }
228            }
229        }
230    }
231
232    // Look for patterns like "param_name: description"
233    for line in &lines {
234        let trimmed = line.trim();
235        if trimmed.starts_with(param_name) && trimmed.contains(':') {
236            let colon_pos = trimmed.find(':').unwrap();
237            let desc = trimmed[colon_pos + 1..].trim();
238            if !desc.is_empty() {
239                return Some(desc.to_string());
240            }
241        }
242    }
243
244    // Look for patterns like "# Arguments" section
245    let mut in_arguments_section = false;
246    for line in &lines {
247        let trimmed = line.trim();
248
249        if trimmed.to_lowercase().contains("arguments") && trimmed.starts_with('#') {
250            in_arguments_section = true;
251            continue;
252        }
253
254        if trimmed.starts_with('#') && !trimmed.to_lowercase().contains("arguments") {
255            // We've moved to a different section
256            in_arguments_section = false;
257        }
258
259        if in_arguments_section {
260            if trimmed.contains(param_name) && (trimmed.contains('`') && trimmed.contains('-')) {
261                if let Some(start_pos) = trimmed.find(&format!("`{}`", param_name)) {
262                    if let Some(dash_pos) = trimmed[start_pos..].find(" - ") {
263                        let full_dash_pos = start_pos + dash_pos + 3;
264                        let desc = trimmed[full_dash_pos..].trim();
265                        if !desc.is_empty() {
266                            return Some(desc.to_string());
267                        }
268                    }
269                }
270            }
271        }
272    }
273
274    None
275}
276
277fn extract_func_description(doc_comments: &str) -> Option<proc_macro2::TokenStream> {
278    // Split the doc comments into lines and find the first meaningful description
279    // Skip any empty lines or lines that are just whitespace
280    let lines: Vec<&str> = doc_comments.lines().collect();
281
282    for line in lines {
283        let trimmed = line.trim();
284        // Skip empty lines or lines that start with '#' (headers) or '*' (list items)
285        if !trimmed.is_empty() && !trimmed.starts_with('#') && !trimmed.starts_with('*') {
286            // This is likely the function description
287            // Avoid returning argument descriptions
288            if !trimmed.to_lowercase().contains("arguments")
289                && !trimmed.to_lowercase().contains("params")
290                && !trimmed.to_lowercase().contains(":")
291            {
292                let lit_str = syn::LitStr::new(trimmed, Span::call_site());
293                return Some(quote! { Some(#lit_str) });
294            }
295        }
296    }
297    Some(quote! { None })
298}