use crate::{cbd::*, ntt::*, params::*, reduce::*, symmetric::*};
#[derive(Clone)]
pub(crate) struct Poly {
pub(crate) coeffs: [i16; KYBER_N],
}
impl Copy for Poly {}
impl Default for Poly {
fn default() -> Self {
Poly {
coeffs: [0i16; KYBER_N],
}
}
}
impl Poly {
pub(crate) fn new() -> Self {
Self::default()
}
}
pub(crate) fn poly_compress(r: &mut [u8], a: Poly) {
let mut t = [0u8; 8];
let mut k = 0usize;
let mut u: i16;
match KYBER_POLY_COMPRESSED_BYTES {
128 => {
#[allow(clippy::needless_range_loop)]
for i in 0..KYBER_N / 8 {
for j in 0..8 {
u = a.coeffs[8 * i + j];
u += (u >> 15) & KYBER_Q as i16;
let mut tmp: u32 =
(((u as u16) << 4) + KYBER_Q as u16 / 2) as u32;
tmp *= 315;
tmp >>= 20;
t[j] = ((tmp as u16) & 15) as u8;
}
r[k] = t[0] | (t[1] << 4);
r[k + 1] = t[2] | (t[3] << 4);
r[k + 2] = t[4] | (t[5] << 4);
r[k + 3] = t[6] | (t[7] << 4);
k += 4;
}
}
160 => {
#[allow(clippy::needless_range_loop)]
for i in 0..(KYBER_N / 8) {
for j in 0..8 {
u = a.coeffs[8 * i + j];
u += (u >> 15) & KYBER_Q as i16;
let mut tmp: u32 =
((u as u32) << 5) + KYBER_Q as u32 / 2;
tmp *= 315;
tmp >>= 20;
t[j] = ((tmp as u16) & 31) as u8;
}
r[k] = t[0] | (t[1] << 5);
r[k + 1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7);
r[k + 2] = (t[3] >> 1) | (t[4] << 4);
r[k + 3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6);
r[k + 4] = (t[6] >> 2) | (t[7] << 3);
k += 5;
}
}
_ => panic!(
"KYBER_POLY_COMPRESSED_BYTES needs to be one of (128, 160)"
),
}
}
pub(crate) fn poly_decompress(r: &mut Poly, a: &[u8]) {
match KYBER_POLY_COMPRESSED_BYTES {
128 => {
for (idx, i) in (0..KYBER_N / 2).enumerate() {
r.coeffs[2 * i] = ((((a[idx] & 15) as usize * KYBER_Q)
+ 8)
>> 4) as i16;
r.coeffs[2 * i + 1] =
((((a[idx] >> 4) as usize * KYBER_Q) + 8) >> 4)
as i16;
}
}
160 => {
let mut idx = 0usize;
let mut t = [0u8; 8];
#[allow(clippy::needless_range_loop)]
for i in 0..KYBER_N / 8 {
t[0] = a[idx];
t[1] = (a[idx] >> 5) | (a[idx + 1] << 3);
t[2] = a[idx + 1] >> 2;
t[3] = (a[idx + 1] >> 7) | (a[idx + 2] << 1);
t[4] = (a[idx + 2] >> 4) | (a[idx + 3] << 4);
t[5] = a[idx + 3] >> 1;
t[6] = (a[idx + 3] >> 6) | (a[idx + 4] << 2);
t[7] = a[idx + 4] >> 3;
idx += 5;
for j in 0..8 {
r.coeffs[8 * i + j] =
((((t[j] as u32) & 31) * KYBER_Q as u32 + 16)
>> 5) as i16;
}
}
}
_ => panic!(
"KYBER_POLY_COMPRESSED_BYTES needs to be either (128, 160)"
),
}
}
pub(crate) fn poly_tobytes(r: &mut [u8], a: Poly) {
let (mut t0, mut t1);
#[allow(clippy::needless_range_loop)]
for i in 0..(KYBER_N / 2) {
t0 = a.coeffs[2 * i];
t0 += (t0 >> 15) & KYBER_Q as i16;
t1 = a.coeffs[2 * i + 1];
t1 += (t1 >> 15) & KYBER_Q as i16;
r[3 * i] = (t0) as u8;
r[3 * i + 1] = ((t0 >> 8) | (t1 << 4)) as u8;
r[3 * i + 2] = (t1 >> 4) as u8;
}
}
pub(crate) fn poly_frombytes(r: &mut Poly, a: &[u8]) {
for i in 0..(KYBER_N / 2) {
r.coeffs[2 * i] = ((a[3 * i]) as u16
| ((a[3 * i + 1] as u16) << 8) & 0xFFF)
as i16;
r.coeffs[2 * i + 1] = ((a[3 * i + 1] >> 4) as u16
| ((a[3 * i + 2] as u16) << 4) & 0xFFF)
as i16;
}
}
pub(crate) fn poly_getnoise_eta1(r: &mut Poly, seed: &[u8], nonce: u8) {
const LENGTH: usize = KYBER_ETA1 * KYBER_N / 4;
let mut buf = [0u8; LENGTH];
prf(&mut buf, LENGTH, seed, nonce);
poly_cbd_eta1(r, &buf);
}
pub(crate) fn poly_getnoise_eta2(r: &mut Poly, seed: &[u8], nonce: u8) {
const LENGTH: usize = KYBER_ETA2 * KYBER_N / 4;
let mut buf = [0u8; LENGTH];
prf(&mut buf, LENGTH, seed, nonce);
poly_cbd_eta2(r, &buf);
}
pub(crate) fn poly_ntt(r: &mut Poly) {
ntt(&mut r.coeffs);
poly_reduce(r);
}
pub(crate) fn poly_invntt_tomont(r: &mut Poly) {
invntt(&mut r.coeffs);
}
pub(crate) fn poly_basemul(r: &mut Poly, a: &Poly, b: &Poly) {
#[allow(clippy::needless_range_loop)]
for i in 0..(KYBER_N / 4) {
basemul(
&mut r.coeffs[4 * i..],
&a.coeffs[4 * i..],
&b.coeffs[4 * i..],
ZETAS[64 + i],
);
basemul(
&mut r.coeffs[4 * i + 2..],
&a.coeffs[4 * i + 2..],
&b.coeffs[4 * i + 2..],
-(ZETAS[64 + i]),
);
}
}
pub(crate) fn poly_tomont(r: &mut Poly) {
let f = ((1u64 << 32) % KYBER_Q as u64) as i16;
#[allow(clippy::needless_range_loop)]
for i in 0..KYBER_N {
let a = r.coeffs[i] as i32 * f as i32;
r.coeffs[i] = montgomery_reduce(a);
}
}
pub(crate) fn poly_reduce(r: &mut Poly) {
#[allow(clippy::needless_range_loop)]
for i in 0..KYBER_N {
r.coeffs[i] = barrett_reduce(r.coeffs[i]);
}
}
pub(crate) fn poly_add(r: &mut Poly, b: &Poly) {
#[allow(clippy::needless_range_loop)]
for i in 0..KYBER_N {
r.coeffs[i] += b.coeffs[i];
}
}
pub(crate) fn poly_sub(r: &mut Poly, a: &Poly) {
#[allow(clippy::needless_range_loop)]
for i in 0..KYBER_N {
r.coeffs[i] = a.coeffs[i] - r.coeffs[i];
}
}
pub(crate) fn poly_frommsg(r: &mut Poly, msg: &[u8]) {
let mut mask;
#[allow(clippy::needless_range_loop)]
for i in 0..KYBER_N / 8 {
for j in 0..8 {
mask = ((msg[i] as u16 >> j) & 1).wrapping_neg();
r.coeffs[8 * i + j] =
(mask & KYBER_Q.div_ceil(2) as u16) as i16;
}
}
}
pub(crate) fn poly_tomsg(msg: &mut [u8], a: Poly) {
let mut t: u32;
#[allow(clippy::needless_range_loop)]
for i in 0..KYBER_N / 8 {
msg[i] = 0;
for j in 0..8 {
t = a.coeffs[8 * i + j] as u32;
t <<= 1;
t = t.wrapping_add(1665);
t = t.wrapping_mul(80635);
t >>= 28;
t &= 1;
msg[i] |= (t << j) as u8;
}
}
}
#[allow(dead_code)] pub(crate) fn poly_compress_generic<
P: crate::paramsets::MlKemParams,
>(
r: &mut [u8],
a: Poly,
) {
let mut t = [0u8; 8];
let mut k = 0usize;
let mut u: i16;
match P::DV {
4 =>
{
#[allow(clippy::needless_range_loop)]
for i in 0..KYBER_N / 8 {
for j in 0..8 {
u = a.coeffs[8 * i + j];
u += (u >> 15) & KYBER_Q as i16;
let mut tmp: u32 =
(((u as u16) << 4) + KYBER_Q as u16 / 2) as u32;
tmp *= 315;
tmp >>= 20;
t[j] = ((tmp as u16) & 15) as u8;
}
r[k] = t[0] | (t[1] << 4);
r[k + 1] = t[2] | (t[3] << 4);
r[k + 2] = t[4] | (t[5] << 4);
r[k + 3] = t[6] | (t[7] << 4);
k += 4;
}
}
5 =>
{
#[allow(clippy::needless_range_loop)]
for i in 0..(KYBER_N / 8) {
for j in 0..8 {
u = a.coeffs[8 * i + j];
u += (u >> 15) & KYBER_Q as i16;
let mut tmp: u32 =
((u as u32) << 5) + KYBER_Q as u32 / 2;
tmp *= 315;
tmp >>= 20;
t[j] = ((tmp as u16) & 31) as u8;
}
r[k] = t[0] | (t[1] << 5);
r[k + 1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7);
r[k + 2] = (t[3] >> 1) | (t[4] << 4);
r[k + 3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6);
r[k + 4] = (t[6] >> 2) | (t[7] << 3);
k += 5;
}
}
_ => unreachable!("DV must be 4 or 5 per FIPS 203 §6"),
}
}
#[allow(dead_code)]
pub(crate) fn poly_decompress_generic<
P: crate::paramsets::MlKemParams,
>(
r: &mut Poly,
a: &[u8],
) {
match P::DV {
4 => {
for (idx, i) in (0..KYBER_N / 2).enumerate() {
r.coeffs[2 * i] = ((((a[idx] & 15) as usize * KYBER_Q)
+ 8)
>> 4) as i16;
r.coeffs[2 * i + 1] =
((((a[idx] >> 4) as usize * KYBER_Q) + 8) >> 4)
as i16;
}
}
5 => {
let mut idx = 0usize;
let mut t = [0u8; 8];
#[allow(clippy::needless_range_loop)]
for i in 0..KYBER_N / 8 {
t[0] = a[idx];
t[1] = (a[idx] >> 5) | (a[idx + 1] << 3);
t[2] = a[idx + 1] >> 2;
t[3] = (a[idx + 1] >> 7) | (a[idx + 2] << 1);
t[4] = (a[idx + 2] >> 4) | (a[idx + 3] << 4);
t[5] = a[idx + 3] >> 1;
t[6] = (a[idx + 3] >> 6) | (a[idx + 4] << 2);
t[7] = a[idx + 4] >> 3;
idx += 5;
for j in 0..8 {
r.coeffs[8 * i + j] =
((((t[j] as u32) & 31) * KYBER_Q as u32 + 16)
>> 5) as i16;
}
}
}
_ => unreachable!("DV must be 4 or 5 per FIPS 203 §6"),
}
}
#[allow(dead_code)]
pub(crate) const fn poly_compressed_len<
P: crate::paramsets::MlKemParams,
>() -> usize {
32 * P::DV
}
#[cfg(test)]
mod poly_generic_tests {
#![allow(unused_imports)]
use super::*;
use crate::paramsets::MlKemParams;
fn build_test_poly() -> Poly {
let mut p = Poly::new();
for (j, c) in p.coeffs.iter_mut().enumerate() {
let v = (j * 113 + 7) % (KYBER_Q + 50);
*c = if j % 5 == 0 {
-(v as i16 % KYBER_Q as i16)
} else {
v as i16 % KYBER_Q as i16
};
}
p
}
#[test]
#[cfg(feature = "kyber768")]
fn poly_compress_matches_existing_kyber768() {
use crate::MlKem768;
let p = build_test_poly();
let mut buf_existing = [0u8; KYBER_POLY_COMPRESSED_BYTES];
poly_compress(&mut buf_existing, p);
let mut buf_generic = [0u8; poly_compressed_len::<MlKem768>()];
poly_compress_generic::<MlKem768>(&mut buf_generic, p);
assert_eq!(buf_existing.as_slice(), buf_generic.as_slice());
}
#[test]
#[cfg(feature = "kyber512")]
fn poly_compress_matches_existing_kyber512() {
use crate::MlKem512;
let p = build_test_poly();
let mut buf_existing = [0u8; KYBER_POLY_COMPRESSED_BYTES];
poly_compress(&mut buf_existing, p);
let mut buf_generic = [0u8; poly_compressed_len::<MlKem512>()];
poly_compress_generic::<MlKem512>(&mut buf_generic, p);
assert_eq!(buf_existing.as_slice(), buf_generic.as_slice());
}
#[test]
#[cfg(feature = "kyber1024")]
fn poly_compress_matches_existing_kyber1024() {
use crate::MlKem1024;
let p = build_test_poly();
let mut buf_existing = [0u8; KYBER_POLY_COMPRESSED_BYTES];
poly_compress(&mut buf_existing, p);
let mut buf_generic = [0u8; poly_compressed_len::<MlKem1024>()];
poly_compress_generic::<MlKem1024>(&mut buf_generic, p);
assert_eq!(buf_existing.as_slice(), buf_generic.as_slice());
}
#[test]
#[cfg(feature = "kyber768")]
fn poly_decompress_matches_existing_kyber768() {
use crate::MlKem768;
let p = build_test_poly();
let mut buf = [0u8; KYBER_POLY_COMPRESSED_BYTES];
poly_compress(&mut buf, p);
let mut p_existing = Poly::new();
poly_decompress(&mut p_existing, &buf);
let mut p_generic = Poly::new();
poly_decompress_generic::<MlKem768>(&mut p_generic, &buf);
assert_eq!(p_existing.coeffs, p_generic.coeffs);
}
#[test]
fn poly_compressed_len_formula() {
use crate::{MlKem1024, MlKem512, MlKem768};
assert_eq!(poly_compressed_len::<MlKem512>(), 128);
assert_eq!(poly_compressed_len::<MlKem768>(), 128);
assert_eq!(poly_compressed_len::<MlKem1024>(), 160);
}
}
#[allow(dead_code)]
pub(crate) fn poly_getnoise_eta1_generic<
P: crate::paramsets::MlKemParams,
>(
r: &mut Poly,
seed: &[u8],
nonce: u8,
) {
const MAX_LENGTH: usize = 3 * KYBER_N / 4; let length = P::ETA1 * KYBER_N / 4;
let mut buf = [0u8; MAX_LENGTH];
prf(&mut buf[..length], length, seed, nonce);
poly_cbd_eta1_generic::<P>(r, &buf[..length]);
}
#[allow(dead_code)]
pub(crate) fn poly_getnoise_eta2_generic<
P: crate::paramsets::MlKemParams,
>(
r: &mut Poly,
seed: &[u8],
nonce: u8,
) {
const LENGTH: usize = 2 * KYBER_N / 4; let mut buf = [0u8; LENGTH];
prf(&mut buf, LENGTH, seed, nonce);
poly_cbd_eta2_generic::<P>(r, &buf);
}
#[cfg(test)]
mod poly_getnoise_generic_tests {
#![allow(unused_imports)]
use super::*;
use crate::paramsets::MlKemParams;
#[test]
#[cfg(feature = "kyber768")]
fn poly_getnoise_eta1_matches_existing_kyber768() {
use crate::MlKem768;
let seed = [0xAAu8; KYBER_SYM_BYTES];
let mut p_e = Poly::new();
let mut p_g = Poly::new();
poly_getnoise_eta1(&mut p_e, &seed, 7);
poly_getnoise_eta1_generic::<MlKem768>(&mut p_g, &seed, 7);
assert_eq!(p_e.coeffs, p_g.coeffs);
}
#[test]
#[cfg(feature = "kyber512")]
fn poly_getnoise_eta1_matches_existing_kyber512() {
use crate::MlKem512;
let seed = [0xAAu8; KYBER_SYM_BYTES];
let mut p_e = Poly::new();
let mut p_g = Poly::new();
poly_getnoise_eta1(&mut p_e, &seed, 7);
poly_getnoise_eta1_generic::<MlKem512>(&mut p_g, &seed, 7);
assert_eq!(p_e.coeffs, p_g.coeffs);
}
#[test]
#[cfg(feature = "kyber1024")]
fn poly_getnoise_eta1_matches_existing_kyber1024() {
use crate::MlKem1024;
let seed = [0xAAu8; KYBER_SYM_BYTES];
let mut p_e = Poly::new();
let mut p_g = Poly::new();
poly_getnoise_eta1(&mut p_e, &seed, 7);
poly_getnoise_eta1_generic::<MlKem1024>(&mut p_g, &seed, 7);
assert_eq!(p_e.coeffs, p_g.coeffs);
}
#[test]
#[cfg(feature = "kyber768")]
fn poly_getnoise_eta2_matches_existing_kyber768() {
use crate::MlKem768;
let seed = [0xBBu8; KYBER_SYM_BYTES];
let mut p_e = Poly::new();
let mut p_g = Poly::new();
poly_getnoise_eta2(&mut p_e, &seed, 13);
poly_getnoise_eta2_generic::<MlKem768>(&mut p_g, &seed, 13);
assert_eq!(p_e.coeffs, p_g.coeffs);
}
}