1use darling::ast::Fields;
2use proc_macro2::TokenStream;
3use quote::{quote, ToTokens};
4
5use crate::{FieldInfo, FieldsCollector, FieldsHelper};
6
7#[macro_export]
9macro_rules! variant_info {
10 ($v:path, $f:path) => {
11 impl $crate::VariantInfo<$f> for $v {
12 fn ident(&self) -> &syn::Ident {
13 &self.ident
14 }
15
16 fn discriminant(&self) -> &Option<syn::Expr> {
17 &self.discriminant
18 }
19
20 fn fields(&self) -> &darling::ast::Fields<$f> {
21 &self.fields
22 }
23 }
24 };
25}
26
27pub trait VariantInfo<F: FieldInfo> {
29 fn ident(&self) -> &syn::Ident;
31 fn discriminant(&self) -> &Option<syn::Expr>;
33 fn fields(&self) -> &Fields<F>;
35 fn discriminant_expr(&self) -> TokenStream {
37 self.discriminant()
38 .as_ref()
39 .map(|d| {
40 let expr = d.to_token_stream();
41 quote!(= #expr)
42 })
43 .unwrap_or_default()
44 }
45}
46
47pub struct VariantsHelper<'v, V: VariantInfo<F>, F: FieldInfo> {
49 variants: Vec<&'v V>,
50 variant_filter: Option<Box<dyn Fn(&V) -> bool + 'v>>,
51 variant_attributes: Option<Box<dyn Fn(&V) -> Option<TokenStream> + 'v>>,
52 include_extra_variants: Vec<(TokenStream, Option<TokenStream>)>,
53 ignore_all_extra_variants: Option<TokenStream>,
54 include_wrapper: bool,
55 left_collector: Option<Box<dyn Fn(&V, FieldsHelper<'_, F>) -> TokenStream + 'v>>,
56 right_collector: Option<Box<dyn Fn(&V, FieldsHelper<'_, F>) -> TokenStream + 'v>>,
57}
58
59impl<'v, V: VariantInfo<F>, F: FieldInfo> VariantsHelper<'v, V, F> {
60 pub fn new(variants: &'v [V]) -> Self {
62 Self {
63 variants: variants.iter().collect(),
64 variant_filter: None,
65 variant_attributes: None,
66 include_extra_variants: Vec::new(),
67 ignore_all_extra_variants: None,
68 include_wrapper: true,
69 left_collector: None,
70 right_collector: None,
71 }
72 }
73
74 pub fn filtering_variants<P>(mut self, predicate: P) -> Self
78 where
79 P: Fn(&V) -> bool + 'v,
80 {
81 self.variant_filter = Some(Box::new(predicate));
82 self
83 }
84
85 pub fn with_variant_attributes<P>(mut self, predicate: P) -> Self
87 where
88 P: Fn(&V) -> Option<TokenStream> + 'v,
89 {
90 self.variant_attributes = Some(Box::new(predicate));
91 self
92 }
93
94 pub fn include_extra_variants(
97 mut self,
98 include_extra_variants: impl IntoIterator<Item = (impl ToTokens, Option<impl ToTokens>)>,
99 ) -> Self {
100 let mut include_extra_variants = include_extra_variants
101 .into_iter()
102 .map(|(l, r)| (l.to_token_stream(), r.map(|t| t.to_token_stream())))
103 .collect::<Vec<_>>();
104 self.include_extra_variants.append(&mut include_extra_variants);
105 self
106 }
107
108 pub fn ignore_all_extra_variants(mut self, right_side: Option<TokenStream>) -> Self {
112 self.ignore_all_extra_variants = right_side;
113 self
114 }
115
116 pub fn include_wrapper(mut self, include_wrapper: bool) -> Self {
118 self.include_wrapper = include_wrapper;
119 self
120 }
121
122 pub fn left_collector<C>(mut self, left_collector: C) -> Self
126 where
127 C: Fn(&V, FieldsHelper<'_, F>) -> TokenStream + 'v,
128 {
129 self.left_collector = Some(Box::new(left_collector));
130 self
131 }
132
133 pub fn right_collector<C>(mut self, right_collector: C) -> Self
137 where
138 C: Fn(&V, FieldsHelper<'_, F>) -> TokenStream + 'v,
139 {
140 self.right_collector = Some(Box::new(right_collector));
141 self
142 }
143
144 pub fn collect(self) -> TokenStream {
175 let left_collector = self
176 .left_collector
177 .unwrap_or_else(|| Box::new(VariantsCollector::empty));
178
179 let mut variants = self
180 .variants
181 .into_iter()
182 .filter(|&v| {
183 if let Some(variant_filter_fn) = &self.variant_filter {
184 variant_filter_fn(v)
185 } else {
186 true
187 }
188 })
189 .map(|v| {
190 let attrs = if let Some(attrs_fn) = &self.variant_attributes {
191 attrs_fn(v)
192 } else {
193 None
194 }
195 .unwrap_or_default();
196
197 let left = left_collector(v, FieldsHelper::new(v.fields()));
198 let right = if let Some(right_collector) = &self.right_collector {
199 let right = right_collector(v, FieldsHelper::new(v.fields()));
200 quote!(=> #right)
201 } else {
202 TokenStream::default()
203 };
204
205 quote!(
206 #attrs
207 #left #right
208 )
209 })
210 .collect::<Vec<_>>();
211
212 for (left, right) in self.include_extra_variants {
213 let right = right.map(|r| quote!(=> #r)).unwrap_or_default();
214 variants.push(quote!(#left #right));
215 }
216
217 if let Some(right) = self.ignore_all_extra_variants {
218 variants.push(quote!(_ => #right));
219 }
220
221 if self.include_wrapper {
222 quote!(
223 {
224 #( #variants ),*
225 }
226 )
227 } else {
228 quote!( #( #variants ),* )
229 }
230 }
231}
232
233pub struct VariantsCollector;
235impl VariantsCollector {
236 pub fn empty<V, F>(_v: &V, _f: FieldsHelper<'_, F>) -> TokenStream
238 where
239 V: VariantInfo<F>,
240 F: FieldInfo,
241 {
242 TokenStream::default()
243 }
244
245 pub fn variant_definition<V, F>(v: &V, fields: FieldsHelper<'_, F>) -> TokenStream
259 where
260 V: VariantInfo<F>,
261 F: FieldInfo,
262 {
263 let ident = v.ident();
264 let dis = v.discriminant_expr();
265 let fields_expr = fields.collect();
266 quote!(
267 #ident #dis #fields_expr
268 )
269 }
270
271 pub fn variant_fields_collector<V, F>(ty: impl ToTokens) -> impl Fn(&V, FieldsHelper<'_, F>) -> TokenStream
290 where
291 V: VariantInfo<F>,
292 F: FieldInfo,
293 {
294 move |v, fields| {
295 let ident = v.ident();
296 let right = fields.right_collector(FieldsCollector::ident).collect();
297 quote!( #ty::#ident #right )
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 #![allow(clippy::manual_unwrap_or_default)] use darling::{FromField, FromVariant};
307 use quote::quote;
308 use syn::Result;
309
310 use super::*;
311 use crate::field_info;
312
313 #[derive(FromField, Clone)]
314 #[darling(attributes(tst))]
315 struct FieldReceiver {
316 ident: Option<syn::Ident>,
318 vis: syn::Visibility,
320 ty: syn::Type,
322
323 #[darling(default)]
325 pub skip: bool,
326 }
327 field_info!(FieldReceiver);
328
329 #[derive(FromVariant, Clone)]
330 #[darling(attributes(tst))]
331 struct VariantReceiver {
332 ident: syn::Ident,
334 discriminant: Option<syn::Expr>,
336 fields: Fields<FieldReceiver>,
338
339 #[darling(default)]
341 skip: bool,
342 }
343 variant_info!(VariantReceiver, FieldReceiver);
344
345 #[test]
346 fn test_variant_helper() -> Result<()> {
347 let input: syn::DeriveInput = syn::parse2(quote! {
348 pub enum MyEnum {
349 Variant1,
350 #[tst(skip)]
351 Variant2,
352 Variant3 (String, i64),
353 Variant4 {
354 field_1: String,
355 #[tst(skip)]
356 field_2: i32,
357 field_3: bool,
358 }
359 }
360 })?;
361 let variants = darling::ast::Data::<VariantReceiver, FieldReceiver>::try_from(&input.data)?
362 .take_enum()
363 .unwrap();
364
365 let collected = VariantsHelper::new(&variants)
366 .filtering_variants(|v| !v.skip)
367 .left_collector(|v, fields| {
368 let ident = &v.ident;
369 let dis = v.discriminant_expr();
370 let fields_expr = fields.filtering(|_ix, f| !f.skip).collect();
371 quote!(
372 #ident #dis #fields_expr
373 )
374 })
375 .collect();
376 #[rustfmt::skip]
377 let expected = quote!({
378 Variant1,
379 Variant3 (String, i64),
380 Variant4 {
381 field_1: String,
382 field_3: bool
383 }
384 });
385
386 assert_eq!(collected.to_string(), expected.to_string());
387
388 let collected = VariantsHelper::new(&variants)
389 .with_variant_attributes(|v| if v.skip { Some(quote!(#[skipped])) } else { None })
390 .left_collector(|v, fields| {
391 let ident = &v.ident;
392 let dis = v.discriminant_expr();
393 let fields_expr = fields.filtering(|_ix, f| !f.skip).include_all_default(true).collect();
394 quote!(
395 #ident #dis #fields_expr
396 )
397 })
398 .collect();
399 #[rustfmt::skip]
400 let expected = quote!({
401 Variant1,
402 #[skipped]
403 Variant2,
404 Variant3 (String, i64, .. Default::default()),
405 Variant4 {
406 field_1: String,
407 field_3: bool,
408 .. Default::default()
409 }
410 });
411
412 assert_eq!(collected.to_string(), expected.to_string());
413
414 let collected = VariantsHelper::new(&variants)
415 .filtering_variants(|v| !v.skip)
416 .left_collector(|v, fields| {
417 let ident = &v.ident;
418 let dis = v.discriminant_expr();
419 let fields_expr = fields
420 .with_attributes(|_ix, f| if f.skip { Some(quote!(#[skipped])) } else { None })
421 .ignore_all_extra(true)
422 .collect();
423 quote!(
424 #ident #dis #fields_expr
425 )
426 })
427 .collect();
428 #[rustfmt::skip]
429 let expected = quote!({
430 Variant1,
431 Variant3 (String, i64, ..),
432 Variant4 {
433 field_1: String,
434 #[skipped]
435 field_2: i32,
436 field_3: bool,
437 ..
438 }
439 });
440
441 assert_eq!(collected.to_string(), expected.to_string());
442
443 let collected = VariantsHelper::new(&variants)
444 .filtering_variants(|v| !v.skip)
445 .left_collector(|v, fields| {
446 let ident = &v.ident;
447 let fields_expr = fields
448 .filtering(|_ix, f| !f.skip)
449 .right_collector(FieldsCollector::ident)
450 .collect();
451 quote!( MyEnum1::#ident #fields_expr )
452 })
453 .right_collector(|v, fields| {
454 let ident = &v.ident;
455 let fields_expr = fields
456 .filtering(|_ix, f| !f.skip)
457 .right_collector(FieldsCollector::ident)
458 .collect();
459 quote!( MyEnum2::#ident #fields_expr )
460 })
461 .collect();
462 #[rustfmt::skip]
463 let expected = quote!({
464 MyEnum1::Variant1 => MyEnum2::Variant1,
465 MyEnum1::Variant3 (v_0, v_1) => MyEnum2::Variant3 (v_0, v_1),
466 MyEnum1::Variant4 {
467 field_1: field_1,
468 field_3: field_3
469 } => MyEnum2::Variant4 {
470 field_1: field_1,
471 field_3: field_3
472 }
473 });
474
475 assert_eq!(collected.to_string(), expected.to_string());
476
477 let collected = VariantsHelper::new(&variants)
478 .left_collector(|v, fields| {
479 let ident = &v.ident;
480 let fields_expr = fields
481 .ignore_all_extra(true)
482 .right_collector(FieldsCollector::ident)
483 .collect();
484 quote!( MyEnum1::#ident #fields_expr )
485 })
486 .right_collector(|v, fields| {
487 let ident = &v.ident;
488 let fields_expr = fields
489 .include_all_default(true)
490 .right_collector(FieldsCollector::ident)
491 .collect();
492 quote!( MyEnum2::#ident #fields_expr )
493 })
494 .collect();
495 #[rustfmt::skip]
496 let expected = quote!({
497 MyEnum1::Variant1 => MyEnum2::Variant1,
498 MyEnum1::Variant2 => MyEnum2::Variant2,
499 MyEnum1::Variant3 (v_0, v_1, ..) => MyEnum2::Variant3 (v_0, v_1, ..Default::default()),
500 MyEnum1::Variant4 {
501 field_1: field_1,
502 field_2: field_2,
503 field_3: field_3,
504 ..
505 } => MyEnum2::Variant4 {
506 field_1: field_1,
507 field_2: field_2,
508 field_3: field_3,
509 ..Default::default()
510 }
511 });
512
513 assert_eq!(collected.to_string(), expected.to_string());
514
515 Ok(())
516 }
517}