Skip to main content

cryptography/public_key/
bigint.rs

1//! A small pure-Rust bigint foundation for public-key primitives.
2//!
3//! The representation uses little-endian `u64` limbs because the surrounding
4//! algorithms are naturally word-oriented. This is intentionally simple:
5//! schoolbook multiplication and bitwise long division are easy to audit and
6//! track the public-key formulas directly, while keeping the public-key layer
7//! fully in Rust with no external arithmetic backend.
8//!
9//! Local references for planned multiplication-kernel upgrades:
10//! - `pubs/comba-1990-exponentiation-cryptosystems-on-the-ibm-pc.pdf`
11//! - `pubs/karatsuba-ofman-1963-multiplication-of-multidigit-numbers-on-automata.pdf`
12
13use core::cmp::Ordering;
14
15// Heuristic crossover where the recursive split starts beating schoolbook in
16// this pure-Rust implementation on our benchmark hardware.
17const KARATSUBA_THRESHOLD_LIMBS: usize = 32;
18// Limit highly lopsided splits; beyond this ratio the extra recursion/temporary
19// cost usually outweighs Karatsuba's multiplication count reduction.
20const KARATSUBA_MAX_IMBALANCE: usize = 2;
21
22/// Sign of a [`BigInt`].
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24pub enum Sign {
25    /// Strictly positive value.
26    Positive,
27    /// Strictly negative value.
28    Negative,
29    /// Zero.
30    Zero,
31}
32
33/// Unsigned multiprecision integer stored as little-endian `u64` limbs.
34#[derive(Clone, Debug, Eq, PartialEq)]
35pub struct BigUint {
36    limbs: Vec<u64>,
37}
38
39/// Signed multiprecision integer used by later public-key helpers.
40#[derive(Clone, Debug, Eq, PartialEq)]
41pub struct BigInt {
42    sign: Sign,
43    magnitude: BigUint,
44}
45
46/// Montgomery arithmetic context for a fixed odd modulus.
47///
48/// Public-key schemes spend most of their time doing repeated modular
49/// multiplication under one long-lived odd modulus. Precomputing the
50/// Montgomery constants once avoids paying the setup cost on every multiply
51/// while keeping the scheme code readable.
52#[derive(Clone, Debug, Eq, PartialEq)]
53pub struct MontgomeryCtx {
54    modulus: BigUint,
55    // n0_inv = -n^{-1} mod 2^64 (Montgomery reduction coefficient).
56    n0_inv: u64,
57    // R^2 mod n with R = 2^(64 * limbs(n)): conversion factor into Montgomery form.
58    r2_mod: BigUint,
59    // 1 encoded in Montgomery form, i.e. R mod n.
60    one_mont: BigUint,
61}
62
63impl Ord for BigUint {
64    fn cmp(&self, other: &Self) -> Ordering {
65        // Ordering assumes normalized limb vectors (no most-significant zero
66        // limbs). All constructors/arithmetic paths call `normalize()`.
67        debug_assert!(
68            self.limbs.last().copied() != Some(0),
69            "BigUint invariant: no leading zero limbs",
70        );
71        debug_assert!(
72            other.limbs.last().copied() != Some(0),
73            "BigUint invariant: no leading zero limbs",
74        );
75        match self.limbs.len().cmp(&other.limbs.len()) {
76            Ordering::Equal => {}
77            ord => return ord,
78        }
79
80        for (&lhs, &rhs) in self.limbs.iter().rev().zip(other.limbs.iter().rev()) {
81            match lhs.cmp(&rhs) {
82                Ordering::Equal => {}
83                ord => return ord,
84            }
85        }
86
87        Ordering::Equal
88    }
89}
90
91impl PartialOrd for BigUint {
92    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
93        Some(self.cmp(other))
94    }
95}
96
97impl BigUint {
98    /// Construct zero.
99    #[must_use]
100    pub fn zero() -> Self {
101        Self { limbs: Vec::new() }
102    }
103
104    /// Construct one.
105    #[must_use]
106    pub fn one() -> Self {
107        Self { limbs: vec![1] }
108    }
109
110    /// Construct from a machine word.
111    #[must_use]
112    pub fn from_u64(value: u64) -> Self {
113        if value == 0 {
114            Self::zero()
115        } else {
116            Self { limbs: vec![value] }
117        }
118    }
119
120    /// Construct from a `u128`.
121    ///
122    /// # Panics
123    ///
124    /// Panics only if the internal limb split invariants fail unexpectedly.
125    #[must_use]
126    pub fn from_u128(value: u128) -> Self {
127        if value == 0 {
128            return Self::zero();
129        }
130
131        let lo =
132            u64::try_from(value & u128::from(u64::MAX)).expect("low 64 bits always fit into u64");
133        let hi = u64::try_from(value >> 64).expect("high 64 bits always fit into u64");
134        if hi == 0 {
135            Self { limbs: vec![lo] }
136        } else {
137            Self {
138                limbs: vec![lo, hi],
139            }
140        }
141    }
142
143    /// Decode big-endian bytes.
144    ///
145    /// Internally, limb 0 always stores the least-significant 64 bits.
146    #[must_use]
147    pub fn from_be_bytes(bytes: &[u8]) -> Self {
148        if bytes.is_empty() {
149            return Self::zero();
150        }
151
152        let mut limbs = Vec::with_capacity(bytes.len().div_ceil(8));
153        let mut acc = 0u64;
154        let mut shift = 0u32;
155
156        // Walk bytes from least-significant (last byte of the big-endian input)
157        // to most-significant, packing eight bytes at a time into a 64-bit limb.
158        // When `shift` reaches 64, the current limb is full — push it and start
159        // the next one.  Any remaining bytes at the end form a partial limb.
160        for &byte in bytes.iter().rev() {
161            acc |= u64::from(byte) << shift;
162            shift += 8;
163            if shift == 64 {
164                limbs.push(acc);
165                acc = 0;
166                shift = 0;
167            }
168        }
169
170        if shift != 0 {
171            limbs.push(acc);
172        }
173
174        let mut out = Self { limbs };
175        out.normalize();
176        out
177    }
178
179    /// Encode as big-endian bytes without leading zero bytes.
180    ///
181    /// Internally, limb 0 stores the least-significant 64 bits, so encoding
182    /// walks the limbs in reverse order and strips only the leading zero bytes
183    /// introduced by the fixed-width `u64` representation.
184    ///
185    /// # Panics
186    ///
187    /// Panics only if the internal representation is corrupt and a non-zero
188    /// value contains no non-zero bytes.
189    #[must_use]
190    pub fn to_be_bytes(&self) -> Vec<u8> {
191        if self.is_zero() {
192            return vec![0];
193        }
194
195        let mut out = Vec::with_capacity(self.limbs.len() * 8);
196        for &limb in self.limbs.iter().rev() {
197            out.extend_from_slice(&limb.to_be_bytes());
198        }
199
200        let first_nonzero = out
201            .iter()
202            .position(|&byte| byte != 0)
203            .expect("non-zero bigint must encode to at least one non-zero byte");
204        out.drain(0..first_nonzero);
205        out
206    }
207
208    /// Return whether the value is zero.
209    #[must_use]
210    pub fn is_zero(&self) -> bool {
211        self.limbs.is_empty()
212    }
213
214    /// Return whether the value is odd.
215    #[must_use]
216    pub fn is_odd(&self) -> bool {
217        !self.is_zero() && (self.limbs[0] & 1) == 1
218    }
219
220    /// Return whether the value is exactly one.
221    #[must_use]
222    pub fn is_one(&self) -> bool {
223        self.limbs.len() == 1 && self.limbs[0] == 1
224    }
225
226    /// Number of significant bits.
227    ///
228    /// # Panics
229    ///
230    /// Panics only if the internal representation is corrupt and a non-zero
231    /// value contains no limbs.
232    #[must_use]
233    pub fn bits(&self) -> usize {
234        if self.is_zero() {
235            return 0;
236        }
237
238        let top = *self
239            .limbs
240            .last()
241            .expect("non-zero bigint has at least one limb");
242        let top_bits = (u64::BITS - top.leading_zeros()) as usize;
243        (self.limbs.len() - 1) * 64 + top_bits
244    }
245
246    /// Integer square root: the largest `r` such that `r^2 <= self`.
247    #[must_use]
248    pub fn sqrt_floor(&self) -> Self {
249        if self.is_zero() {
250            return Self::zero();
251        }
252        if self.is_one() {
253            return Self::one();
254        }
255
256        let mut low = Self::one();
257        let mut high = Self::zero();
258        // Choose `high` so the search starts with `low^2 <= self < high^2`.
259        // Setting bit `ceil(bits(self) / 2)` makes
260        // `high = 2^ceil(bits(self)/2)`, so `high^2 >= 2^bits(self) > self`.
261        // That gives the binary search a proved upper bound from the start.
262        high.set_bit(self.bits().div_ceil(2));
263
264        while {
265            let next_low = low.add_ref(&Self::one());
266            next_low < high
267        } {
268            let mut middle = low.add_ref(&high);
269            middle.shr1();
270            let square = middle.square_ref();
271            if square <= *self {
272                low = middle;
273            } else {
274                high = middle;
275            }
276        }
277
278        low
279    }
280
281    /// Test bit `index`.
282    #[must_use]
283    pub fn bit(&self, index: usize) -> bool {
284        let limb = index / 64;
285        let shift = index % 64;
286        if limb >= self.limbs.len() {
287            false
288        } else {
289            ((self.limbs[limb] >> shift) & 1) == 1
290        }
291    }
292
293    /// Set bit `index`.
294    pub fn set_bit(&mut self, index: usize) {
295        let limb = index / 64;
296        let shift = index % 64;
297        if self.limbs.len() <= limb {
298            self.limbs.resize(limb + 1, 0);
299        }
300        self.limbs[limb] |= 1u64 << shift;
301    }
302
303    /// Add another bigint in place.
304    ///
305    /// # Panics
306    ///
307    /// Panics only if the internal `u128` accumulator cannot be split back
308    /// into `u64` limbs, which would indicate a logic error.
309    pub fn add_assign_ref(&mut self, other: &Self) {
310        if other.is_zero() {
311            return;
312        }
313
314        if self.limbs.len() < other.limbs.len() {
315            self.limbs.resize(other.limbs.len(), 0);
316        }
317
318        let mut carry = 0u128;
319        for i in 0..other.limbs.len() {
320            let sum = u128::from(self.limbs[i]) + u128::from(other.limbs[i]) + carry;
321            self.limbs[i] = low_u64(sum);
322            carry = sum >> 64;
323        }
324
325        let mut i = other.limbs.len();
326        while carry != 0 && i < self.limbs.len() {
327            let sum = u128::from(self.limbs[i]) + carry;
328            self.limbs[i] = low_u64(sum);
329            carry = sum >> 64;
330            i += 1;
331        }
332
333        if carry != 0 {
334            self.limbs
335                .push(u64::try_from(carry).expect("final carry from u64 addition is at most 1"));
336        }
337    }
338
339    /// Return `self + other`.
340    #[must_use]
341    pub fn add_ref(&self, other: &Self) -> Self {
342        let mut out = self.clone();
343        out.add_assign_ref(other);
344        out
345    }
346
347    /// Subtract another bigint in place. Panics if `self < other`.
348    ///
349    /// # Panics
350    ///
351    /// Panics if `self < other`.
352    pub fn sub_assign_ref(&mut self, other: &Self) {
353        assert!((*self).cmp(other) != Ordering::Less, "BigUint underflow");
354        if other.is_zero() {
355            return;
356        }
357
358        let mut borrow = 0u128;
359        for i in 0..self.limbs.len() {
360            let lhs = u128::from(self.limbs[i]);
361            let rhs = if i < other.limbs.len() {
362                u128::from(other.limbs[i])
363            } else {
364                0
365            };
366
367            let subtrahend = rhs + borrow;
368            if lhs >= subtrahend {
369                self.limbs[i] = low_u64(lhs - subtrahend);
370                borrow = 0;
371            } else {
372                self.limbs[i] = low_u64((1u128 << 64) + lhs - subtrahend);
373                borrow = 1;
374            }
375        }
376
377        self.normalize();
378    }
379
380    /// Return `self - other`. Panics if `self < other`.
381    #[must_use]
382    pub fn sub_ref(&self, other: &Self) -> Self {
383        let mut out = self.clone();
384        out.sub_assign_ref(other);
385        out
386    }
387
388    /// Multiply two big integers.
389    ///
390    /// # Panics
391    ///
392    /// Panics only if the internal `u128` accumulators cannot be split back
393    /// into `u64` limbs, which would indicate a logic error.
394    #[must_use]
395    pub fn mul_ref(&self, other: &Self) -> Self {
396        if self.is_zero() || other.is_zero() {
397            return Self::zero();
398        }
399
400        if Self::should_use_karatsuba(self, other) {
401            return self.mul_karatsuba_ref(other);
402        }
403
404        Self::mul_schoolbook_ref(self, other)
405    }
406
407    /// Multiply a value by itself.
408    #[must_use]
409    pub fn square_ref(&self) -> Self {
410        self.mul_ref(self)
411    }
412
413    fn split_at_limb(&self, split: usize) -> (Self, Self) {
414        let low_end = split.min(self.limbs.len());
415        let mut low = Self {
416            limbs: self.limbs[..low_end].to_vec(),
417        };
418        low.normalize();
419
420        if split >= self.limbs.len() {
421            return (low, Self::zero());
422        }
423
424        let mut high = Self {
425            limbs: self.limbs[split..].to_vec(),
426        };
427        high.normalize();
428        (low, high)
429    }
430
431    fn should_use_karatsuba(lhs: &Self, rhs: &Self) -> bool {
432        let short = lhs.limbs.len().min(rhs.limbs.len());
433        let long = lhs.limbs.len().max(rhs.limbs.len());
434        short >= KARATSUBA_THRESHOLD_LIMBS && long <= short * KARATSUBA_MAX_IMBALANCE
435    }
436
437    fn mul_karatsuba_ref(&self, other: &Self) -> Self {
438        let split = self.limbs.len().max(other.limbs.len()) / 2;
439        if split == 0 {
440            return Self::mul_schoolbook_ref(self, other);
441        }
442
443        let (a0, a1) = self.split_at_limb(split);
444        let (b0, b1) = other.split_at_limb(split);
445        if a1.is_zero() || b1.is_zero() {
446            return Self::mul_schoolbook_ref(self, other);
447        }
448
449        let z0 = a0.mul_ref(&b0);
450        let z2 = a1.mul_ref(&b1);
451
452        let a_sum = a0.add_ref(&a1);
453        let b_sum = b0.add_ref(&b1);
454        let mut z1 = a_sum.mul_ref(&b_sum);
455        z1.sub_assign_ref(&z0);
456        z1.sub_assign_ref(&z2);
457
458        let mut out = z0;
459        z1.shl_bits(split * 64);
460        out.add_assign_ref(&z1);
461
462        let mut z2_shifted = z2;
463        z2_shifted.shl_bits(split * 128);
464        out.add_assign_ref(&z2_shifted);
465        out
466    }
467
468    fn mul_schoolbook_ref(lhs: &Self, rhs: &Self) -> Self {
469        let mut out = vec![0u64; lhs.limbs.len() + rhs.limbs.len()];
470        for (i, &lhs_limb) in lhs.limbs.iter().enumerate() {
471            let mut carry = 0u128;
472            for (j, &rhs_limb) in rhs.limbs.iter().enumerate() {
473                let idx = i + j;
474                let acc =
475                    u128::from(out[idx]) + u128::from(lhs_limb) * u128::from(rhs_limb) + carry;
476                out[idx] = low_u64(acc);
477                carry = acc >> 64;
478            }
479
480            let mut idx = i + rhs.limbs.len();
481            while carry != 0 {
482                let acc = u128::from(out[idx]) + carry;
483                out[idx] = low_u64(acc);
484                carry = acc >> 64;
485                idx += 1;
486            }
487        }
488
489        let mut result = Self { limbs: out };
490        // A normalized non-zero multiplicand and multiplier cannot produce a
491        // spuriously zero high limb except through the carry chain itself, so
492        // one post-pass normalization is enough.
493        result.normalize();
494        result
495    }
496
497    /// Shift left by one bit.
498    pub fn shl1(&mut self) {
499        if self.is_zero() {
500            return;
501        }
502
503        let mut carry = 0u64;
504        for limb in &mut self.limbs {
505            let next = *limb >> 63;
506            *limb = (*limb << 1) | carry;
507            carry = next;
508        }
509
510        if carry != 0 {
511            self.limbs.push(carry);
512        }
513        // A left shift on an already-normalized value cannot introduce a
514        // leading zero limb, so no normalize() pass is required here.
515    }
516
517    /// Shift right by one bit.
518    pub fn shr1(&mut self) {
519        if self.is_zero() {
520            return;
521        }
522
523        let mut carry = 0u64;
524        for limb in self.limbs.iter_mut().rev() {
525            let next = (*limb & 1) << 63;
526            *limb = (*limb >> 1) | carry;
527            carry = next;
528        }
529
530        self.normalize();
531    }
532
533    /// XOR another bigint into `self` in place (GF(2^m) field addition).
534    ///
535    /// Extends `self.limbs` with zeros if shorter than `other.limbs`, then
536    /// XORs each corresponding limb pair.  The result is normalized to strip
537    /// any leading zero limbs produced by XOR cancellation.
538    pub fn bitxor_assign(&mut self, other: &BigUint) {
539        if self.limbs.len() < other.limbs.len() {
540            self.limbs.resize(other.limbs.len(), 0);
541        }
542        for (s, &o) in self.limbs.iter_mut().zip(other.limbs.iter()) {
543            *s ^= o;
544        }
545        self.normalize();
546    }
547
548    /// Left-shift by `n` bits.
549    ///
550    /// Implemented as `n / 64` full-limb shifts (inserting zero limbs at the
551    /// low end) followed by up to 63 single-bit left shifts, which avoids
552    /// undefined behaviour from shifting a `u64` by 64 or more positions.
553    pub fn shl_bits(&mut self, n: usize) {
554        if self.is_zero() || n == 0 {
555            return;
556        }
557        let limb_shifts = n / 64;
558        let bit_shifts = n % 64;
559        // Full-limb shift: prepend zeros at the low (index 0) end.
560        if limb_shifts > 0 {
561            let mut new_limbs = vec![0u64; limb_shifts];
562            new_limbs.extend_from_slice(&self.limbs);
563            self.limbs = new_limbs;
564        }
565        // Remaining bit-level shift (0 < bit_shifts < 64, so 64 - bit_shifts is safe).
566        if bit_shifts > 0 {
567            let mut carry = 0u64;
568            for limb in &mut self.limbs {
569                let next_carry = *limb >> (64 - bit_shifts);
570                *limb = (*limb << bit_shifts) | carry;
571                carry = next_carry;
572            }
573            if carry != 0 {
574                self.limbs.push(carry);
575            }
576        }
577        // A left-shift on a normalized value cannot introduce a leading zero
578        // limb, so no normalize() pass is needed here.
579    }
580
581    /// Compute `self mod modulus`.
582    #[must_use]
583    pub fn modulo(&self, modulus: &Self) -> Self {
584        let (_, remainder) = self.div_rem(modulus);
585        remainder
586    }
587
588    /// Compute the remainder modulo a machine word.
589    ///
590    /// # Panics
591    ///
592    /// Panics if `modulus == 0`.
593    #[must_use]
594    pub fn rem_u64(&self, modulus: u64) -> u64 {
595        assert!(modulus != 0, "division by zero");
596        if self.is_zero() {
597            return 0;
598        }
599
600        let mut remainder = 0u128;
601        // Horner's method in base `2^64`: carry the remainder of the already
602        // processed high limbs, then append the next limb as the next base
603        // digit before reducing again.
604        for &limb in self.limbs.iter().rev() {
605            let acc = (remainder << 64) | u128::from(limb);
606            remainder = acc % u128::from(modulus);
607        }
608
609        u64::try_from(remainder).expect("remainder modulo u64 fits into u64")
610    }
611
612    /// Compute `(lhs * rhs) mod modulus`.
613    ///
614    /// Odd moduli use a fresh Montgomery context so the common public-key path
615    /// avoids the division-heavy fallback. Even moduli keep the old
616    /// double-and-add reducer because Montgomery requires an odd modulus.
617    /// Rewriting one multiplicand as `y - 1` plus one extra add can change the
618    /// operand parity, but it does not change the modulus parity; the core
619    /// Montgomery requirement is `gcd(R, n) = 1`, so an even modulus still
620    /// needs a non-Montgomery path.
621    ///
622    /// # Panics
623    ///
624    /// Panics if `modulus == 0`.
625    #[must_use]
626    pub fn mod_mul(lhs: &Self, rhs: &Self, modulus: &Self) -> Self {
627        assert!(!modulus.is_zero(), "modulus must be non-zero");
628        if modulus == &Self::one() {
629            return Self::zero();
630        }
631        if let Some(ctx) = MontgomeryCtx::new(modulus) {
632            return ctx.mul(lhs, rhs);
633        }
634        Self::mod_mul_plain(lhs, rhs, modulus)
635    }
636
637    /// Compute `(lhs * rhs) mod modulus` using the simple double-and-add
638    /// fallback implementation.
639    ///
640    /// The result is mathematically correct, but repeated division-based
641    /// reduction makes it much slower than Montgomery multiplication for the
642    /// odd moduli that dominate public-key code. The current scheme code only
643    /// reaches this path for even moduli, so it remains as the explicit
644    /// fallback and readable reference for non-Montgomery cases.
645    #[must_use]
646    pub(crate) fn mod_mul_plain(lhs: &Self, rhs: &Self, modulus: &Self) -> Self {
647        if lhs.is_zero() || rhs.is_zero() {
648            return Self::zero();
649        }
650
651        let mut a = lhs.modulo(modulus);
652        let mut b = rhs.clone();
653        let mut out = Self::zero();
654        while !b.is_zero() {
655            if b.is_odd() {
656                out = out.add_ref(&a).modulo(modulus);
657            }
658            a = a.add_ref(&a).modulo(modulus);
659            b.shr1();
660        }
661        out
662    }
663
664    /// Return `(quotient, remainder)` for Euclidean division. Panics on zero divisor.
665    ///
666    /// # Panics
667    ///
668    /// Panics if `divisor == 0`.
669    #[must_use]
670    pub fn div_rem(&self, divisor: &Self) -> (Self, Self) {
671        assert!(!divisor.is_zero(), "division by zero");
672        if self.cmp(divisor) == Ordering::Less {
673            return (Self::zero(), self.clone());
674        }
675
676        let mut quotient = Self::zero();
677        let mut remainder = Self::zero();
678
679        // Bit-by-bit long division. `remainder` holds the partially
680        // reconstructed dividend prefix; each step shifts it left, appends the
681        // next source bit, and subtracts the divisor if the prefix is already
682        // large enough.
683        for bit in (0..self.bits()).rev() {
684            remainder.shl1();
685            if self.bit(bit) {
686                if remainder.is_zero() {
687                    remainder.limbs.push(1);
688                } else {
689                    remainder.limbs[0] |= 1;
690                }
691            }
692
693            if remainder.cmp(divisor) != Ordering::Less {
694                remainder.sub_assign_ref(divisor);
695                quotient.set_bit(bit);
696            }
697        }
698
699        (quotient, remainder)
700    }
701
702    fn normalize(&mut self) {
703        // Canonical representation invariant:
704        // - zero has `limbs.is_empty()`
705        // - non-zero values have a non-zero top limb
706        while self.limbs.last().copied() == Some(0) {
707            self.limbs.pop();
708        }
709    }
710
711    fn limb_or_zero(&self, idx: usize) -> u64 {
712        self.limbs.get(idx).copied().unwrap_or(0)
713    }
714
715    fn montgomery_mul_odd_with_workspace(
716        lhs: &Self,
717        rhs: &Self,
718        modulus: &Self,
719        n0_inv: u64,
720        workspace: &mut Vec<u64>,
721    ) -> Self {
722        debug_assert!(modulus.is_odd(), "Montgomery path requires an odd modulus");
723        let width = modulus.limbs.len();
724        // `2 * width` limbs hold the schoolbook product. The extra two limbs
725        // are carry headroom so neither pass can run off the end.
726        let needed = width * 2 + 2;
727        if workspace.len() != needed {
728            workspace.resize(needed, 0);
729        } else {
730            workspace.fill(0);
731        }
732
733        // First pass: accumulate the ordinary product `lhs * rhs`.
734        for i in 0..width {
735            let lhs_limb = lhs.limb_or_zero(i);
736            let mut carry = 0u128;
737            for j in 0..width {
738                let idx = i + j;
739                let acc = u128::from(workspace[idx])
740                    + u128::from(lhs_limb) * u128::from(rhs.limb_or_zero(j))
741                    + carry;
742                workspace[idx] = low_u64(acc);
743                carry = acc >> 64;
744            }
745
746            let mut idx = i + width;
747            while carry != 0 {
748                let acc = u128::from(workspace[idx]) + carry;
749                workspace[idx] = low_u64(acc);
750                carry = acc >> 64;
751                idx += 1;
752            }
753        }
754
755        // Second pass: Montgomery reduction. Choose `m` so the current low
756        // limb cancels modulo `2^64`, then add `m * modulus`. Each round
757        // zeros one more low limb; after `width` rounds the discarded low half
758        // accounts for the implicit division by `R = 2^(64w)`, so the high
759        // half is `lhs * rhs * R^-1 mod n`. That is why copying out
760        // `workspace[width..]` yields the Montgomery product.
761        for i in 0..width {
762            let m = workspace[i].wrapping_mul(n0_inv);
763            let mut carry = 0u128;
764            for j in 0..width {
765                let idx = i + j;
766                let acc = u128::from(workspace[idx])
767                    + u128::from(m) * u128::from(modulus.limb_or_zero(j))
768                    + carry;
769                workspace[idx] = low_u64(acc);
770                carry = acc >> 64;
771            }
772
773            let mut idx = i + width;
774            while carry != 0 {
775                let acc = u128::from(workspace[idx]) + carry;
776                workspace[idx] = low_u64(acc);
777                carry = acc >> 64;
778                idx += 1;
779            }
780        }
781
782        let mut out = Self {
783            limbs: workspace[width..=(width * 2)].to_vec(),
784        };
785        out.normalize();
786        // Montgomery reduction leaves a value in `[0, 2n)`, so at most one
787        // subtraction is needed to return to the canonical residue range.
788        if out >= *modulus {
789            out.sub_assign_ref(modulus);
790        }
791        out
792    }
793}
794
795impl MontgomeryCtx {
796    fn encode_with_workspace(&self, value: &BigUint, workspace: &mut Vec<u64>) -> BigUint {
797        if value.is_zero() {
798            return BigUint::zero();
799        }
800
801        BigUint::montgomery_mul_odd_with_workspace(
802            &value.modulo(&self.modulus),
803            &self.r2_mod,
804            &self.modulus,
805            self.n0_inv,
806            workspace,
807        )
808    }
809
810    fn decode_with_workspace(&self, value: &BigUint, workspace: &mut Vec<u64>) -> BigUint {
811        BigUint::montgomery_mul_odd_with_workspace(
812            value,
813            &BigUint::one(),
814            &self.modulus,
815            self.n0_inv,
816            workspace,
817        )
818    }
819
820    fn pow_encoded_with_workspace(
821        &self,
822        base_mont: &BigUint,
823        exponent: &BigUint,
824        workspace: &mut Vec<u64>,
825    ) -> BigUint {
826        if self.modulus == BigUint::one() {
827            return BigUint::zero();
828        }
829
830        let mut result = self.one_mont.clone();
831        let mut power = base_mont.clone();
832
833        for bit in 0..exponent.bits() {
834            if exponent.bit(bit) {
835                result = BigUint::montgomery_mul_odd_with_workspace(
836                    &result,
837                    &power,
838                    &self.modulus,
839                    self.n0_inv,
840                    workspace,
841                );
842            }
843            power = BigUint::montgomery_mul_odd_with_workspace(
844                &power,
845                &power,
846                &self.modulus,
847                self.n0_inv,
848                workspace,
849            );
850        }
851
852        self.decode_with_workspace(&result, workspace)
853    }
854
855    /// Build a Montgomery context for a non-zero odd modulus.
856    #[must_use]
857    pub fn new(modulus: &BigUint) -> Option<Self> {
858        if modulus.is_zero() || !modulus.is_odd() {
859            return None;
860        }
861
862        let n0_inv = montgomery_n0_inv(modulus.limbs[0]);
863
864        // With `w` limbs, Montgomery arithmetic uses `R = 2^(64w)`. `R^2 mod
865        // n` is the standard conversion factor for entering the Montgomery
866        // domain because `montgomery_mul(a, R^2) = a * R^2 * R^-1 = aR`, the
867        // Montgomery encoding of the ordinary residue `a`.
868        let mut r2 = BigUint::zero();
869        r2.set_bit(modulus.limbs.len() * 128);
870        let r2_mod = r2.modulo(modulus);
871
872        // `R mod n` is the Montgomery encoding of 1, stored so exponentiation
873        // can start its accumulator in the correct domain.
874        let mut r = BigUint::zero();
875        r.set_bit(modulus.limbs.len() * 64);
876        let one_mont = r.modulo(modulus);
877
878        Some(Self {
879            modulus: modulus.clone(),
880            n0_inv,
881            r2_mod,
882            one_mont,
883        })
884    }
885
886    /// Return the odd modulus this context was built for.
887    #[must_use]
888    pub fn modulus(&self) -> &BigUint {
889        &self.modulus
890    }
891
892    /// Convert an ordinary residue into Montgomery form.
893    #[must_use]
894    pub fn encode(&self, value: &BigUint) -> BigUint {
895        let mut workspace = Vec::new();
896        self.encode_with_workspace(value, &mut workspace)
897    }
898
899    /// Convert a Montgomery residue back to the ordinary representation.
900    #[must_use]
901    pub fn decode(&self, value: &BigUint) -> BigUint {
902        let mut workspace = Vec::new();
903        self.decode_with_workspace(value, &mut workspace)
904    }
905
906    /// Multiply two ordinary residues modulo the context modulus.
907    #[must_use]
908    pub fn mul(&self, lhs: &BigUint, rhs: &BigUint) -> BigUint {
909        let mut workspace = Vec::new();
910        let lhs_mont = self.encode_with_workspace(lhs, &mut workspace);
911        let rhs_mont = self.encode_with_workspace(rhs, &mut workspace);
912        let product_mont = BigUint::montgomery_mul_odd_with_workspace(
913            &lhs_mont,
914            &rhs_mont,
915            &self.modulus,
916            self.n0_inv,
917            &mut workspace,
918        );
919        self.decode_with_workspace(&product_mont, &mut workspace)
920    }
921
922    /// Square one ordinary residue modulo the context modulus.
923    #[must_use]
924    pub fn square(&self, value: &BigUint) -> BigUint {
925        let mut workspace = Vec::new();
926        let value_mont = self.encode_with_workspace(value, &mut workspace);
927        let square_mont = BigUint::montgomery_mul_odd_with_workspace(
928            &value_mont,
929            &value_mont,
930            &self.modulus,
931            self.n0_inv,
932            &mut workspace,
933        );
934        self.decode_with_workspace(&square_mont, &mut workspace)
935    }
936
937    /// Compute `base^exponent mod modulus` inside the context.
938    #[must_use]
939    pub fn pow(&self, base: &BigUint, exponent: &BigUint) -> BigUint {
940        let mut workspace = Vec::new();
941        let base_mont = self.encode_with_workspace(&base.modulo(&self.modulus), &mut workspace);
942        self.pow_encoded_with_workspace(&base_mont, exponent, &mut workspace)
943    }
944
945    /// Compute `base^exponent mod modulus` with `base` already in Montgomery form.
946    ///
947    /// This is useful when callers reuse the same base and can cache the
948    /// encoded value once.
949    #[must_use]
950    pub fn pow_encoded(&self, base_mont: &BigUint, exponent: &BigUint) -> BigUint {
951        let mut workspace = Vec::new();
952        self.pow_encoded_with_workspace(base_mont, exponent, &mut workspace)
953    }
954}
955
956impl Drop for BigUint {
957    fn drop(&mut self) {
958        // BigUint backs private exponents, prime factors, and nonces in the
959        // public-key layer. Clear the limb buffer on drop so those values do
960        // not linger in freed heap memory.
961        crate::ct::zeroize_slice(self.limbs.as_mut_slice());
962    }
963}
964
965#[inline]
966fn low_u64(value: u128) -> u64 {
967    u64::try_from(value & u128::from(u64::MAX)).expect("masked low 64 bits always fit into u64")
968}
969
970fn montgomery_n0_inv(n0: u64) -> u64 {
971    debug_assert!(n0 & 1 == 1, "Montgomery path requires an odd modulus");
972    // Newton iteration in Z_(2^64): each step doubles the number of correct
973    // low bits in the inverse of `n0`. Six iterations are enough to converge
974    // to the full 64-bit inverse, and Montgomery reduction wants `-n0^-1`.
975    let mut inv = 1u64;
976    for _ in 0..6 {
977        inv = inv.wrapping_mul(2u64.wrapping_sub(n0.wrapping_mul(inv)));
978    }
979    inv.wrapping_neg()
980}
981
982impl BigInt {
983    /// Construct zero.
984    #[must_use]
985    pub fn zero() -> Self {
986        Self {
987            sign: Sign::Zero,
988            magnitude: BigUint::zero(),
989        }
990    }
991
992    /// Construct from an explicit sign and magnitude.
993    #[must_use]
994    pub fn from_parts(sign: Sign, magnitude: BigUint) -> Self {
995        if magnitude.is_zero() {
996            return Self::zero();
997        }
998
999        let canonical_sign = match sign {
1000            Sign::Zero => Sign::Positive,
1001            other => other,
1002        };
1003
1004        Self {
1005            sign: canonical_sign,
1006            magnitude,
1007        }
1008    }
1009
1010    /// Construct a non-negative signed integer from an unsigned value.
1011    #[must_use]
1012    pub fn from_biguint(magnitude: BigUint) -> Self {
1013        Self::from_parts(Sign::Positive, magnitude)
1014    }
1015
1016    /// Return the sign.
1017    #[must_use]
1018    pub fn sign(&self) -> Sign {
1019        self.sign
1020    }
1021
1022    /// Return the absolute value.
1023    #[must_use]
1024    pub fn magnitude(&self) -> &BigUint {
1025        &self.magnitude
1026    }
1027
1028    /// Negate the integer.
1029    #[must_use]
1030    pub fn negated(&self) -> Self {
1031        let sign = match self.sign {
1032            Sign::Positive => Sign::Negative,
1033            Sign::Negative => Sign::Positive,
1034            Sign::Zero => Sign::Zero,
1035        };
1036        Self {
1037            sign,
1038            magnitude: self.magnitude.clone(),
1039        }
1040    }
1041
1042    /// Return `self + other`.
1043    #[must_use]
1044    pub fn add_ref(&self, other: &Self) -> Self {
1045        match (self.sign, other.sign) {
1046            (Sign::Zero, _) => other.clone(),
1047            (_, Sign::Zero) => self.clone(),
1048            (Sign::Positive, Sign::Positive) => {
1049                Self::from_parts(Sign::Positive, self.magnitude.add_ref(&other.magnitude))
1050            }
1051            (Sign::Negative, Sign::Negative) => {
1052                Self::from_parts(Sign::Negative, self.magnitude.add_ref(&other.magnitude))
1053            }
1054            (Sign::Positive, Sign::Negative) => self.sub_ref(&other.negated()),
1055            (Sign::Negative, Sign::Positive) => other.sub_ref(&self.negated()),
1056        }
1057    }
1058
1059    /// Return `self - other`.
1060    #[must_use]
1061    pub fn sub_ref(&self, other: &Self) -> Self {
1062        match (self.sign, other.sign) {
1063            (_, Sign::Zero) => self.clone(),
1064            (Sign::Zero, _) => other.negated(),
1065            (Sign::Positive, Sign::Negative) => {
1066                Self::from_parts(Sign::Positive, self.magnitude.add_ref(&other.magnitude))
1067            }
1068            (Sign::Negative, Sign::Positive) => {
1069                Self::from_parts(Sign::Negative, self.magnitude.add_ref(&other.magnitude))
1070            }
1071            (Sign::Positive, Sign::Positive) => match self.magnitude.cmp(&other.magnitude) {
1072                Ordering::Greater => {
1073                    Self::from_parts(Sign::Positive, self.magnitude.sub_ref(&other.magnitude))
1074                }
1075                Ordering::Less => {
1076                    Self::from_parts(Sign::Negative, other.magnitude.sub_ref(&self.magnitude))
1077                }
1078                Ordering::Equal => Self::zero(),
1079            },
1080            (Sign::Negative, Sign::Negative) => match self.magnitude.cmp(&other.magnitude) {
1081                Ordering::Greater => {
1082                    Self::from_parts(Sign::Negative, self.magnitude.sub_ref(&other.magnitude))
1083                }
1084                Ordering::Less => {
1085                    Self::from_parts(Sign::Positive, other.magnitude.sub_ref(&self.magnitude))
1086                }
1087                Ordering::Equal => Self::zero(),
1088            },
1089        }
1090    }
1091
1092    /// Return `self * factor` for a non-negative factor.
1093    #[must_use]
1094    pub fn mul_biguint_ref(&self, factor: &BigUint) -> Self {
1095        if factor.is_zero() || self.sign == Sign::Zero {
1096            return Self::zero();
1097        }
1098
1099        Self::from_parts(self.sign, self.magnitude.mul_ref(factor))
1100    }
1101
1102    /// Reduce modulo a positive modulus and return the least non-negative residue.
1103    ///
1104    /// # Panics
1105    ///
1106    /// Panics if `modulus == 0`.
1107    #[must_use]
1108    pub fn modulo_positive(&self, modulus: &BigUint) -> BigUint {
1109        assert!(!modulus.is_zero(), "modulus must be non-zero");
1110        match self.sign {
1111            Sign::Zero => BigUint::zero(),
1112            Sign::Positive => self.magnitude.modulo(modulus),
1113            Sign::Negative => {
1114                let rem = self.magnitude.modulo(modulus);
1115                if rem.is_zero() {
1116                    BigUint::zero()
1117                } else {
1118                    modulus.sub_ref(&rem)
1119                }
1120            }
1121        }
1122    }
1123}
1124
1125#[cfg(test)]
1126mod tests {
1127    use super::{BigInt, BigUint, MontgomeryCtx, Sign};
1128
1129    fn lcg_next(state: &mut u64) -> u64 {
1130        *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
1131        *state
1132    }
1133
1134    fn seeded_biguint(words: usize, state: &mut u64) -> BigUint {
1135        let mut limbs = Vec::with_capacity(words);
1136        for _ in 0..words {
1137            limbs.push(lcg_next(state));
1138        }
1139        if words > 0 && limbs[words - 1] == 0 {
1140            limbs[words - 1] = 1;
1141        }
1142        BigUint { limbs }
1143    }
1144
1145    #[test]
1146    fn bytes_roundtrip() {
1147        let value =
1148            BigUint::from_be_bytes(&[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x11, 0x22]);
1149        assert_eq!(
1150            value.to_be_bytes(),
1151            vec![0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x11, 0x22]
1152        );
1153    }
1154
1155    #[test]
1156    fn add_sub_mul_small_values() {
1157        let a = BigUint::from_u128(1_000_000_000_000);
1158        let b = BigUint::from_u128(777_777_777_777);
1159        assert_eq!(a.add_ref(&b), BigUint::from_u128(1_777_777_777_777));
1160        assert_eq!(
1161            a.sub_ref(&BigUint::from_u64(1)),
1162            BigUint::from_u128(999_999_999_999)
1163        );
1164        assert_eq!(
1165            a.mul_ref(&b),
1166            BigUint::from_u128(777_777_777_777_000_000_000_000)
1167        );
1168    }
1169
1170    #[test]
1171    fn square_ref_matches_mul_ref() {
1172        let mut seed = 0x9e37_79b9_7f4a_7c15;
1173        for words in [1usize, 2, 8, 32, 48] {
1174            for _ in 0..8 {
1175                let value = seeded_biguint(words, &mut seed);
1176                assert_eq!(value.square_ref(), value.mul_ref(&value));
1177            }
1178        }
1179    }
1180
1181    #[test]
1182    fn karatsuba_dispatch_matches_schoolbook() {
1183        let mut seed = 0x243f_6a88_85a3_08d3;
1184        for words in [32usize, 40, 64] {
1185            for _ in 0..6 {
1186                let lhs = seeded_biguint(words, &mut seed);
1187                let rhs = seeded_biguint(words, &mut seed);
1188                let dispatched = lhs.mul_ref(&rhs);
1189                let schoolbook = BigUint::mul_schoolbook_ref(&lhs, &rhs);
1190                assert_eq!(dispatched, schoolbook);
1191            }
1192        }
1193    }
1194
1195    #[test]
1196    fn division_roundtrip() {
1197        let dividend = BigUint::from_u128(1_234_567_890_123_456_789);
1198        let divisor = BigUint::from_u64(37);
1199        let (q, r) = dividend.div_rem(&divisor);
1200        assert_eq!(q, BigUint::from_u128(33_366_699_733_066_399));
1201        assert_eq!(r, BigUint::from_u64(26));
1202        assert_eq!(q.mul_ref(&divisor).add_ref(&r), dividend);
1203    }
1204
1205    #[test]
1206    fn sqrt_floor_small_values() {
1207        assert_eq!(BigUint::from_u64(0).sqrt_floor(), BigUint::from_u64(0));
1208        assert_eq!(BigUint::from_u64(1).sqrt_floor(), BigUint::from_u64(1));
1209        assert_eq!(BigUint::from_u64(2).sqrt_floor(), BigUint::from_u64(1));
1210        assert_eq!(BigUint::from_u64(15).sqrt_floor(), BigUint::from_u64(3));
1211        assert_eq!(BigUint::from_u64(16).sqrt_floor(), BigUint::from_u64(4));
1212        assert_eq!(BigUint::from_u64(17).sqrt_floor(), BigUint::from_u64(4));
1213        assert_eq!(
1214            BigUint::from_u128(17_184_849_881).sqrt_floor(),
1215            BigUint::from_u64(131_090)
1216        );
1217    }
1218
1219    #[test]
1220    fn mod_mul_matches_small_arithmetic() {
1221        let a = BigUint::from_u64(123_456_789);
1222        let b = BigUint::from_u64(987_654_321);
1223        let m = BigUint::from_u64(1_000_000_007);
1224        assert_eq!(BigUint::mod_mul(&a, &b, &m), BigUint::from_u64(259_106_859));
1225    }
1226
1227    #[test]
1228    fn montgomery_mod_pow_matches_small_arithmetic() {
1229        let ctx = MontgomeryCtx::new(&BigUint::from_u64(1_000_000_007))
1230            .expect("odd modulus builds a context");
1231        let base = BigUint::from_u64(123_456_789);
1232        let exponent = BigUint::from_u64(65_537);
1233        assert_eq!(ctx.pow(&base, &exponent), BigUint::from_u64(560_583_526));
1234    }
1235
1236    #[test]
1237    fn montgomery_ctx_mul_matches_small_arithmetic() {
1238        let ctx = MontgomeryCtx::new(&BigUint::from_u64(1_000_000_007))
1239            .expect("odd modulus builds a context");
1240        let a = BigUint::from_u64(123_456_789);
1241        let b = BigUint::from_u64(987_654_321);
1242        assert_eq!(ctx.mul(&a, &b), BigUint::from_u64(259_106_859));
1243    }
1244
1245    #[test]
1246    fn mod_mul_even_modulus_uses_fallback_path() {
1247        let a = BigUint::from_u64(37);
1248        let b = BigUint::from_u64(19);
1249        let modulus = BigUint::from_u64(100);
1250        assert_eq!(BigUint::mod_mul(&a, &b, &modulus), BigUint::from_u64(3));
1251    }
1252
1253    #[test]
1254    fn bigint_sign_normalization() {
1255        let zero = BigInt::from_parts(Sign::Negative, BigUint::zero());
1256        assert_eq!(zero.sign(), Sign::Zero);
1257
1258        let value = BigInt::from_parts(Sign::Positive, BigUint::from_u64(7));
1259        assert_eq!(value.negated().sign(), Sign::Negative);
1260        assert_eq!(value.magnitude(), &BigUint::from_u64(7));
1261    }
1262
1263    #[test]
1264    fn bigint_add_sub_and_modulo() {
1265        let a = BigInt::from_biguint(BigUint::from_u64(10));
1266        let b = BigInt::from_parts(Sign::Negative, BigUint::from_u64(3));
1267        assert_eq!(a.add_ref(&b), BigInt::from_biguint(BigUint::from_u64(7)));
1268        assert_eq!(
1269            b.sub_ref(&a),
1270            BigInt::from_parts(Sign::Negative, BigUint::from_u64(13))
1271        );
1272        assert_eq!(
1273            BigInt::from_parts(Sign::Negative, BigUint::from_u64(3))
1274                .modulo_positive(&BigUint::from_u64(11)),
1275            BigUint::from_u64(8)
1276        );
1277    }
1278}