lower_macros/
lib.rs

1//! a lil macro crate.
2//!
3//! provides a handy macro for converting `a + b` to `a.add(b)` for when you cant easily overload the `Add` trait.
4use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9    ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10        $(let $k = $v;)+
11        quote!($($tt)+)
12    }};
13}
14trait Sub {
15    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22        use syn::BinOp::*;
23        match op {
24            Add(_) => quote!((#left).add(#right)),
25            Sub(_) => quote!((#left).sub(#right)),
26            Mul(_) => quote!((#left).mul(#right)),
27            Div(_) => quote!((#left).div(#right)),
28            Rem(_) => quote!((#left).rem(#right)),
29            And(_) => quote!((#left).and(#right)),
30            Or(_) => quote!((#left).or(#right)),
31            BitXor(_) => quote!((#left).bitxor(#right)),
32            BitAnd(_) => quote!((#left).bitand(#right)),
33            BitOr(_) => quote!((#left).bitor(#right)),
34            Shl(_) => quote!((#left).shl(#right)),
35            Shr(_) => quote!((#left).shr(#right)),
36            Eq(_) => quote!((#left).eq(#right)),
37            Lt(_) => quote!((#left).lt(#right)),
38            Le(_) => quote!((#left).le(#right)),
39            Ne(_) => quote!((#left).ne(#right)),
40            Ge(_) => quote!((#left).ge(#right)),
41            Gt(_) => quote!((#left).gt(#right)),
42            // don't support assigning ops
43            e => {
44                Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45            }
46        }
47    }
48
49    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50        match op {
51            UnOp::Deref(_) => quote!((#x).deref()),
52            UnOp::Not(_) => quote!((#x).not()),
53            UnOp::Neg(_) => quote!((#x).neg()),
54            e => Error::new(
55                e.span(),
56                "it would appear a new operation has been added! please tell me.",
57            )
58            .to_compile_error(),
59        }
60    }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66        use syn::BinOp::*;
67        match op {
68            Add(_) => quote!((#left).wrapping_add(#right)),
69            Sub(_) => quote!((#left).wrapping_sub(#right)),
70            Mul(_) => quote!((#left).wrapping_mul(#right)),
71            Div(_) => quote!((#left).wrapping_div(#right)),
72            Rem(_) => quote!((#left).wrapping_rem(#right)),
73            Shl(_) => quote!((#left).wrapping_shl(#right)),
74            Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76            SubAssign(_) => quote!(#left = #left.wrapping_sub(#right)),
77            AddAssign(_) => quote!(#left = #left.wrapping_add(#right)),
78            MulAssign(_) => quote!(#left = #left.wrapping_mul(#right)),
79            DivAssign(_) => quote!(#left = #left.wrapping_div(#right)),
80            RemAssign(_) => quote!(#left = #left.wrapping_rem(#right)),
81            ShlAssign(_) => quote!(#left = #left.wrapping_shl(#right)),
82            ShrAssign(_) => quote!(#left = #left.wrapping_shr(#right)),
83
84            _ => quote!((#left) #op (#right)),
85        }
86    }
87
88    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
89        match op {
90            UnOp::Neg(_) => quote!((#x).wrapping_neg()),
91            _ => quote!(#op #x),
92        }
93    }
94}
95
96struct Saturating;
97impl Sub for Saturating {
98    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
99        use syn::BinOp::*;
100        match op {
101            Add(_) => quote!((#left).saturating_add(#right)),
102            Sub(_) => quote!((#left).saturating_sub(#right)),
103            Mul(_) => quote!((#left).saturating_mul(#right)),
104            Div(_) => quote!((#left).saturating_div(#right)),
105            Rem(_) => quote!((#left).saturating_rem(#right)),
106            Shl(_) => quote!((#left).saturating_shl(#right)),
107            Shr(_) => quote!((#left).saturating_shr(#right)),
108
109            SubAssign(_) => quote!(#left = #left.saturating_sub(#right)),
110            AddAssign(_) => quote!(#left = #left.saturating_add(#right)),
111            MulAssign(_) => quote!(#left = #left.saturating_mul(#right)),
112            DivAssign(_) => quote!(#left = #left.saturating_div(#right)),
113            RemAssign(_) => quote!(#left = #left.saturating_rem(#right)),
114            ShlAssign(_) => quote!(#left = #left.saturating_shl(#right)),
115            ShrAssign(_) => quote!(#left = #left.saturating_shr(#right)),
116
117            _ => quote!((#left) #op (#right)),
118        }
119    }
120
121    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
122        match op {
123            UnOp::Neg(_) => quote!((#x).saturating_neg()),
124            _ => quote!(#op #x),
125        }
126    }
127}
128
129struct Algebraic;
130impl Sub for Algebraic {
131    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
132        use syn::BinOp::*;
133        match op {
134            Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
135            Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
136            Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
137            Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
138            Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
139            And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
140            _ => quote!((#left) #op (#right)),
141        }
142    }
143
144    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
145        quote!(#op #x)
146    }
147}
148
149struct Fast;
150impl Sub for Fast {
151    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
152        use syn::BinOp::*;
153        match op {
154            Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
155            Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
156            Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
157            Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
158            Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
159            And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
160            Eq(_) => quote!(/* eq */ ((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
161            _ => quote!((#left) #op (#right)),
162        }
163    }
164
165    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
166        quote!(#op #x)
167    }
168}
169
170fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
171    let walk = |e| walk(sub, e);
172    let map_block = |b| map_block(sub, b);
173    match e {
174        Expr::Binary(ExprBinary {
175            left, op, right, ..
176        }) => {
177            let left = walk(*left);
178            let right = walk(*right);
179            sub.sub_bin(op, left, right)
180        }
181        Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
182        Expr::Break(ExprBreak {
183            label,
184            expr: Some(expr),
185            ..
186        }) => {
187            let expr = walk(*expr);
188            quote!(#label #expr)
189        }
190        Expr::Call(ExprCall { func, args, .. }) => {
191            let f = walk(*func);
192            let args = args.into_iter().map(walk);
193            quote!(#f ( #(#args),* ))
194        }
195        Expr::Closure(ExprClosure {
196            lifetimes,
197            constness,
198            movability,
199            asyncness,
200            capture,
201            inputs,
202            output,
203            body,
204            ..
205        }) => {
206            let body = walk(*body);
207            quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output { #body })
208        }
209        Expr::ForLoop(ExprForLoop {
210            label,
211            pat,
212            expr,
213            body,
214            ..
215        }) => {
216            let (expr, body) = (walk(*expr), map_block(body));
217            quote!(#label for #pat in #expr #body)
218        }
219        Expr::Let(ExprLet { pat, expr, .. }) => {
220            quote_with!(expr = walk(*expr) => let #pat = #expr)
221        }
222        Expr::Const(ExprConst { block, .. }) => {
223            quote_with!(block =map_block(block) => const #block)
224        }
225        Expr::Range(ExprRange {
226            start, limits, end, ..
227        }) => {
228            let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
229            quote!((#start #limits #end))
230        }
231        Expr::Return(ExprReturn { expr, .. }) => {
232            let expr = expr.map(|x| walk(*x));
233            quote!(return #expr;)
234        }
235        Expr::Try(ExprTry { expr, .. }) => {
236            let expr = walk(*expr);
237            quote!(#expr ?)
238        }
239        Expr::TryBlock(ExprTryBlock { block, .. }) => {
240            let block = map_block(block);
241            quote!(try #block)
242        }
243        Expr::Unsafe(ExprUnsafe { block, .. }) => {
244            quote_with!(block =map_block(block) => unsafe #block)
245        }
246        Expr::While(ExprWhile {
247            label, cond, body, ..
248        }) => {
249            quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
250        }
251        Expr::Index(ExprIndex { expr, index, .. }) => {
252            let expr = walk(*expr);
253            let index = walk(*index);
254            quote!(#expr [ #index ])
255        }
256        Expr::Loop(ExprLoop { label, body, .. }) => {
257            quote_with!(body =map_block(body) => #label loop #body)
258        }
259        Expr::Reference(ExprReference {
260            mutability, expr, ..
261        }) => {
262            let expr = walk(*expr);
263            quote!(& #mutability #expr)
264        }
265        Expr::MethodCall(ExprMethodCall {
266            receiver,
267            method,
268            turbofish,
269            args,
270            ..
271        }) => {
272            let receiver = walk(*receiver);
273            let args = args.into_iter().map(walk);
274            quote!(#receiver . #method #turbofish (#(#args,)*))
275        }
276        Expr::If(ExprIf {
277            cond,
278            then_branch,
279            else_branch: Some((_, else_branch)),
280            ..
281        }) => {
282            let (cond, then_branch, else_branch) =
283                (walk(*cond), map_block(then_branch), walk(*else_branch));
284            quote!(if #cond #then_branch else #else_branch)
285        }
286        Expr::If(ExprIf {
287            cond, then_branch, ..
288        }) => {
289            let (cond, then_branch) = (walk(*cond), map_block(then_branch));
290            quote!(if #cond #then_branch)
291        }
292        Expr::Async(ExprAsync {
293            attrs,
294            capture,
295            block,
296            ..
297        }) => {
298            let block = map_block(block);
299            quote!(#(#attrs)* async #capture #block)
300        }
301        Expr::Await(ExprAwait { base, .. }) => {
302            let base = walk(*base);
303            quote!(#base.await)
304        }
305        Expr::Assign(ExprAssign { left, right, .. }) => {
306            let (left, right) = (walk(*left), walk(*right));
307            quote!(#left = #right;)
308        }
309        Expr::Paren(ExprParen { expr, .. }) => {
310            let expr = walk(*expr);
311            quote!(#expr)
312        }
313        Expr::Tuple(ExprTuple { elems, .. }) => {
314            let ts = elems.into_iter().map(walk);
315            quote!((#(#ts,)*))
316        }
317        Expr::Array(ExprArray { elems, .. }) => {
318            let ts = elems.into_iter().map(walk);
319            quote!([#(#ts,)*])
320        }
321        Expr::Repeat(ExprRepeat { expr, len, .. }) => {
322            let x = walk(*expr);
323            let len = walk(*len);
324            quote!([ #x ; #len ])
325        }
326        Expr::Block(ExprBlock {
327            block,
328            label: Some(label),
329            ..
330        }) => {
331            let b = map_block(block);
332            quote! { #label: #b }
333        }
334        Expr::Block(ExprBlock { block, .. }) => map_block(block),
335        e => quote!(#e),
336    }
337}
338
339fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
340    let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
341    quote! { { #(#stmts)* } }
342}
343
344fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
345    let walk = |e| walk(sub, e);
346    match x {
347        Stmt::Local(Local {
348            pat,
349            init:
350                Some(LocalInit {
351                    expr,
352                    diverge: Some((_, diverge)),
353                    ..
354                }),
355            ..
356        }) => {
357            let expr = walk(*expr);
358            let diverge = walk(*diverge);
359            quote!(let #pat = #expr else { #diverge };)
360        }
361        Stmt::Local(Local {
362            pat,
363            init: Some(LocalInit { expr, .. }),
364            ..
365        }) => {
366            let expr = walk(*expr);
367            quote!(let #pat = #expr;)
368        }
369        Stmt::Item(x) => walk_item(sub, x),
370        Stmt::Expr(e, t) => {
371            let e = walk(e);
372            quote!(#e #t)
373        }
374        e => quote!(#e),
375    }
376}
377
378fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
379    let walk = |e| walk(sub, e);
380    match x {
381        Item::Const(ItemConst {
382            vis,
383            ident,
384            ty,
385            expr,
386            ..
387        }) => {
388            let expr = walk(*expr);
389            quote!(#vis const #ident : #ty = #expr;)
390        }
391        Item::Fn(ItemFn {
392            vis,
393            attrs,
394            sig,
395            block,
396        }) => {
397            let block = map_block(sub, *block);
398            quote!( #(#attrs)* #vis #sig #block)
399        }
400        Item::Impl(ItemImpl {
401            attrs,
402            unsafety,
403            defaultness,
404            generics,
405            trait_,
406            self_ty,
407            items,
408            ..
409        }) => {
410            let items = items.into_iter().map(|x| match x {
411                ImplItem::Const(ImplItemConst {
412                    vis,
413                    attrs,
414                    defaultness,
415                    ident,
416                    ty,
417                    expr,
418                    ..
419                }) => {
420                    let expr = walk(expr);
421                    quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
422                }
423                ImplItem::Fn(ImplItemFn {
424                    attrs,
425                    vis,
426                    defaultness,
427                    sig,
428                    block,
429                }) => {
430                    let block = map_block(sub, block);
431                    quote!(#(#attrs)* #vis #defaultness #sig #block)
432                }
433                e => quote!(#e),
434            });
435            let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
436            quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
437        }
438        Item::Mod(ItemMod {
439            attrs,
440            vis,
441            ident,
442            content: Some((_, content)),
443            ..
444        }) => {
445            let content = content.into_iter().map(|x| walk_item(sub, x));
446            quote!(#(#attrs)* #vis mod #ident { #(#content)* })
447        }
448        Item::Static(ItemStatic {
449            attrs,
450            vis,
451            mutability,
452            ident,
453            ty,
454            expr,
455            ..
456        }) => {
457            let expr = walk(*expr);
458            quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
459        }
460        e => quote!(#e),
461    }
462}
463
464macro_rules! walk {
465    ($input:ident,$t:expr) => {
466        match parse::<Expr>($input.clone())
467            .map(|x| walk(&$t, x))
468            .map_err(|x| x.to_compile_error().into_token_stream())
469        {
470            Ok(x) => x,
471            Err(e) => parse::<Stmt>($input)
472                .map(|x| walk_stmt(&$t, x))
473                .unwrap_or(e),
474        }
475        .into()
476    };
477}
478
479#[proc_macro]
480pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
481    walk!(input, Basic {})
482}
483
484#[proc_macro]
485pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
486    walk!(input, Fast {})
487}
488
489#[proc_macro]
490pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
491    walk!(input, Algebraic {})
492}
493
494#[proc_macro]
495pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
496    walk!(input, Wrapping {})
497}
498
499#[proc_macro]
500pub fn saturating(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
501    walk!(input, Saturating {})
502}
503
504#[proc_macro_attribute]
505pub fn apply(
506    args: proc_macro::TokenStream,
507    input: proc_macro::TokenStream,
508) -> proc_macro::TokenStream {
509    match &*args.to_string() {
510        "basic" | "" => math(input),
511        "fast" => fast(input),
512        "algebraic" => algebraic(input),
513        "wrapping" => wrapping(input),
514        "saturating" => saturating(input),
515        _ => {
516            quote! { compile_error!("type must be {fast, basic, algebraic, wrapping, saturating}") }
517                .into()
518        }
519    }
520}