1use core::fmt::{self, Debug, Display, Formatter};
2use core::iter::{Product, Sum};
3use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
4
5use num_bigint::BigUint;
6use p3_field::{
7 exp_1420470955, exp_u64_by_squaring, halve_u32, AbstractField, Field, Packable, PrimeField,
8 PrimeField31, PrimeField32, PrimeField64, TwoAdicField,
9};
10use rand::distributions::{Distribution, Standard};
11use rand::Rng;
12use serde::{Deserialize, Deserializer, Serialize};
13
14const P: u32 = 0x7f000001;
20
21const MONTY_BITS: u32 = 32;
22
23const MONTY_MU: u32 = 0x81000001;
26
27const MONTY_MASK: u32 = ((1u64 << MONTY_BITS) - 1) as u32;
29
30#[derive(Copy, Clone, Default, Eq, Hash, PartialEq)]
32#[repr(transparent)] pub struct KoalaBear {
34 pub(crate) value: u32,
37}
38
39impl KoalaBear {
40 #[inline]
42 pub(crate) const fn new(n: u32) -> Self {
43 Self { value: to_monty(n) }
44 }
45}
46
47impl Ord for KoalaBear {
48 #[inline]
49 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
50 self.as_canonical_u32().cmp(&other.as_canonical_u32())
51 }
52}
53
54impl PartialOrd for KoalaBear {
55 #[inline]
56 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
57 Some(self.cmp(other))
58 }
59}
60
61impl Display for KoalaBear {
62 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
63 Display::fmt(&self.as_canonical_u32(), f)
64 }
65}
66
67impl Debug for KoalaBear {
68 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69 Debug::fmt(&self.as_canonical_u32(), f)
70 }
71}
72
73impl Distribution<KoalaBear> for Standard {
74 #[inline]
75 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> KoalaBear {
76 loop {
77 let next_u31 = rng.next_u32() >> 1;
78 let is_canonical = next_u31 < P;
79 if is_canonical {
80 return KoalaBear { value: next_u31 };
81 }
82 }
83 }
84}
85
86impl Serialize for KoalaBear {
87 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
88 serializer.serialize_u32(self.as_canonical_u32())
89 }
90}
91
92impl<'de> Deserialize<'de> for KoalaBear {
93 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
94 let val = u32::deserialize(d)?;
95 Ok(KoalaBear::from_canonical_u32(val))
96 }
97}
98
99const MONTY_ZERO: u32 = to_monty(0);
100const MONTY_ONE: u32 = to_monty(1);
101const MONTY_TWO: u32 = to_monty(2);
102const MONTY_NEG_ONE: u32 = to_monty(P - 1);
103
104impl Packable for KoalaBear {}
105
106impl AbstractField for KoalaBear {
107 type F = Self;
108
109 fn zero() -> Self {
110 Self { value: MONTY_ZERO }
111 }
112 fn one() -> Self {
113 Self { value: MONTY_ONE }
114 }
115 fn two() -> Self {
116 Self { value: MONTY_TWO }
117 }
118 fn neg_one() -> Self {
119 Self {
120 value: MONTY_NEG_ONE,
121 }
122 }
123
124 #[inline]
125 fn from_f(f: Self::F) -> Self {
126 f
127 }
128
129 #[inline]
130 fn from_bool(b: bool) -> Self {
131 Self::from_canonical_u32(b as u32)
132 }
133
134 #[inline]
135 fn from_canonical_u8(n: u8) -> Self {
136 Self::from_canonical_u32(n as u32)
137 }
138
139 #[inline]
140 fn from_canonical_u16(n: u16) -> Self {
141 Self::from_canonical_u32(n as u32)
142 }
143
144 #[inline]
145 fn from_canonical_u32(n: u32) -> Self {
146 debug_assert!(n < P);
147 Self::from_wrapped_u32(n)
148 }
149
150 #[inline]
151 fn from_canonical_u64(n: u64) -> Self {
152 debug_assert!(n < P as u64);
153 Self::from_canonical_u32(n as u32)
154 }
155
156 #[inline]
157 fn from_canonical_usize(n: usize) -> Self {
158 debug_assert!(n < P as usize);
159 Self::from_canonical_u32(n as u32)
160 }
161
162 #[inline]
163 fn from_wrapped_u32(n: u32) -> Self {
164 Self { value: to_monty(n) }
165 }
166
167 #[inline]
168 fn from_wrapped_u64(n: u64) -> Self {
169 Self {
170 value: to_monty_64(n),
171 }
172 }
173
174 #[inline]
175 fn generator() -> Self {
176 Self::from_canonical_u32(0x3)
177 }
178}
179
180impl Field for KoalaBear {
181 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
182 type Packing = crate::PackedKoalaBearNeon;
183 #[cfg(all(
184 target_arch = "x86_64",
185 target_feature = "avx2",
186 not(target_feature = "avx512f")
187 ))]
188 type Packing = crate::PackedKoalaBearAVX2;
189 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
190 type Packing = crate::PackedKoalaBearAVX512;
191 #[cfg(not(any(
192 all(target_arch = "aarch64", target_feature = "neon"),
193 all(
194 target_arch = "x86_64",
195 target_feature = "avx2",
196 not(target_feature = "avx512f")
197 ),
198 all(target_arch = "x86_64", target_feature = "avx512f"),
199 )))]
200 type Packing = Self;
201
202 #[inline]
203 fn mul_2exp_u64(&self, exp: u64) -> Self {
204 let product = (self.value as u64) << exp;
205 let value = (product % (P as u64)) as u32;
206 Self { value }
207 }
208
209 #[inline]
210 fn exp_u64_generic<AF: AbstractField<F = Self>>(val: AF, power: u64) -> AF {
211 match power {
212 1420470955 => exp_1420470955(val), _ => exp_u64_by_squaring(val, power),
214 }
215 }
216
217 fn try_inverse(&self) -> Option<Self> {
218 if self.is_zero() {
219 return None;
220 }
221
222 let p1 = *self;
227 let p10 = p1.square();
228 let p11 = p10 * p1;
229 let p1100 = p11.exp_power_of_2(2);
230 let p1111 = p1100 * p11;
231 let p110000 = p1100.exp_power_of_2(2);
232 let p111111 = p110000 * p1111;
233 let p1111110000 = p111111.exp_power_of_2(4);
234 let p1111111111 = p1111110000 * p1111;
235 let p11111101111 = p1111111111 * p1111110000;
236 let p111111011110000000000 = p11111101111.exp_power_of_2(10);
237 let p111111011111111111111 = p111111011110000000000 * p1111111111;
238 let p1111110111111111111110000000000 = p111111011111111111111.exp_power_of_2(10);
239 let p1111110111111111111111111111111 = p1111110111111111111110000000000 * p1111111111;
240
241 Some(p1111110111111111111111111111111)
242 }
243
244 #[inline]
245 fn halve(&self) -> Self {
246 KoalaBear {
247 value: halve_u32::<P>(self.value),
248 }
249 }
250
251 #[inline]
252 fn order() -> BigUint {
253 P.into()
254 }
255}
256
257impl PrimeField for KoalaBear {
258 fn as_canonical_biguint(&self) -> BigUint {
259 <Self as PrimeField32>::as_canonical_u32(self).into()
260 }
261}
262
263impl PrimeField64 for KoalaBear {
264 const ORDER_U64: u64 = <Self as PrimeField32>::ORDER_U32 as u64;
265
266 #[inline]
267 fn as_canonical_u64(&self) -> u64 {
268 u64::from(self.as_canonical_u32())
269 }
270}
271
272impl PrimeField32 for KoalaBear {
273 const ORDER_U32: u32 = P;
274
275 #[inline]
276 fn as_canonical_u32(&self) -> u32 {
277 from_monty(self.value)
278 }
279}
280
281impl PrimeField31 for KoalaBear {}
282
283impl TwoAdicField for KoalaBear {
284 const TWO_ADICITY: usize = 24;
285
286 fn two_adic_generator(bits: usize) -> Self {
287 assert!(bits <= Self::TWO_ADICITY);
288 match bits {
289 0 => Self::one(),
290 1 => Self::from_canonical_u32(0x7f000000),
291 2 => Self::from_canonical_u32(0x7e010002),
292 3 => Self::from_canonical_u32(0x6832fe4a),
293 4 => Self::from_canonical_u32(0x8dbd69c),
294 5 => Self::from_canonical_u32(0xa28f031),
295 6 => Self::from_canonical_u32(0x5c4a5b99),
296 7 => Self::from_canonical_u32(0x29b75a80),
297 8 => Self::from_canonical_u32(0x17668b8a),
298 9 => Self::from_canonical_u32(0x27ad539b),
299 10 => Self::from_canonical_u32(0x334d48c7),
300 11 => Self::from_canonical_u32(0x7744959c),
301 12 => Self::from_canonical_u32(0x768fc6fa),
302 13 => Self::from_canonical_u32(0x303964b2),
303 14 => Self::from_canonical_u32(0x3e687d4d),
304 15 => Self::from_canonical_u32(0x45a60e61),
305 16 => Self::from_canonical_u32(0x6e2f4d7a),
306 17 => Self::from_canonical_u32(0x163bd499),
307 18 => Self::from_canonical_u32(0x6c4a8a45),
308 19 => Self::from_canonical_u32(0x143ef899),
309 20 => Self::from_canonical_u32(0x514ddcad),
310 21 => Self::from_canonical_u32(0x484ef19b),
311 22 => Self::from_canonical_u32(0x205d63c3),
312 23 => Self::from_canonical_u32(0x68e7dd49),
313 24 => Self::from_canonical_u32(0x6ac49f88),
314 _ => unreachable!("Already asserted that bits <= Self::TWO_ADICITY"),
315 }
316 }
317}
318
319impl Add for KoalaBear {
320 type Output = Self;
321
322 #[inline]
323 fn add(self, rhs: Self) -> Self {
324 let mut sum = self.value + rhs.value;
325 let (corr_sum, over) = sum.overflowing_sub(P);
326 if !over {
327 sum = corr_sum;
328 }
329 Self { value: sum }
330 }
331}
332
333impl AddAssign for KoalaBear {
334 #[inline]
335 fn add_assign(&mut self, rhs: Self) {
336 *self = *self + rhs;
337 }
338}
339
340impl Sum for KoalaBear {
341 #[inline]
342 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
343 let sum = iter.map(|x| (x.value as u64)).sum::<u64>();
348 Self {
349 value: (sum % P as u64) as u32,
350 }
351 }
352}
353
354impl Sub for KoalaBear {
355 type Output = Self;
356
357 #[inline]
358 fn sub(self, rhs: Self) -> Self {
359 let (mut diff, over) = self.value.overflowing_sub(rhs.value);
360 let corr = if over { P } else { 0 };
361 diff = diff.wrapping_add(corr);
362 Self { value: diff }
363 }
364}
365
366impl SubAssign for KoalaBear {
367 #[inline]
368 fn sub_assign(&mut self, rhs: Self) {
369 *self = *self - rhs;
370 }
371}
372
373impl Neg for KoalaBear {
374 type Output = Self;
375
376 #[inline]
377 fn neg(self) -> Self::Output {
378 Self::zero() - self
379 }
380}
381
382impl Mul for KoalaBear {
383 type Output = Self;
384
385 #[inline]
386 fn mul(self, rhs: Self) -> Self {
387 let long_prod = self.value as u64 * rhs.value as u64;
388 Self {
389 value: monty_reduce(long_prod),
390 }
391 }
392}
393
394impl MulAssign for KoalaBear {
395 #[inline]
396 fn mul_assign(&mut self, rhs: Self) {
397 *self = *self * rhs;
398 }
399}
400
401impl Product for KoalaBear {
402 #[inline]
403 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
404 iter.reduce(|x, y| x * y).unwrap_or(Self::one())
405 }
406}
407
408impl Div for KoalaBear {
409 type Output = Self;
410
411 #[allow(clippy::suspicious_arithmetic_impl)]
412 #[inline]
413 fn div(self, rhs: Self) -> Self {
414 self * rhs.inverse()
415 }
416}
417
418#[inline]
419#[must_use]
420const fn to_monty(x: u32) -> u32 {
421 (((x as u64) << MONTY_BITS) % P as u64) as u32
422}
423
424#[inline]
427#[must_use]
428pub(crate) const fn to_koalabear_array<const N: usize>(input: [u32; N]) -> [KoalaBear; N] {
429 let mut output = [KoalaBear { value: 0 }; N];
430 let mut i = 0;
431 loop {
432 if i == N {
433 break;
434 }
435 output[i].value = to_monty(input[i]);
436 i += 1;
437 }
438 output
439}
440
441#[inline]
442#[must_use]
443const fn to_monty_64(x: u64) -> u32 {
444 (((x as u128) << MONTY_BITS) % P as u128) as u32
445}
446
447#[inline]
448#[must_use]
449const fn from_monty(x: u32) -> u32 {
450 monty_reduce(x as u64)
451}
452
453#[inline]
455#[must_use]
456pub(crate) const fn monty_reduce(x: u64) -> u32 {
457 let t = x.wrapping_mul(MONTY_MU as u64) & (MONTY_MASK as u64);
458 let u = t * (P as u64);
459
460 let (x_sub_u, over) = x.overflowing_sub(u);
461 let x_sub_u_hi = (x_sub_u >> MONTY_BITS) as u32;
462 let corr = if over { P } else { 0 };
463 x_sub_u_hi.wrapping_add(corr)
464}
465
466#[cfg(test)]
467mod tests {
468 use p3_field_testing::{test_field, test_two_adic_field};
469
470 use super::*;
471
472 type F = KoalaBear;
473
474 #[test]
475 fn test_koala_bear_two_adicity_generators() {
476 let base = KoalaBear::from_canonical_u32(0x6ac49f88);
477 for bits in 0..=KoalaBear::TWO_ADICITY {
478 assert_eq!(
479 KoalaBear::two_adic_generator(bits),
480 base.exp_power_of_2(KoalaBear::TWO_ADICITY - bits)
481 );
482 }
483 }
484
485 #[test]
486 fn test_koala_bear() {
487 let f = F::from_canonical_u32(100);
488 assert_eq!(f.as_canonical_u64(), 100);
489
490 let f = F::from_canonical_u32(0);
491 assert!(f.is_zero());
492
493 let f = F::from_wrapped_u32(F::ORDER_U32);
494 assert!(f.is_zero());
495
496 let f_1 = F::one();
497 let f_1_copy = F::from_canonical_u32(1);
498
499 let expected_result = F::zero();
500 assert_eq!(f_1 - f_1_copy, expected_result);
501
502 let expected_result = F::two();
503 assert_eq!(f_1 + f_1_copy, expected_result);
504
505 let f_2 = F::from_canonical_u32(2);
506 let expected_result = F::from_canonical_u32(3);
507 assert_eq!(f_1 + f_1_copy * f_2, expected_result);
508
509 let expected_result = F::from_canonical_u32(5);
510 assert_eq!(f_1 + f_2 * f_2, expected_result);
511
512 let f_p_minus_1 = F::from_canonical_u32(F::ORDER_U32 - 1);
513 let expected_result = F::zero();
514 assert_eq!(f_1 + f_p_minus_1, expected_result);
515
516 let f_p_minus_2 = F::from_canonical_u32(F::ORDER_U32 - 2);
517 let expected_result = F::from_canonical_u32(F::ORDER_U32 - 3);
518 assert_eq!(f_p_minus_1 + f_p_minus_2, expected_result);
519
520 let expected_result = F::from_canonical_u32(1);
521 assert_eq!(f_p_minus_1 - f_p_minus_2, expected_result);
522
523 let expected_result = f_p_minus_1;
524 assert_eq!(f_p_minus_2 - f_p_minus_1, expected_result);
525
526 let expected_result = f_p_minus_2;
527 assert_eq!(f_p_minus_1 - f_1, expected_result);
528
529 let m1 = F::from_canonical_u32(0x34167c58);
530 let m2 = F::from_canonical_u32(0x61f3207b);
531 let expected_prod = F::from_canonical_u32(0x54b46b81);
532 assert_eq!(m1 * m2, expected_prod);
533
534 assert_eq!(m1.exp_u64(1420470955).exp_const_u64::<3>(), m1);
535 assert_eq!(m2.exp_u64(1420470955).exp_const_u64::<3>(), m2);
536 assert_eq!(f_2.exp_u64(1420470955).exp_const_u64::<3>(), f_2);
537
538 let f_serialized = serde_json::to_string(&f).unwrap();
539 let f_deserialized: F = serde_json::from_str(&f_serialized).unwrap();
540 assert_eq!(f, f_deserialized);
541
542 let f_1_serialized = serde_json::to_string(&f_1).unwrap();
543 let f_1_deserialized: F = serde_json::from_str(&f_1_serialized).unwrap();
544 let f_1_serialized_again = serde_json::to_string(&f_1_deserialized).unwrap();
545 let f_1_deserialized_again: F = serde_json::from_str(&f_1_serialized_again).unwrap();
546 assert_eq!(f_1, f_1_deserialized);
547 assert_eq!(f_1, f_1_deserialized_again);
548
549 let f_2_serialized = serde_json::to_string(&f_2).unwrap();
550 let f_2_deserialized: F = serde_json::from_str(&f_2_serialized).unwrap();
551 assert_eq!(f_2, f_2_deserialized);
552
553 let f_p_minus_1_serialized = serde_json::to_string(&f_p_minus_1).unwrap();
554 let f_p_minus_1_deserialized: F = serde_json::from_str(&f_p_minus_1_serialized).unwrap();
555 assert_eq!(f_p_minus_1, f_p_minus_1_deserialized);
556
557 let f_p_minus_2_serialized = serde_json::to_string(&f_p_minus_2).unwrap();
558 let f_p_minus_2_deserialized: F = serde_json::from_str(&f_p_minus_2_serialized).unwrap();
559 assert_eq!(f_p_minus_2, f_p_minus_2_deserialized);
560
561 let m1_serialized = serde_json::to_string(&m1).unwrap();
562 let m1_deserialized: F = serde_json::from_str(&m1_serialized).unwrap();
563 assert_eq!(m1, m1_deserialized);
564
565 let m2_serialized = serde_json::to_string(&m2).unwrap();
566 let m2_deserialized: F = serde_json::from_str(&m2_serialized).unwrap();
567 assert_eq!(m2, m2_deserialized);
568 }
569
570 test_field!(crate::KoalaBear);
571 test_two_adic_field!(crate::KoalaBear);
572}