composite_modulus_proofs 0.1.0

Proofs about several propoerties of a composite modulus - square-free, product of 2 primes, a blum integer
Documentation
//! Proof that modulus `N` has exactly two distinct prime divisors. So `N` could be of form `p^a*q^b` but not `p*q*r`
//! where `p`, `q` and `r` are primes. Described in section 3.4 of the paper [Efficient Noninteractive Certification of RSA Moduli and Beyond](https://eprint.iacr.org/2018/057)

use crate::{
    error::Error,
    math::{
        jacobi::jacobi_symbol_vartime, prime_check::is_prime_power,
        sqrt::sqrt_mod_composite_given_prime_factors_as_mtg_params_and_p_inv,
    },
    setup::{Modulus, Primes, PrimesWithPrecomp},
    util::uint_le_bytes,
};
use alloc::{vec, vec::Vec};
use ark_std::{cfg_into_iter, cfg_iter_mut, io::Write};
use crypto_bigint::{
    modular::{MontyForm, MontyParams, SafeGcdInverter},
    Concat, Odd, PrecomputeInverter, Split, Uint,
};
use crypto_primes::is_prime_with_rng;
use digest::{ExtendableOutput, Update};
use rand_core::CryptoRngCore;

#[cfg(feature = "parallel")]
use rayon::prelude::*;

/// Proof that modulus `N` has exactly two distinct prime divisors. Number of challenges/responses
/// depend on `KAPPA`
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProofTwoPrimeDivisors<const MODULUS_LIMBS: usize, const KAPPA: usize>(
    /// Responses to the challenges. The response is None when the challenge is a quadratic non-residue else it's
    /// the square root of the challenge
    pub Vec<Option<Uint<MODULUS_LIMBS>>>,
);

impl<const MODULUS_LIMBS: usize, const KAPPA: usize> ProofTwoPrimeDivisors<MODULUS_LIMBS, KAPPA> {
    pub fn new<
        D: Default + Update + ExtendableOutput,
        W: Write + AsRef<[u8]>,
        const PRIME_LIMBS: usize,
        const PRIME_UNSAT_LIMBS: usize,
    >(
        primes: Primes<PRIME_LIMBS>,
        modulus: &Modulus<MODULUS_LIMBS>,
        nonce: &[u8],
        transcript: &mut W,
    ) -> Result<Self, Error>
    where
        Uint<PRIME_LIMBS>: Concat<Output = Uint<MODULUS_LIMBS>>,
        Uint<MODULUS_LIMBS>: Split<Output = Uint<PRIME_LIMBS>>,
        Odd<Uint<PRIME_LIMBS>>: PrecomputeInverter<
            Inverter = SafeGcdInverter<PRIME_LIMBS, PRIME_UNSAT_LIMBS>,
            Output = Uint<PRIME_LIMBS>,
        >,
    {
        let p_mtg = MontyParams::new(primes.p);
        let q_mtg = MontyParams::new(primes.q);
        // gcd(p, q) will be 1
        let p_inv = MontyForm::new(p_mtg.modulus(), q_mtg).inv().unwrap();
        Self::new_inner::<D, W, PRIME_LIMBS>(p_mtg, q_mtg, p_inv, modulus, nonce, transcript)
    }

    pub fn new_given_precomputation<
        D: Default + Update + ExtendableOutput,
        W: Write + AsRef<[u8]>,
        const PRIME_LIMBS: usize,
    >(
        primes: PrimesWithPrecomp<PRIME_LIMBS, MODULUS_LIMBS>,
        modulus: &Modulus<MODULUS_LIMBS>,
        nonce: &[u8],
        transcript: &mut W,
    ) -> Result<Self, Error>
    where
        Uint<PRIME_LIMBS>: Concat<Output = Uint<MODULUS_LIMBS>>,
    {
        Self::new_inner::<D, W, PRIME_LIMBS>(
            primes.p_mtg,
            primes.q_mtg,
            primes.p_inv,
            modulus,
            nonce,
            transcript,
        )
    }

    pub fn num_responses(&self) -> usize {
        self.0.len()
    }

    pub fn verify<
        R: CryptoRngCore,
        D: Default + Update + ExtendableOutput,
        W: Write + AsRef<[u8]>,
    >(
        &self,
        rng: &mut R,
        modulus: &Modulus<MODULUS_LIMBS>,
        nonce: &[u8],
        transcript: &mut W,
    ) -> Result<(), Error> {
        if is_prime_with_rng(rng, modulus.0.as_ref()) {
            return Err(Error::ModulusIsPrime);
        }
        if is_prime_power(rng, &modulus.0).is_some() {
            return Err(Error::ModulusIsPrimePower);
        }
        let num_chals = num_challenges::<KAPPA>();
        if self.0.len() != num_chals as usize {
            return Err(Error::InsufficientResponses(
                self.0.len(),
                num_chals as usize,
            ));
        }
        let num_quadratic_residues = self.0.iter().filter(|s| s.is_some()).count();
        if num_quadratic_residues < 3 * (num_chals as usize / 8) {
            return Err(Error::InsufficientQuadraticResidues(
                num_quadratic_residues,
                3 * (num_chals as usize / 8),
            ));
        }
        let challenges = Self::challenges::<D, W>(modulus, nonce, num_chals, transcript);
        let params = MontyParams::new_vartime(modulus.0);
        #[allow(unused_mut)]
        let mut res = cfg_into_iter!(0..num_chals as usize).map(|i| {
            if let Some(s) = self.0[i] {
                if !modulus.is_greater_than(&s) {
                    return Err(Error::GreaterThanModulus(i as u32));
                }
                if MontyForm::new(&s, params).square().retrieve() != challenges[i] {
                    return Err(Error::InvalidProofForIndex(i as u32));
                }
                Ok(())
            } else {
                Ok(())
            }
        });

        report_error_if_any!(res)
    }

