balancer_maths_rust/common/
log_exp_math.rs

1//! Logarithmic and exponential math utilities for fixed-point arithmetic
2
3use crate::common::constants::WAD;
4use crate::common::errors::PoolError;
5use lazy_static::lazy_static;
6use num_bigint::BigInt;
7use num_traits::{One, Signed, Zero};
8use std::str::FromStr;
9
10lazy_static! {
11    // Constants for LogExpMath
12    static ref MAX_NATURAL_EXPONENT: BigInt = BigInt::from(130000000000000000000i128);
13    static ref MIN_NATURAL_EXPONENT: BigInt = BigInt::from(-41000000000000000000i128);
14    static ref LN_36_LOWER_BOUND: BigInt = BigInt::from(900000000000000000i128);
15    static ref LN_36_UPPER_BOUND: BigInt = BigInt::from(1100000000000000000i128);
16    // Precomputed value of 2^254 / HUNDRED_WAD
17    static ref MILD_EXPONENT_BOUND: BigInt = BigInt::from_str("289480223093290488558927462521719769633174961664101410098").unwrap();
18
19    // RAY constant for 36 decimal precision
20    static ref RAY: BigInt = BigInt::from_str("1000000000000000000000000000000000000").unwrap();
21
22    // 18 decimal constants
23    static ref X0: BigInt = BigInt::from(128000000000000000000i128); // 2^7
24    static ref A0: BigInt = BigInt::from_str("38877084059945950922200000000000000000000000000000000000").unwrap(); // e^(x0) (no decimals)
25    static ref X1: BigInt = BigInt::from(64000000000000000000i128); // 2^6
26    static ref A1: BigInt = BigInt::from(6235149080811616882910000000i128); // e^(x1) (no decimals)
27
28    // 20 decimal constants
29    static ref X2: BigInt = BigInt::from(3200000000000000000000i128); // 2^5
30    static ref A2: BigInt = BigInt::from_str("7896296018268069516100000000000000").unwrap(); // e^(x2)
31    static ref X3: BigInt = BigInt::from(1600000000000000000000i128); // 2^4
32    static ref A3: BigInt = BigInt::from_str("888611052050787263676000000").unwrap(); // e^(x3)
33    static ref X4: BigInt = BigInt::from(800000000000000000000i128); // 2^3
34    static ref A4: BigInt = BigInt::from_str("298095798704172827474000").unwrap(); // e^(x4)
35    static ref X5: BigInt = BigInt::from(400000000000000000000i128); // 2^2
36    static ref A5: BigInt = BigInt::from_str("5459815003314423907810").unwrap(); // e^(x5)
37    static ref X6: BigInt = BigInt::from(200000000000000000000i128); // 2^1
38    static ref A6: BigInt = BigInt::from_str("738905609893065022723").unwrap(); // e^(x6)
39    static ref X7: BigInt = BigInt::from(100000000000000000000i128); // 2^0
40    static ref A7: BigInt = BigInt::from_str("271828182845904523536").unwrap(); // e^(x7)
41    static ref X8: BigInt = BigInt::from(50000000000000000000i128); // 2^-1
42    static ref A8: BigInt = BigInt::from_str("164872127070012814685").unwrap(); // e^(x8)
43    static ref X9: BigInt = BigInt::from(25000000000000000000i128); // 2^-2
44    static ref A9: BigInt = BigInt::from_str("128402541668774148407").unwrap(); // e^(x9)
45    static ref X10: BigInt = BigInt::from(12500000000000000000i128); // 2^-3
46    static ref A10: BigInt = BigInt::from_str("113314845306682631683").unwrap(); // e^(x10)
47    static ref X11: BigInt = BigInt::from(6250000000000000000i128); // 2^-4
48    static ref A11: BigInt = BigInt::from_str("106449445891785942956").unwrap(); // e^(x11)
49
50    static ref HUNDRED_WAD: BigInt = BigInt::from(100000000000000000000i128);
51}
52
53/// Calculate x^y using logarithmic and exponential properties
54pub fn pow(x: &BigInt, y: &BigInt) -> Result<BigInt, PoolError> {
55    if y.is_zero() {
56        // We solve the 0^0 indetermination by making it equal one.
57        return Ok(WAD.clone());
58    }
59
60    if x.is_zero() {
61        return Ok(BigInt::zero());
62    }
63
64    // The ln function takes a signed value, so we need to make sure x fits in the signed 256 bit range.
65    if x >= &BigInt::from_str(
66        "57896044618658097711785492504343953926634992332820282019728792003956564819968",
67    )
68    .unwrap()
69    {
70        return Err(PoolError::MathOverflow);
71    }
72    let x_int256 = x.clone();
73
74    // This prevents y * ln(x) from overflowing, and at the same time guarantees y fits in the signed 256 bit range.
75    if y >= &*MILD_EXPONENT_BOUND {
76        return Err(PoolError::MathOverflow);
77    }
78    let y_int256 = y.clone();
79
80    let logx_times_y = if x_int256 > *LN_36_LOWER_BOUND && x_int256 < *LN_36_UPPER_BOUND {
81        let ln_36_x = ln_36(&x_int256)?;
82        // ln_36_x has 36 decimal places, so multiplying by y_int256 isn't as straightforward, since we can't just
83        // bring y_int256 to 36 decimal places, as it might overflow. Instead, we perform two 18 decimal
84        // multiplications and add the results: one with the first 18 decimals of ln_36_x, and one with the
85        // (downscaled) last 18 decimals.
86        (ln_36_x.clone() / &*WAD) * &y_int256 + ((ln_36_x.clone() % &*WAD) * &y_int256) / &*WAD
87    } else {
88        ln(&x_int256)? * &y_int256
89    };
90
91    let logx_times_y = logx_times_y / &*WAD;
92
93    // Finally, we compute exp(y * ln(x)) to arrive at x^y
94    if logx_times_y < *MIN_NATURAL_EXPONENT || logx_times_y > *MAX_NATURAL_EXPONENT {
95        return Err(PoolError::MathOverflow);
96    }
97
98    exp(&logx_times_y)
99}
100
101/// Calculate exponential function e^x
102fn exp(x: &BigInt) -> Result<BigInt, PoolError> {
103    if x < &MIN_NATURAL_EXPONENT || x > &MAX_NATURAL_EXPONENT {
104        return Err(PoolError::MathOverflow);
105    }
106
107    if x.is_negative() {
108        // We only handle positive exponents: e^(-x) is computed as 1 / e^x. We can safely make x positive since it
109        // fits in the signed 256 bit range (as it is larger than MIN_NATURAL_EXPONENT).
110        // Fixed point division requires multiplying by ONE_18.
111        return Ok((&*WAD * &*WAD) / exp(&(-x))?);
112    }
113
114    let mut x = x.clone();
115    let first_an = if x >= *X0 {
116        x -= &*X0;
117        A0.clone()
118    } else if x >= *X1 {
119        x -= &*X1;
120        A1.clone()
121    } else {
122        BigInt::one()
123    };
124
125    // We now transform x into a 20 decimal fixed point number, to have enhanced precision when computing the
126    // smaller terms.
127    x *= BigInt::from(100);
128
129    // `product` is the accumulated product of all a_n (except a0 and a1), which starts at 20 decimal fixed point
130    // one. Recall that fixed point multiplication requires dividing by ONE_20.
131    let mut product = HUNDRED_WAD.clone();
132
133    if x >= *X2 {
134        x -= &*X2;
135        product = (product * &*A2) / &*HUNDRED_WAD;
136    }
137    if x >= *X3 {
138        x -= &*X3;
139        product = (product * &*A3) / &*HUNDRED_WAD;
140    }
141    if x >= *X4 {
142        x -= &*X4;
143        product = (product * &*A4) / &*HUNDRED_WAD;
144    }
145    if x >= *X5 {
146        x -= &*X5;
147        product = (product * &*A5) / &*HUNDRED_WAD;
148    }
149    if x >= *X6 {
150        x -= &*X6;
151        product = (product * &*A6) / &*HUNDRED_WAD;
152    }
153    if x >= *X7 {
154        x -= &*X7;
155        product = (product * &*A7) / &*HUNDRED_WAD;
156    }
157    if x >= *X8 {
158        x -= &*X8;
159        product = (product * &*A8) / &*HUNDRED_WAD;
160    }
161    if x >= *X9 {
162        x -= &*X9;
163        product = (product * &*A9) / &*HUNDRED_WAD;
164    }
165
166    // x10 and x11 are unnecessary here since we have high enough precision already.
167
168    // Now we need to compute e^x, where x is small (in particular, it is smaller than x9). We use the Taylor series
169    // expansion for e^x: 1 + x + (x^2 / 2!) + (x^3 / 3!) + ... + (x^n / n!).
170
171    let mut series_sum = HUNDRED_WAD.clone(); // The initial one in the sum, with 20 decimal places.
172    let mut term = x.clone(); // Each term in the sum, where the nth term is (x^n / n!).
173
174    // The first term is simply x.
175    series_sum += &term;
176
177    // Each term (x^n / n!) equals the previous one times x, divided by n. Since x is a fixed point number,
178    // multiplying by it requires dividing by HUNDRED_WAD, but dividing by the non-fixed point n values does not.
179
180    term = (term * &x) / &*HUNDRED_WAD / &BigInt::from(2);
181    series_sum += &term;
182
183    term = (term * &x) / &*HUNDRED_WAD / &BigInt::from(3);
184    series_sum += &term;
185
186    term = (term * &x) / &*HUNDRED_WAD / &BigInt::from(4);
187    series_sum += &term;
188
189    term = (term * &x) / &*HUNDRED_WAD / &BigInt::from(5);
190    series_sum += &term;
191
192    term = (term * &x) / &*HUNDRED_WAD / &BigInt::from(6);
193    series_sum += &term;
194
195    term = (term * &x) / &*HUNDRED_WAD / &BigInt::from(7);
196    series_sum += &term;
197
198    term = (term * &x) / &*HUNDRED_WAD / &BigInt::from(8);
199    series_sum += &term;
200
201    term = (term * &x) / &*HUNDRED_WAD / &BigInt::from(9);
202    series_sum += &term;
203
204    term = (term * &x) / &*HUNDRED_WAD / &BigInt::from(10);
205    series_sum += &term;
206
207    term = (term * &x) / &*HUNDRED_WAD / &BigInt::from(11);
208    series_sum += &term;
209
210    term = (term * &x) / &*HUNDRED_WAD / &BigInt::from(12);
211    series_sum += &term;
212
213    // 12 Taylor terms are sufficient for 18 decimal precision.
214
215    // Finally, we multiply by 2^7 / 2^7 = 1 and add all the terms up to compute the result.
216    // If the first argument is 0 (x = 0), then we want the result to be 1, as e^0 = 1.
217    let result = ((product * series_sum) / &*HUNDRED_WAD) * &first_an / BigInt::from(100);
218
219    Ok(result)
220}
221
222/// Calculate natural logarithm ln(x) with signed 18 decimal fixed point argument
223fn ln(x: &BigInt) -> Result<BigInt, PoolError> {
224    let mut a = x.clone();
225
226    if a < *WAD {
227        // Since ln(a^k) = k * ln(a), we can compute ln(a) as ln(a) = ln((1/a)^(-1)) = - ln((1/a)). If a is less
228        // than one, 1/a will be greater than one, and this if statement will not be entered in the recursive call.
229        // Fixed point division requires multiplying by ONE_18.
230        return Ok(-BigInt::one() * ln(&((&*WAD * &*WAD) / &a))?);
231    }
232
233    // First, we use the fact that ln^(a * b) = ln(a) + ln(b) to decompose ln(a) into a sum of powers of two, which
234    // we call x_n, where x_n == 2^(7 - n), which are the natural logarithm of precomputed quantities a_n (that is,
235    // ln(a_n) = x_n). We choose the first x_n, x0, to equal 2^7 because the exponential of all larger powers cannot
236    // be represented as 18 fixed point decimal numbers in 256 bits, and are therefore larger than a.
237    // At the end of this process we will have the sum of all x_n = ln(a_n) that apply, and the remainder of this
238    // decomposition, which will be lower than the smallest a_n.
239    // ln(a) = k_0 * x_0 + k_1 * x_1 + ... + k_n * x_n + ln(remainder), where each k_n equals either 0 or 1.
240    // We mutate a by subtracting a_n, making it the remainder of the decomposition.
241
242    // For reasons related to how `exp` works, the first two a_n (e^(2^7) and e^(2^6)) are not stored as fixed point
243    // numbers with 18 decimals, but instead as plain integers with 0 decimals, so we need to multiply them by
244    // ONE_18 to convert them to fixed point.
245    // For each a_n, we test if that term is present in the decomposition (if a is larger than it), and if so divide
246    // by it and compute the accumulated sum.
247
248    let mut sum = BigInt::zero();
249    if a >= (&*A0 * &*WAD) {
250        a /= &*A0; // Integer, not fixed point division
251        sum += &*X0;
252    }
253
254    if a >= (&*A1 * &*WAD) {
255        a /= &*A1; // Integer, not fixed point division
256        sum += &*X1;
257    }
258
259    // All other a_n and x_n are stored as 20 digit fixed point numbers, so we convert the sum and a to this format.
260    sum *= BigInt::from(100);
261    a *= BigInt::from(100);
262
263    // Because further a_n are 20 digit fixed point numbers, we multiply by ONE_20 when dividing by them.
264
265    if a >= *A2 {
266        a = (&a * &*HUNDRED_WAD) / &*A2;
267        sum += &*X2;
268    }
269
270    if a >= *A3 {
271        a = (&a * &*HUNDRED_WAD) / &*A3;
272        sum += &*X3;
273    }
274
275    if a >= *A4 {
276        a = (&a * &*HUNDRED_WAD) / &*A4;
277        sum += &*X4;
278    }
279
280    if a >= *A5 {
281        a = (&a * &*HUNDRED_WAD) / &*A5;
282        sum += &*X5;
283    }
284
285    if a >= *A6 {
286        a = (&a * &*HUNDRED_WAD) / &*A6;
287        sum += &*X6;
288    }
289
290    if a >= *A7 {
291        a = (&a * &*HUNDRED_WAD) / &*A7;
292        sum += &*X7;
293    }
294
295    if a >= *A8 {
296        a = (&a * &*HUNDRED_WAD) / &*A8;
297        sum += &*X8;
298    }
299
300    if a >= *A9 {
301        a = (&a * &*HUNDRED_WAD) / &*A9;
302        sum += &*X9;
303    }
304
305    if a >= *A10 {
306        a = (&a * &*HUNDRED_WAD) / &*A10;
307        sum += &*X10;
308    }
309
310    if a >= *A11 {
311        a = (&a * &*HUNDRED_WAD) / &*A11;
312        sum += &*X11;
313    }
314
315    // a is now a small number (smaller than a_11, which roughly equals 1.06). This means we can use a Taylor series
316    // that converges rapidly for values of `a` close to one - the same one used in ln_36.
317    // Let z = (a - 1) / (a + 1).
318    // ln(a) = 2 * (z + z^3 / 3 + z^5 / 5 + z^7 / 7 + ... + z^(2 * n + 1) / (2 * n + 1))
319
320    // Recall that 20 digit fixed point division requires multiplying by ONE_20, and multiplication requires
321    // division by ONE_20.
322    let z = ((&a - &*HUNDRED_WAD) * &*HUNDRED_WAD) / (&a + &*HUNDRED_WAD);
323    let z_squared = (&z * &z) / &*HUNDRED_WAD;
324
325    // num is the numerator of the series: the z^(2 * n + 1) term
326    let mut num = z.clone();
327
328    // seriesSum holds the accumulated sum of each term in the series, starting with the initial z
329    let mut series_sum = num.clone();
330
331    // In each step, the numerator is multiplied by z^2
332    num = (&num * &z_squared) / &*HUNDRED_WAD;
333    series_sum += &num / &BigInt::from(3);
334
335    num = (&num * &z_squared) / &*HUNDRED_WAD;
336    series_sum += &num / &BigInt::from(5);
337
338    num = (&num * &z_squared) / &*HUNDRED_WAD;
339    series_sum += &num / &BigInt::from(7);
340
341    num = (&num * &z_squared) / &*HUNDRED_WAD;
342    series_sum += &num / &BigInt::from(9);
343
344    num = (&num * &z_squared) / &*HUNDRED_WAD;
345    series_sum += &num / &BigInt::from(11);
346
347    // 6 Taylor terms are sufficient for 36 decimal precision.
348
349    // Finally, we multiply by 2 (non fixed point) to compute ln(remainder)
350    series_sum *= BigInt::from(2);
351
352    // We now have the sum of all x_n present, and the Taylor approximation of the logarithm of the remainder (both
353    // with 20 decimals). All that remains is to sum these two, and then drop two digits to return a 18 decimal
354    // value.
355
356    Ok((sum + series_sum) / BigInt::from(100))
357}
358
359/// Calculate natural logarithm with 36 decimal precision
360fn ln_36(x: &BigInt) -> Result<BigInt, PoolError> {
361    let mut x = x.clone();
362    // Since ln(1) = 0, a value of x close to one will yield a very small result, which makes using 36 digits
363    // worthwhile.
364
365    // First, we transform x to a 36 digit fixed point value.
366    x *= &*WAD;
367
368    // We will use the following Taylor expansion, which converges very rapidly. Let z = (x - 1) / (x + 1).
369    // ln(x) = 2 * (z + z^3 / 3 + z^5 / 5 + z^7 / 7 + ... + z^(2 * n + 1) / (2 * n + 1))
370
371    // Recall that 36 digit fixed point division requires multiplying by ONE_36, and multiplication requires
372    // division by ONE_36.
373    let z = ((&x - &*RAY) * &*RAY) / (&x + &*RAY);
374    let z_squared = (&z * &z) / &*RAY;
375
376    // num is the numerator of the series: the z^(2 * n + 1) term
377    let mut num = z.clone();
378
379    // seriesSum holds the accumulated sum of each term in the series, starting with the initial z
380    let mut series_sum = num.clone();
381
382    // In each step, the numerator is multiplied by z^2
383    num = (&num * &z_squared) / &*RAY;
384    series_sum += &num / &BigInt::from(3);
385
386    num = (&num * &z_squared) / &*RAY;
387    series_sum += &num / &BigInt::from(5);
388
389    num = (&num * &z_squared) / &*RAY;
390    series_sum += &num / &BigInt::from(7);
391
392    num = (&num * &z_squared) / &*RAY;
393    series_sum += &num / &BigInt::from(9);
394
395    num = (&num * &z_squared) / &*RAY;
396    series_sum += &num / &BigInt::from(11);
397
398    num = (&num * &z_squared) / &*RAY;
399    series_sum += &num / &BigInt::from(13);
400
401    num = (&num * &z_squared) / &*RAY;
402    series_sum += &num / &BigInt::from(15);
403
404    // 8 Taylor terms are sufficient for 36 decimal precision.
405
406    // All that remains is multiplying by 2 (non fixed point).
407    Ok(series_sum * BigInt::from(2))
408}