composite_modulus_proofs 0.1.0

Proofs about several propoerties of a composite modulus - square-free, product of 2 primes, a blum integer
Documentation
use crate::{error::Error, math};
use alloc::vec::Vec;
use core::ops::Sub;
use crypto_bigint::{
    modular::SafeGcdInverter, Integer, NonZero, Odd, PrecomputeInverter, RandomMod, Uint, Zero,
};
use crypto_primes::generate_prime_with_rng;
use rand_core::CryptoRngCore;

/// Blum primes are congruent to 3 mod 4, and they allow for efficient proof that a composite number can be
/// factored into Blum primes. Safe primes are also Blum primes.
pub fn blum_prime<R: CryptoRngCore, const PRIME_LIMBS: usize>(rng: &mut R) -> Uint<PRIME_LIMBS> {
    let mut p: Uint<PRIME_LIMBS> = generate_prime_with_rng(rng, Uint::<PRIME_LIMBS>::BITS);
    while !math::misc::is_3_mod_4(&p) {
        p = generate_prime_with_rng(rng, Uint::<PRIME_LIMBS>::BITS);
    }
    p
}

pub fn uint_le_bytes<const LIMBS: usize>(n: &Uint<LIMBS>) -> Vec<u8> {
    let mut bytes = Vec::new();
    for w in n.as_words() {
        bytes.extend_from_slice(w.to_le_bytes().as_ref());
    }
    bytes
}

/// Given `x, p, q`, return `x^-1 mod p-1` and `x^-1 mod q-1`
pub fn get_inv_mod_p_minus_1_and_q_minus_1<
    const PRIME_LIMBS: usize,
    const MODULUS_LIMBS: usize,
    const PRIME_UNSAT_LIMBS: usize,
>(
    x: &Uint<MODULUS_LIMBS>,
    p: &Uint<PRIME_LIMBS>,
    q: &Uint<PRIME_LIMBS>,
) -> Result<(Uint<PRIME_LIMBS>, Uint<PRIME_LIMBS>), Error>
where
    Odd<Uint<PRIME_LIMBS>>: PrecomputeInverter<
        Inverter = SafeGcdInverter<PRIME_LIMBS, PRIME_UNSAT_LIMBS>,
        Output = Uint<PRIME_LIMBS>,
    >,
{
    let (x_inv_p_minus_1, x_inv_q_minus_1) = join!(
        {
            let p_minus_1 = p.sub(&Uint::ONE);
            let x_mod_p_minus_1 = x
                .rem(&p_minus_1.resize().to_nz().unwrap())
                .resize::<PRIME_LIMBS>();
            x_mod_p_minus_1.inv_mod(&p_minus_1)
        },
        {
            let q_minus_1 = q.sub(&Uint::ONE);
            let x_mod_q_minus_1 = x
                .rem(&q_minus_1.resize().to_nz().unwrap())
                .resize::<PRIME_LIMBS>();
            x_mod_q_minus_1.inv_mod(&q_minus_1)
        }
    );
    if x_inv_p_minus_1.is_none().into() {
        // gcd(x, p-1) won't be 1
        return Err(Error::NotInvertible);
    }
    let x_inv_p_minus_1 = x_inv_p_minus_1.unwrap();
    if x_inv_q_minus_1.is_none().into() {
        // gcd(x, q-1) won't be 1
        return Err(Error::NotInvertible);
    }
    let x_inv_q_minus_1 = x_inv_q_minus_1.unwrap();
    Ok((x_inv_p_minus_1, x_inv_q_minus_1))
}

/// Return a number `x mod n` such that `gcd(x, a) = 1` by repeated trying.
pub fn get_coprime<R: CryptoRngCore, const LIMBS: usize, const UNSAT_LIMBS: usize>(
    rng: &mut R,
    a: &Uint<LIMBS>,
    n: &NonZero<Uint<LIMBS>>,
) -> Uint<LIMBS>
where
    Odd<Uint<LIMBS>>: PrecomputeInverter<Inverter = SafeGcdInverter<LIMBS, UNSAT_LIMBS>>,
{
    let mut x = Uint::<LIMBS>::random_mod(rng, n);
    let one = Uint::<LIMBS>::one();
    while x.gcd(a) != one {
        x = Uint::<LIMBS>::random_mod(rng, n);
    }
    x
}

/// Ensure the given value can fit in `max_limbs` limbs by ensuring the remaining most significant limbs to be 0
pub fn ensure_size_bound<const L: usize>(val: &Uint<L>, max_limbs: usize) {
    for (i, l_i) in val.as_limbs().iter().enumerate() {
        if i >= max_limbs {
            assert!(bool::from(l_i.is_zero()));
        }
    }
}

