multi_default_trait_impl/
lib.rs

1//! Define multiple default implementations for a trait.
2//!
3//! This library contains two attribute macros: `default_trait_impl` which defines a default trait
4//! implementation, and `trait_impl` which uses a default trait implementation you've defined.
5//!
6//! This is particularly useful in testing, when many of your mocked types will have very similar
7//! trait implementations, but do not want the canonical default trait implementation to use mocked
8//! values.
9//!
10//! # Example
11//!
12//! First, define a default trait implementation for the trait `Car`:
13//!
14//! ```
15//! #[default_trait_impl]
16//! impl Car for NewCar {
17//!     fn get_mileage(&self) -> Option<usize> { Some(6000) }
18//!     fn has_bluetooth(&self) -> bool { true }
19//! }
20//! ```
21//! 
22//! `NewCar` does not need to be defined beforehand.
23//!
24//! Next, implement the new default implementation for a type:
25//!
26//! ```
27//! struct NewOldFashionedCar;
28//!
29//! #[trait_impl]
30//! impl NewCar for NewOldFashionedCar {
31//!     fn has_bluetooth(&self) -> bool { false }
32//! }
33//!
34//!
35//! struct WellUsedNewCar;
36//!
37//! #[trait_impl]
38//! impl NewCar for WellUsedNewCar {
39//!     fn get_mileage(&self) -> Option<usize> { Some(100000) }
40//! }
41//! ```
42//!
43//! This will ensure that our structs use the `NewCar` defaults, without having to change the
44//! canonical `Car` default implementation:
45//!
46//! ```
47//! fn main() {
48//!     assert_eq!(NewOldFashionedCar.get_mileage(), Some(6000));
49//!     assert_eq!(NewOldFashionedCar.has_bluetooth(), false);
50//!     assert_eq!(WellUsedNewCar.get_mileage(), Some(100000));
51//!     assert_eq!(WellUsedNewCar.has_bluetooth(), true);
52//! }
53//! ```
54
55extern crate proc_macro;
56use proc_macro::TokenStream;
57use syn::{parse_macro_input, parse_str, Ident, ImplItem, ImplItemMethod, ItemImpl, Type};
58use quote::quote;
59use std::collections::{HashSet, HashMap};
60use std::sync::Mutex;
61use proc_macro2::Span;
62
63#[macro_use]
64extern crate lazy_static;
65
66lazy_static!{
67    static ref DEFAULT_TRAIT_IMPLS: Mutex<HashMap<String, DefaultTraitImpl>> = Mutex::new(HashMap::new());
68}
69
70struct DefaultTraitImpl {
71    pub trait_name: String,
72    pub methods: Vec<String>,
73}
74
75#[proc_macro_attribute]
76pub fn default_trait_impl(_: TokenStream, input: TokenStream) -> TokenStream {
77    let input = parse_macro_input!(input as ItemImpl);
78
79    let pseudotrait = match *input.self_ty {
80        Type::Path(type_path) => {
81            match type_path.path.get_ident() {
82                Some(ident) => ident.to_string(),
83                None => return syntax_invalid_error(),
84            }
85        },
86        _ => return syntax_invalid_error(),
87    };
88
89    let trait_name = match input.trait_ {
90        Some(trait_tuple) => {
91            match trait_tuple.1.get_ident() {
92                Some(ident) => ident.to_string(),
93                None => return syntax_invalid_error(),
94            }
95        },
96        _ => return syntax_invalid_error(),
97    };
98
99    let methods: Vec<String> = input.items.iter().map(|method| {
100        return quote! {
101            #method
102        }.to_string()
103    }).collect();
104
105    DEFAULT_TRAIT_IMPLS.lock().unwrap().insert(pseudotrait, DefaultTraitImpl { trait_name, methods });
106
107    TokenStream::new()
108}
109
110fn syntax_invalid_error() -> TokenStream {
111    return quote! {
112        compile_error!("`default_trait_impl` expects to be given a syntactially valid trait implementation");
113    }.into()
114}
115
116#[proc_macro_attribute]
117pub fn trait_impl(_: TokenStream, input: TokenStream) -> TokenStream {
118    let mut input = parse_macro_input!(input as ItemImpl);
119
120    let trait_name = match &input.trait_ {
121        Some(trait_tuple) => {
122            match trait_tuple.1.get_ident() {
123                Some(ident) => ident.to_string(),
124                None => return syntax_invalid_error(),
125            }
126        },
127        _ => return syntax_invalid_error(),
128    };
129
130    let mut methods = HashSet::new();
131    for item in &input.items {
132        if let ImplItem::Method(method) = item {
133            methods.insert(method.sig.ident.to_string());
134        }
135    }
136
137    match DEFAULT_TRAIT_IMPLS.lock().unwrap().get(&trait_name) {
138        Some(default_impl) => {
139            if let Some(trait_tuple) = &mut input.trait_ {
140                trait_tuple.1.segments[0].ident = Ident::new(&default_impl.trait_name, Span::call_site());
141            }
142
143            for default_impl_method in &default_impl.methods {
144                let parsed_default_method: ImplItemMethod = parse_str(default_impl_method).unwrap();
145                if !methods.contains(&parsed_default_method.sig.ident.to_string()) {
146                    input.items.push(ImplItem::Method(parsed_default_method));
147                }
148            }
149        },
150        _ => return quote! {
151            compile_error!("`trait_impl` expects there to be a `default_trait_impl` for the trait it implements");
152        }.into()
153    }
154
155    let res = quote! {
156        #input
157    };
158    res.into()
159}