    /// Generate challenges by rejection sampling
    pub fn challenges<D: Default + Update + ExtendableOutput, W: Write + AsRef<[u8]>>(
        modulus: &Modulus<MODULUS_LIMBS>,
        nonce: &[u8],
        num_challenges: u32,
        transcript: &mut W,
    ) -> Vec<Uint<MODULUS_LIMBS>> {
        let mut chals = vec![Uint::<MODULUS_LIMBS>::ZERO; num_challenges as usize];
        let mut ctr = 0;
        transcript
            .write_all(b"proof-that-modulus-has-2-prime-divisors")
            .unwrap();
        transcript.write_all(&uint_le_bytes(&modulus.0)).unwrap();
        // Not hashing in MODULUS_LIMBS since it can change if proof generation and verification happen on different arch. (64 bit vs 32 bit)
        transcript
            .write_all(modulus.size_for_hashing().as_slice())
            .unwrap();
        transcript.write_all(nonce).unwrap();
        let mut c_bytes = vec![0; Uint::<MODULUS_LIMBS>::BYTES];
        let mut i = 0_u32;
        while ctr < num_challenges {
            transcript.write_all(i.to_le_bytes().as_slice()).unwrap();
            D::digest_xof(transcript.as_ref(), &mut c_bytes);
            let c = Uint::<MODULUS_LIMBS>::from_le_slice(&c_bytes);
            // NOTE: For prover, jacobi could be computed using primes making proof gen faster
            // NOTE: Computing Jacobi symbol in variable time is fine since both the modulus and challenges are public
            if modulus.is_greater_than(&c)
                && jacobi_symbol_vartime(c.clone(), modulus.0.clone()).is_one()
            {
                chals[ctr as usize] = c;
                ctr += 1;
            }
            i += 1;
        }
        chals
    }

    pub fn new_inner<
        D: Default + Update + ExtendableOutput,
        W: Write + AsRef<[u8]>,
        const PRIME_LIMBS: usize,
    >(
        p_mtg: MontyParams<PRIME_LIMBS>,
        q_mtg: MontyParams<PRIME_LIMBS>,
        p_inv: MontyForm<PRIME_LIMBS>,
        modulus: &Modulus<MODULUS_LIMBS>,
        nonce: &[u8],
        transcript: &mut W,
    ) -> Result<Self, Error>
    where
        Uint<PRIME_LIMBS>: Concat<Output = Uint<MODULUS_LIMBS>>,
    {
        let num_chals = num_challenges::<KAPPA>();
        let challenges = Self::challenges::<D, W>(modulus, nonce, num_chals, transcript);
        let mut responses = vec![None; num_chals as usize];
        cfg_iter_mut!(responses).enumerate().for_each(|(i, s)| {
            *s = sqrt_mod_composite_given_prime_factors_as_mtg_params_and_p_inv::<
                PRIME_LIMBS,
                MODULUS_LIMBS,
            >(challenges[i], p_mtg, q_mtg, p_inv);
        });
        Ok(Self(responses))
    }
}

