cnfy-uint 0.2.3

Zero-dependency 256-bit unsigned integer arithmetic for cryptographic applications
Documentation
//! Widening multiplication of [`U256`] by a `u128` scalar into a [`U384`].
use super::U256;
use crate::u384::U384;
use core::ops::Mul;

/// Widening multiplication of a 256-bit integer by a `u128` scalar,
/// producing a 384-bit result.
///
/// Decomposes `rhs` into two `u64` halves (`rhs_hi`, `rhs_lo`), computes
/// two 320-bit partial products (`self × rhs_hi` shifted left 64 bits and
/// `self × rhs_lo`), then accumulates them with carry propagation across
/// six result limbs. The result is returned as a [`U384`].
///
/// This operation cannot overflow: `(2^256 − 1) × (2^128 − 1) < 2^384`.
///
/// # Examples
///
/// ```
/// use cnfy_uint::u256::U256;
/// use cnfy_uint::u384::U384;
///
/// let a = U256::from_be_limbs([0, 0, 0, 7]);
/// let result = a * 6u128;
/// assert_eq!(result, U384::from_be_limbs([0, 0, 0, 0, 0, 42]));
/// ```
impl Mul<u128> for U256 {
    type Output = U384;

    #[inline]
    fn mul(self, rhs: u128) -> U384 {
        let rhs_lo = rhs as u64;
        let rhs_hi = (rhs >> 64) as u64;

        // Multiply self by rhs_lo → 5 limbs in LE order
        let lo = rhs_lo as u128;
        let p0 = (self.0[0] as u128) * lo;
        let p1 = (self.0[1] as u128) * lo + (p0 >> 64);
        let p2 = (self.0[2] as u128) * lo + (p1 >> 64);
        let p3 = (self.0[3] as u128) * lo + (p2 >> 64);

        let lo_limbs: [u64; 5] = [
            p0 as u64,
            p1 as u64,
            p2 as u64,
            p3 as u64,
            (p3 >> 64) as u64,
        ];

        // Multiply self by rhs_hi → 5 limbs, shifted left by 64 bits
        let hi = rhs_hi as u128;
        let q0 = (self.0[0] as u128) * hi;
        let q1 = (self.0[1] as u128) * hi + (q0 >> 64);
        let q2 = (self.0[2] as u128) * hi + (q1 >> 64);
        let q3 = (self.0[3] as u128) * hi + (q2 >> 64);

        // hi_limbs occupies positions [1,2,3,4,5] (shifted left by 1 limb)
        // lo_limbs occupies positions [0,1,2,3,4]
        // Overlap at positions [1,2,3,4]

        // Add hi partial product (shifted) into the result
        // Position 0 (LSB): lo_limbs[0] only
        let r0 = lo_limbs[0];

        // Position 1: lo_limbs[1] + q0_lo
        let sum1 = (lo_limbs[1] as u128) + (q0 as u64 as u128);
        let r1 = sum1 as u64;

        // Position 2: lo_limbs[2] + q1_lo + carry1
        let sum2 = (lo_limbs[2] as u128) + (q1 as u64 as u128) + (sum1 >> 64);
        let r2 = sum2 as u64;

        // Position 3: lo_limbs[3] + q2_lo + carry2
        let sum3 = (lo_limbs[3] as u128) + (q2 as u64 as u128) + (sum2 >> 64);
        let r3 = sum3 as u64;

        // Position 4: lo_limbs[4] + q3_lo + carry3
        let sum4 = (lo_limbs[4] as u128) + (q3 as u64 as u128) + (sum3 >> 64);
        let r4 = sum4 as u64;

        // Position 5 (MSB): q3_hi + carry4
        let r5 = ((q3 >> 64) as u64) + (sum4 >> 64) as u64;

        U384([r0, r1, r2, r3, r4, r5])
    }
}

#[cfg(test)]
mod ai_tests {
    use super::*;

    /// 7 × 6 = 42.
    #[test]
    fn small_product() {
        let a = U256::from_be_limbs([0, 0, 0, 7]);
        assert_eq!(a * 6u128, U384::from_be_limbs([0, 0, 0, 0, 0, 42]));
    }

    /// Anything times zero is zero.
    #[test]
    fn mul_by_zero() {
        let a = U256::from_be_limbs([0x1234, 0x5678, 0x9ABC, 0xDEF0]);
        assert_eq!(a * 0u128, U384::from_be_limbs([0; 6]));
    }

    /// Multiplicative identity: x × 1 = x in the lower 256 bits.
    #[test]
    fn mul_by_one() {
        let a = U256::from_be_limbs([0x1234, 0x5678, 0x9ABC, 0xDEF0]);
        assert_eq!(
            a * 1u128,
            U384::from_be_limbs([0, 0, 0x1234, 0x5678, 0x9ABC, 0xDEF0]),
        );
    }

    /// Multiplying by a value that only uses the high 64 bits of u128.
    #[test]
    fn mul_by_high_u128() {
        let a = U256::from_be_limbs([0, 0, 0, 3]);
        // 3 * (5 << 64)
        let rhs: u128 = 5u128 << 64;
        let result = a * rhs;
        // = 15 << 64 = [0, 0, 0, 0, 15, 0]
        assert_eq!(result, U384::from_be_limbs([0, 0, 0, 0, 15, 0]));
    }

    /// MAX × MAX_u128 produces correct 384-bit result.
    #[test]
    fn max_times_max_u128() {
        // (2^256 - 1) * (2^128 - 1) = 2^384 - 2^256 - 2^128 + 1
        let result = U256::MAX * u128::MAX;
        assert_eq!(
            result,
            U384::from_be_limbs([u64::MAX, u64::MAX - 1, u64::MAX, u64::MAX, 0, 1]),
        );
    }

    /// Carry propagation across all limbs.
    #[test]
    fn full_width_carry() {
        let a = U256::from_be_limbs([u64::MAX, u64::MAX, u64::MAX, u64::MAX]);
        let result = a * 2u128;
        // (2^256 - 1) * 2 = 2^257 - 2
        assert_eq!(
            result,
            U384::from_be_limbs([0, 1, u64::MAX, u64::MAX, u64::MAX, u64::MAX - 1]),
        );
    }

    /// Cross-validates against `Mul<U256>` for a u128 that fits in u64.
    #[test]
    fn matches_mul_for_small_rhs() {
        let a = U256::from_be_limbs([0xAAAA, 0xBBBB, 0xCCCC, 0xDDDD]);
        let rhs = 0x12345678u128;

        // Use Mul<U256> with rhs as U256
        let wide = a * U256::from_be_limbs([0, 0, 0, rhs as u64]);
        let w = wide.to_be_limbs();
        let result = a * rhs;

        // wide is U512 [u64; 8], result is U384 [u64; 6]
        // w[0..2] should be zero (256-bit × 64-bit fits in 320 bits)
        assert_eq!(w[0], 0);
        assert_eq!(w[1], 0);
        assert_eq!(
            result,
            U384::from_be_limbs([w[2], w[3], w[4], w[5], w[6], w[7]]),
        );
    }

    /// Full u128 scalar: multiplying by (2^64 + 1) doubles and adds.
    #[test]
    fn mul_by_two_pow_64_plus_one() {
        let a = U256::from_be_limbs([0, 0, 0, u64::MAX]);
        // (2^64 - 1) * (2^64 + 1) = 2^128 - 1
        let rhs: u128 = (1u128 << 64) + 1;
        let result = a * rhs;
        assert_eq!(
            result,
            U384::from_be_limbs([0, 0, 0, 0, u64::MAX, u64::MAX]),
        );
    }
}