balancer_maths_rust/common/
log_exp_math.rs

1//! Logarithmic and exponential math utilities for fixed-point arithmetic
2use crate::common::errors::PoolError;
3use alloy_primitives::{uint, I256, U256};
4use std::str::FromStr;
5
6// Constants for LogExpMath
7pub const ONE_18: I256 = I256::from_raw(uint!(1000000000000000000_U256));
8pub const MAX_NATURAL_EXPONENT: I256 = I256::from_raw(uint!(130000000000000000000_U256));
9pub const LN_36_LOWER_BOUND: I256 = I256::from_raw(uint!(900000000000000000_U256));
10pub const LN_36_UPPER_BOUND: I256 = I256::from_raw(uint!(1100000000000000000_U256));
11// Precomputed value of 2^254 / HUNDRED_WAD
12pub const MILD_EXPONENT_BOUND: U256 =
13    uint!(289480223093290488558927462521719769633174961664101410098_U256);
14
15// RAY constant for 36 decimal precision
16pub const RAY: I256 = I256::from_raw(uint!(1000000000000000000000000000000000000_U256));
17
18// 18 decimal constants
19pub const X0: I256 = I256::from_raw(uint!(128000000000000000000_U256)); // 2^7
20pub const A0: I256 = I256::from_raw(uint!(
21    38877084059945950922200000000000000000000000000000000000_U256
22)); // e^(x0) (no decimals)
23pub const X1: I256 = I256::from_raw(uint!(64000000000000000000_U256)); // 2^6
24pub const A1: I256 = I256::from_raw(uint!(6235149080811616882910000000_U256)); // e^(x1) (no decimals)
25
26// 20 decimal constants
27pub const X2: I256 = I256::from_raw(uint!(3200000000000000000000_U256)); // 2^5
28pub const A2: I256 = I256::from_raw(uint!(7896296018268069516100000000000000_U256)); // e^(x2)
29pub const X3: I256 = I256::from_raw(uint!(1600000000000000000000_U256)); // 2^4
30pub const A3: I256 = I256::from_raw(uint!(888611052050787263676000000_U256)); // e^(x3)
31pub const X4: I256 = I256::from_raw(uint!(800000000000000000000_U256)); // 2^3
32pub const A4: I256 = I256::from_raw(uint!(298095798704172827474000_U256)); // e^(x4)
33pub const X5: I256 = I256::from_raw(uint!(400000000000000000000_U256)); // 2^2
34pub const A5: I256 = I256::from_raw(uint!(5459815003314423907810_U256)); // e^(x5)
35pub const X6: I256 = I256::from_raw(uint!(200000000000000000000_U256)); // 2^1
36pub const A6: I256 = I256::from_raw(uint!(738905609893065022723_U256)); // e^(x6)
37pub const X7: I256 = I256::from_raw(uint!(100000000000000000000_U256)); // 2^0
38pub const A7: I256 = I256::from_raw(uint!(271828182845904523536_U256)); // e^(x7)
39pub const X8: I256 = I256::from_raw(uint!(50000000000000000000_U256)); // 2^-1
40pub const A8: I256 = I256::from_raw(uint!(164872127070012814685_U256)); // e^(x8)
41pub const X9: I256 = I256::from_raw(uint!(25000000000000000000_U256)); // 2^-2
42pub const A9: I256 = I256::from_raw(uint!(128402541668774148407_U256)); // e^(x9)
43pub const X10: I256 = I256::from_raw(uint!(12500000000000000000_U256)); // 2^-3
44pub const A10: I256 = I256::from_raw(uint!(113314845306682631683_U256)); // e^(x10)
45pub const X11: I256 = I256::from_raw(uint!(6250000000000000000_U256)); // 2^-4
46pub const A11: I256 = I256::from_raw(uint!(106449445891785942956_U256)); // e^(x11)
47
48pub const ONE_20: I256 = I256::from_raw(uint!(100000000000000000000_U256));
49
50// This gives you 2^256 - 41e18, which is the two's complement representation of -41e18
51const MIN_NATURAL_EXPONENT_ABS: U256 = uint!(41000000000000000000_U256);
52pub const MIN_NATURAL_EXPONENT: I256 =
53    I256::from_raw(U256::ZERO.wrapping_sub(MIN_NATURAL_EXPONENT_ABS));
54
55/// Calculate x^y using logarithmic and exponential properties
56pub fn pow(x: &U256, y: &U256) -> Result<U256, PoolError> {
57    if y.is_zero() {
58        // We solve the 0^0 indetermination by making it equal one.
59        return Ok(ONE_18.into_raw());
60    }
61
62    if x.is_zero() {
63        return Ok(U256::ZERO);
64    }
65
66    let x_int256 = I256::from_raw(*x);
67
68    // This prevents y * ln(x) from overflowing, and at the same time guarantees y fits in the signed 256 bit range.
69    if y >= &MILD_EXPONENT_BOUND {
70        return Err(PoolError::MathOverflow);
71    }
72    let y_int256 = I256::from_raw(*y);
73
74    let logx_times_y = if x_int256 > LN_36_LOWER_BOUND && x_int256 < LN_36_UPPER_BOUND {
75        let ln_36_x = ln_36(&x_int256)?;
76        // ln_36_x has 36 decimal places, so multiplying by y_int256 isn't as straightforward, since we can't just
77        // bring y_int256 to 36 decimal places, as it might overflow. Instead, we perform two 18 decimal
78        // multiplications and add the results: one with the first 18 decimals of ln_36_x, and one with the
79        // (downscaled) last 18 decimals.
80        (ln_36_x / ONE_18) * y_int256 + ((ln_36_x % ONE_18) * y_int256) / ONE_18
81    } else {
82        ln(&x_int256)? * y_int256
83    };
84
85    let logx_times_y = logx_times_y / ONE_18;
86
87    // Finally, we compute exp(y * ln(x)) to arrive at x^y
88    if logx_times_y < MIN_NATURAL_EXPONENT || logx_times_y > MAX_NATURAL_EXPONENT {
89        return Err(PoolError::MathOverflow);
90    }
91
92    exp(&logx_times_y).map(|result| result.into_raw())
93}
94
95/// Calculate exponential function e^x
96fn exp(x: &I256) -> Result<I256, PoolError> {
97    if x < &MIN_NATURAL_EXPONENT || x > &MAX_NATURAL_EXPONENT {
98        return Err(PoolError::MathOverflow);
99    }
100
101    if x.is_negative() {
102        // We only handle positive exponents: e^(-x) is computed as 1 / e^x. We can safely make x positive since it
103        // fits in the signed 256 bit range (as it is larger than MIN_NATURAL_EXPONENT).
104        // Fixed point division requires multiplying by ONE_18.
105        return Ok((ONE_18 * ONE_18) / exp(&(-*x))?);
106    }
107
108    let mut x = *x;
109    let first_an = if x >= X0 {
110        x -= X0;
111        A0
112    } else if x >= X1 {
113        x -= X1;
114        A1
115    } else {
116        I256::ONE
117    };
118
119    // We now transform x into a 20 decimal fixed point number, to have enhanced precision when computing the
120    // smaller terms.
121    x *= I256::from_str("100").unwrap();
122
123    // `product` is the accumulated product of all a_n (except a0 and a1), which starts at 20 decimal fixed point
124    // one. Recall that fixed point multiplication requires dividing by ONE_20.
125    let mut product = ONE_20;
126
127    if x >= X2 {
128        x -= X2;
129        product = (product * A2) / ONE_20;
130    }
131    if x >= X3 {
132        x -= X3;
133        product = (product * A3) / ONE_20;
134    }
135    if x >= X4 {
136        x -= X4;
137        product = (product * A4) / ONE_20;
138    }
139    if x >= X5 {
140        x -= X5;
141        product = (product * A5) / ONE_20;
142    }
143    if x >= X6 {
144        x -= X6;
145        product = (product * A6) / ONE_20;
146    }
147    if x >= X7 {
148        x -= X7;
149        product = (product * A7) / ONE_20;
150    }
151    if x >= X8 {
152        x -= X8;
153        product = (product * A8) / ONE_20;
154    }
155    if x >= X9 {
156        x -= X9;
157        product = (product * A9) / ONE_20;
158    }
159
160    // x10 and x11 are unnecessary here since we have high enough precision already.
161
162    // Now we need to compute e^x, where x is small (in particular, it is smaller than x9). We use the Taylor series
163    // expansion for e^x: 1 + x + (x^2 / 2!) + (x^3 / 3!) + ... + (x^n / n!).
164
165    let mut series_sum = ONE_20; // The initial one in the sum, with 20 decimal places.
166    let mut term = x; // Each term in the sum, where the nth term is (x^n / n!).
167
168    // The first term is simply x.
169    series_sum += term;
170
171    // Each term (x^n / n!) equals the previous one times x, divided by n. Since x is a fixed point number,
172    // multiplying by it requires dividing by HUNDRED_WAD, but dividing by the non-fixed point n values does not.
173
174    term = (term * x) / ONE_20 / I256::from_str("2").unwrap();
175    series_sum += term;
176
177    term = (term * x) / ONE_20 / I256::from_str("3").unwrap();
178    series_sum += term;
179
180    term = (term * x) / ONE_20 / I256::from_str("4").unwrap();
181    series_sum += term;
182
183    term = (term * x) / ONE_20 / I256::from_str("5").unwrap();
184    series_sum += term;
185
186    term = (term * x) / ONE_20 / I256::from_str("6").unwrap();
187    series_sum += term;
188
189    term = (term * x) / ONE_20 / I256::from_str("7").unwrap();
190    series_sum += term;
191
192    term = (term * x) / ONE_20 / I256::from_str("8").unwrap();
193    series_sum += term;
194
195    term = (term * x) / ONE_20 / I256::from_str("9").unwrap();
196    series_sum += term;
197
198    term = (term * x) / ONE_20 / I256::from_str("10").unwrap();
199    series_sum += term;
200
201    term = (term * x) / ONE_20 / I256::from_str("11").unwrap();
202    series_sum += term;
203
204    term = (term * x) / ONE_20 / I256::from_str("12").unwrap();
205    series_sum += term;
206
207    // 12 Taylor terms are sufficient for 18 decimal precision.
208
209    // Finally, we multiply by 2^7 / 2^7 = 1 and add all the terms up to compute the result.
210    // If the first argument is 0 (x = 0), then we want the result to be 1, as e^0 = 1.
211    let result = ((product * series_sum) / ONE_20) * first_an / I256::from_str("100").unwrap();
212
213    Ok(result)
214}
215
216/// Calculate natural logarithm ln(x) with signed 18 decimal fixed point argument
217fn ln(x: &I256) -> Result<I256, PoolError> {
218    let mut a = *x;
219
220    if a < ONE_18 {
221        // Since ln(a^k) = k * ln(a), we can compute ln(a) as ln(a) = ln((1/a)^(-1)) = - ln((1/a)). If a is less
222        // than one, 1/a will be greater than one, and this if statement will not be entered in the recursive call.
223        // Fixed point division requires multiplying by ONE_18.
224        return Ok(I256::MINUS_ONE * ln(&((ONE_18 * ONE_18) / a))?);
225    }
226
227    // First, we use the fact that ln^(a * b) = ln(a) + ln(b) to decompose ln(a) into a sum of powers of two, which
228    // we call x_n, where x_n == 2^(7 - n), which are the natural logarithm of precomputed quantities a_n (that is,
229    // ln(a_n) = x_n). We choose the first x_n, x0, to equal 2^7 because the exponential of all larger powers cannot
230    // be represented as 18 fixed point decimal numbers in 256 bits, and are therefore larger than a.
231    // At the end of this process we will have the sum of all x_n = ln(a_n) that apply, and the remainder of this
232    // decomposition, which will be lower than the smallest a_n.
233    // ln(a) = k_0 * x_0 + k_1 * x_1 + ... + k_n * x_n + ln(remainder), where each k_n equals either 0 or 1.
234    // We mutate a by subtracting a_n, making it the remainder of the decomposition.
235
236    // For reasons related to how `exp` works, the first two a_n (e^(2^7) and e^(2^6)) are not stored as fixed point
237    // numbers with 18 decimals, but instead as plain integers with 0 decimals, so we need to multiply them by
238    // ONE_18 to convert them to fixed point.
239    // For each a_n, we test if that term is present in the decomposition (if a is larger than it), and if so divide
240    // by it and compute the accumulated sum.
241
242    let mut sum = I256::ZERO;
243    if a >= (A0 * ONE_18) {
244        a /= A0; // Integer, not fixed point division
245        sum += X0;
246    }
247
248    if a >= (A1 * ONE_18) {
249        a /= A1; // Integer, not fixed point division
250        sum += X1;
251    }
252
253    // All other a_n and x_n are stored as 20 digit fixed point numbers, so we convert the sum and a to this format.
254    sum *= I256::from_str("100").unwrap();
255    a *= I256::from_str("100").unwrap();
256
257    // Because further a_n are 20 digit fixed point numbers, we multiply by ONE_20 when dividing by them.
258
259    if a >= A2 {
260        a = (a * ONE_20) / A2;
261        sum += X2;
262    }
263
264    if a >= A3 {
265        a = (a * ONE_20) / A3;
266        sum += X3;
267    }
268
269    if a >= A4 {
270        a = (a * ONE_20) / A4;
271        sum += X4;
272    }
273
274    if a >= A5 {
275        a = (a * ONE_20) / A5;
276        sum += X5;
277    }
278
279    if a >= A6 {
280        a = (a * ONE_20) / A6;
281        sum += X6;
282    }
283
284    if a >= A7 {
285        a = (a * ONE_20) / A7;
286        sum += X7;
287    }
288
289    if a >= A8 {
290        a = (a * ONE_20) / A8;
291        sum += X8;
292    }
293
294    if a >= A9 {
295        a = (a * ONE_20) / A9;
296        sum += X9;
297    }
298
299    if a >= A10 {
300        a = (a * ONE_20) / A10;
301        sum += X10;
302    }
303
304    if a >= A11 {
305        a = (a * ONE_20) / A11;
306        sum += X11;
307    }
308
309    // a is now a small number (smaller than a_11, which roughly equals 1.06). This means we can use a Taylor series
310    // that converges rapidly for values of `a` close to one - the same one used in ln_36.
311    // Let z = (a - 1) / (a + 1).
312    // ln(a) = 2 * (z + z^3 / 3 + z^5 / 5 + z^7 / 7 + ... + z^(2 * n + 1) / (2 * n + 1))
313
314    // Recall that 20 digit fixed point division requires multiplying by ONE_20, and multiplication requires
315    // division by ONE_20.
316    let z = ((a - ONE_20) * ONE_20) / (a + ONE_20);
317    let z_squared = (z * z) / ONE_20;
318
319    // num is the numerator of the series: the z^(2 * n + 1) term
320    let mut num = z;
321
322    // seriesSum holds the accumulated sum of each term in the series, starting with the initial z
323    let mut series_sum = num;
324
325    // In each step, the numerator is multiplied by z^2
326    num = (num * z_squared) / ONE_20;
327    series_sum += num / I256::from_str("3").unwrap();
328
329    num = (num * z_squared) / ONE_20;
330    series_sum += num / I256::from_str("5").unwrap();
331
332    num = (num * z_squared) / ONE_20;
333    series_sum += num / I256::from_str("7").unwrap();
334
335    num = (num * z_squared) / ONE_20;
336    series_sum += num / I256::from_str("9").unwrap();
337
338    num = (num * z_squared) / ONE_20;
339    series_sum += num / I256::from_str("11").unwrap();
340
341    // 6 Taylor terms are sufficient for 36 decimal precision.
342
343    // Finally, we multiply by 2 (non fixed point) to compute ln(remainder)
344    series_sum *= I256::from_str("2").unwrap();
345
346    // We now have the sum of all x_n present, and the Taylor approximation of the logarithm of the remainder (both
347    // with 20 decimals). All that remains is to sum these two, and then drop two digits to return a 18 decimal
348    // value.
349
350    Ok((sum + series_sum) / I256::from_str("100").unwrap())
351}
352
353/// Calculate natural logarithm with 36 decimal precision
354fn ln_36(x: &I256) -> Result<I256, PoolError> {
355    let mut x = *x;
356    // Since ln(1) = 0, a value of x close to one will yield a very small result, which makes using 36 digits
357    // worthwhile.
358
359    // First, we transform x to a 36 digit fixed point value.
360    x *= ONE_18;
361
362    // We will use the following Taylor expansion, which converges very rapidly. Let z = (x - 1) / (x + 1).
363    // ln(x) = 2 * (z + z^3 / 3 + z^5 / 5 + z^7 / 7 + ... + z^(2 * n + 1) / (2 * n + 1))
364
365    // Recall that 36 digit fixed point division requires multiplying by ONE_36, and multiplication requires
366    // division by ONE_36.
367    let z = ((x - RAY) * RAY) / (x + RAY);
368    let z_squared = (z * z) / RAY;
369
370    // num is the numerator of the series: the z^(2 * n + 1) term
371    let mut num = z;
372
373    // seriesSum holds the accumulated sum of each term in the series, starting with the initial z
374    let mut series_sum = num;
375
376    // In each step, the numerator is multiplied by z^2
377    num = (num * z_squared) / RAY;
378    series_sum += num / I256::from_str("3").unwrap();
379
380    num = (num * z_squared) / RAY;
381    series_sum += num / I256::from_str("5").unwrap();
382
383    num = (num * z_squared) / RAY;
384    series_sum += num / I256::from_str("7").unwrap();
385
386    num = (num * z_squared) / RAY;
387    series_sum += num / I256::from_str("9").unwrap();
388
389    num = (num * z_squared) / RAY;
390    series_sum += num / I256::from_str("11").unwrap();
391
392    num = (num * z_squared) / RAY;
393    series_sum += num / I256::from_str("13").unwrap();
394
395    num = (num * z_squared) / RAY;
396    series_sum += num / I256::from_str("15").unwrap();
397
398    // 8 Taylor terms are sufficient for 36 decimal precision.
399
400    // All that remains is multiplying by 2 (non fixed point).
401    Ok(series_sum * I256::from_str("2").unwrap())
402}