1use quote::quote;
2use std::collections::HashMap;
3use std::iter;
4use syn::punctuated::Punctuated;
5use syn::token::Comma;
6use syn::visit_mut::{self, VisitMut};
7use syn::{Expr, ExprCall, ExprParen, ExprPath, ExprReference, Ident, Macro, Path, PathSegment};
8
9struct MigrationCtx(HashMap<&'static str, &'static str>);
10
11impl visit_mut::VisitMut for MigrationCtx {
12 fn visit_macro_mut(&mut self, i: &mut syn::Macro) {
13 if let Ok(mut expr) = i.parse_body_with(Punctuated::<Expr, Comma>::parse_terminated) {
14 for expr in expr.iter_mut() {
15 self.visit_expr_mut(expr);
16 }
17
18 *i = Macro {
19 path: i.path.clone(),
20 bang_token: i.bang_token,
21 delimiter: i.delimiter.clone(),
22 tokens: quote! { #expr },
23 };
24 }
25 }
26
27 fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
28 visit_mut::visit_expr_mut(self, i);
29
30 match i {
31 Expr::MethodCall(call) if call.method == "faer_add" => {
32 *i = Expr::Binary(syn::ExprBinary {
33 attrs: vec![],
34 left: call.receiver.clone(),
35 op: syn::BinOp::Add(Default::default()),
36 right: Box::new(call.args[0].clone()),
37 });
38 *i = Expr::Paren(ExprParen {
39 attrs: vec![],
40 paren_token: Default::default(),
41 expr: Box::new(i.clone()),
42 });
43 },
44 Expr::MethodCall(call) if call.method == "faer_sub" => {
45 *i = Expr::Binary(syn::ExprBinary {
46 attrs: vec![],
47 left: call.receiver.clone(),
48 op: syn::BinOp::Sub(Default::default()),
49 right: Box::new(call.args[0].clone()),
50 });
51 *i = Expr::Paren(ExprParen {
52 attrs: vec![],
53 paren_token: Default::default(),
54 expr: Box::new(i.clone()),
55 });
56 },
57 Expr::MethodCall(call) if call.method == "faer_mul" => {
58 *i = Expr::Binary(syn::ExprBinary {
59 attrs: vec![],
60 left: call.receiver.clone(),
61 op: syn::BinOp::Mul(Default::default()),
62 right: Box::new(call.args[0].clone()),
63 });
64 *i = Expr::Paren(ExprParen {
65 attrs: vec![],
66 paren_token: Default::default(),
67 expr: Box::new(i.clone()),
68 });
69 },
70 Expr::MethodCall(call) if call.method == "faer_div" => {
71 *i = Expr::Binary(syn::ExprBinary {
72 attrs: vec![],
73 left: call.receiver.clone(),
74 op: syn::BinOp::Div(Default::default()),
75 right: Box::new(call.args[0].clone()),
76 });
77 *i = Expr::Paren(ExprParen {
78 attrs: vec![],
79 paren_token: Default::default(),
80 expr: Box::new(i.clone()),
81 });
82 },
83 Expr::MethodCall(call) if call.method == "faer_neg" => {
84 *i = Expr::Unary(syn::ExprUnary {
85 attrs: vec![],
86 op: syn::UnOp::Neg(Default::default()),
87 expr: call.receiver.clone(),
88 });
89 *i = Expr::Paren(ExprParen {
90 attrs: vec![],
91 paren_token: Default::default(),
92 expr: Box::new(i.clone()),
93 });
94 },
95
96 Expr::MethodCall(call) if call.method.to_string().starts_with("faer_") => {
97 if let Some(new_method) = self.0.get(&*call.method.to_string()).map(|x| &**x) {
98 *i = math_expr(
99 &Ident::new(new_method, call.method.span()),
100 std::iter::once(&*call.receiver).chain(call.args.iter()),
101 )
102 }
103 },
104
105 _ => {},
106 }
107 }
108}
109
110struct MathCtx;
111
112fn ident_expr(ident: &syn::Ident) -> Expr {
113 Expr::Path(ExprPath {
114 attrs: vec![],
115 qself: None,
116 path: Path {
117 leading_colon: None,
118 segments: Punctuated::from_iter(iter::once(PathSegment {
119 ident: ident.clone(),
120 arguments: syn::PathArguments::None,
121 })),
122 },
123 })
124}
125
126impl visit_mut::VisitMut for MathCtx {
127 fn visit_macro_mut(&mut self, i: &mut syn::Macro) {
128 if let Ok(mut expr) = i.parse_body_with(Punctuated::<Expr, Comma>::parse_terminated) {
129 for expr in expr.iter_mut() {
130 self.visit_expr_mut(expr);
131 }
132
133 *i = Macro {
134 path: i.path.clone(),
135 bang_token: i.bang_token,
136 delimiter: i.delimiter.clone(),
137 tokens: quote! { #expr },
138 };
139 }
140 }
141
142 fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
143 visit_mut::visit_expr_mut(self, i);
144
145 match i {
146 Expr::Unary(unary) => match unary.op {
147 syn::UnOp::Neg(minus) => {
148 *i = Expr::Call(ExprCall {
149 attrs: vec![],
150 func: Box::new(ident_expr(&Ident::new("neg", minus.span))),
151 paren_token: Default::default(),
152 args: std::iter::once((*unary.expr).clone())
153 .map(|e| {
154 Expr::Reference(ExprReference {
155 attrs: vec![],
156 and_token: Default::default(),
157 mutability: None,
158 expr: Box::new(e),
159 })
160 })
161 .collect(),
162 })
163 },
164 _ => {},
165 },
166 Expr::Binary(binop) => {
167 let func = match binop.op {
168 syn::BinOp::Add(plus) => Some(Ident::new("add", plus.span)),
169 syn::BinOp::Sub(minus) => Some(Ident::new("sub", minus.span)),
170 syn::BinOp::Mul(star) => Some(Ident::new("mul", star.span)),
171 syn::BinOp::Div(star) => Some(Ident::new("div", star.span)),
172 _ => None,
173 };
174 if let Some(func) = func {
175 *i = Expr::Call(ExprCall {
176 attrs: vec![],
177 func: Box::new(ident_expr(&func)),
178 paren_token: Default::default(),
179 args: [(*binop.left).clone(), (*binop.right).clone()]
180 .into_iter()
181 .map(|e| {
182 Expr::Reference(ExprReference {
183 attrs: vec![],
184 and_token: Default::default(),
185 mutability: None,
186 expr: Box::new(e),
187 })
188 })
189 .collect(),
190 })
191 }
192 },
193
194 Expr::Call(call) => match &*call.func {
195 Expr::Path(e) if e.path.get_ident().is_some() => {
196 let name = &*e.path.get_ident().unwrap().to_string();
197 if matches!(
198 name,
199 "sqrt"
200 | "from_real" | "copy"
201 | "max" | "min" | "conj"
202 | "absmax" | "abs2" | "abs1"
203 | "abs" | "add" | "sub"
204 | "div" | "mul" | "mul_real"
205 | "mul_pow2" | "hypot"
206 | "neg" | "recip" | "real"
207 | "imag" | "is_nan" | "is_finite"
208 | "is_zero" | "lt_zero"
209 | "gt_zero" | "le_zero"
210 | "ge_zero"
211 ) {
212 call.args.iter_mut().for_each(|x| {
213 *x = Expr::Reference(ExprReference {
214 attrs: vec![],
215 and_token: Default::default(),
216 mutability: None,
217 expr: Box::new(x.clone()),
218 })
219 })
220 }
221 },
222 _ => {},
223 },
224 _ => {},
225 }
226 }
227}
228
229fn math_expr<'a>(method: &Ident, args: impl Iterator<Item = &'a Expr>) -> Expr {
230 Expr::Call(ExprCall {
231 attrs: vec![],
232
233 paren_token: Default::default(),
234 args: args.cloned().collect(),
235 func: Box::new(ident_expr(method)),
236 })
237}
238
239#[proc_macro_attribute]
240pub fn math(_: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
241 let Ok(mut item) = syn::parse::<syn::ItemFn>(item.clone()) else {
242 return item;
243 };
244 let mut rust_ctx = MathCtx;
245 rust_ctx.visit_item_fn_mut(&mut item);
246 let item = quote! { #item };
247 item.into()
248}
249
250#[proc_macro_attribute]
251pub fn migrate(_: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
252 let Ok(mut item) = syn::parse::<syn::ItemFn>(item.clone()) else {
253 return item;
254 };
255
256 let mut rust_ctx = MigrationCtx(
257 [
258 ("faer_add", "add"),
260 ("faer_sub", "sub"),
261 ("faer_mul", "mul"),
262 ("faer_div", "div"),
263 ("faer_neg", "neg"),
264 ("faer_inv", "recip"),
265 ("faer_abs", "abs"),
266 ("faer_abs2", "abs2"),
267 ("faer_sqrt", "sqrt"),
268 ("faer_conj", "conj"),
269 ("faer_real", "real"),
270 ("faer_scale_real", "mul_real"),
271 ("faer_scale_power_of_two", "mul_pow2"),
272 ]
273 .into_iter()
274 .collect(),
275 );
276 rust_ctx.visit_item_fn_mut(&mut item);
277 let mut rust_ctx = MathCtx;
278 rust_ctx.visit_item_fn_mut(&mut item);
279
280 let item = quote! { #item };
281 item.into()
282}