1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use darling::FromDeriveInput;
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::__private::{Span, TokenStream2};
use syn::{
    parse_macro_input, Data, DeriveInput, GenericArgument, Ident, Path, PathArguments, PathSegment,
    Type,
};

#[derive(FromDeriveInput)]
#[darling(attributes(request))]
struct DeriveRequestOpts {
    executor: Option<darling::util::PathList>,
}

#[proc_macro_derive(Request, attributes(request))]
pub fn derive_request(input: TokenStream) -> TokenStream {
    let derive_input = parse_macro_input!(input as DeriveInput);
    let DeriveInput {
        ident,
        generics,
        data,
        ..
    } = &derive_input;
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    let request_opts: DeriveRequestOpts =
        DeriveRequestOpts::from_derive_input(&derive_input).unwrap();
    let executor_fields = request_opts.executor.unwrap_or_default();

    let mut impl_executor = vec![];

    if let Data::Struct(data_struct) = data {
        for field in data_struct.fields.iter() {
            if let Some(ident) = &field.ident {
                for path in executor_fields.iter() {
                    if path.is_ident(ident) {
                        let ty = if let Type::Path(ty) = field.clone().ty {
                            ty
                        } else {
                            unreachable!()
                        };
                        impl_executor.push(derive_request_check(quote! { self.#ident }, &ty.path));
                        continue;
                    }
                }

                if let Type::Path(ty) = field.clone().ty {
                    let segment = ty.path.segments.last().unwrap();
                    if segment.ident == "Arc" && segment_types(segment)[0].is_ident("Executor") {
                        impl_executor.push(quote! {
                            self.#ident = executor.clone();
                        })
                    }
                }
            }
        }
    };

    let expanded = quote! {
        #[async_trait::async_trait]
        impl #impl_generics crate::Request for #ident #ty_generics # where_clause {
            async fn __set_executor(&mut self, executor: std::sync::Arc<crate::Executor>) {
                #(#impl_executor)*
            }
        }
    };
    expanded.into()
}

fn derive_request_check(set_path: TokenStream2, path: &Path) -> TokenStream2 {
    let segment = path.segments.last().unwrap();

    let _deep_set_path = set_path.to_string();
    let deep_set_path = _deep_set_path.split('.').last().unwrap();

    if segment.ident == "Option" {
        let options_set_path = Ident::new(
            format!("{}{}", "option_", deep_set_path).as_str(),
            Span::call_site(),
        );
        let ty = &segment_types(segment)[0];
        let check = derive_request_check(options_set_path.to_token_stream(), ty);
        quote! {
            if let Some(#options_set_path) = &mut #set_path {
                #check
            }
        }
    } else if segment.ident == "Vec" {
        let vec_set_path = Ident::new(
            format!("{}{}", "vec_", deep_set_path).as_str(),
            Span::call_site(),
        );
        let ty = &segment_types(segment)[0];
        let check = derive_request_check(vec_set_path.to_token_stream(), ty);
        quote! {
            for #vec_set_path in #set_path.iter_mut() {
                #check
            }
        }
    } else if segment.ident == "HashMap" {
        let hash_map_set_path = Ident::new(
            format!("{}{}", "hash_map_", deep_set_path).as_str(),
            Span::call_site(),
        );
        let ty = &segment_types(segment)[1];
        let check = derive_request_check(hash_map_set_path.to_token_stream(), ty);
        quote! {
            for #hash_map_set_path in #set_path.values_mut() {
                #check
            }
        }
    } else {
        quote! {
            #set_path.__set_executor(executor.clone()).await;
        }
    }
}

fn segment_types(segment: &PathSegment) -> Vec<Path> {
    let args = if let PathArguments::AngleBracketed(args) = &segment.arguments {
        &args.args
    } else {
        unreachable!()
    };
    args.iter()
        .map(|a| {
            if let GenericArgument::Type(t) = a {
                t
            } else {
                unreachable!()
            }
        })
        .map(|t| {
            if let Type::Path(ty) = t {
                ty.path.clone()
            } else {
                unreachable!()
            }
        })
        .collect()
}