pg_curve/
scalar.rs

1//! This module provides an implementation of the BLS12-381 scalar field $\mathbb{F}_q$
2//! where `q = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001`
3
4use core::fmt;
5use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
6use rand_core::RngCore;
7
8use ff::{Field, PrimeField};
9use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[cfg(feature = "bits")]
12use ff::{FieldBits, PrimeFieldBits};
13
14use crate::util::{adc, mac, sbb};
15
16/// Represents an element of the scalar field $\mathbb{F}_q$ of the BLS12-381 elliptic
17/// curve construction.
18// The internal representation of this type is four 64-bit unsigned
19// integers in little-endian order. `Scalar` values are always in
20// Montgomery form; i.e., Scalar(a) = aR mod q, with R = 2^256.
21#[derive(Clone, Copy, Eq)]
22pub struct Scalar(pub(crate) [u64; 4]);
23
24impl fmt::Debug for Scalar {
25    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26        let tmp = self.to_bytes();
27        write!(f, "0x")?;
28        for &b in tmp.iter().rev() {
29            write!(f, "{:02x}", b)?;
30        }
31        Ok(())
32    }
33}
34
35impl fmt::Display for Scalar {
36    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37        write!(f, "{:?}", self)
38    }
39}
40
41impl From<u64> for Scalar {
42    fn from(val: u64) -> Scalar {
43        Scalar([val, 0, 0, 0]) * R2
44    }
45}
46
47impl ConstantTimeEq for Scalar {
48    fn ct_eq(&self, other: &Self) -> Choice {
49        self.0[0].ct_eq(&other.0[0])
50            & self.0[1].ct_eq(&other.0[1])
51            & self.0[2].ct_eq(&other.0[2])
52            & self.0[3].ct_eq(&other.0[3])
53    }
54}
55
56impl PartialEq for Scalar {
57    #[inline]
58    fn eq(&self, other: &Self) -> bool {
59        bool::from(self.ct_eq(other))
60    }
61}
62
63impl ConditionallySelectable for Scalar {
64    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
65        Scalar([
66            u64::conditional_select(&a.0[0], &b.0[0], choice),
67            u64::conditional_select(&a.0[1], &b.0[1], choice),
68            u64::conditional_select(&a.0[2], &b.0[2], choice),
69            u64::conditional_select(&a.0[3], &b.0[3], choice),
70        ])
71    }
72}
73
74/// Constant representing the modulus
75/// q = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
76pub(crate) const MODULUS: Scalar = Scalar([
77    0xffff_ffff_0000_0001,
78    0x53bd_a402_fffe_5bfe,
79    0x3339_d808_09a1_d805,
80    0x73ed_a753_299d_7d48,
81]);
82
83/// The modulus as u32 limbs.
84#[cfg(all(feature = "bits", not(target_pointer_width = "64")))]
85const MODULUS_LIMBS_32: [u32; 8] = [
86    0x0000_0001,
87    0xffff_ffff,
88    0xfffe_5bfe,
89    0x53bd_a402,
90    0x09a1_d805,
91    0x3339_d808,
92    0x299d_7d48,
93    0x73ed_a753,
94];
95
96// The number of bits needed to represent the modulus.
97const MODULUS_BITS: u32 = 255;
98
99// GENERATOR = 7 (multiplicative generator of r-1 order, that is also quadratic nonresidue)
100const GENERATOR: Scalar = Scalar([
101    0x0000_000e_ffff_fff1,
102    0x17e3_63d3_0018_9c0f,
103    0xff9c_5787_6f84_57b0,
104    0x3513_3220_8fc5_a8c4,
105]);
106
107impl<'a> Neg for &'a Scalar {
108    type Output = Scalar;
109
110    #[inline]
111    fn neg(self) -> Scalar {
112        self.neg()
113    }
114}
115
116impl Neg for Scalar {
117    type Output = Scalar;
118
119    #[inline]
120    fn neg(self) -> Scalar {
121        -&self
122    }
123}
124
125impl<'a, 'b> Sub<&'b Scalar> for &'a Scalar {
126    type Output = Scalar;
127
128    #[inline]
129    fn sub(self, rhs: &'b Scalar) -> Scalar {
130        self.sub(rhs)
131    }
132}
133
134impl<'a, 'b> Add<&'b Scalar> for &'a Scalar {
135    type Output = Scalar;
136
137    #[inline]
138    fn add(self, rhs: &'b Scalar) -> Scalar {
139        self.add(rhs)
140    }
141}
142
143impl<'a, 'b> Mul<&'b Scalar> for &'a Scalar {
144    type Output = Scalar;
145
146    #[inline]
147    fn mul(self, rhs: &'b Scalar) -> Scalar {
148        self.mul(rhs)
149    }
150}
151
152impl_binops_additive!(Scalar, Scalar);
153impl_binops_multiplicative!(Scalar, Scalar);
154
155/// INV = -(q^{-1} mod 2^64) mod 2^64
156const INV: u64 = 0xffff_fffe_ffff_ffff;
157
158/// R = 2^256 mod q
159const R: Scalar = Scalar([
160    0x0000_0001_ffff_fffe,
161    0x5884_b7fa_0003_4802,
162    0x998c_4fef_ecbc_4ff5,
163    0x1824_b159_acc5_056f,
164]);
165
166/// R^2 = 2^512 mod q
167const R2: Scalar = Scalar([
168    0xc999_e990_f3f2_9c6d,
169    0x2b6c_edcb_8792_5c23,
170    0x05d3_1496_7254_398f,
171    0x0748_d9d9_9f59_ff11,
172]);
173
174/// R^3 = 2^768 mod q
175const R3: Scalar = Scalar([
176    0xc62c_1807_439b_73af,
177    0x1b3e_0d18_8cf0_6990,
178    0x73d1_3c71_c7b5_f418,
179    0x6e2a_5bb9_c8db_33e9,
180]);
181
182/// 2^-1
183const TWO_INV: Scalar = Scalar([
184    0x0000_0000_ffff_ffff,
185    0xac42_5bfd_0001_a401,
186    0xccc6_27f7_f65e_27fa,
187    0x0c12_58ac_d662_82b7,
188]);
189
190// 2^S * t = MODULUS - 1 with t odd
191const S: u32 = 32;
192
193/// GENERATOR^t where t * 2^s + 1 = q
194/// with t odd. In other words, this
195/// is a 2^s root of unity.
196///
197/// `GENERATOR = 7 mod q` is a generator
198/// of the q - 1 order multiplicative
199/// subgroup.
200const ROOT_OF_UNITY: Scalar = Scalar([
201    0xb9b5_8d8c_5f0e_466a,
202    0x5b1b_4c80_1819_d7ec,
203    0x0af5_3ae3_52a3_1e64,
204    0x5bf3_adda_19e9_b27b,
205]);
206
207/// ROOT_OF_UNITY^-1
208const ROOT_OF_UNITY_INV: Scalar = Scalar([
209    0x4256_481a_dcf3_219a,
210    0x45f3_7b7f_96b6_cad3,
211    0xf9c3_f1d7_5f7a_3b27,
212    0x2d2f_c049_658a_fd43,
213]);
214
215/// GENERATOR^{2^s} where t * 2^s + 1 = q with t odd.
216/// In other words, this is a t root of unity.
217const DELTA: Scalar = Scalar([
218    0x70e3_10d3_d146_f96a,
219    0x4b64_c089_19e2_99e6,
220    0x51e1_1418_6a8b_970d,
221    0x6185_d066_27c0_67cb,
222]);
223
224impl Default for Scalar {
225    #[inline]
226    fn default() -> Self {
227        Self::zero()
228    }
229}
230
231#[cfg(feature = "zeroize")]
232impl zeroize::DefaultIsZeroes for Scalar {}
233
234impl Scalar {
235    /// Returns zero, the additive identity.
236    #[inline]
237    pub const fn zero() -> Scalar {
238        Scalar([0, 0, 0, 0])
239    }
240
241    /// Returns one, the multiplicative identity.
242    #[inline]
243    pub const fn one() -> Scalar {
244        R
245    }
246
247    /// Doubles this field element.
248    #[inline]
249    pub const fn double(&self) -> Scalar {
250        // TODO: This can be achieved more efficiently with a bitshift.
251        self.add(self)
252    }
253
254    /// Attempts to convert a little-endian byte representation of
255    /// a scalar into a `Scalar`, failing if the input is not canonical.
256    pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<Scalar> {
257        let mut tmp = Scalar([0, 0, 0, 0]);
258
259        tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap());
260        tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap());
261        tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap());
262        tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap());
263
264        // Try to subtract the modulus
265        let (_, borrow) = sbb(tmp.0[0], MODULUS.0[0], 0);
266        let (_, borrow) = sbb(tmp.0[1], MODULUS.0[1], borrow);
267        let (_, borrow) = sbb(tmp.0[2], MODULUS.0[2], borrow);
268        let (_, borrow) = sbb(tmp.0[3], MODULUS.0[3], borrow);
269
270        // If the element is smaller than MODULUS then the
271        // subtraction will underflow, producing a borrow value
272        // of 0xffff...ffff. Otherwise, it'll be zero.
273        let is_some = (borrow as u8) & 1;
274
275        // Convert to Montgomery form by computing
276        // (a.R^0 * R^2) / R = a.R
277        tmp *= &R2;
278
279        CtOption::new(tmp, Choice::from(is_some))
280    }
281
282    /// Converts an element of `Scalar` into a byte representation in
283    /// little-endian byte order.
284    pub fn to_bytes(&self) -> [u8; 32] {
285        // Turn into canonical form by computing
286        // (a.R) / R = a
287        let tmp = Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0);
288
289        let mut res = [0; 32];
290        res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes());
291        res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes());
292        res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes());
293        res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes());
294
295        res
296    }
297
298    /// Converts a 512-bit little endian integer into
299    /// a `Scalar` by reducing by the modulus.
300    pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar {
301        Scalar::from_u512([
302            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()),
303            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()),
304            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()),
305            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()),
306            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[32..40]).unwrap()),
307            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[40..48]).unwrap()),
308            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[48..56]).unwrap()),
309            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[56..64]).unwrap()),
310        ])
311    }
312
313    fn from_u512(limbs: [u64; 8]) -> Scalar {
314        // We reduce an arbitrary 512-bit number by decomposing it into two 256-bit digits
315        // with the higher bits multiplied by 2^256. Thus, we perform two reductions
316        //
317        // 1. the lower bits are multiplied by R^2, as normal
318        // 2. the upper bits are multiplied by R^2 * 2^256 = R^3
319        //
320        // and computing their sum in the field. It remains to see that arbitrary 256-bit
321        // numbers can be placed into Montgomery form safely using the reduction. The
322        // reduction works so long as the product is less than R=2^256 multiplied by
323        // the modulus. This holds because for any `c` smaller than the modulus, we have
324        // that (2^256 - 1)*c is an acceptable product for the reduction. Therefore, the
325        // reduction always works so long as `c` is in the field; in this case it is either the
326        // constant `R2` or `R3`.
327        let d0 = Scalar([limbs[0], limbs[1], limbs[2], limbs[3]]);
328        let d1 = Scalar([limbs[4], limbs[5], limbs[6], limbs[7]]);
329        // Convert to Montgomery form
330        d0 * R2 + d1 * R3
331    }
332
333    /// Converts from an integer represented in little endian
334    /// into its (congruent) `Scalar` representation.
335    pub const fn from_raw(val: [u64; 4]) -> Self {
336        (&Scalar(val)).mul(&R2)
337    }
338
339    /// Squares this element.
340    #[inline]
341    pub const fn square(&self) -> Scalar {
342        let (r1, carry) = mac(0, self.0[0], self.0[1], 0);
343        let (r2, carry) = mac(0, self.0[0], self.0[2], carry);
344        let (r3, r4) = mac(0, self.0[0], self.0[3], carry);
345
346        let (r3, carry) = mac(r3, self.0[1], self.0[2], 0);
347        let (r4, r5) = mac(r4, self.0[1], self.0[3], carry);
348
349        let (r5, r6) = mac(r5, self.0[2], self.0[3], 0);
350
351        let r7 = r6 >> 63;
352        let r6 = (r6 << 1) | (r5 >> 63);
353        let r5 = (r5 << 1) | (r4 >> 63);
354        let r4 = (r4 << 1) | (r3 >> 63);
355        let r3 = (r3 << 1) | (r2 >> 63);
356        let r2 = (r2 << 1) | (r1 >> 63);
357        let r1 = r1 << 1;
358
359        let (r0, carry) = mac(0, self.0[0], self.0[0], 0);
360        let (r1, carry) = adc(0, r1, carry);
361        let (r2, carry) = mac(r2, self.0[1], self.0[1], carry);
362        let (r3, carry) = adc(0, r3, carry);
363        let (r4, carry) = mac(r4, self.0[2], self.0[2], carry);
364        let (r5, carry) = adc(0, r5, carry);
365        let (r6, carry) = mac(r6, self.0[3], self.0[3], carry);
366        let (r7, _) = adc(0, r7, carry);
367
368        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
369    }
370
371    /// Exponentiates `self` by `by`, where `by` is a
372    /// little-endian order integer exponent.
373    pub fn pow(&self, by: &[u64; 4]) -> Self {
374        let mut res = Self::one();
375        for e in by.iter().rev() {
376            for i in (0..64).rev() {
377                res = res.square();
378                let mut tmp = res;
379                tmp *= self;
380                res.conditional_assign(&tmp, (((*e >> i) & 0x1) as u8).into());
381            }
382        }
383        res
384    }
385
386    /// Exponentiates `self` by `by`, where `by` is a
387    /// little-endian order integer exponent.
388    ///
389    /// **This operation is variable time with respect
390    /// to the exponent.** If the exponent is fixed,
391    /// this operation is effectively constant time.
392    pub fn pow_vartime(&self, by: &[u64; 4]) -> Self {
393        let mut res = Self::one();
394        for e in by.iter().rev() {
395            for i in (0..64).rev() {
396                res = res.square();
397
398                if ((*e >> i) & 1) == 1 {
399                    res.mul_assign(self);
400                }
401            }
402        }
403        res
404    }
405
406    /// Computes the multiplicative inverse of this element,
407    /// failing if the element is zero.
408    pub fn invert(&self) -> CtOption<Self> {
409        #[inline(always)]
410        fn square_assign_multi(n: &mut Scalar, num_times: usize) {
411            for _ in 0..num_times {
412                *n = n.square();
413            }
414        }
415        // found using https://github.com/kwantam/addchain
416        let mut t0 = self.square();
417        let mut t1 = t0 * self;
418        let mut t16 = t0.square();
419        let mut t6 = t16.square();
420        let mut t5 = t6 * t0;
421        t0 = t6 * t16;
422        let mut t12 = t5 * t16;
423        let mut t2 = t6.square();
424        let mut t7 = t5 * t6;
425        let mut t15 = t0 * t5;
426        let mut t17 = t12.square();
427        t1 *= t17;
428        let mut t3 = t7 * t2;
429        let t8 = t1 * t17;
430        let t4 = t8 * t2;
431        let t9 = t8 * t7;
432        t7 = t4 * t5;
433        let t11 = t4 * t17;
434        t5 = t9 * t17;
435        let t14 = t7 * t15;
436        let t13 = t11 * t12;
437        t12 = t11 * t17;
438        t15 *= &t12;
439        t16 *= &t15;
440        t3 *= &t16;
441        t17 *= &t3;
442        t0 *= &t17;
443        t6 *= &t0;
444        t2 *= &t6;
445        square_assign_multi(&mut t0, 8);
446        t0 *= &t17;
447        square_assign_multi(&mut t0, 9);
448        t0 *= &t16;
449        square_assign_multi(&mut t0, 9);
450        t0 *= &t15;
451        square_assign_multi(&mut t0, 9);
452        t0 *= &t15;
453        square_assign_multi(&mut t0, 7);
454        t0 *= &t14;
455        square_assign_multi(&mut t0, 7);
456        t0 *= &t13;
457        square_assign_multi(&mut t0, 10);
458        t0 *= &t12;
459        square_assign_multi(&mut t0, 9);
460        t0 *= &t11;
461        square_assign_multi(&mut t0, 8);
462        t0 *= &t8;
463        square_assign_multi(&mut t0, 8);
464        t0 *= self;
465        square_assign_multi(&mut t0, 14);
466        t0 *= &t9;
467        square_assign_multi(&mut t0, 10);
468        t0 *= &t8;
469        square_assign_multi(&mut t0, 15);
470        t0 *= &t7;
471        square_assign_multi(&mut t0, 10);
472        t0 *= &t6;
473        square_assign_multi(&mut t0, 8);
474        t0 *= &t5;
475        square_assign_multi(&mut t0, 16);
476        t0 *= &t3;
477        square_assign_multi(&mut t0, 8);
478        t0 *= &t2;
479        square_assign_multi(&mut t0, 7);
480        t0 *= &t4;
481        square_assign_multi(&mut t0, 9);
482        t0 *= &t2;
483        square_assign_multi(&mut t0, 8);
484        t0 *= &t3;
485        square_assign_multi(&mut t0, 8);
486        t0 *= &t2;
487        square_assign_multi(&mut t0, 8);
488        t0 *= &t2;
489        square_assign_multi(&mut t0, 8);
490        t0 *= &t2;
491        square_assign_multi(&mut t0, 8);
492        t0 *= &t3;
493        square_assign_multi(&mut t0, 8);
494        t0 *= &t2;
495        square_assign_multi(&mut t0, 8);
496        t0 *= &t2;
497        square_assign_multi(&mut t0, 5);
498        t0 *= &t1;
499        square_assign_multi(&mut t0, 5);
500        t0 *= &t1;
501
502        CtOption::new(t0, !self.ct_eq(&Self::zero()))
503    }
504
505    #[inline(always)]
506    const fn montgomery_reduce(
507        r0: u64,
508        r1: u64,
509        r2: u64,
510        r3: u64,
511        r4: u64,
512        r5: u64,
513        r6: u64,
514        r7: u64,
515    ) -> Self {
516        // The Montgomery reduction here is based on Algorithm 14.32 in
517        // Handbook of Applied Cryptography
518        // <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
519
520        let k = r0.wrapping_mul(INV);
521        let (_, carry) = mac(r0, k, MODULUS.0[0], 0);
522        let (r1, carry) = mac(r1, k, MODULUS.0[1], carry);
523        let (r2, carry) = mac(r2, k, MODULUS.0[2], carry);
524        let (r3, carry) = mac(r3, k, MODULUS.0[3], carry);
525        let (r4, carry2) = adc(r4, 0, carry);
526
527        let k = r1.wrapping_mul(INV);
528        let (_, carry) = mac(r1, k, MODULUS.0[0], 0);
529        let (r2, carry) = mac(r2, k, MODULUS.0[1], carry);
530        let (r3, carry) = mac(r3, k, MODULUS.0[2], carry);
531        let (r4, carry) = mac(r4, k, MODULUS.0[3], carry);
532        let (r5, carry2) = adc(r5, carry2, carry);
533
534        let k = r2.wrapping_mul(INV);
535        let (_, carry) = mac(r2, k, MODULUS.0[0], 0);
536        let (r3, carry) = mac(r3, k, MODULUS.0[1], carry);
537        let (r4, carry) = mac(r4, k, MODULUS.0[2], carry);
538        let (r5, carry) = mac(r5, k, MODULUS.0[3], carry);
539        let (r6, carry2) = adc(r6, carry2, carry);
540
541        let k = r3.wrapping_mul(INV);
542        let (_, carry) = mac(r3, k, MODULUS.0[0], 0);
543        let (r4, carry) = mac(r4, k, MODULUS.0[1], carry);
544        let (r5, carry) = mac(r5, k, MODULUS.0[2], carry);
545        let (r6, carry) = mac(r6, k, MODULUS.0[3], carry);
546        let (r7, _) = adc(r7, carry2, carry);
547
548        // Result may be within MODULUS of the correct value
549        (&Scalar([r4, r5, r6, r7])).sub(&MODULUS)
550    }
551
552    /// Multiplies `rhs` by `self`, returning the result.
553    #[inline]
554    pub const fn mul(&self, rhs: &Self) -> Self {
555        // Schoolbook multiplication
556
557        let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0);
558        let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry);
559        let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry);
560        let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry);
561
562        let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0);
563        let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry);
564        let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry);
565        let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry);
566
567        let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0);
568        let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry);
569        let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry);
570        let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry);
571
572        let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0);
573        let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry);
574        let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry);
575        let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry);
576
577        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
578    }
579
580    /// Subtracts `rhs` from `self`, returning the result.
581    #[inline]
582    pub const fn sub(&self, rhs: &Self) -> Self {
583        let (d0, borrow) = sbb(self.0[0], rhs.0[0], 0);
584        let (d1, borrow) = sbb(self.0[1], rhs.0[1], borrow);
585        let (d2, borrow) = sbb(self.0[2], rhs.0[2], borrow);
586        let (d3, borrow) = sbb(self.0[3], rhs.0[3], borrow);
587
588        // If underflow occurred on the final limb, borrow = 0xfff...fff, otherwise
589        // borrow = 0x000...000. Thus, we use it as a mask to conditionally add the modulus.
590        let (d0, carry) = adc(d0, MODULUS.0[0] & borrow, 0);
591        let (d1, carry) = adc(d1, MODULUS.0[1] & borrow, carry);
592        let (d2, carry) = adc(d2, MODULUS.0[2] & borrow, carry);
593        let (d3, _) = adc(d3, MODULUS.0[3] & borrow, carry);
594
595        Scalar([d0, d1, d2, d3])
596    }
597
598    /// Adds `rhs` to `self`, returning the result.
599    #[inline]
600    pub const fn add(&self, rhs: &Self) -> Self {
601        let (d0, carry) = adc(self.0[0], rhs.0[0], 0);
602        let (d1, carry) = adc(self.0[1], rhs.0[1], carry);
603        let (d2, carry) = adc(self.0[2], rhs.0[2], carry);
604        let (d3, _) = adc(self.0[3], rhs.0[3], carry);
605
606        // Attempt to subtract the modulus, to ensure the value
607        // is smaller than the modulus.
608        (&Scalar([d0, d1, d2, d3])).sub(&MODULUS)
609    }
610
611    /// Negates `self`.
612    #[inline]
613    pub const fn neg(&self) -> Self {
614        // Subtract `self` from `MODULUS` to negate. Ignore the final
615        // borrow because it cannot underflow; self is guaranteed to
616        // be in the field.
617        let (d0, borrow) = sbb(MODULUS.0[0], self.0[0], 0);
618        let (d1, borrow) = sbb(MODULUS.0[1], self.0[1], borrow);
619        let (d2, borrow) = sbb(MODULUS.0[2], self.0[2], borrow);
620        let (d3, _) = sbb(MODULUS.0[3], self.0[3], borrow);
621
622        // `tmp` could be `MODULUS` if `self` was zero. Create a mask that is
623        // zero if `self` was zero, and `u64::max_value()` if self was nonzero.
624        let mask = (((self.0[0] | self.0[1] | self.0[2] | self.0[3]) == 0) as u64).wrapping_sub(1);
625
626        Scalar([d0 & mask, d1 & mask, d2 & mask, d3 & mask])
627    }
628}
629
630impl From<Scalar> for [u8; 32] {
631    fn from(value: Scalar) -> [u8; 32] {
632        value.to_bytes()
633    }
634}
635
636impl<'a> From<&'a Scalar> for [u8; 32] {
637    fn from(value: &'a Scalar) -> [u8; 32] {
638        value.to_bytes()
639    }
640}
641
642impl Field for Scalar {
643    const ZERO: Self = Self::zero();
644    const ONE: Self = Self::one();
645
646    fn random(mut rng: impl RngCore) -> Self {
647        let mut buf = [0; 64];
648        rng.fill_bytes(&mut buf);
649        Self::from_bytes_wide(&buf)
650    }
651
652    #[must_use]
653    fn square(&self) -> Self {
654        self.square()
655    }
656
657    #[must_use]
658    fn double(&self) -> Self {
659        self.double()
660    }
661
662    fn invert(&self) -> CtOption<Self> {
663        self.invert()
664    }
665
666    fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
667        ff::helpers::sqrt_ratio_generic(num, div)
668    }
669
670    fn sqrt(&self) -> CtOption<Self> {
671        // (t - 1) // 2 = 6104339283789297388802252303364915521546564123189034618274734669823
672        ff::helpers::sqrt_tonelli_shanks(
673            self,
674            &[
675                0x7fff_2dff_7fff_ffff,
676                0x04d0_ec02_a9de_d201,
677                0x94ce_bea4_199c_ec04,
678                0x0000_0000_39f6_d3a9,
679            ],
680        )
681    }
682
683    fn is_zero_vartime(&self) -> bool {
684        self.0 == Self::zero().0
685    }
686}
687
688impl PrimeField for Scalar {
689    type Repr = [u8; 32];
690
691    fn from_repr(r: Self::Repr) -> CtOption<Self> {
692        Self::from_bytes(&r)
693    }
694
695    fn to_repr(&self) -> Self::Repr {
696        self.to_bytes()
697    }
698
699    fn is_odd(&self) -> Choice {
700        Choice::from(self.to_bytes()[0] & 1)
701    }
702
703    const MODULUS: &'static str =
704        "0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001";
705    const NUM_BITS: u32 = MODULUS_BITS;
706    const CAPACITY: u32 = Self::NUM_BITS - 1;
707    const TWO_INV: Self = TWO_INV;
708    const MULTIPLICATIVE_GENERATOR: Self = GENERATOR;
709    const S: u32 = S;
710    const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
711    const ROOT_OF_UNITY_INV: Self = ROOT_OF_UNITY_INV;
712    const DELTA: Self = DELTA;
713}
714
715#[cfg(all(feature = "bits", not(target_pointer_width = "64")))]
716type ReprBits = [u32; 8];
717
718#[cfg(all(feature = "bits", target_pointer_width = "64"))]
719type ReprBits = [u64; 4];
720
721#[cfg(feature = "bits")]
722impl PrimeFieldBits for Scalar {
723    type ReprBits = ReprBits;
724
725    fn to_le_bits(&self) -> FieldBits<Self::ReprBits> {
726        let bytes = self.to_bytes();
727
728        #[cfg(not(target_pointer_width = "64"))]
729        let limbs = [
730            u32::from_le_bytes(bytes[0..4].try_into().unwrap()),
731            u32::from_le_bytes(bytes[4..8].try_into().unwrap()),
732            u32::from_le_bytes(bytes[8..12].try_into().unwrap()),
733            u32::from_le_bytes(bytes[12..16].try_into().unwrap()),
734            u32::from_le_bytes(bytes[16..20].try_into().unwrap()),
735            u32::from_le_bytes(bytes[20..24].try_into().unwrap()),
736            u32::from_le_bytes(bytes[24..28].try_into().unwrap()),
737            u32::from_le_bytes(bytes[28..32].try_into().unwrap()),
738        ];
739
740        #[cfg(target_pointer_width = "64")]
741        let limbs = [
742            u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
743            u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
744            u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
745            u64::from_le_bytes(bytes[24..32].try_into().unwrap()),
746        ];
747
748        FieldBits::new(limbs)
749    }
750
751    fn char_le_bits() -> FieldBits<Self::ReprBits> {
752        #[cfg(not(target_pointer_width = "64"))]
753        {
754            FieldBits::new(MODULUS_LIMBS_32)
755        }
756
757        #[cfg(target_pointer_width = "64")]
758        FieldBits::new(MODULUS.0)
759    }
760}
761
762impl<T> core::iter::Sum<T> for Scalar
763where
764    T: core::borrow::Borrow<Scalar>,
765{
766    fn sum<I>(iter: I) -> Self
767    where
768        I: Iterator<Item = T>,
769    {
770        iter.fold(Self::zero(), |acc, item| acc + item.borrow())
771    }
772}
773
774impl<'a> From<&'a Scalar> for [u64; 4] {
775    fn from(value: &'a Scalar) -> [u64; 4] {
776        let res =
777            Scalar::montgomery_reduce(value.0[0], value.0[1], value.0[2], value.0[3], 0, 0, 0, 0);
778        res.0
779    }
780}
781
782impl<T> core::iter::Product<T> for Scalar
783where
784    T: core::borrow::Borrow<Scalar>,
785{
786    fn product<I>(iter: I) -> Self
787    where
788        I: Iterator<Item = T>,
789    {
790        iter.fold(Self::one(), |acc, item| acc * item.borrow())
791    }
792}
793
794#[test]
795fn test_constants() {
796    assert_eq!(
797        Scalar::MODULUS,
798        "0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001",
799    );
800
801    assert_eq!(Scalar::from(2) * Scalar::TWO_INV, Scalar::ONE);
802
803    assert_eq!(
804        Scalar::ROOT_OF_UNITY * Scalar::ROOT_OF_UNITY_INV,
805        Scalar::ONE,
806    );
807
808    // ROOT_OF_UNITY^{2^s} mod m == 1
809    assert_eq!(
810        Scalar::ROOT_OF_UNITY.pow(&[1u64 << Scalar::S, 0, 0, 0]),
811        Scalar::ONE,
812    );
813
814    // DELTA^{t} mod m == 1
815    assert_eq!(
816        Scalar::DELTA.pow(&[
817            0xfffe_5bfe_ffff_ffff,
818            0x09a1_d805_53bd_a402,
819            0x299d_7d48_3339_d808,
820            0x0000_0000_73ed_a753,
821        ]),
822        Scalar::ONE,
823    );
824}
825
826#[test]
827fn test_inv() {
828    // Compute -(q^{-1} mod 2^64) mod 2^64 by exponentiating
829    // by totient(2**64) - 1
830
831    let mut inv = 1u64;
832    for _ in 0..63 {
833        inv = inv.wrapping_mul(inv);
834        inv = inv.wrapping_mul(MODULUS.0[0]);
835    }
836    inv = inv.wrapping_neg();
837
838    assert_eq!(inv, INV);
839}
840
841#[cfg(feature = "std")]
842#[test]
843fn test_debug() {
844    assert_eq!(
845        format!("{:?}", Scalar::zero()),
846        "0x0000000000000000000000000000000000000000000000000000000000000000"
847    );
848    assert_eq!(
849        format!("{:?}", Scalar::one()),
850        "0x0000000000000000000000000000000000000000000000000000000000000001"
851    );
852    assert_eq!(
853        format!("{:?}", R2),
854        "0x1824b159acc5056f998c4fefecbc4ff55884b7fa0003480200000001fffffffe"
855    );
856}
857
858#[test]
859fn test_equality() {
860    assert_eq!(Scalar::zero(), Scalar::zero());
861    assert_eq!(Scalar::one(), Scalar::one());
862    assert_eq!(R2, R2);
863
864    assert!(Scalar::zero() != Scalar::one());
865    assert!(Scalar::one() != R2);
866}
867
868#[test]
869fn test_to_bytes() {
870    assert_eq!(
871        Scalar::zero().to_bytes(),
872        [
873            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
874            0, 0, 0
875        ]
876    );
877
878    assert_eq!(
879        Scalar::one().to_bytes(),
880        [
881            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
882            0, 0, 0
883        ]
884    );
885
886    assert_eq!(
887        R2.to_bytes(),
888        [
889            254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
890            79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24
891        ]
892    );
893
894    assert_eq!(
895        (-&Scalar::one()).to_bytes(),
896        [
897            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
898            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
899        ]
900    );
901}
902
903#[test]
904fn test_from_bytes() {
905    assert_eq!(
906        Scalar::from_bytes(&[
907            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
908            0, 0, 0
909        ])
910        .unwrap(),
911        Scalar::zero()
912    );
913
914    assert_eq!(
915        Scalar::from_bytes(&[
916            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
917            0, 0, 0
918        ])
919        .unwrap(),
920        Scalar::one()
921    );
922
923    assert_eq!(
924        Scalar::from_bytes(&[
925            254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
926            79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24
927        ])
928        .unwrap(),
929        R2
930    );
931
932    // -1 should work
933    assert!(bool::from(
934        Scalar::from_bytes(&[
935            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
936            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
937        ])
938        .is_some()
939    ));
940
941    // modulus is invalid
942    assert!(bool::from(
943        Scalar::from_bytes(&[
944            1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
945            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
946        ])
947        .is_none()
948    ));
949
950    // Anything larger than the modulus is invalid
951    assert!(bool::from(
952        Scalar::from_bytes(&[
953            2, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
954            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
955        ])
956        .is_none()
957    ));
958    assert!(bool::from(
959        Scalar::from_bytes(&[
960            1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
961            216, 58, 51, 72, 125, 157, 41, 83, 167, 237, 115
962        ])
963        .is_none()
964    ));
965    assert!(bool::from(
966        Scalar::from_bytes(&[
967            1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
968            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 116
969        ])
970        .is_none()
971    ));
972}
973
974#[test]
975fn test_from_u512_zero() {
976    assert_eq!(
977        Scalar::zero(),
978        Scalar::from_u512([
979            MODULUS.0[0],
980            MODULUS.0[1],
981            MODULUS.0[2],
982            MODULUS.0[3],
983            0,
984            0,
985            0,
986            0
987        ])
988    );
989}
990
991#[test]
992fn test_from_u512_r() {
993    assert_eq!(R, Scalar::from_u512([1, 0, 0, 0, 0, 0, 0, 0]));
994}
995
996#[test]
997fn test_from_u512_r2() {
998    assert_eq!(R2, Scalar::from_u512([0, 0, 0, 0, 1, 0, 0, 0]));
999}
1000
1001#[test]
1002fn test_from_u512_max() {
1003    let max_u64 = 0xffff_ffff_ffff_ffff;
1004    assert_eq!(
1005        R3 - R,
1006        Scalar::from_u512([max_u64, max_u64, max_u64, max_u64, max_u64, max_u64, max_u64, max_u64])
1007    );
1008}
1009
1010#[test]
1011fn test_from_bytes_wide_r2() {
1012    assert_eq!(
1013        R2,
1014        Scalar::from_bytes_wide(&[
1015            254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
1016            79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1017            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1018        ])
1019    );
1020}
1021
1022#[test]
1023fn test_from_bytes_wide_negative_one() {
1024    assert_eq!(
1025        -&Scalar::one(),
1026        Scalar::from_bytes_wide(&[
1027            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
1028            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1029            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1030        ])
1031    );
1032}
1033
1034#[test]
1035fn test_from_bytes_wide_maximum() {
1036    assert_eq!(
1037        Scalar([
1038            0xc62c_1805_439b_73b1,
1039            0xc2b9_551e_8ced_218e,
1040            0xda44_ec81_daf9_a422,
1041            0x5605_aa60_1c16_2e79,
1042        ]),
1043        Scalar::from_bytes_wide(&[0xff; 64])
1044    );
1045}
1046
1047#[test]
1048fn test_zero() {
1049    assert_eq!(Scalar::zero(), -&Scalar::zero());
1050    assert_eq!(Scalar::zero(), Scalar::zero() + Scalar::zero());
1051    assert_eq!(Scalar::zero(), Scalar::zero() - Scalar::zero());
1052    assert_eq!(Scalar::zero(), Scalar::zero() * Scalar::zero());
1053}
1054
1055#[cfg(test)]
1056const LARGEST: Scalar = Scalar([
1057    0xffff_ffff_0000_0000,
1058    0x53bd_a402_fffe_5bfe,
1059    0x3339_d808_09a1_d805,
1060    0x73ed_a753_299d_7d48,
1061]);
1062
1063#[test]
1064fn test_addition() {
1065    let mut tmp = LARGEST;
1066    tmp += &LARGEST;
1067
1068    assert_eq!(
1069        tmp,
1070        Scalar([
1071            0xffff_fffe_ffff_ffff,
1072            0x53bd_a402_fffe_5bfe,
1073            0x3339_d808_09a1_d805,
1074            0x73ed_a753_299d_7d48,
1075        ])
1076    );
1077
1078    let mut tmp = LARGEST;
1079    tmp += &Scalar([1, 0, 0, 0]);
1080
1081    assert_eq!(tmp, Scalar::zero());
1082}
1083
1084#[test]
1085fn test_negation() {
1086    let tmp = -&LARGEST;
1087
1088    assert_eq!(tmp, Scalar([1, 0, 0, 0]));
1089
1090    let tmp = -&Scalar::zero();
1091    assert_eq!(tmp, Scalar::zero());
1092    let tmp = -&Scalar([1, 0, 0, 0]);
1093    assert_eq!(tmp, LARGEST);
1094}
1095
1096#[test]
1097fn test_subtraction() {
1098    let mut tmp = LARGEST;
1099    tmp -= &LARGEST;
1100
1101    assert_eq!(tmp, Scalar::zero());
1102
1103    let mut tmp = Scalar::zero();
1104    tmp -= &LARGEST;
1105
1106    let mut tmp2 = MODULUS;
1107    tmp2 -= &LARGEST;
1108
1109    assert_eq!(tmp, tmp2);
1110}
1111
1112#[test]
1113fn test_multiplication() {
1114    let mut cur = LARGEST;
1115
1116    for _ in 0..100 {
1117        let mut tmp = cur;
1118        tmp *= &cur;
1119
1120        let mut tmp2 = Scalar::zero();
1121        for b in cur
1122            .to_bytes()
1123            .iter()
1124            .rev()
1125            .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) == 1u8))
1126        {
1127            let tmp3 = tmp2;
1128            tmp2.add_assign(&tmp3);
1129
1130            if b {
1131                tmp2.add_assign(&cur);
1132            }
1133        }
1134
1135        assert_eq!(tmp, tmp2);
1136
1137        cur.add_assign(&LARGEST);
1138    }
1139}
1140
1141#[test]
1142fn test_squaring() {
1143    let mut cur = LARGEST;
1144
1145    for _ in 0..100 {
1146        let mut tmp = cur;
1147        tmp = tmp.square();
1148
1149        let mut tmp2 = Scalar::zero();
1150        for b in cur
1151            .to_bytes()
1152            .iter()
1153            .rev()
1154            .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) == 1u8))
1155        {
1156            let tmp3 = tmp2;
1157            tmp2.add_assign(&tmp3);
1158
1159            if b {
1160                tmp2.add_assign(&cur);
1161            }
1162        }
1163
1164        assert_eq!(tmp, tmp2);
1165
1166        cur.add_assign(&LARGEST);
1167    }
1168}
1169
1170#[test]
1171fn test_inversion() {
1172    assert!(bool::from(Scalar::zero().invert().is_none()));
1173    assert_eq!(Scalar::one().invert().unwrap(), Scalar::one());
1174    assert_eq!((-&Scalar::one()).invert().unwrap(), -&Scalar::one());
1175
1176    let mut tmp = R2;
1177
1178    for _ in 0..100 {
1179        let mut tmp2 = tmp.invert().unwrap();
1180        tmp2.mul_assign(&tmp);
1181
1182        assert_eq!(tmp2, Scalar::one());
1183
1184        tmp.add_assign(&R2);
1185    }
1186}
1187
1188#[test]
1189fn test_invert_is_pow() {
1190    let q_minus_2 = [
1191        0xffff_fffe_ffff_ffff,
1192        0x53bd_a402_fffe_5bfe,
1193        0x3339_d808_09a1_d805,
1194        0x73ed_a753_299d_7d48,
1195    ];
1196
1197    let mut r1 = R;
1198    let mut r2 = R;
1199    let mut r3 = R;
1200
1201    for _ in 0..100 {
1202        r1 = r1.invert().unwrap();
1203        r2 = r2.pow_vartime(&q_minus_2);
1204        r3 = r3.pow(&q_minus_2);
1205
1206        assert_eq!(r1, r2);
1207        assert_eq!(r2, r3);
1208        // Add R so we check something different next time around
1209        r1.add_assign(&R);
1210        r2 = r1;
1211        r3 = r1;
1212    }
1213}
1214
1215#[test]
1216fn test_sqrt() {
1217    {
1218        assert_eq!(Scalar::zero().sqrt().unwrap(), Scalar::zero());
1219    }
1220
1221    let mut square = Scalar([
1222        0x46cd_85a5_f273_077e,
1223        0x1d30_c47d_d68f_c735,
1224        0x77f6_56f6_0bec_a0eb,
1225        0x494a_a01b_df32_468d,
1226    ]);
1227
1228    let mut none_count = 0;
1229
1230    for _ in 0..100 {
1231        let square_root = square.sqrt();
1232        if bool::from(square_root.is_none()) {
1233            none_count += 1;
1234        } else {
1235            assert_eq!(square_root.unwrap() * square_root.unwrap(), square);
1236        }
1237        square -= Scalar::one();
1238    }
1239
1240    assert_eq!(49, none_count);
1241}
1242
1243#[test]
1244fn test_from_raw() {
1245    assert_eq!(
1246        Scalar::from_raw([
1247            0x0001_ffff_fffd,
1248            0x5884_b7fa_0003_4802,
1249            0x998c_4fef_ecbc_4ff5,
1250            0x1824_b159_acc5_056f,
1251        ]),
1252        Scalar::from_raw([0xffff_ffff_ffff_ffff; 4])
1253    );
1254
1255    assert_eq!(Scalar::from_raw(MODULUS.0), Scalar::zero());
1256
1257    assert_eq!(Scalar::from_raw([1, 0, 0, 0]), R);
1258}
1259
1260#[test]
1261fn test_double() {
1262    let a = Scalar::from_raw([
1263        0x1fff_3231_233f_fffd,
1264        0x4884_b7fa_0003_4802,
1265        0x998c_4fef_ecbc_4ff3,
1266        0x1824_b159_acc5_0562,
1267    ]);
1268
1269    assert_eq!(a.double(), a + a);
1270}
1271
1272#[cfg(feature = "zeroize")]
1273#[test]
1274fn test_zeroize() {
1275    use zeroize::Zeroize;
1276
1277    let mut a = Scalar::from_raw([
1278        0x1fff_3231_233f_fffd,
1279        0x4884_b7fa_0003_4802,
1280        0x998c_4fef_ecbc_4ff3,
1281        0x1824_b159_acc5_0562,
1282    ]);
1283    a.zeroize();
1284    assert!(bool::from(a.is_zero()));
1285}