use alloc::{fmt, vec::Vec};
use core::{
cmp::{Ordering, PartialEq},
ops,
};
use bytemuck::{CheckedBitPattern, NoUninit, Zeroable};
use crate::field::{self, Elem as FieldElem};
pub struct BabyBear;
impl field::Field for BabyBear {
type Elem = Elem;
type ExtElem = ExtElem;
}
const M: u32 = 0x88000001;
const R2: u32 = 1172168163;
#[derive(Eq, Clone, Copy, NoUninit, Zeroable)]
#[repr(transparent)]
pub struct Elem(u32);
pub type BabyBearElem = Elem;
impl Default for Elem {
fn default() -> Self {
Self::ZERO
}
}
impl fmt::Debug for Elem {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "0x{:08x}", decode(self.0))
}
}
pub const P: u32 = 15 * (1 << 27) + 1;
const P_U64: u64 = P as u64;
const WORDS: usize = 1;
impl field::Elem for Elem {
const INVALID: Self = Elem(0xffffffff);
const ZERO: Self = Elem::new(0);
const ONE: Self = Elem::new(1);
const WORDS: usize = WORDS;
fn inv(self) -> Self {
self.ensure_valid().pow((P - 2) as usize)
}
fn random(rng: &mut impl rand_core::RngCore) -> Self {
let mut val: u64 = 0;
for _ in 0..6 {
val <<= 32;
val += rng.next_u32() as u64;
val %= P as u64;
}
Elem::from(val as u32)
}
fn from_u64(val: u64) -> Self {
Elem::from(val)
}
fn to_u32_words(&self) -> Vec<u32> {
Vec::<u32>::from([self.0])
}
fn from_u32_words(val: &[u32]) -> Self {
Self(val[0])
}
fn is_valid(&self) -> bool {
self.0 != Self::INVALID.0
}
fn is_reduced(&self) -> bool {
self.0 < P
}
}
unsafe impl CheckedBitPattern for Elem {
type Bits = u32;
fn is_valid_bit_pattern(bits: &u32) -> bool {
*bits < P
}
}
macro_rules! rou_array {
[$($x:literal),* $(,)?] => {
[$(Elem::new($x)),* ]
}
}
impl field::RootsOfUnity for Elem {
const MAX_ROU_PO2: usize = 27;
const ROU_FWD: &'static [Elem] = &rou_array![
1, 2013265920, 284861408, 1801542727, 567209306, 740045640, 918899846, 1881002012,
1453957774, 65325759, 1538055801, 515192888, 483885487, 157393079, 1695124103, 2005211659,
1540072241, 88064245, 1542985445, 1269900459, 1461624142, 825701067, 682402162, 1311873874,
1164520853, 352275361, 18769, 137
];
const ROU_REV: &'static [Elem] = &rou_array![
1, 2013265920, 1728404513, 1592366214, 196396260, 1253260071, 72041623, 1091445674,
145223211, 1446820157, 1030796471, 2010749425, 1827366325, 1239938613, 246299276,
596347512, 1893145354, 246074437, 1525739923, 1194341128, 1463599021, 704606912, 95395244,
15672543, 647517488, 584175179, 137728885, 749463956
];
}
impl Elem {
pub const fn new(x: u32) -> Self {
Self(encode(x % P))
}
pub const fn new_raw(x: u32) -> Self {
Self(x)
}
pub const fn as_u32(&self) -> u32 {
decode(self.0)
}
pub const fn as_u32_montgomery(&self) -> u32 {
self.0
}
}
impl ops::Add for Elem {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Elem(add(self.ensure_valid().0, rhs.ensure_valid().0))
}
}
impl ops::AddAssign for Elem {
fn add_assign(&mut self, rhs: Self) {
self.0 = add(self.ensure_valid().0, rhs.ensure_valid().0)
}
}
impl ops::Sub for Elem {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Elem(sub(self.ensure_valid().0, rhs.ensure_valid().0))
}
}
impl ops::SubAssign for Elem {
fn sub_assign(&mut self, rhs: Self) {
self.0 = sub(self.ensure_valid().0, rhs.ensure_valid().0)
}
}
impl ops::Mul for Elem {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Elem(mul(self.ensure_valid().0, rhs.ensure_valid().0))
}
}
impl ops::MulAssign for Elem {
fn mul_assign(&mut self, rhs: Self) {
self.0 = mul(self.ensure_valid().0, rhs.ensure_valid().0)
}
}
impl ops::Neg for Elem {
type Output = Self;
fn neg(self) -> Self {
Elem(0) - *self.ensure_valid()
}
}
impl PartialEq<Elem> for Elem {
fn eq(&self, rhs: &Self) -> bool {
self.ensure_valid().0 == rhs.ensure_valid().0
}
}
impl Ord for Elem {
fn cmp(&self, rhs: &Self) -> Ordering {
decode(self.ensure_valid().0).cmp(&decode(rhs.ensure_valid().0))
}
}
impl PartialOrd for Elem {
fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
Some(self.cmp(rhs))
}
}
impl From<Elem> for u32 {
fn from(x: Elem) -> Self {
decode(x.0)
}
}
impl From<Elem> for u64 {
fn from(x: Elem) -> Self {
decode(x.0).into()
}
}
impl From<u32> for Elem {
fn from(x: u32) -> Self {
Elem::new(x)
}
}
impl From<u64> for Elem {
fn from(x: u64) -> Self {
Elem::new((x % P_U64) as u32)
}
}
fn add(lhs: u32, rhs: u32) -> u32 {
let x = lhs.wrapping_add(rhs);
if x >= P {
x - P
} else {
x
}
}
fn sub(lhs: u32, rhs: u32) -> u32 {
let x = lhs.wrapping_sub(rhs);
if x > P {
x.wrapping_add(P)
} else {
x
}
}
const fn mul(lhs: u32, rhs: u32) -> u32 {
let mut o64: u64 = (lhs as u64).wrapping_mul(rhs as u64);
let low: u32 = 0u32.wrapping_sub(o64 as u32);
let red = M.wrapping_mul(low);
o64 += (red as u64).wrapping_mul(P_U64);
let ret = (o64 >> 32) as u32;
if ret >= P {
ret - P
} else {
ret
}
}
const fn encode(a: u32) -> u32 {
mul(R2, a)
}
const fn decode(a: u32) -> u32 {
mul(1, a)
}
const EXT_SIZE: usize = 4;
#[derive(Eq, Clone, Copy, Zeroable)]
#[repr(transparent)]
pub struct ExtElem([Elem; EXT_SIZE]);
unsafe impl NoUninit for ExtElem {}
unsafe impl CheckedBitPattern for ExtElem {
type Bits = [u32; EXT_SIZE];
fn is_valid_bit_pattern(bits: &[u32; EXT_SIZE]) -> bool {
let mut valid = true;
for x in bits {
valid &= *x < P;
}
valid
}
}
pub type BabyBearExtElem = ExtElem;
impl Default for ExtElem {
fn default() -> Self {
Self::ZERO
}
}
impl fmt::Debug for ExtElem {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(
f,
"[{:?}, {:?}, {:?}, {:?}]",
self.0[0], self.0[1], self.0[2], self.0[3]
)
}
}
impl field::Elem for ExtElem {
const INVALID: Self = ExtElem([Elem::INVALID, Elem::INVALID, Elem::INVALID, Elem::INVALID]);
const ZERO: Self = ExtElem::zero();
const ONE: Self = ExtElem::one();
const WORDS: usize = WORDS * EXT_SIZE;
fn random(rng: &mut impl rand_core::RngCore) -> Self {
Self([
Elem::random(rng),
Elem::random(rng),
Elem::random(rng),
Elem::random(rng),
])
}
fn pow(self, n: usize) -> Self {
let mut n = n;
let mut tot = ExtElem::ONE;
let mut x = *self.ensure_valid();
while n != 0 {
if n % 2 == 1 {
tot *= x;
}
n /= 2;
x *= x;
}
tot
}
fn inv(self) -> Self {
let a = &self.ensure_valid().0;
let mut b0 = a[0] * a[0] + BETA * (a[1] * (a[3] + a[3]) - a[2] * a[2]);
let mut b2 = a[0] * (a[2] + a[2]) - a[1] * a[1] + BETA * (a[3] * a[3]);
let c = b0 * b0 + BETA * b2 * b2;
let ic = c.inv();
b0 *= ic;
b2 *= ic;
ExtElem([
a[0] * b0 + BETA * a[2] * b2,
-a[1] * b0 + NBETA * a[3] * b2,
-a[0] * b2 + a[2] * b0,
a[1] * b2 - a[3] * b0,
])
}
fn from_u64(val: u64) -> Self {
Self([Elem::from_u64(val), Elem::ZERO, Elem::ZERO, Elem::ZERO])
}
fn to_u32_words(&self) -> Vec<u32> {
self.elems()
.iter()
.flat_map(|elem| elem.to_u32_words())
.collect()
}
fn from_u32_words(val: &[u32]) -> Self {
field::ExtElem::from_subelems(val.iter().map(|word| Elem(*word)))
}
fn is_valid(&self) -> bool {
self.0[0].is_valid()
}
fn is_reduced(&self) -> bool {
self.0.iter().all(|x| x.is_reduced())
}
}
impl field::ExtElem for ExtElem {
const EXT_SIZE: usize = EXT_SIZE;
type SubElem = Elem;
fn from_subfield(elem: &Elem) -> Self {
Self::from([*elem.ensure_valid(), Elem::ZERO, Elem::ZERO, Elem::ZERO])
}
fn from_subelems(elems: impl IntoIterator<Item = Self::SubElem>) -> Self {
let mut iter = elems.into_iter();
let elem = Self::from([
*iter.next().unwrap().ensure_valid(),
*iter.next().unwrap().ensure_valid(),
*iter.next().unwrap().ensure_valid(),
*iter.next().unwrap().ensure_valid(),
]);
assert!(
iter.next().is_none(),
"Extra elements passed to create element in extension field"
);
elem
}
fn subelems(&self) -> &[Elem] {
&self.ensure_valid().0
}
}
impl PartialEq<ExtElem> for ExtElem {
fn eq(&self, rhs: &Self) -> bool {
self.ensure_valid().0 == rhs.ensure_valid().0
}
}
impl From<[Elem; EXT_SIZE]> for ExtElem {
fn from(val: [Elem; EXT_SIZE]) -> Self {
if cfg!(debug_assertions) {
for elem in val.iter() {
elem.ensure_valid();
}
}
ExtElem(val)
}
}
const BETA: Elem = Elem::new(11);
const NBETA: Elem = Elem::new(P - 11);
const fn const_ensure_valid(x: Elem) -> Elem {
debug_assert!(x.0 != Elem::INVALID.0);
x
}
impl ExtElem {
pub const fn new(x0: Elem, x1: Elem, x2: Elem, x3: Elem) -> Self {
Self([
const_ensure_valid(x0),
const_ensure_valid(x1),
const_ensure_valid(x2),
const_ensure_valid(x3),
])
}
pub fn from_fp(x: Elem) -> Self {
Self([x, Elem::new(0), Elem::new(0), Elem::new(0)])
}
pub const fn from_u32(x0: u32) -> Self {
Self([Elem::new(x0), Elem::new(0), Elem::new(0), Elem::new(0)])
}
const fn zero() -> Self {
Self::from_u32(0)
}
const fn one() -> Self {
Self::from_u32(1)
}
pub fn const_part(self) -> Elem {
self.ensure_valid().0[0]
}
pub fn elems(&self) -> &[Elem] {
&self.ensure_valid().0
}
}
impl ops::Add for ExtElem {
type Output = Self;
fn add(self, rhs: Self) -> Self {
let mut lhs = self;
lhs += rhs;
lhs
}
}
impl ops::AddAssign for ExtElem {
fn add_assign(&mut self, rhs: Self) {
for i in 0..self.0.len() {
self.0[i] += rhs.0[i];
}
}
}
impl ops::Add<Elem> for ExtElem {
type Output = Self;
fn add(self, rhs: Elem) -> Self {
let mut lhs = self;
lhs += rhs;
lhs
}
}
impl ops::Add<ExtElem> for Elem {
type Output = ExtElem;
fn add(self, rhs: ExtElem) -> ExtElem {
let mut lhs = ExtElem::from(self);
lhs += rhs;
lhs
}
}
impl ops::AddAssign<Elem> for ExtElem {
fn add_assign(&mut self, rhs: Elem) {
self.0[0] += rhs;
}
}
impl ops::Sub for ExtElem {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
let mut lhs = self;
lhs -= rhs;
lhs
}
}
impl ops::SubAssign for ExtElem {
fn sub_assign(&mut self, rhs: Self) {
for i in 0..self.0.len() {
self.0[i] -= rhs.0[i];
}
}
}
impl ops::Sub<Elem> for ExtElem {
type Output = Self;
fn sub(self, rhs: Elem) -> Self {
let mut lhs = self;
lhs -= rhs;
lhs
}
}
impl ops::Sub<ExtElem> for Elem {
type Output = ExtElem;
fn sub(self, rhs: ExtElem) -> ExtElem {
let mut lhs = ExtElem::from(self);
lhs -= rhs;
lhs
}
}
impl ops::SubAssign<Elem> for ExtElem {
fn sub_assign(&mut self, rhs: Elem) {
self.0[0] -= rhs;
}
}
impl ops::MulAssign<Elem> for ExtElem {
fn mul_assign(&mut self, rhs: Elem) {
for i in 0..self.0.len() {
self.0[i] *= rhs;
}
}
}
impl ops::Mul<Elem> for ExtElem {
type Output = Self;
fn mul(self, rhs: Elem) -> Self {
let mut lhs = self;
lhs *= rhs;
lhs
}
}
impl ops::Mul<ExtElem> for Elem {
type Output = ExtElem;
fn mul(self, rhs: ExtElem) -> ExtElem {
rhs * self
}
}
impl ops::MulAssign for ExtElem {
#[inline(always)]
fn mul_assign(&mut self, rhs: Self) {
let a = &self.0;
let b = &rhs.0;
self.0 = [
a[0] * b[0] + NBETA * (a[1] * b[3] + a[2] * b[2] + a[3] * b[1]),
a[0] * b[1] + a[1] * b[0] + NBETA * (a[2] * b[3] + a[3] * b[2]),
a[0] * b[2] + a[1] * b[1] + a[2] * b[0] + NBETA * (a[3] * b[3]),
a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0],
];
}
}
impl ops::Mul for ExtElem {
type Output = ExtElem;
#[inline(always)]
fn mul(self, rhs: ExtElem) -> ExtElem {
let mut lhs = self;
lhs *= rhs;
lhs
}
}
impl ops::Neg for ExtElem {
type Output = Self;
fn neg(self) -> Self {
ExtElem::ZERO - self
}
}
impl From<u32> for ExtElem {
fn from(x: u32) -> Self {
Self([Elem::from(x), Elem::ZERO, Elem::ZERO, Elem::ZERO])
}
}
impl From<Elem> for ExtElem {
fn from(x: Elem) -> Self {
Self([x, Elem::ZERO, Elem::ZERO, Elem::ZERO])
}
}
#[cfg(test)]
mod tests {
use alloc::{vec, vec::Vec};
use rand::{Rng, SeedableRng};
use super::{field, Elem, ExtElem, P, P_U64};
use crate::field::Elem as FieldElem;
#[test]
pub fn roots_of_unity() {
field::tests::test_roots_of_unity::<Elem>();
}
#[test]
pub fn field_ops() {
field::tests::test_field_ops::<Elem>(P_U64);
}
#[test]
pub fn ext_field_ops() {
field::tests::test_ext_field_ops::<ExtElem>();
}
#[test]
pub fn linear() {
let x = ExtElem::new(
Elem::new(1880084280),
Elem::new(1788985953),
Elem::new(1273325207),
Elem::new(277471107),
);
let c0 = ExtElem::new(
Elem::new(1582815482),
Elem::new(2011839994),
Elem::new(589901),
Elem::new(698998108),
);
let c1 = ExtElem::new(
Elem::new(1262573828),
Elem::new(1903841444),
Elem::new(1738307519),
Elem::new(100967278),
);
assert_eq!(
x * c1,
ExtElem::new(
Elem::new(876029217),
Elem::new(1948387849),
Elem::new(498773186),
Elem::new(1997003991)
)
);
assert_eq!(
c0 + x * c1,
ExtElem::new(
Elem::new(445578778),
Elem::new(1946961922),
Elem::new(499363087),
Elem::new(682736178)
)
);
}
#[test]
fn isa_field() {
let mut rng = rand::rngs::SmallRng::seed_from_u64(2);
for _ in 0..1_000 {
let a = ExtElem::random(&mut rng);
let b = ExtElem::random(&mut rng);
let c = ExtElem::random(&mut rng);
assert_eq!(a + b, b + a);
assert_eq!(a * b, b * a);
assert_eq!(a + (b + c), (a + b) + c);
assert_eq!(a * (b * c), (a * b) * c);
assert_eq!(a * (b + c), a * b + a * c);
if a != ExtElem::ZERO {
assert_eq!(a.inv() * a, ExtElem::from(1));
}
assert_eq!(ExtElem::ZERO - a, -a);
assert_eq!(a + (-a), ExtElem::ZERO);
}
}
#[test]
fn inv() {
assert_eq!(Elem::new(5).inv() * Elem::new(5), Elem::new(1));
}
#[test]
fn pow() {
assert_eq!(Elem::new(5).pow(0), Elem::new(1));
assert_eq!(Elem::new(5).pow(1), Elem::new(5));
assert_eq!(Elem::new(5).pow(2), Elem::new(25));
assert_eq!(Elem::new(5).pow(1000), Elem::new(589699054));
assert_eq!(
Elem::new(5).pow((P - 2) as usize) * Elem::new(5),
Elem::new(1)
);
assert_eq!(Elem::new(5).pow((P - 1) as usize), Elem::new(1));
}
#[test]
fn compare_native() {
let mut rng = rand::rngs::SmallRng::seed_from_u64(2);
for _ in 0..100_000 {
let fa = Elem::random(&mut rng);
let fb = Elem::random(&mut rng);
let a: u64 = fa.into();
let b: u64 = fb.into();
assert_eq!(fa + fb, Elem::from(a + b));
assert_eq!(fa - fb, Elem::from(a + (P_U64 - b)));
assert_eq!(fa * fb, Elem::from(a * b));
}
}
#[test]
#[cfg_attr(not(debug_assertions), ignore)]
#[should_panic(expected = "assertion failed: self.is_valid")]
fn compare_against_invalid() {
let _ = Elem::ZERO == Elem::INVALID;
}
#[test]
fn u32s_conversions() {
let mut rng = rand::rngs::SmallRng::seed_from_u64(2);
for _ in 0..100 {
let elem = Elem::random(&mut rng);
assert_eq!(elem, Elem::from_u32_words(&elem.to_u32_words()));
let val: u32 = rng.random();
assert_eq!(val, Elem::from_u32_words(&[val]).to_u32_words()[0]);
}
for _ in 0..100 {
let elem = ExtElem::random(&mut rng);
assert_eq!(elem, ExtElem::from_u32_words(&elem.to_u32_words()));
let vec: Vec<u32> = vec![rng.random(), rng.random(), rng.random(), rng.random()];
assert_eq!(vec, ExtElem::from_u32_words(&vec).to_u32_words());
}
}
}