fire_protobuf_codegen/
encode.rs

1use crate::util::{
2	fire_protobuf_crate, repr_as_i32, variants_no_fields, variants_with_fields
3};
4use crate::attr::FieldAttr;
5
6use std::iter;
7
8use proc_macro2::TokenStream;
9use syn::{
10	DeriveInput, Error, Ident, Generics, Data, DataStruct, DataEnum,
11	Fields, Attribute
12};
13use quote::quote;
14
15
16pub(crate) fn expand(input: DeriveInput) -> Result<TokenStream, Error> {
17	let DeriveInput { attrs, ident, generics, data, .. } = input;
18
19	match data {
20		Data::Struct(d) => expand_struct(ident, generics, d),
21		Data::Enum(e) => expand_enum(attrs, ident, generics, e),
22		Data::Union(_) => Err(Error::new(ident.span(), "union not supported"))
23	}
24}
25
26
27fn expand_struct(
28	ident: Ident,
29	generics: Generics,
30	d: DataStruct
31) -> Result<TokenStream, Error> {
32	let fields = match d.fields {
33		Fields::Named(f) => f.named,
34		_ => return Err(Error::new(ident.span(), "only named structs"))
35	};
36
37	// parse fields
38	let fields: Vec<_> = fields.into_iter()
39		.map(|f| Ok((FieldAttr::from_attrs(&f.attrs)?, f)))
40		.collect::<Result<_, Error>>()?;
41
42	let fire = fire_protobuf_crate()?;
43	let fire_encode = quote!(#fire::encode);
44
45	// the wire type for structs is always len
46	let wire_type = quote!(#fire::WireType::Len);
47	let wire_type_const = quote!(
48		const WIRE_TYPE: #fire::WireType = #wire_type;
49	);
50
51	let enctrait = quote!(#fire_encode::EncodeMessage);
52
53	let encoded_size_fields: Vec<_> = fields.iter()
54		.map(|(attr, f)| {
55			let id = &f.ident;
56			let fieldnum = &attr.fieldnum;
57			quote!(
58				if !#enctrait::is_default(&self.#id) {
59					#enctrait::encoded_size(
60						&mut self.#id,
61						Some(#fire_encode::FieldOpt::new(#fieldnum)),
62						&mut size
63					)?;
64				}
65			)
66		})
67		.collect();
68
69	let encoded_size = quote!(
70		fn encoded_size(
71			&mut self,
72			field: Option<#fire_encode::FieldOpt>,
73			builder: &mut #fire_encode::SizeBuilder
74		) -> std::result::Result<(), #fire_encode::EncodeError> {
75			let mut size = #fire_encode::SizeBuilder::new();
76			#(
77				#encoded_size_fields
78			)*
79			let fields_size = size.finish();
80
81			if let Some(field) = field {
82				builder.write_tag(field.num, #wire_type);
83				builder.write_len(fields_size);
84			}
85
86			builder.write_bytes(fields_size);
87
88			Ok(())
89		}
90	);
91
92
93	let encode_fields: Vec<_> = fields.iter()
94		.map(|(attr, f)| {
95			let id = &f.ident;
96			let fieldnum = &attr.fieldnum;
97			quote!(
98				if !#enctrait::is_default(&self.#id) {
99					#enctrait::encode(
100						&mut self.#id,
101						Some(#fire_encode::FieldOpt::new(#fieldnum)),
102						encoder
103					)?;
104				}
105			)
106		})
107		.collect();
108
109
110	let encode = quote!(
111		fn encode<B>(
112			&mut self,
113			field: Option<#fire_encode::FieldOpt>,
114			encoder: &mut #fire_encode::MessageEncoder<B>
115		) -> std::result::Result<(), #fire_encode::EncodeError>
116		where B: #fire::bytes::BytesWrite {
117			#[cfg(debug_assertions)]
118			let mut dbg_fields_size = None;
119
120			// we don't need to get the size if we don't need to write
121			// the size
122			if let Some(field) = field {
123				encoder.write_tag(field.num, #wire_type)?;
124
125				let mut size = #fire_encode::SizeBuilder::new();
126				#(
127					#encoded_size_fields
128				)*
129				let fields_size = size.finish();
130
131				encoder.write_len(fields_size)?;
132
133				#[cfg(debug_assertions)]
134				{
135					dbg_fields_size = Some(fields_size);
136				}
137			}
138
139			#[cfg(debug_assertions)]
140			let prev_len = encoder.written_len();
141
142			#(
143				#encode_fields
144			)*
145
146			#[cfg(debug_assertions)]
147			if let Some(fields_size) = dbg_fields_size {
148				let added_len = encoder.written_len() - prev_len;
149				assert_eq!(fields_size, added_len as u64,
150					"encoded size does not match actual size");
151			}
152
153			Ok(())
154		}
155	);
156
157	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
158
159
160	Ok(quote!(
161		impl #impl_generics #enctrait for #ident #ty_generics #where_clause {
162			#wire_type_const
163			fn is_default(&self) -> bool { false }
164			#encoded_size
165			#encode
166		}
167	))
168}
169
170fn expand_enum(
171	attrs: Vec<Attribute>,
172	ident: Ident,
173	generics: Generics,
174	d: DataEnum
175) -> Result<TokenStream, Error> {
176	let repr_as_i32 = repr_as_i32(attrs)?;
177
178	if repr_as_i32 {
179		expand_enum_no_fields(ident, generics, d)
180	} else {
181		expand_enum_with_fields(ident, generics, d)
182	}
183}
184
185fn expand_enum_no_fields(
186	ident: Ident,
187	generics: Generics,
188	d: DataEnum
189) -> Result<TokenStream, Error> {
190	// (fieldnum, ident)
191	let (variants, default_variant) = variants_no_fields(d.variants)?;
192
193	let fire = fire_protobuf_crate()?;
194	let fire_encode = quote!(#fire::encode);
195
196	// the wire type for structs is always len
197	let wire_type = quote!(#fire::WireType::Varint);
198	let wire_type_const = quote!(
199		const WIRE_TYPE: #fire::WireType = #wire_type;
200	);
201
202	let enctrait = quote!(#fire_encode::EncodeMessage);
203
204	let match_variants: Vec<_> = variants.iter()
205		.chain(iter::once(&default_variant))
206		.map(|(num, ident)| quote!(Self::#ident => #num))
207		.collect();
208
209	let default_ident = default_variant.1;
210
211	let is_default = quote!(
212		fn is_default(&self) -> bool {
213			matches!(self, Self::#default_ident)
214		}
215	);
216
217	let encoded_size = quote!(
218		fn encoded_size(
219			&mut self,
220			field: Option<#fire_encode::FieldOpt>,
221			builder: &mut #fire_encode::SizeBuilder
222		) -> std::result::Result<(), #fire_encode::EncodeError> {
223			if let Some(field) = field {
224				builder.write_tag(field.num, #wire_type);
225			}
226
227			let varint: i32 = match self {
228				#(#match_variants),*
229			};
230
231			builder.write_varint(varint as u64);
232
233			Ok(())
234		}
235	);
236
237	let encode = quote!(
238		fn encode<B>(
239			&mut self,
240			field: Option<#fire_encode::FieldOpt>,
241			encoder: &mut #fire_encode::MessageEncoder<B>
242		) -> std::result::Result<(), #fire_encode::EncodeError>
243		where B: #fire::bytes::BytesWrite {
244			if let Some(field) = field {
245				encoder.write_tag(field.num, #wire_type)?;
246			}
247
248			let varint: i32 = match self {
249				#(#match_variants),*
250			};
251
252			encoder.write_varint(varint as u64)
253		}
254	);
255
256	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
257
258	Ok(quote!(
259		impl #impl_generics #enctrait for #ident #ty_generics #where_clause {
260			#wire_type_const
261			#is_default
262			#encoded_size
263			#encode
264		}
265	))
266}
267
268fn expand_enum_with_fields(
269	ident: Ident,
270	generics: Generics,
271	d: DataEnum
272) -> Result<TokenStream, Error> {
273	// (FieldAttr, ident, Option<field>)
274	let variants = variants_with_fields(d.variants)?;
275
276	let fire = fire_protobuf_crate()?;
277	let fire_encode = quote!(#fire::encode);
278
279	// the wire type for structs is always len
280	let wire_type = quote!(#fire::WireType::Len);
281	let wire_type_const = quote!(
282		const WIRE_TYPE: #fire::WireType = #wire_type;
283	);
284
285	let enctrait = quote!(#fire_encode::EncodeMessage);
286
287	let encoded_size_variants: Vec<_> = variants.iter()
288		.map(|(attr, ident, field)| {
289			let fieldnum = &attr.fieldnum;
290
291			if let Some(_) = field {
292				quote!(
293					Self::#ident(v) => {
294						#enctrait::encoded_size(
295							v,
296							Some(#fire_encode::FieldOpt::new(#fieldnum)),
297							&mut size
298						)?
299					}
300				)
301			} else {
302				quote!(
303					Self::#ident => {
304						size.write_empty_field(#fieldnum)
305					}
306				)
307			}
308		})
309		.collect();
310
311	let encoded_size = quote!(
312		fn encoded_size(
313			&mut self,
314			field: Option<#fire_encode::FieldOpt>,
315			builder: &mut #fire_encode::SizeBuilder
316		) -> std::result::Result<(), #fire_encode::EncodeError> {
317			let mut size = #fire_encode::SizeBuilder::new();
318			match self {
319				#(#encoded_size_variants),*
320			}
321			let size = size.finish();
322
323			if let Some(field) = field {
324				builder.write_tag(field.num, #wire_type);
325				builder.write_len(size);
326			}
327
328			builder.write_bytes(size);
329
330			Ok(())
331		}
332	);
333
334	let encode_variants: Vec<_> = variants.iter()
335		.map(|(attr, ident, field)| {
336			let fieldnum = &attr.fieldnum;
337
338			if let Some(_) = field {
339				quote!(
340					Self::#ident(v) => {
341						#enctrait::encode(
342							v,
343							Some(#fire_encode::FieldOpt::new(#fieldnum)),
344							encoder
345						)?
346					}
347				)
348			} else {
349				quote!(
350					Self::#ident => {
351						encoder.write_empty_field(#fieldnum)?
352					}
353				)
354			}
355		})
356		.collect();
357
358	let encode = quote!(
359		fn encode<B>(
360			&mut self,
361			field: Option<#fire_encode::FieldOpt>,
362			encoder: &mut #fire_encode::MessageEncoder<B>
363		) -> std::result::Result<(), #fire_encode::EncodeError>
364		where B: #fire::bytes::BytesWrite {
365			#[cfg(debug_assertions)]
366			let mut dbg_fields_size = None;
367
368			/// we don't need to get the size if we don't need to write
369			/// the size
370			if let Some(field) = field {
371				encoder.write_tag(field.num, #wire_type)?;
372
373				let mut size = #fire_encode::SizeBuilder::new();
374				match self {
375					#(
376						#encoded_size_variants
377					)*
378				}
379				let fields_size = size.finish();
380
381				encoder.write_len(fields_size)?;
382
383				#[cfg(debug_assertions)]
384				{
385					dbg_fields_size = Some(fields_size);
386				}
387			}
388
389			#[cfg(debug_assertions)]
390			let prev_len = encoder.written_len();
391
392			match self {
393				#(
394					#encode_variants
395				)*
396			}
397
398			#[cfg(debug_assertions)]
399			if let Some(fields_size) = dbg_fields_size {
400				let added_len = encoder.written_len() - prev_len;
401				assert_eq!(fields_size, added_len as u64,
402					"encoded size does not match actual size");
403			}
404
405			Ok(())
406		}
407	);
408
409	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
410
411	Ok(quote!(
412		impl #impl_generics #enctrait for #ident #ty_generics #where_clause {
413			#wire_type_const
414			fn is_default(&self) -> bool { false }
415			#encoded_size
416			#encode
417		}
418	))
419}