partial_enum/lib.rs
1#![feature(never_type)]
2#![feature(exhaustive_patterns)]
3
4//! A proc-macro for generating partial enums from a template enum. This partial
5//! enum contains the same number of variants as the template but can disable a
6//! subset of these variants at compile time. The goal is used specialize enum
7//! with finer-grained variant set for each API.
8//!
9//! This is useful for handling errors. A common pattern is to define an enum
10//! with all possible errors and use this for the entire API surface. Albeit
11//! simple, this representation can fail to represent exact error scenarii by
12//! allowing errors that can not happen.
13//!
14//! Take an API responsible for decoding messages from a socket.
15//!
16//! ```
17//! # struct ConnectError;
18//! # struct ReadError;
19//! # struct DecodeError;
20//! # struct Socket;
21//! # struct Bytes;
22//! # struct Message;
23//! enum Error {
24//! Connect(ConnectError),
25//! Read(ReadError),
26//! Decode(DecodeError),
27//! }
28//!
29//! fn connect() -> Result<Socket, Error> {
30//! Ok(Socket)
31//! }
32//!
33//! fn read(sock: &mut Socket) -> Result<Bytes, Error> {
34//! Ok(Bytes)
35//! }
36//!
37//! fn decode(bytes: Bytes) -> Result<Message, Error> {
38//! Err(Error::Decode(DecodeError))
39//! }
40//! ```
41//!
42//! The same error enum is used all over the place and exposes variants that do
43//! not match the API: `decode` returns a `DecodeError` but nothing prevents
44//! from returning a `ConnectError`. For such low-level API, we could substitute
45//! `Error` by their matching error like `ConnectError` for `connect`. The
46//! downside is that composing with such functions forces us to redefine custom
47//! enums:
48//!
49//! ```
50//! # struct ReadError;
51//! # struct DecodeError;
52//! # struct Socket;
53//! # struct Bytes;
54//! # struct Message;
55//! enum NextMessageError {
56//! Read(ReadError),
57//! Decode(DecodeError),
58//! }
59//!
60//! impl From<ReadError> for NextMessageError {
61//! fn from(err: ReadError) -> Self {
62//! NextMessageError::Read(err)
63//! }
64//! }
65//!
66//! impl From<DecodeError> for NextMessageError {
67//! fn from(err: DecodeError) -> Self {
68//! NextMessageError::Decode(err)
69//! }
70//! }
71//!
72//! fn read(sock: &mut Socket) -> Result<Bytes, ReadError> {
73//! Ok(Bytes)
74//! }
75//!
76//! fn decode(bytes: Bytes) -> Result<Message, DecodeError> {
77//! Err(DecodeError)
78//! }
79//!
80//! fn next_message(sock: &mut Socket) -> Result<Message, NextMessageError> {
81//! let payload = read(sock)?;
82//! let message = decode(payload)?;
83//! Ok(message)
84//! }
85//! ```
86//!
87//! This proc-macro intend to ease the composition of APIs that does not share
88//! the exact same errors by generating a new generic enum where each variant
89//! can be disabled one by one. We can then redefine our API like so:
90//!
91//! ```
92//! # #![feature(never_type)]
93//! # mod example {
94//! # struct ConnectError;
95//! # struct ReadError;
96//! # struct DecodeError;
97//! # struct Socket;
98//! # struct Bytes;
99//! # struct Message;
100//! #[derive(partial_enum::Enum)]
101//! enum Error {
102//! Connect(ConnectError),
103//! Read(ReadError),
104//! Decode(DecodeError),
105//! }
106//!
107//! use partial::Error as E;
108//!
109//! fn connect() -> Result<Socket, E<ConnectError, !, !>> {
110//! Ok(Socket)
111//! }
112//!
113//! fn read(sock: &mut Socket) -> Result<Bytes, E<!, ReadError, !>> {
114//! Ok(Bytes)
115//! }
116//!
117//! fn decode(bytes: Bytes) -> Result<Message, E<!, !, DecodeError>> {
118//! Err(DecodeError)?
119//! }
120//!
121//! fn next_message(sock: &mut Socket) -> Result<Message, E<!, ReadError, DecodeError>> {
122//! let payload = read(sock)?;
123//! let message = decode(payload)?;
124//! Ok(message)
125//! }
126//! # }
127//! ```
128//!
129//! Notice that the `next_message` implementation is unaltered and the signature
130//! clearly states that only `ReadError` and `DecodeError` can be returned. The
131//! callee would never be able to match on `Error::Connect`. The `decode` implementation
132//! uses the `?` operator to convert `DecodeError` to the partial enum. By using the
133//! nightly feature `exhaustive_patterns`, the match statement does not even
134//! need to write the disabled variants.
135//!
136//! ```
137//! #![feature(exhaustive_patterns)]
138//! # #![feature(never_type)]
139//! # mod example {
140//! # struct ConnectError;
141//! # struct ReadError;
142//! # struct DecodeError;
143//! # struct Socket;
144//! # struct Bytes;
145//! # struct Message;
146//! # #[derive(partial_enum::Enum)]
147//! # enum Error {
148//! # Connect(ConnectError),
149//! # Read(ReadError),
150//! # Decode(DecodeError),
151//! # }
152//! # use partial::Error as E;
153//! # fn connect() -> Result<Socket, E<ConnectError, !, !>> { Ok(Socket) }
154//! # fn read(sock: &mut Socket) -> Result<Bytes, E<!, ReadError, !>> { Ok(Bytes) }
155//! # fn decode(bytes: Bytes) -> Result<Message, E<!, !, DecodeError>> { Err(DecodeError)? }
156//! # fn next_message(sock: &mut Socket) -> Result<Message, E<!, ReadError, DecodeError>> {
157//! # let payload = read(sock)?;
158//! # let message = decode(payload)?;
159//! # Ok(message)
160//! # }
161//! fn read_one_message() -> Result<Message, Error> {
162//! let mut socket = connect()?;
163//! match next_message(&mut socket) {
164//! Ok(msg) => Ok(msg),
165//! Err(E::Read(_)) => {
166//! // Retry...
167//! next_message(&mut socket).map_err(Error::from)
168//! }
169//! Err(E::Decode(err)) => Err(Error::Decode(err)),
170//! }
171//! }
172//! # }
173//! ```
174//!
175//! # Rust version
176//!
177//! By default, the empty placeholder is the unit type `()`. The generated code
178//! is compatible with the stable compiler. When the `never` feature is enabled,
179//! the never type `!` is used instead. This requires a nightly compiler and the
180//! nightly feature `#![feature(never_type)]`.
181
182extern crate proc_macro;
183use permutation::Permutations;
184use proc_macro::TokenStream;
185use proc_macro2::Span;
186use quote::ToTokens;
187use syn::{
188 parse::{Parse, ParseStream},
189 punctuated::Punctuated,
190 spanned::Spanned,
191 token::Paren,
192 Fields, Ident, ItemEnum, Token, Type, TypeNever, TypeTuple, Visibility,
193};
194
195mod permutation;
196
197/// Create the partial version of this enum.
198///
199/// This macro generates another enum of the same name, in a sub-module called
200/// `partial`. This enum have the same variant identifiers as the original but
201/// each associated type is now generic: an enum with `N` variants will have `N`
202/// generic parameters. Each of those types can be instantiated with either the
203/// original type or the never type `!`. No other type can be substituted. This
204/// effectively creates an enum capable of disabling several variants. The enum
205/// with no disabled variant is functionally equivalent to the original enum.
206///
207/// # Restrictions
208///
209/// Some restrictions are applied on the original enum for the macro to work:
210///
211/// * generic parameters are not supported
212/// * named variant are not supported
213/// * unit variant are not supported
214/// * unnamed variants must only contain one type
215///
216/// # Example
217///
218/// The following `derive` statement:
219///
220/// ```
221/// # #![feature(never_type)]
222/// # mod example {
223/// # struct Foo;
224/// # struct Bar;
225/// #[derive(partial_enum::Enum)]
226/// enum Error {
227/// Foo(Foo),
228/// Bar(Bar),
229/// }
230/// # }
231/// ```
232///
233/// will generate the following enum:
234///
235/// ```
236/// mod partial {
237/// enum Error<Foo, Bar> {
238/// Foo(Foo),
239/// Bar(Bar),
240/// }
241/// }
242/// ```
243///
244/// where `Foo` can only be instantiated by `Foo` or `!` and `Bar` can only be
245/// instantiated by `Bar` or `!`. `From` implementations are provided for all
246/// valid morphisms: such conversion is valid if and only if, for each variant
247/// type, we never go from a non-`!` type to the `!` type. This would otherwise
248/// allow to forget this variant and pretend we can never match on it. The
249/// compiler will rightfully complains that we're trying to instantiate an
250/// uninhabited type.
251#[proc_macro_derive(Enum)]
252pub fn derive_error(item: TokenStream) -> TokenStream {
253 let e: Enum = syn::parse_macro_input!(item as Enum);
254 e.to_tokens().to_token_stream().into()
255}
256
257struct Enum(PartialEnum);
258
259#[derive(Clone)]
260struct PartialEnum {
261 vis: Visibility,
262 ident: Ident,
263 variants: Vec<Variant>,
264}
265
266#[derive(Clone)]
267struct Variant {
268 ident: Ident,
269 typ: Type,
270}
271
272impl Parse for Enum {
273 fn parse(input: ParseStream) -> syn::Result<Self> {
274 let enum_: ItemEnum = input.parse()?;
275 if !enum_.generics.params.is_empty() {
276 return Err(syn::Error::new(
277 enum_.span(),
278 "generic parameters are not supported",
279 ));
280 }
281
282 let mut variants = vec![];
283 for variant in enum_.variants.into_iter() {
284 match variant.fields {
285 Fields::Named(_) => {
286 return Err(syn::Error::new(
287 variant.fields.span(),
288 "named field is not supported",
289 ))
290 }
291 Fields::Unnamed(ref fields) if fields.unnamed.len() != 1 => {
292 return Err(syn::Error::new(
293 variant.fields.span(),
294 "only one field is supported",
295 ))
296 }
297 Fields::Unnamed(mut fields) => {
298 let field = fields.unnamed.pop().unwrap().into_value();
299 variants.push(Variant {
300 ident: variant.ident,
301 typ: field.ty,
302 });
303 }
304 Fields::Unit => {
305 return Err(syn::Error::new(
306 variant.fields.span(),
307 "unit field is not supported",
308 ))
309 }
310 }
311 }
312
313 Ok(Enum(PartialEnum {
314 vis: enum_.vis,
315 ident: enum_.ident,
316 variants,
317 }))
318 }
319}
320
321impl Enum {
322 fn to_tokens(&self) -> impl ToTokens {
323 let enum_vis = &self.vis;
324 let enum_name = quote::format_ident!("{}", self.ident);
325 let empty_type = empty_token();
326
327 let mut variant_generics = vec![];
328 let mut variant_traits = vec![];
329 let mut variant_idents = vec![];
330 let mut variant_types = vec![];
331 for variant in &self.variants {
332 variant_generics.push(quote::format_ident!("{}", variant.ident));
333 variant_traits.push(quote::format_ident!("{}Bound", variant.ident));
334 variant_idents.push(&variant.ident);
335 variant_types.push(&variant.typ);
336 }
337
338 let mut from_impls = vec![];
339 for to in self.generate_all_partial_enums() {
340 let to_type = to.enum_tokens();
341 for from in self.generate_convertible_partial_enums(&to) {
342 let from_type = from.enum_tokens();
343 from_impls.push(quote::quote!(
344 impl From<#from_type> for #to_type {
345 fn from(value: #from_type) -> Self {
346 #[allow(unreachable_code)]
347 match value {
348 #(#enum_name::#variant_idents(x) => Self::#variant_idents(x),)*
349 }
350 }
351 }
352 ));
353 }
354 from_impls.push(quote::quote!(
355 impl From<#to_type> for super::#enum_name {
356 fn from(value: #to_type) -> Self {
357 #[allow(unreachable_code)]
358 match value {
359 #(#enum_name::#variant_idents(x) => Self::#variant_idents(x),)*
360 }
361 }
362 }
363
364 ));
365 }
366
367 // Implement conversion from a single variant type to any partial enum.
368 // The only constrain is that the corresponding variant type cannot be
369 // generic.
370 for (idx, (variant_type, variant_ident)) in
371 variant_types.iter().zip(&variant_idents).enumerate()
372 {
373 // Generate the destination type which is the generic version of the
374 // partial enum with the concrete type as the `idx`th position.
375 let (left, mut right) = variant_generics.split_at(idx);
376 if let &[_, ref right_1 @ ..] = right {
377 right = right_1;
378 }
379 let to_type = quote::quote!(#enum_name<#(#left,)* #variant_type, #(#right),*>);
380
381 // The `idx`th generic parameter is removed because it is a concrete type for this conversion.
382 let mut variant_generics = variant_generics.clone();
383 let mut variant_traits = variant_traits.clone();
384 variant_generics.remove(idx);
385 variant_traits.remove(idx);
386
387 from_impls.push(quote::quote!(
388 impl<#(#variant_generics: #variant_traits),*> From<#variant_type> for #to_type {
389 fn from(value: #variant_type) -> Self {
390 Self::#variant_ident(value)
391 }
392 }
393 ));
394 }
395
396 quote::quote!(
397 #enum_vis mod partial {
398 #(use super::#variant_types;)*
399
400 pub enum #enum_name<#(#variant_generics: #variant_traits),*> {
401 #(#variant_idents(#variant_generics)),*
402 }
403
404 #(
405 pub trait #variant_traits {}
406 impl #variant_traits for #variant_types {}
407 impl #variant_traits for #empty_type {}
408 )*
409
410 #(#from_impls)*
411 }
412 )
413 }
414
415 fn generate_all_partial_enums(&self) -> Vec<PartialEnum> {
416 let span = Span::call_site();
417 let empty_type = if cfg!(feature = "never") {
418 Type::Never(TypeNever {
419 bang_token: Token,
420 })
421 } else {
422 Type::Tuple(TypeTuple {
423 paren_token: Paren { span },
424 elems: Punctuated::new(),
425 })
426 };
427
428 let mut enums = vec![];
429 for perm in Permutations::new(self.variants.len()) {
430 let mut enum_ = self.0.clone();
431 for (i, is_concrete) in perm.enumerate() {
432 if !is_concrete {
433 enum_.variants[i].typ = empty_type.clone();
434 }
435 }
436 enums.push(enum_);
437 }
438 enums
439 }
440
441 fn generate_convertible_partial_enums(&self, to: &PartialEnum) -> Vec<PartialEnum> {
442 self.generate_all_partial_enums()
443 .into_iter()
444 .filter(|from| from.is_convertible_to(to))
445 .filter(|from| from != to)
446 .collect()
447 }
448}
449
450impl std::ops::Deref for Enum {
451 type Target = PartialEnum;
452 fn deref(&self) -> &Self::Target {
453 &self.0
454 }
455}
456
457impl PartialEq for PartialEnum {
458 fn eq(&self, other: &Self) -> bool {
459 self.ident == other.ident && self.variants == other.variants
460 }
461}
462
463impl PartialEnum {
464 fn enum_tokens(&self) -> impl ToTokens {
465 let enum_name = &self.ident;
466 let variant_types = self.variants.iter().map(|variant| &variant.typ);
467 quote::quote!(#enum_name<#(#variant_types,)*>)
468 }
469
470 fn is_convertible_to(&self, to: &PartialEnum) -> bool {
471 assert_eq!(self.variants.len(), to.variants.len());
472 for (from, to) in self.variants.iter().zip(&to.variants) {
473 if from.is_concrete() && to.is_never() {
474 return false;
475 }
476 }
477 true
478 }
479}
480
481impl Variant {
482 fn is_never(&self) -> bool {
483 matches!(self.typ, Type::Never(_))
484 }
485
486 fn is_concrete(&self) -> bool {
487 !self.is_never()
488 }
489}
490
491impl PartialEq for Variant {
492 fn eq(&self, other: &Self) -> bool {
493 self.ident == other.ident && self.is_concrete() == other.is_concrete()
494 }
495}
496
497fn empty_token() -> impl ToTokens {
498 if cfg!(feature = "never") {
499 quote::quote!(!)
500 } else {
501 quote::quote!(())
502 }
503}