1use super::{Approx, Real};
2use num::bigint::Sign;
3use num::{BigInt, One, Signed, ToPrimitive};
4use std::mem::discriminant;
5
6#[derive(Clone, Debug)]
7pub enum Primitive {
8 Int(BigInt),
9 Shift(Real, i32),
10 Add(Real, Real),
11 Neg(Real),
12 Mul(Real, Real),
13 Square(Real),
14 Inv(Real),
15 Atan(BigInt),
16 Exp(Real),
17 Cos(Real),
18 Ln(Real),
19 Asin(Real),
20 Sqrt(Real),
21}
22
23impl Primitive {
24 pub fn same_type(&self, other: &Self) -> bool {
25 discriminant(self) == discriminant(other)
26 }
27
28 pub fn approximate(&mut self, precision: i32, mut outer: Real) -> Approx {
29 match self {
30 Primitive::Int(value) => Approx::new(value).scale_to(precision),
31 Primitive::Shift(value, bits) => value.appr(precision - *bits).prec(precision),
37 Primitive::Add(lhs, rhs) => {
38 let lhs = lhs.appr(precision - 2);
39 let rhs = rhs.appr(precision - 2);
40 (lhs + rhs).scale_to(precision)
41 }
42 Primitive::Neg(value) => value.appr(precision) * -1,
43 Primitive::Mul(lhs, rhs) => {
44 let mut v1 = lhs;
45 let mut v2 = rhs;
46
47 let half = (precision >> 1) - 1;
48 let mut msd1 = v1.msd(half);
49
50 if msd1.is_none() {
51 let msd2 = v2.msd(half);
52 if msd2.is_none() {
53 return Approx::new(&BigInt::ZERO).scale_to(precision);
54 } else {
55 msd1 = msd2;
56 let tmp = v1;
57 v1 = v2;
58 v2 = tmp;
59 }
60 }
61
62 let prec2 = precision - msd1.unwrap() - 3;
63 let appr2 = v2.appr(prec2);
64 if appr2.value.sign() == Sign::NoSign {
65 return Approx::new(&BigInt::ZERO).scale_to(precision);
66 }
67 let msd2 = v2.known_msd();
68 let prec1 = precision - msd2.unwrap() - 3;
69 let appr1 = v1.appr(prec1);
70 (appr1 * appr2).scale_to(precision)
71 }
72 Primitive::Square(value) => {
73 let half = (precision >> 1) - 1;
74 let msd = value.msd(half);
75
76 if msd.is_none() {
77 return Approx::new(&BigInt::ZERO).scale_to(precision);
78 }
79
80 let prec = precision - msd.unwrap() - 3;
81 let appr = value.appr(prec);
82 if appr.value.sign() == Sign::NoSign {
83 return Approx::new(&BigInt::ZERO).scale_to(precision);
84 }
85 (appr.clone() * appr).scale_to(precision)
86 }
87 Primitive::Inv(value) => {
88 let msd = value.iter_msd().unwrap();
89 let inv = 1 - msd;
90 let digits = inv - precision + 3;
91 let prec = msd - digits;
92 let log_factor = -precision - prec;
93 if log_factor < 0 {
94 Approx::new(&BigInt::ZERO).prec(precision)
95 } else {
96 let mut divident = BigInt::one() << log_factor;
97 let divisor = value.appr(prec).value;
98 divident += divisor.abs() >> 1;
99 let res = divident / divisor;
100 Approx::new(&res).prec(precision)
101 }
102 }
103 Primitive::Atan(value) => {
104 if precision >= 1 {
105 Approx::new(&BigInt::ZERO).scale_to(precision)
106 } else {
107 let iterations = -precision / 2 + 2;
108 let prec = precision - super::bound_log2(2 * iterations) - 2;
109 let scaled_one = BigInt::one() << -prec;
110 let sq = value.clone() * value.clone();
111 let inv = scaled_one / value.clone();
112
113 let mut pow = inv.clone();
114 let mut term = inv.clone();
115 let mut sum = term.clone();
116 let mut sign = 1;
117 let mut i = 1;
118
119 let max_err = BigInt::one() << (precision - 2 - prec);
120 while term.abs() >= max_err {
121 i += 2;
122 pow /= &sq;
123 sign = -sign;
124 term = &pow / (sign * i);
125 sum += &term;
126 }
127
128 Approx::new(&sum).prec(prec).scale_to(precision)
129 }
130 }
131 Primitive::Exp(value) => {
132 if precision >= 1 {
133 Approx::new(&BigInt::ZERO).scale_to(precision)
134 } else {
135 let iter = -precision / 2 + 2;
136 let calc_prec = precision - super::bound_log2(2 * iter) - 4;
137 let op_prec = precision - 3;
138 let value = value.appr(op_prec).value;
139
140 let mut term = BigInt::one() << -calc_prec;
141 let mut sum = term.clone();
142 let mut i = 0;
143 let max_err = BigInt::one() << (precision - 4 - calc_prec);
144
145 while term.abs() >= max_err {
146 i += 1;
147 term = super::scale(term * &value, op_prec);
148 term /= i;
149 sum += &term;
150 }
151
152 Approx::new(&sum).prec(calc_prec).scale_to(precision)
153 }
154 }
155 Primitive::Cos(value) => {
156 if precision >= 1 {
157 Approx::new(&BigInt::ZERO).scale_to(precision)
158 } else {
159 let iter = -precision / 2 + 4;
160 let calc_prec = precision - super::bound_log2(2 * iter) - 4;
161 let op_prec = precision - 2;
162 let value = value.appr(op_prec).value;
163
164 let mut term = BigInt::one() << -calc_prec;
165 let mut sum = term.clone();
166 let mut i = 0;
167 let max_err = BigInt::one() << (precision - 4 - calc_prec);
168
169 while term.abs() >= max_err {
170 i += 2;
171 term = super::scale(term * &value, op_prec);
172 term = super::scale(term * &value, op_prec);
173 term /= -i * (i - 1);
174 sum += &term;
175 }
176
177 Approx::new(&sum).prec(calc_prec).scale_to(precision)
178 }
179 }
180 Primitive::Ln(value) => {
181 if precision >= 0 {
182 Approx::new(&BigInt::ZERO).scale_to(precision)
183 } else {
184 let iter = -precision;
185 let calc_prec = precision - super::bound_log2(2 * iter) - 4;
186 let op_prec = precision - 3;
187 let appr = value.appr(op_prec);
188
189 let mut x_nth = appr.scale_to(calc_prec).value;
190 let mut term = x_nth.clone();
191 let mut sum = term.clone();
192 let mut i = 1;
193 let mut sign = 1;
194 let max_err = BigInt::one() << (precision - 4 - calc_prec);
195
196 while term.abs() >= max_err {
197 i += 1;
198 sign = -sign;
199 x_nth = super::scale(x_nth * appr.value.clone(), op_prec);
200 term = &x_nth / i * sign;
201 sum += &term;
202 }
203
204 Approx::new(&sum).prec(calc_prec).scale_to(precision)
205 }
206 }
207 Primitive::Asin(value) => {
208 if precision >= 2 {
209 Approx::new(&BigInt::ZERO).scale_to(precision)
210 } else {
211 let iter = -3 * precision / 2 + 4;
212 let calc_prec = precision - super::bound_log2(2 * iter) - 4;
213 let op_prec = precision - 3;
214 let value = value.appr(op_prec).value;
215
216 let mut term = &value << (op_prec - calc_prec);
217 let mut sum = term.clone();
218 let mut factor = term.clone();
219 let mut exp = 1;
220 let max_err = BigInt::one() << (precision - 4 - calc_prec);
221
222 while term.abs() >= max_err {
223 exp += 2;
224 factor *= exp - 2;
225 factor = super::scale(factor * &value, op_prec + 2);
226 factor *= &value;
227 factor /= exp - 1;
228 factor = super::scale(factor, op_prec - 2);
229 term = &factor / exp;
230 sum += &term;
231 }
232
233 Approx::new(&sum).prec(calc_prec).scale_to(precision)
234 }
235 }
236 Primitive::Sqrt(value) => {
237 const FP_PREC: i32 = 50;
238 const FP_OP_PREC: i32 = 60;
239
240 let op_prec = 2 * precision - 1;
241 let msd = value.iter_msd_n(op_prec);
242 let zero = Approx::new(&BigInt::ZERO).scale_to(precision);
243 match msd {
244 None => zero,
245 Some(msd) if msd <= op_prec => zero,
246 Some(msd) => {
247 let msd = msd / 2;
248 let digits = msd - precision;
249
250 if digits > FP_PREC {
251 let appr_digits = digits / 2 + 6;
252 let appr_prec = msd - appr_digits;
253 let prod_prec = 2 * appr_prec;
254
255 let op_appr = value.appr(prod_prec).value;
256 let last_appr = outer.appr(appr_prec).value;
257
258 let mut numerator = &last_appr * &last_appr + op_appr;
259 numerator = super::scale(numerator, appr_prec - precision);
260 let res = (numerator / last_appr + BigInt::one()) >> 1;
261 Approx::new(&res).prec(precision)
262 } else {
263 let op_prec = (msd - FP_OP_PREC) & !1;
264 let prec = op_prec - FP_OP_PREC;
265 let appr = (value.appr(op_prec).value << FP_OP_PREC).to_f64().unwrap();
266 let sqrt: BigInt = (appr.sqrt() as i64).into();
267 Approx::new(&sqrt).prec(prec / 2).scale_to(precision)
268 }
269 }
270 }
271 }
272 }
273 }
274}