computable_real/
prim.rs

1use super::{Approx, Real};
2use num::bigint::Sign;
3use num::{BigInt, One, Signed, ToPrimitive};
4use std::mem::discriminant;
5
6#[derive(Clone, Debug)]
7pub enum Primitive {
8    Int(BigInt),
9    Shift(Real, i32),
10    Add(Real, Real),
11    Neg(Real),
12    Mul(Real, Real),
13    Square(Real),
14    Inv(Real),
15    Atan(BigInt),
16    Exp(Real),
17    Cos(Real),
18    Ln(Real),
19    Asin(Real),
20    Sqrt(Real),
21}
22
23impl Primitive {
24    pub fn same_type(&self, other: &Self) -> bool {
25        discriminant(self) == discriminant(other)
26    }
27
28    pub fn approximate(&mut self, precision: i32, mut outer: Real) -> Approx {
29        match self {
30            Primitive::Int(value) => Approx::new(value).scale_to(precision),
31            // If f(n) ~ x <-> |f(n)*2^n - x| < 2^n,
32            //  then g(n) = f(n-k) ~ x * 2^k
33            //  as |f(n-k)*2^(n-k) - x| < 2^(n-k)
34            //      => |f(n-k)*2^n - x*2^k| < 2^n
35            //      => |g(n)*2^n - x*2^k| < 2^n
36            Primitive::Shift(value, bits) => value.appr(precision - *bits).prec(precision),
37            Primitive::Add(lhs, rhs) => {
38                let lhs = lhs.appr(precision - 2);
39                let rhs = rhs.appr(precision - 2);
40                (lhs + rhs).scale_to(precision)
41            }
42            Primitive::Neg(value) => value.appr(precision) * -1,
43            Primitive::Mul(lhs, rhs) => {
44                let mut v1 = lhs;
45                let mut v2 = rhs;
46
47                let half = (precision >> 1) - 1;
48                let mut msd1 = v1.msd(half);
49
50                if msd1.is_none() {
51                    let msd2 = v2.msd(half);
52                    if msd2.is_none() {
53                        return Approx::new(&BigInt::ZERO).scale_to(precision);
54                    } else {
55                        msd1 = msd2;
56                        let tmp = v1;
57                        v1 = v2;
58                        v2 = tmp;
59                    }
60                }
61
62                let prec2 = precision - msd1.unwrap() - 3;
63                let appr2 = v2.appr(prec2);
64                if appr2.value.sign() == Sign::NoSign {
65                    return Approx::new(&BigInt::ZERO).scale_to(precision);
66                }
67                let msd2 = v2.known_msd();
68                let prec1 = precision - msd2.unwrap() - 3;
69                let appr1 = v1.appr(prec1);
70                (appr1 * appr2).scale_to(precision)
71            }
72            Primitive::Square(value) => {
73                let half = (precision >> 1) - 1;
74                let msd = value.msd(half);
75
76                if msd.is_none() {
77                    return Approx::new(&BigInt::ZERO).scale_to(precision);
78                }
79
80                let prec = precision - msd.unwrap() - 3;
81                let appr = value.appr(prec);
82                if appr.value.sign() == Sign::NoSign {
83                    return Approx::new(&BigInt::ZERO).scale_to(precision);
84                }
85                (appr.clone() * appr).scale_to(precision)
86            }
87            Primitive::Inv(value) => {
88                let msd = value.iter_msd().unwrap();
89                let inv = 1 - msd;
90                let digits = inv - precision + 3;
91                let prec = msd - digits;
92                let log_factor = -precision - prec;
93                if log_factor < 0 {
94                    Approx::new(&BigInt::ZERO).prec(precision)
95                } else {
96                    let mut divident = BigInt::one() << log_factor;
97                    let divisor = value.appr(prec).value;
98                    divident += divisor.abs() >> 1;
99                    let res = divident / divisor;
100                    Approx::new(&res).prec(precision)
101                }
102            }
103            Primitive::Atan(value) => {
104                if precision >= 1 {
105                    Approx::new(&BigInt::ZERO).scale_to(precision)
106                } else {
107                    let iterations = -precision / 2 + 2;
108                    let prec = precision - super::bound_log2(2 * iterations) - 2;
109                    let scaled_one = BigInt::one() << -prec;
110                    let sq = value.clone() * value.clone();
111                    let inv = scaled_one / value.clone();
112
113                    let mut pow = inv.clone();
114                    let mut term = inv.clone();
115                    let mut sum = term.clone();
116                    let mut sign = 1;
117                    let mut i = 1;
118
119                    let max_err = BigInt::one() << (precision - 2 - prec);
120                    while term.abs() >= max_err {
121                        i += 2;
122                        pow /= &sq;
123                        sign = -sign;
124                        term = &pow / (sign * i);
125                        sum += &term;
126                    }
127
128                    Approx::new(&sum).prec(prec).scale_to(precision)
129                }
130            }
131            Primitive::Exp(value) => {
132                if precision >= 1 {
133                    Approx::new(&BigInt::ZERO).scale_to(precision)
134                } else {
135                    let iter = -precision / 2 + 2;
136                    let calc_prec = precision - super::bound_log2(2 * iter) - 4;
137                    let op_prec = precision - 3;
138                    let value = value.appr(op_prec).value;
139
140                    let mut term = BigInt::one() << -calc_prec;
141                    let mut sum = term.clone();
142                    let mut i = 0;
143                    let max_err = BigInt::one() << (precision - 4 - calc_prec);
144
145                    while term.abs() >= max_err {
146                        i += 1;
147                        term = super::scale(term * &value, op_prec);
148                        term /= i;
149                        sum += &term;
150                    }
151
152                    Approx::new(&sum).prec(calc_prec).scale_to(precision)
153                }
154            }
155            Primitive::Cos(value) => {
156                if precision >= 1 {
157                    Approx::new(&BigInt::ZERO).scale_to(precision)
158                } else {
159                    let iter = -precision / 2 + 4;
160                    let calc_prec = precision - super::bound_log2(2 * iter) - 4;
161                    let op_prec = precision - 2;
162                    let value = value.appr(op_prec).value;
163
164                    let mut term = BigInt::one() << -calc_prec;
165                    let mut sum = term.clone();
166                    let mut i = 0;
167                    let max_err = BigInt::one() << (precision - 4 - calc_prec);
168
169                    while term.abs() >= max_err {
170                        i += 2;
171                        term = super::scale(term * &value, op_prec);
172                        term = super::scale(term * &value, op_prec);
173                        term /= -i * (i - 1);
174                        sum += &term;
175                    }
176
177                    Approx::new(&sum).prec(calc_prec).scale_to(precision)
178                }
179            }
180            Primitive::Ln(value) => {
181                if precision >= 0 {
182                    Approx::new(&BigInt::ZERO).scale_to(precision)
183                } else {
184                    let iter = -precision;
185                    let calc_prec = precision - super::bound_log2(2 * iter) - 4;
186                    let op_prec = precision - 3;
187                    let appr = value.appr(op_prec);
188
189                    let mut x_nth = appr.scale_to(calc_prec).value;
190                    let mut term = x_nth.clone();
191                    let mut sum = term.clone();
192                    let mut i = 1;
193                    let mut sign = 1;
194                    let max_err = BigInt::one() << (precision - 4 - calc_prec);
195
196                    while term.abs() >= max_err {
197                        i += 1;
198                        sign = -sign;
199                        x_nth = super::scale(x_nth * appr.value.clone(), op_prec);
200                        term = &x_nth / i * sign;
201                        sum += &term;
202                    }
203
204                    Approx::new(&sum).prec(calc_prec).scale_to(precision)
205                }
206            }
207            Primitive::Asin(value) => {
208                if precision >= 2 {
209                    Approx::new(&BigInt::ZERO).scale_to(precision)
210                } else {
211                    let iter = -3 * precision / 2 + 4;
212                    let calc_prec = precision - super::bound_log2(2 * iter) - 4;
213                    let op_prec = precision - 3;
214                    let value = value.appr(op_prec).value;
215
216                    let mut term = &value << (op_prec - calc_prec);
217                    let mut sum = term.clone();
218                    let mut factor = term.clone();
219                    let mut exp = 1;
220                    let max_err = BigInt::one() << (precision - 4 - calc_prec);
221
222                    while term.abs() >= max_err {
223                        exp += 2;
224                        factor *= exp - 2;
225                        factor = super::scale(factor * &value, op_prec + 2);
226                        factor *= &value;
227                        factor /= exp - 1;
228                        factor = super::scale(factor, op_prec - 2);
229                        term = &factor / exp;
230                        sum += &term;
231                    }
232
233                    Approx::new(&sum).prec(calc_prec).scale_to(precision)
234                }
235            }
236            Primitive::Sqrt(value) => {
237                const FP_PREC: i32 = 50;
238                const FP_OP_PREC: i32 = 60;
239
240                let op_prec = 2 * precision - 1;
241                let msd = value.iter_msd_n(op_prec);
242                let zero = Approx::new(&BigInt::ZERO).scale_to(precision);
243                match msd {
244                    None => zero,
245                    Some(msd) if msd <= op_prec => zero,
246                    Some(msd) => {
247                        let msd = msd / 2;
248                        let digits = msd - precision;
249
250                        if digits > FP_PREC {
251                            let appr_digits = digits / 2 + 6;
252                            let appr_prec = msd - appr_digits;
253                            let prod_prec = 2 * appr_prec;
254
255                            let op_appr = value.appr(prod_prec).value;
256                            let last_appr = outer.appr(appr_prec).value;
257
258                            let mut numerator = &last_appr * &last_appr + op_appr;
259                            numerator = super::scale(numerator, appr_prec - precision);
260                            let res = (numerator / last_appr + BigInt::one()) >> 1;
261                            Approx::new(&res).prec(precision)
262                        } else {
263                            let op_prec = (msd - FP_OP_PREC) & !1;
264                            let prec = op_prec - FP_OP_PREC;
265                            let appr = (value.appr(op_prec).value << FP_OP_PREC).to_f64().unwrap();
266                            let sqrt: BigInt = (appr.sqrt() as i64).into();
267                            Approx::new(&sqrt).prec(prec / 2).scale_to(precision)
268                        }
269                    }
270                }
271            }
272        }
273    }
274}