1use alloc::{fmt, vec::Vec};
22use core::{
23 cmp::{Ordering, PartialEq},
24 ops,
25};
26
27use bytemuck::{CheckedBitPattern, NoUninit, Zeroable};
28
29use crate::field::{self, Elem as FieldElem};
30
31pub struct BabyBear;
34
35impl field::Field for BabyBear {
36 type Elem = Elem;
37 type ExtElem = ExtElem;
38}
39
40const M: u32 = 0x88000001;
42const R2: u32 = 1172168163;
43
44#[derive(Eq, Clone, Copy, NoUninit, Zeroable)]
65#[repr(transparent)]
66pub struct Elem(u32);
67
68pub type BabyBearElem = Elem;
70
71impl Default for Elem {
72 fn default() -> Self {
73 Self::ZERO
74 }
75}
76
77impl fmt::Debug for Elem {
78 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79 write!(f, "0x{:08x}", decode(self.0))
80 }
81}
82
83pub const P: u32 = 15 * (1 << 27) + 1;
85
86const P_U64: u64 = P as u64;
88
89const WORDS: usize = 1;
91
92impl field::Elem for Elem {
93 const INVALID: Self = Elem(0xffffffff);
94 const ZERO: Self = Elem::new(0);
95 const ONE: Self = Elem::new(1);
96 const WORDS: usize = WORDS;
97
98 fn inv(self) -> Self {
106 self.ensure_valid().pow((P - 2) as usize)
107 }
108
109 fn random(rng: &mut impl rand_core::RngCore) -> Self {
111 let mut val: u64 = 0;
134 for _ in 0..6 {
135 val <<= 32;
136 val += rng.next_u32() as u64;
137 val %= P as u64;
138 }
139 Elem::from(val as u32)
140 }
141
142 fn from_u64(val: u64) -> Self {
143 Elem::from(val)
144 }
145
146 fn to_u32_words(&self) -> Vec<u32> {
147 Vec::<u32>::from([self.0])
148 }
149
150 fn from_u32_words(val: &[u32]) -> Self {
151 Self(val[0])
152 }
153
154 fn is_valid(&self) -> bool {
155 self.0 != Self::INVALID.0
156 }
157
158 fn is_reduced(&self) -> bool {
159 self.0 < P
160 }
161}
162
163unsafe impl CheckedBitPattern for Elem {
164 type Bits = u32;
165
166 fn is_valid_bit_pattern(bits: &u32) -> bool {
168 *bits < P
169 }
170}
171
172macro_rules! rou_array {
173 [$($x:literal),* $(,)?] => {
174 [$(Elem::new($x)),* ]
175 }
176}
177
178impl field::RootsOfUnity for Elem {
179 const MAX_ROU_PO2: usize = 27;
182
183 const ROU_FWD: &'static [Elem] = &rou_array![
185 1, 2013265920, 284861408, 1801542727, 567209306, 740045640, 918899846, 1881002012,
186 1453957774, 65325759, 1538055801, 515192888, 483885487, 157393079, 1695124103, 2005211659,
187 1540072241, 88064245, 1542985445, 1269900459, 1461624142, 825701067, 682402162, 1311873874,
188 1164520853, 352275361, 18769, 137
189 ];
190
191 const ROU_REV: &'static [Elem] = &rou_array![
193 1, 2013265920, 1728404513, 1592366214, 196396260, 1253260071, 72041623, 1091445674,
194 145223211, 1446820157, 1030796471, 2010749425, 1827366325, 1239938613, 246299276,
195 596347512, 1893145354, 246074437, 1525739923, 1194341128, 1463599021, 704606912, 95395244,
196 15672543, 647517488, 584175179, 137728885, 749463956
197 ];
198}
199
200impl Elem {
201 pub const fn new(x: u32) -> Self {
203 Self(encode(x % P))
204 }
205
206 pub const fn new_raw(x: u32) -> Self {
210 Self(x)
211 }
212
213 pub const fn as_u32(&self) -> u32 {
215 decode(self.0)
216 }
217
218 pub const fn as_u32_montgomery(&self) -> u32 {
221 self.0
222 }
223}
224
225impl ops::Add for Elem {
226 type Output = Self;
227
228 fn add(self, rhs: Self) -> Self {
230 Elem(add(self.ensure_valid().0, rhs.ensure_valid().0))
231 }
232}
233
234impl ops::AddAssign for Elem {
235 fn add_assign(&mut self, rhs: Self) {
237 self.0 = add(self.ensure_valid().0, rhs.ensure_valid().0)
238 }
239}
240
241impl ops::Sub for Elem {
242 type Output = Self;
243
244 fn sub(self, rhs: Self) -> Self {
246 Elem(sub(self.ensure_valid().0, rhs.ensure_valid().0))
247 }
248}
249
250impl ops::SubAssign for Elem {
251 fn sub_assign(&mut self, rhs: Self) {
253 self.0 = sub(self.ensure_valid().0, rhs.ensure_valid().0)
254 }
255}
256
257impl ops::Mul for Elem {
258 type Output = Self;
259
260 fn mul(self, rhs: Self) -> Self {
262 Elem(mul(self.ensure_valid().0, rhs.ensure_valid().0))
263 }
264}
265
266impl ops::MulAssign for Elem {
267 fn mul_assign(&mut self, rhs: Self) {
269 self.0 = mul(self.ensure_valid().0, rhs.ensure_valid().0)
270 }
271}
272
273impl ops::Neg for Elem {
274 type Output = Self;
275
276 fn neg(self) -> Self {
277 Elem(0) - *self.ensure_valid()
278 }
279}
280
281impl PartialEq<Elem> for Elem {
282 fn eq(&self, rhs: &Self) -> bool {
283 self.ensure_valid().0 == rhs.ensure_valid().0
284 }
285}
286
287impl Ord for Elem {
288 fn cmp(&self, rhs: &Self) -> Ordering {
289 decode(self.ensure_valid().0).cmp(&decode(rhs.ensure_valid().0))
290 }
291}
292
293impl PartialOrd for Elem {
294 fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
295 Some(self.cmp(rhs))
296 }
297}
298
299impl From<Elem> for u32 {
300 fn from(x: Elem) -> Self {
301 decode(x.0)
302 }
303}
304
305impl From<Elem> for u64 {
306 fn from(x: Elem) -> Self {
307 decode(x.0).into()
308 }
309}
310
311impl From<u32> for Elem {
312 fn from(x: u32) -> Self {
313 Elem::new(x)
314 }
315}
316
317impl From<u64> for Elem {
318 fn from(x: u64) -> Self {
319 Elem::new((x % P_U64) as u32)
320 }
321}
322
323fn add(lhs: u32, rhs: u32) -> u32 {
325 let x = lhs.wrapping_add(rhs);
326 if x >= P {
327 x - P
328 } else {
329 x
330 }
331}
332
333fn sub(lhs: u32, rhs: u32) -> u32 {
335 let x = lhs.wrapping_sub(rhs);
336 if x > P {
337 x.wrapping_add(P)
338 } else {
339 x
340 }
341}
342
343const fn mul(lhs: u32, rhs: u32) -> u32 {
346 let mut o64: u64 = (lhs as u64).wrapping_mul(rhs as u64);
348 let low: u32 = 0u32.wrapping_sub(o64 as u32);
350 let red = M.wrapping_mul(low);
352 o64 += (red as u64).wrapping_mul(P_U64);
354 let ret = (o64 >> 32) as u32;
356 if ret >= P {
358 ret - P
359 } else {
360 ret
361 }
362}
363
364const fn encode(a: u32) -> u32 {
366 mul(R2, a)
367}
368
369const fn decode(a: u32) -> u32 {
371 mul(1, a)
372}
373
374const EXT_SIZE: usize = 4;
376
377#[derive(Eq, Clone, Copy, Zeroable)]
386#[repr(transparent)]
387pub struct ExtElem([Elem; EXT_SIZE]);
388
389unsafe impl NoUninit for ExtElem {}
393
394unsafe impl CheckedBitPattern for ExtElem {
395 type Bits = [u32; EXT_SIZE];
396
397 fn is_valid_bit_pattern(bits: &[u32; EXT_SIZE]) -> bool {
399 let mut valid = true;
400 for x in bits {
401 valid &= *x < P;
402 }
403 valid
404 }
405}
406
407pub type BabyBearExtElem = ExtElem;
409
410impl Default for ExtElem {
411 fn default() -> Self {
412 Self::ZERO
413 }
414}
415
416impl fmt::Debug for ExtElem {
417 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
418 write!(
419 f,
420 "[{:?}, {:?}, {:?}, {:?}]",
421 self.0[0], self.0[1], self.0[2], self.0[3]
422 )
423 }
424}
425
426impl field::Elem for ExtElem {
427 const INVALID: Self = ExtElem([Elem::INVALID, Elem::INVALID, Elem::INVALID, Elem::INVALID]);
428 const ZERO: Self = ExtElem::zero();
429 const ONE: Self = ExtElem::one();
430 const WORDS: usize = WORDS * EXT_SIZE;
431
432 fn random(rng: &mut impl rand_core::RngCore) -> Self {
434 Self([
437 Elem::random(rng),
438 Elem::random(rng),
439 Elem::random(rng),
440 Elem::random(rng),
441 ])
442 }
443
444 fn pow(self, n: usize) -> Self {
446 let mut n = n;
447 let mut tot = ExtElem::ONE;
448 let mut x = *self.ensure_valid();
449 while n != 0 {
450 if n % 2 == 1 {
451 tot *= x;
452 }
453 n /= 2;
454 x *= x;
455 }
456 tot
457 }
458
459 fn inv(self) -> Self {
461 let a = &self.ensure_valid().0;
462 let mut b0 = a[0] * a[0] + BETA * (a[1] * (a[3] + a[3]) - a[2] * a[2]);
471 let mut b2 = a[0] * (a[2] + a[2]) - a[1] * a[1] + BETA * (a[3] * a[3]);
472 let c = b0 * b0 + BETA * b2 * b2;
476 let ic = c.inv();
479 b0 *= ic;
486 b2 *= ic;
487 ExtElem([
488 a[0] * b0 + BETA * a[2] * b2,
489 -a[1] * b0 + NBETA * a[3] * b2,
490 -a[0] * b2 + a[2] * b0,
491 a[1] * b2 - a[3] * b0,
492 ])
493 }
494
495 fn from_u64(val: u64) -> Self {
497 Self([Elem::from_u64(val), Elem::ZERO, Elem::ZERO, Elem::ZERO])
498 }
499
500 fn to_u32_words(&self) -> Vec<u32> {
501 self.elems()
502 .iter()
503 .flat_map(|elem| elem.to_u32_words())
504 .collect()
505 }
506
507 fn from_u32_words(val: &[u32]) -> Self {
508 field::ExtElem::from_subelems(val.iter().map(|word| Elem(*word)))
509 }
510
511 fn is_valid(&self) -> bool {
516 self.0[0].is_valid()
517 }
518
519 fn is_reduced(&self) -> bool {
520 self.0.iter().all(|x| x.is_reduced())
521 }
522}
523
524impl field::ExtElem for ExtElem {
525 const EXT_SIZE: usize = EXT_SIZE;
526
527 type SubElem = Elem;
528
529 fn from_subfield(elem: &Elem) -> Self {
530 Self::from([*elem.ensure_valid(), Elem::ZERO, Elem::ZERO, Elem::ZERO])
531 }
532
533 fn from_subelems(elems: impl IntoIterator<Item = Self::SubElem>) -> Self {
534 let mut iter = elems.into_iter();
535 let elem = Self::from([
536 *iter.next().unwrap().ensure_valid(),
537 *iter.next().unwrap().ensure_valid(),
538 *iter.next().unwrap().ensure_valid(),
539 *iter.next().unwrap().ensure_valid(),
540 ]);
541 assert!(
542 iter.next().is_none(),
543 "Extra elements passed to create element in extension field"
544 );
545 elem
546 }
547
548 fn subelems(&self) -> &[Elem] {
550 &self.ensure_valid().0
551 }
552}
553
554impl PartialEq<ExtElem> for ExtElem {
555 fn eq(&self, rhs: &Self) -> bool {
556 self.ensure_valid().0 == rhs.ensure_valid().0
557 }
558}
559
560impl From<[Elem; EXT_SIZE]> for ExtElem {
561 fn from(val: [Elem; EXT_SIZE]) -> Self {
562 if cfg!(debug_assertions) {
563 for elem in val.iter() {
564 elem.ensure_valid();
565 }
566 }
567 ExtElem(val)
568 }
569}
570
571const BETA: Elem = Elem::new(11);
572const NBETA: Elem = Elem::new(P - 11);
573
574const fn const_ensure_valid(x: Elem) -> Elem {
576 debug_assert!(x.0 != Elem::INVALID.0);
577 x
578}
579
580impl ExtElem {
581 pub const fn new(x0: Elem, x1: Elem, x2: Elem, x3: Elem) -> Self {
583 Self([
584 const_ensure_valid(x0),
585 const_ensure_valid(x1),
586 const_ensure_valid(x2),
587 const_ensure_valid(x3),
588 ])
589 }
590
591 pub fn from_fp(x: Elem) -> Self {
593 Self([x, Elem::new(0), Elem::new(0), Elem::new(0)])
594 }
595
596 pub const fn from_u32(x0: u32) -> Self {
598 Self([Elem::new(x0), Elem::new(0), Elem::new(0), Elem::new(0)])
599 }
600
601 const fn zero() -> Self {
603 Self::from_u32(0)
604 }
605
606 const fn one() -> Self {
608 Self::from_u32(1)
609 }
610
611 pub fn const_part(self) -> Elem {
613 self.ensure_valid().0[0]
614 }
615
616 pub fn elems(&self) -> &[Elem] {
618 &self.ensure_valid().0
619 }
620}
621
622impl ops::Add for ExtElem {
623 type Output = Self;
624
625 fn add(self, rhs: Self) -> Self {
627 let mut lhs = self;
628 lhs += rhs;
629 lhs
630 }
631}
632
633impl ops::AddAssign for ExtElem {
634 fn add_assign(&mut self, rhs: Self) {
636 for i in 0..self.0.len() {
637 self.0[i] += rhs.0[i];
638 }
639 }
640}
641
642impl ops::Add<Elem> for ExtElem {
643 type Output = Self;
644
645 fn add(self, rhs: Elem) -> Self {
647 let mut lhs = self;
648 lhs += rhs;
649 lhs
650 }
651}
652
653impl ops::Add<ExtElem> for Elem {
654 type Output = ExtElem;
655
656 fn add(self, rhs: ExtElem) -> ExtElem {
658 let mut lhs = ExtElem::from(self);
659 lhs += rhs;
660 lhs
661 }
662}
663
664impl ops::AddAssign<Elem> for ExtElem {
665 fn add_assign(&mut self, rhs: Elem) {
667 self.0[0] += rhs;
668 }
669}
670
671impl ops::Sub for ExtElem {
672 type Output = Self;
673
674 fn sub(self, rhs: Self) -> Self {
676 let mut lhs = self;
677 lhs -= rhs;
678 lhs
679 }
680}
681
682impl ops::SubAssign for ExtElem {
683 fn sub_assign(&mut self, rhs: Self) {
685 for i in 0..self.0.len() {
686 self.0[i] -= rhs.0[i];
687 }
688 }
689}
690
691impl ops::Sub<Elem> for ExtElem {
692 type Output = Self;
693
694 fn sub(self, rhs: Elem) -> Self {
696 let mut lhs = self;
697 lhs -= rhs;
698 lhs
699 }
700}
701
702impl ops::Sub<ExtElem> for Elem {
703 type Output = ExtElem;
704
705 fn sub(self, rhs: ExtElem) -> ExtElem {
707 let mut lhs = ExtElem::from(self);
708 lhs -= rhs;
709 lhs
710 }
711}
712
713impl ops::SubAssign<Elem> for ExtElem {
714 fn sub_assign(&mut self, rhs: Elem) {
716 self.0[0] -= rhs;
717 }
718}
719
720impl ops::MulAssign<Elem> for ExtElem {
721 fn mul_assign(&mut self, rhs: Elem) {
724 for i in 0..self.0.len() {
725 self.0[i] *= rhs;
726 }
727 }
728}
729
730impl ops::Mul<Elem> for ExtElem {
731 type Output = Self;
732
733 fn mul(self, rhs: Elem) -> Self {
735 let mut lhs = self;
736 lhs *= rhs;
737 lhs
738 }
739}
740
741impl ops::Mul<ExtElem> for Elem {
742 type Output = ExtElem;
743
744 fn mul(self, rhs: ExtElem) -> ExtElem {
746 rhs * self
747 }
748}
749
750impl ops::MulAssign for ExtElem {
757 #[inline(always)]
758 fn mul_assign(&mut self, rhs: Self) {
759 let a = &self.0;
761 let b = &rhs.0;
762 self.0 = [
763 a[0] * b[0] + NBETA * (a[1] * b[3] + a[2] * b[2] + a[3] * b[1]),
764 a[0] * b[1] + a[1] * b[0] + NBETA * (a[2] * b[3] + a[3] * b[2]),
765 a[0] * b[2] + a[1] * b[1] + a[2] * b[0] + NBETA * (a[3] * b[3]),
766 a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0],
767 ];
768 }
769}
770
771impl ops::Mul for ExtElem {
772 type Output = ExtElem;
773
774 #[inline(always)]
775 fn mul(self, rhs: ExtElem) -> ExtElem {
776 let mut lhs = self;
777 lhs *= rhs;
778 lhs
779 }
780}
781
782impl ops::Neg for ExtElem {
783 type Output = Self;
784
785 fn neg(self) -> Self {
786 ExtElem::ZERO - self
787 }
788}
789
790impl From<u32> for ExtElem {
791 fn from(x: u32) -> Self {
792 Self([Elem::from(x), Elem::ZERO, Elem::ZERO, Elem::ZERO])
793 }
794}
795
796impl From<Elem> for ExtElem {
797 fn from(x: Elem) -> Self {
798 Self([x, Elem::ZERO, Elem::ZERO, Elem::ZERO])
799 }
800}
801
802#[cfg(test)]
803mod tests {
804 use alloc::{vec, vec::Vec};
805
806 use rand::{Rng, SeedableRng};
807
808 use super::{field, Elem, ExtElem, P, P_U64};
809 use crate::field::Elem as FieldElem;
810
811 #[test]
812 pub fn roots_of_unity() {
813 field::tests::test_roots_of_unity::<Elem>();
814 }
815
816 #[test]
817 pub fn field_ops() {
818 field::tests::test_field_ops::<Elem>(P_U64);
819 }
820
821 #[test]
822 pub fn ext_field_ops() {
823 field::tests::test_ext_field_ops::<ExtElem>();
824 }
825
826 #[test]
827 pub fn linear() {
828 let x = ExtElem::new(
829 Elem::new(1880084280),
830 Elem::new(1788985953),
831 Elem::new(1273325207),
832 Elem::new(277471107),
833 );
834 let c0 = ExtElem::new(
835 Elem::new(1582815482),
836 Elem::new(2011839994),
837 Elem::new(589901),
838 Elem::new(698998108),
839 );
840 let c1 = ExtElem::new(
841 Elem::new(1262573828),
842 Elem::new(1903841444),
843 Elem::new(1738307519),
844 Elem::new(100967278),
845 );
846
847 assert_eq!(
848 x * c1,
849 ExtElem::new(
850 Elem::new(876029217),
851 Elem::new(1948387849),
852 Elem::new(498773186),
853 Elem::new(1997003991)
854 )
855 );
856 assert_eq!(
857 c0 + x * c1,
858 ExtElem::new(
859 Elem::new(445578778),
860 Elem::new(1946961922),
861 Elem::new(499363087),
862 Elem::new(682736178)
863 )
864 );
865 }
866
867 #[test]
868 fn isa_field() {
869 let mut rng = rand::rngs::SmallRng::seed_from_u64(2);
870 for _ in 0..1_000 {
873 let a = ExtElem::random(&mut rng);
874 let b = ExtElem::random(&mut rng);
875 let c = ExtElem::random(&mut rng);
876 assert_eq!(a + b, b + a);
878 assert_eq!(a * b, b * a);
879 assert_eq!(a + (b + c), (a + b) + c);
881 assert_eq!(a * (b * c), (a * b) * c);
882 assert_eq!(a * (b + c), a * b + a * c);
884 if a != ExtElem::ZERO {
886 assert_eq!(a.inv() * a, ExtElem::from(1));
887 }
888 assert_eq!(ExtElem::ZERO - a, -a);
889 assert_eq!(a + (-a), ExtElem::ZERO);
890 }
891 }
892
893 #[test]
894 fn inv() {
895 assert_eq!(Elem::new(5).inv() * Elem::new(5), Elem::new(1));
897 }
898
899 #[test]
900 fn pow() {
901 assert_eq!(Elem::new(5).pow(0), Elem::new(1));
903 assert_eq!(Elem::new(5).pow(1), Elem::new(5));
904 assert_eq!(Elem::new(5).pow(2), Elem::new(25));
905 assert_eq!(Elem::new(5).pow(1000), Elem::new(589699054));
907 assert_eq!(
908 Elem::new(5).pow((P - 2) as usize) * Elem::new(5),
909 Elem::new(1)
910 );
911 assert_eq!(Elem::new(5).pow((P - 1) as usize), Elem::new(1));
912 }
913
914 #[test]
915 fn compare_native() {
916 let mut rng = rand::rngs::SmallRng::seed_from_u64(2);
918 for _ in 0..100_000 {
919 let fa = Elem::random(&mut rng);
920 let fb = Elem::random(&mut rng);
921 let a: u64 = fa.into();
922 let b: u64 = fb.into();
923 assert_eq!(fa + fb, Elem::from(a + b));
924 assert_eq!(fa - fb, Elem::from(a + (P_U64 - b)));
925 assert_eq!(fa * fb, Elem::from(a * b));
926 }
927 }
928
929 #[test]
930 #[cfg_attr(not(debug_assertions), ignore)]
931 #[should_panic(expected = "assertion failed: self.is_valid")]
932 fn compare_against_invalid() {
933 let _ = Elem::ZERO == Elem::INVALID;
934 }
935
936 #[test]
937 fn u32s_conversions() {
938 let mut rng = rand::rngs::SmallRng::seed_from_u64(2);
939 for _ in 0..100 {
940 let elem = Elem::random(&mut rng);
941 assert_eq!(elem, Elem::from_u32_words(&elem.to_u32_words()));
942
943 let val: u32 = rng.gen();
944 assert_eq!(val, Elem::from_u32_words(&[val]).to_u32_words()[0]);
945 }
946 for _ in 0..100 {
947 let elem = ExtElem::random(&mut rng);
948 assert_eq!(elem, ExtElem::from_u32_words(&elem.to_u32_words()));
949
950 let vec: Vec<u32> = vec![rng.gen(), rng.gen(), rng.gen(), rng.gen()];
951
952 assert_eq!(vec, ExtElem::from_u32_words(&vec).to_u32_words());
953 }
954 }
955}