cryptography-rs 0.6.2

Block ciphers, hashes, public-key, and post-quantum primitives implemented directly from their specifications and original papers.
Documentation
//! RSA public-key primitive (Rivest, Shamir, Adleman, 1978).
//!
//! This module exposes the core RSA trapdoor permutation directly: key
//! derivation from explicit primes plus modular exponentiation for
//! encrypt/decrypt. Standards-based message formatting lives in `rsa_pkcs1`,
//! and standard key containers live in `rsa_io`.

use core::fmt;

use crate::public_key::bigint::{BigUint, MontgomeryCtx};
use crate::public_key::primes::{
    gcd, is_probable_prime, lcm, mod_inverse, mod_pow, random_probable_prime,
};
use crate::Csprng;

/// Public key for the core RSA primitive.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct RsaPublicKey {
    e: BigUint,
    n: BigUint,
}

/// Private key for the core RSA primitive.
#[derive(Clone, Eq, PartialEq)]
pub struct RsaPrivateKey {
    e: BigUint,
    d: BigUint,
    n: BigUint,
    p: BigUint,
    q: BigUint,
    d_p: BigUint,
    d_q: BigUint,
    q_inv: BigUint,
    p_ctx: MontgomeryCtx,
    q_ctx: MontgomeryCtx,
}

/// Namespace wrapper for the core RSA construction.
pub struct Rsa;

impl RsaPublicKey {
    #[must_use]
    pub(crate) fn from_components(e: BigUint, n: BigUint) -> Self {
        Self { e, n }
    }

    /// Return the public exponent.
    #[must_use]
    pub fn exponent(&self) -> &BigUint {
        &self.e
    }

    /// Return the modulus `n = p * q`.
    #[must_use]
    pub fn modulus(&self) -> &BigUint {
        &self.n
    }

    /// Apply the raw public operation `m^e mod n`.
    ///
    /// This is textbook RSA's deterministic trapdoor permutation. It performs
    /// no padding or randomness, so equal messages produce equal ciphertexts;
    /// that lack of semantic security is exactly why OAEP exists on top of the
    /// raw arithmetic.
    ///
    /// In this crate's generated keys, `e` is chosen near `65_537` (`0x10001`),
    /// a sparse exponent with two set bits. That is why this operation is often
    /// much faster than private-key `decrypt_raw`.
    #[must_use]
    pub fn encrypt_raw(&self, message: &BigUint) -> BigUint {
        mod_pow(message, &self.e, &self.n)
    }
}

impl RsaPrivateKey {
    /// Return the public exponent paired with this private key.
    #[must_use]
    pub(crate) fn public_exponent(&self) -> &BigUint {
        &self.e
    }

    /// Return the private exponent.
    #[must_use]
    pub fn exponent(&self) -> &BigUint {
        &self.d
    }

    /// Return the modulus `n = p * q`.
    #[must_use]
    pub fn modulus(&self) -> &BigUint {
        &self.n
    }

    /// Return the first prime factor.
    #[must_use]
    pub(crate) fn prime1(&self) -> &BigUint {
        &self.p
    }

    /// Return the second prime factor.
    #[must_use]
    pub(crate) fn prime2(&self) -> &BigUint {
        &self.q
    }

    /// Return the CRT exponent `d mod (p - 1)`.
    #[must_use]
    pub(crate) fn crt_exponent1(&self) -> &BigUint {
        &self.d_p
    }

    /// Return the CRT exponent `d mod (q - 1)`.
    #[must_use]
    pub(crate) fn crt_exponent2(&self) -> &BigUint {
        &self.d_q
    }

    /// Return the CRT coefficient `q^-1 mod p`.
    #[must_use]
    pub(crate) fn crt_coefficient(&self) -> &BigUint {
        &self.q_inv
    }

    /// Apply the raw private operation with CRT recombination.
    ///
    /// This path is intentionally heavier than `encrypt_raw`: it uses large
    /// private exponents (`dP`, `dQ`) and two CRT exponentiations to recover
    /// throughput. Even with CRT, public encrypt is usually faster because the
    /// public exponent is sparse.
    #[must_use]
    pub fn decrypt_raw(&self, ciphertext: &BigUint) -> BigUint {
        // RSA-CRT:
        // m1 = c^dP mod p
        // m2 = c^dQ mod q
        // h  = (qInv * (m1 - m2)) mod p
        // m  = m2 + h*q
        let c_mod_p = ciphertext.modulo(&self.p);
        let c_mod_q = ciphertext.modulo(&self.q);
        let m1 = self.p_ctx.pow(&c_mod_p, &self.d_p);
        let m2 = self.q_ctx.pow(&c_mod_q, &self.d_q);

        // CRT recombination: h = (m1 - m2) mod p.
        // m2 is reduced mod q but NOT mod p, so m2 can be ≥ p.
        // Reduce m2 mod p first so that the conditional subtraction stays in
        // [0, p) and `m1 + p - m2_mod_p` is always non-negative.
        let m2_mod_p = m2.modulo(&self.p);
        let delta = if m1 >= m2_mod_p {
            m1.sub_ref(&m2_mod_p)
        } else {
            m1.add_ref(&self.p).sub_ref(&m2_mod_p)
        };
        let h = BigUint::mod_mul(&self.q_inv, &delta, &self.p);
        m2.add_ref(&self.q.mul_ref(&h))
    }
}

