lambdaworks_math/field/fields/binary/
field.rs

1use core::cmp::Ordering;
2use core::fmt;
3use core::iter::{Product, Sum};
4use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
5
6// Implementation of binary fields of the form GF(2^{2^n}) (i.e. a finite field of 2^{2^n} elements) by constructing a tower of field extensions.
7// The basic idea is to represent an element of each field as a multi-variable polynomial with binary coefficients in GF(2) = {0, 1}.
8// The coefficients of each polynomial are stored as bits in a `u128` integer.
9// The tower structure is built recursively, with each level representing an extension of the previous field.
10// In each level n, polynomials have n variables that satisfy:
11// (x_i)² = x_i * x_{i-1} + 1
12
13// For more details, see:
14// - Lambdaclass blog post about the use of binary fields in SNARKs: https://blog.lambdaclass.com/snarks-on-binary-fields-binius/
15// - Vitalik Buterin's Binius: https://vitalik.eth.limo/general/2024/04/29/binius.html
16
17#[derive(Debug)]
18pub enum BinaryFieldError {
19    /// Attempt to compute inverse of zero
20    InverseOfZero,
21}
22
23#[derive(Clone, Copy, Debug)]
24/// An element in the tower of binary field extensions from level 0 to level 7.
25///
26/// Implements arithmetic in finite fields GF(2^{2^n}) where n is the level of the field extension in the tower.
27///
28/// The internal representation stores polynomial coefficients as bits in a u128 integer.
29#[derive(Default)]
30pub struct TowerFieldElement {
31    /// The value of the element.
32    /// The binary expression of this value represents the coefficients of the corresponding polynomial of the element.
33    /// For example, if value = 0b1101, then p = xy + y + 1. If value = 0b0110, then p = y + x.
34    pub value: u128,
35    /// Number of the level in the tower.
36    /// It tells us to which field extension the element belongs.
37    /// It goes from 0 (representing the base field of two elements) to 7 (representing the field extension of 2^128 elements).
38    pub num_level: usize,
39}
40
41impl TowerFieldElement {
42    /// Constructor that always succeeds by masking the value if it is too big for the given
43    /// num_level, and limiting the level so that is not greater than 7.
44    pub fn new(val: u128, num_level: usize) -> Self {
45        // Limit num_level to a maximum valid value for u128.
46        let safe_level = if num_level > 7 { 7 } else { num_level };
47
48        // The number of bits needed for the given level
49        let bits = 1 << safe_level;
50        let mask = if bits >= 128 {
51            u128::MAX
52        } else {
53            (1 << bits) - 1
54        };
55
56        Self {
57            // We take just the lsb of val that fit in the extension field we are.
58            value: val & mask,
59            num_level: safe_level,
60        }
61    }
62
63    /// Returns true if the element is zero
64    pub fn is_zero(&self) -> bool {
65        self.value == 0
66    }
67
68    /// Returns true if this element is one
69    #[inline]
70    pub fn is_one(&self) -> bool {
71        self.value == 1
72    }
73
74    /// Returns the underlying value
75    #[inline]
76    pub fn value(&self) -> u128 {
77        self.value
78    }
79
80    /// Returns level number in the tower.
81    #[inline]
82    pub fn num_level(&self) -> usize {
83        self.num_level
84    }
85
86    /// Returns the number of bits needed for that level (2^num_levels).
87    /// Note that the order of the extension field in that level is 2^num_bits.
88    #[inline]
89    pub fn num_bits(&self) -> usize {
90        1 << self.num_level()
91    }
92
93    /// Returns binary string representation
94    #[cfg(feature = "std")]
95    pub fn to_binary_string(&self) -> String {
96        format!("{:0width$b}", self.value, width = self.num_bits())
97    }
98
99    /// Splits element into high and low parts.
100    /// For example, if a = xy + y + x, then a = (x + 1)y + x and
101    /// therefore, a_hi = x + 1 and a_lo = x.
102    pub fn split(&self) -> (Self, Self) {
103        let half_bits = self.num_bits() / 2;
104        let mask = (1 << half_bits) - 1;
105        let lo = self.value() & mask;
106        let hi = (self.value() >> half_bits) & mask;
107
108        (
109            Self::new(hi, self.num_level() - 1),
110            Self::new(lo, self.num_level() - 1),
111        )
112    }
113
114    /// Joins the hi and low part making a new element of a bigger level.
115    /// For example, if a_hi = x and a_low = 1
116    /// then a = xy + 1.
117    pub fn join(&self, low: &Self) -> Self {
118        let joined = (self.value() << self.num_bits()) | low.value();
119        Self::new(joined, self.num_level() + 1)
120    }
121
122    // It embeds an element in an extension changing the level number.
123    pub fn extend_num_level(&mut self, new_level: usize) {
124        if self.num_level() < new_level {
125            self.num_level = new_level;
126        }
127    }
128
129    /// Create a zero element
130    pub fn zero() -> Self {
131        Self::new(0, 0)
132    }
133
134    /// Create a one element
135    pub fn one() -> Self {
136        Self::new(1, 0)
137    }
138
139    /// Addition between elements of same or different levels.
140    fn add_elements(&self, other: &Self) -> Self {
141        let num_level = self.num_level().max(other.num_level());
142        Self::new(self.value() ^ other.value(), num_level)
143    }
144
145    // Multiplies a and b in the following way:
146    //
147    // - If a and b are from the same level:
148    // a = a_hi * x_n + a_lo
149    // b = b_hi * x_n + b_lo
150    // Then a * b = (b_hi * a_hi * x_{n-1} + b_hi * a_lo + a_hi * b_lo ) * x_n + b_hi * a_hi + a_lo * b_lo.
151    // We calculate each product in the equation below using recursion.
152    //
153    // - if a's level is larger than b's level, we partition a until we have parts of the size of b and
154    // multiply each part by b.
155    fn mul(self, other: Self) -> Self {
156        match self.num_level().cmp(&other.num_level()) {
157            Ordering::Greater => {
158                // We split a into two parts and call the same method to multiply each part by b.
159                let (a_hi, a_lo) = self.split();
160                // Join a_hi * b and a_lo * b.
161                a_hi.mul(other).join(&a_lo.mul(other))
162            }
163            Ordering::Less => {
164                // If b is larger than a, we swap the arguments and call the same method.
165                other.mul(self)
166            }
167            Ordering::Equal => {
168                // Base case:
169                if self.num_level() == 0 {
170                    // In the binary base field, multiplication is the same as AND operation.
171                    return Self::new(self.value() & other.value(), 0);
172                }
173
174                // Split both elements into high and low parts
175                let (a_high, a_low) = self.split();
176                let (b_high, b_low) = other.split();
177
178                // Step 1: Compute sub-products
179                let low_product = a_low.mul(b_low); // a_low * b_low
180                let high_product = a_high.mul(b_high); // a_high * b_high
181
182                // Step 2: Get the polynomial x_{n-1} value
183                let x_value = if self.num_level() == 1 {
184                    Self::new(1, 0)
185                } else {
186                    Self::new(1 << (self.num_bits() / 4), self.num_level() - 1)
187                };
188
189                // Step 3: Compute high_product * x_{n-1}
190                let shifted_high_product = high_product.mul(x_value);
191
192                // Step 4: Karatsuba optimization for middle term
193                // Instead of computing a_high * b_low + a_low * b_high directly,
194                // we use (a_low + a_high) * (b_low + b_high) - low_product - high_product
195                let sum_product = (a_low + a_high).mul(b_low + b_high);
196                let middle_term = sum_product - low_product - high_product;
197
198                // Step 5: Join the parts according to the tower field multiplication formula
199                (shifted_high_product + middle_term).join(&(high_product + low_product))
200            }
201        }
202    }
203
204    /// Computes the multiplicative inverse using Fermat's little theorem.
205    /// Returns an error if the element is zero.
206    // Based on Ingoyama's implementation
207    // https://github.com/ingonyama-zk/smallfield-super-sumcheck/blob/a8c61beef39bc0c10a8f68d25eeac0a7190a7289/src/tower_fields/binius.rs#L116C5-L116C6
208    pub fn inv(&self) -> Result<Self, BinaryFieldError> {
209        if self.is_zero() {
210            return Err(BinaryFieldError::InverseOfZero);
211        }
212        if self.num_level() <= 1 || self.num_bits() <= 4 {
213            let exponent = (1 << self.num_bits()) - 2;
214            Ok(Self::pow(self, exponent as u32))
215        } else {
216            let (a_hi, a_lo) = self.split();
217            let two_pow_k_minus_one = Self::new(1 << (self.num_bits() / 4), self.num_level() - 1);
218            // a = a_hi * x^k + a_lo
219            // a_lo_next = a_hi * x^(k-1) + a_lo
220            let a_lo_next = a_lo + a_hi * two_pow_k_minus_one;
221
222            // Δ = a_lo * a_lo_next + a_hi^2
223            let delta = a_lo * a_lo_next + a_hi * a_hi;
224
225            // Compute inverse of delta recursively
226            let delta_inverse = delta.inv()?;
227
228            // Compute parts of the inverse
229            let out_hi = delta_inverse * a_hi;
230            let out_lo = delta_inverse * a_lo_next;
231
232            // Join the parts to get the final inverse
233            Ok(out_hi.join(&out_lo))
234        }
235    }
236
237    /// Calculate power.
238    pub fn pow(&self, exp: u32) -> Self {
239        let mut result = Self::one();
240        let mut base = *self;
241        let mut exp_val = exp;
242
243        while exp_val > 0 {
244            if exp_val & 1 == 1 {
245                result *= base;
246            }
247            base = base * base;
248            exp_val >>= 1;
249        }
250
251        result
252    }
253}
254
255impl PartialEq<TowerFieldElement> for TowerFieldElement {
256    fn eq(&self, other: &Self) -> bool {
257        self.value() == other.value()
258    }
259}
260
261impl Eq for TowerFieldElement {}
262
263impl Add for TowerFieldElement {
264    type Output = Self;
265
266    fn add(self, other: Self) -> Self {
267        // Use the helper method that takes references
268        self.add_elements(&other)
269    }
270}
271
272impl<'a> Add<&'a TowerFieldElement> for &'a TowerFieldElement {
273    type Output = TowerFieldElement;
274
275    fn add(self, other: &'a TowerFieldElement) -> TowerFieldElement {
276        // Directly use the helper method
277        self.add_elements(other)
278    }
279}
280
281impl AddAssign for TowerFieldElement {
282    fn add_assign(&mut self, other: Self) {
283        *self = *self + other;
284    }
285}
286#[allow(clippy::suspicious_arithmetic_impl)]
287impl Sub for TowerFieldElement {
288    type Output = Self;
289
290    fn sub(self, other: Self) -> Self {
291        // In binary fields, subtraction is the same as addition
292        self + other
293    }
294}
295
296impl Neg for TowerFieldElement {
297    type Output = Self;
298
299    fn neg(self) -> Self {
300        // In binary fields, negation is the identity
301        self
302    }
303}
304
305impl Mul for TowerFieldElement {
306    type Output = Self;
307
308    fn mul(self, other: Self) -> Self {
309        self.mul(other)
310    }
311}
312
313impl Mul<&TowerFieldElement> for &TowerFieldElement {
314    type Output = TowerFieldElement;
315
316    fn mul(self, other: &TowerFieldElement) -> TowerFieldElement {
317        <TowerFieldElement as Mul<TowerFieldElement>>::mul(*self, *other)
318    }
319}
320
321impl MulAssign for TowerFieldElement {
322    fn mul_assign(&mut self, other: Self) {
323        *self = *self * other;
324    }
325}
326
327impl Product for TowerFieldElement {
328    fn product<I>(iter: I) -> Self
329    where
330        I: Iterator<Item = Self>,
331    {
332        iter.fold(Self::one(), |acc, x| acc * x)
333    }
334}
335
336impl Sum for TowerFieldElement {
337    fn sum<I>(iter: I) -> Self
338    where
339        I: Iterator<Item = Self>,
340    {
341        iter.fold(Self::zero(), |acc, x| acc + x)
342    }
343}
344
345impl fmt::Display for TowerFieldElement {
346    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347        write!(f, "{}", self.value)
348    }
349}
350
351impl From<u128> for TowerFieldElement {
352    fn from(val: u128) -> Self {
353        TowerFieldElement::new(val, 7)
354    }
355}
356
357impl From<u64> for TowerFieldElement {
358    fn from(val: u64) -> Self {
359        TowerFieldElement::new(val as u128, 6)
360    }
361}
362
363impl From<u32> for TowerFieldElement {
364    fn from(val: u32) -> Self {
365        TowerFieldElement::new(val as u128, 5)
366    }
367}
368
369impl From<u16> for TowerFieldElement {
370    fn from(val: u16) -> Self {
371        TowerFieldElement::new(val as u128, 4)
372    }
373}
374
375impl From<u8> for TowerFieldElement {
376    fn from(val: u8) -> Self {
377        TowerFieldElement::new(val as u128, 3)
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use proptest::prelude::*;
385
386    #[test]
387    fn test_new_safe() {
388        // Test with level too large
389        let elem = TowerFieldElement::new(0, 8);
390        assert_eq!(elem.num_level, 7); // Should be capped at 7
391
392        // Test with value too large for level
393        let elem = TowerFieldElement::new(4, 1); // Level 1 can only store 0-3
394        assert_eq!(elem.value, 0); // Should mask to 0 (100 & 11 = 00)
395    }
396
397    #[test]
398    fn test_addition() {
399        let a = TowerFieldElement::new(5, 9); // 8 bits
400        let b = TowerFieldElement::new(3, 2); // 4 bits
401
402        let c = a + b;
403        // 5 (0101) + 3 (0011) should be 6 (0110) at level 3
404        assert_eq!(c.value, 6);
405        assert_eq!(c.num_level, 7);
406
407        // Test commutative property
408        let d = b + a;
409        assert_eq!(d, c);
410    }
411
412    #[test]
413    fn mul_in_level_0() {
414        let a = TowerFieldElement::new(0, 0);
415        let b = TowerFieldElement::new(1, 0);
416        assert_eq!(a * a, a);
417        assert_eq!(a * b, a);
418        assert_eq!(b * b, b);
419    }
420
421    #[test]
422    fn mul_in_level_1() {
423        let a = TowerFieldElement::new(0b00, 1); // 0
424        let b = TowerFieldElement::new(0b01, 1); // 1
425        let c = TowerFieldElement::new(0b10, 1); // x
426        let d = TowerFieldElement::new(0b11, 1); // x + 1
427        assert_eq!(a * a, a);
428        assert_eq!(a * b, a);
429        assert_eq!(b * c, c);
430        assert_eq!(c * d, b);
431    }
432
433    #[test]
434    fn mul_in_level_2() {
435        let a = TowerFieldElement::new(0b0000, 2); // 0
436        let b = TowerFieldElement::new(0b0001, 2); // 1
437        let c = TowerFieldElement::new(0b0010, 2); // x
438        let d = TowerFieldElement::new(0b0011, 2); // x + 1
439        let e = TowerFieldElement::new(0b0100, 2); // y
440        let f = TowerFieldElement::new(0b0101, 2); // y + 1
441        let g = TowerFieldElement::new(0b0110, 2); // y + x
442        let h = TowerFieldElement::new(0b0111, 2); // y + x + 1
443        let i = TowerFieldElement::new(0b1000, 2); // yx
444        let j = TowerFieldElement::new(0b1001, 2); // yx + 1
445        let k = TowerFieldElement::new(0b1010, 2); // yx + x
446        let l = TowerFieldElement::new(0b1011, 2); // yx + x + 1
447        let n = TowerFieldElement::new(0b1100, 2); // yx + y
448        let m = TowerFieldElement::new(0b1101, 2); // yx + y + 1
449        let o = TowerFieldElement::new(0b1110, 2); // yx + y + x
450        let p = TowerFieldElement::new(0b1111, 2); // yx + y + x + 1
451
452        assert_eq!(a * p, a); // 0 * (yx + y + x + 1) = 0
453        assert_eq!(a * l, a); // 0 * (yx + x + 1) = 0
454        assert_eq!(b * m, m); // 1 * 1 = 1
455        assert_eq!(c * e, i); // x * y = xy
456        assert_eq!(c * c, d); // x * x = x + 1
457        assert_eq!(g * h, n); //(y + x)(y + x + 1) = yx + y
458        assert_eq!(k * j, b); // (yx + x)(yx + 1) = 1
459        assert_eq!(j * f, d); // (yx + 1)(y + 1) = x + 1
460        assert_eq!(e * e, j); // y * y = yx + 1
461        assert_eq!(n * o, k); // (yx + y)(yx + y + x) = yx + x
462    }
463
464    #[test]
465    fn mul_between_different_levels() {
466        let a = TowerFieldElement::new(0b10, 1); // x
467        let b = TowerFieldElement::new(0b0100, 2); // y
468        let c = TowerFieldElement::new(0b1000, 2); // yx
469        assert_eq!(a * b, c);
470    }
471
472    #[test]
473    fn test_correct_level_mul() {
474        let a = TowerFieldElement::new(0b1111, 5);
475        let b = TowerFieldElement::new(0b1010, 2);
476        assert_eq!((a * b).num_level, 5);
477    }
478
479    #[test]
480    fn mul_is_asociative() {
481        let a = TowerFieldElement::new(83, 7);
482        let b = TowerFieldElement::new(31, 5);
483        let c = TowerFieldElement::new(3, 2);
484        let ab = a * b;
485        let bc = b * c;
486        assert_eq!(ab * c, a * bc);
487    }
488
489    #[test]
490    fn mul_is_conmutative() {
491        let a = TowerFieldElement::new(127, 7);
492        let b = TowerFieldElement::new(6, 3);
493        let ab = a * b;
494        let ba = b * a;
495        assert_eq!(ab, ba);
496    }
497
498    #[test]
499    fn test_inverse() {
500        let a0 = TowerFieldElement::new(1, 0);
501        let inv_a0 = a0.inv().unwrap();
502        assert_eq!(inv_a0.value, 1);
503        assert_eq!(inv_a0.num_level, 0);
504
505        let a1 = TowerFieldElement::new(2, 1);
506        let inv_a1 = a1.inv().unwrap();
507        assert_eq!(inv_a1.value, 3); // because 10 * 11 = 01.
508        assert_eq!(inv_a1.num_level, 1);
509
510        // Verify a * a^(-1) = 1
511        let a2 = TowerFieldElement::new(15, 4);
512        let inv_a2 = a2.inv().unwrap();
513        let one = TowerFieldElement::new(1, 4);
514        assert_eq!(a2 * inv_a2, one);
515
516        let a3 = TowerFieldElement::new(30, 5);
517        let inv_a3 = a3.inv().unwrap();
518        let one = TowerFieldElement::new(1, 5);
519        assert_eq!(a3 * inv_a3, one);
520
521        let zero = TowerFieldElement::zero();
522        assert!(matches!(zero.inv(), Err(BinaryFieldError::InverseOfZero)));
523    }
524
525    #[test]
526    fn test_multiplication_overflow() {
527        for level in 0..7 {
528            let max_value = (1u128 << (1 << level)) - 1; // Maximum value for this level
529            let a = TowerFieldElement::new(max_value, level);
530            let b = TowerFieldElement::new(max_value, level);
531
532            let result = a * b;
533
534            // Result should be properly reduced
535            assert!(result.value < (1u128 << result.num_bits()));
536        }
537    }
538
539    #[test]
540    fn test_split_join_consistency() {
541        // Test that join and split are consistent operations
542        for i in 0..20 {
543            let original = TowerFieldElement::new(i, 3);
544            let (hi, lo) = original.split();
545            let rejoined = hi.join(&lo);
546
547            assert_eq!(rejoined, original);
548        }
549    }
550    #[cfg(feature = "std")]
551    #[test]
552    fn test_bin_representation() {
553        let a = TowerFieldElement::new(0b1010, 5);
554        assert_eq!(a.to_binary_string(), "00000000000000000000000000001010");
555        let b = TowerFieldElement::new(0b1010, 4);
556        assert_eq!(b.to_binary_string(), "0000000000001010");
557    }
558
559    // Strategy to generate a TowerFieldElement with a random level between 0 and 7.
560    // For a given level:
561    // - The number of bits is computed as 1 << level.
562    // - For level 0, valid values are 0 to (1 << 1) - 1 = 1.
563    // - For level > 0, valid values are 0 to (1 << (1 << level)) - 1.
564    fn arb_tower_element_any() -> impl Strategy<Value = TowerFieldElement> {
565        (0usize..=7)
566            .prop_flat_map(|level| {
567                let max_val = if level == 0 {
568                    1
569                } else if (1usize << level) >= 128 {
570                    u128::MAX
571                } else {
572                    (1u128 << (1 << level)) - 1
573                };
574                (Just(level), 0u128..=max_val)
575            })
576            .prop_map(|(level, val)| TowerFieldElement::new(val, level))
577    }
578
579    #[cfg(feature = "std")]
580    proptest! {
581        // Test that multiplication is commutative:
582        // For any two randomly generated elements, a * b should equal b * a.
583        #[test]
584        fn test_mul_commutative(a in arb_tower_element_any(), b in arb_tower_element_any()) {
585            prop_assert_eq!(a * b, b * a);
586        }
587
588        // Test that multiplication is associative:
589        // For any three randomly generated elements, (a * b) * c should equal a * (b * c).
590        #[test]
591        fn test_mul_associative(a in arb_tower_element_any(), b in arb_tower_element_any(), c in arb_tower_element_any()) {
592            prop_assert_eq!((a * b) * c, a * (b * c));
593        }
594    }
595}