Skip to main content

num_modular/
double.rs

1//! This module implements a double width integer type based on the largest built-in integer (u128)
2//! Part of the optimization comes from `ethnum` and `zkp-u256` crates.
3
4use core::ops::*;
5
6/// Alias of the builtin integer type with max width (currently [u128])
7#[allow(non_camel_case_types)]
8pub type umax = u128;
9
10/// Alias of the builtin signed integer type with max width (currently [i128])
11#[allow(non_camel_case_types)]
12pub type imax = i128;
13
14const HALF_BITS: u32 = umax::BITS / 2;
15
16// Split umax into hi and lo parts. Tt's critical to use inline here
17#[inline(always)]
18const fn split(v: umax) -> (umax, umax) {
19    (v >> HALF_BITS, v & (umax::MAX >> HALF_BITS))
20}
21
22#[inline(always)]
23const fn div_rem(n: umax, d: umax) -> (umax, umax) {
24    (n / d, n % d)
25}
26
27#[allow(non_camel_case_types)]
28#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
29/// A double width integer type based on the largest built-in integer type [umax] (currently [u128]), and
30/// to support double-width operations on it is the only goal for this type.
31///
32/// Although it can be regarded as u256, it's not as feature-rich as in other crates
33/// since it's only designed to support this crate and few other crates (will be noted in comments).
34pub struct udouble {
35    /// Most significant part
36    pub hi: umax,
37    /// Least significant part
38    pub lo: umax,
39}
40
41impl udouble {
42    pub const MAX: Self = Self {
43        lo: umax::MAX,
44        hi: umax::MAX,
45    };
46
47    //> (used in u128::addm)
48    #[inline]
49    pub const fn widening_add(lhs: umax, rhs: umax) -> Self {
50        let (sum, carry) = lhs.overflowing_add(rhs);
51        udouble {
52            hi: carry as umax,
53            lo: sum,
54        }
55    }
56
57    /// Calculate multiplication of two [umax] integers with result represented in double width integer
58    // equivalent to umul_ppmm, can be implemented efficiently with carrying_mul and widening_mul implemented (rust#85532)
59    //> (used in u128::mulm, MersenneInt, Montgomery::<u128>::{reduce, mul}, num-order::NumHash)
60    #[inline]
61    pub const fn widening_mul(lhs: umax, rhs: umax) -> Self {
62        let ((x1, x0), (y1, y0)) = (split(lhs), split(rhs));
63
64        let z2 = x1 * y1;
65        let (c0, z0) = split(x0 * y0); // c0 <= umax::MAX - 1
66        let (c1, z1) = split(x1 * y0 + c0);
67        let z2 = z2 + c1;
68        let (c1, z1) = split(x0 * y1 + z1);
69        Self {
70            hi: z2 + c1,
71            lo: z0 | z1 << HALF_BITS,
72        }
73    }
74
75    /// Optimized squaring function for [umax] integers
76    //> (used in Montgomery::<u128>::{square})
77    #[inline]
78    pub const fn widening_square(x: umax) -> Self {
79        // the algorithm here is basically the same as widening_mul
80        let (x1, x0) = split(x);
81
82        let z2 = x1 * x1;
83        let m = x1 * x0;
84        let (c0, z0) = split(x0 * x0);
85        let (c1, z1) = split(m + c0);
86        let z2 = z2 + c1;
87        let (c1, z1) = split(m + z1);
88        Self {
89            hi: z2 + c1,
90            lo: z0 | z1 << HALF_BITS,
91        }
92    }
93
94    //> (used in Montgomery::<u128>::reduce)
95    #[inline]
96    pub const fn overflowing_add(&self, rhs: Self) -> (Self, bool) {
97        let (lo, carry) = self.lo.overflowing_add(rhs.lo);
98        let (hi, of1) = self.hi.overflowing_add(rhs.hi);
99        let (hi, of2) = hi.overflowing_add(carry as umax);
100        (Self { lo, hi }, of1 || of2)
101    }
102
103    // double by double multiplication, listed here in case of future use
104    #[allow(dead_code)]
105    fn overflowing_mul(&self, rhs: Self) -> (Self, bool) {
106        let c2 = self.hi != 0 && rhs.hi != 0;
107        let Self { lo: z0, hi: c0 } = Self::widening_mul(self.lo, rhs.lo);
108        let (z1x, c1x) = umax::overflowing_mul(self.lo, rhs.hi);
109        let (z1y, c1y) = umax::overflowing_mul(self.hi, rhs.lo);
110        let (z1z, c1z) = umax::overflowing_add(z1x, z1y);
111        let (z1, c1) = z1z.overflowing_add(c0);
112        (Self { hi: z1, lo: z0 }, c1x | c1y | c1z | c1 | c2)
113    }
114
115    /// Multiplication of double width and single width
116    //> (used in num-order:NumHash)
117    #[inline]
118    pub const fn overflowing_mul1(&self, rhs: umax) -> (Self, bool) {
119        let Self { lo: z0, hi: c0 } = Self::widening_mul(self.lo, rhs);
120        let (z1, c1) = self.hi.overflowing_mul(rhs);
121        let (z1, cs1) = z1.overflowing_add(c0);
122        (Self { hi: z1, lo: z0 }, c1 | cs1)
123    }
124
125    /// Multiplication of double width and single width
126    //> (used in Self::mul::<umax>)
127    #[inline]
128    pub fn checked_mul1(&self, rhs: umax) -> Option<Self> {
129        let Self { lo: z0, hi: c0 } = Self::widening_mul(self.lo, rhs);
130        let z1 = self.hi.checked_mul(rhs)?.checked_add(c0)?;
131        Some(Self { hi: z1, lo: z0 })
132    }
133
134    //> (used in num-order::NumHash)
135    #[inline]
136    pub fn checked_shl(self, rhs: u32) -> Option<Self> {
137        if rhs < umax::BITS * 2 {
138            Some(self << rhs)
139        } else {
140            None
141        }
142    }
143
144    //> (not used yet)
145    #[inline]
146    pub fn checked_shr(self, rhs: u32) -> Option<Self> {
147        if rhs < umax::BITS * 2 {
148            Some(self >> rhs)
149        } else {
150            None
151        }
152    }
153}
154
155impl From<umax> for udouble {
156    #[inline]
157    fn from(v: umax) -> Self {
158        Self { lo: v, hi: 0 }
159    }
160}
161
162impl Add for udouble {
163    type Output = udouble;
164
165    // equivalent to add_ssaaaa
166    #[inline]
167    fn add(self, rhs: Self) -> Self::Output {
168        let (lo, carry) = self.lo.overflowing_add(rhs.lo);
169        let hi = self.hi + rhs.hi + carry as umax;
170        Self { lo, hi }
171    }
172}
173//> (used in Self::div_rem)
174impl Add<umax> for udouble {
175    type Output = udouble;
176    #[inline]
177    fn add(self, rhs: umax) -> Self::Output {
178        let (lo, carry) = self.lo.overflowing_add(rhs);
179        let hi = if carry { self.hi + 1 } else { self.hi };
180        Self { lo, hi }
181    }
182}
183impl AddAssign for udouble {
184    #[inline]
185    fn add_assign(&mut self, rhs: Self) {
186        let (lo, carry) = self.lo.overflowing_add(rhs.lo);
187        self.lo = lo;
188        self.hi += rhs.hi + carry as umax;
189    }
190}
191impl AddAssign<umax> for udouble {
192    #[inline]
193    fn add_assign(&mut self, rhs: umax) {
194        let (lo, carry) = self.lo.overflowing_add(rhs);
195        self.lo = lo;
196        if carry {
197            self.hi += 1
198        }
199    }
200}
201
202//> (used in test of Add)
203impl Sub for udouble {
204    type Output = Self;
205    #[inline]
206    fn sub(self, rhs: Self) -> Self::Output {
207        let carry = self.lo < rhs.lo;
208        let lo = self.lo.wrapping_sub(rhs.lo);
209        let hi = self.hi - rhs.hi - carry as umax;
210        Self { lo, hi }
211    }
212}
213impl Sub<umax> for udouble {
214    type Output = Self;
215    #[inline]
216    fn sub(self, rhs: umax) -> Self::Output {
217        let carry = self.lo < rhs;
218        let lo = self.lo.wrapping_sub(rhs);
219        let hi = if carry { self.hi - 1 } else { self.hi };
220        Self { lo, hi }
221    }
222}
223//> (used in test of AddAssign)
224impl SubAssign for udouble {
225    #[inline]
226    fn sub_assign(&mut self, rhs: Self) {
227        let carry = self.lo < rhs.lo;
228        self.lo = self.lo.wrapping_sub(rhs.lo);
229        self.hi -= rhs.hi + carry as umax;
230    }
231}
232impl SubAssign<umax> for udouble {
233    #[inline]
234    fn sub_assign(&mut self, rhs: umax) {
235        let carry = self.lo < rhs;
236        self.lo = self.lo.wrapping_sub(rhs);
237        if carry {
238            self.hi -= 1;
239        }
240    }
241}
242
243macro_rules! impl_sh_ops {
244    ($t:ty) => {
245        //> (used in Self::checked_shl)
246        impl Shl<$t> for udouble {
247            type Output = Self;
248            #[inline]
249            fn shl(self, rhs: $t) -> Self::Output {
250                match rhs {
251                    0 => self, // avoid shifting by full bits, which is UB
252                    s if s >= umax::BITS as $t => Self {
253                        hi: self.lo << (s - umax::BITS as $t),
254                        lo: 0,
255                    },
256                    s => Self {
257                        lo: self.lo << s,
258                        hi: (self.hi << s) | (self.lo >> (umax::BITS as $t - s)),
259                    },
260                }
261            }
262        }
263        //> (not used yet)
264        impl ShlAssign<$t> for udouble {
265            #[inline]
266            fn shl_assign(&mut self, rhs: $t) {
267                match rhs {
268                    0 => {}
269                    s if s >= umax::BITS as $t => {
270                        self.hi = self.lo << (s - umax::BITS as $t);
271                        self.lo = 0;
272                    }
273                    s => {
274                        self.hi <<= s;
275                        self.hi |= self.lo >> (umax::BITS as $t - s);
276                        self.lo <<= s;
277                    }
278                }
279            }
280        }
281        //> (used in Self::checked_shr)
282        impl Shr<$t> for udouble {
283            type Output = Self;
284            #[inline]
285            fn shr(self, rhs: $t) -> Self::Output {
286                match rhs {
287                    0 => self,
288                    s if s >= umax::BITS as $t => Self {
289                        lo: self.hi >> (rhs - umax::BITS as $t),
290                        hi: 0,
291                    },
292                    s => Self {
293                        hi: self.hi >> s,
294                        lo: (self.lo >> s) | (self.hi << (umax::BITS as $t - s)),
295                    },
296                }
297            }
298        }
299        //> (not used yet)
300        impl ShrAssign<$t> for udouble {
301            #[inline]
302            fn shr_assign(&mut self, rhs: $t) {
303                match rhs {
304                    0 => {}
305                    s if s >= umax::BITS as $t => {
306                        self.lo = self.hi >> (rhs - umax::BITS as $t);
307                        self.hi = 0;
308                    }
309                    s => {
310                        self.lo >>= s;
311                        self.lo |= self.hi << (umax::BITS as $t - s);
312                        self.hi >>= s;
313                    }
314                }
315            }
316        }
317    };
318}
319
320// only implement most useful ones, so that we don't need to optimize so many variants
321impl_sh_ops!(u8);
322impl_sh_ops!(u16);
323impl_sh_ops!(u32);
324
325//> (not used yet)
326impl BitAnd for udouble {
327    type Output = Self;
328    #[inline]
329    fn bitand(self, rhs: Self) -> Self::Output {
330        Self {
331            lo: self.lo & rhs.lo,
332            hi: self.hi & rhs.hi,
333        }
334    }
335}
336//> (not used yet)
337impl BitAndAssign for udouble {
338    #[inline]
339    fn bitand_assign(&mut self, rhs: Self) {
340        self.lo &= rhs.lo;
341        self.hi &= rhs.hi;
342    }
343}
344//> (not used yet)
345impl BitOr for udouble {
346    type Output = Self;
347    #[inline]
348    fn bitor(self, rhs: Self) -> Self::Output {
349        Self {
350            lo: self.lo | rhs.lo,
351            hi: self.hi | rhs.hi,
352        }
353    }
354}
355//> (not used yet)
356impl BitOrAssign for udouble {
357    #[inline]
358    fn bitor_assign(&mut self, rhs: Self) {
359        self.lo |= rhs.lo;
360        self.hi |= rhs.hi;
361    }
362}
363//> (not used yet)
364impl BitXor for udouble {
365    type Output = Self;
366    #[inline]
367    fn bitxor(self, rhs: Self) -> Self::Output {
368        Self {
369            lo: self.lo ^ rhs.lo,
370            hi: self.hi ^ rhs.hi,
371        }
372    }
373}
374//> (not used yet)
375impl BitXorAssign for udouble {
376    #[inline]
377    fn bitxor_assign(&mut self, rhs: Self) {
378        self.lo ^= rhs.lo;
379        self.hi ^= rhs.hi;
380    }
381}
382//> (not used yet)
383impl Not for udouble {
384    type Output = Self;
385    #[inline]
386    fn not(self) -> Self::Output {
387        Self {
388            lo: !self.lo,
389            hi: !self.hi,
390        }
391    }
392}
393
394impl udouble {
395    //> (used in Self::div_rem)
396    #[inline]
397    pub const fn leading_zeros(self) -> u32 {
398        if self.hi == 0 {
399            self.lo.leading_zeros() + umax::BITS
400        } else {
401            self.hi.leading_zeros()
402        }
403    }
404
405    // double by double division (long division), it's not the most efficient algorithm.
406    // listed here in case of future use
407    #[allow(dead_code)]
408    fn div_rem_2by2(self, other: Self) -> (Self, Self) {
409        let mut n = self; // numerator
410        let mut d = other; // denominator
411        let mut q = Self { lo: 0, hi: 0 }; // quotient
412
413        let nbits = (2 * umax::BITS - n.leading_zeros()) as u16; // assuming umax = u128
414        let dbits = (2 * umax::BITS - d.leading_zeros()) as u16;
415        assert!(dbits != 0, "division by zero");
416
417        // Early return in case we are dividing by a larger number than us
418        if nbits < dbits {
419            return (q, n);
420        }
421
422        // Bitwise long division
423        let mut shift = nbits - dbits;
424        d <<= shift;
425        loop {
426            if n >= d {
427                q += 1;
428                n -= d;
429            }
430            if shift == 0 {
431                break;
432            }
433
434            d >>= 1u8;
435            q <<= 1u8;
436            shift -= 1;
437        }
438        (q, n)
439    }
440
441    // double by single to single division.
442    // equivalent to `udiv_qrnnd` in C or `divq` in assembly.
443    //> (used in Self::{div, rem}::<umax>)
444    fn div_rem_2by1(self, other: umax) -> (umax, umax) {
445        // the following algorithm comes from `ethnum` crate
446        const B: umax = 1 << HALF_BITS; // number base (64 bits)
447
448        // Normalize the divisor.
449        let s = other.leading_zeros();
450        let (n, d) = (self << s, other << s); // numerator, denominator
451        let (d1, d0) = split(d);
452        let (n1, n0) = split(n.lo); // split lower part of dividend
453
454        // Compute the first quotient digit q1.
455        let (mut q1, mut rhat) = div_rem(n.hi, d1);
456
457        // q1 has at most error 2. No more than 2 iterations.
458        while q1 >= B || q1 * d0 > B * rhat + n1 {
459            q1 -= 1;
460            rhat += d1;
461            if rhat >= B {
462                break;
463            }
464        }
465
466        let r21 =
467            n.hi.wrapping_mul(B)
468                .wrapping_add(n1)
469                .wrapping_sub(q1.wrapping_mul(d));
470
471        // Compute the second quotient digit q0.
472        let (mut q0, mut rhat) = div_rem(r21, d1);
473
474        // q0 has at most error 2. No more than 2 iterations.
475        while q0 >= B || q0 * d0 > B * rhat + n0 {
476            q0 -= 1;
477            rhat += d1;
478            if rhat >= B {
479                break;
480            }
481        }
482
483        let r = (r21
484            .wrapping_mul(B)
485            .wrapping_add(n0)
486            .wrapping_sub(q0.wrapping_mul(d)))
487            >> s;
488        let q = q1 * B + q0;
489        (q, r)
490    }
491}
492
493impl Mul<umax> for udouble {
494    type Output = Self;
495    #[inline]
496    fn mul(self, rhs: umax) -> Self::Output {
497        self.checked_mul1(rhs).expect("multiplication overflow!")
498    }
499}
500
501impl Div<umax> for udouble {
502    type Output = Self;
503    #[inline]
504    fn div(self, rhs: umax) -> Self::Output {
505        // self.div_rem(rhs.into()).0
506        if self.hi < rhs {
507            // The result fits in 128 bits.
508            Self {
509                lo: self.div_rem_2by1(rhs).0,
510                hi: 0,
511            }
512        } else {
513            let (q, r) = div_rem(self.hi, rhs);
514            Self {
515                lo: Self { lo: self.lo, hi: r }.div_rem_2by1(rhs).0,
516                hi: q,
517            }
518        }
519    }
520}
521
522//> (used in Montgomery::<u128>::transform)
523impl Rem<umax> for udouble {
524    type Output = umax;
525    #[inline]
526    fn rem(self, rhs: umax) -> Self::Output {
527        if self.hi < rhs {
528            // The result fits in 128 bits.
529            self.div_rem_2by1(rhs).1
530        } else {
531            Self {
532                lo: self.lo,
533                hi: self.hi % rhs,
534            }
535            .div_rem_2by1(rhs)
536            .1
537        }
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544    use rand::random;
545
546    #[test]
547    fn test_construction() {
548        // from widening operators
549        assert_eq!(udouble { hi: 0, lo: 2 }, udouble::widening_add(1, 1));
550        assert_eq!(
551            udouble {
552                hi: 1,
553                lo: umax::MAX - 1
554            },
555            udouble::widening_add(umax::MAX, umax::MAX)
556        );
557
558        assert_eq!(udouble { hi: 0, lo: 1 }, udouble::widening_mul(1, 1));
559        assert_eq!(udouble { hi: 0, lo: 1 }, udouble::widening_square(1));
560        assert_eq!(
561            udouble { hi: 1 << 32, lo: 0 },
562            udouble::widening_mul(1 << 80, 1 << 80)
563        );
564        assert_eq!(
565            udouble { hi: 1 << 32, lo: 0 },
566            udouble::widening_square(1 << 80)
567        );
568        assert_eq!(
569            udouble {
570                hi: 1 << 32,
571                lo: 2 << 120 | 1 << 80
572            },
573            udouble::widening_mul(1 << 80 | 1 << 40, 1 << 80 | 1 << 40)
574        );
575        assert_eq!(
576            udouble {
577                hi: 1 << 32,
578                lo: 2 << 120 | 1 << 80
579            },
580            udouble::widening_square(1 << 80 | 1 << 40)
581        );
582        assert_eq!(
583            udouble {
584                hi: umax::MAX - 1,
585                lo: 1
586            },
587            udouble::widening_mul(umax::MAX, umax::MAX)
588        );
589        assert_eq!(
590            udouble {
591                hi: umax::MAX - 1,
592                lo: 1
593            },
594            udouble::widening_square(umax::MAX)
595        );
596    }
597
598    #[test]
599    fn test_ops() {
600        const ONE: udouble = udouble { hi: 0, lo: 1 };
601        const TWO: udouble = udouble { hi: 0, lo: 2 };
602        const MAX: udouble = udouble {
603            hi: 0,
604            lo: umax::MAX,
605        };
606        const ONEZERO: udouble = udouble { hi: 1, lo: 0 };
607        const ONEMAX: udouble = udouble {
608            hi: 1,
609            lo: umax::MAX,
610        };
611        const TWOZERO: udouble = udouble { hi: 2, lo: 0 };
612
613        assert_eq!(ONE + MAX, ONEZERO);
614        assert_eq!(ONE + ONEMAX, TWOZERO);
615        assert_eq!(ONEZERO - ONE, MAX);
616        assert_eq!(ONEZERO - MAX, ONE);
617        assert_eq!(TWOZERO - ONE, ONEMAX);
618        assert_eq!(TWOZERO - ONEMAX, ONE);
619
620        assert_eq!(ONE << umax::BITS, ONEZERO);
621        assert_eq!((MAX << 1u8) + 1, ONEMAX);
622        assert_eq!(
623            ONE << 200u8,
624            udouble {
625                lo: 0,
626                hi: 1 << (200 - umax::BITS)
627            }
628        );
629        assert_eq!(ONEZERO >> umax::BITS, ONE);
630        assert_eq!(ONEMAX >> 1u8, MAX);
631
632        assert_eq!(ONE * MAX.lo, MAX);
633        assert_eq!(ONEMAX * ONE.lo, ONEMAX);
634        assert_eq!(ONEMAX * TWO.lo, ONEMAX + ONEMAX);
635        assert_eq!(MAX / ONE.lo, MAX);
636        assert_eq!(MAX / MAX.lo, ONE);
637        assert_eq!(ONE / MAX.lo, udouble { lo: 0, hi: 0 });
638        assert_eq!(ONEMAX / ONE.lo, ONEMAX);
639        assert_eq!(ONEMAX / MAX.lo, TWO);
640        assert_eq!(ONEMAX / TWO.lo, MAX);
641        assert_eq!(ONE % MAX.lo, 1);
642        assert_eq!(TWO % MAX.lo, 2);
643        assert_eq!(ONEMAX % MAX.lo, 1);
644        assert_eq!(ONEMAX % TWO.lo, 1);
645
646        assert_eq!(ONEMAX.checked_mul1(MAX.lo), None);
647        assert_eq!(TWOZERO.checked_mul1(MAX.lo), None);
648    }
649
650    #[test]
651    fn test_assign_ops() {
652        for _ in 0..10 {
653            let x = udouble {
654                hi: random::<u32>() as umax,
655                lo: random(),
656            };
657            let y = udouble {
658                hi: random::<u32>() as umax,
659                lo: random(),
660            };
661            let mut z = x;
662
663            z += y;
664            assert_eq!(z, x + y);
665            z -= y;
666            assert_eq!(z, x);
667        }
668    }
669}