nb_blocking_util/
lib.rs

1use proc_macro2::Span;
2use quote::quote;
3use syn::{
4    parse_macro_input,
5    token::{Brace, Break},
6    Block, Expr, ExprBlock, ExprBreak, ItemFn, Stmt,
7};
8
9#[proc_macro_attribute]
10pub fn blocking(
11    _attr: proc_macro::TokenStream,
12    input: proc_macro::TokenStream,
13) -> proc_macro::TokenStream {
14    let input = parse_macro_input!(input as ItemFn);
15
16    let result = blocking_impl(input);
17    let r = quote! {#result};
18
19    r.into()
20}
21
22fn df_expr() -> Expr {
23    Expr::Break(ExprBreak {
24        attrs: vec![],
25        break_token: Break {
26            span: Span::call_site(),
27        },
28        expr: None,
29        label: None,
30    })
31}
32
33fn blocking_stmts(stmts: &mut [Stmt]) {
34    for stmt in stmts.iter_mut() {
35        match stmt {
36            Stmt::Local(local_var) => {
37                if let Some((_, exp)) = &mut local_var.init {
38                    blocking_expr(exp);
39                }
40            }
41            Stmt::Expr(e) => {
42                blocking_expr(e);
43            }
44            Stmt::Semi(e, _) => {
45                blocking_expr(e);
46            }
47            _ => {}
48        }
49    }
50}
51
52fn blocking_block(block: &mut Block) {
53    blocking_stmts(&mut block.stmts);
54}
55
56fn blocking_expr(e: &mut Expr) {
57    match e {
58        Expr::Array(a) => {
59            for el in a.elems.iter_mut() {
60                blocking_expr(el);
61            }
62        }
63        Expr::Assign(a) => {
64            blocking_expr(&mut a.left);
65            blocking_expr(&mut a.right);
66        }
67        Expr::AssignOp(a) => {
68            blocking_expr(&mut a.left);
69            blocking_expr(&mut a.right);
70        }
71        Expr::Async(a) => {
72            *e = Expr::Block(ExprBlock {
73                block: std::mem::replace(
74                    &mut a.block,
75                    Block {
76                        brace_token: Brace {
77                            span: Span::call_site(),
78                        },
79                        stmts: vec![],
80                    },
81                ),
82                attrs: vec![],
83                label: None,
84            });
85
86            blocking_expr(e);
87        }
88        Expr::Await(a) => {
89            *e = std::mem::replace(&mut *a.base, df_expr());
90            blocking_expr(e);
91        }
92        Expr::Binary(b) => {
93            blocking_expr(&mut b.left);
94            blocking_expr(&mut b.right);
95        }
96        Expr::Block(b) => {
97            blocking_block(&mut b.block);
98        }
99        Expr::Box(b) => {
100            blocking_expr(&mut b.expr);
101        }
102        Expr::Break(b) => {
103            if let Some(brexpr) = &mut b.expr {
104                blocking_expr(brexpr);
105            }
106        }
107        Expr::Call(c) => {
108            blocking_expr(&mut c.func);
109
110            for arg in c.args.iter_mut() {
111                blocking_expr(arg);
112            }
113        }
114        Expr::Cast(c) => {
115            blocking_expr(&mut c.expr);
116        }
117        Expr::Field(f) => {
118            blocking_expr(&mut f.base);
119        }
120        Expr::ForLoop(e) => {
121            blocking_expr(&mut e.expr);
122
123            blocking_block(&mut e.body);
124        }
125        Expr::Group(g) => {
126            blocking_expr(&mut g.expr);
127        }
128        Expr::If(cond) => {
129            blocking_expr(&mut cond.cond);
130
131            blocking_block(&mut cond.then_branch);
132
133            if let Some((_, else_branch)) = &mut cond.else_branch {
134                blocking_expr(else_branch);
135            }
136        }
137        Expr::Index(ind) => {
138            blocking_expr(&mut ind.expr);
139
140            blocking_expr(&mut ind.index);
141        }
142        Expr::Let(l) => {
143            blocking_expr(&mut l.expr);
144        }
145        Expr::Loop(l) => {
146            blocking_block(&mut l.body);
147        }
148        Expr::Match(m) => {
149            blocking_expr(&mut m.expr);
150
151            for arm in m.arms.iter_mut() {
152                blocking_expr(&mut arm.body);
153            }
154        }
155        Expr::MethodCall(mc) => {
156            blocking_expr(&mut mc.receiver);
157
158            for arg in mc.args.iter_mut() {
159                blocking_expr(arg);
160            }
161        }
162        Expr::Paren(p) => {
163            blocking_expr(&mut p.expr);
164        }
165        Expr::Range(r) => {
166            if let Some(from) = &mut r.from {
167                blocking_expr(from);
168            }
169            if let Some(to) = &mut r.to {
170                blocking_expr(to);
171            }
172        }
173        Expr::Reference(r) => {
174            blocking_expr(&mut r.expr);
175        }
176        Expr::Repeat(r) => {
177            blocking_expr(&mut r.expr);
178        }
179        Expr::Return(r) => {
180            if let Some(rval) = &mut r.expr {
181                blocking_expr(rval);
182            }
183        }
184        Expr::Try(t) => {
185            blocking_expr(&mut t.expr);
186        }
187        Expr::TryBlock(tb) => {
188            blocking_block(&mut tb.block);
189        }
190        Expr::Tuple(tp) => {
191            for tv in tp.elems.iter_mut() {
192                blocking_expr(tv);
193            }
194        }
195        Expr::Unary(un) => {
196            blocking_expr(&mut un.expr);
197        }
198        Expr::Unsafe(uns) => {
199            blocking_block(&mut uns.block);
200        }
201        Expr::While(w) => {
202            blocking_expr(&mut w.cond);
203
204            blocking_block(&mut w.body);
205        }
206        Expr::Yield(y) => {
207            if let Some(e) = &mut y.expr {
208                blocking_expr(e);
209            }
210        }
211        _ => {}
212    }
213}
214
215fn blocking_impl(mut input: ItemFn) -> ItemFn {
216    if input.sig.asyncness.is_none() {
217        return input;
218    }
219
220    input.sig.asyncness = None;
221
222    blocking_stmts(&mut input.block.stmts);
223
224    input
225}