1use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9 ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10 $(let $k = $v;)+
11 quote!($($tt)+)
12 }};
13}
14trait Sub {
15 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22 use syn::BinOp::*;
23 match op {
24 Add(_) => quote!((#left).add(#right)),
25 Sub(_) => quote!((#left).sub(#right)),
26 Mul(_) => quote!((#left).mul(#right)),
27 Div(_) => quote!((#left).div(#right)),
28 Rem(_) => quote!((#left).rem(#right)),
29 And(_) => quote!((#left).and(#right)),
30 Or(_) => quote!((#left).or(#right)),
31 BitXor(_) => quote!((#left).bitxor(#right)),
32 BitAnd(_) => quote!((#left).bitand(#right)),
33 BitOr(_) => quote!((#left).bitor(#right)),
34 Shl(_) => quote!((#left).shl(#right)),
35 Shr(_) => quote!((#left).shr(#right)),
36 Eq(_) => quote!((#left).eq(#right)),
37 Lt(_) => quote!((#left).lt(#right)),
38 Le(_) => quote!((#left).le(#right)),
39 Ne(_) => quote!((#left).ne(#right)),
40 Ge(_) => quote!((#left).ge(#right)),
41 Gt(_) => quote!((#left).gt(#right)),
42 e => {
44 Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45 }
46 }
47 }
48
49 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50 match op {
51 UnOp::Deref(_) => quote!((#x).deref()),
52 UnOp::Not(_) => quote!((#x).not()),
53 UnOp::Neg(_) => quote!((#x).neg()),
54 e => Error::new(
55 e.span(),
56 "it would appear a new operation has been added! please tell me.",
57 )
58 .to_compile_error(),
59 }
60 }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66 use syn::BinOp::*;
67 match op {
68 Add(_) => quote!((#left).wrapping_add(#right)),
69 Sub(_) => quote!((#left).wrapping_sub(#right)),
70 Mul(_) => quote!((#left).wrapping_mul(#right)),
71 Div(_) => quote!((#left).wrapping_div(#right)),
72 Rem(_) => quote!((#left).wrapping_rem(#right)),
73 Shl(_) => quote!((#left).wrapping_shl(#right)),
74 Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76 SubAssign(_) => quote!(#left = #left.wrapping_sub(#right)),
77 AddAssign(_) => quote!(#left = #left.wrapping_add(#right)),
78 MulAssign(_) => quote!(#left = #left.wrapping_mul(#right)),
79 DivAssign(_) => quote!(#left = #left.wrapping_div(#right)),
80 RemAssign(_) => quote!(#left = #left.wrapping_rem(#right)),
81 ShlAssign(_) => quote!(#left = #left.wrapping_shl(#right)),
82 ShrAssign(_) => quote!(#left = #left.wrapping_shr(#right)),
83
84 _ => quote!((#left) #op (#right)),
85 }
86 }
87
88 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
89 match op {
90 UnOp::Neg(_) => quote!((#x).wrapping_neg()),
91 _ => quote!(#op #x),
92 }
93 }
94}
95
96struct Saturating;
97impl Sub for Saturating {
98 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
99 use syn::BinOp::*;
100 match op {
101 Add(_) => quote!((#left).saturating_add(#right)),
102 Sub(_) => quote!((#left).saturating_sub(#right)),
103 Mul(_) => quote!((#left).saturating_mul(#right)),
104 Div(_) => quote!((#left).saturating_div(#right)),
105 Rem(_) => quote!((#left).saturating_rem(#right)),
106 Shl(_) => quote!((#left).saturating_shl(#right)),
107 Shr(_) => quote!((#left).saturating_shr(#right)),
108
109 SubAssign(_) => quote!(#left = #left.saturating_sub(#right)),
110 AddAssign(_) => quote!(#left = #left.saturating_add(#right)),
111 MulAssign(_) => quote!(#left = #left.saturating_mul(#right)),
112 DivAssign(_) => quote!(#left = #left.saturating_div(#right)),
113 RemAssign(_) => quote!(#left = #left.saturating_rem(#right)),
114 ShlAssign(_) => quote!(#left = #left.saturating_shl(#right)),
115 ShrAssign(_) => quote!(#left = #left.saturating_shr(#right)),
116
117 _ => quote!((#left) #op (#right)),
118 }
119 }
120
121 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
122 match op {
123 UnOp::Neg(_) => quote!((#x).saturating_neg()),
124 _ => quote!(#op #x),
125 }
126 }
127}
128
129struct Algebraic;
130impl Sub for Algebraic {
131 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
132 use syn::BinOp::*;
133 match op {
134 Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
135 Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
136 Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
137 Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
138 Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
139 And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
140 _ => quote!((#left) #op (#right)),
141 }
142 }
143
144 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
145 quote!(#op #x)
146 }
147}
148
149struct Fast;
150impl Sub for Fast {
151 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
152 use syn::BinOp::*;
153 match op {
154 Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
155 Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
156 Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
157 Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
158 Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
159 And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
160 Eq(_) => quote!(((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
161 _ => quote!((#left) #op (#right)),
162 }
163 }
164
165 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
166 quote!(#op #x)
167 }
168}
169
170fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
171 let walk = |e| walk(sub, e);
172 let map_block = |b| map_block(sub, b);
173 match e {
174 Expr::Binary(ExprBinary {
175 left, op, right, ..
176 }) => {
177 let left = walk(*left);
178 let right = walk(*right);
179 sub.sub_bin(op, left, right)
180 }
181 Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
182 Expr::Break(ExprBreak {
183 label,
184 expr: Some(expr),
185 ..
186 }) => {
187 let expr = walk(*expr);
188 quote!(#label #expr)
189 }
190 Expr::Call(ExprCall { func, args, .. }) => {
191 let f = walk(*func);
192 let args = args.into_iter().map(walk);
193 quote!(#f ( #(#args),* ))
194 }
195 Expr::Closure(ExprClosure {
196 lifetimes,
197 constness,
198 movability,
199 asyncness,
200 capture,
201 inputs,
202 output,
203 body,
204 ..
205 }) => {
206 let body = walk(*body);
207 quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output { #body })
208 }
209 Expr::ForLoop(ExprForLoop {
210 label,
211 pat,
212 expr,
213 body,
214 ..
215 }) => {
216 let (expr, body) = (walk(*expr), map_block(body));
217 quote!(#label for #pat in #expr #body)
218 }
219 Expr::Let(ExprLet { pat, expr, .. }) => {
220 quote_with!(expr = walk(*expr) => let #pat = #expr)
221 }
222 Expr::Const(ExprConst { block, .. }) => {
223 quote_with!(block =map_block(block) => const #block)
224 }
225 Expr::Range(ExprRange {
226 start, limits, end, ..
227 }) => {
228 let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
229 quote!((#start #limits #end))
230 }
231 Expr::Return(ExprReturn { expr, .. }) => {
232 let expr = expr.map(|x| walk(*x));
233 quote!(return #expr;)
234 }
235 Expr::Try(ExprTry { expr, .. }) => {
236 let expr = walk(*expr);
237 quote!(#expr ?)
238 }
239 Expr::TryBlock(ExprTryBlock { block, .. }) => {
240 let block = map_block(block);
241 quote!(try #block)
242 }
243 Expr::Unsafe(ExprUnsafe { block, .. }) => {
244 quote_with!(block =map_block(block) => unsafe #block)
245 }
246 Expr::While(ExprWhile {
247 label, cond, body, ..
248 }) => {
249 quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
250 }
251 Expr::Index(ExprIndex { expr, index, .. }) => {
252 let expr = walk(*expr);
253 let index = walk(*index);
254 quote!(#expr [ #index ])
255 }
256 Expr::Loop(ExprLoop { label, body, .. }) => {
257 quote_with!(body =map_block(body) => #label loop #body)
258 }
259 Expr::Reference(ExprReference {
260 mutability, expr, ..
261 }) => {
262 let expr = walk(*expr);
263 quote!(& #mutability #expr)
264 }
265 Expr::MethodCall(ExprMethodCall {
266 receiver,
267 method,
268 turbofish,
269 args,
270 ..
271 }) => {
272 let receiver = walk(*receiver);
273 let args = args.into_iter().map(walk);
274 quote!(#receiver . #method #turbofish (#(#args,)*))
275 }
276 Expr::Match(ExprMatch { expr, arms, .. }) => {
277 let arms = arms.into_iter().map(
278 |Arm {
279 pat,
280 guard,
281
282 body,
283 ..
285 }| {
286 let b = walk(*body);
287 let guard = match guard {
288 Some((i, x)) => {
289 let z = walk(*x);
290 quote! { #i #z }
291 }
292 None => quote! {},
293 };
294 quote! { #pat #guard => { #b } }
295 },
296 );
297 quote!(match #expr { #(#arms)* })
298 }
299 Expr::If(ExprIf {
300 cond,
301 then_branch,
302 else_branch: Some((_, else_branch)),
303 ..
304 }) => {
305 let (cond, then_branch, else_branch) =
306 (walk(*cond), map_block(then_branch), walk(*else_branch));
307 quote!(if #cond #then_branch else #else_branch)
308 }
309 Expr::If(ExprIf {
310 cond, then_branch, ..
311 }) => {
312 let (cond, then_branch) = (walk(*cond), map_block(then_branch));
313 quote!(if #cond #then_branch)
314 }
315 Expr::Async(ExprAsync {
316 attrs,
317 capture,
318 block,
319 ..
320 }) => {
321 let block = map_block(block);
322 quote!(#(#attrs)* async #capture #block)
323 }
324 Expr::Await(ExprAwait { base, .. }) => {
325 let base = walk(*base);
326 quote!(#base.await)
327 }
328 Expr::Assign(ExprAssign { left, right, .. }) => {
329 let (left, right) = (walk(*left), walk(*right));
330 quote!(#left = #right;)
331 }
332 Expr::Paren(ExprParen { expr, .. }) => {
333 let expr = walk(*expr);
334 quote!(#expr)
335 }
336 Expr::Tuple(ExprTuple { elems, .. }) => {
337 let ts = elems.into_iter().map(walk);
338 quote!((#(#ts,)*))
339 }
340 Expr::Array(ExprArray { elems, .. }) => {
341 let ts = elems.into_iter().map(walk);
342 quote!([#(#ts,)*])
343 }
344 Expr::Repeat(ExprRepeat { expr, len, .. }) => {
345 let x = walk(*expr);
346 let len = walk(*len);
347 quote!([ #x ; #len ])
348 }
349 Expr::Block(ExprBlock {
350 block,
351 label: Some(label),
352 ..
353 }) => {
354 let b = map_block(block);
355 quote! { #label: #b }
356 }
357 Expr::Block(ExprBlock { block, .. }) => map_block(block),
358 Expr::Cast(ExprCast {
359 expr, as_token, ty, ..
360 }) => {
361 let e = walk(*expr);
362 quote! { #e #as_token #ty }
363 }
364 e => quote!(#e),
365 }
366}
367
368fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
369 let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
370 quote! { { #(#stmts)* } }
371}
372
373fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
374 let walk = |e| walk(sub, e);
375 match x {
376 Stmt::Local(Local {
377 pat,
378 init:
379 Some(LocalInit {
380 expr,
381 diverge: Some((_, diverge)),
382 ..
383 }),
384 ..
385 }) => {
386 let expr = walk(*expr);
387 let diverge = walk(*diverge);
388 quote!(let #pat = #expr else { #diverge };)
389 }
390 Stmt::Local(Local {
391 pat,
392 init: Some(LocalInit { expr, .. }),
393 ..
394 }) => {
395 let expr = walk(*expr);
396 quote!(let #pat = #expr;)
397 }
398 Stmt::Item(x) => walk_item(sub, x),
399 Stmt::Expr(e, t) => {
400 let e = walk(e);
401 quote!(#e #t)
402 }
403 e => quote!(#e),
404 }
405}
406
407fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
408 let walk = |e| walk(sub, e);
409 match x {
410 Item::Const(ItemConst {
411 vis,
412 ident,
413 ty,
414 expr,
415 ..
416 }) => {
417 let expr = walk(*expr);
418 quote!(#vis const #ident : #ty = #expr;)
419 }
420 Item::Fn(ItemFn {
421 vis,
422 attrs,
423 sig,
424 block,
425 }) => {
426 let block = map_block(sub, *block);
427 quote!( #(#attrs)* #vis #sig #block)
428 }
429 Item::Impl(ItemImpl {
430 attrs,
431 unsafety,
432 defaultness,
433 generics,
434 trait_,
435 self_ty,
436 items,
437 ..
438 }) => {
439 let items = items.into_iter().map(|x| match x {
440 ImplItem::Const(ImplItemConst {
441 vis,
442 attrs,
443 defaultness,
444 ident,
445 ty,
446 expr,
447 ..
448 }) => {
449 let expr = walk(expr);
450 quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
451 }
452 ImplItem::Fn(ImplItemFn {
453 attrs,
454 vis,
455 defaultness,
456 sig,
457 block,
458 }) => {
459 let block = map_block(sub, block);
460 quote!(#(#attrs)* #vis #defaultness #sig #block)
461 }
462 e => quote!(#e),
463 });
464 let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
465 quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
466 }
467 Item::Mod(ItemMod {
468 attrs,
469 vis,
470 ident,
471 content: Some((_, content)),
472 ..
473 }) => {
474 let content = content.into_iter().map(|x| walk_item(sub, x));
475 quote!(#(#attrs)* #vis mod #ident { #(#content)* })
476 }
477 Item::Static(ItemStatic {
478 attrs,
479 vis,
480 mutability,
481 ident,
482 ty,
483 expr,
484 ..
485 }) => {
486 let expr = walk(*expr);
487 quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
488 }
489 e => quote!(#e),
490 }
491}
492
493macro_rules! walk {
494 ($input:ident,$t:expr) => {
495 match parse::<Expr>($input.clone())
496 .map(|x| walk(&$t, x))
497 .map_err(|x| x.to_compile_error().into_token_stream())
498 {
499 Ok(x) => x,
500 Err(e) => parse::<Stmt>($input)
501 .map(|x| walk_stmt(&$t, x))
502 .unwrap_or(e),
503 }
504 .into()
505 };
506}
507
508#[proc_macro]
509pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
510 walk!(input, Basic {})
511}
512
513#[proc_macro]
514pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
515 walk!(input, Fast {})
516}
517
518#[proc_macro]
519pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
520 walk!(input, Algebraic {})
521}
522
523#[proc_macro]
524pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
525 walk!(input, Wrapping {})
526}
527
528#[proc_macro]
529pub fn saturating(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
530 walk!(input, Saturating {})
531}
532
533#[proc_macro_attribute]
534pub fn apply(
535 args: proc_macro::TokenStream,
536 input: proc_macro::TokenStream,
537) -> proc_macro::TokenStream {
538 match &*args.to_string() {
539 "basic" | "" => math(input),
540 "fast" => fast(input),
541 "algebraic" => algebraic(input),
542 "wrapping" => wrapping(input),
543 "saturating" => saturating(input),
544 _ => {
545 quote! { compile_error!("type must be {fast, basic, algebraic, wrapping, saturating}") }
546 .into()
547 }
548 }
549}