#[cfg(test)]
pub fn get_1024_bit_primes() -> (Odd<crypto_bigint::U1024>, Odd<crypto_bigint::U1024>) {
    let p = crypto_bigint::U1024::from_str_radix_vartime("148677972634832330983979593310074301486537017973460461278300587514468301043894574906886127642530475786889672304776052879927627556769456140664043088700743909632312483413393134504352834240399191134336344285483935856491230340093391784574980688823380828143810804684752914935441384845195613674104960646037368551517", 10).unwrap();
    let q = crypto_bigint::U1024::from_str_radix_vartime("158741574437007245654463598139927898730476924736461654463975966787719309357536545869203069369466212089132653564188443272208127277664424448947476335413293018778018615899291704693105620242763173357203898195318179150836424196645745308205164116144020613415407736216097185962171301808761138424668335445923774195463", 10).unwrap();
    (p.to_odd().unwrap(), q.to_odd().unwrap())
}

#[cfg(test)]
pub fn get_2048_bit_primes() -> (Odd<crypto_bigint::U2048>, Odd<crypto_bigint::U2048>) {
    let p = crypto_bigint::U2048::from_str_radix_vartime("29714581929123975538113401757096867247503888049897126155282036684655427098443105525014011037627595171636270743123002658539126362781975975175765337944068032414914877908601576682891727277414354084913151212699556099504403364557952921342801004492280996715668400103640816843970991636313372745470315455035628601408170417079028041375322988613489555126184463766534396540607235696364068780046050136089443239241198755363075399416619880240793665666130686930042641834471008631848126179567943667666801104241898884410812817279169595932728564045398540809381698710218625876508851295613368971979430951746728583910413116439939078434559", 10).unwrap();
    let q = crypto_bigint::U2048::from_str_radix_vartime("26092039125439665744416238260398697435648406017098864449978544271624805738059383134259926966553183020513772201496445138041100372380433951022528474017361803675300903527015075913474169512090459118347512405005520042799270078794768712536842118407057375282490800716584000679340618387331368881454328913585366779623070416709172563900009042884661367457056955039492864910532308631507979947887262253149026114965531208152102534718129699718880396068707567121640888946634505008511577162806565588378302758525657914103598229285420198323100812024493357088882840233483389168400067067993178810813818498522088103582183754940421621446417", 10).unwrap();
    (p.to_odd().unwrap(), q.to_odd().unwrap())
}

#[cfg(test)]
pub fn get_2048_bit_safe_primes() -> (Odd<crypto_bigint::U2048>, Odd<crypto_bigint::U2048>) {
    let p = crypto_bigint::U2048::from_str_radix_vartime("23093532801597257451454381557585179208352219469996223544542251941125929985815328540294320327083403876625764190953125279671562776569829469977754837977660907494272720872744118688954552777320725349302820917206471328709268440057497173556110093507109159650501332138554465002621986967089677757947294649095623792743511572429592405392345379377007414393355408430787864397488365978743025506620607605394063743439543716230291539484727265794652768105972048568268512105896683396253132610854793405134863041270384588126711178995834562956288110234716675748703948224739725827512799619678465195116501600185960806050690780624564098862167", 10).unwrap();
    let q = crypto_bigint::U2048::from_str_radix_vartime("27441210845657273617368603460727888336379992179179191092446524089518616591599800118528745709154657469277593363732484491471733193023229162827563838534799137915324202178048469744380119540014376456920236011406100750478948335864862173652924000872084798185461480531087839288787764186200527107298944748739807583053607364412650274679881635899363634979798017717302760644341109448171462179576142547939267814110605055016734623893363184973093060865599225649699956472550357183473247689111507499502010925534633450875096806669232120821578782795252255617752723970426931062930170368795686450195928841797174717467477081623532058484059", 10).unwrap();
    (p.to_odd().unwrap(), q.to_odd().unwrap())
}

#[cfg(test)]
pub fn timing_info(mut times: Vec<std::time::Duration>) -> String {
    // Given timings of an operation repeated several times, prints the total time takes, least time,
    // median time and the highest time
    times.sort();
    let median = {
        let mid = times.len() / 2;
        if times.len() % 2 == 0 {
            (times[mid - 1] + times[mid]) / 2
        } else {
            times[mid]
        }
    };
    let total = times.iter().sum::<std::time::Duration>();
    format!(
        "{:.2?} | [{:.2?}, {:.2?}, {:.2?}]",
        total,
        times[0],
        median,
        times[times.len() - 1]
    )
}

pub const fn log_base_2<const N: usize>() -> usize {
    let mut lg2 = 0_usize;
    let mut a = N;
    while a > 0 {
        lg2 += 1;
        a >>= 1;
    }
    lg2
}