injectify_impl/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use syn::parse::Nothing;
4use syn::spanned::Spanned;
5use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Error, Expr};
6
7#[proc_macro_attribute]
8#[allow(non_snake_case)]
9pub fn Injectify(args: TokenStream, input: TokenStream) -> TokenStream {
10    let _ = parse_macro_input!(args as Nothing);
11    let input = parse_macro_input!(input as DeriveInput);
12
13    match input.clone().data {
14        Data::Struct(data) => injectify_struct_impl(data, input),
15        Data::Enum(_) => todo!(),
16        _ => Error::new(input.into_token_stream().span(), "Must be a `struct`")
17            .into_compile_error()
18            .to_token_stream()
19            .into(),
20    }
21}
22
23fn injectify_struct_impl(struct_data: DataStruct, derive_input: DeriveInput) -> TokenStream {
24    // Original data
25    let vis = derive_input.vis;
26    let ident = derive_input.ident;
27    let attrs = derive_input.attrs;
28    let generics_params = derive_input.generics.params;
29    let generics_where = derive_input.generics.where_clause;
30
31    // New generics to insert into struct
32    let mut generated_generics = Vec::new();
33
34    let fields: Vec<_> = struct_data
35        .fields
36        .iter()
37        .map(|field| {
38            let vis = &field.vis;
39            let attrs = &field.attrs;
40            let ident = &field.ident;
41            let field_type = field.ty.to_token_stream().to_string();
42
43            // Field to modify
44            if field_type.starts_with("impl ") {
45                let trait_str = field_type
46                    .strip_prefix("impl ")
47                    .expect("Should have prefix");
48                let impl_trait: Expr = syn::parse_str(trait_str).expect("Should be an expression");
49                let generic = format_ident!("_IJ_{}", generated_generics.len());
50
51                generated_generics.push(quote!(
52                    #generic: #impl_trait,
53                ));
54
55                quote!(
56                    #(#attrs)*
57                    #vis #ident: #generic,
58                )
59            }
60            // Keep field as is
61            else {
62                quote!(#field,)
63            }
64        })
65        .collect();
66
67    quote!(
68        #(#attrs)*
69        #vis struct #ident <#(#generated_generics)* #generics_params> #generics_where {
70            #(#fields)*
71        }
72    )
73    .into()
74}