round5 0.1.2

Implementation of Round5 post-quantum PKE and KEM algorithms
Documentation
fn modulo(a: i32, b: u32) -> u16 {
    (a as u32 & (b - 1)) as u16
}

fn lift_poly(ntru_pol: &mut [u16], cyc_pol: &[u16], len: usize, module: u32) {
    ntru_pol[0] = modulo(-(cyc_pol[0] as i32), module);
    for i in 1usize..len {
        ntru_pol[i] = modulo(cyc_pol[i - 1 as usize] as i32 - cyc_pol[i] as i32, module);
    }
    ntru_pol[len] = modulo(cyc_pol[len - 1 as usize] as i32, module);
}

fn unlift_poly(cyc_pol: &mut [u16], ntru_pol: &[u16], len: usize, module: u32) {
    cyc_pol[0] = modulo(-(ntru_pol[0] as i32), module);
    for i in 1usize..len {
        cyc_pol[i] = modulo(cyc_pol[i - 1 as usize] as i32 - ntru_pol[i] as i32, module);
    }
}

fn mult_poly_ntru(result: &mut [u16], pol_a: &[u16], pol_b: &[u16], module: u32) {
    result.iter_mut().for_each(|x| *x = 0);
    let len = result.len();
    for (i, &a) in pol_a.iter().enumerate() {
        for (j, &b) in pol_b.iter().enumerate() {
            let deg = (i + j) % len;
            let tmp = a as u32 * b as u32;
            result[deg] = modulo((result[deg] as u32 + tmp) as i32, module);
        }
    }
}

fn mult_poly(result: &mut [u16], pol_a: &[u16], pol_b: &[u16], module: u32, is_xi: bool) {
    let len = result.len();
    let mut ntru_a = vec![0u16; len + 1];
    let mut ntru_b = vec![0u16; len + 1];
    let mut ntru_res = vec![0u16; len + 1];

    if is_xi {
        ntru_a[..len].copy_from_slice(&pol_a[..len]);
        ntru_a[len] = 0;
    } else {
        lift_poly(&mut ntru_a, pol_a, len, module);
    }

    ntru_b[..len].copy_from_slice(&pol_b[..len]);
    ntru_b[len] = 0;

    mult_poly_ntru(&mut ntru_res, &ntru_a, &ntru_b, module);

    if is_xi {
        result.copy_from_slice(&ntru_res[1..=len]);
    } else {
        unlift_poly(result, &ntru_res, len, module);
    }
}

fn add_poly_in_place(pol_a: &mut [u16], pol_b: &[u16], module: u32) {
    for i in 0usize..pol_a.len() {
        pol_a[i] += pol_b[i];
        pol_a[i] = modulo(pol_a[i] as i32, module);
    }
}

pub fn mult_matrix(left: &[u16], l_rows: usize, l_cols: usize, right: &[u16], r_rows: usize, r_cols: usize, els: usize, module: u32, is_xi: bool) -> Vec<u16> {
    assert!(l_cols == r_rows);
    let mut ret = vec![0u16; l_rows * r_cols * els];
    let mut temp_poly = vec![0u16; els];

    for i in 0usize..l_rows {
        for j in 0usize..r_cols {
            for k in 0usize..l_cols {
                let l_from = i * (l_cols * els) + k * els;
                let l_to = l_from + els;
                let r_from = k * (r_cols * els) + j * els;
                let r_to = r_from + els;
                mult_poly(&mut temp_poly, &left[l_from..l_to], &right[r_from..r_to], module, is_xi);
                let from = i * (r_cols * els) + j * els;
                let to = from + els;
                add_poly_in_place(&mut ret[from..to], &temp_poly, module);
            }
        }
    }
    ret
}

pub fn transpose_matrix(matrix: &[u16], rows: usize, cols: usize, els: usize) -> Vec<u16> {
    let mut ret = vec![0u16; rows * cols * els];
    for i in 0usize..rows {
        for j in 0usize..cols {
            for k in 0usize..els {
                ret[j * (rows * els) + (i * els) + k] = matrix[i * (cols * els) + (j * els) + k];
            }
        }
    }
    ret
}

fn round_element(x: u16, a_bits: u16, b_bits: u16, b_mask: u16, rounding_constant: u16) -> u16 {
    ((x + rounding_constant) >> (a_bits - b_bits)) & b_mask
}

pub fn round_matrix(matrix: &mut[u16], len: usize, els: usize, a: u16, b: u16, rounding_constant: u16) {
    let b_mask = (1u16 << b) - 1;
    for i in 0..len * els {
        matrix[i] = round_element(matrix[i], a, b, b_mask, rounding_constant);
    }
}

fn decompress_element(x: u16, a_bits: u16, b_bits: u16, b_mask: u16) -> u16 {
    (x << (b_bits - a_bits)) & b_mask
}

pub fn decompress_matrix(matrix: &mut[u16], len: usize, els: usize, a_bits: u16, b_bits: u16) {
    let b_mask = (1u16 << b_bits) - 1;
    for i in 0..len * els {
        matrix[i] = decompress_element(matrix[i], a_bits, b_bits, b_mask);
    }
}