1use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9 ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10 $(let $k = $v;)+
11 quote!($($tt)+)
12 }};
13}
14trait Sub {
15 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22 use syn::BinOp::*;
23 match op {
24 Add(_) => quote!((#left).add(#right)),
25 Sub(_) => quote!((#left).sub(#right)),
26 Mul(_) => quote!((#left).mul(#right)),
27 Div(_) => quote!((#left).div(#right)),
28 Rem(_) => quote!((#left).rem(#right)),
29 And(_) => quote!((#left).and(#right)),
30 Or(_) => quote!((#left).or(#right)),
31 BitXor(_) => quote!((#left).bitxor(#right)),
32 BitAnd(_) => quote!((#left).bitand(#right)),
33 BitOr(_) => quote!((#left).bitor(#right)),
34 Shl(_) => quote!((#left).shl(#right)),
35 Shr(_) => quote!((#left).shr(#right)),
36 Eq(_) => quote!((#left).eq(#right)),
37 Lt(_) => quote!((#left).lt(#right)),
38 Le(_) => quote!((#left).le(#right)),
39 Ne(_) => quote!((#left).ne(#right)),
40 Ge(_) => quote!((#left).ge(#right)),
41 Gt(_) => quote!((#left).gt(#right)),
42 e => {
44 Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45 }
46 }
47 }
48
49 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50 match op {
51 UnOp::Deref(_) => quote!((#x).deref()),
52 UnOp::Not(_) => quote!((#x).not()),
53 UnOp::Neg(_) => quote!((#x).neg()),
54 e => Error::new(
55 e.span(),
56 "it would appear a new operation has been added! please tell me.",
57 )
58 .to_compile_error(),
59 }
60 }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66 use syn::BinOp::*;
67 match op {
68 Add(_) => quote!((#left).wrapping_add(#right)),
69 Sub(_) => quote!((#left).wrapping_sub(#right)),
70 Mul(_) => quote!((#left).wrapping_mul(#right)),
71 Div(_) => quote!((#left).wrapping_div(#right)),
72 Rem(_) => quote!((#left).wrapping_rem(#right)),
73 Shl(_) => quote!((#left).wrapping_shl(#right)),
74 Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76 _ => quote!((#left) #op (#right)),
77 }
78 }
79
80 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
81 match op {
82 UnOp::Neg(_) => quote!((#x).wrapping_neg()),
83 _ => quote!(#op #x),
84 }
85 }
86}
87
88struct Algebraic;
89impl Sub for Algebraic {
90 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
91 use syn::BinOp::*;
92 match op {
93 Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
94 Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
95 Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
96 Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
97 Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
98 And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
99 _ => quote!((#left) #op (#right)),
100 }
101 }
102
103 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
104 quote!(#op #x)
105 }
106}
107
108struct Fast;
109impl Sub for Fast {
110 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
111 use syn::BinOp::*;
112 match op {
113 Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
114 Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
115 Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
116 Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
117 Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
118 And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
119 Eq(_) => quote!(((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
120 _ => quote!((#left) #op (#right)),
121 }
122 }
123
124 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
125 quote!(#op #x)
126 }
127}
128
129fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
130 let walk = |e| walk(sub, e);
131 let map_block = |b| map_block(sub, b);
132 match e {
133 Expr::Binary(ExprBinary {
134 left, op, right, ..
135 }) => {
136 let left = walk(*left);
137 let right = walk(*right);
138 sub.sub_bin(op, left, right)
139 }
140 Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
141 Expr::Break(ExprBreak {
142 label,
143 expr: Some(expr),
144 ..
145 }) => {
146 let expr = walk(*expr);
147 quote!(#label #expr)
148 }
149 Expr::Call(ExprCall { func, args, .. }) => {
150 let f = walk(*func);
151 let args = args.into_iter().map(walk);
152 quote!(#f ( #(#args),* ))
153 }
154 Expr::Closure(ExprClosure {
155 lifetimes,
156 constness,
157 movability,
158 asyncness,
159 capture,
160 inputs,
161 output,
162 body,
163 ..
164 }) => {
165 let body = walk(*body);
166 quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output #body)
167 }
168 Expr::ForLoop(ExprForLoop {
169 label,
170 pat,
171 expr,
172 body,
173 ..
174 }) => {
175 let (expr, body) = (walk(*expr), map_block(body));
176 quote!(#label for #pat in #expr #body)
177 }
178 Expr::Let(ExprLet { pat, expr, .. }) => {
179 quote_with!(expr = walk(*expr) => let #pat = #expr)
180 }
181 Expr::Const(ExprConst { block, .. }) => {
182 quote_with!(block =map_block(block) => const #block)
183 }
184 Expr::Range(ExprRange {
185 start, limits, end, ..
186 }) => {
187 let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
188 quote!((#start #limits #end))
189 }
190 Expr::Return(ExprReturn { expr, .. }) => {
191 let expr = expr.map(|x| walk(*x));
192 quote!(return #expr;)
193 }
194 Expr::Try(ExprTry { expr, .. }) => {
195 let expr = walk(*expr);
196 quote!(#expr ?)
197 }
198 Expr::TryBlock(ExprTryBlock { block, .. }) => {
199 let block = map_block(block);
200 quote!(try #block)
201 }
202 Expr::Unsafe(ExprUnsafe { block, .. }) => {
203 quote_with!(block =map_block(block) => unsafe #block)
204 }
205 Expr::While(ExprWhile {
206 label, cond, body, ..
207 }) => {
208 quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
209 }
210 Expr::Loop(ExprLoop { label, body, .. }) => {
211 quote_with!(body =map_block(body) => #label loop #body)
212 }
213 Expr::Reference(ExprReference {
214 mutability, expr, ..
215 }) => {
216 let expr = walk(*expr);
217 quote!(& #mutability #expr)
218 }
219 Expr::MethodCall(ExprMethodCall {
220 receiver,
221 method,
222 turbofish,
223 args,
224 ..
225 }) => {
226 let receiver = walk(*receiver);
227 let args = args.into_iter().map(walk);
228 quote!(#receiver . #method #turbofish (#(#args,)*))
229 }
230 Expr::If(ExprIf {
231 cond,
232 then_branch,
233 else_branch: Some((_, else_branch)),
234 ..
235 }) => {
236 let (cond, then_branch, else_branch) =
237 (walk(*cond), map_block(then_branch), walk(*else_branch));
238 quote!(if #cond #then_branch else #else_branch)
239 }
240 Expr::If(ExprIf {
241 cond, then_branch, ..
242 }) => {
243 let (cond, then_branch) = (walk(*cond), map_block(then_branch));
244 quote!(if #cond #then_branch)
245 }
246 Expr::Async(ExprAsync {
247 attrs,
248 capture,
249 block,
250 ..
251 }) => {
252 let block = map_block(block);
253 quote!(#(#attrs)* async #capture #block)
254 }
255 Expr::Await(ExprAwait { base, .. }) => {
256 let base = walk(*base);
257 quote!(#base.await)
258 }
259 Expr::Assign(ExprAssign { left, right, .. }) => {
260 let (left, right) = (walk(*left), walk(*right));
261 quote!(#left = #right;)
262 }
263 Expr::Paren(ExprParen { expr, .. }) => {
264 let expr = walk(*expr);
265 quote!(#expr)
266 }
267 Expr::Tuple(ExprTuple { elems, .. }) => {
268 let ts = elems.into_iter().map(walk);
269 quote!((#(#ts,)*))
270 }
271 Expr::Array(ExprArray { elems, .. }) => {
272 let ts = elems.into_iter().map(walk);
273 quote!([#(#ts,)*])
274 }
275 Expr::Repeat(ExprRepeat { expr, len, .. }) => {
276 let x = walk(*expr);
277 let len = walk(*len);
278 quote!([ #x ; #len ])
279 }
280 Expr::Block(ExprBlock {
281 block,
282 label: Some(label),
283 ..
284 }) => {
285 let b = map_block(block);
286 quote! { #label: #b }
287 }
288 Expr::Block(ExprBlock { block, .. }) => map_block(block),
289 e => quote!(#e),
290 }
291}
292
293fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
294 let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
295 quote! { { #(#stmts)* } }
296}
297
298fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
299 let walk = |e| walk(sub, e);
300 match x {
301 Stmt::Local(Local {
302 pat,
303 init:
304 Some(LocalInit {
305 expr,
306 diverge: Some((_, diverge)),
307 ..
308 }),
309 ..
310 }) => {
311 let expr = walk(*expr);
312 let diverge = walk(*diverge);
313 quote!(let #pat = #expr else { #diverge };)
314 }
315 Stmt::Local(Local {
316 pat,
317 init: Some(LocalInit { expr, .. }),
318 ..
319 }) => {
320 let expr = walk(*expr);
321 quote!(let #pat = #expr;)
322 }
323 Stmt::Item(x) => walk_item(sub, x),
324 Stmt::Expr(e, t) => {
325 let e = walk(e);
326 quote!(#e #t)
327 }
328 e => quote!(#e),
329 }
330}
331
332fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
333 let walk = |e| walk(sub, e);
334 match x {
335 Item::Const(ItemConst {
336 vis,
337 ident,
338 ty,
339 expr,
340 ..
341 }) => {
342 let expr = walk(*expr);
343 quote!(#vis const #ident : #ty = #expr;)
344 }
345 Item::Fn(ItemFn {
346 vis,
347 attrs,
348 sig,
349 block,
350 }) => {
351 let block = map_block(sub, *block);
352 quote!( #(#attrs)* #vis #sig #block)
353 }
354 Item::Impl(ItemImpl {
355 attrs,
356 unsafety,
357 defaultness,
358 generics,
359 trait_,
360 self_ty,
361 items,
362 ..
363 }) => {
364 let items = items.into_iter().map(|x| match x {
365 ImplItem::Const(ImplItemConst {
366 vis,
367 attrs,
368 defaultness,
369 ident,
370 ty,
371 expr,
372 ..
373 }) => {
374 let expr = walk(expr);
375 quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
376 }
377 ImplItem::Fn(ImplItemFn {
378 attrs,
379 vis,
380 defaultness,
381 sig,
382 block,
383 }) => {
384 let block = map_block(sub, block);
385 quote!(#(#attrs)* #vis #defaultness #sig #block)
386 }
387 e => quote!(#e),
388 });
389 let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
390 quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
391 }
392 Item::Mod(ItemMod {
393 attrs,
394 vis,
395 ident,
396 content: Some((_, content)),
397 ..
398 }) => {
399 let content = content.into_iter().map(|x| walk_item(sub, x));
400 quote!(#(#attrs)* #vis mod #ident { #(#content)* })
401 }
402 Item::Static(ItemStatic {
403 attrs,
404 vis,
405 mutability,
406 ident,
407 ty,
408 expr,
409 ..
410 }) => {
411 let expr = walk(*expr);
412 quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
413 }
414 e => quote!(#e),
415 }
416}
417
418macro_rules! walk {
419 ($input:ident,$t:expr) => {
420 match parse::<Expr>($input.clone())
421 .map(|x| walk(&$t, x))
422 .map_err(|x| x.to_compile_error().into_token_stream())
423 {
424 Ok(x) => x,
425 Err(e) => parse::<Stmt>($input)
426 .map(|x| walk_stmt(&$t, x))
427 .unwrap_or(e),
428 }
429 .into()
430 };
431}
432
433#[proc_macro]
434pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
435 walk!(input, Basic {})
436}
437
438#[proc_macro]
439pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
440 walk!(input, Fast {})
441}
442
443#[proc_macro]
444pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
445 walk!(input, Algebraic {})
446}
447
448#[proc_macro]
449pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
450 walk!(input, Wrapping {})
451}
452
453#[proc_macro_attribute]
454pub fn apply(
455 args: proc_macro::TokenStream,
456 input: proc_macro::TokenStream,
457) -> proc_macro::TokenStream {
458 match &*args.to_string() {
459 "basic" | "" => math(input),
460 "fast" => fast(input),
461 "algebraic" => algebraic(input),
462 "wrapping" => wrapping(input),
463 _ => quote! { compile_error!("type must be {fast, basic, algebraic, wrapping}") }.into(),
464 }
465}