nullable_utils_macros/
lib.rs1use proc_macro2::TokenStream;
11use quote::format_ident;
12use quote::quote;
13use quote::quote_spanned;
14use quote::ToTokens as _;
15use syn::braced;
16use syn::parse::Parse;
17use syn::parse::ParseStream;
18use syn::parse_macro_input;
19use syn::parse_quote;
20use syn::punctuated::Punctuated;
21use syn::spanned::Spanned as _;
22use syn::token;
23use syn::token::Comma;
24use syn::token::Enum;
25use syn::Attribute;
26use syn::Block;
27use syn::Field;
28use syn::Fields;
29use syn::FieldsUnnamed;
30use syn::FnArg;
31use syn::Ident;
32use syn::ItemEnum;
33use syn::Signature;
34use syn::Token;
35use syn::Variant;
36use syn::Visibility;
37
38#[proc_macro]
51pub fn nullable_wrapper(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
52 let wrapper = parse_macro_input!(input as NullableWrapper);
53 let expanded = expand(wrapper);
54 proc_macro::TokenStream::from(expanded)
55}
56
57fn expand(wrapper: NullableWrapper) -> TokenStream {
58 let NullableWrapper {
59 attrs,
60 vis,
61 enum_token,
62 ident,
63 variants,
64 fns,
65 } = wrapper;
66
67 let (enum_ident, struct_impl) = expand_struct_wrapper(&attrs, &vis, ident, &fns);
68
69 let fns = fns.into_iter().map(
70 |WrapperFn {
71 attrs,
72 sig,
73 default,
74 ..
75 }| {
76 let method = &sig.ident;
77 let args: Punctuated<_, Comma> = sig
78 .inputs
79 .iter()
80 .filter_map(|arg| match arg {
81 FnArg::Receiver(_) => None,
82 FnArg::Typed(pat) => Some(&pat.pat),
83 })
84 .collect();
85
86 let body = default.map_or_else(
87 || {
88 let matchers = variants.iter().map(
89 |Variant { ident, .. }| quote!(Self::#ident(inner) => inner.#method(#args)),
90 );
91
92 quote!({
93 match self {
94 #(#matchers),*
95 }
96 })
97 },
98 Block::into_token_stream,
99 );
100
101 quote! {
102 #(#attrs)*
103 #sig #body
104 }
105 },
106 );
107
108 let from_impls = variants.iter().map(|var @ Variant { ident, fields, .. }| {
109 let Fields::Unnamed(FieldsUnnamed { unnamed, .. }) = fields else {
111 panic!()
112 };
113 let Field { ty, .. } = &unnamed[0];
114
115 quote_spanned! { var.span() =>
116 impl From<#ty> for #enum_ident {
117 fn from(value: #ty) -> Self {
118 Self::#ident(value)
119 }
120 }
121 }
122 });
123
124 let try_into_impls = variants.iter().map(|var @ Variant { ident, fields, .. }| {
125 let Fields::Unnamed(FieldsUnnamed { unnamed, .. }) = fields else {
127 panic!()
128 };
129 let Field { ty, .. } = &unnamed[0];
130
131 quote_spanned! { var.span() =>
132 impl TryFrom<#enum_ident> for #ty {
133 type Error = ();
134
135 fn try_from(value: #enum_ident) -> Result<Self, Self::Error> {
136 match value {
137 #enum_ident::#ident(inner) => Ok(inner),
138 _ => Err(())
139 }
140 }
141 }
142 }
143 });
144
145 let expanded = quote! {
146 #struct_impl
147
148 #(#attrs)*
149 #enum_token #enum_ident {
150 #variants
151 }
152
153 impl #enum_ident {
154 #(#fns)*
155 }
156
157 #(#from_impls)*
158
159 #(#try_into_impls)*
160 };
161 expanded
162}
163
164fn expand_struct_wrapper(
165 attrs: &[Attribute],
166 vis: &Visibility,
167 ident: Ident,
168 fns: &[WrapperFn],
169) -> (Ident, TokenStream) {
170 let Visibility::Public(pub_token) = vis else {
171 return (ident, TokenStream::new());
172 };
173
174 let enum_ident = format_ident!("{}Inner", ident);
175
176 let fns = fns.iter().map(
177 |WrapperFn {
178 attrs, vis, sig, ..
179 }| {
180 let method = &sig.ident;
181 let args: Punctuated<_, Comma> = sig
182 .inputs
183 .iter()
184 .filter_map(|arg| match arg {
185 FnArg::Receiver(_) => None,
186 FnArg::Typed(pat) => Some(&pat.pat),
187 })
188 .collect();
189
190 let body = quote!({
191 self.0.#method(#args)
192 });
193
194 quote! {
195 #(#attrs)*
196 #vis #sig #body
197 }
198 },
199 );
200
201 let token_stream = quote! {
202 #(#attrs)*
203 #[repr(transparent)]
204 #pub_token struct #ident(#enum_ident);
205
206 impl #ident {
207 #(#fns)*
208 }
209
210 impl<T> From<T> for #ident where #enum_ident: From<T> {
211 fn from(value: T) -> Self {
212 Self(#enum_ident::from(value))
213 }
214 }
215 };
216
217 (enum_ident, token_stream)
218}
219
220struct NullableWrapper {
221 attrs: Vec<Attribute>,
222 vis: Visibility,
223 enum_token: Enum,
224 ident: Ident,
225 variants: Punctuated<Variant, Comma>,
226 fns: Vec<WrapperFn>,
227}
228
229impl Parse for NullableWrapper {
231 fn parse(input: ParseStream) -> syn::Result<Self> {
232 let ItemEnum {
233 attrs,
234 vis,
235 enum_token,
236 ident,
237 mut variants,
238 ..
239 } = input.parse()?;
240
241 for variant in &mut variants {
242 match variant.fields {
243 Fields::Unit => {
244 let name = &variant.ident;
245 variant.fields = Fields::Unnamed(parse_quote!((#name)));
246 }
247 Fields::Unnamed(FieldsUnnamed {
248 ref mut unnamed, ..
249 }) if unnamed.len() == 1 => {}
250 _ => {
251 return Err(syn::Error::new_spanned(
252 &variant,
253 "only unit and new-type variants are supported",
254 ))
255 }
256 }
257 }
258 let mut fns = Vec::new();
261 if !input.is_empty() {
262 let content;
263 braced!(content in input);
264
265 while !content.is_empty() {
266 fns.push(content.parse()?);
267 }
268 }
269
270 Ok(NullableWrapper {
271 attrs,
272 vis,
273 enum_token,
274 ident,
275 variants,
276 fns,
277 })
278 }
279}
280
281struct WrapperFn {
282 pub attrs: Vec<Attribute>,
283 pub vis: Visibility,
284 pub sig: Signature,
285 pub default: Option<Block>,
286 pub semi_token: Option<Token![;]>,
287}
288
289impl Parse for WrapperFn {
290 fn parse(input: ParseStream) -> syn::Result<Self> {
291 let attrs = input.call(Attribute::parse_outer)?;
292 let vis: Visibility = input.parse()?;
293 let sig: Signature = input.parse()?;
294
295 let lookahead = input.lookahead1();
296 let (default, semi_token) = if lookahead.peek(token::Brace) {
297 let block = input.parse()?;
298 (Some(block), None)
299 } else if lookahead.peek(Token![;]) {
300 let semi_token: Token![;] = input.parse()?;
301 (None, Some(semi_token))
302 } else {
303 return Err(lookahead.error());
304 };
305
306 Ok(Self {
307 attrs,
308 vis,
309 sig,
310 default,
311 semi_token,
312 })
313 }
314}