ergo_pin/
lib.rs

1//! <p align="center"><i>Immobilis ergo pin</i></p>
2//!
3//! **Ergo**nomic stack **pin**ning for Rust.
4//!
5//! `ergo-pin` exports a single proc-macro-attribute `#[ergo_pin]` that can be applied to a
6//! item/block/`tt`-accepting-macro-invocation to provide the "magical" `pin!`
7//! within the scope. You can consider this `pin!` macro equivalent to a function
8//! with the signature:
9//!
10//! ```ignore
11//! extern "bla̴ck̀ mag̸ic͘" fn pin!<T>(t: T) -> Pin<&'local mut T>;
12//! ```
13//!
14//! it will take in any value and return a `Pin<&mut _>` of the value, with the
15//! correct local stack lifetime.
16//!
17//! # Examples
18//!
19//! ## Pin values inside functions
20//!
21//! ```rust
22//! use core::pin::Pin;
23//! use ergo_pin::ergo_pin;
24//!
25//! struct Foo;
26//!
27//! impl Foo {
28//!     fn foo(self: Pin<&mut Self>) -> usize {
29//!         5
30//!     }
31//! }
32//!
33//! #[ergo_pin]
34//! fn foo() -> usize {
35//!     pin!(Foo).foo()
36//! }
37//!
38//! assert_eq!(foo(), 5);
39//! ```
40//!
41//! ## Pin values in blocks (requires unstable features)
42//!
43#![cfg_attr(feature = "nightly-tests", doc = "```rust")]
44#![cfg_attr(not(feature = "nightly-tests"), doc = "```ignore")]
45//! #![feature(stmt_expr_attributes, proc_macro_hygiene)]
46//!
47//! use core::pin::Pin;
48//! use ergo_pin::ergo_pin;
49//!
50//! struct Foo;
51//!
52//! impl Foo {
53//!     fn foo(self: Pin<&mut Self>) -> usize {
54//!         5
55//!     }
56//! }
57//!
58//! fn foo() -> usize {
59//!     #[ergo_pin] {
60//!         pin!(Foo).foo()
61//!     }
62//! }
63//!
64//! assert_eq!(foo(), 5);
65//! ```
66//!
67//! ## Pin values in other macros that accept normal Rust code (requires unstable features)
68//!
69#![cfg_attr(feature = "nightly-tests", doc = "```rust")]
70#![cfg_attr(not(feature = "nightly-tests"), doc = "```ignore")]
71//! #![feature(proc_macro_hygiene)]
72//!
73//! use core::pin::Pin;
74//! use ergo_pin::ergo_pin;
75//!
76//! struct Foo;
77//!
78//! impl Foo {
79//!     fn foo(self: Pin<&mut Self>) -> usize {
80//!         5
81//!     }
82//! }
83//!
84//! macro_rules! bar {
85//!     ($($tokens:tt)+) => { $($tokens)+ };
86//! }
87//!
88//! fn foo() -> usize {
89//!     #[ergo_pin]
90//!     bar! {
91//!         pin!(Foo).foo()
92//!     }
93//! }
94//!
95//! assert_eq!(foo(), 5);
96//! ```
97//!
98//! ## Pin values inside any function of an impl
99//!
100//! (Note: this does _not_ descend into macros of the inner code as they may not be using normal
101//! Rust code syntax.)
102//!
103//! ```rust
104//! use core::pin::Pin;
105//! use ergo_pin::ergo_pin;
106//!
107//! struct Foo;
108//!
109//! impl Foo {
110//!     fn foo(self: Pin<&mut Self>) -> usize {
111//!         5
112//!     }
113//! }
114//!
115//! struct Bar;
116//!
117//! #[ergo_pin]
118//! impl Bar {
119//!     fn bar() -> usize {
120//!         pin!(Foo).foo()
121//!     }
122//! }
123//!
124//! assert_eq!(Bar::bar(), 5);
125//! ```
126
127extern crate proc_macro;
128
129use quote::{quote, ToTokens};
130use syn::fold::Fold;
131
132#[derive(Default)]
133struct Visitor {
134    counter: usize,
135    pinned: Vec<(syn::Ident, syn::Expr)>,
136}
137
138impl Visitor {
139    fn new() -> Self {
140        Self::default()
141    }
142
143    fn gen_ident(&mut self) -> syn::Ident {
144        let string = format!("__ergo_pin_{}", self.counter);
145        self.counter += 1;
146        syn::Ident::new(&string, proc_macro2::Span::call_site())
147    }
148}
149
150impl Fold for Visitor {
151    fn fold_block(&mut self, block: syn::Block) -> syn::Block {
152        syn::Block {
153            brace_token: block.brace_token,
154            stmts: block
155                .stmts
156                .into_iter()
157                .flat_map(|stmt| {
158                    let prior = std::mem::replace(&mut self.pinned, vec![]);
159                    let stmt = self.fold_stmt(stmt);
160                    std::mem::replace(&mut self.pinned, prior)
161                        .into_iter()
162                        .flat_map(|(ident, expr)| {
163                            syn::parse::<syn::Block>(
164                                quote!({
165                                    let mut #ident = #expr;
166                                    let #ident = unsafe {
167                                        ::core::pin::Pin::new_unchecked(&mut #ident)
168                                    };
169                                })
170                                .into(),
171                            )
172                            .unwrap()
173                            .stmts
174                        })
175                        .chain(std::iter::once(stmt))
176                })
177                .collect(),
178        }
179    }
180
181    fn fold_expr(&mut self, expr: syn::Expr) -> syn::Expr {
182        let pin = syn::Ident::new("pin", proc_macro2::Span::call_site());
183        if let syn::Expr::Macro(expr) = expr {
184            if expr.mac.path.is_ident(&pin) {
185                let ident = self.gen_ident();
186                self.pinned
187                    .push((ident.clone(), syn::parse(expr.mac.tokens.into()).unwrap()));
188                syn::Expr::Path(syn::ExprPath {
189                    attrs: vec![],
190                    qself: None,
191                    path: ident.into(),
192                })
193            } else {
194                syn::fold::fold_expr_macro(self, expr).into()
195            }
196        } else {
197            syn::fold::fold_expr(self, expr)
198        }
199    }
200
201    fn fold_expr_while(&mut self, expr: syn::ExprWhile) -> syn::ExprWhile {
202        syn::ExprWhile {
203            attrs: expr.attrs,
204            label: expr.label,
205            while_token: expr.while_token,
206            cond: Box::new(if let syn::Expr::Let(cond) = *expr.cond {
207                syn::Expr::Let(syn::ExprLet {
208                    expr: Box::new(syn::Expr::Block(syn::ExprBlock {
209                        attrs: vec![],
210                        label: None,
211                        block: self.fold_block(syn::Block {
212                            brace_token: syn::token::Brace {
213                                span: proc_macro2::Span::call_site(),
214                            },
215                            stmts: vec![syn::Stmt::Expr(*cond.expr)],
216                        }),
217                    })),
218                    ..cond
219                })
220            } else {
221                syn::Expr::Block(syn::ExprBlock {
222                    attrs: vec![],
223                    label: None,
224                    block: self.fold_block(syn::Block {
225                        brace_token: syn::token::Brace {
226                            span: proc_macro2::Span::call_site(),
227                        },
228                        stmts: vec![syn::Stmt::Expr(*expr.cond)],
229                    }),
230                })
231            }),
232            body: self.fold_block(expr.body),
233        }
234    }
235}
236
237/// The main attribute, see crate level docs for details.
238#[proc_macro_attribute]
239pub fn ergo_pin(
240    _attrs: proc_macro::TokenStream,
241    code: proc_macro::TokenStream,
242) -> proc_macro::TokenStream {
243    let mut visitor = Visitor::new();
244
245    if let Ok(mac) = syn::parse::<syn::Macro>(code.clone()) {
246        let tokens = mac.tokens;
247        if let Ok(block) = syn::parse::<syn::Block>(quote!({ #tokens }).into()) {
248            let block = visitor.fold_block(block);
249            let tokens = block.stmts.into_iter().map(|stmt| quote!(#stmt)).collect();
250            return syn::Macro { tokens, ..mac }.into_token_stream().into();
251        }
252    }
253
254    if let Ok(item) = syn::parse::<syn::Item>(code.clone()) {
255        return visitor.fold_item(item).into_token_stream().into();
256    }
257
258    if let Ok(block) = syn::parse::<syn::Block>(code.clone()) {
259        return visitor.fold_block(block).into_token_stream().into();
260    }
261
262    panic!("Could not parse input")
263}