pub const fn num_challenges<const KAPPA: usize>() -> u32 {
    let x = KAPPA * 32;
    // ln(2) = 0.3
    let y = x * 3;
    // ceil(y/10)
    ((10 + y - 1) / 10) as u32
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{safegcd_nlimbs, util::timing_info};
    use crypto_bigint::{U128, U64};
    use rand_core::OsRng;
    use sha3::Shake256;
    use std::time::Instant;

    macro_rules! check_given_primes {
        ( $num_iters: ident, $prime_type:ident, $primes: ident ) => {
            const KAPPA: usize = 128;
            const NUM_CHALLENGES: usize = num_challenges::<KAPPA>() as usize;

            const PRIME_LIMBS: usize = $prime_type::LIMBS;
            const PRIME_UNSAT_LIMBS: usize = safegcd_nlimbs!(Uint::<PRIME_LIMBS>::BITS as usize);
            const MODULUS_LIMBS: usize = PRIME_LIMBS * 2;

            let mut rng = OsRng::default();
            let modulus = Modulus::<MODULUS_LIMBS>::new(&$primes);
            let primes_with_crt = PrimesWithPrecomp::from($primes.clone());
            let nonce = b"123";

            let mut prove_times = vec![];
            let mut prove_with_precomp_times = vec![];
            let mut ver_times = vec![];

            println!(
                "Running {} iterations for {} bits prime - {} challenges for {} bits of security",
                $num_iters,
                $prime_type::BITS,
                NUM_CHALLENGES,
                KAPPA
            );
            for _ in 0..$num_iters {
                let mut transcript = vec![];
                let start = Instant::now();
                let proof = ProofTwoPrimeDivisors::<MODULUS_LIMBS, KAPPA>::new::<
                    Shake256,
                    _,
                    PRIME_LIMBS,
                    PRIME_UNSAT_LIMBS,
                >($primes.clone(), &modulus, nonce, &mut transcript)
                .unwrap();
                prove_times.push(start.elapsed());

                assert_eq!(proof.num_responses(), NUM_CHALLENGES);

                let mut transcript = vec![];
                let start = Instant::now();
                proof
                    .verify::<OsRng, Shake256, _>(&mut rng, &modulus, nonce, &mut transcript)
                    .unwrap();
                ver_times.push(start.elapsed());

                let mut transcript = vec![];
                let start = Instant::now();
                let proof =
                    ProofTwoPrimeDivisors::<MODULUS_LIMBS, KAPPA>::new_given_precomputation::<
                        Shake256,
                        _,
                        PRIME_LIMBS,
                    >(primes_with_crt.clone(), &modulus, nonce, &mut transcript)
                    .unwrap();
                prove_with_precomp_times.push(start.elapsed());

                assert_eq!(proof.num_responses(), NUM_CHALLENGES);

                let mut transcript = vec![];
                let start = Instant::now();
                proof
                    .verify::<OsRng, Shake256, _>(&mut rng, &modulus, nonce, &mut transcript)
                    .unwrap();
                ver_times.push(start.elapsed());
            }

            println!("Proving time: {:?}", timing_info(prove_times));
            println!(
                "Proving time with precomputation: {:?}",
                timing_info(prove_with_precomp_times)
            );
            println!("Verification time: {:?}", timing_info(ver_times));
        };
    }

    #[test]
    fn bad_proofs() {
        const KAPPA: usize = 128;

        const PRIME_LIMBS: usize = U64::LIMBS;
        const PRIME_UNSAT_LIMBS: usize = safegcd_nlimbs!(Uint::<PRIME_LIMBS>::BITS as usize);
        const MODULUS_LIMBS: usize = PRIME_LIMBS * 2;

        let mut rng = OsRng::default();
        let primes = Primes::<{ U64::LIMBS }>::new(&mut rng);
        let modulus = Modulus::<MODULUS_LIMBS>::new(&primes);
        let nonce = b"123";

        let mut transcript = vec![];
        let proof = ProofTwoPrimeDivisors::<MODULUS_LIMBS, KAPPA>::new::<
            Shake256,
            _,
            PRIME_LIMBS,
            PRIME_UNSAT_LIMBS,
        >(primes.clone(), &modulus, nonce, &mut transcript)
        .unwrap();

        let mut transcript = vec![];
        proof
            .verify::<OsRng, Shake256, _>(&mut rng, &modulus, nonce, &mut transcript)
            .unwrap();

        // Number of responses < number of challenges
        let mut bad_proof = proof.clone();
        bad_proof.0.remove(0);
        let mut transcript = vec![];
        assert!(bad_proof
            .verify::<OsRng, Shake256, _>(&mut rng, &modulus, nonce, &mut transcript)
            .is_err());

        // Modulus of form p^2 should fail
        let mut same_primes = primes.clone();
        same_primes.q = same_primes.p;
        let modulus_with_square = Modulus::<MODULUS_LIMBS>::new(&same_primes);

        let mut transcript = vec![];
        let proof = ProofTwoPrimeDivisors::<MODULUS_LIMBS, KAPPA>::new::<
            Shake256,
            _,
            PRIME_LIMBS,
            PRIME_UNSAT_LIMBS,
        >(primes.clone(), &modulus_with_square, nonce, &mut transcript)
        .unwrap();
        let mut transcript = vec![];
        assert!(proof
            .verify::<OsRng, Shake256, _>(&mut rng, &modulus, nonce, &mut transcript)
            .is_err());
    }

    #[test]
    fn proof() {
        let mut rng = OsRng::default();
        let primes = Primes::<{ U128::LIMBS }>::new(&mut rng);
        let num_iters = 10;
        check_given_primes!(num_iters, U128, primes);
    }

    // TODO: Uncomment after optimizing prime-power check
    // #[test]
    // fn proof_with_1024_bit_primes() {
    //     let (p, q) = get_1024_bit_primes();
    //     let primes = Primes::from_primes(p, q);
    //     let num_iters = 10;
    //     check_given_primes!(num_iters, U1024, primes);
    // }
    //
    // #[test]
    // fn proof_with_2048_bit_primes() {
    //     let (p, q) = get_2048_bit_primes();
    //     let primes = Primes::from_primes(p, q);
    //     let num_iters = 10;
    //     check_given_primes!(num_iters, U2048, primes);
    // }
}