1use proc_macro2::TokenStream;
5use quote::{ToTokens, quote};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9 ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10 $(let $k = $v;)+
11 quote!($($tt)+)
12 }};
13}
14trait Sub {
15 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22 use syn::BinOp::*;
23 match op {
24 Add(_) => quote!((#left).add(#right)),
25 Sub(_) => quote!((#left).sub(#right)),
26 Mul(_) => quote!((#left).mul(#right)),
27 Div(_) => quote!((#left).div(#right)),
28 Rem(_) => quote!((#left).rem(#right)),
29 And(_) => quote!((#left).and(#right)),
30 Or(_) => quote!((#left).or(#right)),
31 BitXor(_) => quote!((#left).bitxor(#right)),
32 BitAnd(_) => quote!((#left).bitand(#right)),
33 BitOr(_) => quote!((#left).bitor(#right)),
34 Shl(_) => quote!((#left).shl(#right)),
35 Shr(_) => quote!((#left).shr(#right)),
36 Eq(_) => quote!((#left).eq(#right)),
37 Lt(_) => quote!((#left).lt(#right)),
38 Le(_) => quote!((#left).le(#right)),
39 Ne(_) => quote!((#left).ne(#right)),
40 Ge(_) => quote!((#left).ge(#right)),
41 Gt(_) => quote!((#left).gt(#right)),
42 e => {
44 Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45 }
46 }
47 }
48
49 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50 match op {
51 UnOp::Deref(_) => quote!((#x).deref()),
52 UnOp::Not(_) => quote!((#x).not()),
53 UnOp::Neg(_) => quote!((#x).neg()),
54 e => Error::new(
55 e.span(),
56 "it would appear a new operation has been added! please tell me.",
57 )
58 .to_compile_error(),
59 }
60 }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66 use syn::BinOp::*;
67 match op {
68 Add(_) => quote!((#left).wrapping_add(#right)),
69 Sub(_) => quote!((#left).wrapping_sub(#right)),
70 Mul(_) => quote!((#left).wrapping_mul(#right)),
71 Div(_) => quote!((#left).wrapping_div(#right)),
72 Rem(_) => quote!((#left).wrapping_rem(#right)),
73 Shl(_) => quote!((#left).wrapping_shl(#right)),
74 Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76 SubAssign(_) => quote!(#left = #left.wrapping_sub(#right)),
77 AddAssign(_) => quote!(#left = #left.wrapping_add(#right)),
78 MulAssign(_) => quote!(#left = #left.wrapping_mul(#right)),
79 DivAssign(_) => quote!(#left = #left.wrapping_div(#right)),
80 RemAssign(_) => quote!(#left = #left.wrapping_rem(#right)),
81 ShlAssign(_) => quote!(#left = #left.wrapping_shl(#right)),
82 ShrAssign(_) => quote!(#left = #left.wrapping_shr(#right)),
83
84 _ => match (|| {
85 syn::Result::Ok(
86 ExprBinary {
87 attrs: Default::default(),
88 left: syn::parse(left.into())?,
89 op,
90 right: syn::parse(right.into())?,
91 }
92 .to_token_stream(),
93 )
94 })() {
95 Ok(x) => x,
96 Err(e) => e.into_compile_error(),
97 },
98 }
99 }
100
101 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
102 match op {
103 UnOp::Neg(_) => quote!((#x).wrapping_neg()),
104 _ => quote!(#op #x),
105 }
106 }
107}
108
109struct Saturating;
110impl Sub for Saturating {
111 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
112 use syn::BinOp::*;
113 match op {
114 Add(_) => quote!((#left).saturating_add(#right)),
115 Sub(_) => quote!((#left).saturating_sub(#right)),
116 Mul(_) => quote!((#left).saturating_mul(#right)),
117 Div(_) => quote!((#left).saturating_div(#right)),
118 Rem(_) => quote!((#left).saturating_rem(#right)),
119 Shl(_) => quote!((#left).saturating_shl(#right)),
120 Shr(_) => quote!((#left).saturating_shr(#right)),
121
122 SubAssign(_) => quote!(#left = #left.saturating_sub(#right)),
123 AddAssign(_) => quote!(#left = #left.saturating_add(#right)),
124 MulAssign(_) => quote!(#left = #left.saturating_mul(#right)),
125 DivAssign(_) => quote!(#left = #left.saturating_div(#right)),
126 RemAssign(_) => quote!(#left = #left.saturating_rem(#right)),
127 ShlAssign(_) => quote!(#left = #left.saturating_shl(#right)),
128 ShrAssign(_) => quote!(#left = #left.saturating_shr(#right)),
129
130 _ => match (|| {
131 syn::Result::Ok(
132 ExprBinary {
133 attrs: Default::default(),
134 left: syn::parse(left.into())?,
135 op,
136 right: syn::parse(right.into())?,
137 }
138 .to_token_stream(),
139 )
140 })() {
141 Ok(x) => x,
142 Err(e) => e.into_compile_error(),
143 },
144 }
145 }
146
147 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
148 match op {
149 UnOp::Neg(_) => quote!((#x).saturating_neg()),
150 _ => quote!(#op #x),
151 }
152 }
153}
154
155struct Algebraic;
156impl Sub for Algebraic {
157 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
158 use syn::BinOp::*;
159 match op {
160 Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
161 Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
162 Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
163 Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
164 Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
165
166 _ => match (|| {
167 syn::Result::Ok(
168 ExprBinary {
169 attrs: Default::default(),
170 left: syn::parse(left.into())?,
171 op,
172 right: syn::parse(right.into())?,
173 }
174 .to_token_stream(),
175 )
176 })() {
177 Ok(x) => x,
178 Err(e) => e.into_compile_error(),
179 },
180 }
181 }
182
183 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
184 quote!(#op #x)
185 }
186}
187
188struct Fast;
189impl Sub for Fast {
190 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
191 use syn::BinOp::*;
192 match op {
193 Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
194 Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
195 Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
196 Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
197 Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
198 Eq(_) => quote!(((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
199
200 _ => match (|| {
201 syn::Result::Ok(
202 ExprBinary {
203 attrs: Default::default(),
204 left: syn::parse(left.into())?,
205 op,
206 right: syn::parse(right.into())?,
207 }
208 .to_token_stream(),
209 )
210 })() {
211 Ok(x) => x,
212 Err(e) => e.into_compile_error(),
213 },
214 }
215 }
216
217 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
218 quote!(#op #x)
219 }
220}
221
222fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
223 let walk = |e| walk(sub, e);
224 let map_block = |b| map_block(sub, b);
225 match e {
226 Expr::Binary(ExprBinary {
227 left, op, right, ..
228 }) => {
229 let left = walk(*left);
230 let right = walk(*right);
231 sub.sub_bin(op, left, right)
232 }
233 Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
234 Expr::Break(ExprBreak {
235 label,
236 expr: Some(expr),
237 ..
238 }) => {
239 let expr = walk(*expr);
240 quote!(#label #expr)
241 }
242 Expr::Call(ExprCall { func, args, .. }) => {
243 let f = walk(*func);
244 let args = args.into_iter().map(walk);
245 quote!(#f ( #(#args),* ))
246 }
247 Expr::Closure(ExprClosure {
248 lifetimes,
249 constness,
250 movability,
251 asyncness,
252 capture,
253 inputs,
254 output,
255 body,
256 ..
257 }) => {
258 let body = walk(*body);
259 quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output { #body })
260 }
261 Expr::ForLoop(ExprForLoop {
262 label,
263 pat,
264 expr,
265 body,
266 ..
267 }) => {
268 let (expr, body) = (walk(*expr), map_block(body));
269 quote!(#label for #pat in #expr #body)
270 }
271 Expr::Let(ExprLet { pat, expr, .. }) => {
272 quote_with!(expr = walk(*expr) => let #pat = #expr)
273 }
274 Expr::Const(ExprConst { block, .. }) => {
275 quote_with!(block =map_block(block) => const #block)
276 }
277 Expr::Range(ExprRange {
278 start, limits, end, ..
279 }) => {
280 let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
281 quote!((#start #limits #end))
282 }
283 Expr::Return(ExprReturn { expr, .. }) => {
284 let expr = expr.map(|x| walk(*x));
285 quote!(return #expr;)
286 }
287 Expr::Try(ExprTry { expr, .. }) => {
288 let expr = walk(*expr);
289 quote!(#expr ?)
290 }
291 Expr::TryBlock(ExprTryBlock { block, .. }) => {
292 let block = map_block(block);
293 quote!(try #block)
294 }
295 Expr::Unsafe(ExprUnsafe { block, .. }) => {
296 quote_with!(block =map_block(block) => unsafe #block)
297 }
298 Expr::While(ExprWhile {
299 label, cond, body, ..
300 }) => {
301 quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
302 }
303 Expr::Index(ExprIndex { expr, index, .. }) => {
304 let expr = walk(*expr);
305 let index = walk(*index);
306 quote!(#expr [ #index ])
307 }
308 Expr::Loop(ExprLoop { label, body, .. }) => {
309 quote_with!(body =map_block(body) => #label loop #body)
310 }
311 Expr::Reference(ExprReference {
312 mutability, expr, ..
313 }) => {
314 let expr = walk(*expr);
315 quote!(& #mutability #expr)
316 }
317 Expr::MethodCall(ExprMethodCall {
318 receiver,
319 method,
320 turbofish,
321 args,
322 ..
323 }) => {
324 let receiver = walk(*receiver);
325 let args = args.into_iter().map(walk);
326 quote!(#receiver . #method #turbofish (#(#args,)*))
327 }
328 Expr::Match(ExprMatch { expr, arms, .. }) => {
329 let arms = arms.into_iter().map(
330 |Arm {
331 pat,
332 guard,
333
334 body,
335 ..
337 }| {
338 let b = walk(*body);
339 let guard = match guard {
340 Some((i, x)) => {
341 let z = walk(*x);
342 quote! { #i #z }
343 }
344 None => quote! {},
345 };
346 quote! { #pat #guard => { #b } }
347 },
348 );
349 quote!(match #expr { #(#arms)* })
350 }
351 Expr::If(ExprIf {
352 cond,
353 then_branch,
354 else_branch: Some((_, else_branch)),
355 ..
356 }) => {
357 let (cond, then_branch, else_branch) =
358 (walk(*cond), map_block(then_branch), walk(*else_branch));
359 quote!(if #cond #then_branch else #else_branch)
360 }
361 Expr::If(ExprIf {
362 cond, then_branch, ..
363 }) => {
364 let (cond, then_branch) = (walk(*cond), map_block(then_branch));
365 quote!(if #cond #then_branch)
366 }
367 Expr::Async(ExprAsync {
368 attrs,
369 capture,
370 block,
371 ..
372 }) => {
373 let block = map_block(block);
374 quote!(#(#attrs)* async #capture #block)
375 }
376 Expr::Await(ExprAwait { base, .. }) => {
377 let base = walk(*base);
378 quote!(#base.await)
379 }
380 Expr::Assign(ExprAssign { left, right, .. }) => {
381 let (left, right) = (walk(*left), walk(*right));
382 quote!(#left = #right;)
383 }
384 Expr::Paren(ExprParen { expr, .. }) => {
385 let expr = walk(*expr);
386 quote!(#expr)
387 }
388 Expr::Tuple(ExprTuple { elems, .. }) => {
389 let ts = elems.into_iter().map(walk);
390 quote!((#(#ts,)*))
391 }
392 Expr::Array(ExprArray { elems, .. }) => {
393 let ts = elems.into_iter().map(walk);
394 quote!([#(#ts,)*])
395 }
396 Expr::Repeat(ExprRepeat { expr, len, .. }) => {
397 let x = walk(*expr);
398 let len = walk(*len);
399 quote!([ #x ; #len ])
400 }
401 Expr::Block(ExprBlock {
402 block,
403 label: Some(label),
404 ..
405 }) => {
406 let b = map_block(block);
407 quote! { #label: #b }
408 }
409 Expr::Block(ExprBlock { block, .. }) => map_block(block),
410 Expr::Cast(ExprCast {
411 expr, as_token, ty, ..
412 }) => {
413 let e = walk(*expr);
414 quote! { #e #as_token #ty }
415 }
416 e => quote!(#e),
417 }
418}
419
420fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
421 let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
422 quote! { { #(#stmts)* } }
423}
424
425fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
426 let walk = |e| walk(sub, e);
427 match x {
428 Stmt::Local(Local {
429 pat,
430 init:
431 Some(LocalInit {
432 expr,
433 diverge: Some((_, diverge)),
434 ..
435 }),
436 ..
437 }) => {
438 let expr = walk(*expr);
439 let diverge = walk(*diverge);
440 quote!(let #pat = #expr else { #diverge };)
441 }
442 Stmt::Local(Local {
443 pat,
444 init: Some(LocalInit { expr, .. }),
445 ..
446 }) => {
447 let expr = walk(*expr);
448 quote!(let #pat = #expr;)
449 }
450 Stmt::Item(x) => walk_item(sub, x),
451 Stmt::Expr(e, t) => {
452 let e = walk(e);
453 quote!(#e #t)
454 }
455 e => quote!(#e),
456 }
457}
458
459fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
460 let walk = |e| walk(sub, e);
461 match x {
462 Item::Const(ItemConst {
463 vis,
464 ident,
465 ty,
466 expr,
467 ..
468 }) => {
469 let expr = walk(*expr);
470 quote!(#vis const #ident : #ty = #expr;)
471 }
472 Item::Fn(ItemFn {
473 vis,
474 attrs,
475 sig,
476 block,
477 }) => {
478 let block = map_block(sub, *block);
479 quote!( #(#attrs)* #vis #sig #block)
480 }
481 Item::Impl(ItemImpl {
482 attrs,
483 unsafety,
484 defaultness,
485 generics,
486 trait_,
487 self_ty,
488 items,
489 ..
490 }) => {
491 let items = items.into_iter().map(|x| match x {
492 ImplItem::Const(ImplItemConst {
493 vis,
494 attrs,
495 defaultness,
496 ident,
497 ty,
498 expr,
499 ..
500 }) => {
501 let expr = walk(expr);
502 quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
503 }
504 ImplItem::Fn(ImplItemFn {
505 attrs,
506 vis,
507 defaultness,
508 sig,
509 block,
510 }) => {
511 let block = map_block(sub, block);
512 quote!(#(#attrs)* #vis #defaultness #sig #block)
513 }
514 e => quote!(#e),
515 });
516 let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
517 quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
518 }
519 Item::Mod(ItemMod {
520 attrs,
521 vis,
522 ident,
523 content: Some((_, content)),
524 ..
525 }) => {
526 let content = content.into_iter().map(|x| walk_item(sub, x));
527 quote!(#(#attrs)* #vis mod #ident { #(#content)* })
528 }
529 Item::Static(ItemStatic {
530 attrs,
531 vis,
532 mutability,
533 ident,
534 ty,
535 expr,
536 ..
537 }) => {
538 let expr = walk(*expr);
539 quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
540 }
541 e => quote!(#e),
542 }
543}
544
545macro_rules! walk {
546 ($input:ident,$t:expr) => {
547 match parse::<Expr>($input.clone())
548 .map(|x| walk(&$t, x))
549 .map_err(|x| x.to_compile_error().into_token_stream())
550 {
551 Ok(x) => x,
552 Err(e) => parse::<Stmt>($input)
553 .map(|x| walk_stmt(&$t, x))
554 .unwrap_or(e),
555 }
556 .into()
557 };
558}
559
560#[proc_macro]
561pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
562 walk!(input, Basic {})
563}
564
565#[proc_macro]
566pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
567 walk!(input, Fast {})
568}
569
570#[proc_macro]
571pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
572 walk!(input, Algebraic {})
573}
574
575#[proc_macro]
576pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
577 walk!(input, Wrapping {})
578}
579
580#[proc_macro]
581pub fn saturating(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
582 walk!(input, Saturating {})
583}
584
585#[proc_macro_attribute]
586pub fn apply(
587 args: proc_macro::TokenStream,
588 input: proc_macro::TokenStream,
589) -> proc_macro::TokenStream {
590 match &*args.to_string() {
591 "basic" | "" => math(input),
592 "fast" => fast(input),
593 "algebraic" => algebraic(input),
594 "wrapping" => wrapping(input),
595 "saturating" => saturating(input),
596 _ => {
597 quote! { compile_error!("type must be {fast, basic, algebraic, wrapping, saturating}") }
598 .into()
599 }
600 }
601}