use core::ops::{Add, Mul, Sub};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
pub(crate) const KYBER_Q: u32 = 3329;
pub fn barrett_reduce(value: u32) -> u32 {
debug_assert!(value < KYBER_Q.pow(2));
const MUL: u64 = 5039;
const SHIFT: u64 = 24;
let quo: u32 = ((u64::from(value) * MUL) >> SHIFT) as u32;
let r = value - (quo * KYBER_Q);
debug_assert!((0..KYBER_Q * 2).contains(&r));
let ret = conditional_sub_u32(r);
debug_assert!((0..KYBER_Q).contains(&ret));
ret
}
fn conditional_sub_u32(a: u32) -> u32 {
let t: u32 = a.overflowing_sub(KYBER_Q).0;
let mask: u32 = 0u32.overflowing_sub(t >> 31).0;
(t & !mask) | (a & mask)
}
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct FieldElement(pub(crate) u32);
#[cfg(feature = "zeroize")]
impl Zeroize for FieldElement {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
impl FieldElement {
pub fn new(value: u32) -> Self {
debug_assert!((0..KYBER_Q).contains(&value));
Self(value)
}
pub fn zero() -> Self {
Self(0)
}
pub fn compress(&self, d: u8) -> u32 {
debug_assert!((1..=11).contains(&d));
const MUL: u64 = 5039;
const SHIFT: u64 = 24;
let div: u32 = self.0 << d;
let mut quo: u32 = ((u64::from(div) * MUL) >> SHIFT) as u32;
let rem: u32 = div - (quo * KYBER_Q);
quo += ((KYBER_Q / 2).overflowing_sub(rem).0 >> 31) & 1;
quo += ((KYBER_Q + KYBER_Q / 2 - rem) >> 31) & 1;
let mask: u32 = (1 << d as u32) - 1;
((quo & mask) as u16) as u32
}
pub fn decompress(y: u32, d: u8) -> Self {
debug_assert!((1..=11).contains(&d));
let div: u32 = y * KYBER_Q;
let mut quo: u32 = div >> d as u32;
quo += (div >> (d as u32 - 1)) & 1;
debug_assert!(quo < KYBER_Q);
FieldElement(quo)
}
#[cfg(all(test, feature = "safe_api"))]
pub fn random() -> Self {
use rand::prelude::*;
Self(rand::rng().random_range(0..KYBER_Q))
}
}
impl Add for FieldElement {
type Output = Self;
fn add(self, other: Self) -> Self {
let x: u32 = self.0 + other.0;
Self(conditional_sub_u32(x))
}
}
impl Sub for FieldElement {
type Output = Self;
fn sub(self, other: Self) -> Self {
let x: u32 = self.0.overflowing_sub(other.0).0.overflowing_add(KYBER_Q).0;
Self(conditional_sub_u32(x))
}
}
impl Mul for FieldElement {
type Output = Self;
fn mul(self, other: Self) -> Self {
Self(barrett_reduce(self.0 * other.0))
}
}
#[cfg(test)]
mod test_field_modular_arithmetic {
use super::*;
#[test]
fn test_field_ops_add() {
for x in 0..KYBER_Q {
for y in 0..KYBER_Q {
let fe_add_ret = FieldElement(x) + FieldElement(y);
let num_add_ret = (x + y) % KYBER_Q;
assert!(fe_add_ret.0 < KYBER_Q);
assert_eq!(fe_add_ret.0, num_add_ret);
}
}
}
#[test]
fn test_field_ops_sub() {
for x in 0..KYBER_Q {
for y in 0..KYBER_Q {
let fe_sub_ret = FieldElement(x) - FieldElement(y);
let num_sub_ret = (x as i32 - y as i32 + KYBER_Q as i32) % KYBER_Q as i32;
assert!(fe_sub_ret.0 < KYBER_Q);
assert_eq!(fe_sub_ret.0, num_sub_ret as u32);
}
}
}
#[test]
fn test_field_ops_mul() {
for x in 0..KYBER_Q {
for y in 0..KYBER_Q {
let fe_mul_ret = FieldElement(x) * FieldElement(y);
let num_mul_ret = (x * y) % KYBER_Q;
assert_eq!(fe_mul_ret.0, num_mul_ret);
}
}
}
fn conditional_sub_i16(a: i16, modulo: i16) -> i16 {
let t: i16 = a - modulo;
let mask: i16 = t >> 15;
(t & !mask) | (a & mask)
}
#[test]
fn test_conditional_sub() {
for a in 0..KYBER_Q * 2 {
if a >= KYBER_Q {
assert_eq!(conditional_sub_u32(a), a - KYBER_Q);
assert_eq!(
conditional_sub_i16(a as i16, KYBER_Q as i16),
(a - KYBER_Q) as i16
);
} else {
assert_eq!(conditional_sub_u32(a), a);
assert_eq!(conditional_sub_i16(a as i16, KYBER_Q as i16), a as i16);
}
}
}
#[test]
fn test_field_reduced_state() {
for a in 0..KYBER_Q.pow(2) {
let reduced: u32 = barrett_reduce(a);
assert!((0..KYBER_Q).contains(&reduced));
assert_eq!(reduced, a % KYBER_Q);
}
}
}
#[cfg(test)]
mod test_compression {
use super::*;
use num_rational::*;
const COMPRESSION_D: [u8; 6] = [1, 4, 5, 6, 10, 11];
fn ratcompress(x: i32, d: u32) -> i32 {
let m: i32 = 2i32.pow(d);
let mut r = Rational32::new(x * m, KYBER_Q as i32);
r = r.round() % m;
r.to_integer()
}
fn ratdecompress(y: i32, d: u32) -> i32 {
let m: i32 = 2i32.pow(d);
let mut r = Rational32::new(y * KYBER_Q as i32, m);
r = r.round();
r.to_integer()
}
#[test]
fn test_compress_with_rational() {
for d in COMPRESSION_D {
for x in 0..KYBER_Q {
let fe = FieldElement::new(x);
assert_eq!(fe.compress(d) as i32, ratcompress(fe.0 as i32, d.into()));
}
}
}
#[test]
fn test_decompress_with_rational() {
for d in COMPRESSION_D {
for y in 0..2u32.pow(d as u32) {
assert_eq!(
FieldElement::decompress(y, d).0 as i32,
ratdecompress(y as i32, d.into())
);
}
}
}
#[test]
fn test_compress_decompress_roundtrip_d_lt_12() {
for d in COMPRESSION_D {
for x in 0..2u32.pow(d as u32) {
assert_eq!(FieldElement::decompress(x, d).compress(d), x);
}
}
}
}