lower_macros/
lib.rs

1//! a lil macro crate.
2//!
3//! provides a handy macro for converting `a + b` to `a.add(b)` for when you cant easily overload the `Add` trait.
4use proc_macro2::TokenStream;
5use quote::{ToTokens, quote};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9    ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10        $(let $k = $v;)+
11        quote!($($tt)+)
12    }};
13}
14trait Sub {
15    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22        use syn::BinOp::*;
23        match op {
24            Add(_) => quote!((#left).add(#right)),
25            Sub(_) => quote!((#left).sub(#right)),
26            Mul(_) => quote!((#left).mul(#right)),
27            Div(_) => quote!((#left).div(#right)),
28            Rem(_) => quote!((#left).rem(#right)),
29            And(_) => quote!((#left).and(#right)),
30            Or(_) => quote!((#left).or(#right)),
31            BitXor(_) => quote!((#left).bitxor(#right)),
32            BitAnd(_) => quote!((#left).bitand(#right)),
33            BitOr(_) => quote!((#left).bitor(#right)),
34            Shl(_) => quote!((#left).shl(#right)),
35            Shr(_) => quote!((#left).shr(#right)),
36            Eq(_) => quote!((#left).eq(#right)),
37            Lt(_) => quote!((#left).lt(#right)),
38            Le(_) => quote!((#left).le(#right)),
39            Ne(_) => quote!((#left).ne(#right)),
40            Ge(_) => quote!((#left).ge(#right)),
41            Gt(_) => quote!((#left).gt(#right)),
42            // don't support assigning ops
43            e => {
44                Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45            }
46        }
47    }
48
49    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50        match op {
51            UnOp::Deref(_) => quote!((#x).deref()),
52            UnOp::Not(_) => quote!((#x).not()),
53            UnOp::Neg(_) => quote!((#x).neg()),
54            e => Error::new(
55                e.span(),
56                "it would appear a new operation has been added! please tell me.",
57            )
58            .to_compile_error(),
59        }
60    }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66        use syn::BinOp::*;
67        match op {
68            Add(_) => quote!((#left).wrapping_add(#right)),
69            Sub(_) => quote!((#left).wrapping_sub(#right)),
70            Mul(_) => quote!((#left).wrapping_mul(#right)),
71            Div(_) => quote!((#left).wrapping_div(#right)),
72            Rem(_) => quote!((#left).wrapping_rem(#right)),
73            Shl(_) => quote!((#left).wrapping_shl(#right)),
74            Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76            SubAssign(_) => quote!(#left = #left.wrapping_sub(#right)),
77            AddAssign(_) => quote!(#left = #left.wrapping_add(#right)),
78            MulAssign(_) => quote!(#left = #left.wrapping_mul(#right)),
79            DivAssign(_) => quote!(#left = #left.wrapping_div(#right)),
80            RemAssign(_) => quote!(#left = #left.wrapping_rem(#right)),
81            ShlAssign(_) => quote!(#left = #left.wrapping_shl(#right)),
82            ShrAssign(_) => quote!(#left = #left.wrapping_shr(#right)),
83
84            _ => match (|| {
85                syn::Result::Ok(
86                    ExprBinary {
87                        attrs: Default::default(),
88                        left: syn::parse(left.into())?,
89                        op,
90                        right: syn::parse(right.into())?,
91                    }
92                    .to_token_stream(),
93                )
94            })() {
95                Ok(x) => x,
96                Err(e) => e.into_compile_error(),
97            },
98        }
99    }
100
101    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
102        match op {
103            UnOp::Neg(_) => quote!((#x).wrapping_neg()),
104            _ => quote!(#op #x),
105        }
106    }
107}
108
109struct Saturating;
110impl Sub for Saturating {
111    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
112        use syn::BinOp::*;
113        match op {
114            Add(_) => quote!((#left).saturating_add(#right)),
115            Sub(_) => quote!((#left).saturating_sub(#right)),
116            Mul(_) => quote!((#left).saturating_mul(#right)),
117            Div(_) => quote!((#left).saturating_div(#right)),
118            Rem(_) => quote!((#left).saturating_rem(#right)),
119            Shl(_) => quote!((#left).saturating_shl(#right)),
120            Shr(_) => quote!((#left).saturating_shr(#right)),
121
122            SubAssign(_) => quote!(#left = #left.saturating_sub(#right)),
123            AddAssign(_) => quote!(#left = #left.saturating_add(#right)),
124            MulAssign(_) => quote!(#left = #left.saturating_mul(#right)),
125            DivAssign(_) => quote!(#left = #left.saturating_div(#right)),
126            RemAssign(_) => quote!(#left = #left.saturating_rem(#right)),
127            ShlAssign(_) => quote!(#left = #left.saturating_shl(#right)),
128            ShrAssign(_) => quote!(#left = #left.saturating_shr(#right)),
129
130            _ => match (|| {
131                syn::Result::Ok(
132                    ExprBinary {
133                        attrs: Default::default(),
134                        left: syn::parse(left.into())?,
135                        op,
136                        right: syn::parse(right.into())?,
137                    }
138                    .to_token_stream(),
139                )
140            })() {
141                Ok(x) => x,
142                Err(e) => e.into_compile_error(),
143            },
144        }
145    }
146
147    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
148        match op {
149            UnOp::Neg(_) => quote!((#x).saturating_neg()),
150            _ => quote!(#op #x),
151        }
152    }
153}
154
155struct Algebraic;
156impl Sub for Algebraic {
157    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
158        use syn::BinOp::*;
159        match op {
160            Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
161            Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
162            Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
163            Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
164            Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
165
166            _ => match (|| {
167                syn::Result::Ok(
168                    ExprBinary {
169                        attrs: Default::default(),
170                        left: syn::parse(left.into())?,
171                        op,
172                        right: syn::parse(right.into())?,
173                    }
174                    .to_token_stream(),
175                )
176            })() {
177                Ok(x) => x,
178                Err(e) => e.into_compile_error(),
179            },
180        }
181    }
182
183    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
184        quote!(#op #x)
185    }
186}
187
188struct Fast;
189impl Sub for Fast {
190    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
191        use syn::BinOp::*;
192        match op {
193            Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
194            Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
195            Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
196            Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
197            Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
198            Eq(_) => quote!(/* eq */ ((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
199
200            _ => match (|| {
201                syn::Result::Ok(
202                    ExprBinary {
203                        attrs: Default::default(),
204                        left: syn::parse(left.into())?,
205                        op,
206                        right: syn::parse(right.into())?,
207                    }
208                    .to_token_stream(),
209                )
210            })() {
211                Ok(x) => x,
212                Err(e) => e.into_compile_error(),
213            },
214        }
215    }
216
217    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
218        quote!(#op #x)
219    }
220}
221
222fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
223    let walk = |e| walk(sub, e);
224    let map_block = |b| map_block(sub, b);
225    match e {
226        Expr::Binary(ExprBinary {
227            left, op, right, ..
228        }) => {
229            let left = walk(*left);
230            let right = walk(*right);
231            sub.sub_bin(op, left, right)
232        }
233        Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
234        Expr::Break(ExprBreak {
235            label,
236            expr: Some(expr),
237            ..
238        }) => {
239            let expr = walk(*expr);
240            quote!(#label #expr)
241        }
242        Expr::Call(ExprCall { func, args, .. }) => {
243            let f = walk(*func);
244            let args = args.into_iter().map(walk);
245            quote!(#f ( #(#args),* ))
246        }
247        Expr::Closure(ExprClosure {
248            lifetimes,
249            constness,
250            movability,
251            asyncness,
252            capture,
253            inputs,
254            output,
255            body,
256            ..
257        }) => {
258            let body = walk(*body);
259            quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output { #body })
260        }
261        Expr::ForLoop(ExprForLoop {
262            label,
263            pat,
264            expr,
265            body,
266            ..
267        }) => {
268            let (expr, body) = (walk(*expr), map_block(body));
269            quote!(#label for #pat in #expr #body)
270        }
271        Expr::Let(ExprLet { pat, expr, .. }) => {
272            quote_with!(expr = walk(*expr) => let #pat = #expr)
273        }
274        Expr::Const(ExprConst { block, .. }) => {
275            quote_with!(block =map_block(block) => const #block)
276        }
277        Expr::Range(ExprRange {
278            start, limits, end, ..
279        }) => {
280            let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
281            quote!((#start #limits #end))
282        }
283        Expr::Return(ExprReturn { expr, .. }) => {
284            let expr = expr.map(|x| walk(*x));
285            quote!(return #expr;)
286        }
287        Expr::Try(ExprTry { expr, .. }) => {
288            let expr = walk(*expr);
289            quote!(#expr ?)
290        }
291        Expr::TryBlock(ExprTryBlock { block, .. }) => {
292            let block = map_block(block);
293            quote!(try #block)
294        }
295        Expr::Unsafe(ExprUnsafe { block, .. }) => {
296            quote_with!(block =map_block(block) => unsafe #block)
297        }
298        Expr::While(ExprWhile {
299            label, cond, body, ..
300        }) => {
301            quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
302        }
303        Expr::Index(ExprIndex { expr, index, .. }) => {
304            let expr = walk(*expr);
305            let index = walk(*index);
306            quote!(#expr [ #index ])
307        }
308        Expr::Loop(ExprLoop { label, body, .. }) => {
309            quote_with!(body =map_block(body) => #label loop #body)
310        }
311        Expr::Reference(ExprReference {
312            mutability, expr, ..
313        }) => {
314            let expr = walk(*expr);
315            quote!(& #mutability #expr)
316        }
317        Expr::MethodCall(ExprMethodCall {
318            receiver,
319            method,
320            turbofish,
321            args,
322            ..
323        }) => {
324            let receiver = walk(*receiver);
325            let args = args.into_iter().map(walk);
326            quote!(#receiver . #method #turbofish (#(#args,)*))
327        }
328        Expr::Match(ExprMatch { expr, arms, .. }) => {
329            let arms = arms.into_iter().map(
330                |Arm {
331                     pat,
332                     guard,
333
334                     body,
335                     //  comma,
336                     ..
337                 }| {
338                    let b = walk(*body);
339                    let guard = match guard {
340                        Some((i, x)) => {
341                            let z = walk(*x);
342                            quote! { #i #z }
343                        }
344                        None => quote! {},
345                    };
346                    quote! { #pat #guard => { #b } }
347                },
348            );
349            quote!(match #expr { #(#arms)* })
350        }
351        Expr::If(ExprIf {
352            cond,
353            then_branch,
354            else_branch: Some((_, else_branch)),
355            ..
356        }) => {
357            let (cond, then_branch, else_branch) =
358                (walk(*cond), map_block(then_branch), walk(*else_branch));
359            quote!(if #cond #then_branch else #else_branch)
360        }
361        Expr::If(ExprIf {
362            cond, then_branch, ..
363        }) => {
364            let (cond, then_branch) = (walk(*cond), map_block(then_branch));
365            quote!(if #cond #then_branch)
366        }
367        Expr::Async(ExprAsync {
368            attrs,
369            capture,
370            block,
371            ..
372        }) => {
373            let block = map_block(block);
374            quote!(#(#attrs)* async #capture #block)
375        }
376        Expr::Await(ExprAwait { base, .. }) => {
377            let base = walk(*base);
378            quote!(#base.await)
379        }
380        Expr::Assign(ExprAssign { left, right, .. }) => {
381            let (left, right) = (walk(*left), walk(*right));
382            quote!(#left = #right;)
383        }
384        Expr::Paren(ExprParen { expr, .. }) => {
385            let expr = walk(*expr);
386            quote!(#expr)
387        }
388        Expr::Tuple(ExprTuple { elems, .. }) => {
389            let ts = elems.into_iter().map(walk);
390            quote!((#(#ts,)*))
391        }
392        Expr::Array(ExprArray { elems, .. }) => {
393            let ts = elems.into_iter().map(walk);
394            quote!([#(#ts,)*])
395        }
396        Expr::Repeat(ExprRepeat { expr, len, .. }) => {
397            let x = walk(*expr);
398            let len = walk(*len);
399            quote!([ #x ; #len ])
400        }
401        Expr::Block(ExprBlock {
402            block,
403            label: Some(label),
404            ..
405        }) => {
406            let b = map_block(block);
407            quote! { #label: #b }
408        }
409        Expr::Block(ExprBlock { block, .. }) => map_block(block),
410        Expr::Cast(ExprCast {
411            expr, as_token, ty, ..
412        }) => {
413            let e = walk(*expr);
414            quote! { #e #as_token #ty }
415        }
416        e => quote!(#e),
417    }
418}
419
420fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
421    let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
422    quote! { { #(#stmts)* } }
423}
424
425fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
426    let walk = |e| walk(sub, e);
427    match x {
428        Stmt::Local(Local {
429            pat,
430            init:
431                Some(LocalInit {
432                    expr,
433                    diverge: Some((_, diverge)),
434                    ..
435                }),
436            ..
437        }) => {
438            let expr = walk(*expr);
439            let diverge = walk(*diverge);
440            quote!(let #pat = #expr else { #diverge };)
441        }
442        Stmt::Local(Local {
443            pat,
444            init: Some(LocalInit { expr, .. }),
445            ..
446        }) => {
447            let expr = walk(*expr);
448            quote!(let #pat = #expr;)
449        }
450        Stmt::Item(x) => walk_item(sub, x),
451        Stmt::Expr(e, t) => {
452            let e = walk(e);
453            quote!(#e #t)
454        }
455        e => quote!(#e),
456    }
457}
458
459fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
460    let walk = |e| walk(sub, e);
461    match x {
462        Item::Const(ItemConst {
463            vis,
464            ident,
465            ty,
466            expr,
467            ..
468        }) => {
469            let expr = walk(*expr);
470            quote!(#vis const #ident : #ty = #expr;)
471        }
472        Item::Fn(ItemFn {
473            vis,
474            attrs,
475            sig,
476            block,
477        }) => {
478            let block = map_block(sub, *block);
479            quote!( #(#attrs)* #vis #sig #block)
480        }
481        Item::Impl(ItemImpl {
482            attrs,
483            unsafety,
484            defaultness,
485            generics,
486            trait_,
487            self_ty,
488            items,
489            ..
490        }) => {
491            let items = items.into_iter().map(|x| match x {
492                ImplItem::Const(ImplItemConst {
493                    vis,
494                    attrs,
495                    defaultness,
496                    ident,
497                    ty,
498                    expr,
499                    ..
500                }) => {
501                    let expr = walk(expr);
502                    quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
503                }
504                ImplItem::Fn(ImplItemFn {
505                    attrs,
506                    vis,
507                    defaultness,
508                    sig,
509                    block,
510                }) => {
511                    let block = map_block(sub, block);
512                    quote!(#(#attrs)* #vis #defaultness #sig #block)
513                }
514                e => quote!(#e),
515            });
516            let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
517            quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
518        }
519        Item::Mod(ItemMod {
520            attrs,
521            vis,
522            ident,
523            content: Some((_, content)),
524            ..
525        }) => {
526            let content = content.into_iter().map(|x| walk_item(sub, x));
527            quote!(#(#attrs)* #vis mod #ident { #(#content)* })
528        }
529        Item::Static(ItemStatic {
530            attrs,
531            vis,
532            mutability,
533            ident,
534            ty,
535            expr,
536            ..
537        }) => {
538            let expr = walk(*expr);
539            quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
540        }
541        e => quote!(#e),
542    }
543}
544
545macro_rules! walk {
546    ($input:ident,$t:expr) => {
547        match parse::<Expr>($input.clone())
548            .map(|x| walk(&$t, x))
549            .map_err(|x| x.to_compile_error().into_token_stream())
550        {
551            Ok(x) => x,
552            Err(e) => parse::<Stmt>($input)
553                .map(|x| walk_stmt(&$t, x))
554                .unwrap_or(e),
555        }
556        .into()
557    };
558}
559
560#[proc_macro]
561pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
562    walk!(input, Basic {})
563}
564
565#[proc_macro]
566pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
567    walk!(input, Fast {})
568}
569
570#[proc_macro]
571pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
572    walk!(input, Algebraic {})
573}
574
575#[proc_macro]
576pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
577    walk!(input, Wrapping {})
578}
579
580#[proc_macro]
581pub fn saturating(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
582    walk!(input, Saturating {})
583}
584
585#[proc_macro_attribute]
586pub fn apply(
587    args: proc_macro::TokenStream,
588    input: proc_macro::TokenStream,
589) -> proc_macro::TokenStream {
590    match &*args.to_string() {
591        "basic" | "" => math(input),
592        "fast" => fast(input),
593        "algebraic" => algebraic(input),
594        "wrapping" => wrapping(input),
595        "saturating" => saturating(input),
596        _ => {
597            quote! { compile_error!("type must be {fast, basic, algebraic, wrapping, saturating}") }
598                .into()
599        }
600    }
601}