lower_macros/
lib.rs

1//! a lil macro crate.
2//!
3//! provides a handy macro for converting `a + b` to `a.add(b)` for when you cant easily overload the `Add` trait.
4use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9    ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10        $(let $k = $v;)+
11        quote!($($tt)+)
12    }};
13}
14trait Sub {
15    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22        use syn::BinOp::*;
23        match op {
24            Add(_) => quote!((#left).add(#right)),
25            Sub(_) => quote!((#left).sub(#right)),
26            Mul(_) => quote!((#left).mul(#right)),
27            Div(_) => quote!((#left).div(#right)),
28            Rem(_) => quote!((#left).rem(#right)),
29            And(_) => quote!((#left).and(#right)),
30            Or(_) => quote!((#left).or(#right)),
31            BitXor(_) => quote!((#left).bitxor(#right)),
32            BitAnd(_) => quote!((#left).bitand(#right)),
33            BitOr(_) => quote!((#left).bitor(#right)),
34            Shl(_) => quote!((#left).shl(#right)),
35            Shr(_) => quote!((#left).shr(#right)),
36            Eq(_) => quote!((#left).eq(#right)),
37            Lt(_) => quote!((#left).lt(#right)),
38            Le(_) => quote!((#left).le(#right)),
39            Ne(_) => quote!((#left).ne(#right)),
40            Ge(_) => quote!((#left).ge(#right)),
41            Gt(_) => quote!((#left).gt(#right)),
42            // don't support assigning ops
43            e => {
44                Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45            }
46        }
47    }
48
49    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50        match op {
51            UnOp::Deref(_) => quote!((#x).deref()),
52            UnOp::Not(_) => quote!((#x).not()),
53            UnOp::Neg(_) => quote!((#x).neg()),
54            e => Error::new(
55                e.span(),
56                "it would appear a new operation has been added! please tell me.",
57            )
58            .to_compile_error(),
59        }
60    }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66        use syn::BinOp::*;
67        match op {
68            Add(_) => quote!((#left).wrapping_add(#right)),
69            Sub(_) => quote!((#left).wrapping_sub(#right)),
70            Mul(_) => quote!((#left).wrapping_mul(#right)),
71            Div(_) => quote!((#left).wrapping_div(#right)),
72            Rem(_) => quote!((#left).wrapping_rem(#right)),
73            Shl(_) => quote!((#left).wrapping_shl(#right)),
74            Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76            _ => quote!((#left) #op (#right)),
77        }
78    }
79
80    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
81        match op {
82            UnOp::Neg(_) => quote!((#x).wrapping_neg()),
83            _ => quote!(#op #x),
84        }
85    }
86}
87
88struct Saturating;
89impl Sub for Saturating {
90    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
91        use syn::BinOp::*;
92        match op {
93            Add(_) => quote!((#left).saturating_add(#right)),
94            Sub(_) => quote!((#left).saturating_sub(#right)),
95            Mul(_) => quote!((#left).saturating_mul(#right)),
96            Div(_) => quote!((#left).saturating_div(#right)),
97            Rem(_) => quote!((#left).saturating_rem(#right)),
98            Shl(_) => quote!((#left).saturating_shl(#right)),
99            Shr(_) => quote!((#left).saturating_shr(#right)),
100
101            _ => quote!((#left) #op (#right)),
102        }
103    }
104
105    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
106        match op {
107            UnOp::Neg(_) => quote!((#x).saturating_neg()),
108            _ => quote!(#op #x),
109        }
110    }
111}
112
113struct Algebraic;
114impl Sub for Algebraic {
115    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
116        use syn::BinOp::*;
117        match op {
118            Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
119            Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
120            Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
121            Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
122            Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
123            And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
124            _ => quote!((#left) #op (#right)),
125        }
126    }
127
128    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
129        quote!(#op #x)
130    }
131}
132
133struct Fast;
134impl Sub for Fast {
135    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
136        use syn::BinOp::*;
137        match op {
138            Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
139            Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
140            Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
141            Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
142            Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
143            And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
144            Eq(_) => quote!(/* eq */ ((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
145            _ => quote!((#left) #op (#right)),
146        }
147    }
148
149    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
150        quote!(#op #x)
151    }
152}
153
154fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
155    let walk = |e| walk(sub, e);
156    let map_block = |b| map_block(sub, b);
157    match e {
158        Expr::Binary(ExprBinary {
159            left, op, right, ..
160        }) => {
161            let left = walk(*left);
162            let right = walk(*right);
163            sub.sub_bin(op, left, right)
164        }
165        Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
166        Expr::Break(ExprBreak {
167            label,
168            expr: Some(expr),
169            ..
170        }) => {
171            let expr = walk(*expr);
172            quote!(#label #expr)
173        }
174        Expr::Call(ExprCall { func, args, .. }) => {
175            let f = walk(*func);
176            let args = args.into_iter().map(walk);
177            quote!(#f ( #(#args),* ))
178        }
179        Expr::Closure(ExprClosure {
180            lifetimes,
181            constness,
182            movability,
183            asyncness,
184            capture,
185            inputs,
186            output,
187            body,
188            ..
189        }) => {
190            let body = walk(*body);
191            quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output #body)
192        }
193        Expr::ForLoop(ExprForLoop {
194            label,
195            pat,
196            expr,
197            body,
198            ..
199        }) => {
200            let (expr, body) = (walk(*expr), map_block(body));
201            quote!(#label for #pat in #expr #body)
202        }
203        Expr::Let(ExprLet { pat, expr, .. }) => {
204            quote_with!(expr = walk(*expr) => let #pat = #expr)
205        }
206        Expr::Const(ExprConst { block, .. }) => {
207            quote_with!(block =map_block(block) => const #block)
208        }
209        Expr::Range(ExprRange {
210            start, limits, end, ..
211        }) => {
212            let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
213            quote!((#start #limits #end))
214        }
215        Expr::Return(ExprReturn { expr, .. }) => {
216            let expr = expr.map(|x| walk(*x));
217            quote!(return #expr;)
218        }
219        Expr::Try(ExprTry { expr, .. }) => {
220            let expr = walk(*expr);
221            quote!(#expr ?)
222        }
223        Expr::TryBlock(ExprTryBlock { block, .. }) => {
224            let block = map_block(block);
225            quote!(try #block)
226        }
227        Expr::Unsafe(ExprUnsafe { block, .. }) => {
228            quote_with!(block =map_block(block) => unsafe #block)
229        }
230        Expr::While(ExprWhile {
231            label, cond, body, ..
232        }) => {
233            quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
234        }
235        Expr::Index(ExprIndex { expr, index, .. }) => {
236            let expr = walk(*expr);
237            let index = walk(*index);
238            quote!(#expr [ #index ])
239        }
240        Expr::Loop(ExprLoop { label, body, .. }) => {
241            quote_with!(body =map_block(body) => #label loop #body)
242        }
243        Expr::Reference(ExprReference {
244            mutability, expr, ..
245        }) => {
246            let expr = walk(*expr);
247            quote!(& #mutability #expr)
248        }
249        Expr::MethodCall(ExprMethodCall {
250            receiver,
251            method,
252            turbofish,
253            args,
254            ..
255        }) => {
256            let receiver = walk(*receiver);
257            let args = args.into_iter().map(walk);
258            quote!(#receiver . #method #turbofish (#(#args,)*))
259        }
260        Expr::If(ExprIf {
261            cond,
262            then_branch,
263            else_branch: Some((_, else_branch)),
264            ..
265        }) => {
266            let (cond, then_branch, else_branch) =
267                (walk(*cond), map_block(then_branch), walk(*else_branch));
268            quote!(if #cond #then_branch else #else_branch)
269        }
270        Expr::If(ExprIf {
271            cond, then_branch, ..
272        }) => {
273            let (cond, then_branch) = (walk(*cond), map_block(then_branch));
274            quote!(if #cond #then_branch)
275        }
276        Expr::Async(ExprAsync {
277            attrs,
278            capture,
279            block,
280            ..
281        }) => {
282            let block = map_block(block);
283            quote!(#(#attrs)* async #capture #block)
284        }
285        Expr::Await(ExprAwait { base, .. }) => {
286            let base = walk(*base);
287            quote!(#base.await)
288        }
289        Expr::Assign(ExprAssign { left, right, .. }) => {
290            let (left, right) = (walk(*left), walk(*right));
291            quote!(#left = #right;)
292        }
293        Expr::Paren(ExprParen { expr, .. }) => {
294            let expr = walk(*expr);
295            quote!(#expr)
296        }
297        Expr::Tuple(ExprTuple { elems, .. }) => {
298            let ts = elems.into_iter().map(walk);
299            quote!((#(#ts,)*))
300        }
301        Expr::Array(ExprArray { elems, .. }) => {
302            let ts = elems.into_iter().map(walk);
303            quote!([#(#ts,)*])
304        }
305        Expr::Repeat(ExprRepeat { expr, len, .. }) => {
306            let x = walk(*expr);
307            let len = walk(*len);
308            quote!([ #x ; #len ])
309        }
310        Expr::Block(ExprBlock {
311            block,
312            label: Some(label),
313            ..
314        }) => {
315            let b = map_block(block);
316            quote! { #label: #b }
317        }
318        Expr::Block(ExprBlock { block, .. }) => map_block(block),
319        e => quote!(#e),
320    }
321}
322
323fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
324    let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
325    quote! { { #(#stmts)* } }
326}
327
328fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
329    let walk = |e| walk(sub, e);
330    match x {
331        Stmt::Local(Local {
332            pat,
333            init:
334                Some(LocalInit {
335                    expr,
336                    diverge: Some((_, diverge)),
337                    ..
338                }),
339            ..
340        }) => {
341            let expr = walk(*expr);
342            let diverge = walk(*diverge);
343            quote!(let #pat = #expr else { #diverge };)
344        }
345        Stmt::Local(Local {
346            pat,
347            init: Some(LocalInit { expr, .. }),
348            ..
349        }) => {
350            let expr = walk(*expr);
351            quote!(let #pat = #expr;)
352        }
353        Stmt::Item(x) => walk_item(sub, x),
354        Stmt::Expr(e, t) => {
355            let e = walk(e);
356            quote!(#e #t)
357        }
358        e => quote!(#e),
359    }
360}
361
362fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
363    let walk = |e| walk(sub, e);
364    match x {
365        Item::Const(ItemConst {
366            vis,
367            ident,
368            ty,
369            expr,
370            ..
371        }) => {
372            let expr = walk(*expr);
373            quote!(#vis const #ident : #ty = #expr;)
374        }
375        Item::Fn(ItemFn {
376            vis,
377            attrs,
378            sig,
379            block,
380        }) => {
381            let block = map_block(sub, *block);
382            quote!( #(#attrs)* #vis #sig #block)
383        }
384        Item::Impl(ItemImpl {
385            attrs,
386            unsafety,
387            defaultness,
388            generics,
389            trait_,
390            self_ty,
391            items,
392            ..
393        }) => {
394            let items = items.into_iter().map(|x| match x {
395                ImplItem::Const(ImplItemConst {
396                    vis,
397                    attrs,
398                    defaultness,
399                    ident,
400                    ty,
401                    expr,
402                    ..
403                }) => {
404                    let expr = walk(expr);
405                    quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
406                }
407                ImplItem::Fn(ImplItemFn {
408                    attrs,
409                    vis,
410                    defaultness,
411                    sig,
412                    block,
413                }) => {
414                    let block = map_block(sub, block);
415                    quote!(#(#attrs)* #vis #defaultness #sig #block)
416                }
417                e => quote!(#e),
418            });
419            let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
420            quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
421        }
422        Item::Mod(ItemMod {
423            attrs,
424            vis,
425            ident,
426            content: Some((_, content)),
427            ..
428        }) => {
429            let content = content.into_iter().map(|x| walk_item(sub, x));
430            quote!(#(#attrs)* #vis mod #ident { #(#content)* })
431        }
432        Item::Static(ItemStatic {
433            attrs,
434            vis,
435            mutability,
436            ident,
437            ty,
438            expr,
439            ..
440        }) => {
441            let expr = walk(*expr);
442            quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
443        }
444        e => quote!(#e),
445    }
446}
447
448macro_rules! walk {
449    ($input:ident,$t:expr) => {
450        match parse::<Expr>($input.clone())
451            .map(|x| walk(&$t, x))
452            .map_err(|x| x.to_compile_error().into_token_stream())
453        {
454            Ok(x) => x,
455            Err(e) => parse::<Stmt>($input)
456                .map(|x| walk_stmt(&$t, x))
457                .unwrap_or(e),
458        }
459        .into()
460    };
461}
462
463#[proc_macro]
464pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
465    walk!(input, Basic {})
466}
467
468#[proc_macro]
469pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
470    walk!(input, Fast {})
471}
472
473#[proc_macro]
474pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
475    walk!(input, Algebraic {})
476}
477
478#[proc_macro]
479pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
480    walk!(input, Wrapping {})
481}
482
483#[proc_macro]
484pub fn saturating(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
485    walk!(input, Saturating {})
486}
487
488#[proc_macro_attribute]
489pub fn apply(
490    args: proc_macro::TokenStream,
491    input: proc_macro::TokenStream,
492) -> proc_macro::TokenStream {
493    match &*args.to_string() {
494        "basic" | "" => math(input),
495        "fast" => fast(input),
496        "algebraic" => algebraic(input),
497        "wrapping" => wrapping(input),
498        "saturating" => saturating(input),
499        _ => quote! { compile_error!("type must be {fast, basic, algebraic, wrapping}") }.into(),
500    }
501}