1use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::parse::{Parse, ParseStream};
6use syn::{
7 Attribute, Fields, FnArg, Ident, ImplItem, ImplItemFn, Item, ItemEnum, ItemImpl, LifetimeParam,
8 Pat, Path, ReturnType, Token, Type, TypePath, Visibility, parse_quote,
9};
10
11#[proc_macro_attribute]
57pub fn opaque_enum(attr: TokenStream, item: TokenStream) -> TokenStream {
58 let args = match syn::parse::<OpaqueArgs>(attr) {
59 Ok(args) => args,
60 Err(err) => return err.to_compile_error().into(),
61 };
62
63 match syn::parse::<Item>(item) {
64 Ok(Item::Enum(item_enum)) => expand_enum(args, item_enum).into(),
65 Ok(Item::Impl(item_impl)) => expand_impl(item_impl).into(),
66 Ok(other) => syn::Error::new_spanned(
67 other,
68 "`#[opaque_enum]` can only be applied to enums and impl blocks",
69 )
70 .to_compile_error()
71 .into(),
72 Err(err) => err.to_compile_error().into(),
73 }
74}
75
76#[derive(Clone, Copy, Debug, Eq, PartialEq)]
77enum Storage {
78 Inline,
79 Boxed,
80}
81
82#[derive(Clone, Copy, Debug)]
83struct OpaqueArgs {
84 storage: Storage,
85}
86
87impl Parse for OpaqueArgs {
88 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
89 if input.is_empty() {
90 return Ok(Self {
91 storage: Storage::Inline,
92 });
93 }
94
95 let key: Ident = input.parse()?;
96 if key != "wrapper" {
97 return Err(syn::Error::new_spanned(
98 key,
99 "expected `wrapper = Box` or no arguments",
100 ));
101 }
102
103 input.parse::<Token![=]>()?;
104 let value: Ident = input.parse()?;
105 if value != "Box" {
106 return Err(syn::Error::new_spanned(
107 value,
108 "only `wrapper = Box` is currently supported",
109 ));
110 }
111
112 if !input.is_empty() {
113 input.parse::<Token![,]>()?;
114 if !input.is_empty() {
115 return Err(input.error("unexpected extra opaque_enum arguments"));
116 }
117 }
118
119 Ok(Self {
120 storage: Storage::Boxed,
121 })
122 }
123}
124
125fn expand_enum(args: OpaqueArgs, item: ItemEnum) -> proc_macro2::TokenStream {
126 let ItemEnum {
127 attrs,
128 vis,
129 ident,
130 generics,
131 variants,
132 ..
133 } = item;
134 let inner_ident = inner_ident(&ident);
135 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
136 let constructor_vis = constructor_vis(&vis);
137 let constructors = variants
138 .iter()
139 .map(|variant| constructor(&constructor_vis, &ident, &inner_ident, variant));
140 let public_attrs = public_attrs(&attrs);
141 let storage_field = storage_field(args.storage, &inner_ident, &ty_generics);
142 let from_body = from_body(args.storage);
143 let into_inner_body = into_inner_body(args.storage);
144 let as_inner_body = as_inner_body(args.storage);
145 let as_inner_mut_body = as_inner_mut_body(args.storage);
146 let projection_impls = projection_impls(args.storage, &ident, &inner_ident, &generics);
147 let repr = (args.storage == Storage::Inline).then(|| quote!(#[repr(transparent)]));
148
149 quote! {
150 #repr
151 #(#public_attrs)*
152 #vis struct #ident #generics #where_clause {
153 inner: #storage_field,
154 }
155
156 #(#attrs)*
157 enum #inner_ident #generics #where_clause {
158 #variants
159 }
160
161 impl #impl_generics #ident #ty_generics #where_clause {
162 #(#constructors)*
163
164 #[doc(hidden)]
165 fn __opaque_into_inner(self) -> #inner_ident #ty_generics {
166 #into_inner_body
167 }
168
169 #[doc(hidden)]
170 fn __opaque_as_inner(&self) -> &#inner_ident #ty_generics {
171 #as_inner_body
172 }
173
174 #[doc(hidden)]
175 fn __opaque_as_inner_mut(&mut self) -> &mut #inner_ident #ty_generics {
176 #as_inner_mut_body
177 }
178 }
179
180 impl #impl_generics ::std::convert::From<#inner_ident #ty_generics>
181 for #ident #ty_generics
182 #where_clause
183 {
184 fn from(inner: #inner_ident #ty_generics) -> Self {
185 #from_body
186 }
187 }
188
189 #projection_impls
190 }
191}
192
193fn projection_impls(
194 storage: Storage,
195 ident: &Ident,
196 inner_ident: &Ident,
197 generics: &syn::Generics,
198) -> proc_macro2::TokenStream {
199 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
200
201 let mut ref_generics = generics.clone();
202 ref_generics.params.insert(
203 0,
204 syn::GenericParam::Lifetime(LifetimeParam::new(parse_quote!('__opaque))),
205 );
206 let (ref_impl_generics, _, ref_where_clause) = ref_generics.split_for_impl();
207
208 let container_impls = (storage == Storage::Inline).then(|| {
209 quote! {
210 impl #impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
211 for ::std::sync::Arc<#ident #ty_generics>
212 #where_clause
213 {
214 type Output = ::std::sync::Arc<#inner_ident #ty_generics>;
215
216 fn project(self) -> Self::Output {
217 let ptr = ::std::sync::Arc::into_raw(self);
218 unsafe { ::std::sync::Arc::from_raw(ptr.cast::<#inner_ident #ty_generics>()) }
222 }
223 }
224
225 impl #impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
226 for ::std::rc::Rc<#ident #ty_generics>
227 #where_clause
228 {
229 type Output = ::std::rc::Rc<#inner_ident #ty_generics>;
230
231 fn project(self) -> Self::Output {
232 let ptr = ::std::rc::Rc::into_raw(self);
233 unsafe { ::std::rc::Rc::from_raw(ptr.cast::<#inner_ident #ty_generics>()) }
235 }
236 }
237 }
238 });
239
240 quote! {
241 impl #impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
242 for #ident #ty_generics
243 #where_clause
244 {
245 type Output = #inner_ident #ty_generics;
246
247 fn project(self) -> Self::Output {
248 self.__opaque_into_inner()
249 }
250 }
251
252 impl #ref_impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
253 for &'__opaque #ident #ty_generics
254 #ref_where_clause
255 {
256 type Output = &'__opaque #inner_ident #ty_generics;
257
258 fn project(self) -> Self::Output {
259 self.__opaque_as_inner()
260 }
261 }
262
263 impl #ref_impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
264 for &'__opaque mut #ident #ty_generics
265 #ref_where_clause
266 {
267 type Output = &'__opaque mut #inner_ident #ty_generics;
268
269 fn project(self) -> Self::Output {
270 self.__opaque_as_inner_mut()
271 }
272 }
273
274 impl #ref_impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
275 for ::std::pin::Pin<&'__opaque #ident #ty_generics>
276 #ref_where_clause
277 {
278 type Output = ::std::pin::Pin<&'__opaque #inner_ident #ty_generics>;
279
280 fn project(self) -> Self::Output {
281 unsafe { self.map_unchecked(|wrapper| wrapper.__opaque_as_inner()) }
283 }
284 }
285
286 impl #ref_impl_generics ::opaque_enum::OpaqueProject<#inner_ident #ty_generics>
287 for ::std::pin::Pin<&'__opaque mut #ident #ty_generics>
288 #ref_where_clause
289 {
290 type Output = ::std::pin::Pin<&'__opaque mut #inner_ident #ty_generics>;
291
292 fn project(self) -> Self::Output {
293 unsafe { self.map_unchecked_mut(|wrapper| wrapper.__opaque_as_inner_mut()) }
296 }
297 }
298
299 #container_impls
300 }
301}
302
303fn storage_field(
304 storage: Storage,
305 inner_ident: &Ident,
306 ty_generics: &syn::TypeGenerics<'_>,
307) -> proc_macro2::TokenStream {
308 match storage {
309 Storage::Inline => quote!(#inner_ident #ty_generics),
310 Storage::Boxed => quote!(::std::boxed::Box<#inner_ident #ty_generics>),
311 }
312}
313
314fn from_body(storage: Storage) -> proc_macro2::TokenStream {
315 match storage {
316 Storage::Inline => quote!(Self { inner }),
317 Storage::Boxed => quote!(Self {
318 inner: ::std::boxed::Box::new(inner)
319 }),
320 }
321}
322
323fn into_inner_body(storage: Storage) -> proc_macro2::TokenStream {
324 match storage {
325 Storage::Inline => quote!(self.inner),
326 Storage::Boxed => quote!(*self.inner),
327 }
328}
329
330fn as_inner_body(storage: Storage) -> proc_macro2::TokenStream {
331 match storage {
332 Storage::Inline => quote!(&self.inner),
333 Storage::Boxed => quote!(self.inner.as_ref()),
334 }
335}
336
337fn as_inner_mut_body(storage: Storage) -> proc_macro2::TokenStream {
338 match storage {
339 Storage::Inline => quote!(&mut self.inner),
340 Storage::Boxed => quote!(self.inner.as_mut()),
341 }
342}
343
344fn constructor_vis(public_vis: &Visibility) -> Visibility {
345 match public_vis {
346 Visibility::Public(_) => parse_quote!(pub(crate)),
347 other => other.clone(),
349 }
350}
351
352fn constructor(
353 vis: &Visibility,
354 public_ident: &Ident,
355 inner_ident: &Ident,
356 variant: &syn::Variant,
357) -> proc_macro2::TokenStream {
358 let variant_ident = &variant.ident;
359 let attrs = doc_attrs(&variant.attrs);
360
361 match &variant.fields {
362 Fields::Unit => {
363 quote! {
364 #(#attrs)*
365 #[allow(non_snake_case)]
366 #vis fn #variant_ident() -> Self {
367 #public_ident::from(#inner_ident::#variant_ident)
368 }
369 }
370 }
371 Fields::Unnamed(fields) => {
372 let args = fields.unnamed.iter().enumerate().map(|(index, field)| {
373 let ident = format_ident!("field_{index}");
374 let ty = &field.ty;
375 quote!(#ident: #ty)
376 });
377 let values = (0..fields.unnamed.len()).map(|index| format_ident!("field_{index}"));
378 quote! {
379 #(#attrs)*
380 #[allow(non_snake_case)]
381 #vis fn #variant_ident(#(#args),*) -> Self {
382 #public_ident::from(#inner_ident::#variant_ident(#(#values),*))
383 }
384 }
385 }
386 Fields::Named(fields) => {
387 let args = fields.named.iter().map(|field| {
388 let ident = field.ident.as_ref().expect("named field has an ident");
389 let ty = &field.ty;
390 quote!(#ident: #ty)
391 });
392 let values = fields
393 .named
394 .iter()
395 .map(|field| field.ident.as_ref().expect("named field has an ident"));
396 quote! {
397 #(#attrs)*
398 #[allow(non_snake_case)]
399 #vis fn #variant_ident(#(#args),*) -> Self {
400 #public_ident::from(#inner_ident::#variant_ident { #(#values),* })
401 }
402 }
403 }
404 }
405}
406
407#[allow(clippy::single_match_else)]
408fn expand_impl(item: ItemImpl) -> proc_macro2::TokenStream {
409 let Some(self_type_path) = self_type_path(&item.self_ty) else {
410 return syn::Error::new_spanned(
411 item.self_ty,
412 "`#[opaque_enum]` impl target must be a plain type path",
413 )
414 .to_compile_error();
415 };
416
417 let inner_ty = inner_ty(self_type_path);
418 let inner_impl = inner_impl(&item, &inner_ty);
419
420 let wrappers = match item
421 .items
422 .iter()
423 .map(|impl_item| wrapper_item(item.trait_.as_ref(), &inner_ty, impl_item))
424 .collect::<syn::Result<Vec<_>>>()
425 {
426 Ok(wrappers) => wrappers,
427 Err(err) => return err.to_compile_error(),
428 };
429
430 let attrs = &item.attrs;
431 let defaultness = &item.defaultness;
432 let unsafety = &item.unsafety;
433 let impl_token = &item.impl_token;
434 let generics = &item.generics;
435 let self_ty = &item.self_ty;
436 let public_impl = match &item.trait_ {
437 Some((bang, trait_path, for_token)) => quote! {
438 #(#attrs)*
439 #defaultness #unsafety #impl_token #generics #bang #trait_path #for_token #self_ty {
440 #(#wrappers)*
441 }
442 },
443 None => quote! {
444 #(#attrs)*
445 #defaultness #unsafety #impl_token #generics #self_ty {
446 #(#wrappers)*
447 }
448 },
449 };
450
451 quote! {
452 #public_impl
453 #inner_impl
454 }
455}
456
457fn wrapper_item(
458 trait_: Option<&(Option<Token![!]>, Path, Token![for])>,
459 inner_ty: &Type,
460 item: &ImplItem,
461) -> syn::Result<proc_macro2::TokenStream> {
462 let ImplItem::Fn(function) = item else {
463 return Err(syn::Error::new_spanned(
464 item,
465 "`#[opaque_enum]` impl blocks currently support methods only",
466 ));
467 };
468 wrapper_fn(trait_, inner_ty, function)
469}
470
471fn wrapper_fn(
472 trait_: Option<&(Option<Token![!]>, Path, Token![for])>,
473 inner_ty: &Type,
474 function: &ImplItemFn,
475) -> syn::Result<proc_macro2::TokenStream> {
476 if function.sig.asyncness.is_some() {
477 return Err(syn::Error::new_spanned(
478 function.sig.asyncness,
479 "`#[opaque_enum]` does not yet support async methods",
480 ));
481 }
482 if function.sig.constness.is_some() {
483 return Err(syn::Error::new_spanned(
484 function.sig.constness,
485 "`#[opaque_enum]` does not yet support const methods",
486 ));
487 }
488
489 let attrs = &function.attrs;
490 let vis = &function.vis;
491 let defaultness = &function.defaultness;
492 let sig = &function.sig;
493 let method = &function.sig.ident;
494 let args = function_args(function)?;
495 let receiver = has_receiver(function);
496 let call = inner_call(trait_, inner_ty, method, receiver, &args);
497 let body = if returns_self(&function.sig.output) {
502 quote!({
503 ::std::convert::Into::into(#call)
504 })
505 } else {
506 quote!({
507 #call
508 })
509 };
510
511 Ok(quote! {
512 #(#attrs)*
513 #defaultness #vis #sig #body
514 })
515}
516
517fn inner_call(
518 trait_: Option<&(Option<Token![!]>, Path, Token![for])>,
519 inner_ty: &Type,
520 method: &Ident,
521 receiver: bool,
522 args: &[Ident],
523) -> proc_macro2::TokenStream {
524 let mut call_args = Vec::new();
525 if receiver {
526 call_args.push(quote!(
527 ::opaque_enum::OpaqueProject::<#inner_ty>::project(self)
528 ));
529 }
530 call_args.extend(args.iter().map(|arg| quote!(#arg)));
531
532 match trait_ {
533 Some((_, trait_path, _)) => {
534 quote!(<#inner_ty as #trait_path>::#method(#(#call_args),*))
535 }
536 None => {
537 quote!(<#inner_ty>::#method(#(#call_args),*))
538 }
539 }
540}
541
542fn function_args(function: &ImplItemFn) -> syn::Result<Vec<Ident>> {
543 function
544 .sig
545 .inputs
546 .iter()
547 .filter_map(|arg| match arg {
548 FnArg::Receiver(_) => None,
549 FnArg::Typed(arg) => Some(arg),
550 })
551 .map(|arg| match arg.pat.as_ref() {
552 Pat::Ident(pat_ident) => Ok(pat_ident.ident.clone()),
553 _ => Err(syn::Error::new_spanned(
554 &arg.pat,
555 "`#[opaque_enum]` forwarding requires simple identifier arguments",
556 )),
557 })
558 .collect()
559}
560
561fn has_receiver(function: &ImplItemFn) -> bool {
562 matches!(function.sig.inputs.first(), Some(FnArg::Receiver(_)))
563}
564
565fn returns_self(output: &ReturnType) -> bool {
570 matches!(output, ReturnType::Type(_, ty) if type_is_self(ty))
571}
572
573fn type_is_self(ty: &Type) -> bool {
574 matches!(ty, Type::Path(type_path) if type_path.path.is_ident("Self"))
575}
576
577fn public_attrs(attrs: &[Attribute]) -> Vec<&Attribute> {
578 attrs
579 .iter()
580 .filter(|attr| !attr.path().is_ident("repr"))
581 .collect()
582}
583
584fn doc_attrs(attrs: &[Attribute]) -> Vec<&Attribute> {
585 attrs
586 .iter()
587 .filter(|attr| attr.path().is_ident("doc"))
588 .collect()
589}
590
591fn self_type_path(ty: &Type) -> Option<&TypePath> {
592 if let Type::Path(type_path) = ty {
593 Some(type_path)
594 } else {
595 None
596 }
597}
598
599fn inner_impl(item_impl: &ItemImpl, inner_ty: &Type) -> ItemImpl {
605 let mut inner_impl = item_impl.clone();
606 *inner_impl.self_ty = inner_ty.clone();
607 inner_impl
608}
609
610fn inner_ty(type_path: &TypePath) -> Type {
611 let mut type_path = type_path.clone();
612 let self_ident = &mut type_path.path.segments.last_mut().unwrap().ident;
613
614 let inner_ident = inner_ident(self_ident);
615
616 *self_ident = inner_ident;
617
618 Type::Path(type_path)
619}
620
621fn inner_ident(ident: &Ident) -> Ident {
622 format_ident!("{ident}Inner")
623}