lower_macros/
lib.rs

1//! a lil macro crate.
2//!
3//! provides a handy macro for converting `a + b` to `a.add(b)` for when you cant easily overload the `Add` trait.
4use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9    ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10        $(let $k = $v;)+
11        quote!($($tt)+)
12    }};
13}
14trait Sub {
15    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22        use syn::BinOp::*;
23        match op {
24            Add(_) => quote!((#left).add(#right)),
25            Sub(_) => quote!((#left).sub(#right)),
26            Mul(_) => quote!((#left).mul(#right)),
27            Div(_) => quote!((#left).div(#right)),
28            Rem(_) => quote!((#left).rem(#right)),
29            And(_) => quote!((#left).and(#right)),
30            Or(_) => quote!((#left).or(#right)),
31            BitXor(_) => quote!((#left).bitxor(#right)),
32            BitAnd(_) => quote!((#left).bitand(#right)),
33            BitOr(_) => quote!((#left).bitor(#right)),
34            Shl(_) => quote!((#left).shl(#right)),
35            Shr(_) => quote!((#left).shr(#right)),
36            Eq(_) => quote!((#left).eq(#right)),
37            Lt(_) => quote!((#left).lt(#right)),
38            Le(_) => quote!((#left).le(#right)),
39            Ne(_) => quote!((#left).ne(#right)),
40            Ge(_) => quote!((#left).ge(#right)),
41            Gt(_) => quote!((#left).gt(#right)),
42            // don't support assigning ops
43            e => {
44                Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45            }
46        }
47    }
48
49    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50        match op {
51            UnOp::Deref(_) => quote!((#x).deref()),
52            UnOp::Not(_) => quote!((#x).not()),
53            UnOp::Neg(_) => quote!((#x).neg()),
54            e => Error::new(
55                e.span(),
56                "it would appear a new operation has been added! please tell me.",
57            )
58            .to_compile_error(),
59        }
60    }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66        use syn::BinOp::*;
67        match op {
68            Add(_) => quote!((#left).wrapping_add(#right)),
69            Sub(_) => quote!((#left).wrapping_sub(#right)),
70            Mul(_) => quote!((#left).wrapping_mul(#right)),
71            Div(_) => quote!((#left).wrapping_div(#right)),
72            Rem(_) => quote!((#left).wrapping_rem(#right)),
73            Shl(_) => quote!((#left).wrapping_shl(#right)),
74            Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76            SubAssign(_) => quote!(#left = #left.wrapping_sub(#right)),
77            AddAssign(_) => quote!(#left = #left.wrapping_add(#right)),
78            MulAssign(_) => quote!(#left = #left.wrapping_mul(#right)),
79            DivAssign(_) => quote!(#left = #left.wrapping_div(#right)),
80            RemAssign(_) => quote!(#left = #left.wrapping_rem(#right)),
81            ShlAssign(_) => quote!(#left = #left.wrapping_shl(#right)),
82            ShrAssign(_) => quote!(#left = #left.wrapping_shr(#right)),
83
84            _ => quote!((#left) #op (#right)),
85        }
86    }
87
88    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
89        match op {
90            UnOp::Neg(_) => quote!((#x).wrapping_neg()),
91            _ => quote!(#op #x),
92        }
93    }
94}
95
96struct Saturating;
97impl Sub for Saturating {
98    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
99        use syn::BinOp::*;
100        match op {
101            Add(_) => quote!((#left).saturating_add(#right)),
102            Sub(_) => quote!((#left).saturating_sub(#right)),
103            Mul(_) => quote!((#left).saturating_mul(#right)),
104            Div(_) => quote!((#left).saturating_div(#right)),
105            Rem(_) => quote!((#left).saturating_rem(#right)),
106            Shl(_) => quote!((#left).saturating_shl(#right)),
107            Shr(_) => quote!((#left).saturating_shr(#right)),
108
109            SubAssign(_) => quote!(#left = #left.saturating_sub(#right)),
110            AddAssign(_) => quote!(#left = #left.saturating_add(#right)),
111            MulAssign(_) => quote!(#left = #left.saturating_mul(#right)),
112            DivAssign(_) => quote!(#left = #left.saturating_div(#right)),
113            RemAssign(_) => quote!(#left = #left.saturating_rem(#right)),
114            ShlAssign(_) => quote!(#left = #left.saturating_shl(#right)),
115            ShrAssign(_) => quote!(#left = #left.saturating_shr(#right)),
116
117            _ => quote!((#left) #op (#right)),
118        }
119    }
120
121    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
122        match op {
123            UnOp::Neg(_) => quote!((#x).saturating_neg()),
124            _ => quote!(#op #x),
125        }
126    }
127}
128
129struct Algebraic;
130impl Sub for Algebraic {
131    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
132        use syn::BinOp::*;
133        match op {
134            Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
135            Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
136            Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
137            Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
138            Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
139            And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
140            _ => quote!((#left) #op (#right)),
141        }
142    }
143
144    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
145        quote!(#op #x)
146    }
147}
148
149struct Fast;
150impl Sub for Fast {
151    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
152        use syn::BinOp::*;
153        match op {
154            Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
155            Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
156            Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
157            Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
158            Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
159            And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
160            Eq(_) => quote!(/* eq */ ((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
161            _ => quote!((#left) #op (#right)),
162        }
163    }
164
165    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
166        quote!(#op #x)
167    }
168}
169
170fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
171    let walk = |e| walk(sub, e);
172    let map_block = |b| map_block(sub, b);
173    match e {
174        Expr::Binary(ExprBinary {
175            left, op, right, ..
176        }) => {
177            let left = walk(*left);
178            let right = walk(*right);
179            sub.sub_bin(op, left, right)
180        }
181        Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
182        Expr::Break(ExprBreak {
183            label,
184            expr: Some(expr),
185            ..
186        }) => {
187            let expr = walk(*expr);
188            quote!(#label #expr)
189        }
190        Expr::Call(ExprCall { func, args, .. }) => {
191            let f = walk(*func);
192            let args = args.into_iter().map(walk);
193            quote!(#f ( #(#args),* ))
194        }
195        Expr::Closure(ExprClosure {
196            lifetimes,
197            constness,
198            movability,
199            asyncness,
200            capture,
201            inputs,
202            output,
203            body,
204            ..
205        }) => {
206            let body = walk(*body);
207            quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output { #body })
208        }
209        Expr::ForLoop(ExprForLoop {
210            label,
211            pat,
212            expr,
213            body,
214            ..
215        }) => {
216            let (expr, body) = (walk(*expr), map_block(body));
217            quote!(#label for #pat in #expr #body)
218        }
219        Expr::Let(ExprLet { pat, expr, .. }) => {
220            quote_with!(expr = walk(*expr) => let #pat = #expr)
221        }
222        Expr::Const(ExprConst { block, .. }) => {
223            quote_with!(block =map_block(block) => const #block)
224        }
225        Expr::Range(ExprRange {
226            start, limits, end, ..
227        }) => {
228            let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
229            quote!((#start #limits #end))
230        }
231        Expr::Return(ExprReturn { expr, .. }) => {
232            let expr = expr.map(|x| walk(*x));
233            quote!(return #expr;)
234        }
235        Expr::Try(ExprTry { expr, .. }) => {
236            let expr = walk(*expr);
237            quote!(#expr ?)
238        }
239        Expr::TryBlock(ExprTryBlock { block, .. }) => {
240            let block = map_block(block);
241            quote!(try #block)
242        }
243        Expr::Unsafe(ExprUnsafe { block, .. }) => {
244            quote_with!(block =map_block(block) => unsafe #block)
245        }
246        Expr::While(ExprWhile {
247            label, cond, body, ..
248        }) => {
249            quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
250        }
251        Expr::Index(ExprIndex { expr, index, .. }) => {
252            let expr = walk(*expr);
253            let index = walk(*index);
254            quote!(#expr [ #index ])
255        }
256        Expr::Loop(ExprLoop { label, body, .. }) => {
257            quote_with!(body =map_block(body) => #label loop #body)
258        }
259        Expr::Reference(ExprReference {
260            mutability, expr, ..
261        }) => {
262            let expr = walk(*expr);
263            quote!(& #mutability #expr)
264        }
265        Expr::MethodCall(ExprMethodCall {
266            receiver,
267            method,
268            turbofish,
269            args,
270            ..
271        }) => {
272            let receiver = walk(*receiver);
273            let args = args.into_iter().map(walk);
274            quote!(#receiver . #method #turbofish (#(#args,)*))
275        }
276        Expr::Match(ExprMatch { expr, arms, .. }) => {
277            let arms = arms.into_iter().map(
278                |Arm {
279                     pat,
280                     guard,
281
282                     body,
283                     //  comma,
284                     ..
285                 }| {
286                    let b = walk(*body);
287                    let guard = match guard {
288                        Some((i, x)) => {
289                            let z = walk(*x);
290                            quote! { #i #z }
291                        }
292                        None => quote! {},
293                    };
294                    quote! { #pat #guard => { #b } }
295                },
296            );
297            quote!(match #expr { #(#arms)* })
298        }
299        Expr::If(ExprIf {
300            cond,
301            then_branch,
302            else_branch: Some((_, else_branch)),
303            ..
304        }) => {
305            let (cond, then_branch, else_branch) =
306                (walk(*cond), map_block(then_branch), walk(*else_branch));
307            quote!(if #cond #then_branch else #else_branch)
308        }
309        Expr::If(ExprIf {
310            cond, then_branch, ..
311        }) => {
312            let (cond, then_branch) = (walk(*cond), map_block(then_branch));
313            quote!(if #cond #then_branch)
314        }
315        Expr::Async(ExprAsync {
316            attrs,
317            capture,
318            block,
319            ..
320        }) => {
321            let block = map_block(block);
322            quote!(#(#attrs)* async #capture #block)
323        }
324        Expr::Await(ExprAwait { base, .. }) => {
325            let base = walk(*base);
326            quote!(#base.await)
327        }
328        Expr::Assign(ExprAssign { left, right, .. }) => {
329            let (left, right) = (walk(*left), walk(*right));
330            quote!(#left = #right;)
331        }
332        Expr::Paren(ExprParen { expr, .. }) => {
333            let expr = walk(*expr);
334            quote!(#expr)
335        }
336        Expr::Tuple(ExprTuple { elems, .. }) => {
337            let ts = elems.into_iter().map(walk);
338            quote!((#(#ts,)*))
339        }
340        Expr::Array(ExprArray { elems, .. }) => {
341            let ts = elems.into_iter().map(walk);
342            quote!([#(#ts,)*])
343        }
344        Expr::Repeat(ExprRepeat { expr, len, .. }) => {
345            let x = walk(*expr);
346            let len = walk(*len);
347            quote!([ #x ; #len ])
348        }
349        Expr::Block(ExprBlock {
350            block,
351            label: Some(label),
352            ..
353        }) => {
354            let b = map_block(block);
355            quote! { #label: #b }
356        }
357        Expr::Block(ExprBlock { block, .. }) => map_block(block),
358        Expr::Cast(ExprCast {
359            expr, as_token, ty, ..
360        }) => {
361            let e = walk(*expr);
362            quote! { #e #as_token #ty }
363        }
364        e => quote!(#e),
365    }
366}
367
368fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
369    let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
370    quote! { { #(#stmts)* } }
371}
372
373fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
374    let walk = |e| walk(sub, e);
375    match x {
376        Stmt::Local(Local {
377            pat,
378            init:
379                Some(LocalInit {
380                    expr,
381                    diverge: Some((_, diverge)),
382                    ..
383                }),
384            ..
385        }) => {
386            let expr = walk(*expr);
387            let diverge = walk(*diverge);
388            quote!(let #pat = #expr else { #diverge };)
389        }
390        Stmt::Local(Local {
391            pat,
392            init: Some(LocalInit { expr, .. }),
393            ..
394        }) => {
395            let expr = walk(*expr);
396            quote!(let #pat = #expr;)
397        }
398        Stmt::Item(x) => walk_item(sub, x),
399        Stmt::Expr(e, t) => {
400            let e = walk(e);
401            quote!(#e #t)
402        }
403        e => quote!(#e),
404    }
405}
406
407fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
408    let walk = |e| walk(sub, e);
409    match x {
410        Item::Const(ItemConst {
411            vis,
412            ident,
413            ty,
414            expr,
415            ..
416        }) => {
417            let expr = walk(*expr);
418            quote!(#vis const #ident : #ty = #expr;)
419        }
420        Item::Fn(ItemFn {
421            vis,
422            attrs,
423            sig,
424            block,
425        }) => {
426            let block = map_block(sub, *block);
427            quote!( #(#attrs)* #vis #sig #block)
428        }
429        Item::Impl(ItemImpl {
430            attrs,
431            unsafety,
432            defaultness,
433            generics,
434            trait_,
435            self_ty,
436            items,
437            ..
438        }) => {
439            let items = items.into_iter().map(|x| match x {
440                ImplItem::Const(ImplItemConst {
441                    vis,
442                    attrs,
443                    defaultness,
444                    ident,
445                    ty,
446                    expr,
447                    ..
448                }) => {
449                    let expr = walk(expr);
450                    quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
451                }
452                ImplItem::Fn(ImplItemFn {
453                    attrs,
454                    vis,
455                    defaultness,
456                    sig,
457                    block,
458                }) => {
459                    let block = map_block(sub, block);
460                    quote!(#(#attrs)* #vis #defaultness #sig #block)
461                }
462                e => quote!(#e),
463            });
464            let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
465            quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
466        }
467        Item::Mod(ItemMod {
468            attrs,
469            vis,
470            ident,
471            content: Some((_, content)),
472            ..
473        }) => {
474            let content = content.into_iter().map(|x| walk_item(sub, x));
475            quote!(#(#attrs)* #vis mod #ident { #(#content)* })
476        }
477        Item::Static(ItemStatic {
478            attrs,
479            vis,
480            mutability,
481            ident,
482            ty,
483            expr,
484            ..
485        }) => {
486            let expr = walk(*expr);
487            quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
488        }
489        e => quote!(#e),
490    }
491}
492
493macro_rules! walk {
494    ($input:ident,$t:expr) => {
495        match parse::<Expr>($input.clone())
496            .map(|x| walk(&$t, x))
497            .map_err(|x| x.to_compile_error().into_token_stream())
498        {
499            Ok(x) => x,
500            Err(e) => parse::<Stmt>($input)
501                .map(|x| walk_stmt(&$t, x))
502                .unwrap_or(e),
503        }
504        .into()
505    };
506}
507
508#[proc_macro]
509pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
510    walk!(input, Basic {})
511}
512
513#[proc_macro]
514pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
515    walk!(input, Fast {})
516}
517
518#[proc_macro]
519pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
520    walk!(input, Algebraic {})
521}
522
523#[proc_macro]
524pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
525    walk!(input, Wrapping {})
526}
527
528#[proc_macro]
529pub fn saturating(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
530    walk!(input, Saturating {})
531}
532
533#[proc_macro_attribute]
534pub fn apply(
535    args: proc_macro::TokenStream,
536    input: proc_macro::TokenStream,
537) -> proc_macro::TokenStream {
538    match &*args.to_string() {
539        "basic" | "" => math(input),
540        "fast" => fast(input),
541        "algebraic" => algebraic(input),
542        "wrapping" => wrapping(input),
543        "saturating" => saturating(input),
544        _ => {
545            quote! { compile_error!("type must be {fast, basic, algebraic, wrapping, saturating}") }
546                .into()
547        }
548    }
549}