Skip to main content

crunchyroll_rs_internal/
lib.rs

1mod util;
2
3use crate::util::IdentList;
4use darling::FromDeriveInput;
5use proc_macro::TokenStream;
6use quote::{ToTokens, quote};
7use syn::__private::{Span, TokenStream2};
8use syn::spanned::Spanned;
9use syn::{
10    Data, DeriveInput, GenericArgument, Ident, Path, PathArguments, PathSegment, Type,
11    parse_macro_input,
12};
13
14#[derive(FromDeriveInput)]
15#[darling(attributes(request))]
16struct DeriveRequestOpts {
17    executor: Option<IdentList>,
18}
19
20#[proc_macro_derive(Request, attributes(request))]
21pub fn derive_request(input: TokenStream) -> TokenStream {
22    let derive_input = parse_macro_input!(input as DeriveInput);
23    if !matches!(derive_input.data, Data::Struct(_)) && !matches!(derive_input.data, Data::Enum(_))
24    {
25        return syn::Error::new(derive_input.span(), "Only allowed on structs and enums")
26            .to_compile_error()
27            .into();
28    }
29
30    let DeriveInput {
31        ident,
32        generics,
33        data,
34        ..
35    } = &derive_input;
36    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
37
38    let request_opts = DeriveRequestOpts::from_derive_input(&derive_input).unwrap();
39    let mut executor_fields = request_opts.executor.as_ref().map(|f| f.to_vec());
40
41    let mut impl_executor = vec![];
42
43    match data {
44        Data::Struct(data_struct) => {
45            for field in data_struct.fields.iter() {
46                let Some(ident) = &field.ident else {
47                    continue;
48                };
49
50                for (i, executor_ident) in executor_fields.iter().flatten().enumerate() {
51                    if ident != executor_ident {
52                        continue;
53                    }
54
55                    let Type::Path(ty) = field.ty.clone() else {
56                        unreachable!()
57                    };
58                    impl_executor.push(derive_request_check(quote! { self.#ident }, &ty.path));
59
60                    executor_fields.as_mut().unwrap().remove(i);
61                    break;
62                }
63
64                if let Type::Path(ty) = field.clone().ty {
65                    let segment = ty.path.segments.last().unwrap();
66                    if segment.ident == "Arc" && segment_types(segment)[0].is_ident("Executor") {
67                        impl_executor.push(quote! {
68                            self.#ident = executor.clone();
69                        })
70                    }
71                }
72            }
73        }
74        Data::Enum(_) if request_opts.executor.is_some() => {
75            return syn::Error::new(
76                request_opts.executor.unwrap().span(),
77                "Executor fields aren't allowed on enums",
78            )
79            .to_compile_error()
80            .into();
81        }
82        _ => (),
83    }
84
85    if let Some(first_field) = executor_fields.iter().flatten().next() {
86        return syn::Error::new(
87            first_field.span(),
88            format!("Executor field not found: {first_field}"),
89        )
90        .to_compile_error()
91        .into();
92    }
93
94    let expanded = quote! {
95        impl #impl_generics crate::Request for #ident #ty_generics # where_clause {
96            async fn __set_executor(&mut self, executor: std::sync::Arc<crate::Executor>) {
97                #(#impl_executor)*
98            }
99        }
100    };
101    expanded.into()
102}
103
104fn derive_request_check(set_path: TokenStream2, path: &Path) -> TokenStream2 {
105    let segment = path.segments.last().unwrap();
106
107    let _deep_set_path = set_path.to_string();
108    let deep_set_path = _deep_set_path.split('.').next_back().unwrap();
109
110    if segment.ident == "Option" {
111        let options_set_path = Ident::new(
112            format!("{}{}", "option_", deep_set_path).as_str(),
113            Span::call_site(),
114        );
115        let ty = &segment_types(segment)[0];
116        let check = derive_request_check(options_set_path.to_token_stream(), ty);
117        quote! {
118            if let Some(#options_set_path) = &mut #set_path {
119                #check
120            }
121        }
122    } else if segment.ident == "Vec" {
123        let vec_set_path = Ident::new(
124            format!("{}{}", "vec_", deep_set_path).as_str(),
125            Span::call_site(),
126        );
127        let ty = &segment_types(segment)[0];
128        let check = derive_request_check(vec_set_path.to_token_stream(), ty);
129        quote! {
130            for #vec_set_path in #set_path.iter_mut() {
131                #check
132            }
133        }
134    } else if segment.ident == "HashMap" {
135        let hash_map_set_path = Ident::new(
136            format!("{}{}", "hash_map_", deep_set_path).as_str(),
137            Span::call_site(),
138        );
139        let ty = &segment_types(segment)[1];
140        let check = derive_request_check(hash_map_set_path.to_token_stream(), ty);
141        quote! {
142            for #hash_map_set_path in #set_path.values_mut() {
143                #check
144            }
145        }
146    } else {
147        quote! {
148            #set_path.__set_executor(executor.clone()).await;
149        }
150    }
151}
152
153fn segment_types(segment: &PathSegment) -> Vec<Path> {
154    let args = if let PathArguments::AngleBracketed(args) = &segment.arguments {
155        &args.args
156    } else {
157        unreachable!()
158    };
159    args.iter()
160        .map(|a| {
161            if let GenericArgument::Type(t) = a {
162                t
163            } else {
164                unreachable!()
165            }
166        })
167        .map(|t| {
168            if let Type::Path(ty) = t {
169                ty.path.clone()
170            } else {
171                unreachable!()
172            }
173        })
174        .collect()
175}