Skip to main content

opaque_enum_macros/
lib.rs

1//! Proc macro implementation for `opaque-enum`.
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::parse::{Parse, ParseStream};
6use syn::{
7    Attribute, Fields, FnArg, Ident, ImplItem, ImplItemFn, Item, ItemEnum, ItemImpl, LifetimeParam,
8    Pat, Path, ReturnType, Token, Type, TypePath, Visibility, parse_quote,
9};
10
11/// Hides enum variants behind an opaque struct wrapper.
12///
13/// This macro lets a public type keep an enum-like authoring experience while
14/// exposing an opaque wrapper instead of public enum variants. This prevents
15/// breaking changes when you add, remove, or modify variants.
16///
17/// # Examples
18///
19/// ```ignore
20/// # use opaque_enum_macros::opaque_enum;
21/// use std::fmt::{self, Display, Formatter};
22///
23/// #[opaque_enum]
24/// #[derive(Debug)]
25/// pub enum DatabaseError {
26///     ConnectionFailed(String),
27///     QueryFailed { query: String, reason: String },
28///     PermissionDenied,
29/// }
30///
31/// #[opaque_enum]
32/// impl Display for DatabaseError {
33///     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
34///         match self {
35///             Self::ConnectionFailed(err) => write!(f, "connection failed: {err}"),
36///             Self::QueryFailed { query, reason } => {
37///                 write!(f, "query `{query}` failed: {reason}")
38///             }
39///             Self::PermissionDenied => write!(f, "permission denied"),
40///         }
41///     }
42/// }
43/// ```
44///
45/// You can also opt-in to boxing the representation by specifying `wrapper = Box`:
46///
47/// ```ignore
48/// # use opaque_enum_macros::opaque_enum;
49/// #[opaque_enum(wrapper = Box)]
50/// #[derive(Debug)]
51/// pub enum LargeError {
52///     Variant1([u8; 1024]),
53///     Variant2,
54/// }
55/// ```
56#[proc_macro_attribute]
57pub fn opaque_enum(attr: TokenStream, item: TokenStream) -> TokenStream {
58    let args = match syn::parse::<OpaqueArgs>(attr) {
59        Ok(args) => args,
60        Err(err) => return err.to_compile_error().into(),
61    };
62
63    match syn::parse::<Item>(item) {
64        Ok(Item::Enum(item_enum)) => expand_enum(args, item_enum).into(),
65        Ok(Item::Impl(item_impl)) => expand_impl(item_impl).into(),
66        Ok(other) => syn::Error::new_spanned(
67            other,
68            "`#[opaque_enum]` can only be applied to enums and impl blocks",
69        )
70        .to_compile_error()
71        .into(),
72        Err(err) => err.to_compile_error().into(),
73    }
74}
75
76#[derive(Clone, Copy, Debug, Eq, PartialEq)]
77enum Storage {
78    Inline,
79    Boxed,
80}
81
82#[derive(Clone, Copy, Debug)]
83struct OpaqueArgs {
84    storage: Storage,
85}
86
87impl Parse for OpaqueArgs {
88    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
89        if input.is_empty() {
90            return Ok(Self {
91                storage: Storage::Inline,
92            });
93        }
94
95        let key: Ident = input.parse()?;
96        if key != "wrapper" {
97            return Err(syn::Error::new_spanned(
98                key,
99                "expected `wrapper = Box` or no arguments",
100            ));
101        }
102
103        input.parse::<Token![=]>()?;
104        let value: Ident = input.parse()?;
105        if value != "Box" {
106            return Err(syn::Error::new_spanned(
107                value,
108                "only `wrapper = Box` is currently supported",
109            ));
110        }
111
112        if !input.is_empty() {
113            input.parse::<Token![,]>()?;
114            if !input.is_empty() {
115                return Err(input.error("unexpected extra opaque_enum arguments"));
116            }
117        }
118
119        Ok(Self {
120            storage: Storage::Boxed,
121        })
122    }
123}
124
125fn expand_enum(args: OpaqueArgs, item: ItemEnum) -> proc_macro2::TokenStream {
126    let ItemEnum {
127        attrs,
128        vis,
129        ident,
130        generics,
131        variants,
132        ..
133    } = item;
134    let inner_ident = inner_ident(&ident);
135    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
136    let constructor_vis = constructor_vis(&vis);
137    let constructors = variants
138        .iter()
139        .map(|variant| constructor(&constructor_vis, &ident, &inner_ident, variant));
140    let public_attrs = public_attrs(&attrs);
141    let storage_field = storage_field(args.storage, &inner_ident, &ty_generics);
142    let from_body = from_body(args.storage);
143    let into_inner_body = into_inner_body(args.storage);
144    let as_inner_body = as_inner_body(args.storage);
145    let as_inner_mut_body = as_inner_mut_body(args.storage);
146    let projection_impls = projection_impls(args.storage, &ident, &inner_ident, &generics);
147    let repr = (args.storage == Storage::Inline).then(|| quote!(#[repr(transparent)]));
148
149    quote! {
150        #repr
151        #(#public_attrs)*
152        #vis struct #ident #generics #where_clause {
153            inner: #storage_field,
154        }
155
156        #(#attrs)*
157        enum #inner_ident #generics #where_clause {
158            #variants
159        }
160
161        impl #impl_generics #ident #ty_generics #where_clause {
162            #(#constructors)*
163
164            #[doc(hidden)]
165            fn __opaque_into_inner(self) -> #inner_ident #ty_generics {
166                #into_inner_body
167            }
168
169            #[doc(hidden)]
170            fn __opaque_as_inner(&self) -> &#inner_ident #ty_generics {
171                #as_inner_body
172            }
173
174            #[doc(hidden)]
175            fn __opaque_as_inner_mut(&mut self) -> &mut #inner_ident #ty_generics {
176                #as_inner_mut_body
177            }
178        }
179
180        impl #impl_generics ::std::convert::From<#inner_ident #ty_generics>
181            for #ident #ty_generics
182            #where_clause
183        {
184            fn from(inner: #inner_ident #ty_generics) -> Self {
185                #from_body
186            }
187        }
188
189        #projection_impls
190    }
191}
192
193fn projection_impls(
194    storage: Storage,
195    ident: &Ident,
196    inner_ident: &Ident,
197    generics: &syn::Generics,
198) -> proc_macro2::TokenStream {
199    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
200
201    let mut ref_generics = generics.clone();
202    ref_generics.params.insert(
203        0,
204        syn::GenericParam::Lifetime(LifetimeParam::new(parse_quote!('__opaque))),
205    );
206    let (ref_impl_generics, _, ref_where_clause) = ref_generics.split_for_impl();
207
208    let container_impls = (storage == Storage::Inline).then(|| {
209        quote! {
210            impl #impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
211                for ::std::sync::Arc<#ident #ty_generics>
212                #where_clause
213            {
214                type Output = ::std::sync::Arc<#inner_ident #ty_generics>;
215
216                fn project(self) -> Self::Output {
217                    let ptr = ::std::sync::Arc::into_raw(self);
218                    // SAFETY: inline `#[opaque_enum]` emits a transparent
219                    // wrapper over the inner enum and implements
220                    // `OpaqueTransparent` for the wrapper.
221                    unsafe { ::std::sync::Arc::from_raw(ptr.cast::<#inner_ident #ty_generics>()) }
222                }
223            }
224
225            impl #impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
226                for ::std::rc::Rc<#ident #ty_generics>
227                #where_clause
228            {
229                type Output = ::std::rc::Rc<#inner_ident #ty_generics>;
230
231                fn project(self) -> Self::Output {
232                    let ptr = ::std::rc::Rc::into_raw(self);
233                    // SAFETY: see the analogous `Arc` implementation above.
234                    unsafe { ::std::rc::Rc::from_raw(ptr.cast::<#inner_ident #ty_generics>()) }
235                }
236            }
237        }
238    });
239
240    quote! {
241        impl #impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
242            for #ident #ty_generics
243            #where_clause
244        {
245            type Output = #inner_ident #ty_generics;
246
247            fn project(self) -> Self::Output {
248                self.__opaque_into_inner()
249            }
250        }
251
252        impl #ref_impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
253            for &'__opaque #ident #ty_generics
254            #ref_where_clause
255        {
256            type Output = &'__opaque #inner_ident #ty_generics;
257
258            fn project(self) -> Self::Output {
259                self.__opaque_as_inner()
260            }
261        }
262
263        impl #ref_impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
264            for &'__opaque mut #ident #ty_generics
265            #ref_where_clause
266        {
267            type Output = &'__opaque mut #inner_ident #ty_generics;
268
269            fn project(self) -> Self::Output {
270                self.__opaque_as_inner_mut()
271            }
272        }
273
274        impl #ref_impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
275            for ::std::pin::Pin<&'__opaque #ident #ty_generics>
276            #ref_where_clause
277        {
278            type Output = ::std::pin::Pin<&'__opaque #inner_ident #ty_generics>;
279
280            fn project(self) -> Self::Output {
281                // SAFETY: pinning is structurally transparent for immutable references.
282                unsafe { self.map_unchecked(|wrapper| wrapper.__opaque_as_inner()) }
283            }
284        }
285
286        impl #ref_impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
287            for ::std::pin::Pin<&'__opaque mut #ident #ty_generics>
288            #ref_where_clause
289        {
290            type Output = ::std::pin::Pin<&'__opaque mut #inner_ident #ty_generics>;
291
292            fn project(self) -> Self::Output {
293                // SAFETY: wrapper struct is transparent or boxes the inner type,
294                // preserving pinning guarantees.
295                unsafe { self.map_unchecked_mut(|wrapper| wrapper.__opaque_as_inner_mut()) }
296            }
297        }
298
299        #container_impls
300    }
301}
302
303fn storage_field(
304    storage: Storage,
305    inner_ident: &Ident,
306    ty_generics: &syn::TypeGenerics<'_>,
307) -> proc_macro2::TokenStream {
308    match storage {
309        Storage::Inline => quote!(#inner_ident #ty_generics),
310        Storage::Boxed => quote!(::std::boxed::Box<#inner_ident #ty_generics>),
311    }
312}
313
314fn from_body(storage: Storage) -> proc_macro2::TokenStream {
315    match storage {
316        Storage::Inline => quote!(Self { inner }),
317        Storage::Boxed => quote!(Self {
318            inner: ::std::boxed::Box::new(inner)
319        }),
320    }
321}
322
323fn into_inner_body(storage: Storage) -> proc_macro2::TokenStream {
324    match storage {
325        Storage::Inline => quote!(self.inner),
326        Storage::Boxed => quote!(*self.inner),
327    }
328}
329
330fn as_inner_body(storage: Storage) -> proc_macro2::TokenStream {
331    match storage {
332        Storage::Inline => quote!(&self.inner),
333        Storage::Boxed => quote!(self.inner.as_ref()),
334    }
335}
336
337fn as_inner_mut_body(storage: Storage) -> proc_macro2::TokenStream {
338    match storage {
339        Storage::Inline => quote!(&mut self.inner),
340        Storage::Boxed => quote!(self.inner.as_mut()),
341    }
342}
343
344fn constructor_vis(public_vis: &Visibility) -> Visibility {
345    match public_vis {
346        Visibility::Public(_) => parse_quote!(pub(crate)),
347        // TODO
348        other => other.clone(),
349    }
350}
351
352fn constructor(
353    vis: &Visibility,
354    public_ident: &Ident,
355    inner_ident: &Ident,
356    variant: &syn::Variant,
357) -> proc_macro2::TokenStream {
358    let variant_ident = &variant.ident;
359    let attrs = doc_attrs(&variant.attrs);
360
361    match &variant.fields {
362        Fields::Unit => {
363            quote! {
364                #(#attrs)*
365                #[allow(non_snake_case)]
366                #vis fn #variant_ident() -> Self {
367                    #public_ident::from(#inner_ident::#variant_ident)
368                }
369            }
370        }
371        Fields::Unnamed(fields) => {
372            let args = fields.unnamed.iter().enumerate().map(|(index, field)| {
373                let ident = format_ident!("field_{index}");
374                let ty = &field.ty;
375                quote!(#ident: #ty)
376            });
377            let values = (0..fields.unnamed.len()).map(|index| format_ident!("field_{index}"));
378            quote! {
379                #(#attrs)*
380                #[allow(non_snake_case)]
381                #vis fn #variant_ident(#(#args),*) -> Self {
382                    #public_ident::from(#inner_ident::#variant_ident(#(#values),*))
383                }
384            }
385        }
386        Fields::Named(fields) => {
387            let args = fields.named.iter().map(|field| {
388                let ident = field.ident.as_ref().expect("named field has an ident");
389                let ty = &field.ty;
390                quote!(#ident: #ty)
391            });
392            let values = fields
393                .named
394                .iter()
395                .map(|field| field.ident.as_ref().expect("named field has an ident"));
396            quote! {
397                #(#attrs)*
398                #[allow(non_snake_case)]
399                #vis fn #variant_ident(#(#args),*) -> Self {
400                    #public_ident::from(#inner_ident::#variant_ident { #(#values),* })
401                }
402            }
403        }
404    }
405}
406
407#[allow(clippy::single_match_else)]
408fn expand_impl(item: ItemImpl) -> proc_macro2::TokenStream {
409    let Some(self_type_path) = self_type_path(&item.self_ty) else {
410        return syn::Error::new_spanned(
411            item.self_ty,
412            "`#[opaque_enum]` impl target must be a plain type path",
413        )
414        .to_compile_error();
415    };
416
417    let inner_ty = inner_ty(self_type_path);
418    let inner_impl = inner_impl(&item, &inner_ty);
419
420    let wrappers = match item
421        .items
422        .iter()
423        .map(|impl_item| wrapper_item(item.trait_.as_ref(), &inner_ty, impl_item))
424        .collect::<syn::Result<Vec<_>>>()
425    {
426        Ok(wrappers) => wrappers,
427        Err(err) => return err.to_compile_error(),
428    };
429
430    let attrs = &item.attrs;
431    let defaultness = &item.defaultness;
432    let unsafety = &item.unsafety;
433    let impl_token = &item.impl_token;
434    let generics = &item.generics;
435    let self_ty = &item.self_ty;
436    let public_impl = match &item.trait_ {
437        Some((bang, trait_path, for_token)) => quote! {
438            #(#attrs)*
439            #defaultness #unsafety #impl_token #generics #bang #trait_path #for_token #self_ty {
440                #(#wrappers)*
441            }
442        },
443        None => quote! {
444            #(#attrs)*
445            #defaultness #unsafety #impl_token #generics #self_ty {
446                #(#wrappers)*
447            }
448        },
449    };
450
451    quote! {
452        #public_impl
453        #inner_impl
454    }
455}
456
457fn wrapper_item(
458    trait_: Option<&(Option<Token![!]>, Path, Token![for])>,
459    inner_ty: &Type,
460    item: &ImplItem,
461) -> syn::Result<proc_macro2::TokenStream> {
462    let ImplItem::Fn(function) = item else {
463        return Err(syn::Error::new_spanned(
464            item,
465            "`#[opaque_enum]` impl blocks currently support methods only",
466        ));
467    };
468    wrapper_fn(trait_, inner_ty, function)
469}
470
471fn wrapper_fn(
472    trait_: Option<&(Option<Token![!]>, Path, Token![for])>,
473    inner_ty: &Type,
474    function: &ImplItemFn,
475) -> syn::Result<proc_macro2::TokenStream> {
476    if function.sig.asyncness.is_some() {
477        return Err(syn::Error::new_spanned(
478            function.sig.asyncness,
479            "`#[opaque_enum]` does not yet support async methods",
480        ));
481    }
482    if function.sig.constness.is_some() {
483        return Err(syn::Error::new_spanned(
484            function.sig.constness,
485            "`#[opaque_enum]` does not yet support const methods",
486        ));
487    }
488
489    let attrs = &function.attrs;
490    let vis = &function.vis;
491    let defaultness = &function.defaultness;
492    let sig = &function.sig;
493    let method = &function.sig.ident;
494    let args = function_args(function)?;
495    let receiver = has_receiver(function);
496    let call = inner_call(trait_, inner_ty, method, receiver, &args);
497    // NOTE: only a bare `-> Self` return is detected and wrapped with `Into::into`.
498    // Methods returning composite types that *contain* `Self` (e.g. `Option<Self>`,
499    // `Result<Self, E>`) are not rewritten and will produce a type-mismatch compile
500    // error. An `InverseProject` trait is planned to handle those cases.
501    let body = if returns_self(&function.sig.output) {
502        quote!({
503            ::std::convert::Into::into(#call)
504        })
505    } else {
506        quote!({
507            #call
508        })
509    };
510
511    Ok(quote! {
512        #(#attrs)*
513        #defaultness #vis #sig #body
514    })
515}
516
517fn inner_call(
518    trait_: Option<&(Option<Token![!]>, Path, Token![for])>,
519    inner_ty: &Type,
520    method: &Ident,
521    receiver: bool,
522    args: &[Ident],
523) -> proc_macro2::TokenStream {
524    let mut call_args = Vec::new();
525    if receiver {
526        call_args.push(quote!(
527            ::opaque_enum::OpaqueProject::<#inner_ty>::project(self)
528        ));
529    }
530    call_args.extend(args.iter().map(|arg| quote!(#arg)));
531
532    match trait_ {
533        Some((_, trait_path, _)) => {
534            quote!(<#inner_ty as #trait_path>::#method(#(#call_args),*))
535        }
536        None => {
537            quote!(<#inner_ty>::#method(#(#call_args),*))
538        }
539    }
540}
541
542fn function_args(function: &ImplItemFn) -> syn::Result<Vec<Ident>> {
543    function
544        .sig
545        .inputs
546        .iter()
547        .filter_map(|arg| match arg {
548            FnArg::Receiver(_) => None,
549            FnArg::Typed(arg) => Some(arg),
550        })
551        .map(|arg| match arg.pat.as_ref() {
552            Pat::Ident(pat_ident) => Ok(pat_ident.ident.clone()),
553            _ => Err(syn::Error::new_spanned(
554                &arg.pat,
555                "`#[opaque_enum]` forwarding requires simple identifier arguments",
556            )),
557        })
558        .collect()
559}
560
561fn has_receiver(function: &ImplItemFn) -> bool {
562    matches!(function.sig.inputs.first(), Some(FnArg::Receiver(_)))
563}
564
565// Returns true only for a bare `-> Self`. References (`-> &Self`, `-> &mut Self`)
566// and composite types (`-> Option<Self>`) are intentionally excluded: wrapping
567// references requires transmuting the pointer (only sound for inline storage), and
568// wrapping composites requires a not-yet-implemented `InverseProject` pass.
569fn returns_self(output: &ReturnType) -> bool {
570    matches!(output, ReturnType::Type(_, ty) if type_is_self(ty))
571}
572
573fn type_is_self(ty: &Type) -> bool {
574    matches!(ty, Type::Path(type_path) if type_path.path.is_ident("Self"))
575}
576
577fn public_attrs(attrs: &[Attribute]) -> Vec<&Attribute> {
578    attrs
579        .iter()
580        .filter(|attr| !attr.path().is_ident("repr"))
581        .collect()
582}
583
584fn doc_attrs(attrs: &[Attribute]) -> Vec<&Attribute> {
585    attrs
586        .iter()
587        .filter(|attr| attr.path().is_ident("doc"))
588        .collect()
589}
590
591fn self_type_path(ty: &Type) -> Option<&TypePath> {
592    if let Type::Path(type_path) = ty {
593        Some(type_path)
594    } else {
595        None
596    }
597}
598
599// Repoints the impl block's `Self` type to the inner enum type. This means all
600// `self.method()` calls inside the decorated block are resolved against the inner
601// enum, not the public wrapper. As a result, only methods defined in other
602// `#[opaque_enum]`-decorated `impl` blocks are callable on `self`; methods
603// defined solely on the outer wrapper type are not in scope here.
604fn inner_impl(item_impl: &ItemImpl, inner_ty: &Type) -> ItemImpl {
605    let mut inner_impl = item_impl.clone();
606    *inner_impl.self_ty = inner_ty.clone();
607    inner_impl
608}
609
610fn inner_ty(type_path: &TypePath) -> Type {
611    let mut type_path = type_path.clone();
612    let self_ident = &mut type_path.path.segments.last_mut().unwrap().ident;
613
614    let inner_ident = inner_ident(self_ident);
615
616    *self_ident = inner_ident;
617
618    Type::Path(type_path)
619}
620
621fn inner_ident(ident: &Ident) -> Ident {
622    format_ident!("{ident}Inner")
623}