lower_macros/
lib.rs

1//! a lil macro crate.
2//!
3//! provides a handy macro for converting `a + b` to `a.add(b)` for when you cant easily overload the `Add` trait.
4use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9    ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10        $(let $k = $v;)+
11        quote!($($tt)+)
12    }};
13}
14trait Sub {
15    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22        use syn::BinOp::*;
23        match op {
24            Add(_) => quote!((#left).add(#right)),
25            Sub(_) => quote!((#left).sub(#right)),
26            Mul(_) => quote!((#left).mul(#right)),
27            Div(_) => quote!((#left).div(#right)),
28            Rem(_) => quote!((#left).rem(#right)),
29            And(_) => quote!((#left).and(#right)),
30            Or(_) => quote!((#left).or(#right)),
31            BitXor(_) => quote!((#left).bitxor(#right)),
32            BitAnd(_) => quote!((#left).bitand(#right)),
33            BitOr(_) => quote!((#left).bitor(#right)),
34            Shl(_) => quote!((#left).shl(#right)),
35            Shr(_) => quote!((#left).shr(#right)),
36            Eq(_) => quote!((#left).eq(#right)),
37            Lt(_) => quote!((#left).lt(#right)),
38            Le(_) => quote!((#left).le(#right)),
39            Ne(_) => quote!((#left).ne(#right)),
40            Ge(_) => quote!((#left).ge(#right)),
41            Gt(_) => quote!((#left).gt(#right)),
42            // don't support assigning ops
43            e => {
44                Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45            }
46        }
47    }
48
49    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50        match op {
51            UnOp::Deref(_) => quote!((#x).deref()),
52            UnOp::Not(_) => quote!((#x).not()),
53            UnOp::Neg(_) => quote!((#x).neg()),
54            e => Error::new(
55                e.span(),
56                "it would appear a new operation has been added! please tell me.",
57            )
58            .to_compile_error(),
59        }
60    }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66        use syn::BinOp::*;
67        match op {
68            Add(_) => quote!((#left).wrapping_add(#right)),
69            Sub(_) => quote!((#left).wrapping_sub(#right)),
70            Mul(_) => quote!((#left).wrapping_mul(#right)),
71            Div(_) => quote!((#left).wrapping_div(#right)),
72            Rem(_) => quote!((#left).wrapping_rem(#right)),
73            Shl(_) => quote!((#left).wrapping_shl(#right)),
74            Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76            _ => quote!((#left) #op (#right)),
77        }
78    }
79
80    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
81        match op {
82            UnOp::Neg(_) => quote!((#x).wrapping_neg()),
83            _ => quote!(#op #x),
84        }
85    }
86}
87
88struct Algebraic;
89impl Sub for Algebraic {
90    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
91        use syn::BinOp::*;
92        match op {
93            Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
94            Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
95            Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
96            Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
97            Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
98            And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
99            _ => quote!((#left) #op (#right)),
100        }
101    }
102
103    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
104        quote!(#op #x)
105    }
106}
107
108struct Fast;
109impl Sub for Fast {
110    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
111        use syn::BinOp::*;
112        match op {
113            Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
114            Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
115            Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
116            Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
117            Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
118            And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
119            Eq(_) => quote!(/* eq */ ((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
120            _ => quote!((#left) #op (#right)),
121        }
122    }
123
124    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
125        quote!(#op #x)
126    }
127}
128
129fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
130    let walk = |e| walk(sub, e);
131    let map_block = |b| map_block(sub, b);
132    match e {
133        Expr::Binary(ExprBinary {
134            left, op, right, ..
135        }) => {
136            let left = walk(*left);
137            let right = walk(*right);
138            sub.sub_bin(op, left, right)
139        }
140        Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
141        Expr::Break(ExprBreak {
142            label,
143            expr: Some(expr),
144            ..
145        }) => {
146            let expr = walk(*expr);
147            quote!(#label #expr)
148        }
149        Expr::Call(ExprCall { func, args, .. }) => {
150            let f = walk(*func);
151            let args = args.into_iter().map(walk);
152            quote!(#f ( #(#args),* ))
153        }
154        Expr::Closure(ExprClosure {
155            lifetimes,
156            constness,
157            movability,
158            asyncness,
159            capture,
160            inputs,
161            output,
162            body,
163            ..
164        }) => {
165            let body = walk(*body);
166            quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output #body)
167        }
168        Expr::ForLoop(ExprForLoop {
169            label,
170            pat,
171            expr,
172            body,
173            ..
174        }) => {
175            let (expr, body) = (walk(*expr), map_block(body));
176            quote!(#label for #pat in #expr #body)
177        }
178        Expr::Let(ExprLet { pat, expr, .. }) => {
179            quote_with!(expr = walk(*expr) => let #pat = #expr)
180        }
181        Expr::Const(ExprConst { block, .. }) => {
182            quote_with!(block =map_block(block) => const #block)
183        }
184        Expr::Range(ExprRange {
185            start, limits, end, ..
186        }) => {
187            let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
188            quote!((#start #limits #end))
189        }
190        Expr::Return(ExprReturn { expr, .. }) => {
191            let expr = expr.map(|x| walk(*x));
192            quote!(return #expr;)
193        }
194        Expr::Try(ExprTry { expr, .. }) => {
195            let expr = walk(*expr);
196            quote!(#expr ?)
197        }
198        Expr::TryBlock(ExprTryBlock { block, .. }) => {
199            let block = map_block(block);
200            quote!(try #block)
201        }
202        Expr::Unsafe(ExprUnsafe { block, .. }) => {
203            quote_with!(block =map_block(block) => unsafe #block)
204        }
205        Expr::While(ExprWhile {
206            label, cond, body, ..
207        }) => {
208            quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
209        }
210        Expr::Loop(ExprLoop { label, body, .. }) => {
211            quote_with!(body =map_block(body) => #label loop #body)
212        }
213        Expr::Reference(ExprReference {
214            mutability, expr, ..
215        }) => {
216            let expr = walk(*expr);
217            quote!(& #mutability #expr)
218        }
219        Expr::MethodCall(ExprMethodCall {
220            receiver,
221            method,
222            turbofish,
223            args,
224            ..
225        }) => {
226            let receiver = walk(*receiver);
227            let args = args.into_iter().map(walk);
228            quote!(#receiver . #method #turbofish (#(#args,)*))
229        }
230        Expr::If(ExprIf {
231            cond,
232            then_branch,
233            else_branch: Some((_, else_branch)),
234            ..
235        }) => {
236            let (cond, then_branch, else_branch) =
237                (walk(*cond), map_block(then_branch), walk(*else_branch));
238            quote!(if #cond #then_branch else #else_branch)
239        }
240        Expr::If(ExprIf {
241            cond, then_branch, ..
242        }) => {
243            let (cond, then_branch) = (walk(*cond), map_block(then_branch));
244            quote!(if #cond #then_branch)
245        }
246        Expr::Async(ExprAsync {
247            attrs,
248            capture,
249            block,
250            ..
251        }) => {
252            let block = map_block(block);
253            quote!(#(#attrs)* async #capture #block)
254        }
255        Expr::Await(ExprAwait { base, .. }) => {
256            let base = walk(*base);
257            quote!(#base.await)
258        }
259        Expr::Assign(ExprAssign { left, right, .. }) => {
260            let (left, right) = (walk(*left), walk(*right));
261            quote!(#left = #right;)
262        }
263        Expr::Paren(ExprParen { expr, .. }) => {
264            let expr = walk(*expr);
265            quote!(#expr)
266        }
267        Expr::Tuple(ExprTuple { elems, .. }) => {
268            let ts = elems.into_iter().map(walk);
269            quote!((#(#ts,)*))
270        }
271        Expr::Array(ExprArray { elems, .. }) => {
272            let ts = elems.into_iter().map(walk);
273            quote!([#(#ts,)*])
274        }
275        Expr::Repeat(ExprRepeat { expr, len, .. }) => {
276            let x = walk(*expr);
277            let len = walk(*len);
278            quote!([ #x ; #len ])
279        }
280        Expr::Block(ExprBlock {
281            block,
282            label: Some(label),
283            ..
284        }) => {
285            let b = map_block(block);
286            quote! { #label: #b }
287        }
288        Expr::Block(ExprBlock { block, .. }) => map_block(block),
289        e => quote!(#e),
290    }
291}
292
293fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
294    let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
295    quote! { { #(#stmts)* } }
296}
297
298fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
299    let walk = |e| walk(sub, e);
300    match x {
301        Stmt::Local(Local {
302            pat,
303            init:
304                Some(LocalInit {
305                    expr,
306                    diverge: Some((_, diverge)),
307                    ..
308                }),
309            ..
310        }) => {
311            let expr = walk(*expr);
312            let diverge = walk(*diverge);
313            quote!(let #pat = #expr else { #diverge };)
314        }
315        Stmt::Local(Local {
316            pat,
317            init: Some(LocalInit { expr, .. }),
318            ..
319        }) => {
320            let expr = walk(*expr);
321            quote!(let #pat = #expr;)
322        }
323        Stmt::Item(x) => walk_item(sub, x),
324        Stmt::Expr(e, t) => {
325            let e = walk(e);
326            quote!(#e #t)
327        }
328        e => quote!(#e),
329    }
330}
331
332fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
333    let walk = |e| walk(sub, e);
334    match x {
335        Item::Const(ItemConst {
336            vis,
337            ident,
338            ty,
339            expr,
340            ..
341        }) => {
342            let expr = walk(*expr);
343            quote!(#vis const #ident : #ty = #expr;)
344        }
345        Item::Fn(ItemFn {
346            vis,
347            attrs,
348            sig,
349            block,
350        }) => {
351            let block = map_block(sub, *block);
352            quote!( #(#attrs)* #vis #sig #block)
353        }
354        Item::Impl(ItemImpl {
355            attrs,
356            unsafety,
357            defaultness,
358            generics,
359            trait_,
360            self_ty,
361            items,
362            ..
363        }) => {
364            let items = items.into_iter().map(|x| match x {
365                ImplItem::Const(ImplItemConst {
366                    vis,
367                    attrs,
368                    defaultness,
369                    ident,
370                    ty,
371                    expr,
372                    ..
373                }) => {
374                    let expr = walk(expr);
375                    quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
376                }
377                ImplItem::Fn(ImplItemFn {
378                    attrs,
379                    vis,
380                    defaultness,
381                    sig,
382                    block,
383                }) => {
384                    let block = map_block(sub, block);
385                    quote!(#(#attrs)* #vis #defaultness #sig #block)
386                }
387                e => quote!(#e),
388            });
389            let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
390            quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
391        }
392        Item::Mod(ItemMod {
393            attrs,
394            vis,
395            ident,
396            content: Some((_, content)),
397            ..
398        }) => {
399            let content = content.into_iter().map(|x| walk_item(sub, x));
400            quote!(#(#attrs)* #vis mod #ident { #(#content)* })
401        }
402        Item::Static(ItemStatic {
403            attrs,
404            vis,
405            mutability,
406            ident,
407            ty,
408            expr,
409            ..
410        }) => {
411            let expr = walk(*expr);
412            quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
413        }
414        e => quote!(#e),
415    }
416}
417
418macro_rules! walk {
419    ($input:ident,$t:expr) => {
420        match parse::<Expr>($input.clone())
421            .map(|x| walk(&$t, x))
422            .map_err(|x| x.to_compile_error().into_token_stream())
423        {
424            Ok(x) => x,
425            Err(e) => parse::<Stmt>($input)
426                .map(|x| walk_stmt(&$t, x))
427                .unwrap_or(e),
428        }
429        .into()
430    };
431}
432
433#[proc_macro]
434pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
435    walk!(input, Basic {})
436}
437
438#[proc_macro]
439pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
440    walk!(input, Fast {})
441}
442
443#[proc_macro]
444pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
445    walk!(input, Algebraic {})
446}
447
448#[proc_macro]
449pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
450    walk!(input, Wrapping {})
451}
452
453#[proc_macro_attribute]
454pub fn apply(
455    args: proc_macro::TokenStream,
456    input: proc_macro::TokenStream,
457) -> proc_macro::TokenStream {
458    match &*args.to_string() {
459        "basic" | "" => math(input),
460        "fast" => fast(input),
461        "algebraic" => algebraic(input),
462        "wrapping" => wrapping(input),
463        _ => quote! { compile_error!("type must be {fast, basic, algebraic, wrapping}") }.into(),
464    }
465}