1use crate::bytes::FieldBytes;
10use crate::error::Error;
11use crate::field::Field;
12
13const P: u64 = 2_147_483_647;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38pub struct BabyBear(u64);
39
40impl BabyBear {
41 #[must_use]
43 pub fn new(value: u64) -> Self {
44 Self(value % P)
45 }
46
47 #[must_use]
49 pub fn value(self) -> u64 {
50 self.0
51 }
52}
53
54impl core::fmt::Display for BabyBear {
55 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl std::ops::Add for BabyBear {
61 type Output = Self;
62 fn add(self, rhs: Self) -> Self {
63 Self((self.0 + rhs.0) % P)
64 }
65}
66
67impl std::ops::Sub for BabyBear {
68 type Output = Self;
69 fn sub(self, rhs: Self) -> Self {
70 Self((self.0 + P - rhs.0) % P)
71 }
72}
73
74impl std::ops::Mul for BabyBear {
75 type Output = Self;
76 fn mul(self, rhs: Self) -> Self {
77 Self((self.0 * rhs.0) % P)
78 }
79}
80
81impl std::ops::Neg for BabyBear {
82 type Output = Self;
83 fn neg(self) -> Self {
84 Self((P - self.0) % P)
85 }
86}
87
88impl Field for BabyBear {
89 fn zero() -> Self {
90 Self(0)
91 }
92
93 fn one() -> Self {
94 Self(1)
95 }
96
97 fn inv(&self) -> Result<Self, Error> {
98 if self.0 == 0 {
99 Err(Error::DivisionByZero)
100 } else {
101 Ok(Self(pow_mod(self.0, P - 2, P)))
103 }
104 }
105}
106
107impl FieldBytes for BabyBear {
108 fn to_le_bytes(&self) -> Vec<u8> {
109 u32::try_from(self.0)
112 .map(|n| n.to_le_bytes().to_vec())
113 .unwrap_or_default()
114 }
115
116 fn from_le_bytes(bytes: &[u8]) -> Result<Self, Error> {
117 bytes
118 .get(..4)
119 .ok_or(Error::InvalidFieldEncoding)
120 .and_then(|slice| <[u8; 4]>::try_from(slice).map_err(|_| Error::InvalidFieldEncoding))
121 .map(u32::from_le_bytes)
122 .map(|n| Self::new(u64::from(n)))
123 }
124}
125
126fn pow_mod(base: u64, exp: u64, modulus: u64) -> u64 {
132 std::iter::successors(Some(base % modulus), |&b| Some((b * b) % modulus))
133 .zip(0..u64::BITS)
134 .filter(|&(_, i)| (exp >> i) & 1 == 1)
135 .map(|(p, _)| p)
136 .fold(1u64, |acc, p| (acc * p) % modulus)
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn zero_is_additive_identity() {
145 let a = BabyBear::new(123_456);
146 assert_eq!(a + BabyBear::zero(), a);
147 assert_eq!(BabyBear::zero() + a, a);
148 }
149
150 #[test]
151 fn one_is_multiplicative_identity() {
152 let a = BabyBear::new(123_456);
153 assert_eq!(a * BabyBear::one(), a);
154 assert_eq!(BabyBear::one() * a, a);
155 }
156
157 #[test]
158 fn additive_inverse() {
159 let a = BabyBear::new(999_999);
160 assert_eq!(a + (-a), BabyBear::zero());
161 }
162
163 #[test]
164 fn multiplicative_inverse() -> Result<(), Error> {
165 let a = BabyBear::new(42);
166 let a_inv = a.inv()?;
167 assert_eq!(a * a_inv, BabyBear::one());
168 Ok(())
169 }
170
171 #[test]
172 fn inverse_of_zero_fails() {
173 let result = BabyBear::zero().inv();
174 assert!(result.is_err());
175 }
176
177 #[test]
178 fn sample_inverses() -> Result<(), Error> {
179 let samples = [1u64, 2, 7, 100, 1_000_000, P - 1, P - 2];
180 samples.iter().try_for_each(|&v| {
181 let a = BabyBear::new(v);
182 let a_inv = a.inv()?;
183 assert_eq!(a * a_inv, BabyBear::one(), "failed for {v}");
184 Ok(())
185 })
186 }
187
188 #[test]
189 fn subtraction_is_add_neg() {
190 let a = BabyBear::new(1_000_000);
191 let b = BabyBear::new(500_000);
192 assert_eq!(a - b, a + (-b));
193 }
194
195 #[test]
196 fn multiplication_is_commutative() {
197 let a = BabyBear::new(12_345);
198 let b = BabyBear::new(67_890);
199 assert_eq!(a * b, b * a);
200 }
201
202 #[test]
203 fn distributivity() {
204 let a = BabyBear::new(111);
205 let b = BabyBear::new(222);
206 let c = BabyBear::new(333);
207 assert_eq!(a * (b + c), a * b + a * c);
208 }
209
210 #[test]
211 fn new_reduces_mod_p() {
212 assert_eq!(BabyBear::new(P), BabyBear::new(0));
213 assert_eq!(BabyBear::new(P + 1), BabyBear::new(1));
214 assert_eq!(BabyBear::new(2 * P), BabyBear::new(0));
215 }
216
217 #[test]
218 fn bytes_roundtrip() -> Result<(), Error> {
219 let a = BabyBear::new(1_234_567);
220 let bytes = a.to_le_bytes();
221 let b = BabyBear::from_le_bytes(&bytes)?;
222 assert_eq!(a, b);
223 Ok(())
224 }
225
226 #[test]
227 fn bytes_zero_roundtrip() -> Result<(), Error> {
228 let a = BabyBear::zero();
229 let b = BabyBear::from_le_bytes(&a.to_le_bytes())?;
230 assert_eq!(a, b);
231 Ok(())
232 }
233
234 #[test]
235 fn bytes_empty_fails() {
236 let result = BabyBear::from_le_bytes(&[]);
237 assert!(result.is_err());
238 }
239}