round5 0.1.2

Implementation of Round5 post-quantum PKE and KEM algorithms
Documentation
use crate::types::Random;
use crate::parameters::Parameters;
use crate::xef::{xef_compute, xef_fixerr};
use crate::r5_core::{create_a, create_s_t, create_r_t};
use crate::pack::{pack, pack_pk, unpack_pk, pack_ct, unpack_ct};
use crate::r5_core::matrix::{transpose_matrix, mult_matrix, round_matrix, decompress_matrix};


pub fn r5_cpa_pke_keygen(pk: &mut [u8], sk: &mut [u8], drbg: &mut dyn Random, params: &Parameters) {
    let len_b = params.k as usize * params.n_bar as usize * params.n as usize;

    let mut sigma = vec![0u8; params.kappa_bytes as usize];
    drbg.fill_bytes(&mut sigma);

    let a = create_a(&sigma, params);

    drbg.fill_bytes(&mut sk[..params.kappa_bytes as usize]);
    let s_t = create_s_t(&sk[..16], params);

    let s = transpose_matrix(&s_t, params.n_bar as usize, params.k as usize, params.n as usize);
    let mut b = mult_matrix(&a, params.k as usize, params.k as usize,
                            &s, params.k as usize, params.n_bar as usize,
                            params.n as usize, params.q, false);
    round_matrix(&mut b, params.k as usize * params.n_bar as usize, params.n as usize,
                 params.q_bits as u16, params.p_bits as u16, params.h1);

    pack_pk(pk, &sigma, &b, len_b, params.p_bits);
}

#[allow(clippy::many_single_char_names)]
pub fn r5_cpa_pke_encrypt(ct: &mut [u8], pk: &[u8], m: &[u8], rho: &[u8], params: &Parameters) {
    let len_b = params.k as usize * params.n_bar as usize * params.n as usize;

    let mut sigma = vec![0u8; params.kappa_bytes as usize];
    let mut b = vec![0u16; len_b];
    unpack_pk(&mut sigma, &mut b, pk, params.kappa_bytes as usize, len_b, params.p_bits);
    let a = create_a(&sigma, params);
    let r_t = create_r_t(rho, params);
    let a_t = transpose_matrix(&a, params.k as usize, params.k as usize, params.n as usize);
    let r = transpose_matrix(&r_t, params.m_bar as usize, params.k as usize, params.n as usize);

    // U = A^T * R
    let mut u = mult_matrix(&a_t, params.k as usize, params.k as usize,
                            &r, params.k as usize, params.m_bar as usize,
                            params.n as usize, params.q, false);
    // Compress U q_bits -> p_bits with flooring
    round_matrix(&mut u, params.k as usize * params.m_bar as usize, params.n as usize,
                 params.q_bits as u16, params.p_bits as u16, params.h2);
    let u_t = transpose_matrix(&u, params.k as usize, params.m_bar as usize, params.n as usize);
    let b_t = transpose_matrix(&b, params.k as usize, params.n_bar as usize, params.n as usize);

    // X = B^T * R
    let is_xi = params.xe != 0 || params.f != 0;
    let mut x = mult_matrix(&b_t, params.n_bar as usize, params.k as usize,
                            &r, params.k as usize, params.m_bar as usize,
                            params.n as usize, params.p as u32, is_xi);

    // v is a matrix of scalars, so we use 1 as the number of coefficients
    round_matrix(&mut x, params.mu as usize, 1, params.p_bits as u16, params.t_bits as u16, params.h2);

    // Compute codeword
    let mut m1 = vec![0u8; bits_to_bytes!(params.mu * params.b_bits as u16) as usize];
    m1[..params.kappa_bytes as usize].copy_from_slice(&m[..params.kappa_bytes as usize]);
    if params.xe != 0 {
        xef_compute(&mut m1, params.kappa_bytes as usize, params.f);
    }

    let v = add_msg(params.mu as usize, &x, &m1, params.b_bits as u16, params.t_bits);
    pack_ct(ct, &u_t, params.p_bits, &v, params.t_bits);
}

pub fn r5_cpa_pke_decrypt(m: &mut [u8], sk: &[u8], ct: &[u8], params: &Parameters) {
    let s_t = create_s_t(&sk[..params.kappa_bytes as usize], params);
    let mut u_t = vec![0u16; params.k as usize * params.m_bar as usize * params.n as usize];
    let mut v = vec![0u16; params.mu as usize];
    unpack_ct(&mut u_t, &mut v, ct, params.p_bits as usize, params.t_bits as usize);

    let u = transpose_matrix(&u_t, params.m_bar as usize, params.k as usize, params.n as usize);
    let x_prime = mult_matrix(&s_t, params.n_bar as usize, params.k as usize,
                              &u, params.k as usize, params.m_bar as usize,
                              params.n as usize, params.p as u32, params.xe != 0 || params.f != 0);

    decompress_matrix(&mut v, params.mu as usize, 1, params.t_bits as u16, params.p_bits as u16);

    let mut m2 = diff_msg(params.mu as usize, &v, &x_prime, params.p);
    round_matrix(&mut m2, params.mu as usize, 1, params.p_bits as u16, params.b_bits as u16, params.h3);

    let mut m1 = vec![0u8; bits_to_bytes!(params.mu * params.b_bits as u16) as usize];
    pack(&mut m1, &m2, params.mu as usize, params.b_bits);

    if params.xe != 0 {
        xef_compute(&mut m1, params.kappa_bytes as usize, params.f);
        xef_fixerr(&mut m1, params.kappa_bytes as usize, params.f);
    }
    m.copy_from_slice(&m1[..m.len()]);
}

fn add_msg(len: usize, matrix: &[u16], m: &[u8], bits_coeff: u16, scaling_factor: u8) -> Vec<u16> {
    let scale_shift = scaling_factor as i16 - bits_coeff as i16;
    let mut val: u16;
    let mut bits_done = 0usize;
    let mut idx: usize;
    let mut bit_idx: usize;

    let mut result = vec![0u16; len];
    result[..len].copy_from_slice(&matrix[..len]);
    for el in result.iter_mut().take(len) {
        idx = bits_done >> 3;
        bit_idx = bits_done & 7;
        val = (m[idx] >> bit_idx) as u16;
        if bit_idx + bits_coeff as usize > 8 {
            /* Get spill over from next message byte */
            val |= (m[idx + 1] << (8 - bit_idx)) as u16;
        }
        *el = (*el + (val << scale_shift)) & ((1 << scaling_factor) - 1);
        bits_done += bits_coeff as usize;
    }
    result
}

fn diff_msg(len: usize, matrix_a: &[u16], matrix_b: &[u16], module: u16) -> Vec<u16> {
    (0..len).map(|i| ((matrix_a[i] as i16 - matrix_b[i] as i16) as u16) & (module - 1)).collect::<Vec<u16>>()
}