cynic_codegen/input_object_derive/
mod.rs1use {
2 proc_macro2::{Span, TokenStream},
3 std::collections::HashSet,
4};
5
6use crate::{
7 error::Errors,
8 generics_for_serde,
9 idents::RenameAll,
10 schema::{
11 types::{InputObjectType, InputValue},
12 Schema,
13 },
14 suggestions::FieldSuggestionError,
15};
16
17mod field_serializer;
18use field_serializer::FieldSerializer;
19
20pub(crate) mod input;
21
22#[cfg(test)]
23mod tests;
24
25pub use input::InputObjectDeriveInput;
26use {crate::suggestions::guess_field, input::InputObjectDeriveField};
27
28pub fn input_object_derive(ast: &syn::DeriveInput) -> Result<TokenStream, syn::Error> {
29 use darling::FromDeriveInput;
30
31 let struct_span = ast.ident.span();
32
33 match InputObjectDeriveInput::from_derive_input(ast) {
34 Ok(input) => {
35 input_object_derive_impl(input, struct_span).or_else(|e| Ok(e.to_compile_errors()))
36 }
37 Err(e) => Ok(e.write_errors()),
38 }
39}
40
41pub fn input_object_derive_impl(
42 input: InputObjectDeriveInput,
43 struct_span: Span,
44) -> Result<TokenStream, Errors> {
45 use quote::quote;
46
47 let schema = Schema::new(input.schema_input()?);
48
49 let input_object = schema
50 .lookup::<InputObjectType<'_>>(&input.graphql_type_name())
51 .map_err(|e| syn::Error::new(input.graphql_type_span(), e))?;
52
53 let rename_all = input.rename_all.unwrap_or(RenameAll::CamelCase);
54
55 if let darling::ast::Data::Struct(fields) = &input.data {
56 let ident = &input.ident;
57 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
58 let generics_with_ser = generics_for_serde::with_serialize_bounds(&input.generics);
59 let (impl_generics_with_ser, _, where_clause_with_ser) = generics_with_ser.split_for_impl();
60 let input_marker_ident = input_object.marker_ident().to_rust_ident();
61 let schema_module = input.schema_module();
62 let graphql_type_name = proc_macro2::Literal::string(input_object.name.as_ref());
63
64 let pairs = pair_fields(
65 &fields.fields,
66 input_object,
67 rename_all,
68 input.require_all_fields,
69 &struct_span,
70 )?;
71
72 let field_serializers = pairs
73 .into_iter()
74 .map(|(rust_field, graphql_field)| {
75 FieldSerializer::new(rust_field, graphql_field, &schema_module)
76 })
77 .collect::<Vec<_>>();
78
79 let errors = field_serializers
80 .iter()
81 .filter_map(|fs| fs.validate())
82 .collect::<Errors>();
83
84 if !errors.is_empty() {
85 return Ok(errors.to_compile_errors());
86 }
87
88 let typechecks = field_serializers
89 .iter()
90 .map(|fs| fs.type_check(&impl_generics, where_clause, &schema));
91 let map_serializer_ident = proc_macro2::Ident::new("map_serializer", Span::call_site());
92 let field_inserts = field_serializers
93 .iter()
94 .map(|fs| fs.field_insert_call(&map_serializer_ident));
95
96 let map_len = field_serializers.len();
97
98 Ok(quote! {
99 #[automatically_derived]
100 impl #impl_generics cynic::InputObject for #ident #ty_generics #where_clause_with_ser {
101 type SchemaType = #schema_module::#input_marker_ident;
102 }
103
104 #[automatically_derived]
105 impl #impl_generics_with_ser cynic::serde::Serialize for #ident #ty_generics #where_clause_with_ser {
106 fn serialize<__S>(&self, serializer: __S) -> Result<__S::Ok, __S::Error>
107 where
108 __S: cynic::serde::Serializer,
109 {
110 use cynic::serde::ser::SerializeMap;
111 #(#typechecks)*
112
113 let mut map_serializer = serializer.serialize_map(Some(#map_len))?;
114
115 #(#field_inserts)*
116
117 map_serializer.end()
118 }
119 }
120
121 cynic::impl_coercions!(#ident #ty_generics [#impl_generics] [#where_clause], #schema_module::#input_marker_ident);
122
123 #[automatically_derived]
124 impl #impl_generics #schema_module::variable::Variable for #ident #ty_generics #where_clause {
125 const TYPE: cynic::variables::VariableType = cynic::variables::VariableType::Named(#graphql_type_name);
126 }
127 })
128 } else {
129 Err(syn::Error::new(
130 struct_span,
131 "InputObject can only be derived on a struct".to_string(),
132 )
133 .into())
134 }
135}
136
137fn pair_fields<'a>(
138 fields: &'a [InputObjectDeriveField],
139 input_object_def: InputObjectType<'a>,
140 rename_all: RenameAll,
141 require_all_fields: bool,
142 struct_span: &Span,
143) -> Result<Vec<(&'a InputObjectDeriveField, InputValue<'a>)>, Errors> {
144 let mut result = Vec::new();
145 let mut unknown_fields = Vec::new();
146
147 for field in fields {
148 let ident = field.graphql_ident(rename_all);
149 match input_object_def.field(&ident) {
150 Some(schema_field) => result.push((field, schema_field)),
151 None => unknown_fields.push(field),
152 }
153 }
154
155 let required_fields = if require_all_fields {
156 input_object_def.fields.iter().collect::<HashSet<_>>()
157 } else {
158 input_object_def
159 .fields
160 .iter()
161 .filter(|f| f.is_required())
162 .collect::<HashSet<_>>()
163 };
164
165 let provided_fields = result
166 .iter()
167 .map(|(_, field)| field)
168 .cloned()
169 .collect::<HashSet<_>>();
170
171 let missing_fields = required_fields
172 .difference(&provided_fields)
173 .collect::<Vec<_>>();
174
175 if missing_fields.is_empty() && unknown_fields.is_empty() {
176 return Ok(result.into_iter().map(|(l, r)| (l, r.clone())).collect());
177 }
178
179 let field_candidates = input_object_def
180 .fields
181 .iter()
182 .map(|f| f.name.as_str())
183 .collect::<Vec<_>>();
184
185 let mut errors = unknown_fields
186 .into_iter()
187 .map(|field| {
188 let field_name = &field.graphql_ident(rename_all);
189 let graphql_name = field_name.graphql_name();
190 let expected_field = graphql_name.as_str();
191 let suggested_field = guess_field(field_candidates.iter().copied(), expected_field);
192 syn::Error::new(
193 field_name.span(),
194 FieldSuggestionError {
195 expected_field,
196 graphql_type_name: input_object_def.name.as_ref(),
197 suggested_field,
198 },
199 )
200 })
201 .map(Errors::from)
202 .collect::<Errors>();
203
204 if !missing_fields.is_empty() {
205 let missing_fields_string = missing_fields
206 .into_iter()
207 .map(|f| f.name.as_str().to_string())
208 .collect::<Vec<_>>()
209 .join(", ");
210
211 errors.push(syn::Error::new(
212 *struct_span,
213 format!(
214 "This InputObject is missing these fields: {}",
215 missing_fields_string
216 ),
217 ))
218 }
219
220 Err(errors)
221}
222
223#[cfg(test)]
224mod test {
225 use assert_matches::assert_matches;
226
227 use crate::schema::SchemaInput;
228
229 use super::*;
230
231 static SCHEMA: &str = r#"
232 input TestType {
233 field_one: String!,
234 field_two: String
235 }
236 "#;
237
238 #[test]
239 fn test_join_fields_when_all_required() {
240 let schema = Schema::new(SchemaInput::from_sdl(SCHEMA).unwrap());
241 let input_object = schema.lookup("TestType").unwrap();
242
243 let fields = vec![InputObjectDeriveField {
244 ident: Some(proc_macro2::Ident::new("field_one", Span::call_site())),
245 ty: syn::parse_quote! { String },
246 rename: None,
247 skip_serializing_if: None,
248 }];
249
250 let result = pair_fields(
251 &fields,
252 input_object,
253 RenameAll::None,
254 true,
255 &Span::call_site(),
256 );
257
258 assert_matches!(result, Err(_))
259 }
260
261 #[test]
262 fn test_join_fields_when_required_field_missing() {
263 let schema = Schema::new(SchemaInput::from_sdl(SCHEMA).unwrap());
264 let input_object = schema.lookup("TestType").unwrap();
265
266 let fields = vec![InputObjectDeriveField {
267 ident: Some(proc_macro2::Ident::new("field_two", Span::call_site())),
268 ty: syn::parse_quote! { String },
269 rename: None,
270 skip_serializing_if: None,
271 }];
272
273 let result = pair_fields(
274 &fields,
275 input_object,
276 RenameAll::None,
277 false,
278 &Span::call_site(),
279 );
280
281 assert_matches!(result, Err(_))
282 }
283
284 #[test]
285 fn test_join_fields_when_not_required() {
286 let schema = Schema::new(SchemaInput::from_sdl(SCHEMA).unwrap());
287 let input_object = schema.lookup::<InputObjectType<'_>>("TestType").unwrap();
288
289 let fields = vec![InputObjectDeriveField {
290 ident: Some(proc_macro2::Ident::new("field_one", Span::call_site())),
291 ty: syn::parse_quote! { String },
292 rename: None,
293 skip_serializing_if: None,
294 }];
295
296 let result = pair_fields(
297 &fields,
298 input_object.clone(),
299 RenameAll::None,
300 false,
301 &Span::call_site(),
302 );
303
304 assert_matches!(result, Ok(_));
305
306 let (rust_field_ref, input_field_ref) = result.unwrap().into_iter().next().unwrap();
307 assert!(std::ptr::eq(rust_field_ref, fields.first().unwrap()));
308 assert_eq!(&input_field_ref, input_object.fields.first().unwrap());
309 }
310}