use ctutils::{CtEq, CtGt, CtLt, CtSelect};
use hybrid_array::{
ArraySize,
typenum::{Shleft, U1, U13, Unsigned},
};
use module_lattice::{Field, Truncate};
module_lattice::define_field!(BaseField, u32, u64, u128, 8_380_417);
pub(crate) type Int = <BaseField as Field>::Int;
pub(crate) type Elem = module_lattice::Elem<BaseField>;
pub(crate) type Polynomial = module_lattice::Polynomial<BaseField>;
pub(crate) type Vector<K> = module_lattice::Vector<BaseField, K>;
pub(crate) type NttPolynomial = module_lattice::NttPolynomial<BaseField>;
pub(crate) type NttVector<K> = module_lattice::NttVector<BaseField, K>;
pub(crate) type NttMatrix<K, L> = module_lattice::NttMatrix<BaseField, K, L>;
pub(crate) trait BarrettReduce: Unsigned {
const SHIFT: usize;
const MULTIPLIER: u64;
fn reduce(x: u32) -> u32 {
let m = Self::U64;
let x: u64 = x.into();
let quotient = (x * Self::MULTIPLIER) >> Self::SHIFT;
let remainder = x - quotient * m;
let r_small: u32 = Truncate::truncate(remainder);
let r_large: u32 = Truncate::truncate(remainder.wrapping_sub(m));
u32::ct_select(&r_large, &r_small, remainder.ct_lt(&m))
}
}
impl<M> BarrettReduce for M
where
M: Unsigned,
{
#[allow(clippy::as_conversions)]
const SHIFT: usize = 2 * (M::U64.ilog2() + 1) as usize;
#[allow(clippy::integer_division_remainder_used, reason = "constant")]
const MULTIPLIER: u64 = (1 << Self::SHIFT) / M::U64;
}
pub(crate) trait Decompose {
fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem);
}
pub(crate) trait ConstantTimeDiv: Unsigned {
const CT_DIV_SHIFT: usize;
const CT_DIV_MULTIPLIER: u64;
#[allow(clippy::inline_always)] #[inline(always)]
fn ct_div(x: u32) -> u32 {
let x64 = u64::from(x);
let quotient = (x64 * Self::CT_DIV_MULTIPLIER) >> Self::CT_DIV_SHIFT;
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
let result = quotient as u32;
result
}
}
impl<M> ConstantTimeDiv for M
where
M: Unsigned,
{
const CT_DIV_SHIFT: usize = 48;
#[allow(clippy::integer_division_remainder_used, reason = "constant")]
const CT_DIV_MULTIPLIER: u64 = (1u64 << Self::CT_DIV_SHIFT).div_ceil(M::U64);
}
impl Decompose for Elem {
fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem) {
let r_plus = self.clone();
let r0 = r_plus.mod_plus_minus::<TwoGamma2>();
let diff = r_plus - r0;
let is_edge = diff.0.ct_eq(&(BaseField::Q - 1));
let edge = (Elem::new(0), r0 - Elem::new(1));
let r1 = Elem::new(TwoGamma2::ct_div(diff.0));
let normal = (r1, r0);
let r1_out = Elem::new(u32::ct_select(&normal.0.0, &edge.0.0, is_edge));
let r0_out = Elem::new(u32::ct_select(&normal.1.0, &edge.1.0, is_edge));
(r1_out, r0_out)
}
}
#[allow(clippy::module_name_repetitions)] pub(crate) trait AlgebraExt: Sized {
fn mod_plus_minus<M: Unsigned>(&self) -> Self;
fn infinity_norm(&self) -> Int;
fn power2round(&self) -> (Self, Self);
fn high_bits<TwoGamma2: Unsigned>(&self) -> Self;
fn low_bits<TwoGamma2: Unsigned>(&self) -> Self;
}
impl AlgebraExt for Elem {
fn mod_plus_minus<M: Unsigned>(&self) -> Self {
let raw_mod = Elem::new(M::reduce(self.0));
let in_lower_half = !raw_mod.0.ct_gt(&(M::U32 >> 1));
Elem::new(u32::ct_select(
&(raw_mod - Elem::new(M::U32)).0,
&raw_mod.0,
in_lower_half,
))
}
fn infinity_norm(&self) -> u32 {
let in_lower_half = !self.0.ct_gt(&(BaseField::Q >> 1));
u32::ct_select(&(BaseField::Q - self.0), &self.0, in_lower_half)
}
fn power2round(&self) -> (Self, Self) {
type D = U13;
type Pow2D = Shleft<U1, D>;
let r_plus = self.clone();
let r0 = r_plus.mod_plus_minus::<Pow2D>();
let r1 = Elem::new((r_plus - r0).0 >> D::USIZE);
(r1, r0)
}
fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
self.decompose::<TwoGamma2>().0
}
fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
self.decompose::<TwoGamma2>().1
}
}
impl AlgebraExt for Polynomial {
fn mod_plus_minus<M: Unsigned>(&self) -> Self {
Self(self.0.iter().map(AlgebraExt::mod_plus_minus::<M>).collect())
}
fn infinity_norm(&self) -> u32 {
self.0
.iter()
.map(AlgebraExt::infinity_norm)
.max()
.expect("should have a maximum")
}
fn power2round(&self) -> (Self, Self) {
let mut r1 = Self::default();
let mut r0 = Self::default();
for (i, x) in self.0.iter().enumerate() {
(r1.0[i], r0.0[i]) = x.power2round();
}
(r1, r0)
}
fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
Self(
self.0
.iter()
.map(AlgebraExt::high_bits::<TwoGamma2>)
.collect(),
)
}
fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
Self(
self.0
.iter()
.map(AlgebraExt::low_bits::<TwoGamma2>)
.collect(),
)
}
}
impl<K: ArraySize> AlgebraExt for Vector<K> {
fn mod_plus_minus<M: Unsigned>(&self) -> Self {
Self(self.0.iter().map(AlgebraExt::mod_plus_minus::<M>).collect())
}
fn infinity_norm(&self) -> u32 {
self.0
.iter()
.map(AlgebraExt::infinity_norm)
.max()
.expect("should have a maximum")
}
fn power2round(&self) -> (Self, Self) {
let mut r1 = Self::default();
let mut r0 = Self::default();
for (i, x) in self.0.iter().enumerate() {
(r1.0[i], r0.0[i]) = x.power2round();
}
(r1, r0)
}
fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
Self(
self.0
.iter()
.map(AlgebraExt::high_bits::<TwoGamma2>)
.collect(),
)
}
fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
Self(
self.0
.iter()
.map(AlgebraExt::low_bits::<TwoGamma2>)
.collect(),
)
}
}
#[cfg(test)]
#[allow(clippy::integer_division_remainder_used, reason = "tests")]
mod test {
use super::*;
use crate::{MlDsa65, ParameterSet};
type Mod = <MlDsa65 as ParameterSet>::TwoGamma2;
const MOD: u32 = Mod::U32;
const MOD_ELEM: Elem = Elem::new(MOD);
#[test]
fn mod_plus_minus() {
for x in 0..MOD {
let x = Elem::new(x);
let x0 = x.mod_plus_minus::<Mod>();
let positive_bound = x0.0 <= MOD / 2;
let negative_bound = x0.0 > BaseField::Q - MOD / 2;
assert!(positive_bound || negative_bound);
let xn = x + MOD_ELEM;
let x0n = x0 + MOD_ELEM;
assert_eq!(xn.0 % MOD, x0n.0 % MOD);
}
}
#[test]
fn decompose() {
for x in 0..MOD {
let x = Elem::new(x);
let (x1, x0) = x.decompose::<Mod>();
let positive_bound = x0.0 <= MOD / 2;
let negative_bound = x0.0 >= BaseField::Q - MOD / 2;
assert!(positive_bound || negative_bound);
let xx = (MOD * x1.0 + x0.0) % BaseField::Q;
assert_eq!(xx, x.0);
}
}
#[test]
fn barrett_reduce_boundary() {
let m_minus_1 = Mod::U32 - 1;
assert_eq!(Mod::reduce(m_minus_1), m_minus_1);
assert_eq!(Mod::reduce(Mod::U32), 0);
assert_eq!(Mod::reduce(Mod::U32 + 1), 1);
assert_eq!(Mod::reduce(2 * Mod::U32 - 1), m_minus_1);
assert_eq!(Mod::reduce(2 * Mod::U32), 0);
}
#[test]
fn constant_time_div_accuracy() {
for x in 0..1000 {
assert_eq!(Mod::ct_div(x), x / Mod::U32);
}
for x in (BaseField::Q - 1000)..BaseField::Q {
assert_eq!(Mod::ct_div(x), x / Mod::U32);
}
}
#[test]
fn decompose_edge_case() {
let q_minus_1 = Elem::new(BaseField::Q - 1);
let (r1, r0) = q_minus_1.decompose::<Mod>();
let reconstructed = (MOD * r1.0 + r0.0) % BaseField::Q;
assert_eq!(reconstructed, q_minus_1.0);
}
#[test]
fn high_low_bits_consistency() {
for x in [0, 1, MOD / 2, MOD - 1, MOD, MOD + 1, BaseField::Q - 1] {
let elem = Elem::new(x);
let (decomp_high, decomp_low) = elem.decompose::<Mod>();
assert_eq!(elem.high_bits::<Mod>(), decomp_high);
assert_eq!(elem.low_bits::<Mod>(), decomp_low);
}
}
}