1#![deny(missing_docs)]
6#![deny(clippy::suspicious)]
7
8mod util;
9
10use astro_float_num::{Consts, EXPONENT_BIT_SIZE};
11use proc_macro2::TokenStream;
12use quote::quote;
13use syn::{
14 parse::Parse, spanned::Spanned, BinOp, Error, Expr, ExprBinary, ExprCall, ExprGroup, ExprLit,
15 ExprParen, ExprPath, ExprUnary, Lit, Token, UnOp,
16};
17use util::{check_arg_num, str_to_bigfloat_expr};
18
19const SPEC_ADD_ERR: usize = 32;
23
24struct MacroInput {
25 expr: Expr,
26 ctx: Expr,
27}
28
29impl Parse for MacroInput {
30 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
31 let expr = input.parse()?;
32 input.parse::<Token![,]>()?;
33
34 let ctx = input.parse()?;
35
36 Ok(MacroInput { expr, ctx })
37 }
38}
39
40fn traverse_binary(
41 expr: &ExprBinary,
42 err: &mut Vec<usize>,
43 cc: &mut Consts,
44) -> Result<TokenStream, Error> {
45 let left_expr = traverse_expr(&expr.left, err, cc)?;
46 let right_expr = traverse_expr(&expr.right, err, cc)?;
47
48 let errs_id = err.len();
49
50 let ts = match expr.op {
51 BinOp::Add(_) => {
52 err.push(2);
53 quote!({
54 let arg1 = #left_expr;
55 let arg2 = #right_expr;
56 let ret = astro_float::BigFloat::add(&arg1, &arg2, p_wrk, astro_float::RoundingMode::None);
57 if arg1.inexact() || arg2.inexact() {
58 if let (Some(e1), Some(e2), Some(e3)) = (arg1.exponent(), arg2.exponent(), ret.exponent()) {
59 if (e1 as isize - e2 as isize).abs() <= 1 && arg1.sign() != arg2.sign() {
60 let newerr = (e1.max(e2) as isize - e3 as isize).unsigned_abs() + 1;
61 if errs[#errs_id] < newerr {
62 errs[#errs_id] = newerr;
63 continue;
64 }
65 }
66 }
67 }
68 ret
69 })
70 }
71 BinOp::Sub(_) => {
72 err.push(2);
73 quote!({
74 let arg1 = #left_expr;
75 let arg2 = #right_expr;
76 let ret = astro_float::BigFloat::sub(&arg1, &arg2, p_wrk, astro_float::RoundingMode::None);
77 if arg1.inexact() || arg2.inexact() {
78 if let (Some(e1), Some(e2), Some(e3)) = (arg1.exponent(), arg2.exponent(), ret.exponent()) {
79 if (e1 as isize - e2 as isize).abs() <= 1 && arg1.sign() == arg2.sign() {
80 let newerr = (e1.max(e2) as isize - e3 as isize).unsigned_abs() + 1;
81 if errs[#errs_id] < newerr {
82 errs[#errs_id] = newerr;
83 continue;
84 }
85 }
86 }
87 }
88 ret
89 })
90 }
91 BinOp::Mul(_) => {
92 err.push(3);
93 quote!(
94 astro_float::BigFloat::mul(&(#left_expr), &(#right_expr), p_wrk, astro_float::RoundingMode::None))
95 }
96 BinOp::Div(_) => {
97 err.push(3);
98 quote!(astro_float::BigFloat::div(&(#left_expr), &(#right_expr), p_wrk, astro_float::RoundingMode::None))
99 }
100 BinOp::Rem(_) => {
101 quote!(astro_float::BigFloat::rem(&(#left_expr), &(#right_expr)))
102 }
103 _ => return Err(Error::new(
104 expr.span(),
105 "unexpected binary operator. Only \"+\", \"-\", \"*\", \"/\", and \"%\" are allowed.",
106 )),
107 };
108
109 Ok(ts)
110}
111
112fn one_arg_fun(
113 fun: TokenStream,
114 expr: &ExprCall,
115 initial_err: usize,
116 err: &mut Vec<usize>,
117 cc: &mut Consts,
118 use_cc: bool,
119) -> Result<TokenStream, Error> {
120 check_arg_num(1, expr)?;
121
122 let arg = traverse_expr(&expr.args[0], err, cc)?;
123 err.push(initial_err);
124
125 let ret = if use_cc {
126 quote!(#fun(&(#arg), p_wrk, astro_float::RoundingMode::None, cc))
127 } else {
128 quote!(#fun(&(#arg), p_wrk, astro_float::RoundingMode::None))
129 };
130
131 Ok(ret)
132}
133
134fn one_arg_fun_errcheck(
135 fun: TokenStream,
136 expr: &ExprCall,
137 initial_err: usize,
138 err: &mut Vec<usize>,
139 errcheck: TokenStream,
140 cc: &mut Consts,
141) -> Result<TokenStream, Error> {
142 check_arg_num(1, expr)?;
143
144 let arg = traverse_expr(&expr.args[0], err, cc)?;
145 let errs_id = err.len();
146 err.push(initial_err);
147
148 Ok(quote!({
149 let arg = #arg;
150
151 let newerr = astro_float::macro_util::compute_added_err(#errcheck);
152 if errs[#errs_id] < newerr {
153 errs[#errs_id] = newerr;
154 continue;
155 }
156
157 #fun(&arg, p_wrk, astro_float::RoundingMode::None, cc)
158 }))
159}
160
161fn trig_fun(
162 fun: TokenStream,
163 expr: &ExprCall,
164 initial_err: usize,
165 err: &mut Vec<usize>,
166 errfun: TokenStream,
167 cc: &mut Consts,
168) -> Result<TokenStream, Error> {
169 check_arg_num(1, expr)?;
170
171 let arg = traverse_expr(&expr.args[0], err, cc)?;
172 let errs_id = err.len();
173 err.push(initial_err);
174
175 Ok(quote!({
176 let arg = astro_float::macro_util::check_exponent_range(#arg, emin, emax);
177
178 let newerr = astro_float::macro_util::compute_added_err(astro_float::macro_util::ErrAlgo::Trig(&arg, p_wrk, #errfun, cc, emin));
179 if errs[#errs_id] < newerr {
180 errs[#errs_id] = newerr;
181 continue;
182 }
183
184 #fun(&arg, p_wrk, astro_float::RoundingMode::None, cc)
185 }))
186}
187
188fn two_arg_fun_errcheck(
189 fun: TokenStream,
190 expr: &ExprCall,
191 initial_err: usize,
192 err: &mut Vec<usize>,
193 errcheck: TokenStream,
194 cc: &mut Consts,
195) -> Result<TokenStream, Error> {
196 check_arg_num(2, expr)?;
197
198 let arg1 = traverse_expr(&expr.args[0], err, cc)?;
199 let arg2 = traverse_expr(&expr.args[1], err, cc)?;
200
201 let errs_id = err.len();
202
203 err.push(initial_err);
204
205 Ok(quote!({
206 let arg1 = #arg1;
207 let arg2 = #arg2;
208
209 let newerr = astro_float::macro_util::compute_added_err(#errcheck);
210 if errs[#errs_id] < newerr {
211 errs[#errs_id] = newerr;
212 continue;
213 }
214
215 #fun(&arg1, &arg2, p_wrk, astro_float::RoundingMode::None, cc)
216 }))
217}
218
219fn traverse_call(
220 expr: &ExprCall,
221 err: &mut Vec<usize>,
222 cc: &mut Consts,
223) -> Result<TokenStream, Error> {
224 let errmes = "unexpected function name. Only \"recip\", \"sqrt\", \"cbrt\", \"ln\", \"log2\", \"log10\", \"log\", \"exp\", \"pow\", \"sin\", \"cos\", \"tan\", \"asin\", \"acos\", \"atan\", \"sinh\", \"cosh\", \"tanh\", \"asinh\", \"acosh\", \"atanh\" are allowed.";
225
226 if let Expr::Path(fun) = expr.func.as_ref() {
227 if let Some(fname) = fun.path.get_ident() {
228 let ts = match fname.to_string().as_str() {
229 "recip" => one_arg_fun(
230 quote!(astro_float::BigFloat::reciprocal),
231 expr,
232 2,
233 err,
234 cc,
235 false,
236 ),
237 "sqrt" => one_arg_fun(quote!(astro_float::BigFloat::sqrt), expr, 1, err, cc, false),
238 "cbrt" => one_arg_fun(quote!(astro_float::BigFloat::cbrt), expr, 1, err, cc, false),
239 "ln" => one_arg_fun_errcheck(
240 quote!(astro_float::BigFloat::ln),
241 expr,
242 SPEC_ADD_ERR,
243 err,
244 quote!(astro_float::macro_util::ErrAlgo::Log(&arg, 2, emin)),
245 cc,
246 ),
247 "log2" => one_arg_fun_errcheck(
248 quote!(astro_float::BigFloat::log2),
249 expr,
250 SPEC_ADD_ERR,
251 err,
252 quote!(astro_float::macro_util::ErrAlgo::Log(&arg, 3, emin)),
253 cc,
254 ),
255 "log10" => one_arg_fun_errcheck(
256 quote!(astro_float::BigFloat::log10),
257 expr,
258 SPEC_ADD_ERR,
259 err,
260 quote!(astro_float::macro_util::ErrAlgo::Log(&arg, 6, emin)),
261 cc,
262 ),
263 "log" => two_arg_fun_errcheck(
264 quote!(astro_float::BigFloat::log),
265 expr,
266 SPEC_ADD_ERR,
267 err,
268 quote!(astro_float::macro_util::ErrAlgo::Log2(&arg2, &arg1, emin)),
269 cc,
270 ),
271 "exp" => one_arg_fun(
272 quote!(astro_float::BigFloat::exp),
273 expr,
274 EXPONENT_BIT_SIZE + 1,
275 err,
276 cc,
277 true,
278 ),
279 "pow" => two_arg_fun_errcheck(
280 quote!(astro_float::BigFloat::pow),
281 expr,
282 EXPONENT_BIT_SIZE + SPEC_ADD_ERR,
283 err,
284 quote!(astro_float::macro_util::ErrAlgo::Pow(&arg1, &arg2, emin)),
285 cc,
286 ),
287 "sin" => trig_fun(
288 quote!(astro_float::BigFloat::sin),
289 expr,
290 SPEC_ADD_ERR,
291 err,
292 quote!(astro_float::macro_util::TrigFun::Sin),
293 cc,
294 ),
295 "cos" => trig_fun(
296 quote!(astro_float::BigFloat::cos),
297 expr,
298 SPEC_ADD_ERR,
299 err,
300 quote!(astro_float::macro_util::TrigFun::Cos),
301 cc,
302 ),
303 "tan" => trig_fun(
304 quote!(astro_float::BigFloat::tan),
305 expr,
306 SPEC_ADD_ERR,
307 err,
308 quote!(astro_float::macro_util::TrigFun::Tan),
309 cc,
310 ),
311 "asin" => one_arg_fun_errcheck(
312 quote!(astro_float::BigFloat::asin),
313 expr,
314 SPEC_ADD_ERR / 2,
315 err,
316 quote!(astro_float::macro_util::ErrAlgo::Asin(&arg, emin)),
317 cc,
318 ),
319 "acos" => one_arg_fun_errcheck(
320 quote!(astro_float::BigFloat::acos),
321 expr,
322 SPEC_ADD_ERR / 2,
323 err,
324 quote!(astro_float::macro_util::ErrAlgo::Acos(&arg, emin)),
325 cc,
326 ),
327 "atan" => one_arg_fun(quote!(astro_float::BigFloat::atan), expr, 2, err, cc, true),
328 "sinh" => one_arg_fun(
329 quote!(astro_float::BigFloat::sinh),
330 expr,
331 EXPONENT_BIT_SIZE + 1,
332 err,
333 cc,
334 true,
335 ),
336 "cosh" => one_arg_fun(
337 quote!(astro_float::BigFloat::cosh),
338 expr,
339 EXPONENT_BIT_SIZE + 1,
340 err,
341 cc,
342 true,
343 ),
344 "tanh" => one_arg_fun(quote!(astro_float::BigFloat::tanh), expr, 2, err, cc, true),
345 "asinh" => {
346 one_arg_fun(quote!(astro_float::BigFloat::asinh), expr, 2, err, cc, true)
347 }
348 "acosh" => one_arg_fun_errcheck(
349 quote!(astro_float::BigFloat::acosh),
350 expr,
351 SPEC_ADD_ERR,
352 err,
353 quote!(astro_float::macro_util::ErrAlgo::Acosh(&arg, emin)),
354 cc,
355 ),
356 "atanh" => one_arg_fun_errcheck(
357 quote!(astro_float::BigFloat::atanh),
358 expr,
359 SPEC_ADD_ERR,
360 err,
361 quote!(astro_float::macro_util::ErrAlgo::Atanh(&arg, emin)),
362 cc,
363 ),
364 _ => return Err(Error::new(expr.span(), errmes)),
365 }?;
366
367 return Ok(ts);
368 }
369 }
370 Err(Error::new(expr.span(), errmes))
371}
372
373fn traverse_group(
374 expr: &ExprGroup,
375 err: &mut Vec<usize>,
376 cc: &mut Consts,
377) -> Result<TokenStream, Error> {
378 traverse_expr(&expr.expr, err, cc)
379}
380
381fn traverse_lit(expr: &ExprLit, cc: &mut Consts) -> Result<TokenStream, Error> {
382 let span = expr.span();
383
384 match &expr.lit {
385 Lit::Str(v) => str_to_bigfloat_expr(&v.value(), span, cc),
386 Lit::Int(v) => str_to_bigfloat_expr(v.base10_digits(), span, cc),
387 Lit::Float(v) => str_to_bigfloat_expr(v.base10_digits(), span, cc),
388 _ => Err(Error::new(
389 expr.span(),
390 "unexpected literal. Only string, integer, or floating point literals are supported.",
391 )),
392 }
393}
394
395fn traverse_paren(
396 expr: &ExprParen,
397 err: &mut Vec<usize>,
398 cc: &mut Consts,
399) -> Result<TokenStream, Error> {
400 traverse_expr(&expr.expr, err, cc)
401}
402
403fn traverse_path(expr: &ExprPath) -> Result<TokenStream, Error> {
404 Ok(if expr.path.is_ident("pi") {
405 quote!({ cc.pi(p_wrk, astro_float::RoundingMode::None) })
406 } else if expr.path.is_ident("e") {
407 quote!({ cc.e(p_wrk, astro_float::RoundingMode::None) })
408 } else if expr.path.is_ident("ln_2") {
409 quote!({ cc.ln_2(p_wrk, astro_float::RoundingMode::None) })
410 } else if expr.path.is_ident("ln_10") {
411 quote!({ cc.ln_10(p_wrk, astro_float::RoundingMode::None) })
412 } else {
413 quote!({
414 let mut arg = astro_float::BigFloat::from_ext((#expr).clone(), p_wrk, astro_float::RoundingMode::ToEven, cc);
415 arg.set_inexact(false);
416 arg = astro_float::macro_util::check_exponent_range(arg, emin, emax);
417 arg
418 })
419 })
420}
421
422fn traverse_unary(
423 expr: &ExprUnary,
424 err: &mut Vec<usize>,
425 cc: &mut Consts,
426) -> Result<TokenStream, Error> {
427 let op_expr = traverse_expr(&expr.expr, err, cc)?;
428
429 match expr.op {
430 UnOp::Neg(_) => Ok(quote!(astro_float::BigFloat::neg(&(#op_expr)))),
431 _ => Err(Error::new(
432 expr.span(),
433 "unexpected unary operator. Only \"-\" is allowed.",
434 )),
435 }
436}
437
438fn traverse_expr(expr: &Expr, err: &mut Vec<usize>, cc: &mut Consts) -> Result<TokenStream, Error> {
439 match expr {
440 Expr::Binary(e) => traverse_binary(e, err,cc),
441 Expr::Call(e) => traverse_call(e, err,cc),
442 Expr::Group(e) => traverse_group(e, err,cc),
443 Expr::Lit(e) => traverse_lit(e, cc),
444 Expr::Paren(e) => traverse_paren(e, err, cc),
445 Expr::Path(e) => traverse_path(e),
446 Expr::Unary(e) => traverse_unary(e, err, cc),
447 _ => Err(Error::new(expr.span(), "unexpected expression. Only operators \"+\", \"-\", \"*\", \"/\", \"%\", functions \"recip\", \"sqrt\", \"cbrt\", \"ln\", \"log2\", \"log10\", \"log\", \"exp\", \"pow\", \"sin\", \"cos\", \"tan\", \"asin\", \"acos\", \"atan\", \"sinh\", \"cosh\", \"tanh\", \"asinh\", \"acosh\", \"atanh\", literals and variables, and grouping with parentheses are supported.")),
448 }
449}
450
451#[proc_macro]
454#[allow(missing_docs)]
455pub fn expr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
456 let pmi = syn::parse_macro_input!(input as MacroInput);
457
458 let MacroInput { expr, ctx } = pmi;
459
460 let mut err = Vec::new();
461
462 let mut cc = Consts::new().expect("Failed to initialize constant cache.");
463
464 let expr = traverse_expr(&expr, &mut err, &mut cc).unwrap_or_else(|e| e.to_compile_error());
465
466 let err_sz = err.len();
467
468 let ret = quote!({
469 use astro_float::FromExt;
470 use astro_float::ctx::Contextable;
471
472 let mut ctx = &mut (#ctx);
473 let p: usize = ctx.precision();
474 let rm = ctx.rounding_mode();
475 let emin = ctx.emin();
476 let emax = ctx.emax();
477 let cc = ctx.consts();
478
479 let mut p_rnd = p + astro_float::WORD_BIT_SIZE;
480 let mut errs: [usize; #err_sz] = [#(#err, )*];
481
482 loop {
483 let p_wrk = p_rnd.saturating_add(errs.iter().sum());
484
485 let mut ret: astro_float::BigFloat = (#expr).into();
486
487 if let Err(err) = ret.set_precision(p, rm) {
488 ret = astro_float::BigFloat::nan(Some(err));
489 }
490
491 break astro_float::macro_util::check_exponent_range(ret, emin, emax);
492 }
493 });
494
495 ret.into()
496}