zeropool_bn/
arith.rs

1use rand::Rng;
2use core::cmp::Ordering;
3
4use byteorder::{BigEndian, ByteOrder};
5
6#[cfg(feature = "borsh")]
7use borsh::{BorshDeserialize, BorshSerialize};
8
9/// 256-bit, stack allocated biginteger for use in prime field
10/// arithmetic.
11#[derive(Copy, Clone, Debug, PartialEq, Eq)]
12#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))]
13#[repr(C)]
14pub struct U256(pub [u128; 2]);
15
16impl From<[u64; 4]> for U256 {
17    fn from(d: [u64; 4]) -> Self {
18        let mut a = [0u128; 2];
19        a[0] = (d[1] as u128) << 64 | d[0] as u128;
20        a[1] = (d[3] as u128) << 64 | d[2] as u128;
21        U256(a)
22    }
23}
24
25impl From<u64> for U256 {
26    fn from(d: u64) -> Self {
27        U256::from([d, 0, 0, 0])
28    }
29}
30
31/// 512-bit, stack allocated biginteger for use in extension
32/// field serialization and scalar interpretation.
33#[derive(Copy, Clone, Debug, PartialEq, Eq)]
34#[repr(C)]
35pub struct U512(pub [u128; 4]);
36
37impl From<[u64; 8]> for U512 {
38    fn from(d: [u64; 8]) -> Self {
39        let mut a = [0u128; 4];
40        a[0] = (d[1] as u128) << 64 | d[0] as u128;
41        a[1] = (d[3] as u128) << 64 | d[2] as u128;
42        a[2] = (d[5] as u128) << 64 | d[4] as u128;
43        a[3] = (d[7] as u128) << 64 | d[6] as u128;
44        U512(a)
45    }
46}
47
48impl U512 {
49    /// Multiplies c1 by modulo, adds c0.
50    pub fn new(c1: &U256, c0: &U256, modulo: &U256) -> U512 {
51        let mut res = [0; 4];
52
53        debug_assert_eq!(c1.0.len(), 2);
54        unroll! {
55            for i in 0..2 {
56                mac_digit(i, &mut res, &modulo.0, c1.0[i]);
57            }
58        }
59
60        let mut carry = 0;
61
62        debug_assert_eq!(res.len(), 4);
63        unroll! {
64            for i in 0..2 {
65                res[i] = adc(res[i], c0.0[i], &mut carry);
66            }
67        }
68
69        unroll! {
70            for i in 0..2 {
71                let (a1, a0) = split_u128(res[i + 2]);
72                let (c, r0) = split_u128(a0 + carry);
73                let (c, r1) = split_u128(a1 + c);
74                carry = c;
75
76                res[i + 2] = combine_u128(r1, r0);
77            }
78        }
79
80        debug_assert!(0 == carry);
81
82        U512(res)
83    }
84
85    pub fn from_slice(s: &[u8]) -> Result<U512, Error> {
86        if s.len() != 64 {
87            return Err(Error::InvalidLength {
88                expected: 32,
89                actual: s.len(),
90            });
91        }
92
93        let mut n = [0; 4];
94        for (l, i) in (0..4).rev().zip((0..4).map(|i| i * 16)) {
95            n[l] = BigEndian::read_u128(&s[i..]);
96        }
97
98        Ok(U512(n))
99    }
100
101    /// Get a random U512
102    pub fn random<R: Rng>(rng: &mut R) -> U512 {
103        U512(rng.gen())
104    }
105
106    pub fn get_bit(&self, n: usize) -> Option<bool> {
107        if n >= 512 {
108            None
109        } else {
110            let part = n / 128;
111            let bit = n - (128 * part);
112
113            Some(self.0[part] & (1 << bit) > 0)
114        }
115    }
116
117    /// Divides self by modulo, returning remainder and, if
118    /// possible, a quotient smaller than the modulus.
119    pub fn divrem(&self, modulo: &U256) -> (Option<U256>, U256) {
120        let mut q = Some(U256::zero());
121        let mut r = U256::zero();
122
123        for i in (0..512).rev() {
124            // NB: modulo's first two bits are always unset
125            // so this will never destroy information
126            mul2(&mut r.0);
127            assert!(r.set_bit(0, self.get_bit(i).unwrap()));
128            if &r >= modulo {
129                sub_noborrow(&mut r.0, &modulo.0);
130                if q.is_some() && !q.as_mut().unwrap().set_bit(i, true) {
131                    q = None
132                }
133            }
134        }
135
136        if q.is_some() && (q.as_ref().unwrap() >= modulo) {
137            (None, r)
138        } else {
139            (q, r)
140        }
141    }
142
143    pub fn interpret(buf: &[u8; 64]) -> U512 {
144        let mut n = [0; 4];
145        for (l, i) in (0..4).rev().zip((0..4).map(|i| i * 16)) {
146            n[l] = BigEndian::read_u128(&buf[i..]);
147        }
148
149        U512(n)
150    }
151}
152
153impl Ord for U512 {
154    #[inline]
155    fn cmp(&self, other: &U512) -> Ordering {
156        for (a, b) in self.0.iter().zip(other.0.iter()).rev() {
157            if *a < *b {
158                return Ordering::Less;
159            } else if *a > *b {
160                return Ordering::Greater;
161            }
162        }
163
164        return Ordering::Equal;
165    }
166}
167
168impl PartialOrd for U512 {
169    #[inline]
170    fn partial_cmp(&self, other: &U512) -> Option<Ordering> {
171        Some(self.cmp(other))
172    }
173}
174
175impl Ord for U256 {
176    #[inline]
177    fn cmp(&self, other: &U256) -> Ordering {
178        for (a, b) in self.0.iter().zip(other.0.iter()).rev() {
179            if *a < *b {
180                return Ordering::Less;
181            } else if *a > *b {
182                return Ordering::Greater;
183            }
184        }
185
186        return Ordering::Equal;
187    }
188}
189
190impl PartialOrd for U256 {
191    #[inline]
192    fn partial_cmp(&self, other: &U256) -> Option<Ordering> {
193        Some(self.cmp(other))
194    }
195}
196
197/// U256/U512 errors
198#[derive(Debug)]
199pub enum Error {
200    InvalidLength { expected: usize, actual: usize },
201}
202
203impl U256 {
204    /// Initialize U256 from slice of bytes (big endian)
205    pub fn from_slice(s: &[u8]) -> Result<U256, Error> {
206        if s.len() != 32 {
207            return Err(Error::InvalidLength {
208                expected: 32,
209                actual: s.len(),
210            });
211        }
212
213        let mut n = [0; 2];
214        for (l, i) in (0..2).rev().zip((0..2).map(|i| i * 16)) {
215            n[l] = BigEndian::read_u128(&s[i..]);
216        }
217
218        Ok(U256(n))
219    }
220
221    pub fn to_big_endian(&self, s: &mut [u8]) -> Result<(), Error> {
222        if s.len() != 32 {
223            return Err(Error::InvalidLength {
224                expected: 32,
225                actual: s.len(),
226            });
227        }
228
229        for (l, i) in (0..2).rev().zip((0..2).map(|i| i * 16)) {
230            BigEndian::write_u128(&mut s[i..], self.0[l]);
231        }
232
233        Ok(())
234    }
235
236    #[inline]
237    pub fn zero() -> U256 {
238        U256([0, 0])
239    }
240
241    #[inline]
242    pub fn one() -> U256 {
243        U256([1, 0])
244    }
245
246    /// Produce a random number (mod `modulo`)
247    pub fn random<R: Rng>(rng: &mut R, modulo: &U256) -> U256 {
248        U512::random(rng).divrem(modulo).1
249    }
250
251    pub fn is_zero(&self) -> bool {
252        self.0[0] == 0 && self.0[1] == 0
253    }
254
255    pub fn set_bit(&mut self, n: usize, to: bool) -> bool {
256        if n >= 256 {
257            false
258        } else {
259            let part = n / 128;
260            let bit = n - (128 * part);
261
262            if to {
263                self.0[part] |= 1 << bit;
264            } else {
265                self.0[part] &= !(1 << bit);
266            }
267
268            true
269        }
270    }
271
272    pub fn get_bit(&self, n: usize) -> Option<bool> {
273        if n >= 256 {
274            None
275        } else {
276            let part = n / 128;
277            let bit = n - (128 * part);
278
279            Some(self.0[part] & (1 << bit) > 0)
280        }
281    }
282
283    /// Add `other` to `self` (mod `modulo`)
284    pub fn add(&mut self, other: &U256, modulo: &U256) {
285        add_nocarry(&mut self.0, &other.0);
286
287        if *self >= *modulo {
288            sub_noborrow(&mut self.0, &modulo.0);
289        }
290    }
291
292    /// Subtract `other` from `self` (mod `modulo`)
293    pub fn sub(&mut self, other: &U256, modulo: &U256) {
294        if *self < *other {
295            add_nocarry(&mut self.0, &modulo.0);
296        }
297
298        sub_noborrow(&mut self.0, &other.0);
299    }
300
301    /// Multiply `self` by `other` (mod `modulo`) via the Montgomery
302    /// multiplication method.
303    pub fn mul(&mut self, other: &U256, modulo: &U256, inv: u128) {
304        mul_reduce(&mut self.0, &other.0, &modulo.0, inv);
305
306        if *self >= *modulo {
307            sub_noborrow(&mut self.0, &modulo.0);
308        }
309    }
310
311    /// Turn `self` into its additive inverse (mod `modulo`)
312    pub fn neg(&mut self, modulo: &U256) {
313        if *self > Self::zero() {
314            let mut tmp = modulo.0;
315            sub_noborrow(&mut tmp, &self.0);
316
317            self.0 = tmp;
318        }
319    }
320
321    #[inline]
322    pub fn is_even(&self) -> bool {
323        self.0[0] & 1 == 0
324    }
325
326    /// Turn `self` into its multiplicative inverse (mod `modulo`)
327    pub fn invert(&mut self, modulo: &U256) {
328        // Guajardo Kumar Paar Pelzl
329        // Efficient Software-Implementation of Finite Fields with Applications to Cryptography
330        // Algorithm 16 (BEA for Inversion in Fp)
331
332        let mut u = *self;
333        let mut v = *modulo;
334        let mut b = U256::one();
335        let mut c = U256::zero();
336
337        while u != U256::one() && v != U256::one() {
338            while u.is_even() {
339                div2(&mut u.0);
340
341                if b.is_even() {
342                    div2(&mut b.0);
343                } else {
344                    add_nocarry(&mut b.0, &modulo.0);
345                    div2(&mut b.0);
346                }
347            }
348            while v.is_even() {
349                div2(&mut v.0);
350
351                if c.is_even() {
352                    div2(&mut c.0);
353                } else {
354                    add_nocarry(&mut c.0, &modulo.0);
355                    div2(&mut c.0);
356                }
357            }
358
359            if u >= v {
360                sub_noborrow(&mut u.0, &v.0);
361                b.sub(&c, modulo);
362            } else {
363                sub_noborrow(&mut v.0, &u.0);
364                c.sub(&b, modulo);
365            }
366        }
367
368        if u == U256::one() {
369            self.0 = b.0;
370        } else {
371            self.0 = c.0;
372        }
373    }
374
375    /// Return an Iterator<Item=bool> over all bits from
376    /// MSB to LSB.
377    pub fn bits(&self) -> BitIterator {
378        BitIterator { int: &self, n: 256 }
379    }
380}
381
382pub struct BitIterator<'a> {
383    int: &'a U256,
384    n: usize,
385}
386
387impl<'a> Iterator for BitIterator<'a> {
388    type Item = bool;
389
390    fn next(&mut self) -> Option<bool> {
391        if self.n == 0 {
392            None
393        } else {
394            self.n -= 1;
395
396            self.int.get_bit(self.n)
397        }
398    }
399}
400
401/// Divide by two
402#[inline]
403fn div2(a: &mut [u128; 2]) {
404    let tmp = a[1] << 127;
405    a[1] >>= 1;
406    a[0] >>= 1;
407    a[0] |= tmp;
408}
409
410/// Multiply by two
411#[inline]
412fn mul2(a: &mut [u128; 2]) {
413    let tmp = a[0] >> 127;
414    a[0] <<= 1;
415    a[1] <<= 1;
416    a[1] |= tmp;
417}
418
419#[inline(always)]
420fn split_u128(i: u128) -> (u128, u128) {
421    (i >> 64, i & 0xFFFFFFFFFFFFFFFF)
422}
423
424#[inline(always)]
425fn combine_u128(hi: u128, lo: u128) -> u128 {
426    (hi << 64) | lo
427}
428
429#[inline]
430fn adc(a: u128, b: u128, carry: &mut u128) -> u128 {
431    let (a1, a0) = split_u128(a);
432    let (b1, b0) = split_u128(b);
433    let (c, r0) = split_u128(a0 + b0 + *carry);
434    let (c, r1) = split_u128(a1 + b1 + c);
435    *carry = c;
436
437    combine_u128(r1, r0)
438}
439
440#[inline]
441fn add_nocarry(a: &mut [u128; 2], b: &[u128; 2]) {
442    let mut carry = 0;
443
444    for (a, b) in a.into_iter().zip(b.iter()) {
445        *a = adc(*a, *b, &mut carry);
446    }
447
448    debug_assert!(0 == carry);
449}
450
451#[inline]
452fn sub_noborrow(a: &mut [u128; 2], b: &[u128; 2]) {
453    #[inline]
454    fn sbb(a: u128, b: u128, borrow: &mut u128) -> u128 {
455        let (a1, a0) = split_u128(a);
456        let (b1, b0) = split_u128(b);
457        let (b, r0) = split_u128((1 << 64) + a0 - b0 - *borrow);
458        let (b, r1) = split_u128((1 << 64) + a1 - b1 - ((b == 0) as u128));
459
460        *borrow = (b == 0) as u128;
461
462        combine_u128(r1, r0)
463    }
464
465    let mut borrow = 0;
466
467    for (a, b) in a.into_iter().zip(b.iter()) {
468        *a = sbb(*a, *b, &mut borrow);
469    }
470
471    debug_assert!(0 == borrow);
472}
473
474// TODO: Make `from_index` a const param
475#[inline(always)]
476fn mac_digit(from_index: usize, acc: &mut [u128; 4], b: &[u128; 2], c: u128) {
477    #[inline]
478    fn mac_with_carry(a: u128, b: u128, c: u128, carry: &mut u128) -> u128 {
479        let (b_hi, b_lo) = split_u128(b);
480        let (c_hi, c_lo) = split_u128(c);
481
482        let (a_hi, a_lo) = split_u128(a);
483        let (carry_hi, carry_lo) = split_u128(*carry);
484        let (x_hi, x_lo) = split_u128(b_lo * c_lo + a_lo + carry_lo);
485        let (y_hi, y_lo) = split_u128(b_lo * c_hi);
486        let (z_hi, z_lo) = split_u128(b_hi * c_lo);
487        // Brackets to allow better ILP
488        let (r_hi, r_lo) = split_u128((x_hi + y_lo) + (z_lo + a_hi) + carry_hi);
489
490        *carry = (b_hi * c_hi) + r_hi + y_hi + z_hi;
491
492        combine_u128(r_lo, x_lo)
493    }
494
495    if c == 0 {
496        return;
497    }
498
499    let mut carry = 0;
500
501    debug_assert_eq!(acc.len(), 4);
502    unroll! {
503        for i in 0..2 {
504            let a_index = i + from_index;
505            acc[a_index] = mac_with_carry(acc[a_index], b[i], c, &mut carry);
506        }
507    }
508    unroll! {
509        for i in 0..2 {
510            let a_index = i + from_index + 2;
511            if a_index < 4 {
512                let (a_hi, a_lo) = split_u128(acc[a_index]);
513                let (carry_hi, carry_lo) = split_u128(carry);
514                let (x_hi, x_lo) = split_u128(a_lo + carry_lo);
515                let (r_hi, r_lo) = split_u128(x_hi + a_hi + carry_hi);
516
517                carry = r_hi;
518
519                acc[a_index] = combine_u128(r_lo, x_lo);
520            }
521        }
522    }
523
524    debug_assert!(carry == 0);
525}
526
527#[inline]
528fn mul_reduce(this: &mut [u128; 2], by: &[u128; 2], modulus: &[u128; 2], inv: u128) {
529    // The Montgomery reduction here is based on Algorithm 14.32 in
530    // Handbook of Applied Cryptography
531    // <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
532
533    let mut res = [0; 2 * 2];
534    unroll! {
535        for i in 0..2 {
536            mac_digit(i, &mut res, by, this[i]);
537        }
538    }
539
540    unroll! {
541        for i in 0..2 {
542            let k = inv.wrapping_mul(res[i]);
543            mac_digit(i, &mut res, modulus, k);
544        }
545    }
546
547    this.copy_from_slice(&res[2..]);
548}
549
550#[test]
551fn setting_bits() {
552    let rng = &mut ::rand::thread_rng();
553    let modulo = U256::from([0xffffffffffffffff; 4]);
554
555    let a = U256::random(rng, &modulo);
556    let mut e = U256::zero();
557    for (i, b) in a.bits().enumerate() {
558        assert!(e.set_bit(255 - i, b));
559    }
560
561    assert_eq!(a, e);
562}
563
564#[test]
565fn from_slice() {
566    let tst = U256::one();
567    let mut s = [0u8; 32];
568    s[31] = 1;
569
570    let num =
571        U256::from_slice(&s).expect("U256 should initialize ok from slice in `from_slice` test");
572    assert_eq!(num, tst);
573}
574
575#[test]
576fn to_big_endian() {
577    let num = U256::one();
578    let mut s = [0u8; 32];
579
580    num.to_big_endian(&mut s)
581        .expect("U256 should convert to bytes ok in `to_big_endian` test");
582    assert_eq!(
583        s,
584        [
585            0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
586            0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 1u8,
587        ]
588    );
589}
590
591#[test]
592fn testing_divrem() {
593    let rng = &mut ::rand::thread_rng();
594
595    let modulo = U256::from([
596        0x3c208c16d87cfd47,
597        0x97816a916871ca8d,
598        0xb85045b68181585d,
599        0x30644e72e131a029,
600    ]);
601
602    for _ in 0..100 {
603        let c0 = U256::random(rng, &modulo);
604        let c1 = U256::random(rng, &modulo);
605
606        let c1q_plus_c0 = U512::new(&c1, &c0, &modulo);
607
608        let (new_c1, new_c0) = c1q_plus_c0.divrem(&modulo);
609
610        assert!(c1 == new_c1.unwrap());
611        assert!(c0 == new_c0);
612    }
613
614    {
615        // Modulus should become 1*q + 0
616        let a = U512::from([
617            0x3c208c16d87cfd47,
618            0x97816a916871ca8d,
619            0xb85045b68181585d,
620            0x30644e72e131a029,
621            0,
622            0,
623            0,
624            0,
625        ]);
626
627        let (c1, c0) = a.divrem(&modulo);
628        assert_eq!(c1.unwrap(), U256::one());
629        assert_eq!(c0, U256::zero());
630    }
631
632    {
633        // Modulus squared minus 1 should be (q-1) q + q-1
634        let a = U512::from([
635            0x3b5458a2275d69b0,
636            0xa602072d09eac101,
637            0x4a50189c6d96cadc,
638            0x04689e957a1242c8,
639            0x26edfa5c34c6b38d,
640            0xb00b855116375606,
641            0x599a6f7c0348d21c,
642            0x0925c4b8763cbf9c,
643        ]);
644
645        let (c1, c0) = a.divrem(&modulo);
646        assert_eq!(
647            c1.unwrap(),
648            U256::from([
649                0x3c208c16d87cfd46,
650                0x97816a916871ca8d,
651                0xb85045b68181585d,
652                0x30644e72e131a029
653            ])
654        );
655        assert_eq!(
656            c0,
657            U256::from([
658                0x3c208c16d87cfd46,
659                0x97816a916871ca8d,
660                0xb85045b68181585d,
661                0x30644e72e131a029
662            ])
663        );
664    }
665
666    {
667        // Modulus squared minus 2 should be (q-1) q + q-2
668        let a = U512::from([
669            0x3b5458a2275d69af,
670            0xa602072d09eac101,
671            0x4a50189c6d96cadc,
672            0x04689e957a1242c8,
673            0x26edfa5c34c6b38d,
674            0xb00b855116375606,
675            0x599a6f7c0348d21c,
676            0x0925c4b8763cbf9c,
677        ]);
678
679        let (c1, c0) = a.divrem(&modulo);
680
681        assert_eq!(
682            c1.unwrap(),
683            U256::from([
684                0x3c208c16d87cfd46,
685                0x97816a916871ca8d,
686                0xb85045b68181585d,
687                0x30644e72e131a029
688            ])
689        );
690        assert_eq!(
691            c0,
692            U256::from([
693                0x3c208c16d87cfd45,
694                0x97816a916871ca8d,
695                0xb85045b68181585d,
696                0x30644e72e131a029
697            ])
698        );
699    }
700
701    {
702        // Ridiculously large number should fail
703        let a = U512::from([
704            0xffffffffffffffff,
705            0xffffffffffffffff,
706            0xffffffffffffffff,
707            0xffffffffffffffff,
708            0xffffffffffffffff,
709            0xffffffffffffffff,
710            0xffffffffffffffff,
711            0xffffffffffffffff,
712        ]);
713
714        let (c1, c0) = a.divrem(&modulo);
715        assert!(c1.is_none());
716        assert_eq!(
717            c0,
718            U256::from([
719                0xf32cfc5b538afa88,
720                0xb5e71911d44501fb,
721                0x47ab1eff0a417ff6,
722                0x06d89f71cab8351f
723            ])
724        );
725    }
726
727    {
728        // Modulus squared should fail
729        let a = U512::from([
730            0x3b5458a2275d69b1,
731            0xa602072d09eac101,
732            0x4a50189c6d96cadc,
733            0x04689e957a1242c8,
734            0x26edfa5c34c6b38d,
735            0xb00b855116375606,
736            0x599a6f7c0348d21c,
737            0x0925c4b8763cbf9c,
738        ]);
739
740        let (c1, c0) = a.divrem(&modulo);
741        assert!(c1.is_none());
742        assert_eq!(c0, U256::zero());
743    }
744
745    {
746        // Modulus squared plus one should fail
747        let a = U512::from([
748            0x3b5458a2275d69b2,
749            0xa602072d09eac101,
750            0x4a50189c6d96cadc,
751            0x04689e957a1242c8,
752            0x26edfa5c34c6b38d,
753            0xb00b855116375606,
754            0x599a6f7c0348d21c,
755            0x0925c4b8763cbf9c,
756        ]);
757
758        let (c1, c0) = a.divrem(&modulo);
759        assert!(c1.is_none());
760        assert_eq!(c0, U256::one());
761    }
762
763    {
764        let modulo = U256::from([
765            0x43e1f593f0000001,
766            0x2833e84879b97091,
767            0xb85045b68181585d,
768            0x30644e72e131a029,
769        ]);
770
771        // Fr modulus masked off is valid
772        let a = U512::from([
773            0xffffffffffffffff,
774            0xffffffffffffffff,
775            0xffffffffffffffff,
776            0xffffffffffffffff,
777            0xffffffffffffffff,
778            0xffffffffffffffff,
779            0xffffffffffffffff,
780            0x07ffffffffffffff,
781        ]);
782
783        let (c1, c0) = a.divrem(&modulo);
784
785        assert!(c1.unwrap() < modulo);
786        assert!(c0 < modulo);
787    }
788}