all_the_same/
lib.rs

1//! If you ever had code that looks like this:
2//!
3//! ```
4//! use std::io;
5//! use std::pin::Pin;
6//! use std::task::{Context, Poll};
7//! use tokio::io::AsyncWrite;
8//! use tokio::net::{TcpStream, UnixStream};
9//!
10//! enum Stream {
11//!     Tcp(TcpStream),
12//!     Unix(UnixStream),
13//!     Custom(Box<dyn AsyncWrite + Unpin + 'static>),
14//! }
15//!
16//! impl AsyncWrite for Stream {
17//!     fn poll_write(
18//!         self: Pin<&mut Self>,
19//!         cx: &mut Context<'_>,
20//!         buf: &[u8],
21//!     ) -> Poll<Result<usize, io::Error>> {
22//!         match self.get_mut() {
23//!             Stream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
24//!             Stream::Unix(s) => Pin::new(s).poll_write(cx, buf),
25//!             Stream::Custom(s) => Pin::new(s).poll_write(cx, buf),
26//!         }
27//!     }
28//!
29//!     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
30//!         match self.get_mut() {
31//!             Stream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
32//!             Stream::Unix(s) => Pin::new(s).poll_shutdown(cx),
33//!             Stream::Custom(s) => Pin::new(s).poll_shutdown(cx),
34//!         }
35//!     }
36//!
37//!     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
38//!         match self.get_mut() {
39//!             Stream::Tcp(s) => Pin::new(s).poll_flush(cx),
40//!             Stream::Unix(s) => Pin::new(s).poll_flush(cx),
41//!             Stream::Custom(s) => Pin::new(s).poll_flush(cx),
42//!         }
43//!     }
44//! }
45//! ```
46//!
47//! with the help of the macro you can now replace it with:
48//! ```
49//! use std::io;
50//! use std::pin::Pin;
51//! use std::task::{Context, Poll};
52//! use tokio::io::AsyncWrite;
53//! use tokio::net::{TcpStream, UnixStream};
54//! use all_the_same::all_the_same;
55//!
56//! enum Stream {
57//!     Tcp(TcpStream),
58//!     Unix(UnixStream),
59//!     Custom(Box<dyn AsyncWrite + Unpin + 'static>),
60//! }
61//!
62//! impl AsyncWrite for Stream {
63//!     fn poll_write(
64//!         self: Pin<&mut Self>,
65//!         cx: &mut Context<'_>,
66//!         buf: &[u8],
67//!     ) -> Poll<Result<usize, io::Error>> {
68//!         all_the_same!(match self.get_mut() {
69//!             Stream::[Tcp, Unix, Custom](s) => Pin::new(s).poll_write(cx, buf)
70//!         })
71//!     }
72//!
73//!     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
74//!         all_the_same!(match self.get_mut() {
75//!             Stream::[Tcp, Unix, Custom](s) => Pin::new(s).poll_shutdown(cx)
76//!         })
77//!     }
78//!
79//!     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
80//!         all_the_same!(match self.get_mut() {
81//!             Stream::[Tcp, Unix, Custom](s) => Pin::new(s).poll_flush(cx)
82//!         })
83//!     }
84//! }
85//! ```
86//!
87//! # Feature gated enum variants, etc.
88//!
89//! Btw, you can add attributes that will be applied to the match arms, to deal with feature-gated
90//! enum variants:
91//!
92//! ```
93//! use all_the_same::all_the_same;
94//!
95//! enum Variants {
96//!     Foo(String),
97//!     
98//!     #[cfg(test)]
99//!     Bar(String)
100//! }
101//!
102//! impl Variants {
103//!     pub fn value(&self) -> &str {
104//!         all_the_same!(match self {
105//!             Self::[Foo, #[cfg(test)]Bar](v) => v
106//!         })
107//!     }
108//! }
109//! ```
110
111use proc_macro::TokenStream;
112use quote::quote;
113use syn::parse::{Parse, ParseStream};
114use syn::punctuated::Punctuated;
115use syn::token::Comma;
116use syn::{braced, bracketed, parenthesized, parse_macro_input, Attribute, Expr, Ident, Token};
117
118struct Variant {
119    attrs: Vec<Attribute>,
120    name: Ident,
121}
122
123impl Parse for Variant {
124    fn parse(input: ParseStream) -> syn::Result<Self> {
125        Ok(Variant {
126            attrs: input.call(Attribute::parse_outer)?,
127            name: input.parse()?,
128        })
129    }
130}
131
132struct Args {
133    expr: Expr,
134    enum_name: Option<Ident>,
135    variants: Punctuated<Variant, Comma>,
136    inner_name: Ident,
137    arm_expr: Expr,
138}
139
140impl Parse for Args {
141    fn parse(input: ParseStream) -> syn::Result<Self> {
142        let match_body_content;
143
144        Ok(Args {
145            expr: {
146                input.parse::<Token!(match)>()?;
147
148                Expr::parse_without_eager_brace(input)?
149            },
150            enum_name: {
151                braced!(match_body_content in input);
152
153                let enum_name = match_body_content.parse::<Option<Ident>>()?;
154
155                if enum_name.is_none() {
156                    match_body_content.parse::<Token!(Self)>()?;
157                }
158
159                enum_name
160            },
161            variants: {
162                match_body_content.parse::<Token!(::)>()?;
163
164                let variants_list_content;
165
166                bracketed!(variants_list_content in match_body_content);
167
168                variants_list_content.parse_terminated(Variant::parse)?
169            },
170            inner_name: {
171                let variant_payload_content;
172
173                parenthesized!(variant_payload_content in match_body_content);
174
175                variant_payload_content.parse()?
176            },
177            arm_expr: {
178                match_body_content.parse::<Token!(=>)>()?;
179
180                match_body_content.parse()?
181            },
182        })
183    }
184}
185
186/// The macro itself.
187#[proc_macro]
188pub fn all_the_same(item: TokenStream) -> TokenStream {
189    let args = parse_macro_input!(item as Args);
190
191    let expr = &args.expr;
192    let enum_name = &args.enum_name;
193    let inner_name = &args.inner_name;
194    let arm_expr = &args.arm_expr;
195
196    let enum_name = match enum_name {
197        Some(name) => quote!(#name),
198        None => quote!(Self),
199    };
200
201    let arms = args.variants.iter().map(|variant| {
202        let name = &variant.name;
203        let attrs = &variant.attrs;
204
205        quote! {
206            #(#attrs)*
207            #enum_name::#name(#inner_name) => #arm_expr
208        }
209    });
210
211    quote! {
212        match #expr {
213            #(#arms),*
214        }
215    }
216    .into()
217}