1use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{spanned::Spanned, *};
7
8macro_rules! quote_with {
9 ($($k: ident = $v: expr);+ => $($tt:tt)+) => {{
10 $(let $k = $v;)+
11 quote!($($tt)+)
12 }};
13}
14trait Sub {
15 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
16 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream;
17}
18
19struct Basic;
20impl Sub for Basic {
21 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
22 use syn::BinOp::*;
23 match op {
24 Add(_) => quote!((#left).add(#right)),
25 Sub(_) => quote!((#left).sub(#right)),
26 Mul(_) => quote!((#left).mul(#right)),
27 Div(_) => quote!((#left).div(#right)),
28 Rem(_) => quote!((#left).rem(#right)),
29 And(_) => quote!((#left).and(#right)),
30 Or(_) => quote!((#left).or(#right)),
31 BitXor(_) => quote!((#left).bitxor(#right)),
32 BitAnd(_) => quote!((#left).bitand(#right)),
33 BitOr(_) => quote!((#left).bitor(#right)),
34 Shl(_) => quote!((#left).shl(#right)),
35 Shr(_) => quote!((#left).shr(#right)),
36 Eq(_) => quote!((#left).eq(#right)),
37 Lt(_) => quote!((#left).lt(#right)),
38 Le(_) => quote!((#left).le(#right)),
39 Ne(_) => quote!((#left).ne(#right)),
40 Ge(_) => quote!((#left).ge(#right)),
41 Gt(_) => quote!((#left).gt(#right)),
42 e => {
44 Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
45 }
46 }
47 }
48
49 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
50 match op {
51 UnOp::Deref(_) => quote!((#x).deref()),
52 UnOp::Not(_) => quote!((#x).not()),
53 UnOp::Neg(_) => quote!((#x).neg()),
54 e => Error::new(
55 e.span(),
56 "it would appear a new operation has been added! please tell me.",
57 )
58 .to_compile_error(),
59 }
60 }
61}
62
63struct Wrapping;
64impl Sub for Wrapping {
65 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
66 use syn::BinOp::*;
67 match op {
68 Add(_) => quote!((#left).wrapping_add(#right)),
69 Sub(_) => quote!((#left).wrapping_sub(#right)),
70 Mul(_) => quote!((#left).wrapping_mul(#right)),
71 Div(_) => quote!((#left).wrapping_div(#right)),
72 Rem(_) => quote!((#left).wrapping_rem(#right)),
73 Shl(_) => quote!((#left).wrapping_shl(#right)),
74 Shr(_) => quote!((#left).wrapping_shr(#right)),
75
76 _ => quote!((#left) #op (#right)),
77 }
78 }
79
80 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
81 match op {
82 UnOp::Neg(_) => quote!((#x).wrapping_neg()),
83 _ => quote!(#op #x),
84 }
85 }
86}
87
88struct Algebraic;
89impl Sub for Algebraic {
90 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
91 use syn::BinOp::*;
92 match op {
93 Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
94 Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
95 Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
96 Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
97 Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
98 And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
99 _ => quote!((#left) #op (#right)),
100 }
101 }
102
103 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
104 quote!(#op #x)
105 }
106}
107
108struct Fast;
109impl Sub for Fast {
110 fn sub_bin(&self, op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
111 use syn::BinOp::*;
112 match op {
113 Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
114 Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
115 Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
116 Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
117 Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
118 And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
119 Eq(_) => quote!(((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
120 _ => quote!((#left) #op (#right)),
121 }
122 }
123
124 fn sub_unop(&self, op: UnOp, x: TokenStream) -> TokenStream {
125 quote!(#op #x)
126 }
127}
128
129fn walk(sub: &impl Sub, e: Expr) -> TokenStream {
130 let walk = |e| walk(sub, e);
131 let map_block = |b| map_block(sub, b);
132 match e {
133 Expr::Binary(ExprBinary {
134 left, op, right, ..
135 }) => {
136 let left = walk(*left);
137 let right = walk(*right);
138 sub.sub_bin(op, left, right)
139 }
140 Expr::Unary(ExprUnary { op, expr, .. }) => sub.sub_unop(op, walk(*expr)),
141 Expr::Break(ExprBreak {
142 label,
143 expr: Some(expr),
144 ..
145 }) => {
146 let expr = walk(*expr);
147 quote!(#label #expr)
148 }
149 Expr::Call(ExprCall { func, args, .. }) => {
150 let f = walk(*func);
151 let args = args.into_iter().map(walk);
152 quote!(#f ( #(#args),* ))
153 }
154 Expr::Closure(ExprClosure {
155 lifetimes,
156 constness,
157 movability,
158 asyncness,
159 capture,
160 inputs,
161 output,
162 body,
163 ..
164 }) => {
165 let body = walk(*body);
166 quote!(#lifetimes #constness #movability #asyncness #capture |#inputs| #output #body)
167 }
168 Expr::ForLoop(ExprForLoop {
169 label,
170 pat,
171 expr,
172 body,
173 ..
174 }) => {
175 let (expr, body) = (walk(*expr), map_block(body));
176 quote!(#label for #pat in #expr #body)
177 }
178 Expr::Let(ExprLet { pat, expr, .. }) => {
179 quote_with!(expr = walk(*expr) => let #pat = #expr)
180 }
181 Expr::Const(ExprConst { block, .. }) => {
182 quote_with!(block =map_block(block) => const #block)
183 }
184 Expr::Range(ExprRange {
185 start, limits, end, ..
186 }) => {
187 let (start, end) = (start.map(|x| walk(*x)), end.map(|x| walk(*x)));
188 quote!((#start #limits #end))
189 }
190 Expr::Return(ExprReturn { expr, .. }) => {
191 let expr = expr.map(|x| walk(*x));
192 quote!(return #expr;)
193 }
194 Expr::Try(ExprTry { expr, .. }) => {
195 let expr = walk(*expr);
196 quote!(#expr ?)
197 }
198 Expr::TryBlock(ExprTryBlock { block, .. }) => {
199 let block = map_block(block);
200 quote!(try #block)
201 }
202 Expr::Unsafe(ExprUnsafe { block, .. }) => {
203 quote_with!(block =map_block(block) => unsafe #block)
204 }
205 Expr::While(ExprWhile {
206 label, cond, body, ..
207 }) => {
208 quote_with!(cond = walk(*cond); body =map_block(body) => #label while #cond #body)
209 }
210 Expr::Index(ExprIndex { expr, index, .. }) => {
211 let expr = walk(*expr);
212 let index = walk(*index);
213 quote!(#expr [ #index ])
214 }
215 Expr::Loop(ExprLoop { label, body, .. }) => {
216 quote_with!(body =map_block(body) => #label loop #body)
217 }
218 Expr::Reference(ExprReference {
219 mutability, expr, ..
220 }) => {
221 let expr = walk(*expr);
222 quote!(& #mutability #expr)
223 }
224 Expr::MethodCall(ExprMethodCall {
225 receiver,
226 method,
227 turbofish,
228 args,
229 ..
230 }) => {
231 let receiver = walk(*receiver);
232 let args = args.into_iter().map(walk);
233 quote!(#receiver . #method #turbofish (#(#args,)*))
234 }
235 Expr::If(ExprIf {
236 cond,
237 then_branch,
238 else_branch: Some((_, else_branch)),
239 ..
240 }) => {
241 let (cond, then_branch, else_branch) =
242 (walk(*cond), map_block(then_branch), walk(*else_branch));
243 quote!(if #cond #then_branch else #else_branch)
244 }
245 Expr::If(ExprIf {
246 cond, then_branch, ..
247 }) => {
248 let (cond, then_branch) = (walk(*cond), map_block(then_branch));
249 quote!(if #cond #then_branch)
250 }
251 Expr::Async(ExprAsync {
252 attrs,
253 capture,
254 block,
255 ..
256 }) => {
257 let block = map_block(block);
258 quote!(#(#attrs)* async #capture #block)
259 }
260 Expr::Await(ExprAwait { base, .. }) => {
261 let base = walk(*base);
262 quote!(#base.await)
263 }
264 Expr::Assign(ExprAssign { left, right, .. }) => {
265 let (left, right) = (walk(*left), walk(*right));
266 quote!(#left = #right;)
267 }
268 Expr::Paren(ExprParen { expr, .. }) => {
269 let expr = walk(*expr);
270 quote!(#expr)
271 }
272 Expr::Tuple(ExprTuple { elems, .. }) => {
273 let ts = elems.into_iter().map(walk);
274 quote!((#(#ts,)*))
275 }
276 Expr::Array(ExprArray { elems, .. }) => {
277 let ts = elems.into_iter().map(walk);
278 quote!([#(#ts,)*])
279 }
280 Expr::Repeat(ExprRepeat { expr, len, .. }) => {
281 let x = walk(*expr);
282 let len = walk(*len);
283 quote!([ #x ; #len ])
284 }
285 Expr::Block(ExprBlock {
286 block,
287 label: Some(label),
288 ..
289 }) => {
290 let b = map_block(block);
291 quote! { #label: #b }
292 }
293 Expr::Block(ExprBlock { block, .. }) => map_block(block),
294 e => quote!(#e),
295 }
296}
297
298fn map_block(sub: &impl Sub, Block { stmts, .. }: Block) -> TokenStream {
299 let stmts = stmts.into_iter().map(|x| walk_stmt(sub, x));
300 quote! { { #(#stmts)* } }
301}
302
303fn walk_stmt(sub: &impl Sub, x: Stmt) -> TokenStream {
304 let walk = |e| walk(sub, e);
305 match x {
306 Stmt::Local(Local {
307 pat,
308 init:
309 Some(LocalInit {
310 expr,
311 diverge: Some((_, diverge)),
312 ..
313 }),
314 ..
315 }) => {
316 let expr = walk(*expr);
317 let diverge = walk(*diverge);
318 quote!(let #pat = #expr else { #diverge };)
319 }
320 Stmt::Local(Local {
321 pat,
322 init: Some(LocalInit { expr, .. }),
323 ..
324 }) => {
325 let expr = walk(*expr);
326 quote!(let #pat = #expr;)
327 }
328 Stmt::Item(x) => walk_item(sub, x),
329 Stmt::Expr(e, t) => {
330 let e = walk(e);
331 quote!(#e #t)
332 }
333 e => quote!(#e),
334 }
335}
336
337fn walk_item(sub: &impl Sub, x: Item) -> TokenStream {
338 let walk = |e| walk(sub, e);
339 match x {
340 Item::Const(ItemConst {
341 vis,
342 ident,
343 ty,
344 expr,
345 ..
346 }) => {
347 let expr = walk(*expr);
348 quote!(#vis const #ident : #ty = #expr;)
349 }
350 Item::Fn(ItemFn {
351 vis,
352 attrs,
353 sig,
354 block,
355 }) => {
356 let block = map_block(sub, *block);
357 quote!( #(#attrs)* #vis #sig #block)
358 }
359 Item::Impl(ItemImpl {
360 attrs,
361 unsafety,
362 defaultness,
363 generics,
364 trait_,
365 self_ty,
366 items,
367 ..
368 }) => {
369 let items = items.into_iter().map(|x| match x {
370 ImplItem::Const(ImplItemConst {
371 vis,
372 attrs,
373 defaultness,
374 ident,
375 ty,
376 expr,
377 ..
378 }) => {
379 let expr = walk(expr);
380 quote!(#(#attrs)* #vis #defaultness const #ident: #ty = #expr;)
381 }
382 ImplItem::Fn(ImplItemFn {
383 attrs,
384 vis,
385 defaultness,
386 sig,
387 block,
388 }) => {
389 let block = map_block(sub, block);
390 quote!(#(#attrs)* #vis #defaultness #sig #block)
391 }
392 e => quote!(#e),
393 });
394 let trait_ = trait_.map(|(n, pat, fr)| quote!(#n #pat #fr));
395 quote!(#(#attrs)* #unsafety #defaultness impl #generics #trait_ #self_ty { #(#items)* })
396 }
397 Item::Mod(ItemMod {
398 attrs,
399 vis,
400 ident,
401 content: Some((_, content)),
402 ..
403 }) => {
404 let content = content.into_iter().map(|x| walk_item(sub, x));
405 quote!(#(#attrs)* #vis mod #ident { #(#content)* })
406 }
407 Item::Static(ItemStatic {
408 attrs,
409 vis,
410 mutability,
411 ident,
412 ty,
413 expr,
414 ..
415 }) => {
416 let expr = walk(*expr);
417 quote!(#(#attrs)* #vis static #mutability #ident: #ty = #expr)
418 }
419 e => quote!(#e),
420 }
421}
422
423macro_rules! walk {
424 ($input:ident,$t:expr) => {
425 match parse::<Expr>($input.clone())
426 .map(|x| walk(&$t, x))
427 .map_err(|x| x.to_compile_error().into_token_stream())
428 {
429 Ok(x) => x,
430 Err(e) => parse::<Stmt>($input)
431 .map(|x| walk_stmt(&$t, x))
432 .unwrap_or(e),
433 }
434 .into()
435 };
436}
437
438#[proc_macro]
439pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
440 walk!(input, Basic {})
441}
442
443#[proc_macro]
444pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
445 walk!(input, Fast {})
446}
447
448#[proc_macro]
449pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
450 walk!(input, Algebraic {})
451}
452
453#[proc_macro]
454pub fn wrapping(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
455 walk!(input, Wrapping {})
456}
457
458#[proc_macro_attribute]
459pub fn apply(
460 args: proc_macro::TokenStream,
461 input: proc_macro::TokenStream,
462) -> proc_macro::TokenStream {
463 match &*args.to_string() {
464 "basic" | "" => math(input),
465 "fast" => fast(input),
466 "algebraic" => algebraic(input),
467 "wrapping" => wrapping(input),
468 _ => quote! { compile_error!("type must be {fast, basic, algebraic, wrapping}") }.into(),
469 }
470}