impl fmt::Debug for RsaPrivateKey {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str("RsaPrivateKey(<redacted>)")
    }
}

impl Rsa {
    /// Derive a raw RSA key pair from explicit primes and an explicit exponent.
    ///
    /// Returns `None` if the inputs are equal, composite, the exponent is not
    /// greater than one, or the exponent is not invertible modulo
    /// `lambda = lcm(p - 1, q - 1)`.
    #[must_use]
    pub fn from_primes_with_exponent(
        p: &BigUint,
        q: &BigUint,
        exponent: &BigUint,
    ) -> Option<(RsaPublicKey, RsaPrivateKey)> {
        if p == q || !is_probable_prime(p) || !is_probable_prime(q) {
            return None;
        }
        if exponent <= &BigUint::one() {
            return None;
        }

        let p_minus_one = p.sub_ref(&BigUint::one());
        let q_minus_one = q.sub_ref(&BigUint::one());
        let lambda = lcm(&p_minus_one, &q_minus_one);
        if gcd(exponent, &lambda) != BigUint::one() {
            return None;
        }

        let d = mod_inverse(exponent, &lambda)?;
        let n = p.mul_ref(q);
        let d_p = d.modulo(&p_minus_one);
        let d_q = d.modulo(&q_minus_one);
        let q_inv = mod_inverse(q, p)?;
        let p_ctx = MontgomeryCtx::new(p)?;
        let q_ctx = MontgomeryCtx::new(q)?;

        Some((
            RsaPublicKey {
                e: exponent.clone(),
                n: n.clone(),
            },
            RsaPrivateKey {
                e: exponent.clone(),
                d,
                n,
                p: p.clone(),
                q: q.clone(),
                d_p,
                d_q,
                q_inv,
                p_ctx,
                q_ctx,
            },
        ))
    }

    /// Derive a raw RSA key pair from explicit primes using the Python
    /// reference's default exponent search.
    ///
    /// The search starts at `2^16 + 1 = 65_537`, the standard sparse public
    /// exponent: it is prime, has only two set bits, and therefore keeps the
    /// public operation cheap. The loop then increments the power until it
    /// finds a value coprime to `lambda = lcm(p - 1, q - 1)`. This terminates
    /// quickly in practice because `lambda` has only finitely many prime
    /// factors, so some Fermat-like exponent in the sequence must be coprime
    /// to it.
    #[must_use]
    pub fn from_primes(p: &BigUint, q: &BigUint) -> Option<(RsaPublicKey, RsaPrivateKey)> {
        if p == q || !is_probable_prime(p) || !is_probable_prime(q) {
            return None;
        }

        let p_minus_one = p.sub_ref(&BigUint::one());
        let q_minus_one = q.sub_ref(&BigUint::one());
        let lambda = lcm(&p_minus_one, &q_minus_one);

        let mut exponent_bit = 16usize;
        loop {
            let mut exponent = BigUint::zero();
            exponent.set_bit(exponent_bit);
            exponent = exponent.add_ref(&BigUint::one());
            if gcd(&exponent, &lambda) == BigUint::one() {
                return Self::from_primes_with_exponent(p, q, &exponent);
            }
            exponent_bit += 1;
        }
    }

    /// Generate an RSA key pair from a CSPRNG and explicit public exponent.
    ///
    /// This keeps the arithmetic primitive usable without forcing callers to
    /// provide their own prime search. The generated primes are screened with
    /// the in-tree Miller-Rabin helper rather than a dedicated external
    /// multiprecision backend, so this remains the crate's built-in reference
    /// key-generation path rather than a substitute for a hardened PKI stack.
    #[must_use]
    pub fn generate_with_exponent<R: Csprng>(
        rng: &mut R,
        bits: usize,
        exponent: &BigUint,
    ) -> Option<(RsaPublicKey, RsaPrivateKey)> {
        // Below 32 total bits, the split primes become so small that the key
        // space is trivially enumerable and the "search until invertible"
        // exponent logic stops being meaningful as a cryptographic API.
        if bits < 32 {
            return None;
        }

        let p_bits = bits / 2;
        let q_bits = bits - p_bits;
        loop {
            let p = random_probable_prime(rng, p_bits)?;
            let q = random_probable_prime(rng, q_bits)?;
            if let Some(keypair) = Self::from_primes_with_exponent(&p, &q, exponent) {
                return Some(keypair);
            }
        }
    }

