1use crate::{Approx, Comparator, Real};
2
3use num::bigint::Sign;
4use num::{BigInt, One, Signed, ToPrimitive};
5
6fn bound_log2(n: i32) -> i32 {
7 ((n.abs() + 1) as f64).log2().ceil() as i32
8}
9
10#[derive(Clone, Debug)]
11pub enum Primitive {
12 Int(BigInt),
13 Shift(Real, i32),
14 Add(Real, Real),
15 Neg(Real),
16 Mul(Real, Real),
17 Square(Real),
18 Inv(Real),
19 Atan(BigInt),
21 Exp(Real),
23 Cos(Real),
25 Ln(Real),
27 Asin(Real),
29 Sqrt(Real),
30 Select(Real, Real, Real),
31}
32
33impl Primitive {
34 pub(crate) fn approximate(&mut self, precision: i32, mut outer: Real) -> Approx {
35 match self {
36 Primitive::Int(value) => Approx::new(value).scale_to(precision),
37 Primitive::Shift(value, bits) => value.appr(precision - *bits).prec(precision),
38 Primitive::Add(lhs, rhs) => {
39 let lhs = lhs.appr(precision - 2);
40 let rhs = rhs.appr(precision - 2);
41 (lhs + rhs).scale_to(precision)
42 }
43 Primitive::Neg(value) => value.appr(precision) * -1,
44 Primitive::Mul(lhs, rhs) => {
45 let mut v1 = lhs;
46 let mut v2 = rhs;
47
48 let half = (precision >> 1) - 1;
49 let mut msd1 = v1.msd(half);
50
51 if msd1.is_none() {
52 let msd2 = v2.msd(half);
53 if msd2.is_none() {
54 return Approx::new(&BigInt::ZERO).scale_to(precision);
55 } else {
56 msd1 = msd2;
57 std::mem::swap(&mut v1, &mut v2);
58 }
59 }
60
61 let prec2 = precision - msd1.unwrap() - 3;
62 let appr2 = v2.appr(prec2);
63 if appr2.value.sign() == Sign::NoSign {
64 return Approx::new(&BigInt::ZERO).scale_to(precision);
65 }
66 let msd2 = v2.known_msd();
67 let prec1 = precision - msd2.unwrap() - 3;
68 let appr1 = v1.appr(prec1);
69 (appr1 * appr2).scale_to(precision)
70 }
71 Primitive::Square(value) => {
72 let half = (precision >> 1) - 1;
73 let msd = value.msd(half);
74
75 if msd.is_none() {
76 return Approx::new(&BigInt::ZERO).scale_to(precision);
77 }
78
79 let prec = precision - msd.unwrap() - 3;
80 let appr = value.appr(prec);
81 if appr.value.sign() == Sign::NoSign {
82 return Approx::new(&BigInt::ZERO).scale_to(precision);
83 }
84 (appr.clone() * appr).scale_to(precision)
85 }
86 Primitive::Inv(value) => {
87 let msd = value.iter_msd().unwrap();
88 let inv = 1 - msd;
89 let digits = inv - precision + 3;
90 let prec = msd - digits;
91 let log_factor = -precision - prec;
92 if log_factor < 0 {
93 Approx::new(&BigInt::ZERO).prec(precision)
94 } else {
95 let mut divident = BigInt::one() << log_factor;
96 let divisor = value.appr(prec).value;
97 divident += divisor.abs() >> 1;
98 let res = divident / divisor;
99 Approx::new(&res).prec(precision)
100 }
101 }
102 Primitive::Atan(value) => {
103 if precision >= 1 {
104 Approx::new(&BigInt::ZERO).scale_to(precision)
105 } else {
106 let iterations = -precision / 2 + 2;
107 let prec = precision - bound_log2(2 * iterations) - 2;
108 let scaled_one = BigInt::one() << -prec;
109 let sq = value.clone() * value.clone();
110 let inv = scaled_one / value.clone();
111
112 let mut pow = inv.clone();
113 let mut term = inv.clone();
114 let mut sum = term.clone();
115 let mut sign = 1;
116 let mut i = 1;
117
118 let max_err = BigInt::one() << (precision - 2 - prec);
119 while term.abs() >= max_err {
120 i += 2;
121 pow /= &sq;
122 sign = -sign;
123 term = &pow / (sign * i);
124 sum += &term;
125 }
126
127 Approx::new(&sum).prec(prec).scale_to(precision)
128 }
129 }
130 Primitive::Exp(value) => {
131 if precision >= 1 {
132 Approx::new(&BigInt::ZERO).scale_to(precision)
133 } else {
134 let iter = -precision / 2 + 2;
135 let calc_prec = precision - bound_log2(2 * iter) - 4;
136 let op_prec = precision - 3;
137 let value = value.appr(op_prec).value;
138
139 let mut term = BigInt::one() << -calc_prec;
140 let mut sum = term.clone();
141 let mut i = 0;
142 let max_err = BigInt::one() << (precision - 4 - calc_prec);
143
144 while term.abs() >= max_err {
145 i += 1;
146 term = super::scale(term * &value, op_prec);
147 term /= i;
148 sum += &term;
149 }
150
151 Approx::new(&sum).prec(calc_prec).scale_to(precision)
152 }
153 }
154 Primitive::Cos(value) => {
155 if precision >= 1 {
156 Approx::new(&BigInt::ZERO).scale_to(precision)
157 } else {
158 let iter = -precision / 2 + 4;
159 let calc_prec = precision - bound_log2(2 * iter) - 4;
160 let op_prec = precision - 2;
161 let value = value.appr(op_prec).value;
162
163 let mut term = BigInt::one() << -calc_prec;
164 let mut sum = term.clone();
165 let mut i = 0;
166 let max_err = BigInt::one() << (precision - 4 - calc_prec);
167
168 while term.abs() >= max_err {
169 i += 2;
170 term = super::scale(term * &value, op_prec);
171 term = super::scale(term * &value, op_prec);
172 term /= -i * (i - 1);
173 sum += &term;
174 }
175
176 Approx::new(&sum).prec(calc_prec).scale_to(precision)
177 }
178 }
179 Primitive::Ln(value) => {
180 if precision >= 0 {
181 Approx::new(&BigInt::ZERO).scale_to(precision)
182 } else {
183 let iter = -precision;
184 let calc_prec = precision - bound_log2(2 * iter) - 4;
185 let op_prec = precision - 3;
186 let appr = value.appr(op_prec);
187
188 let mut x_nth = appr.scale_to(calc_prec).value;
189 let mut term = x_nth.clone();
190 let mut sum = term.clone();
191 let mut i = 1;
192 let mut sign = 1;
193 let max_err = BigInt::one() << (precision - 4 - calc_prec);
194
195 while term.abs() >= max_err {
196 i += 1;
197 sign = -sign;
198 x_nth = super::scale(x_nth * appr.value.clone(), op_prec);
199 term = &x_nth / i * sign;
200 sum += &term;
201 }
202
203 Approx::new(&sum).prec(calc_prec).scale_to(precision)
204 }
205 }
206 Primitive::Asin(value) => {
207 if precision >= 2 {
208 Approx::new(&BigInt::ZERO).scale_to(precision)
209 } else {
210 let iter = -3 * precision / 2 + 4;
211 let calc_prec = precision - bound_log2(2 * iter) - 4;
212 let op_prec = precision - 3;
213 let value = value.appr(op_prec).value;
214
215 let mut term = &value << (op_prec - calc_prec);
216 let mut sum = term.clone();
217 let mut factor = term.clone();
218 let mut exp = 1;
219 let max_err = BigInt::one() << (precision - 4 - calc_prec);
220
221 while term.abs() >= max_err {
222 exp += 2;
223 factor *= exp - 2;
224 factor = super::scale(factor * &value, op_prec + 2);
225 factor *= &value;
226 factor /= exp - 1;
227 factor = super::scale(factor, op_prec - 2);
228 term = &factor / exp;
229 sum += &term;
230 }
231
232 Approx::new(&sum).prec(calc_prec).scale_to(precision)
233 }
234 }
235 Primitive::Sqrt(value) => {
236 const FP_PREC: i32 = 50;
237 const FP_OP_PREC: i32 = 60;
238
239 let op_prec = 2 * precision - 1;
240 let msd = value.iter_msd_n(op_prec);
241 let zero = Approx::new(&BigInt::ZERO).scale_to(precision);
242 match msd {
243 None => zero,
244 Some(msd) if msd <= op_prec => zero,
245 Some(msd) => {
246 let msd = msd / 2;
247 let digits = msd - precision;
248
249 if digits > FP_PREC {
250 let appr_digits = digits / 2 + 6;
251 let appr_prec = msd - appr_digits;
252 let prod_prec = 2 * appr_prec;
253
254 let op_appr = value.appr(prod_prec).value;
255 let last_appr = outer.appr(appr_prec).value;
256
257 let mut numerator = &last_appr * &last_appr + op_appr;
258 numerator = super::scale(numerator, appr_prec - precision);
259 let res = (numerator / last_appr + BigInt::one()) >> 1;
260 Approx::new(&res).prec(precision)
261 } else {
262 let op_prec = (msd - FP_OP_PREC) & !1;
263 let prec = op_prec - FP_OP_PREC;
264 let appr = (value.appr(op_prec).value << FP_OP_PREC).to_f64().unwrap();
265 let sqrt: BigInt = (appr.sqrt() as i64).into();
266 Approx::new(&sqrt).prec(prec / 2).scale_to(precision)
267 }
268 }
269 }
270 }
271 Primitive::Select(cond, lhs, rhs) => {
272 let appr = cond.appr(-20).value;
273 if appr.is_negative() {
274 return lhs.appr(precision);
275 } else if appr.is_positive() {
276 return rhs.appr(precision);
277 }
278
279 let appr1 = lhs.appr(precision - 1);
280 let appr2 = rhs.appr(precision - 1);
281 if (&appr1.value - &appr2.value).abs() < BigInt::one() {
282 return appr1.scale_to(precision);
283 }
284
285 match Comparator::new().signum(cond) {
286 -1 => appr1.scale_to(precision),
287 1 => appr2.scale_to(precision),
288 _ => unreachable!(),
289 }
290 }
291 }
292 }
293}