Skip to main content

lean_decimal/
int128.rs

1use crate::{UnderlyingInt, bits_to_digits};
2
3impl UnderlyingInt for u128 {
4    const ZERO: Self = 0;
5    const ONE: Self = 1;
6    const TEN: Self = 10;
7    const HUNDRED: Self = 100;
8    const MAX_MATISSA: Self = Self::MAX >> Self::META_BITS;
9    const MIN_UNDERINT: Self = (1 << 127) | Self::MAX_MATISSA;
10
11    const BITS: u32 = 128;
12    const MAX_SCALE: u32 = 36;
13
14    type Signed = i128;
15
16    fn to_signed(self, sign: u8) -> Self::Signed {
17        let i = self as i128; // self as mantissa fits 127-bits
18        if sign == 0 { i } else { -i }
19    }
20    fn from_signed(s: Self::Signed) -> (Self, u8) {
21        (s.unsigned_abs(), (s < 0) as u8)
22    }
23
24    fn as_u32(self) -> u32 {
25        self as u32
26    }
27    fn from_u32(n: u32) -> Self {
28        n as Self
29    }
30
31    fn leading_zeros(self) -> u32 {
32        self.leading_zeros()
33    }
34
35    // caller must make sure that no overflow
36    fn mul_exp(self, iexp: u32) -> Self {
37        self * get_exp(iexp)
38    }
39
40    // we check the overflow
41    fn checked_mul_exp(self, iexp: u32) -> Option<Self> {
42        self.checked_mul(get_exp(iexp))
43    }
44
45    fn div_exp(self, iexp: u32) -> Self {
46        let n = self + get_exp(iexp) / 2; // no addition overflow. exp is even.
47
48        // SAFETY: self < MAX_MANTISSA, so n fits in 127-bit
49        unsafe { div_pow10::bit128::unchecked_div_single_r1b(n, iexp) }
50    }
51
52    fn div_rem_exp(self, iexp: u32) -> (Self, Self) {
53        // SAFETY: self < MAX_MANTISSA, so n fits in 127-bit
54        let q = unsafe { div_pow10::bit128::unchecked_div_single_r1b(self, iexp) };
55        (q, self - q * get_exp(iexp))
56    }
57
58    #[inline]
59    fn mul_with_sum_scale(self, right: Self, sum_scale: u32) -> Option<(Self, u32)> {
60        if self.leading_zeros() + right.leading_zeros() >= Self::BITS + Self::META_BITS {
61            // fast path, keep the code simple
62            let p = self * right;
63            if sum_scale <= Self::MAX_SCALE {
64                Some((p, sum_scale))
65            } else {
66                Some((p.div_exp(sum_scale - Self::MAX_SCALE), Self::MAX_SCALE))
67            }
68        } else {
69            // full path
70            mul_with_sum_scale_full(self, right, sum_scale)
71        }
72    }
73
74    #[inline]
75    fn div_with_scales(self, d: Self, s_scale: u32, d_scale: u32) -> Option<(Self, u32)> {
76        let diff_scale = s_scale.saturating_sub(d_scale);
77        let max_scale = Self::MAX_SCALE - diff_scale;
78
79        // TODO optimize 64-bit divisor too
80        let (mut q, mut r, mut act_scale) = match u32::try_from(d) {
81            Ok(d) => div_with_scales_by32(self, d, max_scale),
82            Err(_) => div_with_scales_full(self, d, max_scale),
83        };
84
85        // increase the scale if d_scale > s_scale
86        let min_scale = d_scale.saturating_sub(s_scale);
87        if act_scale < min_scale {
88            (q, r) = increase_scale(q, r, d, min_scale - act_scale)?;
89            act_scale = min_scale;
90        }
91        // reduce the scale if division exactly
92        else if r == 0 && act_scale > min_scale {
93            let (q0, red_scale) = reduce_scale(q, act_scale - min_scale);
94            q = q0;
95            act_scale -= red_scale;
96        }
97
98        // round
99        if r * 2 >= d {
100            if q == u128::MAX_MATISSA {
101                // The final scale must be 0 (but why?), there is no room to reduce
102                debug_assert_eq!(diff_scale + act_scale - min_scale, 0);
103                return None;
104            }
105            q += 1;
106        }
107
108        Some((q, diff_scale + act_scale - min_scale))
109    }
110}
111
112// the caller must make sure that: @i < 39
113fn get_exp(i: u32) -> u128 {
114    // Although in most cases, 36 (which is MAX_SCALE) is enough,
115    // but in some cases we need more.
116    debug_assert!(i < 39);
117
118    unsafe { *ALL_EXPS.get_unchecked(i as usize) }
119}
120
121// calculate: n / d and n % d
122#[inline]
123fn div_rem(n: u128, d: u128) -> (u128, u128) {
124    if (n | d) <= u64::MAX as u128 {
125        let n64 = n as u64;
126        let d64 = d as u64;
127        ((n64 / d64) as u128, (n64 % d64) as u128)
128    } else {
129        (n / d, n % d)
130    }
131}
132
133// calculate: a * b = (mhigh,mlow)
134#[inline]
135const fn mul2(a: u128, b: u128) -> (u128, u128) {
136    let (ahigh, alow) = (a >> 64, a & u64::MAX as u128);
137    let (bhigh, blow) = (b >> 64, b & u64::MAX as u128);
138
139    let (mid, carry1) = (alow * bhigh).overflowing_add(ahigh * blow);
140    let (mlow, carry2) = (alow * blow).overflowing_add(mid << 64);
141    let mhigh = ahigh * bhigh + (mid >> 64) + ((carry1 as u128) << 64) + carry2 as u128;
142    (mhigh, mlow)
143}
144
145fn reduce_scale(n: u128, max_scale: u32) -> (u128, u32) {
146    if n == 0 {
147        return (0, 0);
148    }
149    if n < 1 << 96 {
150        reduce_scale_96(n, max_scale)
151    } else {
152        reduce_scale_full(n, max_scale)
153    }
154}
155
156fn reduce_scale_96(mut n: u128, max_scale: u32) -> (u128, u32) {
157    let mut left_scale = max_scale;
158    while n as u8 == 0 && left_scale >= 8 {
159        let (q, r) = div96_by32(n, 10000_0000);
160        if r != 0 {
161            break;
162        }
163        n = q;
164        left_scale -= 8;
165    }
166    if n & 0xF == 0 && left_scale >= 4 {
167        let (q, r) = div96_by32(n, 10000);
168        if r == 0 {
169            n = q;
170            left_scale -= 4;
171        }
172    }
173    if n & 0x3 == 0 && left_scale >= 2 {
174        let (q, r) = div96_by32(n, 100);
175        if r == 0 {
176            n = q;
177            left_scale -= 2;
178        }
179    }
180    if n & 0x1 == 0 && left_scale >= 1 {
181        let (q, r) = div96_by32(n, 10);
182        if r == 0 {
183            n = q;
184            left_scale -= 1;
185        }
186    }
187    (n, max_scale - left_scale)
188}
189
190fn reduce_scale_full(mut n: u128, max_scale: u32) -> (u128, u32) {
191    let mut left_scale = max_scale;
192    while n as u8 == 0 && left_scale >= 8 {
193        // SAFETY: n < MAX_MANTISSA, so fits in 127-bit
194        let q = unsafe { div_pow10::bit128::unchecked_div_single_r1b(n, 8) };
195        if (q as u32).wrapping_mul(10000_0000) != n as u32 {
196            break;
197        }
198        n = q;
199        left_scale -= 8;
200    }
201    if n & 0xF == 0 && left_scale >= 4 {
202        // SAFETY: n < MAX_MANTISSA, so fits in 127-bit
203        let q = unsafe { div_pow10::bit128::unchecked_div_single_r1b(n, 4) };
204        if (q as u32).wrapping_mul(10000) == n as u32 {
205            n = q;
206            left_scale -= 4;
207        }
208    }
209    if n & 0x3 == 0 && left_scale >= 2 {
210        // SAFETY: n < MAX_MANTISSA, so fits in 127-bit
211        let q = unsafe { div_pow10::bit128::unchecked_div_single_r1b(n, 2) };
212        if (q as u32).wrapping_mul(100) == n as u32 {
213            n = q;
214            left_scale -= 2;
215        }
216    }
217    if n & 0x1 == 0 && left_scale >= 1 {
218        // SAFETY: n < MAX_MANTISSA, so fits in 127-bit
219        let q = unsafe { div_pow10::bit128::unchecked_div_single_r1b(n, 1) };
220        if (q as u32).wrapping_mul(10) == n as u32 {
221            n = q;
222            left_scale -= 1;
223        }
224    }
225    (n, max_scale - left_scale)
226}
227
228#[cold]
229fn increase_scale(q: u128, r: u128, d: u128, scale: u32) -> Option<(u128, u128)> {
230    let (q2, r2) = div_rem(r.checked_mul_exp(scale)?, d);
231    let q = q.checked_mul_exp(scale)?.checked_add(q2)?;
232    if q <= u128::MAX_MATISSA {
233        Some((q, r2))
234    } else {
235        None
236    }
237}
238
239fn mul_with_sum_scale_full(a: u128, b: u128, sum_scale: u32) -> Option<(u128, u32)> {
240    let (high, low) = mul2(a, b);
241
242    if high == 0 {
243        // the production @low is in range [MAX_MATISSA / 2, u128::MAX]
244        return big128_with_sum_scale(low, sum_scale);
245    }
246
247    // check the mantissa @high..@low
248    //
249    // It's hard to calculate how many digits to shrink exactly, so here we
250    // get the ceiling value @clear_digits, which may be 1 more than need.
251    // The value may be MAX_SCALE+1 at biggest.
252    let bits = 128 + u128::META_BITS - high.leading_zeros();
253    let mut clear_digits = bits_to_digits(bits) + 1; // +1 for ceiling
254
255    // check the scale @sum_scale
256    if sum_scale > u128::MAX_SCALE {
257        // normal case
258        clear_digits = clear_digits.max(sum_scale - u128::MAX_SCALE);
259    } else if clear_digits > sum_scale {
260        // edge case, overflow, return None
261        if clear_digits == sum_scale + 1 {
262            // edge case in edge case. The overflow may be false because
263            // the @clear_digits maybe 1 more than need. So here we give
264            // it one more chance by decreasing it by 1 and check it later.
265            clear_digits -= 1;
266        } else {
267            return None;
268        }
269    }
270
271    // prepare for rounding
272    let (low, carry) = low.overflowing_add(get_exp(clear_digits) / 2);
273    let high = high + carry as u128;
274
275    // SAFETY: high > 10.pow(clear_digits) because of the META_BITS
276    let (q, _r) = unsafe { div_pow10::bit128::unchecked_div_double(high, low, clear_digits) };
277
278    // handle the edge case above
279    if q > u128::MAX_MATISSA {
280        debug_assert_eq!(clear_digits, sum_scale);
281        return None;
282    }
283
284    Some((q, sum_scale - clear_digits))
285}
286
287// reduce the @man to under MAX_MANTISSA
288#[cold]
289fn big128_with_sum_scale(man: u128, sum_scale: u32) -> Option<(u128, u32)> {
290    // check the mantissa @man, which is in range [MAX_MATISSA / 2, u128::MAX]
291    let mut clear_digits = if man > u128::MAX_MATISSA * 100 {
292        3
293    } else if man > u128::MAX_MATISSA * 10 {
294        2
295    } else if man > u128::MAX_MATISSA {
296        1
297    } else {
298        0
299    };
300
301    // check the scale @sum_scale
302    if sum_scale > u128::MAX_SCALE {
303        clear_digits = clear_digits.max(sum_scale - u128::MAX_SCALE);
304    } else if clear_digits > sum_scale {
305        return None; // overflow
306    }
307
308    // rescale if need
309    if clear_digits == 0 {
310        Some((man, sum_scale))
311    } else {
312        // can not call div_exp() because @man is larger than MAX_MANTASSI
313        //
314        // SAFETY: man > 10.pow(clear_digits) because of the META_BITS
315        let mut q = unsafe { div_pow10::bit128::unchecked_div_single(man, clear_digits) };
316        let exp = get_exp(clear_digits);
317        let r = man - q * exp;
318        if r >= exp / 2 {
319            q += 1;
320        }
321
322        Some((q, sum_scale - clear_digits))
323    }
324}
325
326fn div_with_scales_by32(n: u128, d: u32, max_scale: u32) -> (u128, u128, u32) {
327    let (q, r) = div128_by32(n, d);
328    if r == 0 {
329        return (q, 0, 0);
330    }
331
332    // find the biggest @act_scale that
333    // - r * 10.pow(act_scale) fits in u128 (128-bit)
334    // - q * 10.pow(act_scale) fits in mantissa (121-bit)
335    // - act_scale <= max_scale
336    let avail_bits = r.leading_zeros().min(q.leading_zeros() - u128::META_BITS);
337    let act_scale = bits_to_digits(avail_bits).min(max_scale);
338    if act_scale == 0 {
339        return (q, r, 0);
340    }
341
342    // We do the division once only, but no loop. So in worst cases (@d is big
343    // and @n < @d), the @q may only have 96 significant bits. Although it is
344    // not fill the 121 bits, it should be enough.
345    let (q2, r2) = div128_by32(r.mul_exp(act_scale), d);
346    (q.mul_exp(act_scale) + q2, r2, act_scale)
347}
348
349fn div_with_scales_full(n: u128, d: u128, max_scale: u32) -> (u128, u128, u32) {
350    let (mut q, mut r) = div_rem(n, d);
351
352    // long division
353    let mut act_scale = 0;
354    while r != 0 {
355        let avail_bits = r.leading_zeros().min(q.leading_zeros() - u128::META_BITS);
356        let scale = bits_to_digits(avail_bits).min(max_scale - act_scale);
357        if scale == 0 {
358            break;
359        }
360
361        r = r.mul_exp(scale);
362        let q2 = r / d;
363        r -= q2 * d;
364
365        q = q.mul_exp(scale) + q2;
366        act_scale += scale;
367    }
368
369    (q, r, act_scale)
370}
371
372fn div96_by32(a: u128, b: u32) -> (u128, u128) {
373    let b = b as u64;
374
375    let high = (a >> 64) as u64;
376    let low = a as u64;
377
378    let n = high << 32 | low >> 32;
379    let q1 = n / b;
380    let r = n % b;
381
382    let n = r << 32 | low & u32::MAX as u64;
383    let q2 = n / b;
384    let r = n % b;
385
386    (((q1 as u128) << 32) | q2 as u128, r as u128)
387}
388
389fn div128_by32(a: u128, b: u32) -> (u128, u128) {
390    let b = b as u64;
391
392    let high = (a >> 64) as u64;
393    let low = a as u64;
394
395    let q0 = high / b;
396    let r = high % b;
397
398    let n = r << 32 | low >> 32;
399    let q1 = n / b;
400    let r = n % b;
401
402    let n = r << 32 | low & u32::MAX as u64;
403    let q2 = n / b;
404    let r = n % b;
405
406    (
407        (q0 as u128) << 64 | ((q1 as u128) << 32) | q2 as u128,
408        r as u128,
409    )
410}
411
412const ALL_EXPS: [u128; 39] = [
413    1,
414    10_u128.pow(1),
415    10_u128.pow(2),
416    10_u128.pow(3),
417    10_u128.pow(4),
418    10_u128.pow(5),
419    10_u128.pow(6),
420    10_u128.pow(7),
421    10_u128.pow(8),
422    10_u128.pow(9),
423    10_u128.pow(10),
424    10_u128.pow(11),
425    10_u128.pow(12),
426    10_u128.pow(13),
427    10_u128.pow(14),
428    10_u128.pow(15),
429    10_u128.pow(16),
430    10_u128.pow(17),
431    10_u128.pow(18),
432    10_u128.pow(19),
433    10_u128.pow(20),
434    10_u128.pow(21),
435    10_u128.pow(22),
436    10_u128.pow(23),
437    10_u128.pow(24),
438    10_u128.pow(25),
439    10_u128.pow(26),
440    10_u128.pow(27),
441    10_u128.pow(28),
442    10_u128.pow(29),
443    10_u128.pow(30),
444    10_u128.pow(31),
445    10_u128.pow(32),
446    10_u128.pow(33),
447    10_u128.pow(34),
448    10_u128.pow(35),
449    10_u128.pow(36),
450    10_u128.pow(37),
451    10_u128.pow(38),
452];