1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{format_ident, quote};
4use syn::{
5 parenthesized,
6 parse::{Parse, ParseStream},
7 parse_macro_input, parse_quote,
8 punctuated::Punctuated,
9 visit_mut::VisitMut,
10 Block, Error, Fields, Ident, Item, ItemFn, ItemStruct, Lit, LitStr, Pat, Result, Stmt, Token,
11};
12
13#[derive(Debug, Clone, Default)]
14struct HijektConfig {
15 feat: String,
16 begin: Vec<String>,
17 begin_with: Vec<String>,
18 end: Vec<String>,
19 rm: Vec<String>,
20 replace: Option<String>,
21 add: Vec<String>,
22}
23
24impl HijektConfig {
25 fn feature_flag(&self) -> String {
26 self.feat.clone()
27 }
28
29 fn is_simple_feature_only(&self) -> bool {
30 self.begin.is_empty()
31 && self.begin_with.is_empty()
32 && self.end.is_empty()
33 && self.rm.is_empty()
34 && self.replace.is_none()
35 && self.add.is_empty()
36 }
37
38 fn parse_meta_item(&mut self, meta: syn::meta::ParseNestedMeta) -> Result<()> {
39 if meta.path.is_ident("feat") {
40 let value = meta.value()?;
41 let lit: LitStr = value.parse()?;
42 self.feat = lit.value();
43 return Ok(());
44 }
45
46 if meta.path.is_ident("begin") {
47 if meta.input.peek(Token![=]) {
48 let value = meta.value()?;
49 let lit: LitStr = value.parse()?;
50 self.begin.push(lit.value());
51 } else if meta.input.peek(syn::token::Paren) {
52 meta.parse_nested_meta(|nested| {
53 let lit: LitStr = nested.input.parse()?;
54 self.begin.push(lit.value());
55 Ok(())
56 })?;
57 }
58 return Ok(());
59 }
60
61 if meta.path.is_ident("begin_with") {
62 if meta.input.peek(Token![=]) {
63 let value = meta.value()?;
64 let lit: LitStr = value.parse()?;
65 self.begin_with.push(lit.value());
66 } else if meta.input.peek(syn::token::Paren) {
67 meta.parse_nested_meta(|nested| {
68 let lit: LitStr = nested.input.parse()?;
69 self.begin_with.push(lit.value());
70 Ok(())
71 })?;
72 }
73 return Ok(());
74 }
75
76 if meta.path.is_ident("end") {
77 if meta.input.peek(Token![=]) {
78 let value = meta.value()?;
79 let lit: LitStr = value.parse()?;
80 self.end.push(lit.value());
81 } else if meta.input.peek(syn::token::Paren) {
82 meta.parse_nested_meta(|nested| {
83 let lit: LitStr = nested.input.parse()?;
84 self.end.push(lit.value());
85 Ok(())
86 })?;
87 }
88 return Ok(());
89 }
90
91 if meta.path.is_ident("rm") {
92 if meta.input.peek(Token![=]) {
93 let value = meta.value()?;
94 let lit: LitStr = value.parse()?;
95 self.rm.push(lit.value());
96 } else if meta.input.peek(syn::token::Paren) {
97 let content;
98 parenthesized!(content in meta.input);
99 let items: Punctuated<Lit, Token![,]> =
100 content.parse_terminated(Lit::parse, Token![,])?;
101 for item in items {
102 if let Lit::Str(litstr) = item {
103 self.rm.push(litstr.value());
104 }
105 }
106 }
107 return Ok(());
108 }
109
110 if meta.path.is_ident("swap") {
111 let value = meta.value()?;
112 let lit: LitStr = value.parse()?;
113 self.replace = Some(lit.value());
114 return Ok(());
115 }
116
117 if meta.path.is_ident("add") {
118 if meta.input.peek(Token![=]) {
119 let value = meta.value()?;
120 let lit: LitStr = value.parse()?;
121 self.add.push(lit.value());
122 } else if meta.input.peek(syn::token::Paren) {
123 let content;
124 parenthesized!(content in meta.input);
125 let items: Punctuated<Lit, Token![,]> =
126 content.parse_terminated(Lit::parse, Token![,])?;
127 for item in items {
128 if let Lit::Str(litstr) = item {
129 self.add.push(litstr.value());
130 }
131 }
132 }
133 return Ok(());
134 }
135
136 Err(meta.error("unrecognized hijekt attribute"))
137 }
138}
139
140struct HijektArgs {
141 config: HijektConfig,
142}
143
144impl Parse for HijektArgs {
145 fn parse(input: ParseStream) -> Result<Self> {
146 let mut config = HijektConfig::default();
147
148 let metas = Punctuated::<syn::Meta, Token![,]>::parse_terminated(input)?;
149
150 for meta in metas {
151 match meta {
152 syn::Meta::NameValue(nv) => {
153 if nv.path.is_ident("feat") {
154 if let syn::Expr::Lit(lit) = &nv.value {
155 if let syn::Lit::Str(s) = &lit.lit {
156 config.feat = s.value();
157 }
158 }
159 } else if nv.path.is_ident("begin") {
160 if let syn::Expr::Lit(lit) = &nv.value {
161 if let syn::Lit::Str(s) = &lit.lit {
162 config.begin.push(s.value());
163 }
164 }
165 } else if nv.path.is_ident("begin_with") {
166 if let syn::Expr::Lit(lit) = &nv.value {
167 if let syn::Lit::Str(s) = &lit.lit {
168 config.begin_with.push(s.value());
169 }
170 }
171 } else if nv.path.is_ident("end") {
172 if let syn::Expr::Lit(lit) = &nv.value {
173 if let syn::Lit::Str(s) = &lit.lit {
174 config.end.push(s.value());
175 }
176 }
177 } else if nv.path.is_ident("swap") {
178 if let syn::Expr::Lit(lit) = &nv.value {
179 if let syn::Lit::Str(s) = &lit.lit {
180 config.replace = Some(s.value());
181 }
182 }
183 } else if nv.path.is_ident("rm") {
184 if let syn::Expr::Lit(lit) = &nv.value {
185 if let syn::Lit::Str(s) = &lit.lit {
186 config.rm.push(s.value());
187 }
188 }
189 } else if nv.path.is_ident("add") {
190 if let syn::Expr::Lit(lit) = &nv.value {
191 if let syn::Lit::Str(s) = &lit.lit {
192 config.add.push(s.value());
193 }
194 }
195 }
196 }
197 syn::Meta::List(list) => {
198 if list.path.is_ident("rm") {
199 let nested = list
200 .parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)?;
201 for lit in nested {
202 config.rm.push(lit.value());
203 }
204 } else if list.path.is_ident("begin") {
205 let nested = list
206 .parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)?;
207 for lit in nested {
208 config.begin.push(lit.value());
209 }
210 } else if list.path.is_ident("begin_with") {
211 let nested = list
212 .parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)?;
213 for lit in nested {
214 config.begin_with.push(lit.value());
215 }
216 } else if list.path.is_ident("end") {
217 let nested = list
218 .parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)?;
219 for lit in nested {
220 config.end.push(lit.value());
221 }
222 } else if list.path.is_ident("add") {
223 let nested = list
224 .parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)?;
225 for lit in nested {
226 config.add.push(lit.value());
227 }
228 }
229 }
230 syn::Meta::Path(path) => {
231 return Err(Error::new_spanned(path, "expected key-value or list"));
232 }
233 }
234 }
235
236 if config.feat.is_empty() {
237 return Err(Error::new(Span::call_site(), "feat attribute is required"));
238 }
239
240 Ok(HijektArgs { config })
241 }
242}
243
244#[proc_macro_attribute]
245pub fn hijekt(args: TokenStream, input: TokenStream) -> TokenStream {
246 let args = parse_macro_input!(args as HijektArgs);
247 let config = args.config;
248
249 if config.is_simple_feature_only() {
250 let feat_flag = config.feature_flag();
251 let item = parse_macro_input!(input as Item);
252 return TokenStream::from(quote! {
253 #[cfg(feature = #feat_flag)]
254 #item
255 });
256 }
257
258 if let Ok(item_fn) = syn::parse::<ItemFn>(input.clone()) {
259 return handle_function(config, item_fn);
260 }
261
262 if let Ok(item_struct) = syn::parse::<ItemStruct>(input.clone()) {
263 return handle_struct(config, item_struct);
264 }
265
266 let feat_flag = config.feature_flag();
267 let item = parse_macro_input!(input as Item);
268 TokenStream::from(quote! {
269 #[cfg(feature = #feat_flag)]
270 #item
271 })
272}
273
274fn handle_function(config: HijektConfig, func: ItemFn) -> TokenStream {
275 let feat_flag = config.feature_flag();
276 let original = func.clone();
277
278 if let Some(replace_fn) = &config.replace {
279 let replace_ident: Ident = syn::parse_str(replace_fn).unwrap();
280 let vis = &func.vis;
281 let sig = &func.sig;
282 let attrs = &func.attrs;
283
284 let args: Vec<_> = sig
285 .inputs
286 .iter()
287 .filter_map(|arg| match arg {
288 syn::FnArg::Typed(pat_type) => match &*pat_type.pat {
289 syn::Pat::Ident(ident) => Some(quote! { #ident }),
290 _ => None,
291 },
292 syn::FnArg::Receiver(_) => Some(quote! { self }),
293 })
294 .collect();
295
296 let begin_calls: Vec<Stmt> = config
297 .begin
298 .iter()
299 .map(|begin_fn| {
300 let begin_ident: Ident = syn::parse_str(begin_fn).unwrap();
301 parse_quote! { #begin_ident(); }
302 })
303 .collect();
304
305 let begin_with_calls: Vec<Stmt> = config
306 .begin_with
307 .iter()
308 .map(|begin_fn| {
309 let begin_ident: Ident = syn::parse_str(begin_fn).unwrap();
310 let ref_args: Vec<_> = sig
311 .inputs
312 .iter()
313 .filter_map(|arg| match arg {
314 syn::FnArg::Typed(pat_type) => match &*pat_type.pat {
315 syn::Pat::Ident(ident) => Some(quote! { &#ident }),
316 _ => None,
317 },
318 syn::FnArg::Receiver(_) => Some(quote! { &self }),
319 })
320 .collect();
321 parse_quote! { #begin_ident(#(#ref_args),*); }
322 })
323 .collect();
324
325 let end_calls: Vec<Stmt> = config
326 .end
327 .iter()
328 .map(|end_fn| {
329 let end_ident: Ident = syn::parse_str(end_fn).unwrap();
330 parse_quote! { #end_ident(); }
331 })
332 .collect();
333
334 let has_return = !matches!(sig.output, syn::ReturnType::Default);
335 let swap_body = if has_return {
336 if !end_calls.is_empty() {
337 quote! {
338 #(#begin_calls)*
339 #(#begin_with_calls)*
340 let __result = #replace_ident(#(#args),*);
341 #(#end_calls)*
342 __result
343 }
344 } else {
345 quote! {
346 #(#begin_calls)*
347 #(#begin_with_calls)*
348 #replace_ident(#(#args),*)
349 }
350 }
351 } else {
352 quote! {
353 #(#begin_calls)*
354 #(#begin_with_calls)*
355 #replace_ident(#(#args),*);
356 #(#end_calls)*
357 }
358 };
359
360 return TokenStream::from(quote! {
361 #(#attrs)*
362 #[cfg(feature = #feat_flag)]
363 #vis #sig {
364 #swap_body
365 }
366
367 #[cfg(not(feature = #feat_flag))]
368 #original
369 });
370 }
371
372 let mut modified = func.clone();
373
374 for rm_target in &config.rm {
375 let mut remover = ItemRemover {
376 targets: vec![rm_target.clone()],
377 };
378 remover.visit_block_mut(&mut modified.block);
379 }
380
381 for begin_fn in config.begin.iter().rev() {
382 let begin_ident: Ident = syn::parse_str(begin_fn).unwrap();
383 modified
384 .block
385 .stmts
386 .insert(0, parse_quote! { #begin_ident(); });
387 }
388
389 for begin_fn in config.begin_with.iter().rev() {
390 let begin_ident: Ident = syn::parse_str(begin_fn).unwrap();
391 let ref_args: Vec<_> = func
392 .sig
393 .inputs
394 .iter()
395 .filter_map(|arg| match arg {
396 syn::FnArg::Typed(pat_type) => match &*pat_type.pat {
397 syn::Pat::Ident(ident) => Some(quote! { &#ident }),
398 _ => None,
399 },
400 syn::FnArg::Receiver(_) => Some(quote! { &self }),
401 })
402 .collect();
403 modified
404 .block
405 .stmts
406 .insert(0, parse_quote! { #begin_ident(#(#ref_args),*); });
407 }
408
409 if !config.end.is_empty() {
410 inject_at_end(&mut modified.block, &config.end);
411 }
412
413 TokenStream::from(quote! {
414 #[cfg(feature = #feat_flag)]
415 #modified
416
417 #[cfg(not(feature = #feat_flag))]
418 #original
419 })
420}
421
422fn handle_struct(config: HijektConfig, item: ItemStruct) -> TokenStream {
423 let feat_flag = config.feature_flag();
424 let original = item.clone();
425 let mut modified = item.clone();
426
427 for rm_field in &config.rm {
428 if let Fields::Named(ref mut fields) = modified.fields {
429 fields.named = fields
430 .named
431 .iter()
432 .filter(|f| {
433 f.ident
434 .as_ref()
435 .map(|i| i.to_string() != *rm_field)
436 .unwrap_or(true)
437 })
438 .cloned()
439 .collect();
440 }
441 }
442
443 for add_spec in &config.add {
444 if let Fields::Named(ref mut fields) = modified.fields {
445 if add_spec.contains(':') {
446 let parts: Vec<&str> = add_spec.splitn(2, ':').collect();
448 if parts.len() == 2 {
449 let field_name = parts[0].trim();
450 let type_str = parts[1].trim();
451
452 if let Ok(field_ident) = syn::parse_str::<Ident>(field_name) {
453 if let Ok(field_type) = syn::parse_str::<syn::Type>(type_str) {
454 fields.named.push(parse_quote! {
455 pub #field_ident: #field_type
456 });
457 }
458 }
459 }
460 } else {
461 let sanitized_name = add_spec
463 .to_lowercase()
464 .replace("::", "_")
465 .replace('<', "_")
466 .replace('>', "")
467 .replace(' ', "")
468 .replace(',', "_");
469
470 let field_name = format_ident!("hijekt_{}", sanitized_name);
471
472 if let Ok(field_type) = syn::parse_str::<syn::Type>(add_spec) {
473 fields.named.push(parse_quote! {
474 pub #field_name: #field_type
475 });
476 }
477 }
478 }
479 }
480
481 TokenStream::from(quote! {
482 #[cfg(feature = #feat_flag)]
483 #modified
484
485 #[cfg(not(feature = #feat_flag))]
486 #original
487 })
488}
489
490fn inject_at_end(block: &mut Block, end_fns: &[String]) {
491 let has_implicit_return = block
492 .stmts
493 .last()
494 .map_or(false, |stmt| matches!(stmt, Stmt::Expr(_, None)));
495
496 let end_calls: Vec<Stmt> = end_fns
497 .iter()
498 .map(|end_fn| {
499 let end_ident: Ident = syn::parse_str(end_fn).unwrap();
500 parse_quote! { #end_ident(); }
501 })
502 .collect();
503
504 if has_implicit_return {
505 if let Some(Stmt::Expr(expr, None)) = block.stmts.pop() {
506 block.stmts.push(parse_quote! {
507 let __hijekt_result = #expr;
508 });
509
510 block.stmts.extend(end_calls);
511
512 block
513 .stmts
514 .push(Stmt::Expr(parse_quote! { __hijekt_result }, None));
515 }
516 } else {
517 block.stmts.extend(end_calls);
518 }
519}
520
521struct ItemRemover {
522 targets: Vec<String>,
523}
524
525impl VisitMut for ItemRemover {
526 fn visit_block_mut(&mut self, block: &mut Block) {
527 block.stmts.retain(|stmt| match stmt {
528 Stmt::Item(Item::Fn(func)) => !self.targets.contains(&func.sig.ident.to_string()),
529 Stmt::Local(local) => {
530 if let Pat::Ident(ident) = &local.pat {
531 !self.targets.contains(&ident.ident.to_string())
532 } else {
533 true
534 }
535 }
536 _ => true,
537 });
538
539 syn::visit_mut::visit_block_mut(self, block);
540 }
541}