1#![warn(clippy::all)]
2#![warn(clippy::pedantic)]
3#![allow(clippy::missing_panics_doc)]
4
5use heck::ToLowerCamelCase;
6use proc_macro2::{Span, TokenStream};
7use quote::quote;
8use syn::{
9 parenthesized,
10 parse::Parse,
11 parse2, parse_macro_input,
12 punctuated::Punctuated,
13 token::{Comma, Paren},
14 Error, Ident, ImplItem, ItemImpl, LitStr, Path, Result, Token, Type, Visibility,
15};
16
17#[proc_macro_attribute]
18pub fn composable_object(
19 _: proc_macro::TokenStream,
20 item: proc_macro::TokenStream,
21) -> proc_macro::TokenStream {
22 let item_impl = parse_macro_input!(item as ItemImpl);
23 expand_composable_object(&item_impl).into()
24}
25
26#[proc_macro]
27pub fn composite_object(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
28 let input = parse_macro_input!(input as CompositeObjectInput);
29 let context = input
30 .context_ty
31 .map_or_else(|| parse2(quote! { () }).unwrap(), |input| input.ty);
32 expand_composite_object(&input.vis, &input.ident, &context, &input.composables).into()
33}
34
35struct CompositeObjectInput {
36 vis: Visibility,
37 ident: Ident,
38 context_ty: Option<CompositeObjectCustomContextType>,
39 #[allow(dead_code)]
40 paren: Paren,
41 composables: Punctuated<Path, Comma>,
42}
43
44impl Parse for CompositeObjectInput {
45 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
46 let vis = input.parse()?;
47 let ident = input.parse()?;
48 let context_ty = if input.peek(Token![<]) {
49 Some(input.parse()?)
50 } else {
51 None
52 };
53 let composables;
54 let paren = parenthesized!(composables in input);
55 Ok(Self {
56 vis,
57 ident,
58 context_ty,
59 paren,
60 composables: composables.parse_terminated(Path::parse)?,
61 })
62 }
63}
64
65struct CompositeObjectCustomContextType {
66 #[allow(dead_code)]
67 left_angle_bracket: Token![<],
68 #[allow(dead_code)]
69 context_ident: Ident,
70 #[allow(dead_code)]
71 eq_token: Token![=],
72 ty: Type,
73 #[allow(dead_code)]
74 right_angle_bracket: Token![>],
75}
76
77impl Parse for CompositeObjectCustomContextType {
78 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
79 let left_angle_bracket = input.parse()?;
80 let context_ident = input.parse::<Ident>()?;
81 if context_ident != "Context" {
82 return Err(Error::new(context_ident.span(), "expected `Context`"));
83 }
84 let eq_token = input.parse()?;
85 let ty = input.parse()?;
86 let right_angle_bracket = input.parse()?;
87 Ok(Self {
88 left_angle_bracket,
89 context_ident,
90 eq_token,
91 ty,
92 right_angle_bracket,
93 })
94 }
95}
96
97fn expand_composable_object(item_impl: &ItemImpl) -> TokenStream {
98 let ty = &item_impl.self_ty;
99
100 let fields = item_impl
101 .items
102 .iter()
103 .filter_map(|item| {
104 if let ImplItem::Method(method) = item {
105 Some(method)
106 } else {
107 None
108 }
109 })
110 .map(|method| {
111 LitStr::new(
112 &method.sig.ident.to_string().to_lower_camel_case(),
113 Span::call_site(),
114 )
115 });
116
117 quote! {
118 impl ::juniper_compose::ComposableObject for #ty {
119 fn fields() -> &'static [&'static str] {
120 &[#( #fields ),*]
121 }
122 }
123
124 #item_impl
125 }
126}
127
128fn expand_composite_object<P>(
129 vis: &Visibility,
130 name: &Ident,
131 context: &Type,
132 composables: &Punctuated<Path, P>,
133) -> TokenStream {
134 let name_lit = LitStr::new(&name.to_string(), Span::call_site());
135 let impl_graphql_type = expand_impl_graphql_type(name, &name_lit, composables.iter());
136 let impl_graphql_value =
137 expand_impl_graphql_value(name, &name_lit, context, composables.iter());
138 let impl_graphql_value_async =
139 expand_impl_graphql_value_async(name, &name_lit, composables.iter());
140 quote! {
141 #[derive(::std::default::Default)]
142 #vis struct #name;
143 #impl_graphql_type
144 #impl_graphql_value
145 #impl_graphql_value_async
146 }
147}
148
149fn expand_impl_graphql_type<'a>(
150 name: &Ident,
151 name_lit: &LitStr,
152 composables: impl IntoIterator<Item = &'a Path>,
153) -> TokenStream {
154 let composables = composables.into_iter();
155 quote! {
156 impl ::juniper::GraphQLType for #name {
157 fn name(info: &Self::TypeInfo) -> ::std::option::Option<&str> {
158 ::std::option::Option::Some(#name_lit)
159 }
160
161 fn meta<'r>(
162 info: &Self::TypeInfo,
163 registry: &mut ::juniper::executor::Registry<'r, ::juniper::DefaultScalarValue>
164 ) -> ::juniper::meta::MetaType<'r, ::juniper::DefaultScalarValue>
165 where
166 ::juniper::DefaultScalarValue: 'r
167 {
168 let mut fields = ::std::vec![];
169 let mut seen_field_names = ::std::collections::HashSet::<&str>::new();
170
171 #(
172 let composable_meta = <#composables as ::juniper::GraphQLType>::meta(info, registry);
173
174 for field_name in <#composables as ::juniper_compose::ComposableObject>::fields() {
175 if !seen_field_names.insert(field_name) {
176 ::std::panic!("Conflicting field in composed objects: {}", field_name);
177 }
178
179 let composable_field = composable_meta
180 .field_by_name(field_name)
181 .unwrap_or_else(|| {
182 ::std::panic!(
183 "Incorrect implementation of ComposableObject on type {}: unknown field {}",
184 <#composables as ::juniper::GraphQLType>::name(&()).unwrap_or("<anonymous>"), field_name
185 )
186 });
187
188 fields.push(::juniper::meta::Field {
189 name: composable_field.name.clone(),
190 description: composable_field.description.clone(),
191 arguments: composable_field.arguments.as_ref().map(|arguments| {
192 arguments
193 .iter()
194 .map(|argument| ::juniper::meta::Argument {
195 name: argument.name.clone(),
196 description: argument.description.clone(),
197 arg_type: ::juniper_compose::type_to_owned(&argument.arg_type),
198 default_value: argument.default_value.clone(),
199 })
200 .collect()
201 }),
202 field_type: ::juniper_compose::type_to_owned(&composable_field.field_type),
203 deprecation_status: composable_field.deprecation_status.clone(),
204 });
205 }
206 )*
207
208 registry.build_object_type::<Self>(&(), &fields).into_meta()
209 }
210 }
211 }
212}
213
214fn expand_impl_graphql_value<'a>(
215 name: &Ident,
216 name_lit: &LitStr,
217 context: &Type,
218 composables: impl IntoIterator<Item = &'a Path>,
219) -> TokenStream {
220 let composables = composables.into_iter();
221 quote! {
222 impl ::juniper::GraphQLValue for #name {
223 type Context = #context;
224 type TypeInfo = ();
225
226 fn type_name<'i>(&self, info: &'i Self::TypeInfo) -> Option<&'i str> {
227 <Self as ::juniper::GraphQLType>::name(info)
228 }
229
230 fn resolve_field(
231 &self,
232 info: &Self::TypeInfo,
233 field_name: &str,
234 arguments: &::juniper::Arguments<'_, ::juniper::DefaultScalarValue>,
235 executor: &::juniper::executor::Executor<'_, '_, Self::Context, ::juniper::DefaultScalarValue>
236 ) -> ::juniper::executor::ExecutionResult<::juniper::DefaultScalarValue> {
237 #(
238 if <#composables as ::juniper_compose::ComposableObject>::fields().contains(&field_name) {
239 return <#composables as ::juniper::GraphQLValue>::resolve_field(
240 &<#composables as ::std::default::Default>::default(),
241 info,
242 field_name,
243 arguments,
244 executor
245 );
246 }
247 )*
248 Err(::juniper::FieldError::from(::std::format!(
249 "Field `{}` not found on type `{}`",
250 field_name,
251 #name_lit,
252 )))
253 }
254
255 fn concrete_type_name(
256 &self,
257 context: &Self::Context,
258 info: &Self::TypeInfo
259 ) -> String {
260 String::from(#name_lit)
261 }
262 }
263 }
264}
265
266fn expand_impl_graphql_value_async<'a>(
267 name: &Ident,
268 name_lit: &LitStr,
269 composables: impl IntoIterator<Item = &'a Path>,
270) -> TokenStream {
271 let composables = composables.into_iter();
272 quote! {
273 impl ::juniper::GraphQLValueAsync for #name
274 where
275 Self::TypeInfo: Sync,
276 Self::Context: Sync,
277 {
278 fn resolve_field_async<'a>(
279 &'a self,
280 info: &'a Self::TypeInfo,
281 field_name: &'a str,
282 arguments: &'a ::juniper::Arguments<'_, ::juniper::DefaultScalarValue>,
283 executor: &'a ::juniper::executor::Executor<'_, '_, Self::Context, ::juniper::DefaultScalarValue>
284 ) -> ::juniper::BoxFuture<'a, ::juniper::executor::ExecutionResult<::juniper::DefaultScalarValue>> {
285 #(
286 if <#composables as ::juniper_compose::ComposableObject>::fields().contains(&field_name) {
287 return ::std::boxed::Box::pin(async move {
288 <#composables as ::juniper::GraphQLValueAsync>::resolve_field_async(
289 &<#composables as ::std::default::Default>::default(),
290 info,
291 field_name,
292 arguments,
293 executor
294 ).await
295 })
296 }
297 )*
298 ::std::boxed::Box::pin(async move { Err(::juniper::FieldError::from(::std::format!(
299 "Field `{}` not found on type `{}`",
300 field_name,
301 #name_lit,
302 ))) })
303 }
304 }
305 }
306}