lambdaworks_math/field/fields/binary/
field.rs1use core::cmp::Ordering;
2use core::fmt;
3use core::iter::{Product, Sum};
4use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
5
6#[derive(Debug)]
18pub enum BinaryFieldError {
19 InverseOfZero,
21}
22
23#[derive(Clone, Copy, Debug)]
24#[derive(Default)]
30pub struct TowerFieldElement {
31 pub value: u128,
35 pub num_level: usize,
39}
40
41impl TowerFieldElement {
42 pub fn new(val: u128, num_level: usize) -> Self {
45 let safe_level = if num_level > 7 { 7 } else { num_level };
47
48 let bits = 1 << safe_level;
50 let mask = if bits >= 128 {
51 u128::MAX
52 } else {
53 (1 << bits) - 1
54 };
55
56 Self {
57 value: val & mask,
59 num_level: safe_level,
60 }
61 }
62
63 pub fn is_zero(&self) -> bool {
65 self.value == 0
66 }
67
68 #[inline]
70 pub fn is_one(&self) -> bool {
71 self.value == 1
72 }
73
74 #[inline]
76 pub fn value(&self) -> u128 {
77 self.value
78 }
79
80 #[inline]
82 pub fn num_level(&self) -> usize {
83 self.num_level
84 }
85
86 #[inline]
89 pub fn num_bits(&self) -> usize {
90 1 << self.num_level()
91 }
92
93 #[cfg(feature = "std")]
95 pub fn to_binary_string(&self) -> String {
96 format!("{:0width$b}", self.value, width = self.num_bits())
97 }
98
99 pub fn split(&self) -> (Self, Self) {
103 let half_bits = self.num_bits() / 2;
104 let mask = (1 << half_bits) - 1;
105 let lo = self.value() & mask;
106 let hi = (self.value() >> half_bits) & mask;
107
108 (
109 Self::new(hi, self.num_level() - 1),
110 Self::new(lo, self.num_level() - 1),
111 )
112 }
113
114 pub fn join(&self, low: &Self) -> Self {
118 let joined = (self.value() << self.num_bits()) | low.value();
119 Self::new(joined, self.num_level() + 1)
120 }
121
122 pub fn extend_num_level(&mut self, new_level: usize) {
124 if self.num_level() < new_level {
125 self.num_level = new_level;
126 }
127 }
128
129 pub fn zero() -> Self {
131 Self::new(0, 0)
132 }
133
134 pub fn one() -> Self {
136 Self::new(1, 0)
137 }
138
139 fn add_elements(&self, other: &Self) -> Self {
141 let num_level = self.num_level().max(other.num_level());
142 Self::new(self.value() ^ other.value(), num_level)
143 }
144
145 fn mul(self, other: Self) -> Self {
156 match self.num_level().cmp(&other.num_level()) {
157 Ordering::Greater => {
158 let (a_hi, a_lo) = self.split();
160 a_hi.mul(other).join(&a_lo.mul(other))
162 }
163 Ordering::Less => {
164 other.mul(self)
166 }
167 Ordering::Equal => {
168 if self.num_level() == 0 {
170 return Self::new(self.value() & other.value(), 0);
172 }
173
174 let (a_high, a_low) = self.split();
176 let (b_high, b_low) = other.split();
177
178 let low_product = a_low.mul(b_low); let high_product = a_high.mul(b_high); let x_value = if self.num_level() == 1 {
184 Self::new(1, 0)
185 } else {
186 Self::new(1 << (self.num_bits() / 4), self.num_level() - 1)
187 };
188
189 let shifted_high_product = high_product.mul(x_value);
191
192 let sum_product = (a_low + a_high).mul(b_low + b_high);
196 let middle_term = sum_product - low_product - high_product;
197
198 (shifted_high_product + middle_term).join(&(high_product + low_product))
200 }
201 }
202 }
203
204 pub fn inv(&self) -> Result<Self, BinaryFieldError> {
209 if self.is_zero() {
210 return Err(BinaryFieldError::InverseOfZero);
211 }
212 if self.num_level() <= 1 || self.num_bits() <= 4 {
213 let exponent = (1 << self.num_bits()) - 2;
214 Ok(Self::pow(self, exponent as u32))
215 } else {
216 let (a_hi, a_lo) = self.split();
217 let two_pow_k_minus_one = Self::new(1 << (self.num_bits() / 4), self.num_level() - 1);
218 let a_lo_next = a_lo + a_hi * two_pow_k_minus_one;
221
222 let delta = a_lo * a_lo_next + a_hi * a_hi;
224
225 let delta_inverse = delta.inv()?;
227
228 let out_hi = delta_inverse * a_hi;
230 let out_lo = delta_inverse * a_lo_next;
231
232 Ok(out_hi.join(&out_lo))
234 }
235 }
236
237 pub fn pow(&self, exp: u32) -> Self {
239 let mut result = Self::one();
240 let mut base = *self;
241 let mut exp_val = exp;
242
243 while exp_val > 0 {
244 if exp_val & 1 == 1 {
245 result *= base;
246 }
247 base = base * base;
248 exp_val >>= 1;
249 }
250
251 result
252 }
253}
254
255impl PartialEq<TowerFieldElement> for TowerFieldElement {
256 fn eq(&self, other: &Self) -> bool {
257 self.value() == other.value()
258 }
259}
260
261impl Eq for TowerFieldElement {}
262
263impl Add for TowerFieldElement {
264 type Output = Self;
265
266 fn add(self, other: Self) -> Self {
267 self.add_elements(&other)
269 }
270}
271
272impl<'a> Add<&'a TowerFieldElement> for &'a TowerFieldElement {
273 type Output = TowerFieldElement;
274
275 fn add(self, other: &'a TowerFieldElement) -> TowerFieldElement {
276 self.add_elements(other)
278 }
279}
280
281impl AddAssign for TowerFieldElement {
282 fn add_assign(&mut self, other: Self) {
283 *self = *self + other;
284 }
285}
286#[allow(clippy::suspicious_arithmetic_impl)]
287impl Sub for TowerFieldElement {
288 type Output = Self;
289
290 fn sub(self, other: Self) -> Self {
291 self + other
293 }
294}
295
296impl Neg for TowerFieldElement {
297 type Output = Self;
298
299 fn neg(self) -> Self {
300 self
302 }
303}
304
305impl Mul for TowerFieldElement {
306 type Output = Self;
307
308 fn mul(self, other: Self) -> Self {
309 self.mul(other)
310 }
311}
312
313impl Mul<&TowerFieldElement> for &TowerFieldElement {
314 type Output = TowerFieldElement;
315
316 fn mul(self, other: &TowerFieldElement) -> TowerFieldElement {
317 <TowerFieldElement as Mul<TowerFieldElement>>::mul(*self, *other)
318 }
319}
320
321impl MulAssign for TowerFieldElement {
322 fn mul_assign(&mut self, other: Self) {
323 *self = *self * other;
324 }
325}
326
327impl Product for TowerFieldElement {
328 fn product<I>(iter: I) -> Self
329 where
330 I: Iterator<Item = Self>,
331 {
332 iter.fold(Self::one(), |acc, x| acc * x)
333 }
334}
335
336impl Sum for TowerFieldElement {
337 fn sum<I>(iter: I) -> Self
338 where
339 I: Iterator<Item = Self>,
340 {
341 iter.fold(Self::zero(), |acc, x| acc + x)
342 }
343}
344
345impl fmt::Display for TowerFieldElement {
346 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347 write!(f, "{}", self.value)
348 }
349}
350
351impl From<u128> for TowerFieldElement {
352 fn from(val: u128) -> Self {
353 TowerFieldElement::new(val, 7)
354 }
355}
356
357impl From<u64> for TowerFieldElement {
358 fn from(val: u64) -> Self {
359 TowerFieldElement::new(val as u128, 6)
360 }
361}
362
363impl From<u32> for TowerFieldElement {
364 fn from(val: u32) -> Self {
365 TowerFieldElement::new(val as u128, 5)
366 }
367}
368
369impl From<u16> for TowerFieldElement {
370 fn from(val: u16) -> Self {
371 TowerFieldElement::new(val as u128, 4)
372 }
373}
374
375impl From<u8> for TowerFieldElement {
376 fn from(val: u8) -> Self {
377 TowerFieldElement::new(val as u128, 3)
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use proptest::prelude::*;
385
386 #[test]
387 fn test_new_safe() {
388 let elem = TowerFieldElement::new(0, 8);
390 assert_eq!(elem.num_level, 7); let elem = TowerFieldElement::new(4, 1); assert_eq!(elem.value, 0); }
396
397 #[test]
398 fn test_addition() {
399 let a = TowerFieldElement::new(5, 9); let b = TowerFieldElement::new(3, 2); let c = a + b;
403 assert_eq!(c.value, 6);
405 assert_eq!(c.num_level, 7);
406
407 let d = b + a;
409 assert_eq!(d, c);
410 }
411
412 #[test]
413 fn mul_in_level_0() {
414 let a = TowerFieldElement::new(0, 0);
415 let b = TowerFieldElement::new(1, 0);
416 assert_eq!(a * a, a);
417 assert_eq!(a * b, a);
418 assert_eq!(b * b, b);
419 }
420
421 #[test]
422 fn mul_in_level_1() {
423 let a = TowerFieldElement::new(0b00, 1); let b = TowerFieldElement::new(0b01, 1); let c = TowerFieldElement::new(0b10, 1); let d = TowerFieldElement::new(0b11, 1); assert_eq!(a * a, a);
428 assert_eq!(a * b, a);
429 assert_eq!(b * c, c);
430 assert_eq!(c * d, b);
431 }
432
433 #[test]
434 fn mul_in_level_2() {
435 let a = TowerFieldElement::new(0b0000, 2); let b = TowerFieldElement::new(0b0001, 2); let c = TowerFieldElement::new(0b0010, 2); let d = TowerFieldElement::new(0b0011, 2); let e = TowerFieldElement::new(0b0100, 2); let f = TowerFieldElement::new(0b0101, 2); let g = TowerFieldElement::new(0b0110, 2); let h = TowerFieldElement::new(0b0111, 2); let i = TowerFieldElement::new(0b1000, 2); let j = TowerFieldElement::new(0b1001, 2); let k = TowerFieldElement::new(0b1010, 2); let l = TowerFieldElement::new(0b1011, 2); let n = TowerFieldElement::new(0b1100, 2); let m = TowerFieldElement::new(0b1101, 2); let o = TowerFieldElement::new(0b1110, 2); let p = TowerFieldElement::new(0b1111, 2); assert_eq!(a * p, a); assert_eq!(a * l, a); assert_eq!(b * m, m); assert_eq!(c * e, i); assert_eq!(c * c, d); assert_eq!(g * h, n); assert_eq!(k * j, b); assert_eq!(j * f, d); assert_eq!(e * e, j); assert_eq!(n * o, k); }
463
464 #[test]
465 fn mul_between_different_levels() {
466 let a = TowerFieldElement::new(0b10, 1); let b = TowerFieldElement::new(0b0100, 2); let c = TowerFieldElement::new(0b1000, 2); assert_eq!(a * b, c);
470 }
471
472 #[test]
473 fn test_correct_level_mul() {
474 let a = TowerFieldElement::new(0b1111, 5);
475 let b = TowerFieldElement::new(0b1010, 2);
476 assert_eq!((a * b).num_level, 5);
477 }
478
479 #[test]
480 fn mul_is_asociative() {
481 let a = TowerFieldElement::new(83, 7);
482 let b = TowerFieldElement::new(31, 5);
483 let c = TowerFieldElement::new(3, 2);
484 let ab = a * b;
485 let bc = b * c;
486 assert_eq!(ab * c, a * bc);
487 }
488
489 #[test]
490 fn mul_is_conmutative() {
491 let a = TowerFieldElement::new(127, 7);
492 let b = TowerFieldElement::new(6, 3);
493 let ab = a * b;
494 let ba = b * a;
495 assert_eq!(ab, ba);
496 }
497
498 #[test]
499 fn test_inverse() {
500 let a0 = TowerFieldElement::new(1, 0);
501 let inv_a0 = a0.inv().unwrap();
502 assert_eq!(inv_a0.value, 1);
503 assert_eq!(inv_a0.num_level, 0);
504
505 let a1 = TowerFieldElement::new(2, 1);
506 let inv_a1 = a1.inv().unwrap();
507 assert_eq!(inv_a1.value, 3); assert_eq!(inv_a1.num_level, 1);
509
510 let a2 = TowerFieldElement::new(15, 4);
512 let inv_a2 = a2.inv().unwrap();
513 let one = TowerFieldElement::new(1, 4);
514 assert_eq!(a2 * inv_a2, one);
515
516 let a3 = TowerFieldElement::new(30, 5);
517 let inv_a3 = a3.inv().unwrap();
518 let one = TowerFieldElement::new(1, 5);
519 assert_eq!(a3 * inv_a3, one);
520
521 let zero = TowerFieldElement::zero();
522 assert!(matches!(zero.inv(), Err(BinaryFieldError::InverseOfZero)));
523 }
524
525 #[test]
526 fn test_multiplication_overflow() {
527 for level in 0..7 {
528 let max_value = (1u128 << (1 << level)) - 1; let a = TowerFieldElement::new(max_value, level);
530 let b = TowerFieldElement::new(max_value, level);
531
532 let result = a * b;
533
534 assert!(result.value < (1u128 << result.num_bits()));
536 }
537 }
538
539 #[test]
540 fn test_split_join_consistency() {
541 for i in 0..20 {
543 let original = TowerFieldElement::new(i, 3);
544 let (hi, lo) = original.split();
545 let rejoined = hi.join(&lo);
546
547 assert_eq!(rejoined, original);
548 }
549 }
550 #[cfg(feature = "std")]
551 #[test]
552 fn test_bin_representation() {
553 let a = TowerFieldElement::new(0b1010, 5);
554 assert_eq!(a.to_binary_string(), "00000000000000000000000000001010");
555 let b = TowerFieldElement::new(0b1010, 4);
556 assert_eq!(b.to_binary_string(), "0000000000001010");
557 }
558
559 fn arb_tower_element_any() -> impl Strategy<Value = TowerFieldElement> {
565 (0usize..=7)
566 .prop_flat_map(|level| {
567 let max_val = if level == 0 {
568 1
569 } else if (1usize << level) >= 128 {
570 u128::MAX
571 } else {
572 (1u128 << (1 << level)) - 1
573 };
574 (Just(level), 0u128..=max_val)
575 })
576 .prop_map(|(level, val)| TowerFieldElement::new(val, level))
577 }
578
579 #[cfg(feature = "std")]
580 proptest! {
581 #[test]
584 fn test_mul_commutative(a in arb_tower_element_any(), b in arb_tower_element_any()) {
585 prop_assert_eq!(a * b, b * a);
586 }
587
588 #[test]
591 fn test_mul_associative(a in arb_tower_element_any(), b in arb_tower_element_any(), c in arb_tower_element_any()) {
592 prop_assert_eq!((a * b) * c, a * (b * c));
593 }
594 }
595}