ergo_pin/lib.rs
1//! <p align="center"><i>Immobilis ergo pin</i></p>
2//!
3//! **Ergo**nomic stack **pin**ning for Rust.
4//!
5//! `ergo-pin` exports a single proc-macro-attribute `#[ergo_pin]` that can be applied to a
6//! item/block/`tt`-accepting-macro-invocation to provide the "magical" `pin!`
7//! within the scope. You can consider this `pin!` macro equivalent to a function
8//! with the signature:
9//!
10//! ```ignore
11//! extern "bla̴ck̀ mag̸ic͘" fn pin!<T>(t: T) -> Pin<&'local mut T>;
12//! ```
13//!
14//! it will take in any value and return a `Pin<&mut _>` of the value, with the
15//! correct local stack lifetime.
16//!
17//! # Examples
18//!
19//! ## Pin values inside functions
20//!
21//! ```rust
22//! use core::pin::Pin;
23//! use ergo_pin::ergo_pin;
24//!
25//! struct Foo;
26//!
27//! impl Foo {
28//! fn foo(self: Pin<&mut Self>) -> usize {
29//! 5
30//! }
31//! }
32//!
33//! #[ergo_pin]
34//! fn foo() -> usize {
35//! pin!(Foo).foo()
36//! }
37//!
38//! assert_eq!(foo(), 5);
39//! ```
40//!
41//! ## Pin values in blocks (requires unstable features)
42//!
43#![cfg_attr(feature = "nightly-tests", doc = "```rust")]
44#![cfg_attr(not(feature = "nightly-tests"), doc = "```ignore")]
45//! #![feature(stmt_expr_attributes, proc_macro_hygiene)]
46//!
47//! use core::pin::Pin;
48//! use ergo_pin::ergo_pin;
49//!
50//! struct Foo;
51//!
52//! impl Foo {
53//! fn foo(self: Pin<&mut Self>) -> usize {
54//! 5
55//! }
56//! }
57//!
58//! fn foo() -> usize {
59//! #[ergo_pin] {
60//! pin!(Foo).foo()
61//! }
62//! }
63//!
64//! assert_eq!(foo(), 5);
65//! ```
66//!
67//! ## Pin values in other macros that accept normal Rust code (requires unstable features)
68//!
69#![cfg_attr(feature = "nightly-tests", doc = "```rust")]
70#![cfg_attr(not(feature = "nightly-tests"), doc = "```ignore")]
71//! #![feature(proc_macro_hygiene)]
72//!
73//! use core::pin::Pin;
74//! use ergo_pin::ergo_pin;
75//!
76//! struct Foo;
77//!
78//! impl Foo {
79//! fn foo(self: Pin<&mut Self>) -> usize {
80//! 5
81//! }
82//! }
83//!
84//! macro_rules! bar {
85//! ($($tokens:tt)+) => { $($tokens)+ };
86//! }
87//!
88//! fn foo() -> usize {
89//! #[ergo_pin]
90//! bar! {
91//! pin!(Foo).foo()
92//! }
93//! }
94//!
95//! assert_eq!(foo(), 5);
96//! ```
97//!
98//! ## Pin values inside any function of an impl
99//!
100//! (Note: this does _not_ descend into macros of the inner code as they may not be using normal
101//! Rust code syntax.)
102//!
103//! ```rust
104//! use core::pin::Pin;
105//! use ergo_pin::ergo_pin;
106//!
107//! struct Foo;
108//!
109//! impl Foo {
110//! fn foo(self: Pin<&mut Self>) -> usize {
111//! 5
112//! }
113//! }
114//!
115//! struct Bar;
116//!
117//! #[ergo_pin]
118//! impl Bar {
119//! fn bar() -> usize {
120//! pin!(Foo).foo()
121//! }
122//! }
123//!
124//! assert_eq!(Bar::bar(), 5);
125//! ```
126
127extern crate proc_macro;
128
129use quote::{quote, ToTokens};
130use syn::fold::Fold;
131
132#[derive(Default)]
133struct Visitor {
134 counter: usize,
135 pinned: Vec<(syn::Ident, syn::Expr)>,
136}
137
138impl Visitor {
139 fn new() -> Self {
140 Self::default()
141 }
142
143 fn gen_ident(&mut self) -> syn::Ident {
144 let string = format!("__ergo_pin_{}", self.counter);
145 self.counter += 1;
146 syn::Ident::new(&string, proc_macro2::Span::call_site())
147 }
148}
149
150impl Fold for Visitor {
151 fn fold_block(&mut self, block: syn::Block) -> syn::Block {
152 syn::Block {
153 brace_token: block.brace_token,
154 stmts: block
155 .stmts
156 .into_iter()
157 .flat_map(|stmt| {
158 let prior = std::mem::replace(&mut self.pinned, vec![]);
159 let stmt = self.fold_stmt(stmt);
160 std::mem::replace(&mut self.pinned, prior)
161 .into_iter()
162 .flat_map(|(ident, expr)| {
163 syn::parse::<syn::Block>(
164 quote!({
165 let mut #ident = #expr;
166 let #ident = unsafe {
167 ::core::pin::Pin::new_unchecked(&mut #ident)
168 };
169 })
170 .into(),
171 )
172 .unwrap()
173 .stmts
174 })
175 .chain(std::iter::once(stmt))
176 })
177 .collect(),
178 }
179 }
180
181 fn fold_expr(&mut self, expr: syn::Expr) -> syn::Expr {
182 let pin = syn::Ident::new("pin", proc_macro2::Span::call_site());
183 if let syn::Expr::Macro(expr) = expr {
184 if expr.mac.path.is_ident(&pin) {
185 let ident = self.gen_ident();
186 self.pinned
187 .push((ident.clone(), syn::parse(expr.mac.tokens.into()).unwrap()));
188 syn::Expr::Path(syn::ExprPath {
189 attrs: vec![],
190 qself: None,
191 path: ident.into(),
192 })
193 } else {
194 syn::fold::fold_expr_macro(self, expr).into()
195 }
196 } else {
197 syn::fold::fold_expr(self, expr)
198 }
199 }
200
201 fn fold_expr_while(&mut self, expr: syn::ExprWhile) -> syn::ExprWhile {
202 syn::ExprWhile {
203 attrs: expr.attrs,
204 label: expr.label,
205 while_token: expr.while_token,
206 cond: Box::new(if let syn::Expr::Let(cond) = *expr.cond {
207 syn::Expr::Let(syn::ExprLet {
208 expr: Box::new(syn::Expr::Block(syn::ExprBlock {
209 attrs: vec![],
210 label: None,
211 block: self.fold_block(syn::Block {
212 brace_token: syn::token::Brace {
213 span: proc_macro2::Span::call_site(),
214 },
215 stmts: vec![syn::Stmt::Expr(*cond.expr)],
216 }),
217 })),
218 ..cond
219 })
220 } else {
221 syn::Expr::Block(syn::ExprBlock {
222 attrs: vec![],
223 label: None,
224 block: self.fold_block(syn::Block {
225 brace_token: syn::token::Brace {
226 span: proc_macro2::Span::call_site(),
227 },
228 stmts: vec![syn::Stmt::Expr(*expr.cond)],
229 }),
230 })
231 }),
232 body: self.fold_block(expr.body),
233 }
234 }
235}
236
237/// The main attribute, see crate level docs for details.
238#[proc_macro_attribute]
239pub fn ergo_pin(
240 _attrs: proc_macro::TokenStream,
241 code: proc_macro::TokenStream,
242) -> proc_macro::TokenStream {
243 let mut visitor = Visitor::new();
244
245 if let Ok(mac) = syn::parse::<syn::Macro>(code.clone()) {
246 let tokens = mac.tokens;
247 if let Ok(block) = syn::parse::<syn::Block>(quote!({ #tokens }).into()) {
248 let block = visitor.fold_block(block);
249 let tokens = block.stmts.into_iter().map(|stmt| quote!(#stmt)).collect();
250 return syn::Macro { tokens, ..mac }.into_token_stream().into();
251 }
252 }
253
254 if let Ok(item) = syn::parse::<syn::Item>(code.clone()) {
255 return visitor.fold_item(item).into_token_stream().into();
256 }
257
258 if let Ok(block) = syn::parse::<syn::Block>(code.clone()) {
259 return visitor.fold_block(block).into_token_stream().into();
260 }
261
262 panic!("Could not parse input")
263}