1use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9 ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10 $(let $k = $v;)+
11 quote!($($tt)+)
12 }};
13}
14trait Sub {
15 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22 use syn::BinOp::*;
23 match op {
24 Add(_) => quote!((#left).add(#right)),
25 Sub(_) => quote!((#left).sub(#right)),
26 Mul(_) => quote!((#left).mul(#right)),
27 Div(_) => quote!((#left).div(#right)),
28 Rem(_) => quote!((#left).rem(#right)),
29 And(_) => quote!((#left).and(#right)),
30 Or(_) => quote!((#left).or(#right)),
31 BitXor(_) => quote!((#left).bitxor(#right)),
32 BitAnd(_) => quote!((#left).bitand(#right)),
33 BitOr(_) => quote!((#left).bitor(#right)),
34 Shl(_) => quote!((#left).shl(#right)),
35 Shr(_) => quote!((#left).shr(#right)),
36 Eq(_) => quote!((#left).eq(#right)),
37 Lt(_) => quote!((#left).lt(#right)),
38 Le(_) => quote!((#left).le(#right)),
39 Ne(_) => quote!((#left).ne(#right)),
40 Ge(_) => quote!((#left).ge(#right)),
41 Gt(_) => quote!((#left).gt(#right)),
42 e => {
44 Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45 }
46 }
47 }
48
49 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50 match op {
51 UnOp::Deref(_) => quote!((#x).deref()),
52 UnOp::Not(_) => quote!((#x).not()),
53 UnOp::Neg(_) => quote!((#x).neg()),
54 e => Error::new(
55 e.span(),
56 "it would appear a new operation has been added! please tell me.",
57 )
58 .to_compile_error(),
59 }
60 }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66 use syn::BinOp::*;
67 match op {
68 Add(_) => quote!((#left).wrapping_add(#right)),
69 Sub(_) => quote!((#left).wrapping_sub(#right)),
70 Mul(_) => quote!((#left).wrapping_mul(#right)),
71 Div(_) => quote!((#left).wrapping_div(#right)),
72 Rem(_) => quote!((#left).wrapping_rem(#right)),
73 Shl(_) => quote!((#left).wrapping_shl(#right)),
74 Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76 _ => quote!((#left) #op (#right)),
77 }
78 }
79
80 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
81 match op {
82 UnOp::Neg(_) => quote!((#x).wrapping_neg()),
83 _ => quote!(#op #x),
84 }
85 }
86}
87
88struct Saturating;
89impl Sub for Saturating {
90 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
91 use syn::BinOp::*;
92 match op {
93 Add(_) => quote!((#left).saturating_add(#right)),
94 Sub(_) => quote!((#left).saturating_sub(#right)),
95 Mul(_) => quote!((#left).saturating_mul(#right)),
96 Div(_) => quote!((#left).saturating_div(#right)),
97 Rem(_) => quote!((#left).saturating_rem(#right)),
98 Shl(_) => quote!((#left).saturating_shl(#right)),
99 Shr(_) => quote!((#left).saturating_shr(#right)),
100
101 _ => quote!((#left) #op (#right)),
102 }
103 }
104
105 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
106 match op {
107 UnOp::Neg(_) => quote!((#x).saturating_neg()),
108 _ => quote!(#op #x),
109 }
110 }
111}
112
113struct Algebraic;
114impl Sub for Algebraic {
115 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
116 use syn::BinOp::*;
117 match op {
118 Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
119 Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
120 Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
121 Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
122 Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
123 And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
124 _ => quote!((#left) #op (#right)),
125 }
126 }
127
128 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
129 quote!(#op #x)
130 }
131}
132
133struct Fast;
134impl Sub for Fast {
135 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
136 use syn::BinOp::*;
137 match op {
138 Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
139 Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
140 Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
141 Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
142 Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
143 And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
144 Eq(_) => quote!(((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
145 _ => quote!((#left) #op (#right)),
146 }
147 }
148
149 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
150 quote!(#op #x)
151 }
152}
153
154fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
155 let walk = |e| walk(sub, e);
156 let map_block = |b| map_block(sub, b);
157 match e {
158 Expr::Binary(ExprBinary {
159 left, op, right, ..
160 }) => {
161 let left = walk(*left);
162 let right = walk(*right);
163 sub.sub_bin(op, left, right)
164 }
165 Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
166 Expr::Break(ExprBreak {
167 label,
168 expr: Some(expr),
169 ..
170 }) => {
171 let expr = walk(*expr);
172 quote!(#label #expr)
173 }
174 Expr::Call(ExprCall { func, args, .. }) => {
175 let f = walk(*func);
176 let args = args.into_iter().map(walk);
177 quote!(#f ( #(#args),* ))
178 }
179 Expr::Closure(ExprClosure {
180 lifetimes,
181 constness,
182 movability,
183 asyncness,
184 capture,
185 inputs,
186 output,
187 body,
188 ..
189 }) => {
190 let body = walk(*body);
191 quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output #body)
192 }
193 Expr::ForLoop(ExprForLoop {
194 label,
195 pat,
196 expr,
197 body,
198 ..
199 }) => {
200 let (expr, body) = (walk(*expr), map_block(body));
201 quote!(#label for #pat in #expr #body)
202 }
203 Expr::Let(ExprLet { pat, expr, .. }) => {
204 quote_with!(expr = walk(*expr) => let #pat = #expr)
205 }
206 Expr::Const(ExprConst { block, .. }) => {
207 quote_with!(block =map_block(block) => const #block)
208 }
209 Expr::Range(ExprRange {
210 start, limits, end, ..
211 }) => {
212 let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
213 quote!((#start #limits #end))
214 }
215 Expr::Return(ExprReturn { expr, .. }) => {
216 let expr = expr.map(|x| walk(*x));
217 quote!(return #expr;)
218 }
219 Expr::Try(ExprTry { expr, .. }) => {
220 let expr = walk(*expr);
221 quote!(#expr ?)
222 }
223 Expr::TryBlock(ExprTryBlock { block, .. }) => {
224 let block = map_block(block);
225 quote!(try #block)
226 }
227 Expr::Unsafe(ExprUnsafe { block, .. }) => {
228 quote_with!(block =map_block(block) => unsafe #block)
229 }
230 Expr::While(ExprWhile {
231 label, cond, body, ..
232 }) => {
233 quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
234 }
235 Expr::Index(ExprIndex { expr, index, .. }) => {
236 let expr = walk(*expr);
237 let index = walk(*index);
238 quote!(#expr [ #index ])
239 }
240 Expr::Loop(ExprLoop { label, body, .. }) => {
241 quote_with!(body =map_block(body) => #label loop #body)
242 }
243 Expr::Reference(ExprReference {
244 mutability, expr, ..
245 }) => {
246 let expr = walk(*expr);
247 quote!(& #mutability #expr)
248 }
249 Expr::MethodCall(ExprMethodCall {
250 receiver,
251 method,
252 turbofish,
253 args,
254 ..
255 }) => {
256 let receiver = walk(*receiver);
257 let args = args.into_iter().map(walk);
258 quote!(#receiver . #method #turbofish (#(#args,)*))
259 }
260 Expr::If(ExprIf {
261 cond,
262 then_branch,
263 else_branch: Some((_, else_branch)),
264 ..
265 }) => {
266 let (cond, then_branch, else_branch) =
267 (walk(*cond), map_block(then_branch), walk(*else_branch));
268 quote!(if #cond #then_branch else #else_branch)
269 }
270 Expr::If(ExprIf {
271 cond, then_branch, ..
272 }) => {
273 let (cond, then_branch) = (walk(*cond), map_block(then_branch));
274 quote!(if #cond #then_branch)
275 }
276 Expr::Async(ExprAsync {
277 attrs,
278 capture,
279 block,
280 ..
281 }) => {
282 let block = map_block(block);
283 quote!(#(#attrs)* async #capture #block)
284 }
285 Expr::Await(ExprAwait { base, .. }) => {
286 let base = walk(*base);
287 quote!(#base.await)
288 }
289 Expr::Assign(ExprAssign { left, right, .. }) => {
290 let (left, right) = (walk(*left), walk(*right));
291 quote!(#left = #right;)
292 }
293 Expr::Paren(ExprParen { expr, .. }) => {
294 let expr = walk(*expr);
295 quote!(#expr)
296 }
297 Expr::Tuple(ExprTuple { elems, .. }) => {
298 let ts = elems.into_iter().map(walk);
299 quote!((#(#ts,)*))
300 }
301 Expr::Array(ExprArray { elems, .. }) => {
302 let ts = elems.into_iter().map(walk);
303 quote!([#(#ts,)*])
304 }
305 Expr::Repeat(ExprRepeat { expr, len, .. }) => {
306 let x = walk(*expr);
307 let len = walk(*len);
308 quote!([ #x ; #len ])
309 }
310 Expr::Block(ExprBlock {
311 block,
312 label: Some(label),
313 ..
314 }) => {
315 let b = map_block(block);
316 quote! { #label: #b }
317 }
318 Expr::Block(ExprBlock { block, .. }) => map_block(block),
319 e => quote!(#e),
320 }
321}
322
323fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
324 let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
325 quote! { { #(#stmts)* } }
326}
327
328fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
329 let walk = |e| walk(sub, e);
330 match x {
331 Stmt::Local(Local {
332 pat,
333 init:
334 Some(LocalInit {
335 expr,
336 diverge: Some((_, diverge)),
337 ..
338 }),
339 ..
340 }) => {
341 let expr = walk(*expr);
342 let diverge = walk(*diverge);
343 quote!(let #pat = #expr else { #diverge };)
344 }
345 Stmt::Local(Local {
346 pat,
347 init: Some(LocalInit { expr, .. }),
348 ..
349 }) => {
350 let expr = walk(*expr);
351 quote!(let #pat = #expr;)
352 }
353 Stmt::Item(x) => walk_item(sub, x),
354 Stmt::Expr(e, t) => {
355 let e = walk(e);
356 quote!(#e #t)
357 }
358 e => quote!(#e),
359 }
360}
361
362fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
363 let walk = |e| walk(sub, e);
364 match x {
365 Item::Const(ItemConst {
366 vis,
367 ident,
368 ty,
369 expr,
370 ..
371 }) => {
372 let expr = walk(*expr);
373 quote!(#vis const #ident : #ty = #expr;)
374 }
375 Item::Fn(ItemFn {
376 vis,
377 attrs,
378 sig,
379 block,
380 }) => {
381 let block = map_block(sub, *block);
382 quote!( #(#attrs)* #vis #sig #block)
383 }
384 Item::Impl(ItemImpl {
385 attrs,
386 unsafety,
387 defaultness,
388 generics,
389 trait_,
390 self_ty,
391 items,
392 ..
393 }) => {
394 let items = items.into_iter().map(|x| match x {
395 ImplItem::Const(ImplItemConst {
396 vis,
397 attrs,
398 defaultness,
399 ident,
400 ty,
401 expr,
402 ..
403 }) => {
404 let expr = walk(expr);
405 quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
406 }
407 ImplItem::Fn(ImplItemFn {
408 attrs,
409 vis,
410 defaultness,
411 sig,
412 block,
413 }) => {
414 let block = map_block(sub, block);
415 quote!(#(#attrs)* #vis #defaultness #sig #block)
416 }
417 e => quote!(#e),
418 });
419 let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
420 quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
421 }
422 Item::Mod(ItemMod {
423 attrs,
424 vis,
425 ident,
426 content: Some((_, content)),
427 ..
428 }) => {
429 let content = content.into_iter().map(|x| walk_item(sub, x));
430 quote!(#(#attrs)* #vis mod #ident { #(#content)* })
431 }
432 Item::Static(ItemStatic {
433 attrs,
434 vis,
435 mutability,
436 ident,
437 ty,
438 expr,
439 ..
440 }) => {
441 let expr = walk(*expr);
442 quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
443 }
444 e => quote!(#e),
445 }
446}
447
448macro_rules! walk {
449 ($input:ident,$t:expr) => {
450 match parse::<Expr>($input.clone())
451 .map(|x| walk(&$t, x))
452 .map_err(|x| x.to_compile_error().into_token_stream())
453 {
454 Ok(x) => x,
455 Err(e) => parse::<Stmt>($input)
456 .map(|x| walk_stmt(&$t, x))
457 .unwrap_or(e),
458 }
459 .into()
460 };
461}
462
463#[proc_macro]
464pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
465 walk!(input, Basic {})
466}
467
468#[proc_macro]
469pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
470 walk!(input, Fast {})
471}
472
473#[proc_macro]
474pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
475 walk!(input, Algebraic {})
476}
477
478#[proc_macro]
479pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
480 walk!(input, Wrapping {})
481}
482
483#[proc_macro]
484pub fn saturating(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
485 walk!(input, Saturating {})
486}
487
488#[proc_macro_attribute]
489pub fn apply(
490 args: proc_macro::TokenStream,
491 input: proc_macro::TokenStream,
492) -> proc_macro::TokenStream {
493 match &*args.to_string() {
494 "basic" | "" => math(input),
495 "fast" => fast(input),
496 "algebraic" => algebraic(input),
497 "wrapping" => wrapping(input),
498 "saturating" => saturating(input),
499 _ => quote! { compile_error!("type must be {fast, basic, algebraic, wrapping}") }.into(),
500 }
501}