cynic_codegen/input_object_derive/
mod.rs

1use {
2    proc_macro2::{Span, TokenStream},
3    std::collections::HashSet,
4};
5
6use crate::{
7    error::Errors,
8    generics_for_serde,
9    idents::RenameAll,
10    schema::{
11        types::{InputObjectType, InputValue},
12        Schema,
13    },
14    suggestions::FieldSuggestionError,
15};
16
17mod field_serializer;
18use field_serializer::FieldSerializer;
19
20pub(crate) mod input;
21
22#[cfg(test)]
23mod tests;
24
25pub use input::InputObjectDeriveInput;
26use {crate::suggestions::guess_field, input::InputObjectDeriveField};
27
28pub fn input_object_derive(ast: &syn::DeriveInput) -> Result<TokenStream, syn::Error> {
29    use darling::FromDeriveInput;
30
31    let struct_span = ast.ident.span();
32
33    match InputObjectDeriveInput::from_derive_input(ast) {
34        Ok(input) => {
35            input_object_derive_impl(input, struct_span).or_else(|e| Ok(e.to_compile_errors()))
36        }
37        Err(e) => Ok(e.write_errors()),
38    }
39}
40
41pub fn input_object_derive_impl(
42    input: InputObjectDeriveInput,
43    struct_span: Span,
44) -> Result<TokenStream, Errors> {
45    use quote::quote;
46
47    let schema = Schema::new(input.schema_input()?);
48
49    let input_object = schema
50        .lookup::<InputObjectType<'_>>(&input.graphql_type_name())
51        .map_err(|e| syn::Error::new(input.graphql_type_span(), e))?;
52
53    let rename_all = input.rename_all.unwrap_or(RenameAll::CamelCase);
54
55    if let darling::ast::Data::Struct(fields) = &input.data {
56        let ident = &input.ident;
57        let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
58        let generics_with_ser = generics_for_serde::with_serialize_bounds(&input.generics);
59        let (impl_generics_with_ser, _, where_clause_with_ser) = generics_with_ser.split_for_impl();
60        let input_marker_ident = input_object.marker_ident().to_rust_ident();
61        let schema_module = input.schema_module();
62        let graphql_type_name = proc_macro2::Literal::string(input_object.name.as_ref());
63
64        let pairs = pair_fields(
65            &fields.fields,
66            input_object,
67            rename_all,
68            input.require_all_fields,
69            &struct_span,
70        )?;
71
72        let field_serializers = pairs
73            .into_iter()
74            .map(|(rust_field, graphql_field)| {
75                FieldSerializer::new(rust_field, graphql_field, &schema_module)
76            })
77            .collect::<Vec<_>>();
78
79        let errors = field_serializers
80            .iter()
81            .filter_map(|fs| fs.validate())
82            .collect::<Errors>();
83
84        if !errors.is_empty() {
85            return Ok(errors.to_compile_errors());
86        }
87
88        let typechecks = field_serializers
89            .iter()
90            .map(|fs| fs.type_check(&impl_generics, where_clause, &schema));
91        let map_serializer_ident = proc_macro2::Ident::new("map_serializer", Span::call_site());
92        let field_inserts = field_serializers
93            .iter()
94            .map(|fs| fs.field_insert_call(&map_serializer_ident));
95
96        let map_len = field_serializers.len();
97
98        Ok(quote! {
99            #[automatically_derived]
100            impl #impl_generics cynic::InputObject for #ident #ty_generics #where_clause_with_ser {
101                type SchemaType = #schema_module::#input_marker_ident;
102            }
103
104            #[automatically_derived]
105            impl #impl_generics_with_ser cynic::serde::Serialize for #ident #ty_generics #where_clause_with_ser {
106                fn serialize<__S>(&self, serializer: __S) -> Result<__S::Ok, __S::Error>
107                where
108                    __S: cynic::serde::Serializer,
109                {
110                    use cynic::serde::ser::SerializeMap;
111                    #(#typechecks)*
112
113                    let mut map_serializer = serializer.serialize_map(Some(#map_len))?;
114
115                    #(#field_inserts)*
116
117                    map_serializer.end()
118                }
119            }
120
121            cynic::impl_coercions!(#ident #ty_generics [#impl_generics] [#where_clause], #schema_module::#input_marker_ident);
122
123            #[automatically_derived]
124            impl #impl_generics #schema_module::variable::Variable for #ident #ty_generics #where_clause {
125                const TYPE: cynic::variables::VariableType = cynic::variables::VariableType::Named(#graphql_type_name);
126            }
127        })
128    } else {
129        Err(syn::Error::new(
130            struct_span,
131            "InputObject can only be derived on a struct".to_string(),
132        )
133        .into())
134    }
135}
136
137fn pair_fields<'a>(
138    fields: &'a [InputObjectDeriveField],
139    input_object_def: InputObjectType<'a>,
140    rename_all: RenameAll,
141    require_all_fields: bool,
142    struct_span: &Span,
143) -> Result<Vec<(&'a InputObjectDeriveField, InputValue<'a>)>, Errors> {
144    let mut result = Vec::new();
145    let mut unknown_fields = Vec::new();
146
147    for field in fields {
148        let ident = field.graphql_ident(rename_all);
149        match input_object_def.field(&ident) {
150            Some(schema_field) => result.push((field, schema_field)),
151            None => unknown_fields.push(field),
152        }
153    }
154
155    let required_fields = if require_all_fields {
156        input_object_def.fields.iter().collect::<HashSet<_>>()
157    } else {
158        input_object_def
159            .fields
160            .iter()
161            .filter(|f| f.is_required())
162            .collect::<HashSet<_>>()
163    };
164
165    let provided_fields = result
166        .iter()
167        .map(|(_, field)| field)
168        .cloned()
169        .collect::<HashSet<_>>();
170
171    let missing_fields = required_fields
172        .difference(&provided_fields)
173        .collect::<Vec<_>>();
174
175    if missing_fields.is_empty() && unknown_fields.is_empty() {
176        return Ok(result.into_iter().map(|(l, r)| (l, r.clone())).collect());
177    }
178
179    let field_candidates = input_object_def
180        .fields
181        .iter()
182        .map(|f| f.name.as_str())
183        .collect::<Vec<_>>();
184
185    let mut errors = unknown_fields
186        .into_iter()
187        .map(|field| {
188            let field_name = &field.graphql_ident(rename_all);
189            let graphql_name = field_name.graphql_name();
190            let expected_field = graphql_name.as_str();
191            let suggested_field = guess_field(field_candidates.iter().copied(), expected_field);
192            syn::Error::new(
193                field_name.span(),
194                FieldSuggestionError {
195                    expected_field,
196                    graphql_type_name: input_object_def.name.as_ref(),
197                    suggested_field,
198                },
199            )
200        })
201        .map(Errors::from)
202        .collect::<Errors>();
203
204    if !missing_fields.is_empty() {
205        let missing_fields_string = missing_fields
206            .into_iter()
207            .map(|f| f.name.as_str().to_string())
208            .collect::<Vec<_>>()
209            .join(", ");
210
211        errors.push(syn::Error::new(
212            *struct_span,
213            format!(
214                "This InputObject is missing these fields: {}",
215                missing_fields_string
216            ),
217        ))
218    }
219
220    Err(errors)
221}
222
223#[cfg(test)]
224mod test {
225    use assert_matches::assert_matches;
226
227    use crate::schema::SchemaInput;
228
229    use super::*;
230
231    static SCHEMA: &str = r#"
232        input TestType {
233            field_one: String!,
234            field_two: String
235        }
236        "#;
237
238    #[test]
239    fn test_join_fields_when_all_required() {
240        let schema = Schema::new(SchemaInput::from_sdl(SCHEMA).unwrap());
241        let input_object = schema.lookup("TestType").unwrap();
242
243        let fields = vec![InputObjectDeriveField {
244            ident: Some(proc_macro2::Ident::new("field_one", Span::call_site())),
245            ty: syn::parse_quote! { String },
246            rename: None,
247            skip_serializing_if: None,
248        }];
249
250        let result = pair_fields(
251            &fields,
252            input_object,
253            RenameAll::None,
254            true,
255            &Span::call_site(),
256        );
257
258        assert_matches!(result, Err(_))
259    }
260
261    #[test]
262    fn test_join_fields_when_required_field_missing() {
263        let schema = Schema::new(SchemaInput::from_sdl(SCHEMA).unwrap());
264        let input_object = schema.lookup("TestType").unwrap();
265
266        let fields = vec![InputObjectDeriveField {
267            ident: Some(proc_macro2::Ident::new("field_two", Span::call_site())),
268            ty: syn::parse_quote! { String },
269            rename: None,
270            skip_serializing_if: None,
271        }];
272
273        let result = pair_fields(
274            &fields,
275            input_object,
276            RenameAll::None,
277            false,
278            &Span::call_site(),
279        );
280
281        assert_matches!(result, Err(_))
282    }
283
284    #[test]
285    fn test_join_fields_when_not_required() {
286        let schema = Schema::new(SchemaInput::from_sdl(SCHEMA).unwrap());
287        let input_object = schema.lookup::<InputObjectType<'_>>("TestType").unwrap();
288
289        let fields = vec![InputObjectDeriveField {
290            ident: Some(proc_macro2::Ident::new("field_one", Span::call_site())),
291            ty: syn::parse_quote! { String },
292            rename: None,
293            skip_serializing_if: None,
294        }];
295
296        let result = pair_fields(
297            &fields,
298            input_object.clone(),
299            RenameAll::None,
300            false,
301            &Span::call_site(),
302        );
303
304        assert_matches!(result, Ok(_));
305
306        let (rust_field_ref, input_field_ref) = result.unwrap().into_iter().next().unwrap();
307        assert!(std::ptr::eq(rust_field_ref, fields.first().unwrap()));
308        assert_eq!(&input_field_ref, input_object.fields.first().unwrap());
309    }
310}