use subtle::{
Choice,
ConditionallySelectable,
ConstantTimeGreater,
};
use zeroize::{
Zeroize,
ZeroizeOnDrop,
};
use crate::coeff::{
COEFFICIENTS_IN_SIMD_UNIT,
Coefficients,
FieldElement,
SIMD_UNITS_IN_RING_ELEMENT,
};
use crate::constants::{
COEFFICIENTS_IN_RING_ELEMENT,
FIELD_MODULUS,
};
use crate::field::{
add_coeffs,
reduce_element,
reduce_poly_simd,
subtract_coeffs,
};
use crate::ntt::{
intt_montgomery,
ntt_forward_simd,
ntt_multiply_montgomery,
};
#[inline]
fn ct_gt_i32(a: i32, b: i32) -> Choice {
let flip = 1u32 << 31;
let a_u = (a as u32) ^ flip;
let b_u = (b as u32) ^ flip;
a_u.ct_gt(&b_u)
}
#[inline]
fn centered_abs_i32(coefficient: i32) -> i32 {
let sign = coefficient >> 31;
coefficient - (sign & (coefficient << 1))
}
#[derive(Clone, Debug, Eq, PartialEq, Hash, Zeroize, ZeroizeOnDrop)]
pub struct Poly {
pub coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT],
}
impl Poly {
#[must_use]
pub const fn zero() -> Self {
Self {
coeffs: [0; COEFFICIENTS_IN_RING_ELEMENT],
}
}
#[must_use]
pub const fn from_coeffs(coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT]) -> Self {
Self { coeffs }
}
pub fn add_assign(&mut self, rhs: &Self) {
for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
self.coeffs[i] = reduce_element(self.coeffs[i] + rhs.coeffs[i]);
}
}
pub fn sub_assign(&mut self, rhs: &Self) {
for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
self.coeffs[i] = reduce_element(self.coeffs[i] - rhs.coeffs[i]);
}
}
pub fn scalar_mul_assign(&mut self, k: i32) {
for c in &mut self.coeffs {
*c = reduce_element((*c as i64 * k as i64) as i32);
}
}
#[must_use]
pub fn mul_negacyclic(&self, rhs: &Self) -> Self {
let mut acc = [0i64; COEFFICIENTS_IN_RING_ELEMENT];
let q = FIELD_MODULUS as i64;
for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
for j in 0..COEFFICIENTS_IN_RING_ELEMENT {
let k = i + j;
let prod = (self.coeffs[i] as i64).wrapping_mul(rhs.coeffs[j] as i64);
if k < COEFFICIENTS_IN_RING_ELEMENT {
acc[k] += prod;
} else {
let idx = k - COEFFICIENTS_IN_RING_ELEMENT;
acc[idx] -= prod;
}
}
}
let mut out = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
for (o, a) in out.iter_mut().zip(acc) {
let mut r = a % q;
if r < 0 {
r += q;
}
*o = reduce_element(r as i32);
}
Self { coeffs: out }
}
#[must_use]
pub fn infinity_norm(&self) -> i32 {
let half = FIELD_MODULUS / 2;
let mut m = 0i32;
for &c in &self.coeffs {
let gt_half = ct_gt_i32(c, half);
let centered = i32::conditional_select(&c, &c.wrapping_sub(FIELD_MODULUS), gt_half);
let abs = centered_abs_i32(centered);
let gt_max = ct_gt_i32(abs, m);
m = i32::conditional_select(&m, &abs, gt_max);
}
m
}
#[must_use]
pub fn norm_within_bound(&self, bound: i32) -> Choice {
let exceeds = ct_gt_i32(self.infinity_norm(), bound);
exceeds ^ Choice::from(1u8)
}
pub fn normalize_mod_q_assign(&mut self) {
let q = FIELD_MODULUS;
for c in &mut self.coeffs {
*c = reduce_element(*c);
let sign = *c >> 31;
*c += sign & q;
}
}
#[must_use]
pub fn scalar_mul_by_u32_mod_q(&self, scalar: u32) -> Poly {
let q = FIELD_MODULUS as i64;
let r = (scalar % FIELD_MODULUS as u32) as i64;
let mut out = self.clone();
for c in &mut out.coeffs {
let v = (*c as i64 * r).rem_euclid(q) as i32;
*c = reduce_element(v);
}
out
}
#[must_use]
pub fn to_simd(&self) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
let mut s = [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT];
for (i, lane) in s.iter_mut().enumerate() {
let base = i * COEFFICIENTS_IN_SIMD_UNIT;
lane.values
.copy_from_slice(&self.coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT]);
}
s
}
fn from_simd(simd: &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT]) -> Self {
let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
for (i, lane) in simd.iter().enumerate() {
let base = i * COEFFICIENTS_IN_SIMD_UNIT;
coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&lane.values);
}
Self { coeffs }
}
#[must_use]
pub fn to_ntt(&self) -> NttPoly {
let mut simd = self.to_simd();
ntt_forward_simd(&mut simd);
NttPoly { simd }
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct NttPoly {
pub(crate) simd: [Coefficients; SIMD_UNITS_IN_RING_ELEMENT],
}
impl NttPoly {
#[must_use]
pub fn zero() -> Self {
Self {
simd: [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT],
}
}
#[must_use]
pub fn packed_ntt_coefficients(&self) -> [FieldElement; COEFFICIENTS_IN_RING_ELEMENT] {
let mut c = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
let base = i * COEFFICIENTS_IN_SIMD_UNIT;
c[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&self.simd[i].values);
}
c
}
#[must_use]
pub fn as_simd(&self) -> &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
&self.simd
}
pub fn as_simd_mut(&mut self) -> &mut [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
&mut self.simd
}
pub fn pointwise_mul_assign(&mut self, rhs: &Self) {
for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
ntt_multiply_montgomery(&mut self.simd[i], &rhs.simd[i]);
}
}
#[must_use]
pub fn to_poly(mut self) -> Poly {
intt_montgomery(&mut self.simd);
reduce_poly_simd(&mut self.simd);
Poly::from_simd(&self.simd)
}
pub fn add_assign(&mut self, rhs: &Self) {
for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
add_coeffs(&mut self.simd[i], &rhs.simd[i]);
}
}
pub fn sub_assign(&mut self, rhs: &Self) {
for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
subtract_coeffs(&mut self.simd[i], &rhs.simd[i]);
}
}
}
#[must_use]
pub fn simd_from_i256(
buf: &[i32; COEFFICIENTS_IN_RING_ELEMENT],
) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
Poly::from_coeffs(*buf).to_simd()
}
#[must_use]
pub fn polys_norm_within_bound(polys: &[Poly], bound: i32) -> Choice {
let mut acc = Choice::from(1u8);
for p in polys {
acc &= p.norm_within_bound(bound);
}
acc
}
#[cfg(test)]
mod tests {
use super::*;
fn lcg_step(state: &mut u64) -> u32 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
(*state >> 32) as u32
}
fn small_poly(state: &mut u64, bound: i32) -> Poly {
let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
let width = (2 * bound + 1) as u32;
for c in &mut coeffs {
let v = (lcg_step(state) % width) as i32;
*c = v - bound;
}
Poly::from_coeffs(coeffs)
}
#[test]
fn ntt_inverse_has_expected_linear_scale() {
let mut one = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
one[0] = 1;
let scale = Poly::from_coeffs(one).to_ntt().to_poly().coeffs[0];
let mut st = 0xC0DEC0DE_u64;
for _ in 0..16 {
let p = small_poly(&mut st, 8);
let back = p.clone().to_ntt().to_poly();
for (orig, got) in p.coeffs.iter().zip(back.coeffs.iter()) {
let expected = reduce_element((*orig as i64 * scale as i64) as i32);
assert_eq!(expected, *got);
}
}
}
#[test]
fn ntt_pointwise_matches_schoolbook_for_small_coeffs() {
let mut st = 0xDEADBEEF_u64;
for _ in 0..4 {
let a = small_poly(&mut st, 8);
let b = small_poly(&mut st, 8);
let schoolbook = a.mul_negacyclic(&b);
let mut ntt = a.to_ntt();
let b_ntt = b.to_ntt();
ntt.pointwise_mul_assign(&b_ntt);
let back = ntt.to_poly();
assert_eq!(schoolbook, back);
}
}
fn infinity_norm_branchy_reference(p: &Poly) -> i32 {
let half = FIELD_MODULUS / 2;
let mut m = 0i32;
for &c in &p.coeffs {
let v = if c > half { c - FIELD_MODULUS } else { c };
m = m.max(v.abs());
}
m
}
#[test]
fn infinity_norm_matches_branchy_reference() {
let q = FIELD_MODULUS;
let mut st = 0xA11CE_u64;
for _ in 0..256 {
let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
for c in &mut coeffs {
*c = (lcg_step(&mut st) as i32) % q;
}
let p = Poly::from_coeffs(coeffs);
assert_eq!(p.infinity_norm(), infinity_norm_branchy_reference(&p));
}
for &edge in &[0, 1, q / 2, q / 2 + 1, q - 1] {
let mut p = Poly::zero();
p.coeffs[0] = edge;
p.coeffs[1] = -edge;
assert_eq!(p.infinity_norm(), infinity_norm_branchy_reference(&p));
}
}
#[test]
fn normalize_mod_q_and_scalar_mul_smoke() {
let mut p = Poly::zero();
p.coeffs[0] = FIELD_MODULUS + 5;
p.normalize_mod_q_assign();
assert!((0..FIELD_MODULUS).contains(&p.coeffs[0]));
p.coeffs[1] = -3;
p.normalize_mod_q_assign();
assert!((0..FIELD_MODULUS).contains(&p.coeffs[1]));
let scaled = p.scalar_mul_by_u32_mod_q(3);
assert_eq!(scaled.coeffs[0], reduce_element(p.coeffs[0] * 3));
}
}