    /// Generate an RSA key pair using the Python reference's default exponent
    /// search.
    #[must_use]
    pub fn generate<R: Csprng>(rng: &mut R, bits: usize) -> Option<(RsaPublicKey, RsaPrivateKey)> {
        // Match the explicit-exponent generator's floor for the same reason:
        // below 32 bits the result is too small to be a meaningful RSA key.
        if bits < 32 {
            return None;
        }

        let p_bits = bits / 2;
        let q_bits = bits - p_bits;
        loop {
            let p = random_probable_prime(rng, p_bits)?;
            let q = random_probable_prime(rng, q_bits)?;
            if let Some(keypair) = Self::from_primes(&p, &q) {
                return Some(keypair);
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::Rsa;
    use crate::public_key::bigint::BigUint;
    use crate::CtrDrbgAes256;

    #[test]
    fn derive_reference_key_with_default_exponent() {
        let p = BigUint::from_u64(61);
        let q = BigUint::from_u64(53);
        let (public, private) = Rsa::from_primes(&p, &q).expect("valid RSA key");
        assert_eq!(public.modulus(), &BigUint::from_u64(3_233));
        assert_eq!(public.exponent(), &BigUint::from_u64(65_537));
        assert_eq!(private.exponent(), &BigUint::from_u64(413));
        assert_eq!(private.modulus(), &BigUint::from_u64(3_233));
    }

    #[test]
    fn roundtrip_small_messages() {
        let p = BigUint::from_u64(61);
        let q = BigUint::from_u64(53);
        let (public, private) = Rsa::from_primes(&p, &q).expect("valid RSA key");

        for msg in [0u64, 1, 2, 65, 123, 3_232] {
            let message = BigUint::from_u64(msg);
            let ciphertext = public.encrypt_raw(&message);
            let plaintext = private.decrypt_raw(&ciphertext);
            assert_eq!(plaintext, message);
        }
    }

    #[test]
    fn exact_small_ciphertext_matches_reference() {
        let p = BigUint::from_u64(61);
        let q = BigUint::from_u64(53);
        let (public, private) = Rsa::from_primes(&p, &q).expect("valid RSA key");
        let message = BigUint::from_u64(65);
        let ciphertext = public.encrypt_raw(&message);
        assert_eq!(ciphertext, BigUint::from_u64(2_790));
        assert_eq!(private.decrypt_raw(&ciphertext), message);
    }

    #[test]
    fn raw_rsa_is_multiplicatively_homomorphic() {
        let p = BigUint::from_u64(61);
        let q = BigUint::from_u64(53);
        let (public, private) = Rsa::from_primes(&p, &q).expect("valid RSA key");
        let left = BigUint::from_u64(12);
        let right = BigUint::from_u64(17);

        let left_cipher = public.encrypt_raw(&left);
        let right_cipher = public.encrypt_raw(&right);
        let combined_cipher = BigUint::mod_mul(&left_cipher, &right_cipher, public.modulus());
        let decrypted = private.decrypt_raw(&combined_cipher);
        let expected = BigUint::mod_mul(&left, &right, public.modulus());

        assert_eq!(decrypted, expected);
    }

    #[test]
    fn explicit_exponent_matches_classic_example() {
        let p = BigUint::from_u64(61);
        let q = BigUint::from_u64(53);
        let exponent = BigUint::from_u64(17);
        let (public, private) =
            Rsa::from_primes_with_exponent(&p, &q, &exponent).expect("valid RSA key");
        assert_eq!(public.exponent(), &BigUint::from_u64(17));
        assert_eq!(private.exponent(), &BigUint::from_u64(413));
    }

    #[test]
    fn rejects_non_invertible_exponent() {
        let p = BigUint::from_u64(11);
        let q = BigUint::from_u64(13);
        let exponent = BigUint::from_u64(3);
        assert!(Rsa::from_primes_with_exponent(&p, &q, &exponent).is_none());
    }

    #[test]
    fn generate_keypair_roundtrip() {
        let seed = [0x55u8; 48];
        let mut drbg = CtrDrbgAes256::new(&seed);
        let (public, private) = Rsa::generate(&mut drbg, 64).expect("generated RSA key");
        assert!(public.modulus().bits() >= 63);
        let message = BigUint::from_u64(42);
        let ciphertext = public.encrypt_raw(&message);
        assert_eq!(private.decrypt_raw(&ciphertext), message);
    }
}