lower_macros/
lib.rs

1//! a lil macro crate.
2//!
3//! provides a handy macro for converting `a + b` to `a.add(b)` for when you cant easily overload the `Add` trait.
4use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9    ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10        $(let $k = $v;)+
11        quote!($($tt)+)
12    }};
13}
14trait Sub {
15    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22        use syn::BinOp::*;
23        match op {
24            Add(_) => quote!((#left).add(#right)),
25            Sub(_) => quote!((#left).sub(#right)),
26            Mul(_) => quote!((#left).mul(#right)),
27            Div(_) => quote!((#left).div(#right)),
28            Rem(_) => quote!((#left).rem(#right)),
29            And(_) => quote!((#left).and(#right)),
30            Or(_) => quote!((#left).or(#right)),
31            BitXor(_) => quote!((#left).bitxor(#right)),
32            BitAnd(_) => quote!((#left).bitand(#right)),
33            BitOr(_) => quote!((#left).bitor(#right)),
34            Shl(_) => quote!((#left).shl(#right)),
35            Shr(_) => quote!((#left).shr(#right)),
36            Eq(_) => quote!((#left).eq(#right)),
37            Lt(_) => quote!((#left).lt(#right)),
38            Le(_) => quote!((#left).le(#right)),
39            Ne(_) => quote!((#left).ne(#right)),
40            Ge(_) => quote!((#left).ge(#right)),
41            Gt(_) => quote!((#left).gt(#right)),
42            // don't support assigning ops
43            e => {
44                Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45            }
46        }
47    }
48
49    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50        match op {
51            UnOp::Deref(_) => quote!((#x).deref()),
52            UnOp::Not(_) => quote!((#x).not()),
53            UnOp::Neg(_) => quote!((#x).neg()),
54            e => Error::new(
55                e.span(),
56                "it would appear a new operation has been added! please tell me.",
57            )
58            .to_compile_error(),
59        }
60    }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66        use syn::BinOp::*;
67        match op {
68            Add(_) => quote!((#left).wrapping_add(#right)),
69            Sub(_) => quote!((#left).wrapping_sub(#right)),
70            Mul(_) => quote!((#left).wrapping_mul(#right)),
71            Div(_) => quote!((#left).wrapping_div(#right)),
72            Rem(_) => quote!((#left).wrapping_rem(#right)),
73            Shl(_) => quote!((#left).wrapping_shl(#right)),
74            Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76            SubAssign(_) => quote!(#left = #left.wrapping_sub(#right)),
77            AddAssign(_) => quote!(#left = #left.wrapping_add(#right)),
78            MulAssign(_) => quote!(#left = #left.wrapping_mul(#right)),
79            DivAssign(_) => quote!(#left = #left.wrapping_div(#right)),
80            RemAssign(_) => quote!(#left = #left.wrapping_rem(#right)),
81            ShlAssign(_) => quote!(#left = #left.wrapping_shl(#right)),
82            ShrAssign(_) => quote!(#left = #left.wrapping_shr(#right)),
83
84            _ => quote!((#left) #op (#right)),
85        }
86    }
87
88    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
89        match op {
90            UnOp::Neg(_) => quote!((#x).wrapping_neg()),
91            _ => quote!(#op #x),
92        }
93    }
94}
95
96struct Saturating;
97impl Sub for Saturating {
98    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
99        use syn::BinOp::*;
100        match op {
101            Add(_) => quote!((#left).saturating_add(#right)),
102            Sub(_) => quote!((#left).saturating_sub(#right)),
103            Mul(_) => quote!((#left).saturating_mul(#right)),
104            Div(_) => quote!((#left).saturating_div(#right)),
105            Rem(_) => quote!((#left).saturating_rem(#right)),
106            Shl(_) => quote!((#left).saturating_shl(#right)),
107            Shr(_) => quote!((#left).saturating_shr(#right)),
108
109            SubAssign(_) => quote!(#left = #left.saturating_sub(#right)),
110            AddAssign(_) => quote!(#left = #left.saturating_add(#right)),
111            MulAssign(_) => quote!(#left = #left.saturating_mul(#right)),
112            DivAssign(_) => quote!(#left = #left.saturating_div(#right)),
113            RemAssign(_) => quote!(#left = #left.saturating_rem(#right)),
114            ShlAssign(_) => quote!(#left = #left.saturating_shl(#right)),
115            ShrAssign(_) => quote!(#left = #left.saturating_shr(#right)),
116
117            _ => quote!((#left) #op (#right)),
118        }
119    }
120
121    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
122        match op {
123            UnOp::Neg(_) => quote!((#x).saturating_neg()),
124            _ => quote!(#op #x),
125        }
126    }
127}
128
129struct Algebraic;
130impl Sub for Algebraic {
131    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
132        use syn::BinOp::*;
133        match op {
134            Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
135            Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
136            Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
137            Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
138            Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
139            And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
140            _ => quote!((#left) #op (#right)),
141        }
142    }
143
144    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
145        quote!(#op #x)
146    }
147}
148
149struct Fast;
150impl Sub for Fast {
151    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
152        use syn::BinOp::*;
153        match op {
154            Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
155            Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
156            Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
157            Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
158            Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
159            And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
160            Eq(_) => quote!(/* eq */ ((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
161            _ => quote!((#left) #op (#right)),
162        }
163    }
164
165    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
166        quote!(#op #x)
167    }
168}
169
170fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
171    let walk = |e| walk(sub, e);
172    let map_block = |b| map_block(sub, b);
173    match e {
174        Expr::Binary(ExprBinary {
175            left, op, right, ..
176        }) => {
177            let left = walk(*left);
178            let right = walk(*right);
179            sub.sub_bin(op, left, right)
180        }
181        Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
182        Expr::Break(ExprBreak {
183            label,
184            expr: Some(expr),
185            ..
186        }) => {
187            let expr = walk(*expr);
188            quote!(#label #expr)
189        }
190        Expr::Call(ExprCall { func, args, .. }) => {
191            let f = walk(*func);
192            let args = args.into_iter().map(walk);
193            quote!(#f ( #(#args),* ))
194        }
195        Expr::Closure(ExprClosure {
196            lifetimes,
197            constness,
198            movability,
199            asyncness,
200            capture,
201            inputs,
202            output,
203            body,
204            ..
205        }) => {
206            let body = walk(*body);
207            quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output { #body })
208        }
209        Expr::ForLoop(ExprForLoop {
210            label,
211            pat,
212            expr,
213            body,
214            ..
215        }) => {
216            let (expr, body) = (walk(*expr), map_block(body));
217            quote!(#label for #pat in #expr #body)
218        }
219        Expr::Let(ExprLet { pat, expr, .. }) => {
220            quote_with!(expr = walk(*expr) => let #pat = #expr)
221        }
222        Expr::Const(ExprConst { block, .. }) => {
223            quote_with!(block =map_block(block) => const #block)
224        }
225        Expr::Range(ExprRange {
226            start, limits, end, ..
227        }) => {
228            let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
229            quote!((#start #limits #end))
230        }
231        Expr::Return(ExprReturn { expr, .. }) => {
232            let expr = expr.map(|x| walk(*x));
233            quote!(return #expr;)
234        }
235        Expr::Try(ExprTry { expr, .. }) => {
236            let expr = walk(*expr);
237            quote!(#expr ?)
238        }
239        Expr::TryBlock(ExprTryBlock { block, .. }) => {
240            let block = map_block(block);
241            quote!(try #block)
242        }
243        Expr::Unsafe(ExprUnsafe { block, .. }) => {
244            quote_with!(block =map_block(block) => unsafe #block)
245        }
246        Expr::While(ExprWhile {
247            label, cond, body, ..
248        }) => {
249            quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
250        }
251        Expr::Index(ExprIndex { expr, index, .. }) => {
252            let expr = walk(*expr);
253            let index = walk(*index);
254            quote!(#expr [ #index ])
255        }
256        Expr::Loop(ExprLoop { label, body, .. }) => {
257            quote_with!(body =map_block(body) => #label loop #body)
258        }
259        Expr::Reference(ExprReference {
260            mutability, expr, ..
261        }) => {
262            let expr = walk(*expr);
263            quote!(& #mutability #expr)
264        }
265        Expr::MethodCall(ExprMethodCall {
266            receiver,
267            method,
268            turbofish,
269            args,
270            ..
271        }) => {
272            let receiver = walk(*receiver);
273            let args = args.into_iter().map(walk);
274            quote!(#receiver . #method #turbofish (#(#args,)*))
275        }
276        Expr::Match(ExprMatch { expr, arms, .. }) => {
277            let arms = arms.into_iter().map(
278                |Arm {
279                     pat,
280                     guard,
281
282                     body,
283                     //  comma,
284                     ..
285                 }| {
286                    let b = walk(*body);
287                    let guard = match guard {
288                        Some((i, x)) => {
289                            let z = walk(*x);
290                            quote! { #i #z }
291                        }
292                        None => quote! {},
293                    };
294                    quote! { #pat #guard => { #b } }
295                },
296            );
297            quote!(match #expr { #(#arms)* })
298        }
299        Expr::If(ExprIf {
300            cond,
301            then_branch,
302            else_branch: Some((_, else_branch)),
303            ..
304        }) => {
305            let (cond, then_branch, else_branch) =
306                (walk(*cond), map_block(then_branch), walk(*else_branch));
307            quote!(if #cond #then_branch else #else_branch)
308        }
309        Expr::If(ExprIf {
310            cond, then_branch, ..
311        }) => {
312            let (cond, then_branch) = (walk(*cond), map_block(then_branch));
313            quote!(if #cond #then_branch)
314        }
315        Expr::Async(ExprAsync {
316            attrs,
317            capture,
318            block,
319            ..
320        }) => {
321            let block = map_block(block);
322            quote!(#(#attrs)* async #capture #block)
323        }
324        Expr::Await(ExprAwait { base, .. }) => {
325            let base = walk(*base);
326            quote!(#base.await)
327        }
328        Expr::Assign(ExprAssign { left, right, .. }) => {
329            let (left, right) = (walk(*left), walk(*right));
330            quote!(#left = #right;)
331        }
332        Expr::Paren(ExprParen { expr, .. }) => {
333            let expr = walk(*expr);
334            quote!(#expr)
335        }
336        Expr::Tuple(ExprTuple { elems, .. }) => {
337            let ts = elems.into_iter().map(walk);
338            quote!((#(#ts,)*))
339        }
340        Expr::Array(ExprArray { elems, .. }) => {
341            let ts = elems.into_iter().map(walk);
342            quote!([#(#ts,)*])
343        }
344        Expr::Repeat(ExprRepeat { expr, len, .. }) => {
345            let x = walk(*expr);
346            let len = walk(*len);
347            quote!([ #x ; #len ])
348        }
349        Expr::Block(ExprBlock {
350            block,
351            label: Some(label),
352            ..
353        }) => {
354            let b = map_block(block);
355            quote! { #label: #b }
356        }
357        Expr::Block(ExprBlock { block, .. }) => map_block(block),
358        e => quote!(#e),
359    }
360}
361
362fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
363    let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
364    quote! { { #(#stmts)* } }
365}
366
367fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
368    let walk = |e| walk(sub, e);
369    match x {
370        Stmt::Local(Local {
371            pat,
372            init:
373                Some(LocalInit {
374                    expr,
375                    diverge: Some((_, diverge)),
376                    ..
377                }),
378            ..
379        }) => {
380            let expr = walk(*expr);
381            let diverge = walk(*diverge);
382            quote!(let #pat = #expr else { #diverge };)
383        }
384        Stmt::Local(Local {
385            pat,
386            init: Some(LocalInit { expr, .. }),
387            ..
388        }) => {
389            let expr = walk(*expr);
390            quote!(let #pat = #expr;)
391        }
392        Stmt::Item(x) => walk_item(sub, x),
393        Stmt::Expr(e, t) => {
394            let e = walk(e);
395            quote!(#e #t)
396        }
397        e => quote!(#e),
398    }
399}
400
401fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
402    let walk = |e| walk(sub, e);
403    match x {
404        Item::Const(ItemConst {
405            vis,
406            ident,
407            ty,
408            expr,
409            ..
410        }) => {
411            let expr = walk(*expr);
412            quote!(#vis const #ident : #ty = #expr;)
413        }
414        Item::Fn(ItemFn {
415            vis,
416            attrs,
417            sig,
418            block,
419        }) => {
420            let block = map_block(sub, *block);
421            quote!( #(#attrs)* #vis #sig #block)
422        }
423        Item::Impl(ItemImpl {
424            attrs,
425            unsafety,
426            defaultness,
427            generics,
428            trait_,
429            self_ty,
430            items,
431            ..
432        }) => {
433            let items = items.into_iter().map(|x| match x {
434                ImplItem::Const(ImplItemConst {
435                    vis,
436                    attrs,
437                    defaultness,
438                    ident,
439                    ty,
440                    expr,
441                    ..
442                }) => {
443                    let expr = walk(expr);
444                    quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
445                }
446                ImplItem::Fn(ImplItemFn {
447                    attrs,
448                    vis,
449                    defaultness,
450                    sig,
451                    block,
452                }) => {
453                    let block = map_block(sub, block);
454                    quote!(#(#attrs)* #vis #defaultness #sig #block)
455                }
456                e => quote!(#e),
457            });
458            let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
459            quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
460        }
461        Item::Mod(ItemMod {
462            attrs,
463            vis,
464            ident,
465            content: Some((_, content)),
466            ..
467        }) => {
468            let content = content.into_iter().map(|x| walk_item(sub, x));
469            quote!(#(#attrs)* #vis mod #ident { #(#content)* })
470        }
471        Item::Static(ItemStatic {
472            attrs,
473            vis,
474            mutability,
475            ident,
476            ty,
477            expr,
478            ..
479        }) => {
480            let expr = walk(*expr);
481            quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
482        }
483        e => quote!(#e),
484    }
485}
486
487macro_rules! walk {
488    ($input:ident,$t:expr) => {
489        match parse::<Expr>($input.clone())
490            .map(|x| walk(&$t, x))
491            .map_err(|x| x.to_compile_error().into_token_stream())
492        {
493            Ok(x) => x,
494            Err(e) => parse::<Stmt>($input)
495                .map(|x| walk_stmt(&$t, x))
496                .unwrap_or(e),
497        }
498        .into()
499    };
500}
501
502#[proc_macro]
503pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
504    walk!(input, Basic {})
505}
506
507#[proc_macro]
508pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
509    walk!(input, Fast {})
510}
511
512#[proc_macro]
513pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
514    walk!(input, Algebraic {})
515}
516
517#[proc_macro]
518pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
519    walk!(input, Wrapping {})
520}
521
522#[proc_macro]
523pub fn saturating(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
524    walk!(input, Saturating {})
525}
526
527#[proc_macro_attribute]
528pub fn apply(
529    args: proc_macro::TokenStream,
530    input: proc_macro::TokenStream,
531) -> proc_macro::TokenStream {
532    match &*args.to_string() {
533        "basic" | "" => math(input),
534        "fast" => fast(input),
535        "algebraic" => algebraic(input),
536        "wrapping" => wrapping(input),
537        "saturating" => saturating(input),
538        _ => {
539            quote! { compile_error!("type must be {fast, basic, algebraic, wrapping, saturating}") }
540                .into()
541        }
542    }
543}