bincode_derive/
derive_enum.rs

1use crate::attribute::{ContainerAttributes, FieldAttributes};
2use virtue::prelude::*;
3
4const TUPLE_FIELD_PREFIX: &str = "field_";
5
6pub(crate) struct DeriveEnum {
7    pub variants: Vec<EnumVariant>,
8    pub attributes: ContainerAttributes,
9}
10
11impl DeriveEnum {
12    fn iter_fields(&self) -> EnumVariantIterator {
13        EnumVariantIterator {
14            idx: 0,
15            variants: &self.variants,
16        }
17    }
18
19    pub fn generate_encode(self, generator: &mut Generator) -> Result<()> {
20        let crate_name = self.attributes.crate_name.as_str();
21        generator
22            .impl_for(format!("{}::Encode", crate_name))
23            .modify_generic_constraints(|generics, where_constraints| {
24                if let Some((bounds, lit)) =
25                    (self.attributes.encode_bounds.as_ref()).or(self.attributes.bounds.as_ref())
26                {
27                    where_constraints.clear();
28                    where_constraints
29                        .push_parsed_constraint(bounds)
30                        .map_err(|e| e.with_span(lit.span()))?;
31                } else {
32                    for g in generics.iter_generics() {
33                        where_constraints
34                            .push_constraint(g, format!("{}::Encode", crate_name))
35                            .unwrap();
36                    }
37                }
38                Ok(())
39            })?
40            .generate_fn("encode")
41            .with_generic_deps("__E", [format!("{}::enc::Encoder", crate_name)])
42            .with_self_arg(FnSelfArg::RefSelf)
43            .with_arg("encoder", "&mut __E")
44            .with_return_type(format!(
45                "core::result::Result<(), {}::error::EncodeError>",
46                crate_name
47            ))
48            .body(|fn_body| {
49                fn_body.ident_str("match");
50                fn_body.ident_str("self");
51                fn_body.group(Delimiter::Brace, |match_body| {
52                    if self.variants.is_empty() {
53                        self.encode_empty_enum_case(match_body)?;
54                    }
55                    for (variant_index, variant) in self.iter_fields() {
56                        // Self::Variant
57                        match_body.ident_str("Self");
58                        match_body.puncts("::");
59                        match_body.ident(variant.name.clone());
60
61                        // if we have any fields, declare them here
62                        // Self::Variant { a, b, c }
63                        if let Some(fields) = variant.fields.as_ref() {
64                            let delimiter = fields.delimiter();
65                            match_body.group(delimiter, |field_body| {
66                                for (idx, field_name) in fields.names().into_iter().enumerate() {
67                                    if idx != 0 {
68                                        field_body.punct(',');
69                                    }
70                                    field_body.push(
71                                        field_name.to_token_tree_with_prefix(TUPLE_FIELD_PREFIX),
72                                    );
73                                }
74                                Ok(())
75                            })?;
76                        }
77
78                        // Arrow
79                        // Self::Variant { a, b, c } =>
80                        match_body.puncts("=>");
81
82                        // Body of this variant
83                        // Note that the fields are available as locals because of the match destructuring above
84                        // {
85                        //      encoder.encode_u32(n)?;
86                        //      bincode::Encode::encode(a, encoder)?;
87                        //      bincode::Encode::encode(b, encoder)?;
88                        //      bincode::Encode::encode(c, encoder)?;
89                        // }
90                        match_body.group(Delimiter::Brace, |body| {
91                            // variant index
92                            body.push_parsed(format!("<u32 as {}::Encode>::encode", crate_name))?;
93                            body.group(Delimiter::Parenthesis, |args| {
94                                args.punct('&');
95                                args.group(Delimiter::Parenthesis, |num| {
96                                    num.extend(variant_index);
97                                    Ok(())
98                                })?;
99                                args.punct(',');
100                                args.push_parsed("encoder")?;
101                                Ok(())
102                            })?;
103                            body.punct('?');
104                            body.punct(';');
105                            // If we have any fields, encode them all one by one
106                            if let Some(fields) = variant.fields.as_ref() {
107                                for field_name in fields.names() {
108                                    let attributes = field_name
109                                        .attributes()
110                                        .get_attribute::<FieldAttributes>()?
111                                        .unwrap_or_default();
112                                    if attributes.with_serde {
113                                        body.push_parsed(format!(
114                                        "{0}::Encode::encode(&{0}::serde::Compat({1}), encoder)?;",
115                                        crate_name,
116                                        field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX),
117                                    ))?;
118                                    } else {
119                                        body.push_parsed(format!(
120                                            "{0}::Encode::encode({1}, encoder)?;",
121                                            crate_name,
122                                            field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX),
123                                        ))?;
124                                    }
125                                }
126                            }
127                            body.push_parsed("core::result::Result::Ok(())")?;
128                            Ok(())
129                        })?;
130                        match_body.punct(',');
131                    }
132                    Ok(())
133                })?;
134                Ok(())
135            })?;
136        Ok(())
137    }
138
139    /// If we're encoding an empty enum, we need to add an empty case in the form of:
140    /// `_ => core::unreachable!(),`
141    fn encode_empty_enum_case(&self, builder: &mut StreamBuilder) -> Result {
142        builder.push_parsed("_ => core::unreachable!()").map(|_| ())
143    }
144
145    /// Build the catch-all case for an int-to-enum decode implementation
146    fn invalid_variant_case(&self, enum_name: &str, result: &mut StreamBuilder) -> Result {
147        let crate_name = self.attributes.crate_name.as_str();
148
149        // we'll be generating:
150        // variant => Err(
151        //    bincode::error::DecodeError::UnexpectedVariant {
152        //        found: variant,
153        //        type_name: <enum_name>
154        //        allowed: ...,
155        //    }
156        // )
157        //
158        // Where allowed is either:
159        // - bincode::error::AllowedEnumVariants::Range { min: 0, max: <max> }
160        //   if we have no fixed value variants
161        // - bincode::error::AllowedEnumVariants::Allowed(&[<variant1>, <variant2>, ...])
162        //   if we have fixed value variants
163        result.ident_str("variant");
164        result.puncts("=>");
165        result.push_parsed("core::result::Result::Err")?;
166        result.group(Delimiter::Parenthesis, |err_inner| {
167            err_inner.push_parsed(format!(
168                "{}::error::DecodeError::UnexpectedVariant",
169                crate_name
170            ))?;
171            err_inner.group(Delimiter::Brace, |variant_inner| {
172                variant_inner.ident_str("found");
173                variant_inner.punct(':');
174                variant_inner.ident_str("variant");
175                variant_inner.punct(',');
176
177                variant_inner.ident_str("type_name");
178                variant_inner.punct(':');
179                variant_inner.lit_str(enum_name);
180                variant_inner.punct(',');
181
182                variant_inner.ident_str("allowed");
183                variant_inner.punct(':');
184
185                if self.variants.iter().any(|i| i.value.is_some()) {
186                    // we have fixed values, implement AllowedEnumVariants::Allowed
187                    variant_inner.push_parsed(format!(
188                        "&{}::error::AllowedEnumVariants::Allowed",
189                        crate_name
190                    ))?;
191                    variant_inner.group(Delimiter::Parenthesis, |allowed_inner| {
192                        allowed_inner.punct('&');
193                        allowed_inner.group(Delimiter::Bracket, |allowed_slice| {
194                            for (idx, (ident, _)) in self.iter_fields().enumerate() {
195                                if idx != 0 {
196                                    allowed_slice.punct(',');
197                                }
198                                allowed_slice.extend(ident);
199                            }
200                            Ok(())
201                        })?;
202                        Ok(())
203                    })?;
204                } else {
205                    // no fixed values, implement a range
206                    variant_inner.push_parsed(format!(
207                        "&{0}::error::AllowedEnumVariants::Range {{ min: 0, max: {1} }}",
208                        crate_name,
209                        self.variants.len() - 1
210                    ))?;
211                }
212                Ok(())
213            })?;
214            Ok(())
215        })?;
216        Ok(())
217    }
218
219    pub fn generate_decode(self, generator: &mut Generator) -> Result<()> {
220        let crate_name = self.attributes.crate_name.as_str();
221
222        let decode_context = if let Some((decode_context, _)) = &self.attributes.decode_context {
223            decode_context.as_str()
224        } else {
225            "__Context"
226        };
227        // Remember to keep this mostly in sync with generate_borrow_decode
228
229        let enum_name = generator.target_name().to_string();
230
231        let mut impl_for = generator.impl_for(format!("{}::Decode", crate_name));
232
233        if self.attributes.decode_context.is_none() {
234            impl_for = impl_for.with_impl_generics(["__Context"]);
235        }
236
237        impl_for
238            .with_trait_generics([decode_context])
239            .modify_generic_constraints(|generics, where_constraints| {
240                if let Some((bounds, lit)) = (self.attributes.decode_bounds.as_ref()).or(self.attributes.bounds.as_ref()) {
241                    where_constraints.clear();
242                    where_constraints.push_parsed_constraint(bounds).map_err(|e| e.with_span(lit.span()))?;
243                } else {
244                    for g in generics.iter_generics() {
245                        where_constraints.push_constraint(g, format!("{}::Decode<__Context>", crate_name))?;
246                    }
247                }
248                Ok(())
249            })?
250            .generate_fn("decode")
251            .with_generic_deps("__D", [format!("{}::de::Decoder<Context = {}>", crate_name, decode_context)])
252            .with_arg("decoder", "&mut __D")
253            .with_return_type(format!("core::result::Result<Self, {}::error::DecodeError>", crate_name))
254            .body(|fn_builder| {
255                if self.variants.is_empty() {
256                    fn_builder.push_parsed(format!(
257                        "core::result::Result::Err({}::error::DecodeError::EmptyEnum {{ type_name: core::any::type_name::<Self>() }})",
258                        crate_name
259                    ))?;
260                } else {
261                    fn_builder
262                        .push_parsed(format!(
263                            "let variant_index = <u32 as {}::Decode::<__D::Context>>::decode(decoder)?;",
264                            crate_name
265                        ))?;
266                    fn_builder.push_parsed("match variant_index")?;
267                    fn_builder.group(Delimiter::Brace, |variant_case| {
268                        for (mut variant_index, variant) in self.iter_fields() {
269                            // idx => Ok(..)
270                            if variant_index.len() > 1 {
271                                variant_case.push_parsed("x if x == ")?;
272                                variant_case.extend(variant_index);
273                            } else {
274                                variant_case.push(variant_index.remove(0));
275                            }
276                            variant_case.puncts("=>");
277                            variant_case.push_parsed("core::result::Result::Ok")?;
278                            variant_case.group(Delimiter::Parenthesis, |variant_case_body| {
279                                // Self::Variant { }
280                                // Self::Variant { 0: ..., 1: ... 2: ... },
281                                // Self::Variant { a: ..., b: ... c: ... },
282                                variant_case_body.ident_str("Self");
283                                variant_case_body.puncts("::");
284                                variant_case_body.ident(variant.name.clone());
285
286                                variant_case_body.group(Delimiter::Brace, |variant_body| {
287                                    if let Some(fields) = variant.fields.as_ref() {
288                                        let is_tuple = matches!(fields, Fields::Tuple(_));
289                                        for (idx, field) in fields.names().into_iter().enumerate() {
290                                            if is_tuple {
291                                                variant_body.lit_usize(idx);
292                                            } else {
293                                                variant_body.ident(field.unwrap_ident().clone());
294                                            }
295                                            variant_body.punct(':');
296                                            let attributes = field.attributes().get_attribute::<FieldAttributes>()?.unwrap_or_default();
297                                            if attributes.with_serde {
298                                                variant_body
299                                                    .push_parsed(format!(
300                                                        "<{0}::serde::Compat<_> as {0}::Decode::<__D::Context>>::decode(decoder)?.0,",
301                                                        crate_name
302                                                    ))?;
303                                            } else {
304                                                variant_body
305                                                    .push_parsed(format!(
306                                                        "{}::Decode::<__D::Context>::decode(decoder)?,",
307                                                        crate_name
308                                                    ))?;
309                                            }
310                                        }
311                                    }
312                                    Ok(())
313                                })?;
314                                Ok(())
315                            })?;
316                            variant_case.punct(',');
317                        }
318
319                        // invalid idx
320                        self.invalid_variant_case(&enum_name, variant_case)
321                    })?;
322                }
323                Ok(())
324            })?;
325        self.generate_borrow_decode(generator)?;
326        Ok(())
327    }
328
329    pub fn generate_borrow_decode(self, generator: &mut Generator) -> Result<()> {
330        let crate_name = &self.attributes.crate_name;
331
332        let decode_context = if let Some((decode_context, _)) = &self.attributes.decode_context {
333            decode_context.as_str()
334        } else {
335            "__Context"
336        };
337
338        // Remember to keep this mostly in sync with generate_decode
339        let enum_name = generator.target_name().to_string();
340
341        let mut impl_for = generator
342            .impl_for_with_lifetimes(format!("{}::BorrowDecode", crate_name), ["__de"])
343            .with_trait_generics([decode_context]);
344        if self.attributes.decode_context.is_none() {
345            impl_for = impl_for.with_impl_generics(["__Context"]);
346        }
347
348        impl_for
349            .modify_generic_constraints(|generics, where_constraints| {
350                if let Some((bounds, lit)) = (self.attributes.borrow_decode_bounds.as_ref()).or(self.attributes.bounds.as_ref()) {
351                    where_constraints.clear();
352                    where_constraints.push_parsed_constraint(bounds).map_err(|e| e.with_span(lit.span()))?;
353                } else {
354                    for g in generics.iter_generics() {
355                        where_constraints.push_constraint(g, format!("{}::de::BorrowDecode<'__de, {}>", crate_name, decode_context)).unwrap();
356                    }
357                    for lt in generics.iter_lifetimes() {
358                        where_constraints.push_parsed_constraint(format!("'__de: '{}", lt.ident))?;
359                    }
360                }
361                Ok(())
362            })?
363            .generate_fn("borrow_decode")
364            .with_generic_deps("__D", [format!("{}::de::BorrowDecoder<'__de, Context = {}>", crate_name, decode_context)])
365            .with_arg("decoder", "&mut __D")
366            .with_return_type(format!("core::result::Result<Self, {}::error::DecodeError>", crate_name))
367            .body(|fn_builder| {
368                if self.variants.is_empty() {
369                    fn_builder.push_parsed(format!(
370                        "core::result::Result::Err({}::error::DecodeError::EmptyEnum {{ type_name: core::any::type_name::<Self>() }})",
371                        crate_name
372                    ))?;
373                } else {
374                    fn_builder
375                        .push_parsed(format!("let variant_index = <u32 as {}::Decode::<__D::Context>>::decode(decoder)?;", crate_name))?;
376                    fn_builder.push_parsed("match variant_index")?;
377                    fn_builder.group(Delimiter::Brace, |variant_case| {
378                        for (mut variant_index, variant) in self.iter_fields() {
379                            // idx => Ok(..)
380                            if variant_index.len() > 1 {
381                                variant_case.push_parsed("x if x == ")?;
382                                variant_case.extend(variant_index);
383                            } else {
384                                variant_case.push(variant_index.remove(0));
385                            }
386                            variant_case.puncts("=>");
387                            variant_case.push_parsed("core::result::Result::Ok")?;
388                            variant_case.group(Delimiter::Parenthesis, |variant_case_body| {
389                                // Self::Variant { }
390                                // Self::Variant { 0: ..., 1: ... 2: ... },
391                                // Self::Variant { a: ..., b: ... c: ... },
392                                variant_case_body.ident_str("Self");
393                                variant_case_body.puncts("::");
394                                variant_case_body.ident(variant.name.clone());
395
396                                variant_case_body.group(Delimiter::Brace, |variant_body| {
397                                    if let Some(fields) = variant.fields.as_ref() {
398                                        let is_tuple = matches!(fields, Fields::Tuple(_));
399                                        for (idx, field) in fields.names().into_iter().enumerate() {
400                                            if is_tuple {
401                                                variant_body.lit_usize(idx);
402                                            } else {
403                                                variant_body.ident(field.unwrap_ident().clone());
404                                            }
405                                            variant_body.punct(':');
406                                            let attributes = field.attributes().get_attribute::<FieldAttributes>()?.unwrap_or_default();
407                                            if attributes.with_serde {
408                                                variant_body
409                                                    .push_parsed(format!("<{0}::serde::BorrowCompat<_> as {0}::BorrowDecode::<__D::Context>>::borrow_decode(decoder)?.0,", crate_name))?;
410                                            } else {
411                                                variant_body.push_parsed(format!("{}::BorrowDecode::<__D::Context>::borrow_decode(decoder)?,", crate_name))?;
412                                            }
413                                        }
414                                    }
415                                    Ok(())
416                                })?;
417                                Ok(())
418                            })?;
419                            variant_case.punct(',');
420                        }
421
422                        // invalid idx
423                        self.invalid_variant_case(&enum_name, variant_case)
424                    })?;
425                }
426                Ok(())
427            })?;
428        Ok(())
429    }
430}
431
432struct EnumVariantIterator<'a> {
433    variants: &'a [EnumVariant],
434    idx: usize,
435}
436
437impl<'a> Iterator for EnumVariantIterator<'a> {
438    type Item = (Vec<TokenTree>, &'a EnumVariant);
439
440    fn next(&mut self) -> Option<Self::Item> {
441        let idx = self.idx;
442        let variant = self.variants.get(self.idx)?;
443        self.idx += 1;
444
445        let tokens = vec![TokenTree::Literal(Literal::u32_suffixed(idx as u32))];
446
447        Some((tokens, variant))
448    }
449}