partial_enum/
lib.rs

1#![feature(never_type)]
2#![feature(exhaustive_patterns)]
3
4//! A proc-macro for generating partial enums from a template enum. This partial
5//! enum contains the same number of variants as the template but can disable a
6//! subset of these variants at compile time. The goal is used specialize enum
7//! with finer-grained variant set for each API.
8//!
9//! This is useful for handling errors. A common pattern is to define an enum
10//! with all possible errors and use this for the entire API surface. Albeit
11//! simple, this representation can fail to represent exact error scenarii by
12//! allowing errors that can not happen.
13//!
14//! Take an API responsible for decoding messages from a socket.
15//!
16//! ```
17//! # struct ConnectError;
18//! # struct ReadError;
19//! # struct DecodeError;
20//! # struct Socket;
21//! # struct Bytes;
22//! # struct Message;
23//! enum Error {
24//!     Connect(ConnectError),
25//!     Read(ReadError),
26//!     Decode(DecodeError),
27//! }
28//!
29//! fn connect() -> Result<Socket, Error> {
30//!     Ok(Socket)
31//! }
32//!
33//! fn read(sock: &mut Socket) -> Result<Bytes, Error> {
34//!     Ok(Bytes)
35//! }
36//!
37//! fn decode(bytes: Bytes) -> Result<Message, Error> {
38//!     Err(Error::Decode(DecodeError))
39//! }
40//! ```
41//!
42//! The same error enum is used all over the place and exposes variants that do
43//! not match the API: `decode` returns a `DecodeError` but nothing prevents
44//! from returning a `ConnectError`. For such low-level API, we could substitute
45//! `Error` by their matching error like `ConnectError` for `connect`. The
46//! downside is that composing with such functions forces us to redefine custom
47//! enums:
48//!
49//! ```
50//! # struct ReadError;
51//! # struct DecodeError;
52//! # struct Socket;
53//! # struct Bytes;
54//! # struct Message;
55//! enum NextMessageError {
56//!     Read(ReadError),
57//!     Decode(DecodeError),
58//! }
59//!
60//! impl From<ReadError> for NextMessageError {
61//!     fn from(err: ReadError) -> Self {
62//!         NextMessageError::Read(err)
63//!     }
64//! }
65//!
66//! impl From<DecodeError> for NextMessageError {
67//!     fn from(err: DecodeError) -> Self {
68//!         NextMessageError::Decode(err)
69//!     }
70//! }
71//!
72//! fn read(sock: &mut Socket) -> Result<Bytes, ReadError> {
73//!     Ok(Bytes)
74//! }
75//!
76//! fn decode(bytes: Bytes) -> Result<Message, DecodeError> {
77//!     Err(DecodeError)
78//! }
79//!
80//! fn next_message(sock: &mut Socket) -> Result<Message, NextMessageError> {
81//!     let payload = read(sock)?;
82//!     let message = decode(payload)?;
83//!     Ok(message)
84//! }
85//! ```
86//!
87//! This proc-macro intend to ease the composition of APIs that does not share
88//! the exact same errors by generating a new generic enum where each variant
89//! can be disabled one by one. We can then redefine our API like so:
90//!
91//! ```
92//! # #![feature(never_type)]
93//! # mod example {
94//! # struct ConnectError;
95//! # struct ReadError;
96//! # struct DecodeError;
97//! # struct Socket;
98//! # struct Bytes;
99//! # struct Message;
100//! #[derive(partial_enum::Enum)]
101//! enum Error {
102//!     Connect(ConnectError),
103//!     Read(ReadError),
104//!     Decode(DecodeError),
105//! }
106//!
107//! use partial::Error as E;
108//!
109//! fn connect() -> Result<Socket, E<ConnectError, !, !>> {
110//!     Ok(Socket)
111//! }
112//!
113//! fn read(sock: &mut Socket) -> Result<Bytes, E<!, ReadError, !>> {
114//!     Ok(Bytes)
115//! }
116//!
117//! fn decode(bytes: Bytes) -> Result<Message, E<!, !, DecodeError>> {
118//!     Err(DecodeError)?
119//! }
120//!
121//! fn next_message(sock: &mut Socket) -> Result<Message, E<!, ReadError, DecodeError>> {
122//!     let payload = read(sock)?;
123//!     let message = decode(payload)?;
124//!     Ok(message)
125//! }
126//! # }
127//! ```
128//!
129//! Notice that the `next_message` implementation is unaltered and the signature
130//! clearly states that only `ReadError` and `DecodeError` can be returned. The
131//! callee would never be able to match on `Error::Connect`. The `decode` implementation
132//! uses the `?` operator to convert `DecodeError` to the partial enum. By using the
133//! nightly feature `exhaustive_patterns`, the match statement does not even
134//! need to write the disabled variants.
135//!
136//! ```
137//! #![feature(exhaustive_patterns)]
138//! # #![feature(never_type)]
139//! # mod example {
140//! # struct ConnectError;
141//! # struct ReadError;
142//! # struct DecodeError;
143//! # struct Socket;
144//! # struct Bytes;
145//! # struct Message;
146//! # #[derive(partial_enum::Enum)]
147//! # enum Error {
148//! #     Connect(ConnectError),
149//! #     Read(ReadError),
150//! #     Decode(DecodeError),
151//! # }
152//! # use partial::Error as E;
153//! # fn connect() -> Result<Socket, E<ConnectError, !, !>> { Ok(Socket) }
154//! # fn read(sock: &mut Socket) -> Result<Bytes, E<!, ReadError, !>> { Ok(Bytes) }
155//! # fn decode(bytes: Bytes) -> Result<Message, E<!, !, DecodeError>> { Err(DecodeError)? }
156//! # fn next_message(sock: &mut Socket) -> Result<Message, E<!, ReadError, DecodeError>> {
157//! #     let payload = read(sock)?;
158//! #     let message = decode(payload)?;
159//! #     Ok(message)
160//! # }
161//! fn read_one_message() -> Result<Message, Error> {
162//!     let mut socket = connect()?;
163//!     match next_message(&mut socket) {
164//!         Ok(msg) => Ok(msg),
165//!         Err(E::Read(_)) => {
166//!             // Retry...
167//!             next_message(&mut socket).map_err(Error::from)
168//!         }
169//!         Err(E::Decode(err)) => Err(Error::Decode(err)),
170//!     }
171//! }
172//! # }
173//! ```
174//!
175//! # Rust version
176//!
177//! By default, the empty placeholder is the unit type `()`. The generated code
178//! is compatible with the stable compiler. When the `never` feature is enabled,
179//! the never type `!` is used instead. This requires a nightly compiler and the
180//! nightly feature `#![feature(never_type)]`.
181
182extern crate proc_macro;
183use permutation::Permutations;
184use proc_macro::TokenStream;
185use proc_macro2::Span;
186use quote::ToTokens;
187use syn::{
188    parse::{Parse, ParseStream},
189    punctuated::Punctuated,
190    spanned::Spanned,
191    token::Paren,
192    Fields, Ident, ItemEnum, Token, Type, TypeNever, TypeTuple, Visibility,
193};
194
195mod permutation;
196
197/// Create the partial version of this enum.
198///
199/// This macro generates another enum of the same name, in a sub-module called
200/// `partial`. This enum have the same variant identifiers as the original but
201/// each associated type is now generic: an enum with `N` variants will have `N`
202/// generic parameters. Each of those types can be instantiated with either the
203/// original type or the never type `!`. No other type can be substituted. This
204/// effectively creates an enum capable of disabling several variants. The enum
205/// with no disabled variant is functionally equivalent to the original enum.
206///
207/// # Restrictions
208///
209/// Some restrictions are applied on the original enum for the macro to work:
210///
211/// * generic parameters are not supported
212/// * named variant are not supported
213/// * unit variant are not supported
214/// * unnamed variants must only contain one type
215///
216/// # Example
217///
218/// The following `derive` statement:
219///
220/// ```
221/// # #![feature(never_type)]
222/// # mod example {
223/// # struct Foo;
224/// # struct Bar;
225/// #[derive(partial_enum::Enum)]
226/// enum Error {
227///     Foo(Foo),
228///     Bar(Bar),
229/// }
230/// # }
231/// ```
232///
233/// will generate the following enum:
234///
235/// ```
236/// mod partial {
237///     enum Error<Foo, Bar> {
238///         Foo(Foo),
239///         Bar(Bar),
240///     }
241/// }
242/// ```
243///
244/// where `Foo` can only be instantiated by `Foo` or `!` and `Bar` can only be
245/// instantiated by `Bar` or `!`. `From` implementations are provided for all
246/// valid morphisms: such conversion is valid if and only if, for each variant
247/// type, we never go from a non-`!` type to the `!` type. This would otherwise
248/// allow to forget this variant and pretend we can never match on it. The
249/// compiler will rightfully complains that we're trying to instantiate an
250/// uninhabited type.
251#[proc_macro_derive(Enum)]
252pub fn derive_error(item: TokenStream) -> TokenStream {
253    let e: Enum = syn::parse_macro_input!(item as Enum);
254    e.to_tokens().to_token_stream().into()
255}
256
257struct Enum(PartialEnum);
258
259#[derive(Clone)]
260struct PartialEnum {
261    vis: Visibility,
262    ident: Ident,
263    variants: Vec<Variant>,
264}
265
266#[derive(Clone)]
267struct Variant {
268    ident: Ident,
269    typ: Type,
270}
271
272impl Parse for Enum {
273    fn parse(input: ParseStream) -> syn::Result<Self> {
274        let enum_: ItemEnum = input.parse()?;
275        if !enum_.generics.params.is_empty() {
276            return Err(syn::Error::new(
277                enum_.span(),
278                "generic parameters are not supported",
279            ));
280        }
281
282        let mut variants = vec![];
283        for variant in enum_.variants.into_iter() {
284            match variant.fields {
285                Fields::Named(_) => {
286                    return Err(syn::Error::new(
287                        variant.fields.span(),
288                        "named field is not supported",
289                    ))
290                }
291                Fields::Unnamed(ref fields) if fields.unnamed.len() != 1 => {
292                    return Err(syn::Error::new(
293                        variant.fields.span(),
294                        "only one field is supported",
295                    ))
296                }
297                Fields::Unnamed(mut fields) => {
298                    let field = fields.unnamed.pop().unwrap().into_value();
299                    variants.push(Variant {
300                        ident: variant.ident,
301                        typ: field.ty,
302                    });
303                }
304                Fields::Unit => {
305                    return Err(syn::Error::new(
306                        variant.fields.span(),
307                        "unit field is not supported",
308                    ))
309                }
310            }
311        }
312
313        Ok(Enum(PartialEnum {
314            vis: enum_.vis,
315            ident: enum_.ident,
316            variants,
317        }))
318    }
319}
320
321impl Enum {
322    fn to_tokens(&self) -> impl ToTokens {
323        let enum_vis = &self.vis;
324        let enum_name = quote::format_ident!("{}", self.ident);
325        let empty_type = empty_token();
326
327        let mut variant_generics = vec![];
328        let mut variant_traits = vec![];
329        let mut variant_idents = vec![];
330        let mut variant_types = vec![];
331        for variant in &self.variants {
332            variant_generics.push(quote::format_ident!("{}", variant.ident));
333            variant_traits.push(quote::format_ident!("{}Bound", variant.ident));
334            variant_idents.push(&variant.ident);
335            variant_types.push(&variant.typ);
336        }
337
338        let mut from_impls = vec![];
339        for to in self.generate_all_partial_enums() {
340            let to_type = to.enum_tokens();
341            for from in self.generate_convertible_partial_enums(&to) {
342                let from_type = from.enum_tokens();
343                from_impls.push(quote::quote!(
344                    impl From<#from_type> for #to_type {
345                        fn from(value: #from_type) -> Self {
346                            #[allow(unreachable_code)]
347                            match value {
348                                #(#enum_name::#variant_idents(x) => Self::#variant_idents(x),)*
349                            }
350                        }
351                    }
352                ));
353            }
354            from_impls.push(quote::quote!(
355                impl From<#to_type> for super::#enum_name {
356                    fn from(value: #to_type) -> Self {
357                        #[allow(unreachable_code)]
358                        match value {
359                            #(#enum_name::#variant_idents(x) => Self::#variant_idents(x),)*
360                        }
361                    }
362                }
363
364            ));
365        }
366
367        // Implement conversion from a single variant type to any partial enum.
368        // The only constrain is that the corresponding variant type cannot be
369        // generic.
370        for (idx, (variant_type, variant_ident)) in
371            variant_types.iter().zip(&variant_idents).enumerate()
372        {
373            // Generate the destination type which is the generic version of the
374            // partial enum with the concrete type as the `idx`th position.
375            let (left, mut right) = variant_generics.split_at(idx);
376            if let &[_, ref right_1 @ ..] = right {
377                right = right_1;
378            }
379            let to_type = quote::quote!(#enum_name<#(#left,)* #variant_type, #(#right),*>);
380
381            // The `idx`th generic parameter is removed because it is a concrete type for this conversion.
382            let mut variant_generics = variant_generics.clone();
383            let mut variant_traits = variant_traits.clone();
384            variant_generics.remove(idx);
385            variant_traits.remove(idx);
386
387            from_impls.push(quote::quote!(
388                impl<#(#variant_generics: #variant_traits),*> From<#variant_type> for #to_type {
389                    fn from(value: #variant_type) -> Self {
390                        Self::#variant_ident(value)
391                    }
392                }
393            ));
394        }
395
396        quote::quote!(
397            #enum_vis mod partial {
398                #(use super::#variant_types;)*
399
400                pub enum #enum_name<#(#variant_generics: #variant_traits),*> {
401                    #(#variant_idents(#variant_generics)),*
402                }
403
404                #(
405                pub trait #variant_traits {}
406                impl #variant_traits for #variant_types {}
407                impl #variant_traits for #empty_type {}
408                )*
409
410                #(#from_impls)*
411            }
412        )
413    }
414
415    fn generate_all_partial_enums(&self) -> Vec<PartialEnum> {
416        let span = Span::call_site();
417        let empty_type = if cfg!(feature = "never") {
418            Type::Never(TypeNever {
419                bang_token: Token![!]([span]),
420            })
421        } else {
422            Type::Tuple(TypeTuple {
423                paren_token: Paren { span },
424                elems: Punctuated::new(),
425            })
426        };
427
428        let mut enums = vec![];
429        for perm in Permutations::new(self.variants.len()) {
430            let mut enum_ = self.0.clone();
431            for (i, is_concrete) in perm.enumerate() {
432                if !is_concrete {
433                    enum_.variants[i].typ = empty_type.clone();
434                }
435            }
436            enums.push(enum_);
437        }
438        enums
439    }
440
441    fn generate_convertible_partial_enums(&self, to: &PartialEnum) -> Vec<PartialEnum> {
442        self.generate_all_partial_enums()
443            .into_iter()
444            .filter(|from| from.is_convertible_to(to))
445            .filter(|from| from != to)
446            .collect()
447    }
448}
449
450impl std::ops::Deref for Enum {
451    type Target = PartialEnum;
452    fn deref(&self) -> &Self::Target {
453        &self.0
454    }
455}
456
457impl PartialEq for PartialEnum {
458    fn eq(&self, other: &Self) -> bool {
459        self.ident == other.ident && self.variants == other.variants
460    }
461}
462
463impl PartialEnum {
464    fn enum_tokens(&self) -> impl ToTokens {
465        let enum_name = &self.ident;
466        let variant_types = self.variants.iter().map(|variant| &variant.typ);
467        quote::quote!(#enum_name<#(#variant_types,)*>)
468    }
469
470    fn is_convertible_to(&self, to: &PartialEnum) -> bool {
471        assert_eq!(self.variants.len(), to.variants.len());
472        for (from, to) in self.variants.iter().zip(&to.variants) {
473            if from.is_concrete() && to.is_never() {
474                return false;
475            }
476        }
477        true
478    }
479}
480
481impl Variant {
482    fn is_never(&self) -> bool {
483        matches!(self.typ, Type::Never(_))
484    }
485
486    fn is_concrete(&self) -> bool {
487        !self.is_never()
488    }
489}
490
491impl PartialEq for Variant {
492    fn eq(&self, other: &Self) -> bool {
493        self.ident == other.ident && self.is_concrete() == other.is_concrete()
494    }
495}
496
497fn empty_token() -> impl ToTokens {
498    if cfg!(feature = "never") {
499        quote::quote!(!)
500    } else {
501        quote::quote!(())
502    }
503}