attr_bounds/
lib.rs

1//! An attribute macro to stipulate bounds.
2//!
3//! The attribute applies bounds to `struct`s, `enum`s, `union`s, `trait`s, `fn`s, associated `type`s, and `impl` blocks.
4//!
5//! ```rust
6//! use attr_bounds::bounds;
7//!
8//! #[bounds(T: Copy)]
9//! pub struct Wrapper<T>(T);
10//!
11//! let var = Wrapper(42);
12//! ```
13//!
14//! ```compile_fail
15//! use attr_bounds::bounds;
16//!
17//! #[bounds(T: Copy)]
18//! pub struct Wrapper<T>(T);
19//!
20//! let var = Wrapper(Vec::<i32>::new());
21//! //                ^^^^^^^^^^^^^^^^^ the trait `Copy` is not implemented for `Vec<i32>`
22//! ```
23//!
24//! # Usage notes
25//!
26//! Basically, the attribute is designed to be used for conditional compilation and otherwise you will not need the attribute.
27//!
28//! ```rust
29//! use attr_bounds::bounds;
30//!
31//! #[cfg(feature = "unstable_feature_a")]
32//! pub trait UnstableA {}
33//! #[cfg(feature = "unstable_feature_b")]
34//! pub trait UnstableB {}
35//!
36//! #[cfg_attr(feature = "unstable_feature_a", bounds(Self: UnstableA))]
37//! #[cfg_attr(feature = "unstable_feature_b", bounds(Self: UnstableB))]
38//! pub trait Trait {}
39//!
40//! #[cfg(feature = "unstable_feature_a")]
41//! impl UnstableA for () {}
42//! #[cfg(feature = "unstable_feature_b")]
43//! impl UnstableB for () {}
44//!
45//! impl Trait for () {}
46//! ```
47
48use proc_macro2::TokenStream;
49use quote::{quote, ToTokens};
50use syn::{
51    parse::{discouraged::Speculative, Parse},
52    parse_macro_input,
53    punctuated::Punctuated,
54    ItemEnum, ItemFn, ItemImpl, ItemStruct, ItemTrait, ItemType, ItemUnion, Signature, Token,
55    TraitItemFn, TraitItemType, WhereClause, WherePredicate,
56};
57
58enum Item {
59    Enum(ItemEnum),
60    Fn(ItemFn),
61    Impl(ItemImpl),
62    Struct(ItemStruct),
63    Trait(ItemTrait),
64    Type(ItemType),
65    Union(ItemUnion),
66    AssocType(TraitItemType),
67    FnDecl(TraitItemFn),
68}
69
70impl Item {
71    fn make_where_clause(&mut self) -> &mut WhereClause {
72        let generics = match self {
73            Item::Enum(ItemEnum { generics, .. })
74            | Item::Fn(ItemFn {
75                sig: Signature { generics, .. },
76                ..
77            })
78            | Item::Impl(ItemImpl { generics, .. })
79            | Item::Struct(ItemStruct { generics, .. })
80            | Item::Trait(ItemTrait { generics, .. })
81            | Item::Type(ItemType { generics, .. })
82            | Item::Union(ItemUnion { generics, .. })
83            | Item::AssocType(TraitItemType { generics, .. })
84            | Item::FnDecl(TraitItemFn {
85                sig: Signature { generics, .. },
86                ..
87            }) => generics,
88        };
89        generics.make_where_clause()
90    }
91}
92
93impl Parse for Item {
94    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
95        let fork = input.fork();
96        if let Ok(item) = fork
97            .parse::<syn::Item>()
98            .map_or(Err(()), |item| match item {
99                syn::Item::Enum(item) => Ok(Item::Enum(item)),
100                syn::Item::Fn(item) => Ok(Item::Fn(item)),
101                syn::Item::Impl(item) => Ok(Item::Impl(item)),
102                syn::Item::Struct(item) => Ok(Item::Struct(item)),
103                syn::Item::Trait(item) => Ok(Item::Trait(item)),
104                syn::Item::Type(item) => Ok(Item::Type(item)),
105                syn::Item::Union(item) => Ok(Item::Union(item)),
106                _ => Err(()),
107            })
108        {
109            input.advance_to(&fork);
110            return Ok(item);
111        }
112
113        if let Ok(item) = input
114            .parse::<syn::TraitItem>()
115            .map_or(Err(()), |item| match item {
116                syn::TraitItem::Fn(item) => Ok(Item::FnDecl(item)),
117                syn::TraitItem::Type(item) => Ok(Item::AssocType(item)),
118                _ => Err(()),
119            })
120        {
121            return Ok(item);
122        }
123
124        Err(input.error("Unexpected item."))
125    }
126}
127
128impl ToTokens for Item {
129    fn to_tokens(&self, tokens: &mut TokenStream) {
130        match self {
131            Item::Enum(item) => item.to_tokens(tokens),
132            Item::Fn(item) => item.to_tokens(tokens),
133            Item::Impl(item) => item.to_tokens(tokens),
134            Item::Struct(item) => item.to_tokens(tokens),
135            Item::Trait(item) => item.to_tokens(tokens),
136            Item::Type(item) => item.to_tokens(tokens),
137            Item::Union(item) => item.to_tokens(tokens),
138            Item::AssocType(item) => item.to_tokens(tokens),
139            Item::FnDecl(item) => item.to_tokens(tokens),
140        }
141    }
142}
143
144/// Applies bounds to an item.
145///
146/// You can specify bounds with <i>[WhereClauseItem]</i>s.
147///
148/// # Examples
149/// ```rust
150/// use attr_bounds::bounds;
151///
152/// #[bounds(
153///     A: Clone,
154///     for<'a> &'a A: std::ops::Add<&'a A, Output = A>,
155///     B: Clone,
156/// )]
157/// pub struct Pair<A, B>(A, B);
158///
159/// let pair = Pair(42, vec!['a', 'b', 'c']);
160/// ```
161///
162/// [WhereClauseItem]: https://doc.rust-lang.org/reference/items/generics.html#where-clauses
163#[proc_macro_attribute]
164pub fn bounds(
165    attr: proc_macro::TokenStream,
166    input: proc_macro::TokenStream,
167) -> proc_macro::TokenStream {
168    let parser = Punctuated::<WherePredicate, Token![,]>::parse_terminated;
169    let attr = parse_macro_input!(attr with parser);
170
171    match syn::parse::<Item>(input) {
172        Ok(mut item) => {
173            let where_clause = item.make_where_clause();
174            where_clause.predicates.extend(attr);
175            item.into_token_stream().into()
176        }
177        Err(_) => {
178            // Using the compile_error!() macro to highlight the attribute in reporting an error.
179            quote! {
180                compile_error!("The attribute may only be applied to `struct`s, `enum`s, `union`s, `trait`s, `fn`s, `type`s, and `impl` blocks.");
181            }.into()
182        }
183    }
184}