lower_macros/
lib.rs

1//! a lil macro crate.
2//!
3//! provides a handy macro for converting `a + b` to `a.add(b)` for when you cant easily overload the `Add` trait.
4use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9    ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10        $(let $k = $v;)+
11        quote!($($tt)+)
12    }};
13}
14trait Sub {
15    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22        use syn::BinOp::*;
23        match op {
24            Add(_) => quote!((#left).add(#right)),
25            Sub(_) => quote!((#left).sub(#right)),
26            Mul(_) => quote!((#left).mul(#right)),
27            Div(_) => quote!((#left).div(#right)),
28            Rem(_) => quote!((#left).rem(#right)),
29            And(_) => quote!((#left).and(#right)),
30            Or(_) => quote!((#left).or(#right)),
31            BitXor(_) => quote!((#left).bitxor(#right)),
32            BitAnd(_) => quote!((#left).bitand(#right)),
33            BitOr(_) => quote!((#left).bitor(#right)),
34            Shl(_) => quote!((#left).shl(#right)),
35            Shr(_) => quote!((#left).shr(#right)),
36            Eq(_) => quote!((#left).eq(#right)),
37            Lt(_) => quote!((#left).lt(#right)),
38            Le(_) => quote!((#left).le(#right)),
39            Ne(_) => quote!((#left).ne(#right)),
40            Ge(_) => quote!((#left).ge(#right)),
41            Gt(_) => quote!((#left).gt(#right)),
42            // don't support assigning ops
43            e => {
44                Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45            }
46        }
47    }
48
49    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50        match op {
51            UnOp::Deref(_) => quote!((#x).deref()),
52            UnOp::Not(_) => quote!((#x).not()),
53            UnOp::Neg(_) => quote!((#x).neg()),
54            e => Error::new(
55                e.span(),
56                "it would appear a new operation has been added! please tell me.",
57            )
58            .to_compile_error(),
59        }
60    }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66        use syn::BinOp::*;
67        match op {
68            Add(_) => quote!((#left).wrapping_add(#right)),
69            Sub(_) => quote!((#left).wrapping_sub(#right)),
70            Mul(_) => quote!((#left).wrapping_mul(#right)),
71            Div(_) => quote!((#left).wrapping_div(#right)),
72            Rem(_) => quote!((#left).wrapping_rem(#right)),
73            Shl(_) => quote!((#left).wrapping_shl(#right)),
74            Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76            _ => quote!((#left) #op (#right)),
77        }
78    }
79
80    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
81        match op {
82            UnOp::Neg(_) => quote!((#x).wrapping_neg()),
83            _ => quote!(#op #x),
84        }
85    }
86}
87
88struct Algebraic;
89impl Sub for Algebraic {
90    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
91        use syn::BinOp::*;
92        match op {
93            Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
94            Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
95            Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
96            Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
97            Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
98            And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
99            _ => quote!((#left) #op (#right)),
100        }
101    }
102
103    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
104        quote!(#op #x)
105    }
106}
107
108struct Fast;
109impl Sub for Fast {
110    fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
111        use syn::BinOp::*;
112        match op {
113            Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
114            Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
115            Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
116            Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
117            Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
118            And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
119            Eq(_) => quote!(/* eq */ ((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
120            _ => quote!((#left) #op (#right)),
121        }
122    }
123
124    fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
125        quote!(#op #x)
126    }
127}
128
129fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
130    let walk = |e| walk(sub, e);
131    let map_block = |b| map_block(sub, b);
132    match e {
133        Expr::Binary(ExprBinary {
134            left, op, right, ..
135        }) => {
136            let left = walk(*left);
137            let right = walk(*right);
138            sub.sub_bin(op, left, right)
139        }
140        Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
141        Expr::Break(ExprBreak {
142            label,
143            expr: Some(expr),
144            ..
145        }) => {
146            let expr = walk(*expr);
147            quote!(#label #expr)
148        }
149        Expr::Call(ExprCall { func, args, .. }) => {
150            let f = walk(*func);
151            let args = args.into_iter().map(walk);
152            quote!(#f ( #(#args),* ))
153        }
154        Expr::Closure(ExprClosure {
155            lifetimes,
156            constness,
157            movability,
158            asyncness,
159            capture,
160            inputs,
161            output,
162            body,
163            ..
164        }) => {
165            let body = walk(*body);
166            quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output #body)
167        }
168        Expr::ForLoop(ExprForLoop {
169            label,
170            pat,
171            expr,
172            body,
173            ..
174        }) => {
175            let (expr, body) = (walk(*expr), map_block(body));
176            quote!(#label for #pat in #expr #body)
177        }
178        Expr::Let(ExprLet { pat, expr, .. }) => {
179            quote_with!(expr = walk(*expr) => let #pat = #expr)
180        }
181        Expr::Const(ExprConst { block, .. }) => {
182            quote_with!(block =map_block(block) => const #block)
183        }
184        Expr::Range(ExprRange {
185            start, limits, end, ..
186        }) => {
187            let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
188            quote!((#start #limits #end))
189        }
190        Expr::Return(ExprReturn { expr, .. }) => {
191            let expr = expr.map(|x| walk(*x));
192            quote!(return #expr;)
193        }
194        Expr::Try(ExprTry { expr, .. }) => {
195            let expr = walk(*expr);
196            quote!(#expr ?)
197        }
198        Expr::TryBlock(ExprTryBlock { block, .. }) => {
199            let block = map_block(block);
200            quote!(try #block)
201        }
202        Expr::Unsafe(ExprUnsafe { block, .. }) => {
203            quote_with!(block =map_block(block) => unsafe #block)
204        }
205        Expr::While(ExprWhile {
206            label, cond, body, ..
207        }) => {
208            quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
209        }
210        Expr::Index(ExprIndex { expr, index, .. }) => {
211            let expr = walk(*expr);
212            let index = walk(*index);
213            quote!(#expr [ #index ])
214        }
215        Expr::Loop(ExprLoop { label, body, .. }) => {
216            quote_with!(body =map_block(body) => #label loop #body)
217        }
218        Expr::Reference(ExprReference {
219            mutability, expr, ..
220        }) => {
221            let expr = walk(*expr);
222            quote!(& #mutability #expr)
223        }
224        Expr::MethodCall(ExprMethodCall {
225            receiver,
226            method,
227            turbofish,
228            args,
229            ..
230        }) => {
231            let receiver = walk(*receiver);
232            let args = args.into_iter().map(walk);
233            quote!(#receiver . #method #turbofish (#(#args,)*))
234        }
235        Expr::If(ExprIf {
236            cond,
237            then_branch,
238            else_branch: Some((_, else_branch)),
239            ..
240        }) => {
241            let (cond, then_branch, else_branch) =
242                (walk(*cond), map_block(then_branch), walk(*else_branch));
243            quote!(if #cond #then_branch else #else_branch)
244        }
245        Expr::If(ExprIf {
246            cond, then_branch, ..
247        }) => {
248            let (cond, then_branch) = (walk(*cond), map_block(then_branch));
249            quote!(if #cond #then_branch)
250        }
251        Expr::Async(ExprAsync {
252            attrs,
253            capture,
254            block,
255            ..
256        }) => {
257            let block = map_block(block);
258            quote!(#(#attrs)* async #capture #block)
259        }
260        Expr::Await(ExprAwait { base, .. }) => {
261            let base = walk(*base);
262            quote!(#base.await)
263        }
264        Expr::Assign(ExprAssign { left, right, .. }) => {
265            let (left, right) = (walk(*left), walk(*right));
266            quote!(#left = #right;)
267        }
268        Expr::Paren(ExprParen { expr, .. }) => {
269            let expr = walk(*expr);
270            quote!(#expr)
271        }
272        Expr::Tuple(ExprTuple { elems, .. }) => {
273            let ts = elems.into_iter().map(walk);
274            quote!((#(#ts,)*))
275        }
276        Expr::Array(ExprArray { elems, .. }) => {
277            let ts = elems.into_iter().map(walk);
278            quote!([#(#ts,)*])
279        }
280        Expr::Repeat(ExprRepeat { expr, len, .. }) => {
281            let x = walk(*expr);
282            let len = walk(*len);
283            quote!([ #x ; #len ])
284        }
285        Expr::Block(ExprBlock {
286            block,
287            label: Some(label),
288            ..
289        }) => {
290            let b = map_block(block);
291            quote! { #label: #b }
292        }
293        Expr::Block(ExprBlock { block, .. }) => map_block(block),
294        e => quote!(#e),
295    }
296}
297
298fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
299    let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
300    quote! { { #(#stmts)* } }
301}
302
303fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
304    let walk = |e| walk(sub, e);
305    match x {
306        Stmt::Local(Local {
307            pat,
308            init:
309                Some(LocalInit {
310                    expr,
311                    diverge: Some((_, diverge)),
312                    ..
313                }),
314            ..
315        }) => {
316            let expr = walk(*expr);
317            let diverge = walk(*diverge);
318            quote!(let #pat = #expr else { #diverge };)
319        }
320        Stmt::Local(Local {
321            pat,
322            init: Some(LocalInit { expr, .. }),
323            ..
324        }) => {
325            let expr = walk(*expr);
326            quote!(let #pat = #expr;)
327        }
328        Stmt::Item(x) => walk_item(sub, x),
329        Stmt::Expr(e, t) => {
330            let e = walk(e);
331            quote!(#e #t)
332        }
333        e => quote!(#e),
334    }
335}
336
337fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
338    let walk = |e| walk(sub, e);
339    match x {
340        Item::Const(ItemConst {
341            vis,
342            ident,
343            ty,
344            expr,
345            ..
346        }) => {
347            let expr = walk(*expr);
348            quote!(#vis const #ident : #ty = #expr;)
349        }
350        Item::Fn(ItemFn {
351            vis,
352            attrs,
353            sig,
354            block,
355        }) => {
356            let block = map_block(sub, *block);
357            quote!( #(#attrs)* #vis #sig #block)
358        }
359        Item::Impl(ItemImpl {
360            attrs,
361            unsafety,
362            defaultness,
363            generics,
364            trait_,
365            self_ty,
366            items,
367            ..
368        }) => {
369            let items = items.into_iter().map(|x| match x {
370                ImplItem::Const(ImplItemConst {
371                    vis,
372                    attrs,
373                    defaultness,
374                    ident,
375                    ty,
376                    expr,
377                    ..
378                }) => {
379                    let expr = walk(expr);
380                    quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
381                }
382                ImplItem::Fn(ImplItemFn {
383                    attrs,
384                    vis,
385                    defaultness,
386                    sig,
387                    block,
388                }) => {
389                    let block = map_block(sub, block);
390                    quote!(#(#attrs)* #vis #defaultness #sig #block)
391                }
392                e => quote!(#e),
393            });
394            let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
395            quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
396        }
397        Item::Mod(ItemMod {
398            attrs,
399            vis,
400            ident,
401            content: Some((_, content)),
402            ..
403        }) => {
404            let content = content.into_iter().map(|x| walk_item(sub, x));
405            quote!(#(#attrs)* #vis mod #ident { #(#content)* })
406        }
407        Item::Static(ItemStatic {
408            attrs,
409            vis,
410            mutability,
411            ident,
412            ty,
413            expr,
414            ..
415        }) => {
416            let expr = walk(*expr);
417            quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
418        }
419        e => quote!(#e),
420    }
421}
422
423macro_rules! walk {
424    ($input:ident,$t:expr) => {
425        match parse::<Expr>($input.clone())
426            .map(|x| walk(&$t, x))
427            .map_err(|x| x.to_compile_error().into_token_stream())
428        {
429            Ok(x) => x,
430            Err(e) => parse::<Stmt>($input)
431                .map(|x| walk_stmt(&$t, x))
432                .unwrap_or(e),
433        }
434        .into()
435    };
436}
437
438#[proc_macro]
439pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
440    walk!(input, Basic {})
441}
442
443#[proc_macro]
444pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
445    walk!(input, Fast {})
446}
447
448#[proc_macro]
449pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
450    walk!(input, Algebraic {})
451}
452
453#[proc_macro]
454pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
455    walk!(input, Wrapping {})
456}
457
458#[proc_macro_attribute]
459pub fn apply(
460    args: proc_macro::TokenStream,
461    input: proc_macro::TokenStream,
462) -> proc_macro::TokenStream {
463    match &*args.to_string() {
464        "basic" | "" => math(input),
465        "fast" => fast(input),
466        "algebraic" => algebraic(input),
467        "wrapping" => wrapping(input),
468        _ => quote! { compile_error!("type must be {fast, basic, algebraic, wrapping}") }.into(),
469    }
470}