use crate::algebra::{Additive, Field, FieldNTT, Multiplicative, Object, Random, Ring};
use commonware_codec::{FixedSize, Read, Write};
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use rand_core::CryptoRngCore;
const P: u64 = u64::wrapping_neg(1 << 32) + 1;
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct F(u64);
impl FixedSize for F {
const SIZE: usize = u64::SIZE;
}
impl Write for F {
fn write(&self, buf: &mut impl bytes::BufMut) {
self.0.write(buf)
}
}
impl Read for F {
type Cfg = <u64 as Read>::Cfg;
fn read_cfg(
buf: &mut impl bytes::Buf,
cfg: &Self::Cfg,
) -> Result<Self, commonware_codec::Error> {
let x = u64::read_cfg(buf, cfg)?;
if x >= P {
return Err(commonware_codec::Error::Invalid("F", "out of range"));
}
Ok(Self(x))
}
}
impl core::fmt::Debug for F {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{:016X}", self.0)
}
}
#[cfg(any(test, feature = "arbitrary"))]
impl arbitrary::Arbitrary<'_> for F {
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
let x = u.arbitrary::<u64>()?;
Ok(Self::reduce_64(x))
}
}
impl F {
#[cfg(test)]
pub const GENERATOR: Self = Self(0xd64f951101aff9bf);
pub const ROOT_OF_UNITY: Self = Self(0xee41f5320c4ea145);
pub const NOT_ROOT_OF_UNITY: Self = Self(0x79bc2f50acd74161);
pub const NOT_ROOT_OF_UNITY_INV: Self = Self(0x1036c4023580ce8d);
const ZERO: Self = Self(0);
const ONE: Self = Self(1);
const fn add_inner(self, b: Self) -> Self {
let (addition, overflow) = self.0.overflowing_add(b.0);
let (subtraction, underflow) = addition.overflowing_sub(P);
if overflow || !underflow {
Self(subtraction)
} else {
Self(addition)
}
}
const fn sub_inner(self, b: Self) -> Self {
let (subtraction, underflow) = self.0.overflowing_sub(b.0);
if underflow {
Self(subtraction.wrapping_add(P))
} else {
Self(subtraction)
}
}
const fn reduce_64(x: u64) -> Self {
let (subtraction, underflow) = x.overflowing_sub(P);
if underflow {
Self(x)
} else {
Self(subtraction)
}
}
const fn reduce_128(x: u128) -> Self {
let a = x as u64;
let b = ((x >> 64) & 0xFF_FF_FF_FF) as u64;
let c = (x >> 96) as u64;
Self(a).sub_inner(Self(c)).add_inner(Self((b << 32) - b))
}
const fn mul_inner(self, b: Self) -> Self {
Self::reduce_128((self.0 as u128) * (b.0 as u128))
}
const fn neg_inner(self) -> Self {
Self::ZERO.sub_inner(self)
}
pub fn inv(self) -> Self {
self.exp(&[P - 2])
}
pub fn stream_from_u64s(inner: impl Iterator<Item = u64>) -> impl Iterator<Item = Self> {
struct Iter<I> {
acc: u128,
acc_bits: u32,
inner: I,
}
impl<I: Iterator<Item = u64>> Iterator for Iter<I> {
type Item = F;
fn next(&mut self) -> Option<Self::Item> {
while self.acc_bits < 63 {
let Some(x) = self.inner.next() else {
break;
};
let x = u128::from(x);
self.acc |= x << self.acc_bits;
self.acc_bits += 64;
}
if self.acc_bits > 0 {
self.acc_bits = self.acc_bits.saturating_sub(63);
let out = F((self.acc as u64) & ((1 << 63) - 1));
self.acc >>= 63;
return Some(out);
}
None
}
}
Iter {
acc: 0,
acc_bits: 0,
inner,
}
}
pub fn stream_to_u64s(inner: impl Iterator<Item = Self>) -> impl Iterator<Item = u64> {
struct Iter<I> {
acc: u128,
acc_bits: u32,
inner: I,
}
impl<I: Iterator<Item = F>> Iterator for Iter<I> {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
while self.acc_bits < 64 {
let Some(F(x)) = self.inner.next() else {
break;
};
let x = u128::from(x & ((1 << 63) - 1));
self.acc |= x << self.acc_bits;
self.acc_bits += 63;
}
if self.acc_bits > 0 {
self.acc_bits = self.acc_bits.saturating_sub(64);
let out = self.acc as u64;
self.acc >>= 64;
return Some(out);
}
None
}
}
Iter {
acc: 0,
acc_bits: 0,
inner,
}
}
pub const fn bits_to_elements(bits: usize) -> usize {
bits.div_ceil(63)
}
pub const fn to_le_bytes(&self) -> [u8; 8] {
self.0.to_le_bytes()
}
}
impl Object for F {}
impl Random for F {
fn random(mut rng: impl CryptoRngCore) -> Self {
loop {
let x = rng.next_u64();
if x < P {
return Self(x);
}
}
}
}
impl Add for F {
type Output = Self;
fn add(self, b: Self) -> Self::Output {
self.add_inner(b)
}
}
impl<'a> Add<&'a Self> for F {
type Output = Self;
fn add(self, rhs: &'a Self) -> Self::Output {
self + *rhs
}
}
impl<'a> AddAssign<&'a Self> for F {
fn add_assign(&mut self, rhs: &'a Self) {
*self = *self + rhs
}
}
impl<'a> Sub<&'a Self> for F {
type Output = Self;
fn sub(self, rhs: &'a Self) -> Self::Output {
self - *rhs
}
}
impl<'a> SubAssign<&'a Self> for F {
fn sub_assign(&mut self, rhs: &'a Self) {
*self = *self - rhs;
}
}
impl Additive for F {
fn zero() -> Self {
Self::ZERO
}
}
impl Sub for F {
type Output = Self;
fn sub(self, b: Self) -> Self::Output {
self.sub_inner(b)
}
}
impl Mul for F {
type Output = Self;
fn mul(self, b: Self) -> Self::Output {
Self::mul_inner(self, b)
}
}
impl<'a> Mul<&'a Self> for F {
type Output = Self;
fn mul(self, rhs: &'a Self) -> Self::Output {
self * *rhs
}
}
impl<'a> MulAssign<&'a Self> for F {
fn mul_assign(&mut self, rhs: &'a Self) {
*self = *self * rhs;
}
}
impl Multiplicative for F {}
impl Neg for F {
type Output = Self;
fn neg(self) -> Self::Output {
self.neg_inner()
}
}
impl From<u64> for F {
fn from(value: u64) -> Self {
Self::reduce_64(value)
}
}
impl Ring for F {
fn one() -> Self {
Self::ONE
}
}
impl Field for F {
fn inv(&self) -> Self {
Self::inv(*self)
}
}
impl FieldNTT for F {
const MAX_LG_ROOT_ORDER: u8 = 32;
fn root_of_unity(lg: u8) -> Option<Self> {
if lg > Self::MAX_LG_ROOT_ORDER {
return None;
}
let mut out = Self::ROOT_OF_UNITY;
for _ in 0..(Self::MAX_LG_ROOT_ORDER - lg) {
out = out * out;
}
Some(out)
}
fn coset_shift() -> Self {
Self::NOT_ROOT_OF_UNITY
}
fn coset_shift_inv() -> Self {
Self::NOT_ROOT_OF_UNITY_INV
}
fn div_2(&self) -> Self {
if self.0 & 1 == 0 {
Self(self.0 >> 1)
} else {
let (addition, carry) = self.0.overflowing_add(P);
Self((u64::from(carry) << 63) | (addition >> 1))
}
}
}
#[cfg(any(test, feature = "fuzz"))]
pub mod fuzz {
use super::*;
use crate::algebra::test_suites;
use arbitrary::{Arbitrary, Unstructured};
use commonware_codec::{Encode as _, ReadExt as _};
#[derive(Debug)]
pub struct NonCanonicalU64(pub u64);
impl Arbitrary<'_> for NonCanonicalU64 {
fn arbitrary(u: &mut Unstructured<'_>) -> arbitrary::Result<Self> {
Ok(Self(u.int_in_range(P..=u64::MAX)?))
}
}
#[derive(Debug, Arbitrary)]
pub enum Plan {
StreamRoundtrip(Vec<u64>),
ReadRejectsOutOfRange(NonCanonicalU64),
FuzzField,
}
impl Plan {
pub fn run(self, u: &mut Unstructured<'_>) -> arbitrary::Result<()> {
match self {
Self::StreamRoundtrip(data) => {
let mut roundtrip =
F::stream_to_u64s(F::stream_from_u64s(data.clone().into_iter()))
.collect::<Vec<_>>();
roundtrip.truncate(data.len());
assert_eq!(data, roundtrip);
}
Self::ReadRejectsOutOfRange(NonCanonicalU64(x)) => {
let result = F::read(&mut x.encode());
assert!(matches!(
result,
Err(commonware_codec::Error::Invalid("F", "out of range"))
));
}
Self::FuzzField => {
test_suites::fuzz_field_ntt::<F>(u)?;
}
}
Ok(())
}
}
#[test]
fn test_fuzz() {
commonware_invariants::minifuzz::test(|u| u.arbitrary::<Plan>()?.run(u));
}
#[test]
fn test_read_cfg_rejects_modulus_regression_case() {
let mut u = Unstructured::new(&[]);
Plan::ReadRejectsOutOfRange(NonCanonicalU64(P))
.run(&mut u)
.expect("regression plan should succeed");
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_generator_calculation() {
assert_eq!(F::GENERATOR, F(7).exp(&[133]));
}
#[test]
fn test_root_of_unity_calculation() {
assert_eq!(F::ROOT_OF_UNITY, F::GENERATOR.exp(&[(P - 1) >> 32]));
}
#[test]
fn test_not_root_of_unity_calculation() {
assert_eq!(F::NOT_ROOT_OF_UNITY, F::GENERATOR.exp(&[1 << 32]));
}
#[test]
fn test_not_root_of_unity_inv_calculation() {
assert_eq!(F::NOT_ROOT_OF_UNITY * F::NOT_ROOT_OF_UNITY_INV, F::one());
}
#[test]
fn test_root_of_unity_exp() {
assert_eq!(F::ROOT_OF_UNITY.exp(&[1 << 26]), F(8));
}
#[cfg(feature = "arbitrary")]
mod conformance {
use super::*;
use commonware_codec::conformance::CodecConformance;
commonware_conformance::conformance_tests! {
CodecConformance<F>
}
}
}