Skip to main content

feat_hijekt/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{format_ident, quote};
4use syn::{
5    parenthesized,
6    parse::{Parse, ParseStream},
7    parse_macro_input, parse_quote,
8    punctuated::Punctuated,
9    visit_mut::VisitMut,
10    Block, Error, Fields, Ident, Item, ItemFn, ItemStruct, Lit, LitStr, Pat, Result, Stmt, Token,
11};
12
13#[derive(Debug, Clone, Default)]
14struct HijektConfig {
15    feat: String,
16    begin: Vec<String>,
17    begin_with: Vec<String>,
18    end: Vec<String>,
19    rm: Vec<String>,
20    replace: Option<String>,
21    add: Vec<String>,
22}
23
24impl HijektConfig {
25    fn feature_flag(&self) -> String {
26        self.feat.clone()
27    }
28
29    fn is_simple_feature_only(&self) -> bool {
30        self.begin.is_empty()
31            && self.begin_with.is_empty()
32            && self.end.is_empty()
33            && self.rm.is_empty()
34            && self.replace.is_none()
35            && self.add.is_empty()
36    }
37
38    fn parse_meta_item(&mut self, meta: syn::meta::ParseNestedMeta) -> Result<()> {
39        if meta.path.is_ident("feat") {
40            let value = meta.value()?;
41            let lit: LitStr = value.parse()?;
42            self.feat = lit.value();
43            return Ok(());
44        }
45
46        if meta.path.is_ident("begin") {
47            if meta.input.peek(Token![=]) {
48                let value = meta.value()?;
49                let lit: LitStr = value.parse()?;
50                self.begin.push(lit.value());
51            } else if meta.input.peek(syn::token::Paren) {
52                meta.parse_nested_meta(|nested| {
53                    let lit: LitStr = nested.input.parse()?;
54                    self.begin.push(lit.value());
55                    Ok(())
56                })?;
57            }
58            return Ok(());
59        }
60
61        if meta.path.is_ident("begin_with") {
62            if meta.input.peek(Token![=]) {
63                let value = meta.value()?;
64                let lit: LitStr = value.parse()?;
65                self.begin_with.push(lit.value());
66            } else if meta.input.peek(syn::token::Paren) {
67                meta.parse_nested_meta(|nested| {
68                    let lit: LitStr = nested.input.parse()?;
69                    self.begin_with.push(lit.value());
70                    Ok(())
71                })?;
72            }
73            return Ok(());
74        }
75
76        if meta.path.is_ident("end") {
77            if meta.input.peek(Token![=]) {
78                let value = meta.value()?;
79                let lit: LitStr = value.parse()?;
80                self.end.push(lit.value());
81            } else if meta.input.peek(syn::token::Paren) {
82                meta.parse_nested_meta(|nested| {
83                    let lit: LitStr = nested.input.parse()?;
84                    self.end.push(lit.value());
85                    Ok(())
86                })?;
87            }
88            return Ok(());
89        }
90
91        if meta.path.is_ident("rm") {
92            if meta.input.peek(Token![=]) {
93                let value = meta.value()?;
94                let lit: LitStr = value.parse()?;
95                self.rm.push(lit.value());
96            } else if meta.input.peek(syn::token::Paren) {
97                let content;
98                parenthesized!(content in meta.input);
99                let items: Punctuated<Lit, Token![,]> =
100                    content.parse_terminated(Lit::parse, Token![,])?;
101                for item in items {
102                    if let Lit::Str(litstr) = item {
103                        self.rm.push(litstr.value());
104                    }
105                }
106            }
107            return Ok(());
108        }
109
110        if meta.path.is_ident("swap") {
111            let value = meta.value()?;
112            let lit: LitStr = value.parse()?;
113            self.replace = Some(lit.value());
114            return Ok(());
115        }
116
117        if meta.path.is_ident("add") {
118            if meta.input.peek(Token![=]) {
119                let value = meta.value()?;
120                let lit: LitStr = value.parse()?;
121                self.add.push(lit.value());
122            } else if meta.input.peek(syn::token::Paren) {
123                let content;
124                parenthesized!(content in meta.input);
125                let items: Punctuated<Lit, Token![,]> =
126                    content.parse_terminated(Lit::parse, Token![,])?;
127                for item in items {
128                    if let Lit::Str(litstr) = item {
129                        self.add.push(litstr.value());
130                    }
131                }
132            }
133            return Ok(());
134        }
135
136        Err(meta.error("unrecognized hijekt attribute"))
137    }
138}
139
140struct HijektArgs {
141    config: HijektConfig,
142}
143
144impl Parse for HijektArgs {
145    fn parse(input: ParseStream) -> Result<Self> {
146        let mut config = HijektConfig::default();
147
148        let metas = Punctuated::<syn::Meta, Token![,]>::parse_terminated(input)?;
149
150        for meta in metas {
151            match meta {
152                syn::Meta::NameValue(nv) => {
153                    if nv.path.is_ident("feat") {
154                        if let syn::Expr::Lit(lit) = &nv.value {
155                            if let syn::Lit::Str(s) = &lit.lit {
156                                config.feat = s.value();
157                            }
158                        }
159                    } else if nv.path.is_ident("begin") {
160                        if let syn::Expr::Lit(lit) = &nv.value {
161                            if let syn::Lit::Str(s) = &lit.lit {
162                                config.begin.push(s.value());
163                            }
164                        }
165                    } else if nv.path.is_ident("begin_with") {
166                        if let syn::Expr::Lit(lit) = &nv.value {
167                            if let syn::Lit::Str(s) = &lit.lit {
168                                config.begin_with.push(s.value());
169                            }
170                        }
171                    } else if nv.path.is_ident("end") {
172                        if let syn::Expr::Lit(lit) = &nv.value {
173                            if let syn::Lit::Str(s) = &lit.lit {
174                                config.end.push(s.value());
175                            }
176                        }
177                    } else if nv.path.is_ident("swap") {
178                        if let syn::Expr::Lit(lit) = &nv.value {
179                            if let syn::Lit::Str(s) = &lit.lit {
180                                config.replace = Some(s.value());
181                            }
182                        }
183                    } else if nv.path.is_ident("rm") {
184                        if let syn::Expr::Lit(lit) = &nv.value {
185                            if let syn::Lit::Str(s) = &lit.lit {
186                                config.rm.push(s.value());
187                            }
188                        }
189                    } else if nv.path.is_ident("add") {
190                        if let syn::Expr::Lit(lit) = &nv.value {
191                            if let syn::Lit::Str(s) = &lit.lit {
192                                config.add.push(s.value());
193                            }
194                        }
195                    }
196                }
197                syn::Meta::List(list) => {
198                    if list.path.is_ident("rm") {
199                        let nested = list
200                            .parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)?;
201                        for lit in nested {
202                            config.rm.push(lit.value());
203                        }
204                    } else if list.path.is_ident("begin") {
205                        let nested = list
206                            .parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)?;
207                        for lit in nested {
208                            config.begin.push(lit.value());
209                        }
210                    } else if list.path.is_ident("begin_with") {
211                        let nested = list
212                            .parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)?;
213                        for lit in nested {
214                            config.begin_with.push(lit.value());
215                        }
216                    } else if list.path.is_ident("end") {
217                        let nested = list
218                            .parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)?;
219                        for lit in nested {
220                            config.end.push(lit.value());
221                        }
222                    } else if list.path.is_ident("add") {
223                        let nested = list
224                            .parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)?;
225                        for lit in nested {
226                            config.add.push(lit.value());
227                        }
228                    }
229                }
230                syn::Meta::Path(path) => {
231                    return Err(Error::new_spanned(path, "expected key-value or list"));
232                }
233            }
234        }
235
236        if config.feat.is_empty() {
237            return Err(Error::new(Span::call_site(), "feat attribute is required"));
238        }
239
240        Ok(HijektArgs { config })
241    }
242}
243
244#[proc_macro_attribute]
245pub fn hijekt(args: TokenStream, input: TokenStream) -> TokenStream {
246    let args = parse_macro_input!(args as HijektArgs);
247    let config = args.config;
248
249    if config.is_simple_feature_only() {
250        let feat_flag = config.feature_flag();
251        let item = parse_macro_input!(input as Item);
252        return TokenStream::from(quote! {
253            #[cfg(feature = #feat_flag)]
254            #item
255        });
256    }
257
258    if let Ok(item_fn) = syn::parse::<ItemFn>(input.clone()) {
259        return handle_function(config, item_fn);
260    }
261
262    if let Ok(item_struct) = syn::parse::<ItemStruct>(input.clone()) {
263        return handle_struct(config, item_struct);
264    }
265
266    let feat_flag = config.feature_flag();
267    let item = parse_macro_input!(input as Item);
268    TokenStream::from(quote! {
269        #[cfg(feature = #feat_flag)]
270        #item
271    })
272}
273
274fn handle_function(config: HijektConfig, func: ItemFn) -> TokenStream {
275    let feat_flag = config.feature_flag();
276    let original = func.clone();
277
278    if let Some(replace_fn) = &config.replace {
279        let replace_ident: Ident = syn::parse_str(replace_fn).unwrap();
280        let vis = &func.vis;
281        let sig = &func.sig;
282        let attrs = &func.attrs;
283
284        let args: Vec<_> = sig
285            .inputs
286            .iter()
287            .filter_map(|arg| match arg {
288                syn::FnArg::Typed(pat_type) => match &*pat_type.pat {
289                    syn::Pat::Ident(ident) => Some(quote! { #ident }),
290                    _ => None,
291                },
292                syn::FnArg::Receiver(_) => Some(quote! { self }),
293            })
294            .collect();
295
296        let begin_calls: Vec<Stmt> = config
297            .begin
298            .iter()
299            .map(|begin_fn| {
300                let begin_ident: Ident = syn::parse_str(begin_fn).unwrap();
301                parse_quote! { #begin_ident(); }
302            })
303            .collect();
304
305        let begin_with_calls: Vec<Stmt> = config
306            .begin_with
307            .iter()
308            .map(|begin_fn| {
309                let begin_ident: Ident = syn::parse_str(begin_fn).unwrap();
310                let ref_args: Vec<_> = sig
311                    .inputs
312                    .iter()
313                    .filter_map(|arg| match arg {
314                        syn::FnArg::Typed(pat_type) => match &*pat_type.pat {
315                            syn::Pat::Ident(ident) => Some(quote! { &#ident }),
316                            _ => None,
317                        },
318                        syn::FnArg::Receiver(_) => Some(quote! { &self }),
319                    })
320                    .collect();
321                parse_quote! { #begin_ident(#(#ref_args),*); }
322            })
323            .collect();
324
325        let end_calls: Vec<Stmt> = config
326            .end
327            .iter()
328            .map(|end_fn| {
329                let end_ident: Ident = syn::parse_str(end_fn).unwrap();
330                parse_quote! { #end_ident(); }
331            })
332            .collect();
333
334        let has_return = !matches!(sig.output, syn::ReturnType::Default);
335        let swap_body = if has_return {
336            if !end_calls.is_empty() {
337                quote! {
338                    #(#begin_calls)*
339                    #(#begin_with_calls)*
340                    let __result = #replace_ident(#(#args),*);
341                    #(#end_calls)*
342                    __result
343                }
344            } else {
345                quote! {
346                    #(#begin_calls)*
347                    #(#begin_with_calls)*
348                    #replace_ident(#(#args),*)
349                }
350            }
351        } else {
352            quote! {
353                #(#begin_calls)*
354                #(#begin_with_calls)*
355                #replace_ident(#(#args),*);
356                #(#end_calls)*
357            }
358        };
359
360        return TokenStream::from(quote! {
361            #(#attrs)*
362            #[cfg(feature = #feat_flag)]
363            #vis #sig {
364                #swap_body
365            }
366
367            #[cfg(not(feature = #feat_flag))]
368            #original
369        });
370    }
371
372    let mut modified = func.clone();
373
374    for rm_target in &config.rm {
375        let mut remover = ItemRemover {
376            targets: vec![rm_target.clone()],
377        };
378        remover.visit_block_mut(&mut modified.block);
379    }
380
381    for begin_fn in config.begin.iter().rev() {
382        let begin_ident: Ident = syn::parse_str(begin_fn).unwrap();
383        modified
384            .block
385            .stmts
386            .insert(0, parse_quote! { #begin_ident(); });
387    }
388
389    for begin_fn in config.begin_with.iter().rev() {
390        let begin_ident: Ident = syn::parse_str(begin_fn).unwrap();
391        let ref_args: Vec<_> = func
392            .sig
393            .inputs
394            .iter()
395            .filter_map(|arg| match arg {
396                syn::FnArg::Typed(pat_type) => match &*pat_type.pat {
397                    syn::Pat::Ident(ident) => Some(quote! { &#ident }),
398                    _ => None,
399                },
400                syn::FnArg::Receiver(_) => Some(quote! { &self }),
401            })
402            .collect();
403        modified
404            .block
405            .stmts
406            .insert(0, parse_quote! { #begin_ident(#(#ref_args),*); });
407    }
408
409    if !config.end.is_empty() {
410        inject_at_end(&mut modified.block, &config.end);
411    }
412
413    TokenStream::from(quote! {
414        #[cfg(feature = #feat_flag)]
415        #modified
416
417        #[cfg(not(feature = #feat_flag))]
418        #original
419    })
420}
421
422fn handle_struct(config: HijektConfig, item: ItemStruct) -> TokenStream {
423    let feat_flag = config.feature_flag();
424    let original = item.clone();
425    let mut modified = item.clone();
426
427    for rm_field in &config.rm {
428        if let Fields::Named(ref mut fields) = modified.fields {
429            fields.named = fields
430                .named
431                .iter()
432                .filter(|f| {
433                    f.ident
434                        .as_ref()
435                        .map(|i| i.to_string() != *rm_field)
436                        .unwrap_or(true)
437                })
438                .cloned()
439                .collect();
440        }
441    }
442
443    for add_spec in &config.add {
444        if let Fields::Named(ref mut fields) = modified.fields {
445            if add_spec.contains(':') {
446                // Parse "field_name: Type"
447                let parts: Vec<&str> = add_spec.splitn(2, ':').collect();
448                if parts.len() == 2 {
449                    let field_name = parts[0].trim();
450                    let type_str = parts[1].trim();
451
452                    if let Ok(field_ident) = syn::parse_str::<Ident>(field_name) {
453                        if let Ok(field_type) = syn::parse_str::<syn::Type>(type_str) {
454                            fields.named.push(parse_quote! {
455                                pub #field_ident: #field_type
456                            });
457                        }
458                    }
459                }
460            } else {
461                // Only type specified, generate field name
462                let sanitized_name = add_spec
463                    .to_lowercase()
464                    .replace("::", "_")
465                    .replace('<', "_")
466                    .replace('>', "")
467                    .replace(' ', "")
468                    .replace(',', "_");
469
470                let field_name = format_ident!("hijekt_{}", sanitized_name);
471
472                if let Ok(field_type) = syn::parse_str::<syn::Type>(add_spec) {
473                    fields.named.push(parse_quote! {
474                        pub #field_name: #field_type
475                    });
476                }
477            }
478        }
479    }
480
481    TokenStream::from(quote! {
482        #[cfg(feature = #feat_flag)]
483        #modified
484
485        #[cfg(not(feature = #feat_flag))]
486        #original
487    })
488}
489
490fn inject_at_end(block: &mut Block, end_fns: &[String]) {
491    let has_implicit_return = block
492        .stmts
493        .last()
494        .map_or(false, |stmt| matches!(stmt, Stmt::Expr(_, None)));
495
496    let end_calls: Vec<Stmt> = end_fns
497        .iter()
498        .map(|end_fn| {
499            let end_ident: Ident = syn::parse_str(end_fn).unwrap();
500            parse_quote! { #end_ident(); }
501        })
502        .collect();
503
504    if has_implicit_return {
505        if let Some(Stmt::Expr(expr, None)) = block.stmts.pop() {
506            block.stmts.push(parse_quote! {
507                let __hijekt_result = #expr;
508            });
509
510            block.stmts.extend(end_calls);
511
512            block
513                .stmts
514                .push(Stmt::Expr(parse_quote! { __hijekt_result }, None));
515        }
516    } else {
517        block.stmts.extend(end_calls);
518    }
519}
520
521struct ItemRemover {
522    targets: Vec<String>,
523}
524
525impl VisitMut for ItemRemover {
526    fn visit_block_mut(&mut self, block: &mut Block) {
527        block.stmts.retain(|stmt| match stmt {
528            Stmt::Item(Item::Fn(func)) => !self.targets.contains(&func.sig.ident.to_string()),
529            Stmt::Local(local) => {
530                if let Pat::Ident(ident) = &local.pat {
531                    !self.targets.contains(&ident.ident.to_string())
532                } else {
533                    true
534                }
535            }
536            _ => true,
537        });
538
539        syn::visit_mut::visit_block_mut(self, block);
540    }
541}