faer_macros/
lib.rs

1use quote::quote;
2use std::collections::HashMap;
3use std::iter;
4use syn::punctuated::Punctuated;
5use syn::token::Comma;
6use syn::visit_mut::{self, VisitMut};
7use syn::{Expr, ExprCall, ExprParen, ExprPath, ExprReference, Ident, Macro, Path, PathSegment};
8
9struct MigrationCtx(HashMap<&'static str, &'static str>);
10
11impl visit_mut::VisitMut for MigrationCtx {
12	fn visit_macro_mut(&mut self, i: &mut syn::Macro) {
13		if let Ok(mut expr) = i.parse_body_with(Punctuated::<Expr, Comma>::parse_terminated) {
14			for expr in expr.iter_mut() {
15				self.visit_expr_mut(expr);
16			}
17
18			*i = Macro {
19				path: i.path.clone(),
20				bang_token: i.bang_token,
21				delimiter: i.delimiter.clone(),
22				tokens: quote! { #expr },
23			};
24		}
25	}
26
27	fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
28		visit_mut::visit_expr_mut(self, i);
29
30		match i {
31			Expr::MethodCall(call) if call.method == "faer_add" => {
32				*i = Expr::Binary(syn::ExprBinary {
33					attrs: vec![],
34					left: call.receiver.clone(),
35					op: syn::BinOp::Add(Default::default()),
36					right: Box::new(call.args[0].clone()),
37				});
38				*i = Expr::Paren(ExprParen {
39					attrs: vec![],
40					paren_token: Default::default(),
41					expr: Box::new(i.clone()),
42				});
43			},
44			Expr::MethodCall(call) if call.method == "faer_sub" => {
45				*i = Expr::Binary(syn::ExprBinary {
46					attrs: vec![],
47					left: call.receiver.clone(),
48					op: syn::BinOp::Sub(Default::default()),
49					right: Box::new(call.args[0].clone()),
50				});
51				*i = Expr::Paren(ExprParen {
52					attrs: vec![],
53					paren_token: Default::default(),
54					expr: Box::new(i.clone()),
55				});
56			},
57			Expr::MethodCall(call) if call.method == "faer_mul" => {
58				*i = Expr::Binary(syn::ExprBinary {
59					attrs: vec![],
60					left: call.receiver.clone(),
61					op: syn::BinOp::Mul(Default::default()),
62					right: Box::new(call.args[0].clone()),
63				});
64				*i = Expr::Paren(ExprParen {
65					attrs: vec![],
66					paren_token: Default::default(),
67					expr: Box::new(i.clone()),
68				});
69			},
70			Expr::MethodCall(call) if call.method == "faer_div" => {
71				*i = Expr::Binary(syn::ExprBinary {
72					attrs: vec![],
73					left: call.receiver.clone(),
74					op: syn::BinOp::Div(Default::default()),
75					right: Box::new(call.args[0].clone()),
76				});
77				*i = Expr::Paren(ExprParen {
78					attrs: vec![],
79					paren_token: Default::default(),
80					expr: Box::new(i.clone()),
81				});
82			},
83			Expr::MethodCall(call) if call.method == "faer_neg" => {
84				*i = Expr::Unary(syn::ExprUnary {
85					attrs: vec![],
86					op: syn::UnOp::Neg(Default::default()),
87					expr: call.receiver.clone(),
88				});
89				*i = Expr::Paren(ExprParen {
90					attrs: vec![],
91					paren_token: Default::default(),
92					expr: Box::new(i.clone()),
93				});
94			},
95
96			Expr::MethodCall(call) if call.method.to_string().starts_with("faer_") => {
97				if let Some(new_method) = self.0.get(&*call.method.to_string()).map(|x| &**x) {
98					*i = math_expr(
99						&Ident::new(new_method, call.method.span()),
100						std::iter::once(&*call.receiver).chain(call.args.iter()),
101					)
102				}
103			},
104
105			_ => {},
106		}
107	}
108}
109
110struct MathCtx;
111
112fn ident_expr(ident: &syn::Ident) -> Expr {
113	Expr::Path(ExprPath {
114		attrs: vec![],
115		qself: None,
116		path: Path {
117			leading_colon: None,
118			segments: Punctuated::from_iter(iter::once(PathSegment {
119				ident: ident.clone(),
120				arguments: syn::PathArguments::None,
121			})),
122		},
123	})
124}
125
126impl visit_mut::VisitMut for MathCtx {
127	fn visit_macro_mut(&mut self, i: &mut syn::Macro) {
128		if let Ok(mut expr) = i.parse_body_with(Punctuated::<Expr, Comma>::parse_terminated) {
129			for expr in expr.iter_mut() {
130				self.visit_expr_mut(expr);
131			}
132
133			*i = Macro {
134				path: i.path.clone(),
135				bang_token: i.bang_token,
136				delimiter: i.delimiter.clone(),
137				tokens: quote! { #expr },
138			};
139		}
140	}
141
142	fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
143		visit_mut::visit_expr_mut(self, i);
144
145		match i {
146			Expr::Unary(unary) => match unary.op {
147				syn::UnOp::Neg(minus) => {
148					*i = Expr::Call(ExprCall {
149						attrs: vec![],
150						func: Box::new(ident_expr(&Ident::new("neg", minus.span))),
151						paren_token: Default::default(),
152						args: std::iter::once((*unary.expr).clone())
153							.map(|e| {
154								Expr::Reference(ExprReference {
155									attrs: vec![],
156									and_token: Default::default(),
157									mutability: None,
158									expr: Box::new(e),
159								})
160							})
161							.collect(),
162					})
163				},
164				_ => {},
165			},
166			Expr::Binary(binop) => {
167				let func = match binop.op {
168					syn::BinOp::Add(plus) => Some(Ident::new("add", plus.span)),
169					syn::BinOp::Sub(minus) => Some(Ident::new("sub", minus.span)),
170					syn::BinOp::Mul(star) => Some(Ident::new("mul", star.span)),
171					syn::BinOp::Div(star) => Some(Ident::new("div", star.span)),
172					_ => None,
173				};
174				if let Some(func) = func {
175					*i = Expr::Call(ExprCall {
176						attrs: vec![],
177						func: Box::new(ident_expr(&func)),
178						paren_token: Default::default(),
179						args: [(*binop.left).clone(), (*binop.right).clone()]
180							.into_iter()
181							.map(|e| {
182								Expr::Reference(ExprReference {
183									attrs: vec![],
184									and_token: Default::default(),
185									mutability: None,
186									expr: Box::new(e),
187								})
188							})
189							.collect(),
190					})
191				}
192			},
193
194			Expr::Call(call) => match &*call.func {
195				Expr::Path(e) if e.path.get_ident().is_some() => {
196					let name = &*e.path.get_ident().unwrap().to_string();
197					if matches!(
198						name,
199						"sqrt"
200							| "from_real" | "copy"
201							| "max" | "min" | "conj"
202							| "absmax" | "abs2" | "abs1"
203							| "abs" | "add" | "sub"
204							| "div" | "mul" | "mul_real"
205							| "mul_pow2" | "hypot"
206							| "neg" | "recip" | "real"
207							| "imag" | "is_nan" | "is_finite"
208							| "is_zero" | "lt_zero"
209							| "gt_zero" | "le_zero"
210							| "ge_zero"
211					) {
212						call.args.iter_mut().for_each(|x| {
213							*x = Expr::Reference(ExprReference {
214								attrs: vec![],
215								and_token: Default::default(),
216								mutability: None,
217								expr: Box::new(x.clone()),
218							})
219						})
220					}
221				},
222				_ => {},
223			},
224			_ => {},
225		}
226	}
227}
228
229fn math_expr<'a>(method: &Ident, args: impl Iterator<Item = &'a Expr>) -> Expr {
230	Expr::Call(ExprCall {
231		attrs: vec![],
232
233		paren_token: Default::default(),
234		args: args.cloned().collect(),
235		func: Box::new(ident_expr(method)),
236	})
237}
238
239#[proc_macro_attribute]
240pub fn math(_: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
241	let Ok(mut item) = syn::parse::<syn::ItemFn>(item.clone()) else {
242		return item;
243	};
244	let mut rust_ctx = MathCtx;
245	rust_ctx.visit_item_fn_mut(&mut item);
246	let item = quote! { #item };
247	item.into()
248}
249
250#[proc_macro_attribute]
251pub fn migrate(_: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
252	let Ok(mut item) = syn::parse::<syn::ItemFn>(item.clone()) else {
253		return item;
254	};
255
256	let mut rust_ctx = MigrationCtx(
257		[
258			//
259			("faer_add", "add"),
260			("faer_sub", "sub"),
261			("faer_mul", "mul"),
262			("faer_div", "div"),
263			("faer_neg", "neg"),
264			("faer_inv", "recip"),
265			("faer_abs", "abs"),
266			("faer_abs2", "abs2"),
267			("faer_sqrt", "sqrt"),
268			("faer_conj", "conj"),
269			("faer_real", "real"),
270			("faer_scale_real", "mul_real"),
271			("faer_scale_power_of_two", "mul_pow2"),
272		]
273		.into_iter()
274		.collect(),
275	);
276	rust_ctx.visit_item_fn_mut(&mut item);
277	let mut rust_ctx = MathCtx;
278	rust_ctx.visit_item_fn_mut(&mut item);
279
280	let item = quote! { #item };
281	item.into()
282}