Skip to main content

crypto_bigint/uint/
mul.rs

1//! [`Uint`] multiplication operations.
2
3use crate::{
4    Checked, CheckedMul, Choice, Concat, ConcatenatingMul, ConcatenatingSquare, CtOption, Limb,
5    Mul, MulAssign, Uint, Wrapping, WrappingMul,
6};
7
8pub(crate) mod karatsuba;
9pub(crate) mod schoolbook;
10
11impl<const LIMBS: usize> Uint<LIMBS> {
12    /// Multiply `self` by `rhs`, returning a concatenated "wide" result.
13    #[must_use]
14    pub const fn concatenating_mul<const RHS_LIMBS: usize, const WIDE_LIMBS: usize>(
15        &self,
16        rhs: &Uint<RHS_LIMBS>,
17    ) -> Uint<WIDE_LIMBS>
18    where
19        Self: Concat<RHS_LIMBS, Output = Uint<WIDE_LIMBS>>,
20    {
21        let (lo, hi) = self.widening_mul(rhs);
22        Uint::concat_mixed(&lo, &hi)
23    }
24
25    /// Compute "wide" multiplication as a 2-tuple containing the `(lo, hi)` components of the product, whose sizes
26    /// correspond to the sizes of the operands.
27    #[deprecated(since = "0.7.0", note = "please use `widening_mul` instead")]
28    #[must_use]
29    pub const fn split_mul<const RHS_LIMBS: usize>(
30        &self,
31        rhs: &Uint<RHS_LIMBS>,
32    ) -> (Self, Uint<RHS_LIMBS>) {
33        self.widening_mul(rhs)
34    }
35
36    /// Compute "wide" multiplication as a 2-tuple containing the `(lo, hi)` components of the product, whose sizes
37    /// correspond to the sizes of the operands.
38    #[inline(always)]
39    #[must_use]
40    pub const fn widening_mul<const RHS_LIMBS: usize>(
41        &self,
42        rhs: &Uint<RHS_LIMBS>,
43    ) -> (Self, Uint<RHS_LIMBS>) {
44        karatsuba::widening_mul_fixed(self.as_uint_ref(), rhs.as_uint_ref())
45    }
46
47    /// Perform wrapping multiplication, discarding overflow.
48    #[must_use]
49    pub const fn wrapping_mul<const RHS_LIMBS: usize>(&self, rhs: &Uint<RHS_LIMBS>) -> Self {
50        karatsuba::wrapping_mul_fixed::<LIMBS>(self.as_uint_ref(), rhs.as_uint_ref()).0
51    }
52
53    /// Perform saturating multiplication, returning `MAX` on overflow.
54    #[must_use]
55    pub const fn saturating_mul<const RHS_LIMBS: usize>(&self, rhs: &Uint<RHS_LIMBS>) -> Self {
56        let (lo, overflow) = self.overflowing_mul(rhs);
57        Self::select(&lo, &Self::MAX, overflow)
58    }
59
60    /// Perform wrapping multiplication, checking that the result fits in the original [`Uint`] size.
61    #[must_use]
62    pub const fn checked_mul<const RHS_LIMBS: usize>(
63        &self,
64        rhs: &Uint<RHS_LIMBS>,
65    ) -> CtOption<Uint<LIMBS>> {
66        let (lo, overflow) = self.overflowing_mul(rhs);
67        CtOption::new(lo, overflow.not())
68    }
69
70    /// Perform overflowing multiplication, returning the wrapped result and a `Choice`
71    /// indicating whether overflow occurred.
72    #[inline(always)]
73    #[must_use]
74    pub(crate) const fn overflowing_mul<const RHS_LIMBS: usize>(
75        &self,
76        rhs: &Uint<RHS_LIMBS>,
77    ) -> (Uint<LIMBS>, Choice) {
78        let (lo, carry) = karatsuba::wrapping_mul_fixed(self.as_uint_ref(), rhs.as_uint_ref());
79        let overflow = self
80            .as_uint_ref()
81            .check_mul_overflow(rhs.as_uint_ref(), carry.is_nonzero());
82        (lo, overflow)
83    }
84
85    /// Perform multiplication by a Limb, returning the wrapped result and a Limb overflow.
86    pub(crate) const fn overflowing_mul_limb(&self, rhs: Limb) -> (Self, Limb) {
87        let mut ret = [Limb::ZERO; LIMBS];
88        let mut i = 0;
89        let mut carry = Limb::ZERO;
90        while i < LIMBS {
91            (ret[i], carry) = self.limbs[i].carrying_mul_add(rhs, Limb::ZERO, carry);
92            i += 1;
93        }
94        (Uint::new(ret), carry)
95    }
96
97    /// Perform wrapping multiplication by a Limb, discarding overflow.
98    pub(crate) const fn wrapping_mul_limb(&self, rhs: Limb) -> Self {
99        self.overflowing_mul_limb(rhs).0
100    }
101}
102
103/// Squaring operations
104impl<const LIMBS: usize> Uint<LIMBS> {
105    /// Square self, returning a "wide" result in two parts as (lo, hi).
106    #[inline(always)]
107    #[must_use]
108    #[deprecated(since = "0.7.0", note = "please use `widening_square` instead")]
109    pub const fn square_wide(&self) -> (Self, Self) {
110        self.widening_square()
111    }
112
113    /// Square self, returning a "wide" result in two parts as `(lo, hi)`.
114    #[inline(always)]
115    #[must_use]
116    pub const fn widening_square(&self) -> (Self, Self) {
117        karatsuba::widening_square_fixed(self.as_uint_ref())
118    }
119
120    /// Square self, returning a concatenated "wide" result.
121    #[must_use]
122    pub const fn concatenating_square<const WIDE_LIMBS: usize>(&self) -> Uint<WIDE_LIMBS>
123    where
124        Self: Concat<LIMBS, Output = Uint<WIDE_LIMBS>>,
125    {
126        let (lo, hi) = self.widening_square();
127        Uint::concat_mixed(&lo, &hi)
128    }
129
130    /// Square self, checking that the result fits in the original [`Uint`] size.
131    #[must_use]
132    pub const fn checked_square(&self) -> CtOption<Uint<LIMBS>> {
133        let (lo, overflow) = self.overflowing_square();
134        CtOption::new(lo, overflow.not())
135    }
136
137    /// Perform wrapping square, discarding overflow.
138    #[must_use]
139    pub const fn wrapping_square(&self) -> Uint<LIMBS> {
140        karatsuba::wrapping_square_fixed(self.as_uint_ref()).0
141    }
142
143    /// Perform saturating squaring, returning `MAX` on overflow.
144    #[must_use]
145    pub const fn saturating_square(&self) -> Self {
146        let (lo, overflow) = self.overflowing_square();
147        Self::select(&lo, &Self::MAX, overflow)
148    }
149
150    /// Perform overflowing squaring, returning the wrapped result and a `Choice`
151    /// indicating whether overflow occurred.
152    #[inline(always)]
153    #[must_use]
154    pub(crate) const fn overflowing_square(&self) -> (Uint<LIMBS>, Choice) {
155        let (lo, carry) = karatsuba::wrapping_square_fixed(self.as_uint_ref());
156        let overflow = self.as_uint_ref().check_square_overflow(carry.is_nonzero());
157        (lo, overflow)
158    }
159}
160
161impl<const LIMBS: usize, const WIDE_LIMBS: usize> Uint<LIMBS>
162where
163    Self: Concat<LIMBS, Output = Uint<WIDE_LIMBS>>,
164{
165    /// Square self, returning a concatenated "wide" result.
166    #[must_use]
167    #[deprecated(since = "0.7.0", note = "please use `concatenating_square` instead")]
168    pub const fn square(&self) -> Uint<WIDE_LIMBS> {
169        let (lo, hi) = self.widening_square();
170        lo.concat(&hi)
171    }
172}
173
174impl<const LIMBS: usize, const RHS_LIMBS: usize> CheckedMul<Uint<RHS_LIMBS>> for Uint<LIMBS> {
175    fn checked_mul(&self, rhs: &Uint<RHS_LIMBS>) -> CtOption<Self> {
176        self.checked_mul(rhs)
177    }
178}
179
180impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<Uint<RHS_LIMBS>> for Uint<LIMBS> {
181    type Output = Uint<LIMBS>;
182
183    fn mul(self, rhs: Uint<RHS_LIMBS>) -> Self {
184        self.mul(&rhs)
185    }
186}
187
188impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<&Uint<RHS_LIMBS>> for Uint<LIMBS> {
189    type Output = Uint<LIMBS>;
190
191    fn mul(self, rhs: &Uint<RHS_LIMBS>) -> Self {
192        (&self).mul(rhs)
193    }
194}
195
196impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<Uint<RHS_LIMBS>> for &Uint<LIMBS> {
197    type Output = Uint<LIMBS>;
198
199    fn mul(self, rhs: Uint<RHS_LIMBS>) -> Self::Output {
200        self.mul(&rhs)
201    }
202}
203
204impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<&Uint<RHS_LIMBS>> for &Uint<LIMBS> {
205    type Output = Uint<LIMBS>;
206
207    fn mul(self, rhs: &Uint<RHS_LIMBS>) -> Self::Output {
208        self.checked_mul(rhs)
209            .expect("attempted to multiply with overflow")
210    }
211}
212
213impl<const LIMBS: usize, const RHS_LIMBS: usize> MulAssign<Uint<RHS_LIMBS>> for Uint<LIMBS> {
214    fn mul_assign(&mut self, rhs: Uint<RHS_LIMBS>) {
215        *self = self.mul(&rhs);
216    }
217}
218
219impl<const LIMBS: usize, const RHS_LIMBS: usize> MulAssign<&Uint<RHS_LIMBS>> for Uint<LIMBS> {
220    fn mul_assign(&mut self, rhs: &Uint<RHS_LIMBS>) {
221        *self = self.mul(rhs);
222    }
223}
224
225impl<const LIMBS: usize> MulAssign<Wrapping<Uint<LIMBS>>> for Wrapping<Uint<LIMBS>> {
226    fn mul_assign(&mut self, other: Wrapping<Uint<LIMBS>>) {
227        *self = *self * other;
228    }
229}
230
231impl<const LIMBS: usize> MulAssign<&Wrapping<Uint<LIMBS>>> for Wrapping<Uint<LIMBS>> {
232    fn mul_assign(&mut self, other: &Wrapping<Uint<LIMBS>>) {
233        *self = *self * other;
234    }
235}
236
237impl<const LIMBS: usize> MulAssign<Checked<Uint<LIMBS>>> for Checked<Uint<LIMBS>> {
238    fn mul_assign(&mut self, other: Checked<Uint<LIMBS>>) {
239        *self = *self * other;
240    }
241}
242
243impl<const LIMBS: usize> MulAssign<&Checked<Uint<LIMBS>>> for Checked<Uint<LIMBS>> {
244    fn mul_assign(&mut self, other: &Checked<Uint<LIMBS>>) {
245        *self = *self * other;
246    }
247}
248
249impl<const LIMBS: usize, const RHS_LIMBS: usize, const WIDE_LIMBS: usize>
250    ConcatenatingMul<Uint<RHS_LIMBS>> for Uint<LIMBS>
251where
252    Self: Concat<RHS_LIMBS, Output = Uint<WIDE_LIMBS>>,
253{
254    type Output = Uint<WIDE_LIMBS>;
255
256    #[inline]
257    fn concatenating_mul(&self, rhs: Uint<RHS_LIMBS>) -> Self::Output {
258        self.concatenating_mul(&rhs)
259    }
260}
261
262impl<const LIMBS: usize, const RHS_LIMBS: usize, const WIDE_LIMBS: usize>
263    ConcatenatingMul<&Uint<RHS_LIMBS>> for Uint<LIMBS>
264where
265    Self: Concat<RHS_LIMBS, Output = Uint<WIDE_LIMBS>>,
266{
267    type Output = Uint<WIDE_LIMBS>;
268
269    #[inline]
270    fn concatenating_mul(&self, rhs: &Uint<RHS_LIMBS>) -> Self::Output {
271        self.concatenating_mul(rhs)
272    }
273}
274
275impl<const LIMBS: usize, const WIDE_LIMBS: usize> ConcatenatingSquare for Uint<LIMBS>
276where
277    Self: Concat<LIMBS, Output = Uint<WIDE_LIMBS>>,
278{
279    type Output = Uint<WIDE_LIMBS>;
280
281    #[inline]
282    fn concatenating_square(&self) -> Self::Output {
283        self.concatenating_square()
284    }
285}
286
287impl<const LIMBS: usize> WrappingMul for Uint<LIMBS> {
288    fn wrapping_mul(&self, v: &Self) -> Self {
289        self.wrapping_mul(v)
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use crate::{ConcatenatingMul, ConcatenatingSquare, Limb, U64, U128, U192, U256, Uint};
296
297    #[test]
298    fn widening_mul_zero_and_one() {
299        assert_eq!(U64::ZERO.widening_mul(&U64::ZERO), (U64::ZERO, U64::ZERO));
300        assert_eq!(U64::ZERO.widening_mul(&U64::ONE), (U64::ZERO, U64::ZERO));
301        assert_eq!(U64::ONE.widening_mul(&U64::ZERO), (U64::ZERO, U64::ZERO));
302        assert_eq!(U64::ONE.widening_mul(&U64::ONE), (U64::ONE, U64::ZERO));
303    }
304
305    #[test]
306    fn widening_mul_lo_only() {
307        let primes: &[u32] = &[3, 5, 17, 257, 65537];
308
309        for &a_int in primes {
310            for &b_int in primes {
311                let (lo, hi) = U64::from_u32(a_int).widening_mul(&U64::from_u32(b_int));
312                let expected = U64::from_u64(u64::from(a_int) * u64::from(b_int));
313                assert_eq!(lo, expected);
314                assert!(bool::from(hi.is_zero()));
315                assert_eq!(lo, U64::from_u32(a_int).wrapping_mul(&U64::from_u32(b_int)));
316            }
317        }
318    }
319
320    #[test]
321    fn mul_concat_even() {
322        assert_eq!(U64::ZERO.concatenating_mul(&U64::MAX), U128::ZERO);
323        assert_eq!(U64::MAX.concatenating_mul(&U64::ZERO), U128::ZERO);
324        assert_eq!(
325            U64::MAX.concatenating_mul(&U64::MAX),
326            U128::from_u128(0xfffffffffffffffe_0000000000000001)
327        );
328        assert_eq!(
329            U64::ONE.concatenating_mul(&U64::MAX),
330            U128::from_u128(0x0000000000000000_ffffffffffffffff)
331        );
332    }
333
334    #[test]
335    fn mul_concat_mixed() {
336        let a = U64::from_u64(0x0011223344556677);
337        let b = U128::from_u128(0x8899aabbccddeeff_8899aabbccddeeff);
338        let expected = U192::from(&b).saturating_mul(&a);
339        assert_eq!(a.concatenating_mul(&b), expected);
340        assert_eq!(ConcatenatingMul::concatenating_mul(&a, &b), expected);
341        assert_eq!(b.concatenating_mul(&a), expected);
342        assert_eq!(ConcatenatingMul::concatenating_mul(&b, &a), expected);
343    }
344
345    #[test]
346    fn wrapping_mul_even() {
347        assert_eq!(U64::ZERO.wrapping_mul(&U64::MAX), U64::ZERO);
348        assert_eq!(U64::MAX.wrapping_mul(&U64::ZERO), U64::ZERO);
349        assert_eq!(U64::MAX.wrapping_mul(&U64::MAX), U64::ONE);
350        assert_eq!(U64::ONE.wrapping_mul(&U64::MAX), U64::MAX);
351    }
352
353    #[test]
354    fn wrapping_mul_mixed() {
355        let a = U64::from_u64(0x0011223344556677);
356        let b = U128::from_u128(0x8899aabbccddeeff_8899aabbccddeeff);
357        let expected = U192::from(&b).saturating_mul(&a);
358        assert_eq!(b.wrapping_mul(&a), expected.resize());
359        assert_eq!(a.wrapping_mul(&b), expected.resize());
360    }
361
362    #[test]
363    fn checked_mul_ok() {
364        let n = U64::from_u32(0xffff_ffff);
365        assert_eq!(
366            n.checked_mul(&n).unwrap(),
367            U64::from_u64(0xffff_fffe_0000_0001)
368        );
369        assert_eq!(U64::ZERO.checked_mul(&U64::ZERO).unwrap(), U64::ZERO);
370    }
371
372    #[test]
373    fn checked_mul_overflow() {
374        let n = U64::MAX;
375        assert!(bool::from(n.checked_mul(&n).is_none()));
376    }
377
378    #[test]
379    fn saturating_mul_no_overflow() {
380        let n = U64::from_u8(8);
381        assert_eq!(n.saturating_mul(&n), U64::from_u8(64));
382    }
383
384    #[test]
385    fn saturating_mul_overflow() {
386        let a = U64::from(0xffff_ffff_ffff_ffffu64);
387        let b = U64::from(2u8);
388        assert_eq!(a.saturating_mul(&b), U64::MAX);
389    }
390
391    #[test]
392    fn concatenating_square() {
393        let n = U64::from_u64(0xffff_ffff_ffff_ffff);
394        let (lo, hi) = n.concatenating_square().split();
395        assert_eq!(lo, U64::from_u64(1));
396        assert_eq!(hi, U64::from_u64(0xffff_ffff_ffff_fffe));
397        let check = ConcatenatingSquare::concatenating_square(&n).split();
398        assert_eq!(check, (lo, hi));
399    }
400
401    #[test]
402    fn concatenating_square_larger() {
403        let n = U256::MAX;
404        let (lo, hi) = n.concatenating_square().split();
405        assert_eq!(lo, U256::ONE);
406        assert_eq!(hi, U256::MAX.wrapping_sub(&U256::ONE));
407    }
408
409    #[test]
410    fn checked_square() {
411        let n = U256::from_u64(u64::MAX).wrapping_add(&U256::ONE);
412        let n2 = n.checked_square();
413        assert!(n2.is_some().to_bool());
414        let n4 = n2.unwrap().checked_square();
415        assert!(n4.is_none().to_bool());
416        let z = U256::ZERO.checked_square();
417        assert!(z.is_some().to_bool());
418        let m = U256::MAX.checked_square();
419        assert!(m.is_none().to_bool());
420    }
421
422    #[test]
423    fn wrapping_square() {
424        let n = U256::from_u64(u64::MAX).wrapping_add(&U256::ONE);
425        let n2 = n.wrapping_square();
426        assert_eq!(n2, U256::from_u128(u128::MAX).wrapping_add(&U256::ONE));
427        let n4 = n2.wrapping_square();
428        assert_eq!(n4, U256::ZERO);
429    }
430
431    #[test]
432    fn saturating_square() {
433        let n = U256::from_u64(u64::MAX).wrapping_add(&U256::ONE);
434        let n2 = n.saturating_square();
435        assert_eq!(n2, U256::from_u128(u128::MAX).wrapping_add(&U256::ONE));
436        let n4 = n2.saturating_square();
437        assert_eq!(n4, U256::MAX);
438    }
439
440    #[cfg(feature = "rand_core")]
441    #[test]
442    fn mul_cmp() {
443        use crate::{Random, U4096};
444        use rand_core::SeedableRng;
445        let mut rng = chacha20::ChaCha8Rng::seed_from_u64(1);
446
447        let rounds = if cfg!(miri) { 10 } else { 50 };
448        for _ in 0..rounds {
449            let a = U4096::random_from_rng(&mut rng);
450            assert_eq!(a.concatenating_mul(&a), a.concatenating_square(), "a = {a}");
451            assert_eq!(a.widening_mul(&a), a.widening_square(), "a = {a}");
452            assert_eq!(a.wrapping_mul(&a), a.wrapping_square(), "a = {a}");
453            assert_eq!(a.saturating_mul(&a), a.saturating_square(), "a = {a}");
454        }
455    }
456
457    #[test]
458    fn checked_mul_sizes() {
459        const SIZE_A: usize = 4;
460        const SIZE_B: usize = 8;
461
462        for n in 0..Uint::<SIZE_A>::BITS {
463            let mut a = Uint::<SIZE_A>::ZERO;
464            a = a.set_bit_vartime(n, true);
465
466            for m in (0..Uint::<SIZE_B>::BITS).step_by(16) {
467                let mut b = Uint::<SIZE_B>::ZERO;
468                b = b.set_bit_vartime(m, true);
469                let res = a.widening_mul(&b);
470                let res_overflow = res.1.is_nonzero();
471                let checked = a.checked_mul(&b);
472                assert_eq!(checked.is_some().to_bool(), res_overflow.not().to_bool());
473                assert_eq!(
474                    checked.as_inner_unchecked(),
475                    &res.0,
476                    "a = 2**{n}, b = 2**{m}"
477                );
478            }
479        }
480    }
481
482    #[test]
483    fn checked_square_sizes() {
484        const SIZE: usize = 4;
485
486        for n in 0..Uint::<SIZE>::BITS {
487            let mut a = Uint::<SIZE>::ZERO;
488            a = a.set_bit_vartime(n, true);
489
490            let res = a.widening_square();
491            let res_overflow = res.1.is_nonzero();
492            let checked = a.checked_square();
493            assert_eq!(checked.is_some().to_bool(), res_overflow.not().to_bool());
494            assert_eq!(checked.as_inner_unchecked(), &res.0, "a = 2**{n}");
495        }
496    }
497
498    #[test]
499    fn overflowing_mul_limb() {
500        let (max_lo, max_hi) = U128::MAX.widening_mul(&U128::from(Limb::MAX));
501
502        let result = U128::ZERO.overflowing_mul_limb(Limb::ZERO);
503        assert_eq!(result, (U128::ZERO, Limb::ZERO));
504        let result = U128::ZERO.overflowing_mul_limb(Limb::ONE);
505        assert_eq!(result, (U128::ZERO, Limb::ZERO));
506        let result = U128::MAX.overflowing_mul_limb(Limb::ZERO);
507        assert_eq!(result, (U128::ZERO, Limb::ZERO));
508        let result = U128::MAX.overflowing_mul_limb(Limb::ONE);
509        assert_eq!(result, (U128::MAX, Limb::ZERO));
510        let result = U128::MAX.overflowing_mul_limb(Limb::MAX);
511        assert_eq!(result, (max_lo, max_hi.limbs[0]));
512
513        assert_eq!(U128::ZERO.wrapping_mul_limb(Limb::ZERO), U128::ZERO);
514        assert_eq!(U128::ZERO.wrapping_mul_limb(Limb::ONE), U128::ZERO);
515        assert_eq!(U128::MAX.wrapping_mul_limb(Limb::ZERO), U128::ZERO);
516        assert_eq!(U128::MAX.wrapping_mul_limb(Limb::ONE), U128::MAX);
517        assert_eq!(U128::MAX.wrapping_mul_limb(Limb::MAX), max_lo);
518    }
519}