1#![forbid(unsafe_code)]
2
3use std::collections::HashMap;
6
7use proc_macro::TokenStream;
8use proc_macro2::{Span, TokenStream as TokenStream2};
9use proc_macro_crate::{crate_name, FoundCrate};
10use quote::{format_ident, quote};
11use syn::{
12 ext::IdentExt, parse_macro_input, spanned::Spanned, Attribute, Data, DeriveInput, Field,
13 Fields, Ident, LitStr,
14};
15
16#[proc_macro_derive(FeatherSerialize, attributes(serde))]
17pub fn derive_feather_serialize(input: TokenStream) -> TokenStream {
18 let input = parse_macro_input!(input as DeriveInput);
19 match expand_serialize(&input) {
20 Ok(output) => output.into(),
21 Err(error) => error.into_compile_error().into(),
22 }
23}
24
25#[proc_macro_derive(FeatherDeserialize, attributes(serde))]
26pub fn derive_feather_deserialize(input: TokenStream) -> TokenStream {
27 let input = parse_macro_input!(input as DeriveInput);
28 match expand_deserialize(&input) {
29 Ok(output) => output.into(),
30 Err(error) => error.into_compile_error().into(),
31 }
32}
33
34struct ContainerAttrOptions {
35 rename: Option<LitStr>,
36}
37
38#[derive(Default)]
39struct FieldAttrOptions {
40 rename: Option<LitStr>,
41 default: bool,
42 skip_serializing: bool,
43 skip_deserializing: bool,
44}
45
46struct ParsedField {
47 ident: Ident,
48 ty: syn::Type,
49 serialized_name: LitStr,
50 default: bool,
51 skip_serializing: bool,
52 skip_deserializing: bool,
53}
54
55struct ParsedStruct {
56 ident: Ident,
57 struct_name: LitStr,
58 fields: Vec<ParsedField>,
59}
60
61#[derive(Clone, Copy)]
62enum WireDirection {
63 Serialize,
64 Deserialize,
65}
66
67impl WireDirection {
68 fn includes(self, field: &ParsedField) -> bool {
69 match self {
70 Self::Serialize => !field.skip_serializing,
71 Self::Deserialize => !field.skip_deserializing,
72 }
73 }
74
75 fn name(self) -> &'static str {
76 match self {
77 Self::Serialize => "serialization",
78 Self::Deserialize => "deserialization",
79 }
80 }
81}
82
83fn expand_serialize(input: &DeriveInput) -> syn::Result<TokenStream2> {
84 let parsed = parse_input(input, "FeatherSerialize")?;
85 let crate_path = serde_feather_path();
86
87 let included_fields: Vec<&ParsedField> = parsed
88 .fields
89 .iter()
90 .filter(|field| !field.skip_serializing)
91 .collect();
92
93 let field_count = included_fields.len();
94 let serialize_fields = included_fields.into_iter().map(|field| {
95 let field_ident = &field.ident;
96 let field_name = &field.serialized_name;
97 quote! {
98 #crate_path::serde::ser::SerializeStruct::serialize_field(
99 &mut state,
100 #field_name,
101 &self.#field_ident,
102 )?;
103 }
104 });
105
106 let struct_ident = &parsed.ident;
107 let struct_name = &parsed.struct_name;
108
109 Ok(quote! {
110 impl #crate_path::serde::ser::Serialize for #struct_ident {
111 fn serialize<S>(
112 &self,
113 serializer: S,
114 ) -> ::core::result::Result<S::Ok, S::Error>
115 where
116 S: #crate_path::serde::ser::Serializer,
117 {
118 let mut state = #crate_path::serde::ser::Serializer::serialize_struct(
119 serializer,
120 #struct_name,
121 #field_count,
122 )?;
123 #(#serialize_fields)*
124 #crate_path::serde::ser::SerializeStruct::end(state)
125 }
126 }
127 })
128}
129
130fn expand_deserialize(input: &DeriveInput) -> syn::Result<TokenStream2> {
131 let parsed = parse_input(input, "FeatherDeserialize")?;
132 let crate_path = serde_feather_path();
133
134 struct DeserBinding {
135 field_index: usize,
136 binding_ident: Ident,
137 field_name: LitStr,
138 field_ty: syn::Type,
139 default: bool,
140 }
141
142 let mut bindings = Vec::<DeserBinding>::new();
143 for (index, field) in parsed.fields.iter().enumerate() {
144 if field.skip_deserializing {
145 continue;
146 }
147
148 bindings.push(DeserBinding {
149 field_index: index,
150 binding_ident: format_ident!("__feather_field_{index}"),
151 field_name: field.serialized_name.clone(),
152 field_ty: field.ty.clone(),
153 default: field.default,
154 });
155 }
156
157 let field_bindings: Vec<TokenStream2> = bindings
158 .iter()
159 .map(|binding| {
160 let binding_ident = &binding.binding_ident;
161 let field_ty = &binding.field_ty;
162 quote! { let mut #binding_ident: ::core::option::Option<#field_ty> = ::core::option::Option::None; }
163 })
164 .collect();
165 let field_bindings_in_map = field_bindings.clone();
166 let field_bindings_in_seq = field_bindings;
167
168 let field_setter_match_arms = bindings.iter().enumerate().map(|(binding_index, binding)| {
169 let field_index = binding_index;
170 let binding_ident = &binding.binding_ident;
171 let field_name = &binding.field_name;
172 let field_ty = &binding.field_ty;
173 quote! {
174 #field_index => {
175 if #binding_ident.is_some() {
176 return ::core::result::Result::Err(
177 #crate_path::serde::de::Error::duplicate_field(#field_name),
178 );
179 }
180 #binding_ident = ::core::option::Option::Some(#crate_path::serde::de::MapAccess::next_value::<#field_ty>(&mut map)?);
181 }
182 }
183 });
184
185 let known_fields: Vec<LitStr> = bindings
186 .iter()
187 .map(|binding| binding.field_name.clone())
188 .collect();
189 let known_fields_in_map = known_fields.clone();
190
191 let construct_fields: Vec<TokenStream2> = parsed
192 .fields
193 .iter()
194 .enumerate()
195 .map(|(index, field)| {
196 let field_ident = &field.ident;
197 let field_name = &field.serialized_name;
198 if field.skip_deserializing {
199 return quote! {
200 #field_ident: ::core::default::Default::default()
201 };
202 }
203
204 let binding_ident = bindings
205 .iter()
206 .find(|binding| binding.field_index == index)
207 .expect("binding for non-skipped field")
208 .binding_ident
209 .clone();
210
211 if field.default {
212 quote! {
213 #field_ident: #binding_ident.unwrap_or_default()
214 }
215 } else {
216 quote! {
217 #field_ident: match #binding_ident {
218 ::core::option::Option::Some(value) => value,
219 ::core::option::Option::None => {
220 return ::core::result::Result::Err(
221 #crate_path::serde::de::Error::missing_field(#field_name),
222 );
223 }
224 }
225 }
226 }
227 })
228 .collect();
229 let construct_fields_in_map = construct_fields.clone();
230 let construct_fields_in_seq = construct_fields;
231
232 let seq_field_decode_steps = bindings.iter().enumerate().map(|(seq_index, binding)| {
233 let binding_ident = &binding.binding_ident;
234 let field_ty = &binding.field_ty;
235 if binding.default {
236 quote! {
237 if let ::core::option::Option::Some(value) =
238 #crate_path::serde::de::SeqAccess::next_element::<#field_ty>(&mut seq)?
239 {
240 #binding_ident = ::core::option::Option::Some(value);
241 }
242 }
243 } else {
244 quote! {
245 #binding_ident =
246 match #crate_path::serde::de::SeqAccess::next_element::<#field_ty>(&mut seq)? {
247 ::core::option::Option::Some(value) => ::core::option::Option::Some(value),
248 ::core::option::Option::None => {
249 return ::core::result::Result::Err(
250 #crate_path::serde::de::Error::invalid_length(#seq_index, &self),
251 );
252 }
253 };
254 }
255 }
256 });
257
258 let seq_expected_len = bindings.len();
259
260 let struct_ident = &parsed.ident;
261 let struct_name = &parsed.struct_name;
262
263 Ok(quote! {
264 impl<'de> #crate_path::serde::de::Deserialize<'de> for #struct_ident {
265 fn deserialize<D>(deserializer: D) -> ::core::result::Result<Self, D::Error>
266 where
267 D: #crate_path::serde::de::Deserializer<'de>,
268 {
269 struct __FeatherVisitor;
270
271 impl<'de> #crate_path::serde::de::Visitor<'de> for __FeatherVisitor {
272 type Value = #struct_ident;
273
274 fn expecting(
275 &self,
276 formatter: &mut ::core::fmt::Formatter<'_>,
277 ) -> ::core::fmt::Result {
278 ::core::write!(formatter, "struct {}", #struct_name)
279 }
280
281 fn visit_map<V>(
282 self,
283 mut map: V,
284 ) -> ::core::result::Result<Self::Value, V::Error>
285 where
286 V: #crate_path::serde::de::MapAccess<'de>,
287 {
288 const __FEATHER_FIELDS: &[&str] = &[#(#known_fields_in_map),*];
289 #(#field_bindings_in_map)*
290 while let ::core::option::Option::Some(key) = #crate_path::serde::de::MapAccess::next_key::<#crate_path::__private::OwnedFieldName>(&mut map)?
291 {
292 match #crate_path::__private::select_field_index(key.as_str(), __FEATHER_FIELDS) {
293 ::core::option::Option::Some(index) => match index {
294 #(#field_setter_match_arms)*
295 _ => {
296 let _: #crate_path::serde::de::IgnoredAny =
297 #crate_path::serde::de::MapAccess::next_value(&mut map)?;
298 }
299 },
300 ::core::option::Option::None => {
301 let _: #crate_path::serde::de::IgnoredAny =
302 #crate_path::serde::de::MapAccess::next_value(&mut map)?;
303 }
304 }
305 }
306
307 ::core::result::Result::Ok(#struct_ident {
308 #(#construct_fields_in_map,)*
309 })
310 }
311
312 fn visit_seq<V>(
313 self,
314 mut seq: V,
315 ) -> ::core::result::Result<Self::Value, V::Error>
316 where
317 V: #crate_path::serde::de::SeqAccess<'de>,
318 {
319 #(#field_bindings_in_seq)*
320 #(#seq_field_decode_steps)*
321
322 if #crate_path::serde::de::SeqAccess::next_element::<#crate_path::serde::de::IgnoredAny>(&mut seq)?.is_some() {
323 return ::core::result::Result::Err(
324 #crate_path::serde::de::Error::invalid_length(#seq_expected_len + 1, &self),
325 );
326 }
327
328 ::core::result::Result::Ok(#struct_ident {
329 #(#construct_fields_in_seq,)*
330 })
331 }
332 }
333
334 const __FEATHER_FIELDS: &[&str] = &[#(#known_fields),*];
335 #crate_path::serde::de::Deserializer::deserialize_struct(
336 deserializer,
337 #struct_name,
338 __FEATHER_FIELDS,
339 __FeatherVisitor,
340 )
341 }
342 }
343 })
344}
345
346fn parse_input(input: &DeriveInput, macro_name: &str) -> syn::Result<ParsedStruct> {
347 if !input.generics.params.is_empty() || input.generics.where_clause.is_some() {
348 return Err(syn::Error::new_spanned(
349 &input.generics,
350 format!("{macro_name} only supports non-generic structs in this MVP"),
351 ));
352 }
353
354 let container_options = parse_container_attributes(&input.attrs)?;
355 let struct_name = container_options
356 .rename
357 .unwrap_or_else(|| LitStr::new(&input.ident.to_string(), input.ident.span()));
358
359 let named_fields = match &input.data {
360 Data::Struct(data_struct) => match &data_struct.fields {
361 Fields::Named(fields) => &fields.named,
362 _ => {
363 return Err(syn::Error::new_spanned(
364 &data_struct.fields,
365 format!("{macro_name} only supports structs with named fields in this MVP"),
366 ))
367 }
368 },
369 _ => {
370 return Err(syn::Error::new_spanned(
371 &input.ident,
372 format!("{macro_name} only supports structs in this MVP"),
373 ))
374 }
375 };
376
377 let mut parsed_fields = Vec::with_capacity(named_fields.len());
378 for field in named_fields {
379 parsed_fields.push(parse_field(field)?);
380 }
381
382 validate_unique_wire_field_names(&parsed_fields, WireDirection::Serialize)?;
383 validate_unique_wire_field_names(&parsed_fields, WireDirection::Deserialize)?;
384
385 Ok(ParsedStruct {
386 ident: input.ident.clone(),
387 struct_name,
388 fields: parsed_fields,
389 })
390}
391
392fn validate_unique_wire_field_names(
393 parsed_fields: &[ParsedField],
394 direction: WireDirection,
395) -> syn::Result<()> {
396 let mut seen_by_name: HashMap<String, String> = HashMap::new();
397
398 for field in parsed_fields {
399 if !direction.includes(field) {
400 continue;
401 }
402
403 let wire_name = field.serialized_name.value();
404 let current_field = field.ident.to_string();
405 if let Some(previous_field) = seen_by_name.insert(wire_name.clone(), current_field) {
406 return Err(syn::Error::new(
407 field.serialized_name.span(),
408 format!(
409 "duplicate wire field name `{wire_name}` in {}; conflicts with field \
410 `{previous_field}`",
411 direction.name()
412 ),
413 ));
414 }
415 }
416
417 Ok(())
418}
419
420fn parse_container_attributes(attrs: &[Attribute]) -> syn::Result<ContainerAttrOptions> {
421 let mut options = ContainerAttrOptions { rename: None };
422
423 for attr in attrs {
424 if !attr.path().is_ident("serde") {
425 continue;
426 }
427
428 attr.parse_nested_meta(|meta| {
429 if meta.path.is_ident("rename") {
430 let rename_value: LitStr = meta.value()?.parse()?;
431 if options.rename.replace(rename_value).is_some() {
432 return Err(meta.error("duplicate serde container attribute `rename`"));
433 }
434 return Ok(());
435 }
436
437 Err(meta.error("unsupported serde container attribute; supported attributes: `rename`"))
438 })?;
439 }
440
441 Ok(options)
442}
443
444fn parse_field(field: &Field) -> syn::Result<ParsedField> {
445 let field_ident = field.ident.clone().ok_or_else(|| {
446 syn::Error::new(
447 field.span(),
448 "Feather derives only support fields with identifiers",
449 )
450 })?;
451
452 let mut options = FieldAttrOptions::default();
453
454 for attr in &field.attrs {
455 if !attr.path().is_ident("serde") {
456 continue;
457 }
458
459 attr.parse_nested_meta(|meta| {
460 if meta.path.is_ident("rename") {
461 let rename_value: LitStr = meta.value()?.parse()?;
462 if options.rename.replace(rename_value).is_some() {
463 return Err(meta.error("duplicate serde field attribute `rename`"));
464 }
465 return Ok(());
466 }
467
468 if meta.path.is_ident("default") {
469 ensure_flag_meta_has_no_value(&meta, "default")?;
470 if options.default {
471 return Err(meta.error("duplicate serde field attribute `default`"));
472 }
473 options.default = true;
474 return Ok(());
475 }
476
477 if meta.path.is_ident("skip") {
478 ensure_flag_meta_has_no_value(&meta, "skip")?;
479 if options.skip_serializing || options.skip_deserializing {
480 return Err(meta.error(
481 "serde field attribute `skip` conflicts with previously declared `skip`, \
482 `skip_serializing`, or `skip_deserializing`",
483 ));
484 }
485 options.skip_serializing = true;
486 options.skip_deserializing = true;
487 return Ok(());
488 }
489
490 if meta.path.is_ident("skip_serializing") {
491 ensure_flag_meta_has_no_value(&meta, "skip_serializing")?;
492 if options.skip_serializing {
493 return Err(meta.error("duplicate serde field attribute `skip_serializing`"));
494 }
495 if options.skip_deserializing {
496 return Err(meta.error(
497 "serde field attributes `skip_serializing` and `skip_deserializing` \
498 cannot be combined",
499 ));
500 }
501 options.skip_serializing = true;
502 return Ok(());
503 }
504
505 if meta.path.is_ident("skip_deserializing") {
506 ensure_flag_meta_has_no_value(&meta, "skip_deserializing")?;
507 if options.skip_deserializing {
508 return Err(meta.error("duplicate serde field attribute `skip_deserializing`"));
509 }
510 if options.skip_serializing {
511 return Err(meta.error(
512 "serde field attributes `skip_serializing` and `skip_deserializing` \
513 cannot be combined",
514 ));
515 }
516 options.skip_deserializing = true;
517 return Ok(());
518 }
519
520 Err(meta.error(
521 "unsupported serde field attribute; supported attributes: `rename`, `default`, \
522 `skip`, `skip_serializing`, `skip_deserializing`",
523 ))
524 })?;
525 }
526
527 let serialized_name = options
528 .rename
529 .unwrap_or_else(|| LitStr::new(&field_ident.unraw().to_string(), field_ident.span()));
530
531 Ok(ParsedField {
532 ident: field_ident,
533 ty: field.ty.clone(),
534 serialized_name,
535 default: options.default,
536 skip_serializing: options.skip_serializing,
537 skip_deserializing: options.skip_deserializing,
538 })
539}
540
541fn ensure_flag_meta_has_no_value(
542 meta: &syn::meta::ParseNestedMeta<'_>,
543 name: &str,
544) -> syn::Result<()> {
545 if !meta.input.peek(syn::Token![=]) && !meta.input.peek(syn::token::Paren) {
546 return Ok(());
547 }
548
549 Err(meta.error(format!(
550 "serde field attribute `{name}` does not accept a value"
551 )))
552}
553
554fn serde_feather_path() -> TokenStream2 {
555 match crate_name("serde-feather") {
556 Ok(FoundCrate::Itself) => quote!(crate),
557 Ok(FoundCrate::Name(name)) => {
558 let ident = Ident::new(&name.replace('-', "_"), Span::call_site());
559 quote!(::#ident)
560 }
561 Err(_) => quote!(::serde_feather),
562 }
563}