1use crate::util::{
2 fire_protobuf_crate, repr_as_i32, variants_no_fields, variants_with_fields
3};
4use crate::attr::FieldAttr;
5
6use std::iter;
7
8use proc_macro2::TokenStream;
9use syn::{
10 DeriveInput, Error, Ident, Generics, Data, DataStruct, DataEnum,
11 Fields, Attribute
12};
13use quote::quote;
14
15
16pub(crate) fn expand(input: DeriveInput) -> Result<TokenStream, Error> {
17 let DeriveInput { attrs, ident, generics, data, .. } = input;
18
19 match data {
20 Data::Struct(d) => expand_struct(ident, generics, d),
21 Data::Enum(e) => expand_enum(attrs, ident, generics, e),
22 Data::Union(_) => Err(Error::new(ident.span(), "union not supported"))
23 }
24}
25
26
27fn expand_struct(
28 ident: Ident,
29 generics: Generics,
30 d: DataStruct
31) -> Result<TokenStream, Error> {
32 let fields = match d.fields {
33 Fields::Named(f) => f.named,
34 _ => return Err(Error::new(ident.span(), "only named structs"))
35 };
36
37 let fields: Vec<_> = fields.into_iter()
39 .map(|f| Ok((FieldAttr::from_attrs(&f.attrs)?, f)))
40 .collect::<Result<_, Error>>()?;
41
42 let fire = fire_protobuf_crate()?;
43 let fire_encode = quote!(#fire::encode);
44
45 let wire_type = quote!(#fire::WireType::Len);
47 let wire_type_const = quote!(
48 const WIRE_TYPE: #fire::WireType = #wire_type;
49 );
50
51 let enctrait = quote!(#fire_encode::EncodeMessage);
52
53 let encoded_size_fields: Vec<_> = fields.iter()
54 .map(|(attr, f)| {
55 let id = &f.ident;
56 let fieldnum = &attr.fieldnum;
57 quote!(
58 if !#enctrait::is_default(&self.#id) {
59 #enctrait::encoded_size(
60 &mut self.#id,
61 Some(#fire_encode::FieldOpt::new(#fieldnum)),
62 &mut size
63 )?;
64 }
65 )
66 })
67 .collect();
68
69 let encoded_size = quote!(
70 fn encoded_size(
71 &mut self,
72 field: Option<#fire_encode::FieldOpt>,
73 builder: &mut #fire_encode::SizeBuilder
74 ) -> std::result::Result<(), #fire_encode::EncodeError> {
75 let mut size = #fire_encode::SizeBuilder::new();
76 #(
77 #encoded_size_fields
78 )*
79 let fields_size = size.finish();
80
81 if let Some(field) = field {
82 builder.write_tag(field.num, #wire_type);
83 builder.write_len(fields_size);
84 }
85
86 builder.write_bytes(fields_size);
87
88 Ok(())
89 }
90 );
91
92
93 let encode_fields: Vec<_> = fields.iter()
94 .map(|(attr, f)| {
95 let id = &f.ident;
96 let fieldnum = &attr.fieldnum;
97 quote!(
98 if !#enctrait::is_default(&self.#id) {
99 #enctrait::encode(
100 &mut self.#id,
101 Some(#fire_encode::FieldOpt::new(#fieldnum)),
102 encoder
103 )?;
104 }
105 )
106 })
107 .collect();
108
109
110 let encode = quote!(
111 fn encode<B>(
112 &mut self,
113 field: Option<#fire_encode::FieldOpt>,
114 encoder: &mut #fire_encode::MessageEncoder<B>
115 ) -> std::result::Result<(), #fire_encode::EncodeError>
116 where B: #fire::bytes::BytesWrite {
117 #[cfg(debug_assertions)]
118 let mut dbg_fields_size = None;
119
120 if let Some(field) = field {
123 encoder.write_tag(field.num, #wire_type)?;
124
125 let mut size = #fire_encode::SizeBuilder::new();
126 #(
127 #encoded_size_fields
128 )*
129 let fields_size = size.finish();
130
131 encoder.write_len(fields_size)?;
132
133 #[cfg(debug_assertions)]
134 {
135 dbg_fields_size = Some(fields_size);
136 }
137 }
138
139 #[cfg(debug_assertions)]
140 let prev_len = encoder.written_len();
141
142 #(
143 #encode_fields
144 )*
145
146 #[cfg(debug_assertions)]
147 if let Some(fields_size) = dbg_fields_size {
148 let added_len = encoder.written_len() - prev_len;
149 assert_eq!(fields_size, added_len as u64,
150 "encoded size does not match actual size");
151 }
152
153 Ok(())
154 }
155 );
156
157 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
158
159
160 Ok(quote!(
161 impl #impl_generics #enctrait for #ident #ty_generics #where_clause {
162 #wire_type_const
163 fn is_default(&self) -> bool { false }
164 #encoded_size
165 #encode
166 }
167 ))
168}
169
170fn expand_enum(
171 attrs: Vec<Attribute>,
172 ident: Ident,
173 generics: Generics,
174 d: DataEnum
175) -> Result<TokenStream, Error> {
176 let repr_as_i32 = repr_as_i32(attrs)?;
177
178 if repr_as_i32 {
179 expand_enum_no_fields(ident, generics, d)
180 } else {
181 expand_enum_with_fields(ident, generics, d)
182 }
183}
184
185fn expand_enum_no_fields(
186 ident: Ident,
187 generics: Generics,
188 d: DataEnum
189) -> Result<TokenStream, Error> {
190 let (variants, default_variant) = variants_no_fields(d.variants)?;
192
193 let fire = fire_protobuf_crate()?;
194 let fire_encode = quote!(#fire::encode);
195
196 let wire_type = quote!(#fire::WireType::Varint);
198 let wire_type_const = quote!(
199 const WIRE_TYPE: #fire::WireType = #wire_type;
200 );
201
202 let enctrait = quote!(#fire_encode::EncodeMessage);
203
204 let match_variants: Vec<_> = variants.iter()
205 .chain(iter::once(&default_variant))
206 .map(|(num, ident)| quote!(Self::#ident => #num))
207 .collect();
208
209 let default_ident = default_variant.1;
210
211 let is_default = quote!(
212 fn is_default(&self) -> bool {
213 matches!(self, Self::#default_ident)
214 }
215 );
216
217 let encoded_size = quote!(
218 fn encoded_size(
219 &mut self,
220 field: Option<#fire_encode::FieldOpt>,
221 builder: &mut #fire_encode::SizeBuilder
222 ) -> std::result::Result<(), #fire_encode::EncodeError> {
223 if let Some(field) = field {
224 builder.write_tag(field.num, #wire_type);
225 }
226
227 let varint: i32 = match self {
228 #(#match_variants),*
229 };
230
231 builder.write_varint(varint as u64);
232
233 Ok(())
234 }
235 );
236
237 let encode = quote!(
238 fn encode<B>(
239 &mut self,
240 field: Option<#fire_encode::FieldOpt>,
241 encoder: &mut #fire_encode::MessageEncoder<B>
242 ) -> std::result::Result<(), #fire_encode::EncodeError>
243 where B: #fire::bytes::BytesWrite {
244 if let Some(field) = field {
245 encoder.write_tag(field.num, #wire_type)?;
246 }
247
248 let varint: i32 = match self {
249 #(#match_variants),*
250 };
251
252 encoder.write_varint(varint as u64)
253 }
254 );
255
256 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
257
258 Ok(quote!(
259 impl #impl_generics #enctrait for #ident #ty_generics #where_clause {
260 #wire_type_const
261 #is_default
262 #encoded_size
263 #encode
264 }
265 ))
266}
267
268fn expand_enum_with_fields(
269 ident: Ident,
270 generics: Generics,
271 d: DataEnum
272) -> Result<TokenStream, Error> {
273 let variants = variants_with_fields(d.variants)?;
275
276 let fire = fire_protobuf_crate()?;
277 let fire_encode = quote!(#fire::encode);
278
279 let wire_type = quote!(#fire::WireType::Len);
281 let wire_type_const = quote!(
282 const WIRE_TYPE: #fire::WireType = #wire_type;
283 );
284
285 let enctrait = quote!(#fire_encode::EncodeMessage);
286
287 let encoded_size_variants: Vec<_> = variants.iter()
288 .map(|(attr, ident, field)| {
289 let fieldnum = &attr.fieldnum;
290
291 if let Some(_) = field {
292 quote!(
293 Self::#ident(v) => {
294 #enctrait::encoded_size(
295 v,
296 Some(#fire_encode::FieldOpt::new(#fieldnum)),
297 &mut size
298 )?
299 }
300 )
301 } else {
302 quote!(
303 Self::#ident => {
304 size.write_empty_field(#fieldnum)
305 }
306 )
307 }
308 })
309 .collect();
310
311 let encoded_size = quote!(
312 fn encoded_size(
313 &mut self,
314 field: Option<#fire_encode::FieldOpt>,
315 builder: &mut #fire_encode::SizeBuilder
316 ) -> std::result::Result<(), #fire_encode::EncodeError> {
317 let mut size = #fire_encode::SizeBuilder::new();
318 match self {
319 #(#encoded_size_variants),*
320 }
321 let size = size.finish();
322
323 if let Some(field) = field {
324 builder.write_tag(field.num, #wire_type);
325 builder.write_len(size);
326 }
327
328 builder.write_bytes(size);
329
330 Ok(())
331 }
332 );
333
334 let encode_variants: Vec<_> = variants.iter()
335 .map(|(attr, ident, field)| {
336 let fieldnum = &attr.fieldnum;
337
338 if let Some(_) = field {
339 quote!(
340 Self::#ident(v) => {
341 #enctrait::encode(
342 v,
343 Some(#fire_encode::FieldOpt::new(#fieldnum)),
344 encoder
345 )?
346 }
347 )
348 } else {
349 quote!(
350 Self::#ident => {
351 encoder.write_empty_field(#fieldnum)?
352 }
353 )
354 }
355 })
356 .collect();
357
358 let encode = quote!(
359 fn encode<B>(
360 &mut self,
361 field: Option<#fire_encode::FieldOpt>,
362 encoder: &mut #fire_encode::MessageEncoder<B>
363 ) -> std::result::Result<(), #fire_encode::EncodeError>
364 where B: #fire::bytes::BytesWrite {
365 #[cfg(debug_assertions)]
366 let mut dbg_fields_size = None;
367
368 if let Some(field) = field {
371 encoder.write_tag(field.num, #wire_type)?;
372
373 let mut size = #fire_encode::SizeBuilder::new();
374 match self {
375 #(
376 #encoded_size_variants
377 )*
378 }
379 let fields_size = size.finish();
380
381 encoder.write_len(fields_size)?;
382
383 #[cfg(debug_assertions)]
384 {
385 dbg_fields_size = Some(fields_size);
386 }
387 }
388
389 #[cfg(debug_assertions)]
390 let prev_len = encoder.written_len();
391
392 match self {
393 #(
394 #encode_variants
395 )*
396 }
397
398 #[cfg(debug_assertions)]
399 if let Some(fields_size) = dbg_fields_size {
400 let added_len = encoder.written_len() - prev_len;
401 assert_eq!(fields_size, added_len as u64,
402 "encoded size does not match actual size");
403 }
404
405 Ok(())
406 }
407 );
408
409 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
410
411 Ok(quote!(
412 impl #impl_generics #enctrait for #ident #ty_generics #where_clause {
413 #wire_type_const
414 fn is_default(&self) -> bool { false }
415 #encoded_size
416 #encode
417 }
418 ))
419}