astro_float_macro/
lib.rs

1//! Macros for multiple precision floating point numbers library `astro-float`.
2//!
3//! See main crate [docs](https://docs.rs/astro-float/latest/astro_float/).
4
5#![deny(missing_docs)]
6#![deny(clippy::suspicious)]
7
8mod util;
9
10use astro_float_num::{Consts, EXPONENT_BIT_SIZE};
11use proc_macro2::TokenStream;
12use quote::quote;
13use syn::{
14    parse::Parse, spanned::Spanned, BinOp, Error, Expr, ExprBinary, ExprCall, ExprGroup, ExprLit,
15    ExprParen, ExprPath, ExprUnary, Lit, Token, UnOp,
16};
17use util::{check_arg_num, str_to_bigfloat_expr};
18
19// Speculative error estimation.
20// This error is added upfront, before actual error is known.
21// It helps to avoid additional recalculations due to changing error estimation.
22const SPEC_ADD_ERR: usize = 32;
23
24struct MacroInput {
25    expr: Expr,
26    ctx: Expr,
27}
28
29impl Parse for MacroInput {
30    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
31        let expr = input.parse()?;
32        input.parse::<Token![,]>()?;
33
34        let ctx = input.parse()?;
35
36        Ok(MacroInput { expr, ctx })
37    }
38}
39
40fn traverse_binary(
41    expr: &ExprBinary,
42    err: &mut Vec<usize>,
43    cc: &mut Consts,
44) -> Result<TokenStream, Error> {
45    let left_expr = traverse_expr(&expr.left, err, cc)?;
46    let right_expr = traverse_expr(&expr.right, err, cc)?;
47
48    let errs_id = err.len();
49
50    let ts = match expr.op {
51        BinOp::Add(_) => {
52            err.push(2);
53            quote!({
54                let arg1 = #left_expr;
55                let arg2 = #right_expr;
56                let ret = astro_float::BigFloat::add(&arg1, &arg2, p_wrk, astro_float::RoundingMode::None);
57                if arg1.inexact() || arg2.inexact() {
58                    if let (Some(e1), Some(e2), Some(e3)) = (arg1.exponent(), arg2.exponent(), ret.exponent()) {
59                        if (e1 as isize - e2 as isize).abs() <= 1 && arg1.sign() != arg2.sign() {
60                            let newerr = (e1.max(e2) as isize - e3 as isize).unsigned_abs() + 1;
61                            if errs[#errs_id] < newerr {
62                                errs[#errs_id] = newerr;
63                                continue;
64                            }
65                        }
66                    }
67                }
68                ret
69            })
70        }
71        BinOp::Sub(_) => {
72            err.push(2);
73            quote!({
74                let arg1 = #left_expr;
75                let arg2 = #right_expr;
76                let ret = astro_float::BigFloat::sub(&arg1, &arg2, p_wrk, astro_float::RoundingMode::None);
77                if arg1.inexact() || arg2.inexact() {
78                    if let (Some(e1), Some(e2), Some(e3)) = (arg1.exponent(), arg2.exponent(), ret.exponent()) {
79                        if (e1 as isize - e2 as isize).abs() <= 1 && arg1.sign() == arg2.sign() {
80                            let newerr = (e1.max(e2) as isize - e3 as isize).unsigned_abs() + 1;
81                            if errs[#errs_id] < newerr {
82                                errs[#errs_id] = newerr;
83                                continue;
84                            }
85                        }
86                    }
87                }
88                ret
89            })
90        }
91        BinOp::Mul(_) => {
92            err.push(3);
93            quote!(
94                astro_float::BigFloat::mul(&(#left_expr), &(#right_expr), p_wrk, astro_float::RoundingMode::None))
95        }
96        BinOp::Div(_) => {
97            err.push(3);
98            quote!(astro_float::BigFloat::div(&(#left_expr), &(#right_expr), p_wrk, astro_float::RoundingMode::None))
99        }
100        BinOp::Rem(_) => {
101            quote!(astro_float::BigFloat::rem(&(#left_expr), &(#right_expr)))
102        }
103        _ => return Err(Error::new(
104            expr.span(),
105            "unexpected binary operator. Only \"+\", \"-\", \"*\", \"/\", and \"%\" are allowed.",
106        )),
107    };
108
109    Ok(ts)
110}
111
112fn one_arg_fun(
113    fun: TokenStream,
114    expr: &ExprCall,
115    initial_err: usize,
116    err: &mut Vec<usize>,
117    cc: &mut Consts,
118    use_cc: bool,
119) -> Result<TokenStream, Error> {
120    check_arg_num(1, expr)?;
121
122    let arg = traverse_expr(&expr.args[0], err, cc)?;
123    err.push(initial_err);
124
125    let ret = if use_cc {
126        quote!(#fun(&(#arg), p_wrk, astro_float::RoundingMode::None, cc))
127    } else {
128        quote!(#fun(&(#arg), p_wrk, astro_float::RoundingMode::None))
129    };
130
131    Ok(ret)
132}
133
134fn one_arg_fun_errcheck(
135    fun: TokenStream,
136    expr: &ExprCall,
137    initial_err: usize,
138    err: &mut Vec<usize>,
139    errcheck: TokenStream,
140    cc: &mut Consts,
141) -> Result<TokenStream, Error> {
142    check_arg_num(1, expr)?;
143
144    let arg = traverse_expr(&expr.args[0], err, cc)?;
145    let errs_id = err.len();
146    err.push(initial_err);
147
148    Ok(quote!({
149        let arg = #arg;
150
151        let newerr = astro_float::macro_util::compute_added_err(#errcheck);
152        if errs[#errs_id] < newerr {
153            errs[#errs_id] = newerr;
154            continue;
155        }
156
157        #fun(&arg, p_wrk, astro_float::RoundingMode::None, cc)
158    }))
159}
160
161fn trig_fun(
162    fun: TokenStream,
163    expr: &ExprCall,
164    initial_err: usize,
165    err: &mut Vec<usize>,
166    errfun: TokenStream,
167    cc: &mut Consts,
168) -> Result<TokenStream, Error> {
169    check_arg_num(1, expr)?;
170
171    let arg = traverse_expr(&expr.args[0], err, cc)?;
172    let errs_id = err.len();
173    err.push(initial_err);
174
175    Ok(quote!({
176        let arg = astro_float::macro_util::check_exponent_range(#arg, emin, emax);
177
178        let newerr = astro_float::macro_util::compute_added_err(astro_float::macro_util::ErrAlgo::Trig(&arg, p_wrk, #errfun, cc, emin));
179        if errs[#errs_id] < newerr {
180            errs[#errs_id] = newerr;
181            continue;
182        }
183
184        #fun(&arg, p_wrk, astro_float::RoundingMode::None, cc)
185    }))
186}
187
188fn two_arg_fun_errcheck(
189    fun: TokenStream,
190    expr: &ExprCall,
191    initial_err: usize,
192    err: &mut Vec<usize>,
193    errcheck: TokenStream,
194    cc: &mut Consts,
195) -> Result<TokenStream, Error> {
196    check_arg_num(2, expr)?;
197
198    let arg1 = traverse_expr(&expr.args[0], err, cc)?;
199    let arg2 = traverse_expr(&expr.args[1], err, cc)?;
200
201    let errs_id = err.len();
202
203    err.push(initial_err);
204
205    Ok(quote!({
206        let arg1 = #arg1;
207        let arg2 = #arg2;
208
209        let newerr = astro_float::macro_util::compute_added_err(#errcheck);
210        if errs[#errs_id] < newerr {
211            errs[#errs_id] = newerr;
212            continue;
213        }
214
215        #fun(&arg1, &arg2, p_wrk, astro_float::RoundingMode::None, cc)
216    }))
217}
218
219fn traverse_call(
220    expr: &ExprCall,
221    err: &mut Vec<usize>,
222    cc: &mut Consts,
223) -> Result<TokenStream, Error> {
224    let errmes = "unexpected function name. Only \"recip\", \"sqrt\", \"cbrt\", \"ln\", \"log2\", \"log10\", \"log\", \"exp\", \"pow\", \"sin\", \"cos\", \"tan\", \"asin\", \"acos\", \"atan\", \"sinh\", \"cosh\", \"tanh\", \"asinh\", \"acosh\", \"atanh\" are allowed.";
225
226    if let Expr::Path(fun) = expr.func.as_ref() {
227        if let Some(fname) = fun.path.get_ident() {
228            let ts = match fname.to_string().as_str() {
229                "recip" => one_arg_fun(
230                    quote!(astro_float::BigFloat::reciprocal),
231                    expr,
232                    2,
233                    err,
234                    cc,
235                    false,
236                ),
237                "sqrt" => one_arg_fun(quote!(astro_float::BigFloat::sqrt), expr, 1, err, cc, false),
238                "cbrt" => one_arg_fun(quote!(astro_float::BigFloat::cbrt), expr, 1, err, cc, false),
239                "ln" => one_arg_fun_errcheck(
240                    quote!(astro_float::BigFloat::ln),
241                    expr,
242                    SPEC_ADD_ERR,
243                    err,
244                    quote!(astro_float::macro_util::ErrAlgo::Log(&arg, 2, emin)),
245                    cc,
246                ),
247                "log2" => one_arg_fun_errcheck(
248                    quote!(astro_float::BigFloat::log2),
249                    expr,
250                    SPEC_ADD_ERR,
251                    err,
252                    quote!(astro_float::macro_util::ErrAlgo::Log(&arg, 3, emin)),
253                    cc,
254                ),
255                "log10" => one_arg_fun_errcheck(
256                    quote!(astro_float::BigFloat::log10),
257                    expr,
258                    SPEC_ADD_ERR,
259                    err,
260                    quote!(astro_float::macro_util::ErrAlgo::Log(&arg, 6, emin)),
261                    cc,
262                ),
263                "log" => two_arg_fun_errcheck(
264                    quote!(astro_float::BigFloat::log),
265                    expr,
266                    SPEC_ADD_ERR,
267                    err,
268                    quote!(astro_float::macro_util::ErrAlgo::Log2(&arg2, &arg1, emin)),
269                    cc,
270                ),
271                "exp" => one_arg_fun(
272                    quote!(astro_float::BigFloat::exp),
273                    expr,
274                    EXPONENT_BIT_SIZE + 1,
275                    err,
276                    cc,
277                    true,
278                ),
279                "pow" => two_arg_fun_errcheck(
280                    quote!(astro_float::BigFloat::pow),
281                    expr,
282                    EXPONENT_BIT_SIZE + SPEC_ADD_ERR,
283                    err,
284                    quote!(astro_float::macro_util::ErrAlgo::Pow(&arg1, &arg2, emin)),
285                    cc,
286                ),
287                "sin" => trig_fun(
288                    quote!(astro_float::BigFloat::sin),
289                    expr,
290                    SPEC_ADD_ERR,
291                    err,
292                    quote!(astro_float::macro_util::TrigFun::Sin),
293                    cc,
294                ),
295                "cos" => trig_fun(
296                    quote!(astro_float::BigFloat::cos),
297                    expr,
298                    SPEC_ADD_ERR,
299                    err,
300                    quote!(astro_float::macro_util::TrigFun::Cos),
301                    cc,
302                ),
303                "tan" => trig_fun(
304                    quote!(astro_float::BigFloat::tan),
305                    expr,
306                    SPEC_ADD_ERR,
307                    err,
308                    quote!(astro_float::macro_util::TrigFun::Tan),
309                    cc,
310                ),
311                "asin" => one_arg_fun_errcheck(
312                    quote!(astro_float::BigFloat::asin),
313                    expr,
314                    SPEC_ADD_ERR / 2,
315                    err,
316                    quote!(astro_float::macro_util::ErrAlgo::Asin(&arg, emin)),
317                    cc,
318                ),
319                "acos" => one_arg_fun_errcheck(
320                    quote!(astro_float::BigFloat::acos),
321                    expr,
322                    SPEC_ADD_ERR / 2,
323                    err,
324                    quote!(astro_float::macro_util::ErrAlgo::Acos(&arg, emin)),
325                    cc,
326                ),
327                "atan" => one_arg_fun(quote!(astro_float::BigFloat::atan), expr, 2, err, cc, true),
328                "sinh" => one_arg_fun(
329                    quote!(astro_float::BigFloat::sinh),
330                    expr,
331                    EXPONENT_BIT_SIZE + 1,
332                    err,
333                    cc,
334                    true,
335                ),
336                "cosh" => one_arg_fun(
337                    quote!(astro_float::BigFloat::cosh),
338                    expr,
339                    EXPONENT_BIT_SIZE + 1,
340                    err,
341                    cc,
342                    true,
343                ),
344                "tanh" => one_arg_fun(quote!(astro_float::BigFloat::tanh), expr, 2, err, cc, true),
345                "asinh" => {
346                    one_arg_fun(quote!(astro_float::BigFloat::asinh), expr, 2, err, cc, true)
347                }
348                "acosh" => one_arg_fun_errcheck(
349                    quote!(astro_float::BigFloat::acosh),
350                    expr,
351                    SPEC_ADD_ERR,
352                    err,
353                    quote!(astro_float::macro_util::ErrAlgo::Acosh(&arg, emin)),
354                    cc,
355                ),
356                "atanh" => one_arg_fun_errcheck(
357                    quote!(astro_float::BigFloat::atanh),
358                    expr,
359                    SPEC_ADD_ERR,
360                    err,
361                    quote!(astro_float::macro_util::ErrAlgo::Atanh(&arg, emin)),
362                    cc,
363                ),
364                _ => return Err(Error::new(expr.span(), errmes)),
365            }?;
366
367            return Ok(ts);
368        }
369    }
370    Err(Error::new(expr.span(), errmes))
371}
372
373fn traverse_group(
374    expr: &ExprGroup,
375    err: &mut Vec<usize>,
376    cc: &mut Consts,
377) -> Result<TokenStream, Error> {
378    traverse_expr(&expr.expr, err, cc)
379}
380
381fn traverse_lit(expr: &ExprLit, cc: &mut Consts) -> Result<TokenStream, Error> {
382    let span = expr.span();
383
384    match &expr.lit {
385        Lit::Str(v) => str_to_bigfloat_expr(&v.value(), span, cc),
386        Lit::Int(v) => str_to_bigfloat_expr(v.base10_digits(), span, cc),
387        Lit::Float(v) => str_to_bigfloat_expr(v.base10_digits(), span, cc),
388        _ => Err(Error::new(
389            expr.span(),
390            "unexpected literal. Only string, integer, or floating point literals are supported.",
391        )),
392    }
393}
394
395fn traverse_paren(
396    expr: &ExprParen,
397    err: &mut Vec<usize>,
398    cc: &mut Consts,
399) -> Result<TokenStream, Error> {
400    traverse_expr(&expr.expr, err, cc)
401}
402
403fn traverse_path(expr: &ExprPath) -> Result<TokenStream, Error> {
404    Ok(if expr.path.is_ident("pi") {
405        quote!({ cc.pi(p_wrk, astro_float::RoundingMode::None) })
406    } else if expr.path.is_ident("e") {
407        quote!({ cc.e(p_wrk, astro_float::RoundingMode::None) })
408    } else if expr.path.is_ident("ln_2") {
409        quote!({ cc.ln_2(p_wrk, astro_float::RoundingMode::None) })
410    } else if expr.path.is_ident("ln_10") {
411        quote!({ cc.ln_10(p_wrk, astro_float::RoundingMode::None) })
412    } else {
413        quote!({
414            let mut arg = astro_float::BigFloat::from_ext((#expr).clone(), p_wrk, astro_float::RoundingMode::ToEven, cc);
415            arg.set_inexact(false);
416            arg = astro_float::macro_util::check_exponent_range(arg, emin, emax);
417            arg
418        })
419    })
420}
421
422fn traverse_unary(
423    expr: &ExprUnary,
424    err: &mut Vec<usize>,
425    cc: &mut Consts,
426) -> Result<TokenStream, Error> {
427    let op_expr = traverse_expr(&expr.expr, err, cc)?;
428
429    match expr.op {
430        UnOp::Neg(_) => Ok(quote!(astro_float::BigFloat::neg(&(#op_expr)))),
431        _ => Err(Error::new(
432            expr.span(),
433            "unexpected unary operator. Only \"-\" is allowed.",
434        )),
435    }
436}
437
438fn traverse_expr(expr: &Expr, err: &mut Vec<usize>, cc: &mut Consts) -> Result<TokenStream, Error> {
439    match expr {
440        Expr::Binary(e) => traverse_binary(e, err,cc),
441        Expr::Call(e) => traverse_call(e, err,cc),
442        Expr::Group(e) => traverse_group(e, err,cc),
443        Expr::Lit(e) => traverse_lit(e, cc),
444        Expr::Paren(e) => traverse_paren(e, err, cc),
445        Expr::Path(e) => traverse_path(e),
446        Expr::Unary(e) => traverse_unary(e, err, cc),
447        _ => Err(Error::new(expr.span(), "unexpected expression. Only operators \"+\", \"-\", \"*\", \"/\", \"%\", functions \"recip\", \"sqrt\", \"cbrt\", \"ln\", \"log2\", \"log10\", \"log\", \"exp\", \"pow\", \"sin\", \"cos\", \"tan\", \"asin\", \"acos\", \"atan\", \"sinh\", \"cosh\", \"tanh\", \"asinh\", \"acosh\", \"atanh\", literals and variables, and grouping with parentheses are supported.")),
448    }
449}
450
451// Docs for the macro are in the astro-float crate.
452
453#[proc_macro]
454#[allow(missing_docs)]
455pub fn expr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
456    let pmi = syn::parse_macro_input!(input as MacroInput);
457
458    let MacroInput { expr, ctx } = pmi;
459
460    let mut err = Vec::new();
461
462    let mut cc = Consts::new().expect("Failed to initialize constant cache.");
463
464    let expr = traverse_expr(&expr, &mut err, &mut cc).unwrap_or_else(|e| e.to_compile_error());
465
466    let err_sz = err.len();
467
468    let ret = quote!({
469        use astro_float::FromExt;
470        use astro_float::ctx::Contextable;
471
472        let mut ctx = &mut (#ctx);
473        let p: usize = ctx.precision();
474        let rm = ctx.rounding_mode();
475        let emin = ctx.emin();
476        let emax = ctx.emax();
477        let cc = ctx.consts();
478
479        let mut p_rnd = p + astro_float::WORD_BIT_SIZE;
480        let mut errs: [usize; #err_sz] = [#(#err, )*];
481
482        loop {
483            let p_wrk = p_rnd.saturating_add(errs.iter().sum());
484
485            let mut ret: astro_float::BigFloat = (#expr).into();
486
487            if let Err(err) = ret.set_precision(p, rm) {
488                ret = astro_float::BigFloat::nan(Some(err));
489            }
490
491            break astro_float::macro_util::check_exponent_range(ret, emin, emax);
492        }
493    });
494
495    ret.into()
496}