1use crate::attribute::{ContainerAttributes, FieldAttributes};
2use virtue::prelude::*;
3
4const TUPLE_FIELD_PREFIX: &str = "field_";
5
6pub(crate) struct DeriveEnum {
7 pub variants: Vec<EnumVariant>,
8 pub attributes: ContainerAttributes,
9}
10
11impl DeriveEnum {
12 fn iter_fields(&self) -> EnumVariantIterator {
13 EnumVariantIterator {
14 idx: 0,
15 variants: &self.variants,
16 }
17 }
18
19 pub fn generate_encode(self, generator: &mut Generator) -> Result<()> {
20 let crate_name = self.attributes.crate_name.as_str();
21 generator
22 .impl_for(format!("{}::Encode", crate_name))
23 .modify_generic_constraints(|generics, where_constraints| {
24 if let Some((bounds, lit)) =
25 (self.attributes.encode_bounds.as_ref()).or(self.attributes.bounds.as_ref())
26 {
27 where_constraints.clear();
28 where_constraints
29 .push_parsed_constraint(bounds)
30 .map_err(|e| e.with_span(lit.span()))?;
31 } else {
32 for g in generics.iter_generics() {
33 where_constraints
34 .push_constraint(g, format!("{}::Encode", crate_name))
35 .unwrap();
36 }
37 }
38 Ok(())
39 })?
40 .generate_fn("encode")
41 .with_generic_deps("__E", [format!("{}::enc::Encoder", crate_name)])
42 .with_self_arg(FnSelfArg::RefSelf)
43 .with_arg("encoder", "&mut __E")
44 .with_return_type(format!(
45 "core::result::Result<(), {}::error::EncodeError>",
46 crate_name
47 ))
48 .body(|fn_body| {
49 fn_body.ident_str("match");
50 fn_body.ident_str("self");
51 fn_body.group(Delimiter::Brace, |match_body| {
52 if self.variants.is_empty() {
53 self.encode_empty_enum_case(match_body)?;
54 }
55 for (variant_index, variant) in self.iter_fields() {
56 match_body.ident_str("Self");
58 match_body.puncts("::");
59 match_body.ident(variant.name.clone());
60
61 if let Some(fields) = variant.fields.as_ref() {
64 let delimiter = fields.delimiter();
65 match_body.group(delimiter, |field_body| {
66 for (idx, field_name) in fields.names().into_iter().enumerate() {
67 if idx != 0 {
68 field_body.punct(',');
69 }
70 field_body.push(
71 field_name.to_token_tree_with_prefix(TUPLE_FIELD_PREFIX),
72 );
73 }
74 Ok(())
75 })?;
76 }
77
78 match_body.puncts("=>");
81
82 match_body.group(Delimiter::Brace, |body| {
91 body.push_parsed(format!("<u32 as {}::Encode>::encode", crate_name))?;
93 body.group(Delimiter::Parenthesis, |args| {
94 args.punct('&');
95 args.group(Delimiter::Parenthesis, |num| {
96 num.extend(variant_index);
97 Ok(())
98 })?;
99 args.punct(',');
100 args.push_parsed("encoder")?;
101 Ok(())
102 })?;
103 body.punct('?');
104 body.punct(';');
105 if let Some(fields) = variant.fields.as_ref() {
107 for field_name in fields.names() {
108 let attributes = field_name
109 .attributes()
110 .get_attribute::<FieldAttributes>()?
111 .unwrap_or_default();
112 if attributes.with_serde {
113 body.push_parsed(format!(
114 "{0}::Encode::encode(&{0}::serde::Compat({1}), encoder)?;",
115 crate_name,
116 field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX),
117 ))?;
118 } else {
119 body.push_parsed(format!(
120 "{0}::Encode::encode({1}, encoder)?;",
121 crate_name,
122 field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX),
123 ))?;
124 }
125 }
126 }
127 body.push_parsed("core::result::Result::Ok(())")?;
128 Ok(())
129 })?;
130 match_body.punct(',');
131 }
132 Ok(())
133 })?;
134 Ok(())
135 })?;
136 Ok(())
137 }
138
139 fn encode_empty_enum_case(&self, builder: &mut StreamBuilder) -> Result {
142 builder.push_parsed("_ => core::unreachable!()").map(|_| ())
143 }
144
145 fn invalid_variant_case(&self, enum_name: &str, result: &mut StreamBuilder) -> Result {
147 let crate_name = self.attributes.crate_name.as_str();
148
149 result.ident_str("variant");
164 result.puncts("=>");
165 result.push_parsed("core::result::Result::Err")?;
166 result.group(Delimiter::Parenthesis, |err_inner| {
167 err_inner.push_parsed(format!(
168 "{}::error::DecodeError::UnexpectedVariant",
169 crate_name
170 ))?;
171 err_inner.group(Delimiter::Brace, |variant_inner| {
172 variant_inner.ident_str("found");
173 variant_inner.punct(':');
174 variant_inner.ident_str("variant");
175 variant_inner.punct(',');
176
177 variant_inner.ident_str("type_name");
178 variant_inner.punct(':');
179 variant_inner.lit_str(enum_name);
180 variant_inner.punct(',');
181
182 variant_inner.ident_str("allowed");
183 variant_inner.punct(':');
184
185 if self.variants.iter().any(|i| i.value.is_some()) {
186 variant_inner.push_parsed(format!(
188 "&{}::error::AllowedEnumVariants::Allowed",
189 crate_name
190 ))?;
191 variant_inner.group(Delimiter::Parenthesis, |allowed_inner| {
192 allowed_inner.punct('&');
193 allowed_inner.group(Delimiter::Bracket, |allowed_slice| {
194 for (idx, (ident, _)) in self.iter_fields().enumerate() {
195 if idx != 0 {
196 allowed_slice.punct(',');
197 }
198 allowed_slice.extend(ident);
199 }
200 Ok(())
201 })?;
202 Ok(())
203 })?;
204 } else {
205 variant_inner.push_parsed(format!(
207 "&{0}::error::AllowedEnumVariants::Range {{ min: 0, max: {1} }}",
208 crate_name,
209 self.variants.len() - 1
210 ))?;
211 }
212 Ok(())
213 })?;
214 Ok(())
215 })?;
216 Ok(())
217 }
218
219 pub fn generate_decode(self, generator: &mut Generator) -> Result<()> {
220 let crate_name = self.attributes.crate_name.as_str();
221
222 let decode_context = if let Some((decode_context, _)) = &self.attributes.decode_context {
223 decode_context.as_str()
224 } else {
225 "__Context"
226 };
227 let enum_name = generator.target_name().to_string();
230
231 let mut impl_for = generator.impl_for(format!("{}::Decode", crate_name));
232
233 if self.attributes.decode_context.is_none() {
234 impl_for = impl_for.with_impl_generics(["__Context"]);
235 }
236
237 impl_for
238 .with_trait_generics([decode_context])
239 .modify_generic_constraints(|generics, where_constraints| {
240 if let Some((bounds, lit)) = (self.attributes.decode_bounds.as_ref()).or(self.attributes.bounds.as_ref()) {
241 where_constraints.clear();
242 where_constraints.push_parsed_constraint(bounds).map_err(|e| e.with_span(lit.span()))?;
243 } else {
244 for g in generics.iter_generics() {
245 where_constraints.push_constraint(g, format!("{}::Decode<__Context>", crate_name))?;
246 }
247 }
248 Ok(())
249 })?
250 .generate_fn("decode")
251 .with_generic_deps("__D", [format!("{}::de::Decoder<Context = {}>", crate_name, decode_context)])
252 .with_arg("decoder", "&mut __D")
253 .with_return_type(format!("core::result::Result<Self, {}::error::DecodeError>", crate_name))
254 .body(|fn_builder| {
255 if self.variants.is_empty() {
256 fn_builder.push_parsed(format!(
257 "core::result::Result::Err({}::error::DecodeError::EmptyEnum {{ type_name: core::any::type_name::<Self>() }})",
258 crate_name
259 ))?;
260 } else {
261 fn_builder
262 .push_parsed(format!(
263 "let variant_index = <u32 as {}::Decode::<__D::Context>>::decode(decoder)?;",
264 crate_name
265 ))?;
266 fn_builder.push_parsed("match variant_index")?;
267 fn_builder.group(Delimiter::Brace, |variant_case| {
268 for (mut variant_index, variant) in self.iter_fields() {
269 if variant_index.len() > 1 {
271 variant_case.push_parsed("x if x == ")?;
272 variant_case.extend(variant_index);
273 } else {
274 variant_case.push(variant_index.remove(0));
275 }
276 variant_case.puncts("=>");
277 variant_case.push_parsed("core::result::Result::Ok")?;
278 variant_case.group(Delimiter::Parenthesis, |variant_case_body| {
279 variant_case_body.ident_str("Self");
283 variant_case_body.puncts("::");
284 variant_case_body.ident(variant.name.clone());
285
286 variant_case_body.group(Delimiter::Brace, |variant_body| {
287 if let Some(fields) = variant.fields.as_ref() {
288 let is_tuple = matches!(fields, Fields::Tuple(_));
289 for (idx, field) in fields.names().into_iter().enumerate() {
290 if is_tuple {
291 variant_body.lit_usize(idx);
292 } else {
293 variant_body.ident(field.unwrap_ident().clone());
294 }
295 variant_body.punct(':');
296 let attributes = field.attributes().get_attribute::<FieldAttributes>()?.unwrap_or_default();
297 if attributes.with_serde {
298 variant_body
299 .push_parsed(format!(
300 "<{0}::serde::Compat<_> as {0}::Decode::<__D::Context>>::decode(decoder)?.0,",
301 crate_name
302 ))?;
303 } else {
304 variant_body
305 .push_parsed(format!(
306 "{}::Decode::<__D::Context>::decode(decoder)?,",
307 crate_name
308 ))?;
309 }
310 }
311 }
312 Ok(())
313 })?;
314 Ok(())
315 })?;
316 variant_case.punct(',');
317 }
318
319 self.invalid_variant_case(&enum_name, variant_case)
321 })?;
322 }
323 Ok(())
324 })?;
325 self.generate_borrow_decode(generator)?;
326 Ok(())
327 }
328
329 pub fn generate_borrow_decode(self, generator: &mut Generator) -> Result<()> {
330 let crate_name = &self.attributes.crate_name;
331
332 let decode_context = if let Some((decode_context, _)) = &self.attributes.decode_context {
333 decode_context.as_str()
334 } else {
335 "__Context"
336 };
337
338 let enum_name = generator.target_name().to_string();
340
341 let mut impl_for = generator
342 .impl_for_with_lifetimes(format!("{}::BorrowDecode", crate_name), ["__de"])
343 .with_trait_generics([decode_context]);
344 if self.attributes.decode_context.is_none() {
345 impl_for = impl_for.with_impl_generics(["__Context"]);
346 }
347
348 impl_for
349 .modify_generic_constraints(|generics, where_constraints| {
350 if let Some((bounds, lit)) = (self.attributes.borrow_decode_bounds.as_ref()).or(self.attributes.bounds.as_ref()) {
351 where_constraints.clear();
352 where_constraints.push_parsed_constraint(bounds).map_err(|e| e.with_span(lit.span()))?;
353 } else {
354 for g in generics.iter_generics() {
355 where_constraints.push_constraint(g, format!("{}::de::BorrowDecode<'__de, {}>", crate_name, decode_context)).unwrap();
356 }
357 for lt in generics.iter_lifetimes() {
358 where_constraints.push_parsed_constraint(format!("'__de: '{}", lt.ident))?;
359 }
360 }
361 Ok(())
362 })?
363 .generate_fn("borrow_decode")
364 .with_generic_deps("__D", [format!("{}::de::BorrowDecoder<'__de, Context = {}>", crate_name, decode_context)])
365 .with_arg("decoder", "&mut __D")
366 .with_return_type(format!("core::result::Result<Self, {}::error::DecodeError>", crate_name))
367 .body(|fn_builder| {
368 if self.variants.is_empty() {
369 fn_builder.push_parsed(format!(
370 "core::result::Result::Err({}::error::DecodeError::EmptyEnum {{ type_name: core::any::type_name::<Self>() }})",
371 crate_name
372 ))?;
373 } else {
374 fn_builder
375 .push_parsed(format!("let variant_index = <u32 as {}::Decode::<__D::Context>>::decode(decoder)?;", crate_name))?;
376 fn_builder.push_parsed("match variant_index")?;
377 fn_builder.group(Delimiter::Brace, |variant_case| {
378 for (mut variant_index, variant) in self.iter_fields() {
379 if variant_index.len() > 1 {
381 variant_case.push_parsed("x if x == ")?;
382 variant_case.extend(variant_index);
383 } else {
384 variant_case.push(variant_index.remove(0));
385 }
386 variant_case.puncts("=>");
387 variant_case.push_parsed("core::result::Result::Ok")?;
388 variant_case.group(Delimiter::Parenthesis, |variant_case_body| {
389 variant_case_body.ident_str("Self");
393 variant_case_body.puncts("::");
394 variant_case_body.ident(variant.name.clone());
395
396 variant_case_body.group(Delimiter::Brace, |variant_body| {
397 if let Some(fields) = variant.fields.as_ref() {
398 let is_tuple = matches!(fields, Fields::Tuple(_));
399 for (idx, field) in fields.names().into_iter().enumerate() {
400 if is_tuple {
401 variant_body.lit_usize(idx);
402 } else {
403 variant_body.ident(field.unwrap_ident().clone());
404 }
405 variant_body.punct(':');
406 let attributes = field.attributes().get_attribute::<FieldAttributes>()?.unwrap_or_default();
407 if attributes.with_serde {
408 variant_body
409 .push_parsed(format!("<{0}::serde::BorrowCompat<_> as {0}::BorrowDecode::<__D::Context>>::borrow_decode(decoder)?.0,", crate_name))?;
410 } else {
411 variant_body.push_parsed(format!("{}::BorrowDecode::<__D::Context>::borrow_decode(decoder)?,", crate_name))?;
412 }
413 }
414 }
415 Ok(())
416 })?;
417 Ok(())
418 })?;
419 variant_case.punct(',');
420 }
421
422 self.invalid_variant_case(&enum_name, variant_case)
424 })?;
425 }
426 Ok(())
427 })?;
428 Ok(())
429 }
430}
431
432struct EnumVariantIterator<'a> {
433 variants: &'a [EnumVariant],
434 idx: usize,
435}
436
437impl<'a> Iterator for EnumVariantIterator<'a> {
438 type Item = (Vec<TokenTree>, &'a EnumVariant);
439
440 fn next(&mut self) -> Option<Self::Item> {
441 let idx = self.idx;
442 let variant = self.variants.get(self.idx)?;
443 self.idx += 1;
444
445 let tokens = vec![TokenTree::Literal(Literal::u32_suffixed(idx as u32))];
446
447 Some((tokens, variant))
448 }
449}