fn modulo(a: i32, b: u32) -> u16 {
(a as u32 & (b - 1)) as u16
}
fn lift_poly(ntru_pol: &mut [u16], cyc_pol: &[u16], len: usize, module: u32) {
ntru_pol[0] = modulo(-(cyc_pol[0] as i32), module);
for i in 1usize..len {
ntru_pol[i] = modulo(cyc_pol[i - 1 as usize] as i32 - cyc_pol[i] as i32, module);
}
ntru_pol[len] = modulo(cyc_pol[len - 1 as usize] as i32, module);
}
fn unlift_poly(cyc_pol: &mut [u16], ntru_pol: &[u16], len: usize, module: u32) {
cyc_pol[0] = modulo(-(ntru_pol[0] as i32), module);
for i in 1usize..len {
cyc_pol[i] = modulo(cyc_pol[i - 1 as usize] as i32 - ntru_pol[i] as i32, module);
}
}
fn mult_poly_ntru(result: &mut [u16], pol_a: &[u16], pol_b: &[u16], module: u32) {
result.iter_mut().for_each(|x| *x = 0);
let len = result.len();
for (i, &a) in pol_a.iter().enumerate() {
for (j, &b) in pol_b.iter().enumerate() {
let deg = (i + j) % len;
let tmp = a as u32 * b as u32;
result[deg] = modulo((result[deg] as u32 + tmp) as i32, module);
}
}
}
fn mult_poly(result: &mut [u16], pol_a: &[u16], pol_b: &[u16], module: u32, is_xi: bool) {
let len = result.len();
let mut ntru_a = vec![0u16; len + 1];
let mut ntru_b = vec![0u16; len + 1];
let mut ntru_res = vec![0u16; len + 1];
if is_xi {
ntru_a[..len].copy_from_slice(&pol_a[..len]);
ntru_a[len] = 0;
} else {
lift_poly(&mut ntru_a, pol_a, len, module);
}
ntru_b[..len].copy_from_slice(&pol_b[..len]);
ntru_b[len] = 0;
mult_poly_ntru(&mut ntru_res, &ntru_a, &ntru_b, module);
if is_xi {
result.copy_from_slice(&ntru_res[1..=len]);
} else {
unlift_poly(result, &ntru_res, len, module);
}
}
fn add_poly_in_place(pol_a: &mut [u16], pol_b: &[u16], module: u32) {
for i in 0usize..pol_a.len() {
pol_a[i] += pol_b[i];
pol_a[i] = modulo(pol_a[i] as i32, module);
}
}
pub fn mult_matrix(left: &[u16], l_rows: usize, l_cols: usize, right: &[u16], r_rows: usize, r_cols: usize, els: usize, module: u32, is_xi: bool) -> Vec<u16> {
assert!(l_cols == r_rows);
let mut ret = vec![0u16; l_rows * r_cols * els];
let mut temp_poly = vec![0u16; els];
for i in 0usize..l_rows {
for j in 0usize..r_cols {
for k in 0usize..l_cols {
let l_from = i * (l_cols * els) + k * els;
let l_to = l_from + els;
let r_from = k * (r_cols * els) + j * els;
let r_to = r_from + els;
mult_poly(&mut temp_poly, &left[l_from..l_to], &right[r_from..r_to], module, is_xi);
let from = i * (r_cols * els) + j * els;
let to = from + els;
add_poly_in_place(&mut ret[from..to], &temp_poly, module);
}
}
}
ret
}
pub fn transpose_matrix(matrix: &[u16], rows: usize, cols: usize, els: usize) -> Vec<u16> {
let mut ret = vec![0u16; rows * cols * els];
for i in 0usize..rows {
for j in 0usize..cols {
for k in 0usize..els {
ret[j * (rows * els) + (i * els) + k] = matrix[i * (cols * els) + (j * els) + k];
}
}
}
ret
}
fn round_element(x: u16, a_bits: u16, b_bits: u16, b_mask: u16, rounding_constant: u16) -> u16 {
((x + rounding_constant) >> (a_bits - b_bits)) & b_mask
}
pub fn round_matrix(matrix: &mut[u16], len: usize, els: usize, a: u16, b: u16, rounding_constant: u16) {
let b_mask = (1u16 << b) - 1;
for i in 0..len * els {
matrix[i] = round_element(matrix[i], a, b, b_mask, rounding_constant);
}
}
fn decompress_element(x: u16, a_bits: u16, b_bits: u16, b_mask: u16) -> u16 {
(x << (b_bits - a_bits)) & b_mask
}
pub fn decompress_matrix(matrix: &mut[u16], len: usize, els: usize, a_bits: u16, b_bits: u16) {
let b_mask = (1u16 << b_bits) - 1;
for i in 0..len * els {
matrix[i] = decompress_element(matrix[i], a_bits, b_bits, b_mask);
}
}