moverox_codegen/move_enum/
mod.rs1use std::collections::{HashMap, HashSet};
2
3use move_syn::{FieldsKind, ItemPath};
4use quote::quote;
5use unsynn::{Ident, ToTokens as _, TokenStream};
6
7use crate::generics::GenericsExt;
8use crate::{ItemContext, Result, named_fields, positional_fields};
9
10pub(super) fn to_rust(
12 this: &move_syn::Enum,
13 otw_types: HashSet<Ident>,
14 ctx: ItemContext,
15) -> Result<TokenStream> {
16 let move_syn::Enum {
17 ident, generics, ..
18 } = this;
19
20 let extra_attrs: TokenStream = ctx
21 .package
22 .into_iter()
23 .map(unsynn::ToTokens::to_token_stream)
24 .map(|addr| quote!(#[move_(address = #addr)]))
25 .chain(ctx.module.map(|ident| quote!(#[move_(module = #ident)])))
26 .collect();
27
28 let type_generics = generics
29 .as_ref()
30 .map(|g| g.type_generics(ctx.thecrate, otw_types))
31 .transpose()
32 .map(Option::unwrap_or_default)?;
33
34 let mut phantoms = unused_phantoms(this);
35 let variants = this
36 .variants()
37 .map(|var| variant_to_rust(var, &std::mem::take(&mut phantoms), ctx.address_map));
39
40 let thecrate = ctx.thecrate;
43 let serde_crate = format!("{thecrate}::serde").replace(" ", "");
44 Ok(quote! {
45 #[derive(
46 Clone,
47 Debug,
48 PartialEq,
49 Eq,
50 Hash,
51 #thecrate::traits::MoveDatatype,
52 #thecrate::serde::Deserialize,
53 #thecrate::serde::Serialize,
54 )]
55 #[move_(crate = #thecrate::traits)]
56 #[serde(crate = #serde_crate)]
57 #extra_attrs
58 #[allow(non_snake_case)]
59 pub enum #ident #type_generics {
60 #(#variants),*
61 }
62 })
63}
64
65fn unused_phantoms(this: &move_syn::Enum) -> Vec<Ident> {
67 let Some(generics) = this.generics.as_ref() else {
68 return Vec::new(); };
70
71 let maybe_phantom_leaf_types: HashSet<_> = enum_leaf_types(this)
72 .filter_map(|path| match path {
73 ItemPath::Ident(ident) => Some(ident),
74 _ => None,
75 })
76 .collect();
77
78 generics
79 .phantoms()
80 .filter(|&ident| !maybe_phantom_leaf_types.contains(ident))
81 .cloned()
82 .collect()
83}
84
85fn enum_leaf_types(this: &move_syn::Enum) -> Box<dyn Iterator<Item = &ItemPath> + '_> {
89 this.variants()
90 .flat_map(|var| &var.fields)
91 .flat_map(|fields| match fields {
92 FieldsKind::Positional(positional) => {
93 leaf_types_recursive(positional.fields().map(|field| &field.ty).boxed())
94 }
95 FieldsKind::Named(named) => {
96 leaf_types_recursive(named.fields().map(|field| &field.ty).boxed())
97 }
98 })
99 .boxed()
100}
101
102fn leaf_types_recursive<'a>(
103 types: Box<dyn Iterator<Item = &'a move_syn::Type> + 'a>,
104) -> Box<dyn Iterator<Item = &'a ItemPath> + 'a> {
105 types
106 .into_iter()
107 .flat_map(|t| {
108 t.type_args.as_ref().map_or_else(
109 || std::iter::once(&t.path).boxed(),
110 |t_args| leaf_types_recursive(t_args.types().boxed()),
111 )
112 })
113 .boxed()
114}
115
116fn variant_to_rust(
117 this: &move_syn::EnumVariant,
118 phantoms: &[Ident],
119 address_map: &HashMap<Ident, TokenStream>,
120) -> TokenStream {
121 use move_syn::FieldsKind as K;
122 let move_syn::EnumVariant {
123 attrs,
124 ident,
125 fields,
126 } = this;
127 let attrs = attrs
128 .iter()
129 .filter(|attr| attr.is_doc())
130 .map(|attr| attr.to_token_stream());
131
132 let bool_if_empty = false;
134 let visibility = false;
136
137 let default_fields = (!phantoms.is_empty()).then(|| {
140 positional_fields::to_rust(
141 &Default::default(),
142 phantoms.iter(),
143 address_map,
144 bool_if_empty,
145 visibility,
146 )
147 });
148
149 let fields = fields
150 .as_ref()
151 .map(|kind| match kind {
152 K::Named(named) => {
153 named_fields::to_rust(named, phantoms.iter(), address_map, visibility)
154 }
155 K::Positional(positional) => positional_fields::to_rust(
156 positional,
157 phantoms.iter(),
158 address_map,
159 bool_if_empty,
160 visibility,
161 ),
162 })
163 .or(default_fields)
164 .unwrap_or_default();
165
166 quote! {
167 #(#attrs)*
168 #ident #fields
169 }
170}
171
172trait BoxedIter<'a>: Iterator + 'a {
173 fn boxed(self) -> Box<dyn Iterator<Item = Self::Item> + 'a>
174 where
175 Self: Sized,
176 {
177 Box::new(self)
178 }
179}
180
181impl<'a, T: Iterator + 'a> BoxedIter<'a> for T {}