1use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9 ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10 $(let $k = $v;)+
11 quote!($($tt)+)
12 }};
13}
14trait Sub {
15 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22 use syn::BinOp::*;
23 match op {
24 Add(_) => quote!((#left).add(#right)),
25 Sub(_) => quote!((#left).sub(#right)),
26 Mul(_) => quote!((#left).mul(#right)),
27 Div(_) => quote!((#left).div(#right)),
28 Rem(_) => quote!((#left).rem(#right)),
29 And(_) => quote!((#left).and(#right)),
30 Or(_) => quote!((#left).or(#right)),
31 BitXor(_) => quote!((#left).bitxor(#right)),
32 BitAnd(_) => quote!((#left).bitand(#right)),
33 BitOr(_) => quote!((#left).bitor(#right)),
34 Shl(_) => quote!((#left).shl(#right)),
35 Shr(_) => quote!((#left).shr(#right)),
36 Eq(_) => quote!((#left).eq(#right)),
37 Lt(_) => quote!((#left).lt(#right)),
38 Le(_) => quote!((#left).le(#right)),
39 Ne(_) => quote!((#left).ne(#right)),
40 Ge(_) => quote!((#left).ge(#right)),
41 Gt(_) => quote!((#left).gt(#right)),
42 e => {
44 Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45 }
46 }
47 }
48
49 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50 match op {
51 UnOp::Deref(_) => quote!((#x).deref()),
52 UnOp::Not(_) => quote!((#x).not()),
53 UnOp::Neg(_) => quote!((#x).neg()),
54 e => Error::new(
55 e.span(),
56 "it would appear a new operation has been added! please tell me.",
57 )
58 .to_compile_error(),
59 }
60 }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66 use syn::BinOp::*;
67 match op {
68 Add(_) => quote!((#left).wrapping_add(#right)),
69 Sub(_) => quote!((#left).wrapping_sub(#right)),
70 Mul(_) => quote!((#left).wrapping_mul(#right)),
71 Div(_) => quote!((#left).wrapping_div(#right)),
72 Rem(_) => quote!((#left).wrapping_rem(#right)),
73 Shl(_) => quote!((#left).wrapping_shl(#right)),
74 Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76 SubAssign(_) => quote!(#left = #left.wrapping_sub(#right)),
77 AddAssign(_) => quote!(#left = #left.wrapping_add(#right)),
78 MulAssign(_) => quote!(#left = #left.wrapping_mul(#right)),
79 DivAssign(_) => quote!(#left = #left.wrapping_div(#right)),
80 RemAssign(_) => quote!(#left = #left.wrapping_rem(#right)),
81 ShlAssign(_) => quote!(#left = #left.wrapping_shl(#right)),
82 ShrAssign(_) => quote!(#left = #left.wrapping_shr(#right)),
83
84 _ => quote!((#left) #op (#right)),
85 }
86 }
87
88 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
89 match op {
90 UnOp::Neg(_) => quote!((#x).wrapping_neg()),
91 _ => quote!(#op #x),
92 }
93 }
94}
95
96struct Saturating;
97impl Sub for Saturating {
98 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
99 use syn::BinOp::*;
100 match op {
101 Add(_) => quote!((#left).saturating_add(#right)),
102 Sub(_) => quote!((#left).saturating_sub(#right)),
103 Mul(_) => quote!((#left).saturating_mul(#right)),
104 Div(_) => quote!((#left).saturating_div(#right)),
105 Rem(_) => quote!((#left).saturating_rem(#right)),
106 Shl(_) => quote!((#left).saturating_shl(#right)),
107 Shr(_) => quote!((#left).saturating_shr(#right)),
108
109 SubAssign(_) => quote!(#left = #left.saturating_sub(#right)),
110 AddAssign(_) => quote!(#left = #left.saturating_add(#right)),
111 MulAssign(_) => quote!(#left = #left.saturating_mul(#right)),
112 DivAssign(_) => quote!(#left = #left.saturating_div(#right)),
113 RemAssign(_) => quote!(#left = #left.saturating_rem(#right)),
114 ShlAssign(_) => quote!(#left = #left.saturating_shl(#right)),
115 ShrAssign(_) => quote!(#left = #left.saturating_shr(#right)),
116
117 _ => quote!((#left) #op (#right)),
118 }
119 }
120
121 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
122 match op {
123 UnOp::Neg(_) => quote!((#x).saturating_neg()),
124 _ => quote!(#op #x),
125 }
126 }
127}
128
129struct Algebraic;
130impl Sub for Algebraic {
131 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
132 use syn::BinOp::*;
133 match op {
134 Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
135 Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
136 Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
137 Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
138 Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
139 And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
140 _ => quote!((#left) #op (#right)),
141 }
142 }
143
144 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
145 quote!(#op #x)
146 }
147}
148
149struct Fast;
150impl Sub for Fast {
151 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
152 use syn::BinOp::*;
153 match op {
154 Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
155 Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
156 Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
157 Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
158 Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
159 And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
160 Eq(_) => quote!(((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
161 _ => quote!((#left) #op (#right)),
162 }
163 }
164
165 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
166 quote!(#op #x)
167 }
168}
169
170fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
171 let walk = |e| walk(sub, e);
172 let map_block = |b| map_block(sub, b);
173 match e {
174 Expr::Binary(ExprBinary {
175 left, op, right, ..
176 }) => {
177 let left = walk(*left);
178 let right = walk(*right);
179 sub.sub_bin(op, left, right)
180 }
181 Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
182 Expr::Break(ExprBreak {
183 label,
184 expr: Some(expr),
185 ..
186 }) => {
187 let expr = walk(*expr);
188 quote!(#label #expr)
189 }
190 Expr::Call(ExprCall { func, args, .. }) => {
191 let f = walk(*func);
192 let args = args.into_iter().map(walk);
193 quote!(#f ( #(#args),* ))
194 }
195 Expr::Closure(ExprClosure {
196 lifetimes,
197 constness,
198 movability,
199 asyncness,
200 capture,
201 inputs,
202 output,
203 body,
204 ..
205 }) => {
206 let body = walk(*body);
207 quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output { #body })
208 }
209 Expr::ForLoop(ExprForLoop {
210 label,
211 pat,
212 expr,
213 body,
214 ..
215 }) => {
216 let (expr, body) = (walk(*expr), map_block(body));
217 quote!(#label for #pat in #expr #body)
218 }
219 Expr::Let(ExprLet { pat, expr, .. }) => {
220 quote_with!(expr = walk(*expr) => let #pat = #expr)
221 }
222 Expr::Const(ExprConst { block, .. }) => {
223 quote_with!(block =map_block(block) => const #block)
224 }
225 Expr::Range(ExprRange {
226 start, limits, end, ..
227 }) => {
228 let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
229 quote!((#start #limits #end))
230 }
231 Expr::Return(ExprReturn { expr, .. }) => {
232 let expr = expr.map(|x| walk(*x));
233 quote!(return #expr;)
234 }
235 Expr::Try(ExprTry { expr, .. }) => {
236 let expr = walk(*expr);
237 quote!(#expr ?)
238 }
239 Expr::TryBlock(ExprTryBlock { block, .. }) => {
240 let block = map_block(block);
241 quote!(try #block)
242 }
243 Expr::Unsafe(ExprUnsafe { block, .. }) => {
244 quote_with!(block =map_block(block) => unsafe #block)
245 }
246 Expr::While(ExprWhile {
247 label, cond, body, ..
248 }) => {
249 quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
250 }
251 Expr::Index(ExprIndex { expr, index, .. }) => {
252 let expr = walk(*expr);
253 let index = walk(*index);
254 quote!(#expr [ #index ])
255 }
256 Expr::Loop(ExprLoop { label, body, .. }) => {
257 quote_with!(body =map_block(body) => #label loop #body)
258 }
259 Expr::Reference(ExprReference {
260 mutability, expr, ..
261 }) => {
262 let expr = walk(*expr);
263 quote!(& #mutability #expr)
264 }
265 Expr::MethodCall(ExprMethodCall {
266 receiver,
267 method,
268 turbofish,
269 args,
270 ..
271 }) => {
272 let receiver = walk(*receiver);
273 let args = args.into_iter().map(walk);
274 quote!(#receiver . #method #turbofish (#(#args,)*))
275 }
276 Expr::Match(ExprMatch { expr, arms, .. }) => {
277 let arms = arms.into_iter().map(
278 |Arm {
279 pat,
280 guard,
281
282 body,
283 ..
285 }| {
286 let b = walk(*body);
287 let guard = match guard {
288 Some((i, x)) => {
289 let z = walk(*x);
290 quote! { #i #z }
291 }
292 None => quote! {},
293 };
294 quote! { #pat #guard => { #b } }
295 },
296 );
297 quote!(match #expr { #(#arms)* })
298 }
299 Expr::If(ExprIf {
300 cond,
301 then_branch,
302 else_branch: Some((_, else_branch)),
303 ..
304 }) => {
305 let (cond, then_branch, else_branch) =
306 (walk(*cond), map_block(then_branch), walk(*else_branch));
307 quote!(if #cond #then_branch else #else_branch)
308 }
309 Expr::If(ExprIf {
310 cond, then_branch, ..
311 }) => {
312 let (cond, then_branch) = (walk(*cond), map_block(then_branch));
313 quote!(if #cond #then_branch)
314 }
315 Expr::Async(ExprAsync {
316 attrs,
317 capture,
318 block,
319 ..
320 }) => {
321 let block = map_block(block);
322 quote!(#(#attrs)* async #capture #block)
323 }
324 Expr::Await(ExprAwait { base, .. }) => {
325 let base = walk(*base);
326 quote!(#base.await)
327 }
328 Expr::Assign(ExprAssign { left, right, .. }) => {
329 let (left, right) = (walk(*left), walk(*right));
330 quote!(#left = #right;)
331 }
332 Expr::Paren(ExprParen { expr, .. }) => {
333 let expr = walk(*expr);
334 quote!(#expr)
335 }
336 Expr::Tuple(ExprTuple { elems, .. }) => {
337 let ts = elems.into_iter().map(walk);
338 quote!((#(#ts,)*))
339 }
340 Expr::Array(ExprArray { elems, .. }) => {
341 let ts = elems.into_iter().map(walk);
342 quote!([#(#ts,)*])
343 }
344 Expr::Repeat(ExprRepeat { expr, len, .. }) => {
345 let x = walk(*expr);
346 let len = walk(*len);
347 quote!([ #x ; #len ])
348 }
349 Expr::Block(ExprBlock {
350 block,
351 label: Some(label),
352 ..
353 }) => {
354 let b = map_block(block);
355 quote! { #label: #b }
356 }
357 Expr::Block(ExprBlock { block, .. }) => map_block(block),
358 e => quote!(#e),
359 }
360}
361
362fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
363 let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
364 quote! { { #(#stmts)* } }
365}
366
367fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
368 let walk = |e| walk(sub, e);
369 match x {
370 Stmt::Local(Local {
371 pat,
372 init:
373 Some(LocalInit {
374 expr,
375 diverge: Some((_, diverge)),
376 ..
377 }),
378 ..
379 }) => {
380 let expr = walk(*expr);
381 let diverge = walk(*diverge);
382 quote!(let #pat = #expr else { #diverge };)
383 }
384 Stmt::Local(Local {
385 pat,
386 init: Some(LocalInit { expr, .. }),
387 ..
388 }) => {
389 let expr = walk(*expr);
390 quote!(let #pat = #expr;)
391 }
392 Stmt::Item(x) => walk_item(sub, x),
393 Stmt::Expr(e, t) => {
394 let e = walk(e);
395 quote!(#e #t)
396 }
397 e => quote!(#e),
398 }
399}
400
401fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
402 let walk = |e| walk(sub, e);
403 match x {
404 Item::Const(ItemConst {
405 vis,
406 ident,
407 ty,
408 expr,
409 ..
410 }) => {
411 let expr = walk(*expr);
412 quote!(#vis const #ident : #ty = #expr;)
413 }
414 Item::Fn(ItemFn {
415 vis,
416 attrs,
417 sig,
418 block,
419 }) => {
420 let block = map_block(sub, *block);
421 quote!( #(#attrs)* #vis #sig #block)
422 }
423 Item::Impl(ItemImpl {
424 attrs,
425 unsafety,
426 defaultness,
427 generics,
428 trait_,
429 self_ty,
430 items,
431 ..
432 }) => {
433 let items = items.into_iter().map(|x| match x {
434 ImplItem::Const(ImplItemConst {
435 vis,
436 attrs,
437 defaultness,
438 ident,
439 ty,
440 expr,
441 ..
442 }) => {
443 let expr = walk(expr);
444 quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
445 }
446 ImplItem::Fn(ImplItemFn {
447 attrs,
448 vis,
449 defaultness,
450 sig,
451 block,
452 }) => {
453 let block = map_block(sub, block);
454 quote!(#(#attrs)* #vis #defaultness #sig #block)
455 }
456 e => quote!(#e),
457 });
458 let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
459 quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
460 }
461 Item::Mod(ItemMod {
462 attrs,
463 vis,
464 ident,
465 content: Some((_, content)),
466 ..
467 }) => {
468 let content = content.into_iter().map(|x| walk_item(sub, x));
469 quote!(#(#attrs)* #vis mod #ident { #(#content)* })
470 }
471 Item::Static(ItemStatic {
472 attrs,
473 vis,
474 mutability,
475 ident,
476 ty,
477 expr,
478 ..
479 }) => {
480 let expr = walk(*expr);
481 quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
482 }
483 e => quote!(#e),
484 }
485}
486
487macro_rules! walk {
488 ($input:ident,$t:expr) => {
489 match parse::<Expr>($input.clone())
490 .map(|x| walk(&$t, x))
491 .map_err(|x| x.to_compile_error().into_token_stream())
492 {
493 Ok(x) => x,
494 Err(e) => parse::<Stmt>($input)
495 .map(|x| walk_stmt(&$t, x))
496 .unwrap_or(e),
497 }
498 .into()
499 };
500}
501
502#[proc_macro]
503pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
504 walk!(input, Basic {})
505}
506
507#[proc_macro]
508pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
509 walk!(input, Fast {})
510}
511
512#[proc_macro]
513pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
514 walk!(input, Algebraic {})
515}
516
517#[proc_macro]
518pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
519 walk!(input, Wrapping {})
520}
521
522#[proc_macro]
523pub fn saturating(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
524 walk!(input, Saturating {})
525}
526
527#[proc_macro_attribute]
528pub fn apply(
529 args: proc_macro::TokenStream,
530 input: proc_macro::TokenStream,
531) -> proc_macro::TokenStream {
532 match &*args.to_string() {
533 "basic" | "" => math(input),
534 "fast" => fast(input),
535 "algebraic" => algebraic(input),
536 "wrapping" => wrapping(input),
537 "saturating" => saturating(input),
538 _ => {
539 quote! { compile_error!("type must be {fast, basic, algebraic, wrapping, saturating}") }
540 .into()
541 }
542 }
543}