juniper_compose_macros/
lib.rs

1#![warn(clippy::all)]
2#![warn(clippy::pedantic)]
3#![allow(clippy::missing_panics_doc)]
4
5use heck::ToLowerCamelCase;
6use proc_macro2::{Span, TokenStream};
7use quote::quote;
8use syn::{
9    parenthesized,
10    parse::Parse,
11    parse2, parse_macro_input,
12    punctuated::Punctuated,
13    token::{Comma, Paren},
14    Error, Ident, ImplItem, ItemImpl, LitStr, Path, Result, Token, Type, Visibility,
15};
16
17#[proc_macro_attribute]
18pub fn composable_object(
19    _: proc_macro::TokenStream,
20    item: proc_macro::TokenStream,
21) -> proc_macro::TokenStream {
22    let item_impl = parse_macro_input!(item as ItemImpl);
23    expand_composable_object(&item_impl).into()
24}
25
26#[proc_macro]
27pub fn composite_object(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
28    let input = parse_macro_input!(input as CompositeObjectInput);
29    let context = input
30        .context_ty
31        .map_or_else(|| parse2(quote! { () }).unwrap(), |input| input.ty);
32    expand_composite_object(&input.vis, &input.ident, &context, &input.composables).into()
33}
34
35struct CompositeObjectInput {
36    vis: Visibility,
37    ident: Ident,
38    context_ty: Option<CompositeObjectCustomContextType>,
39    #[allow(dead_code)]
40    paren: Paren,
41    composables: Punctuated<Path, Comma>,
42}
43
44impl Parse for CompositeObjectInput {
45    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
46        let vis = input.parse()?;
47        let ident = input.parse()?;
48        let context_ty = if input.peek(Token![<]) {
49            Some(input.parse()?)
50        } else {
51            None
52        };
53        let composables;
54        let paren = parenthesized!(composables in input);
55        Ok(Self {
56            vis,
57            ident,
58            context_ty,
59            paren,
60            composables: composables.parse_terminated(Path::parse)?,
61        })
62    }
63}
64
65struct CompositeObjectCustomContextType {
66    #[allow(dead_code)]
67    left_angle_bracket: Token![<],
68    #[allow(dead_code)]
69    context_ident: Ident,
70    #[allow(dead_code)]
71    eq_token: Token![=],
72    ty: Type,
73    #[allow(dead_code)]
74    right_angle_bracket: Token![>],
75}
76
77impl Parse for CompositeObjectCustomContextType {
78    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
79        let left_angle_bracket = input.parse()?;
80        let context_ident = input.parse::<Ident>()?;
81        if context_ident != "Context" {
82            return Err(Error::new(context_ident.span(), "expected `Context`"));
83        }
84        let eq_token = input.parse()?;
85        let ty = input.parse()?;
86        let right_angle_bracket = input.parse()?;
87        Ok(Self {
88            left_angle_bracket,
89            context_ident,
90            eq_token,
91            ty,
92            right_angle_bracket,
93        })
94    }
95}
96
97fn expand_composable_object(item_impl: &ItemImpl) -> TokenStream {
98    let ty = &item_impl.self_ty;
99
100    let fields = item_impl
101        .items
102        .iter()
103        .filter_map(|item| {
104            if let ImplItem::Method(method) = item {
105                Some(method)
106            } else {
107                None
108            }
109        })
110        .map(|method| {
111            LitStr::new(
112                &method.sig.ident.to_string().to_lower_camel_case(),
113                Span::call_site(),
114            )
115        });
116
117    quote! {
118        impl ::juniper_compose::ComposableObject for #ty {
119            fn fields() -> &'static [&'static str] {
120                &[#( #fields ),*]
121            }
122        }
123
124        #item_impl
125    }
126}
127
128fn expand_composite_object<P>(
129    vis: &Visibility,
130    name: &Ident,
131    context: &Type,
132    composables: &Punctuated<Path, P>,
133) -> TokenStream {
134    let name_lit = LitStr::new(&name.to_string(), Span::call_site());
135    let impl_graphql_type = expand_impl_graphql_type(name, &name_lit, composables.iter());
136    let impl_graphql_value =
137        expand_impl_graphql_value(name, &name_lit, context, composables.iter());
138    let impl_graphql_value_async =
139        expand_impl_graphql_value_async(name, &name_lit, composables.iter());
140    quote! {
141        #[derive(::std::default::Default)]
142        #vis struct #name;
143        #impl_graphql_type
144        #impl_graphql_value
145        #impl_graphql_value_async
146    }
147}
148
149fn expand_impl_graphql_type<'a>(
150    name: &Ident,
151    name_lit: &LitStr,
152    composables: impl IntoIterator<Item = &'a Path>,
153) -> TokenStream {
154    let composables = composables.into_iter();
155    quote! {
156        impl ::juniper::GraphQLType for #name {
157            fn name(info: &Self::TypeInfo) -> ::std::option::Option<&str> {
158                ::std::option::Option::Some(#name_lit)
159            }
160
161            fn meta<'r>(
162                info: &Self::TypeInfo,
163                registry: &mut ::juniper::executor::Registry<'r, ::juniper::DefaultScalarValue>
164            ) -> ::juniper::meta::MetaType<'r, ::juniper::DefaultScalarValue>
165            where
166                ::juniper::DefaultScalarValue: 'r
167            {
168                let mut fields = ::std::vec![];
169                let mut seen_field_names = ::std::collections::HashSet::<&str>::new();
170
171                #(
172                    let composable_meta = <#composables as ::juniper::GraphQLType>::meta(info, registry);
173
174                    for field_name in <#composables as ::juniper_compose::ComposableObject>::fields() {
175                        if !seen_field_names.insert(field_name) {
176                            ::std::panic!("Conflicting field in composed objects: {}", field_name);
177                        }
178
179                        let composable_field = composable_meta
180                            .field_by_name(field_name)
181                            .unwrap_or_else(|| {
182                                ::std::panic!(
183                                    "Incorrect implementation of ComposableObject on type {}: unknown field {}",
184                                    <#composables as ::juniper::GraphQLType>::name(&()).unwrap_or("<anonymous>"), field_name
185                                )
186                            });
187
188                        fields.push(::juniper::meta::Field {
189                            name: composable_field.name.clone(),
190                            description: composable_field.description.clone(),
191                            arguments: composable_field.arguments.as_ref().map(|arguments| {
192                                arguments
193                                    .iter()
194                                    .map(|argument| ::juniper::meta::Argument {
195                                        name: argument.name.clone(),
196                                        description: argument.description.clone(),
197                                        arg_type: ::juniper_compose::type_to_owned(&argument.arg_type),
198                                        default_value: argument.default_value.clone(),
199                                    })
200                                    .collect()
201                            }),
202                            field_type: ::juniper_compose::type_to_owned(&composable_field.field_type),
203                            deprecation_status: composable_field.deprecation_status.clone(),
204                        });
205                    }
206                )*
207
208                registry.build_object_type::<Self>(&(), &fields).into_meta()
209            }
210        }
211    }
212}
213
214fn expand_impl_graphql_value<'a>(
215    name: &Ident,
216    name_lit: &LitStr,
217    context: &Type,
218    composables: impl IntoIterator<Item = &'a Path>,
219) -> TokenStream {
220    let composables = composables.into_iter();
221    quote! {
222        impl ::juniper::GraphQLValue for #name {
223            type Context = #context;
224            type TypeInfo = ();
225
226            fn type_name<'i>(&self, info: &'i Self::TypeInfo) -> Option<&'i str> {
227                <Self as ::juniper::GraphQLType>::name(info)
228            }
229
230            fn resolve_field(
231                &self,
232                info: &Self::TypeInfo,
233                field_name: &str,
234                arguments: &::juniper::Arguments<'_, ::juniper::DefaultScalarValue>,
235                executor: &::juniper::executor::Executor<'_, '_, Self::Context, ::juniper::DefaultScalarValue>
236            ) -> ::juniper::executor::ExecutionResult<::juniper::DefaultScalarValue> {
237                #(
238                    if <#composables as ::juniper_compose::ComposableObject>::fields().contains(&field_name) {
239                        return <#composables as ::juniper::GraphQLValue>::resolve_field(
240                            &<#composables as ::std::default::Default>::default(),
241                            info,
242                            field_name,
243                            arguments,
244                            executor
245                        );
246                    }
247                )*
248                Err(::juniper::FieldError::from(::std::format!(
249                    "Field `{}` not found on type `{}`",
250                    field_name,
251                    #name_lit,
252                )))
253            }
254
255            fn concrete_type_name(
256                &self,
257                context: &Self::Context,
258                info: &Self::TypeInfo
259            ) -> String {
260                String::from(#name_lit)
261            }
262        }
263    }
264}
265
266fn expand_impl_graphql_value_async<'a>(
267    name: &Ident,
268    name_lit: &LitStr,
269    composables: impl IntoIterator<Item = &'a Path>,
270) -> TokenStream {
271    let composables = composables.into_iter();
272    quote! {
273        impl ::juniper::GraphQLValueAsync for #name
274        where
275            Self::TypeInfo: Sync,
276            Self::Context: Sync,
277        {
278            fn resolve_field_async<'a>(
279                &'a self,
280                info: &'a Self::TypeInfo,
281                field_name: &'a str,
282                arguments: &'a ::juniper::Arguments<'_, ::juniper::DefaultScalarValue>,
283                executor: &'a ::juniper::executor::Executor<'_, '_, Self::Context, ::juniper::DefaultScalarValue>
284            ) -> ::juniper::BoxFuture<'a, ::juniper::executor::ExecutionResult<::juniper::DefaultScalarValue>> {
285                #(
286                    if <#composables as ::juniper_compose::ComposableObject>::fields().contains(&field_name) {
287                        return ::std::boxed::Box::pin(async move {
288                            <#composables as ::juniper::GraphQLValueAsync>::resolve_field_async(
289                                &<#composables as ::std::default::Default>::default(),
290                                info,
291                                field_name,
292                                arguments,
293                                executor
294                            ).await
295                        })
296                    }
297                )*
298                ::std::boxed::Box::pin(async move { Err(::juniper::FieldError::from(::std::format!(
299                    "Field `{}` not found on type `{}`",
300                    field_name,
301                    #name_lit,
302                ))) })
303            }
304        }
305    }
306}