computable_real/
prim.rs

1use crate::{Approx, Comparator, Real};
2
3use num::bigint::Sign;
4use num::{BigInt, One, Signed, ToPrimitive};
5
6fn bound_log2(n: i32) -> i32 {
7    ((n.abs() + 1) as f64).log2().ceil() as i32
8}
9
10#[derive(Clone, Debug)]
11pub enum Primitive {
12    Int(BigInt),
13    Shift(Real, i32),
14    Add(Real, Real),
15    Neg(Real),
16    Mul(Real, Real),
17    Square(Real),
18    Inv(Real),
19    /// 1/atan(x), used to approximate pi
20    Atan(BigInt),
21    /// exp(x), where |x| < 1/2
22    Exp(Real),
23    /// cos(x), where |x| < 1
24    Cos(Real),
25    /// ln(x+1), where |x| < 1/2
26    Ln(Real),
27    /// asin(x), where |x| < (1/2)^(1/3)
28    Asin(Real),
29    Sqrt(Real),
30    Select(Real, Real, Real),
31}
32
33impl Primitive {
34    pub(crate) fn approximate(&mut self, precision: i32, mut outer: Real) -> Approx {
35        match self {
36            Primitive::Int(value) => Approx::new(value).scale_to(precision),
37            Primitive::Shift(value, bits) => value.appr(precision - *bits).prec(precision),
38            Primitive::Add(lhs, rhs) => {
39                let lhs = lhs.appr(precision - 2);
40                let rhs = rhs.appr(precision - 2);
41                (lhs + rhs).scale_to(precision)
42            }
43            Primitive::Neg(value) => value.appr(precision) * -1,
44            Primitive::Mul(lhs, rhs) => {
45                let mut v1 = lhs;
46                let mut v2 = rhs;
47
48                let half = (precision >> 1) - 1;
49                let mut msd1 = v1.msd(half);
50
51                if msd1.is_none() {
52                    let msd2 = v2.msd(half);
53                    if msd2.is_none() {
54                        return Approx::new(&BigInt::ZERO).scale_to(precision);
55                    } else {
56                        msd1 = msd2;
57                        std::mem::swap(&mut v1, &mut v2);
58                    }
59                }
60
61                let prec2 = precision - msd1.unwrap() - 3;
62                let appr2 = v2.appr(prec2);
63                if appr2.value.sign() == Sign::NoSign {
64                    return Approx::new(&BigInt::ZERO).scale_to(precision);
65                }
66                let msd2 = v2.known_msd();
67                let prec1 = precision - msd2.unwrap() - 3;
68                let appr1 = v1.appr(prec1);
69                (appr1 * appr2).scale_to(precision)
70            }
71            Primitive::Square(value) => {
72                let half = (precision >> 1) - 1;
73                let msd = value.msd(half);
74
75                if msd.is_none() {
76                    return Approx::new(&BigInt::ZERO).scale_to(precision);
77                }
78
79                let prec = precision - msd.unwrap() - 3;
80                let appr = value.appr(prec);
81                if appr.value.sign() == Sign::NoSign {
82                    return Approx::new(&BigInt::ZERO).scale_to(precision);
83                }
84                (appr.clone() * appr).scale_to(precision)
85            }
86            Primitive::Inv(value) => {
87                let msd = value.iter_msd().unwrap();
88                let inv = 1 - msd;
89                let digits = inv - precision + 3;
90                let prec = msd - digits;
91                let log_factor = -precision - prec;
92                if log_factor < 0 {
93                    Approx::new(&BigInt::ZERO).prec(precision)
94                } else {
95                    let mut divident = BigInt::one() << log_factor;
96                    let divisor = value.appr(prec).value;
97                    divident += divisor.abs() >> 1;
98                    let res = divident / divisor;
99                    Approx::new(&res).prec(precision)
100                }
101            }
102            Primitive::Atan(value) => {
103                if precision >= 1 {
104                    Approx::new(&BigInt::ZERO).scale_to(precision)
105                } else {
106                    let iterations = -precision / 2 + 2;
107                    let prec = precision - bound_log2(2 * iterations) - 2;
108                    let scaled_one = BigInt::one() << -prec;
109                    let sq = value.clone() * value.clone();
110                    let inv = scaled_one / value.clone();
111
112                    let mut pow = inv.clone();
113                    let mut term = inv.clone();
114                    let mut sum = term.clone();
115                    let mut sign = 1;
116                    let mut i = 1;
117
118                    let max_err = BigInt::one() << (precision - 2 - prec);
119                    while term.abs() >= max_err {
120                        i += 2;
121                        pow /= &sq;
122                        sign = -sign;
123                        term = &pow / (sign * i);
124                        sum += &term;
125                    }
126
127                    Approx::new(&sum).prec(prec).scale_to(precision)
128                }
129            }
130            Primitive::Exp(value) => {
131                if precision >= 1 {
132                    Approx::new(&BigInt::ZERO).scale_to(precision)
133                } else {
134                    let iter = -precision / 2 + 2;
135                    let calc_prec = precision - bound_log2(2 * iter) - 4;
136                    let op_prec = precision - 3;
137                    let value = value.appr(op_prec).value;
138
139                    let mut term = BigInt::one() << -calc_prec;
140                    let mut sum = term.clone();
141                    let mut i = 0;
142                    let max_err = BigInt::one() << (precision - 4 - calc_prec);
143
144                    while term.abs() >= max_err {
145                        i += 1;
146                        term = super::scale(term * &value, op_prec);
147                        term /= i;
148                        sum += &term;
149                    }
150
151                    Approx::new(&sum).prec(calc_prec).scale_to(precision)
152                }
153            }
154            Primitive::Cos(value) => {
155                if precision >= 1 {
156                    Approx::new(&BigInt::ZERO).scale_to(precision)
157                } else {
158                    let iter = -precision / 2 + 4;
159                    let calc_prec = precision - bound_log2(2 * iter) - 4;
160                    let op_prec = precision - 2;
161                    let value = value.appr(op_prec).value;
162
163                    let mut term = BigInt::one() << -calc_prec;
164                    let mut sum = term.clone();
165                    let mut i = 0;
166                    let max_err = BigInt::one() << (precision - 4 - calc_prec);
167
168                    while term.abs() >= max_err {
169                        i += 2;
170                        term = super::scale(term * &value, op_prec);
171                        term = super::scale(term * &value, op_prec);
172                        term /= -i * (i - 1);
173                        sum += &term;
174                    }
175
176                    Approx::new(&sum).prec(calc_prec).scale_to(precision)
177                }
178            }
179            Primitive::Ln(value) => {
180                if precision >= 0 {
181                    Approx::new(&BigInt::ZERO).scale_to(precision)
182                } else {
183                    let iter = -precision;
184                    let calc_prec = precision - bound_log2(2 * iter) - 4;
185                    let op_prec = precision - 3;
186                    let appr = value.appr(op_prec);
187
188                    let mut x_nth = appr.scale_to(calc_prec).value;
189                    let mut term = x_nth.clone();
190                    let mut sum = term.clone();
191                    let mut i = 1;
192                    let mut sign = 1;
193                    let max_err = BigInt::one() << (precision - 4 - calc_prec);
194
195                    while term.abs() >= max_err {
196                        i += 1;
197                        sign = -sign;
198                        x_nth = super::scale(x_nth * appr.value.clone(), op_prec);
199                        term = &x_nth / i * sign;
200                        sum += &term;
201                    }
202
203                    Approx::new(&sum).prec(calc_prec).scale_to(precision)
204                }
205            }
206            Primitive::Asin(value) => {
207                if precision >= 2 {
208                    Approx::new(&BigInt::ZERO).scale_to(precision)
209                } else {
210                    let iter = -3 * precision / 2 + 4;
211                    let calc_prec = precision - bound_log2(2 * iter) - 4;
212                    let op_prec = precision - 3;
213                    let value = value.appr(op_prec).value;
214
215                    let mut term = &value << (op_prec - calc_prec);
216                    let mut sum = term.clone();
217                    let mut factor = term.clone();
218                    let mut exp = 1;
219                    let max_err = BigInt::one() << (precision - 4 - calc_prec);
220
221                    while term.abs() >= max_err {
222                        exp += 2;
223                        factor *= exp - 2;
224                        factor = super::scale(factor * &value, op_prec + 2);
225                        factor *= &value;
226                        factor /= exp - 1;
227                        factor = super::scale(factor, op_prec - 2);
228                        term = &factor / exp;
229                        sum += &term;
230                    }
231
232                    Approx::new(&sum).prec(calc_prec).scale_to(precision)
233                }
234            }
235            Primitive::Sqrt(value) => {
236                const FP_PREC: i32 = 50;
237                const FP_OP_PREC: i32 = 60;
238
239                let op_prec = 2 * precision - 1;
240                let msd = value.iter_msd_n(op_prec);
241                let zero = Approx::new(&BigInt::ZERO).scale_to(precision);
242                match msd {
243                    None => zero,
244                    Some(msd) if msd <= op_prec => zero,
245                    Some(msd) => {
246                        let msd = msd / 2;
247                        let digits = msd - precision;
248
249                        if digits > FP_PREC {
250                            let appr_digits = digits / 2 + 6;
251                            let appr_prec = msd - appr_digits;
252                            let prod_prec = 2 * appr_prec;
253
254                            let op_appr = value.appr(prod_prec).value;
255                            let last_appr = outer.appr(appr_prec).value;
256
257                            let mut numerator = &last_appr * &last_appr + op_appr;
258                            numerator = super::scale(numerator, appr_prec - precision);
259                            let res = (numerator / last_appr + BigInt::one()) >> 1;
260                            Approx::new(&res).prec(precision)
261                        } else {
262                            let op_prec = (msd - FP_OP_PREC) & !1;
263                            let prec = op_prec - FP_OP_PREC;
264                            let appr = (value.appr(op_prec).value << FP_OP_PREC).to_f64().unwrap();
265                            let sqrt: BigInt = (appr.sqrt() as i64).into();
266                            Approx::new(&sqrt).prec(prec / 2).scale_to(precision)
267                        }
268                    }
269                }
270            }
271            Primitive::Select(cond, lhs, rhs) => {
272                let appr = cond.appr(-20).value;
273                if appr.is_negative() {
274                    return lhs.appr(precision);
275                } else if appr.is_positive() {
276                    return rhs.appr(precision);
277                }
278
279                let appr1 = lhs.appr(precision - 1);
280                let appr2 = rhs.appr(precision - 1);
281                if (&appr1.value - &appr2.value).abs() < BigInt::one() {
282                    return appr1.scale_to(precision);
283                }
284
285                match Comparator::new().signum(cond) {
286                    -1 => appr1.scale_to(precision),
287                    1 => appr2.scale_to(precision),
288                    _ => unreachable!(),
289                }
290            }
291        }
292    }
293}