inherit_methods_macro/lib.rs
1//! Inherit methods from a field automatically (via procedural macros).
2//!
3//! # Motivation
4//!
5//! While Rust is partially inspired by the object-oriented programming (OOP) paradigm
6//! and has some typical OOP features (like objects, encapsulation, and polymorphism),
7//! it is not an OOP language. One piece of evidence is the lack of _inheritance_, which an
8//! important pillar of OOP. But don't take me wrong: this lack of inheritance is actually a
9//! good thing since it promotes the practice of
10//! [_composition over inheritance_](https://en.wikipedia.org/wiki/Composition_over_inheritance)
11//! in Rust programs. Despite all the benefits of composition, Rust programmers
12//! have to write trivial [fowarding methods](https://en.wikipedia.org/wiki/Forwarding_(object-oriented_programming)),
13//! which is a tedious task, especially when you have to write many of them.
14//!
15//! To address this pain point of using composition in Rust, the crate provides a convenient
16//! procedure macro that generates forwarding methods automatically for you. In other words,
17//! your structs can now "inherit" methods from their fields, enjoying the best of both worlds:
18//! the convenience of inheritance and the flexibility of composition.
19//!
20//! # Examples
21//!
22//! ## Implementing the new type idiom
23//!
24//! Suppose that you want to create a new struct named `Stack<T>`, which can be implemented by
25//! simply wrapping around `Vec<T>` and exposing only a subset of the APIs of `Vec`. Here is
26//! how this crate can help you do it easily.
27//!
28//! ```rust
29//! use inherit_methods_macro::inherit_methods;
30//!
31//! pub struct Stack<T>(Vec<T>);
32//!
33//! // Annotate an impl block with #[inherit_methods(from = "...")] to enable automatically
34//! // inheriting methods from a field, which is specifiedd by the from attribute.
35//! #[inherit_methods(from = "self.0")]
36//! // This prevent cargo-fmt from issuing false alarms due to the way that this crate extends
37//! // the Rust syntax (i.e., allowing method definitions without code blocks).
38//! #[rustfmt::skip]
39//! impl<T> Stack<T> {
40//! // Normal methods can be implemented with inherited methods in the same impl block.
41//! pub fn new() -> Self {
42//! Self(Vec::new())
43//! }
44//!
45//! // All methods without code blocks will "inherit" the implementation of Vec by
46//! // forwarding their method calls to self.0.
47//! pub fn push(&mut self, value: T);
48//! pub fn pop(&mut self) -> Option<T>;
49//! pub fn len(&self) -> usize;
50//! }
51//! ```
52//!
53//! If you want to derive common traits (like `AsRef` and `Deref`) for a wrapper type, check out
54//! the [shrinkwraprs](https://crates.io/crates/shrinkwraprs) crate.
55//!
56//! ## Emulating the classic OOP inheritance
57//!
58//! In many OOP frameworks or applications, it is useful to have a base class from which all objects
59//! inherit. In this example, we would like to do the same thing, creating a base class
60//! (the `Object` trait for the interface and the `ObjectBase` struct for the implementation).
61//! that all objects should "inherit".
62//!
63//! ```rust
64//! use std::sync::atomic::{AtomicU64, Ordering};
65//! use std::sync::Mutex;
66//!
67//! use inherit_methods_macro::inherit_methods;
68//!
69//! pub trait Object {
70//! fn type_name(&self) -> &'static str;
71//! fn object_id(&self) -> u64;
72//! fn name(&self) -> String;
73//! fn set_name(&self, new_name: String);
74//! }
75//!
76//! struct ObjectBase {
77//! object_id: u64,
78//! name: Mutex<String>,
79//! }
80//!
81//! impl ObjectBase {
82//! pub fn new() -> Self {
83//! static NEXT_ID: AtomicU64 = AtomicU64::new(0);
84//! Self {
85//! object_id: NEXT_ID.fetch_add(1, Ordering::Relaxed),
86//! name: Mutex::new(String::new()),
87//! }
88//! }
89//!
90//! pub fn object_id(&self) -> u64 {
91//! self.object_id
92//! }
93//!
94//! pub fn name(&self) -> String {
95//! self.name.lock().unwrap().clone()
96//! }
97//!
98//! pub fn set_name(&self, new_name: String) {
99//! *self.name.lock().unwrap() = new_name;
100//! }
101//! }
102//!
103//! struct DummyObject {
104//! base: ObjectBase,
105//! }
106//!
107//! impl DummyObject {
108//! pub fn new() -> Self {
109//! Self {
110//! base: ObjectBase::new(),
111//! }
112//! }
113//! }
114//!
115//! #[inherit_methods(from = "self.base")]
116//! #[rustfmt::skip]
117//! impl Object for DummyObject {
118//! // Give this method an implementation specific to this type
119//! fn type_name(&self) -> &'static str {
120//! "DummyObject"
121//! }
122//!
123//! // Inherit methods from the base class
124//! fn object_id(&self) -> u64;
125//! fn name(&self) -> String;
126//! fn set_name(&self, new_name: String);
127//! }
128//! ```
129
130// TODO: fix the compatibility issue with cargo-fmt.
131
132extern crate proc_macro;
133
134use darling::FromMeta;
135use proc_macro2::{Punct, Spacing, TokenStream};
136use quote::{quote, ToTokens, TokenStreamExt};
137use syn::{
138 AttributeArgs, Block, Expr, FnArg, Ident, ImplItem, ImplItemMethod, Item, ItemImpl, Pat, Stmt,
139};
140
141#[derive(Debug, FromMeta)]
142struct MacroAttr {
143 #[darling(default = "default_from_val")]
144 from: String,
145}
146
147fn default_from_val() -> String {
148 "self.0".to_string()
149}
150
151#[proc_macro_attribute]
152pub fn inherit_methods(
153 attr: proc_macro::TokenStream,
154 item: proc_macro::TokenStream,
155) -> proc_macro::TokenStream {
156 let attr = {
157 let attr_tokens = syn::parse_macro_input!(attr as AttributeArgs);
158 match MacroAttr::from_list(&attr_tokens) {
159 Ok(attr) => attr,
160 Err(e) => {
161 return e.write_errors().into();
162 }
163 }
164 };
165 let item_impl = syn::parse_macro_input!(item as syn::ItemImpl);
166 do_inherit_methods(attr, item_impl).into()
167}
168
169fn do_inherit_methods(attr: MacroAttr, mut item_impl: ItemImpl) -> TokenStream {
170 // Parse the field to which we will forward method calls
171 let field: Expr = syn::parse_str(&attr.from).unwrap();
172
173 // Transform this impl item by adding method forwarding code to inherited methods.
174 for impl_item in &mut item_impl.items {
175 let impl_item_method = match is_method_missing_fn_block(impl_item) {
176 Some(method) => method,
177 None => continue,
178 };
179 add_fn_block(impl_item_method, &field);
180 }
181 item_impl.into_token_stream()
182}
183
184// Returns whether an item inside `impl XXX { ... }` is a method without code block.
185fn is_method_missing_fn_block(impl_item: &mut ImplItem) -> Option<&mut ImplItemMethod> {
186 // We only care about method items.
187 let impl_item_method = if let ImplItem::Method(method) = impl_item {
188 method
189 } else {
190 return None;
191 };
192 // We only care about methods without a code block.
193 if !impl_item_method.block.is_empty() {
194 return None;
195 }
196 Some(impl_item_method)
197}
198
199// Add a code block of method forwarding for the method item.
200fn add_fn_block(impl_item_method: &mut ImplItemMethod, field: &Expr) {
201 let fn_sig = &impl_item_method.sig;
202 let fn_name = &fn_sig.ident;
203 let fn_arg_tokens = {
204 // Extract all argument idents (except self) from the signature
205 let fn_arg_idents: Vec<&Ident> = fn_sig
206 .inputs
207 .iter()
208 .filter_map(|fn_arg| match fn_arg {
209 FnArg::Receiver(_) => None,
210 FnArg::Typed(pat_type) => Some(pat_type),
211 })
212 .filter_map(|pat_type| match &*pat_type.pat {
213 Pat::Ident(pat_ident) => Some(&pat_ident.ident),
214 _ => None,
215 })
216 .collect();
217
218 // Combine all arguments into a comma-separated token stream
219 let mut fn_arg_tokens = TokenStream::new();
220 for fn_arg_ident in fn_arg_idents {
221 let fn_arg_ident = fn_arg_ident.clone();
222 fn_arg_tokens.append(fn_arg_ident);
223 fn_arg_tokens.append(Punct::new(',', Spacing::Alone));
224 }
225 fn_arg_tokens
226 };
227
228 let new_fn_block: Block = {
229 let new_fn_tokens = quote! {
230 // This is the code block added to the incomplete method, which
231 // is just forwarding the function call to the field.
232 {
233 #field.#fn_name(#fn_arg_tokens)
234 }
235 };
236 syn::parse(new_fn_tokens.into()).unwrap()
237 };
238 impl_item_method.block = new_fn_block;
239}
240
241trait BlockExt {
242 /// Check if a block is empty, which means only contains a ";".
243 fn is_empty(&self) -> bool;
244}
245
246impl BlockExt for Block {
247 fn is_empty(&self) -> bool {
248 if self.stmts.len() == 0 {
249 return true;
250 }
251 if self.stmts.len() > 1 {
252 return false;
253 }
254
255 if let Stmt::Item(item) = &self.stmts[0] {
256 if let Item::Verbatim(token_stream) = item {
257 token_stream.to_string().trim() == ";"
258 } else {
259 false
260 }
261 } else {
262 false
263 }
264 }
265}