dcrypt_algorithms/ec/bls12_381/
scalar.rs

1//! BLS12-381 scalar field F_q where q = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
2
3use crate::error::{Error, Result};
4use crate::hash::{sha2::Sha256, HashFunction};
5use crate::types::{
6    ByteSerializable, ConstantTimeEq as DcryptConstantTimeEq, SecureZeroingType,
7};
8use core::fmt;
9use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
10use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
11
12// Arithmetic helpers
13/// Compute a + b + carry, returning (result, carry)
14#[inline(always)]
15const fn adc(a: u64, b: u64, carry: u64) -> (u64, u64) {
16    let ret = (a as u128) + (b as u128) + (carry as u128);
17    (ret as u64, (ret >> 64) as u64)
18}
19
20/// Compute a - (b + borrow), returning (result, borrow)
21#[inline(always)]
22const fn sbb(a: u64, b: u64, borrow: u64) -> (u64, u64) {
23    let ret = (a as u128).wrapping_sub((b as u128) + ((borrow >> 63) as u128));
24    (ret as u64, (ret >> 64) as u64)
25}
26
27/// Compute a + (b * c) + carry, returning (result, carry)
28#[inline(always)]
29const fn mac(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) {
30    let ret = (a as u128) + ((b as u128) * (c as u128)) + (carry as u128);
31    (ret as u64, (ret >> 64) as u64)
32}
33
34/// Scalar field element of BLS12-381
35/// Internal: Four 64-bit limbs in little-endian Montgomery form
36#[derive(Clone, Copy, Eq)]
37pub struct Scalar(pub(crate) [u64; 4]);
38
39impl fmt::Debug for Scalar {
40    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
41        let tmp = self.to_bytes();
42        write!(f, "0x")?;
43        for &b in tmp.iter().rev() {
44            write!(f, "{:02x}", b)?;
45        }
46        Ok(())
47    }
48}
49
50impl fmt::Display for Scalar {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        write!(f, "{:?}", self)
53    }
54}
55
56impl From<u64> for Scalar {
57    fn from(val: u64) -> Scalar {
58        Scalar([val, 0, 0, 0]) * R2
59    }
60}
61
62impl ConstantTimeEq for Scalar {
63    fn ct_eq(&self, other: &Self) -> Choice {
64        self.0[0].ct_eq(&other.0[0])
65            & self.0[1].ct_eq(&other.0[1])
66            & self.0[2].ct_eq(&other.0[2])
67            & self.0[3].ct_eq(&other.0[3])
68    }
69}
70
71impl PartialEq for Scalar {
72    #[inline]
73    fn eq(&self, other: &Self) -> bool {
74        bool::from(subtle::ConstantTimeEq::ct_eq(self, other))
75    }
76}
77
78impl ConditionallySelectable for Scalar {
79    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
80        Scalar([
81            u64::conditional_select(&a.0[0], &b.0[0], choice),
82            u64::conditional_select(&a.0[1], &b.0[1], choice),
83            u64::conditional_select(&a.0[2], &b.0[2], choice),
84            u64::conditional_select(&a.0[3], &b.0[3], choice),
85        ])
86    }
87}
88
89// Constants
90const MODULUS: Scalar = Scalar([
91    0xffff_ffff_0000_0001,
92    0x53bd_a402_fffe_5bfe,
93    0x3339_d808_09a1_d805,
94    0x73ed_a753_299d_7d48,
95]);
96
97/// INV = -(q^{-1} mod 2^64) mod 2^64
98const INV: u64 = 0xffff_fffe_ffff_ffff;
99
100/// R = 2^256 mod q
101const R: Scalar = Scalar([
102    0x0000_0001_ffff_fffe,
103    0x5884_b7fa_0003_4802,
104    0x998c_4fef_ecbc_4ff5,
105    0x1824_b159_acc5_056f,
106]);
107
108/// R^2 = 2^512 mod q
109const R2: Scalar = Scalar([
110    0xc999_e990_f3f2_9c6d,
111    0x2b6c_edcb_8792_5c23,
112    0x05d3_1496_7254_398f,
113    0x0748_d9d9_9f59_ff11,
114]);
115
116/// R^3 = 2^768 mod q
117const R3: Scalar = Scalar([
118    0xc62c_1807_439b_73af,
119    0x1b3e_0d18_8cf0_6990,
120    0x73d1_3c71_c7b5_f418,
121    0x6e2a_5bb9_c8db_33e9,
122]);
123
124// Constants for Tonelli-Shanks square root algorithm
125// 2-adicity of (r - 1)
126const S: u32 = 32;
127
128// T = (r - 1) / 2^S  (odd part)
129const TONELLI_T: [u64; 4] = [
130    0xfffe_5bfe_ffff_ffff,
131    0x09a1_d805_53bd_a402,
132    0x299d_7d48_3339_d808,
133    0x0000_0000_73ed_a753,
134];
135
136// (T + 1)/2, used to initialize x = a^((T+1)/2)
137const TONELLI_TP1_DIV2: [u64; 4] = [
138    0x7fff_2dff_8000_0000,
139    0x04d0_ec02_a9de_d201,
140    0x94ce_bea4_199c_ec04,
141    0x0000_0000_39f6_d3a9,
142];
143
144// Exponent (r-1)/2, the Legendre exponent
145#[allow(dead_code)]
146const LEGENDRE_EXP: [u64; 4] = [
147    0x7fff_ffff_8000_0000,
148    0xa9de_d201_7fff_2dff,
149    0x199c_ec04_04d0_ec02,
150    0x39f6_d3a9_94ce_bea4,
151];
152
153impl<'a> Neg for &'a Scalar {
154    type Output = Scalar;
155
156    #[inline]
157    fn neg(self) -> Scalar {
158        self.neg()
159    }
160}
161
162impl Neg for Scalar {
163    type Output = Scalar;
164
165    #[inline]
166    fn neg(self) -> Scalar {
167        -&self
168    }
169}
170
171impl<'a, 'b> Sub<&'b Scalar> for &'a Scalar {
172    type Output = Scalar;
173
174    #[inline]
175    fn sub(self, rhs: &'b Scalar) -> Scalar {
176        self.sub(rhs)
177    }
178}
179
180impl<'a, 'b> Add<&'b Scalar> for &'a Scalar {
181    type Output = Scalar;
182
183    #[inline]
184    fn add(self, rhs: &'b Scalar) -> Scalar {
185        self.add(rhs)
186    }
187}
188
189impl<'a, 'b> Mul<&'b Scalar> for &'a Scalar {
190    type Output = Scalar;
191
192    #[inline]
193    fn mul(self, rhs: &'b Scalar) -> Scalar {
194        self.mul(rhs)
195    }
196}
197
198// Binop implementations
199impl<'b> Add<&'b Scalar> for Scalar {
200    type Output = Scalar;
201    #[inline]
202    fn add(self, rhs: &'b Scalar) -> Scalar {
203        &self + rhs
204    }
205}
206
207impl<'a> Add<Scalar> for &'a Scalar {
208    type Output = Scalar;
209    #[inline]
210    fn add(self, rhs: Scalar) -> Scalar {
211        self + &rhs
212    }
213}
214
215impl Add<Scalar> for Scalar {
216    type Output = Scalar;
217    #[inline]
218    fn add(self, rhs: Scalar) -> Scalar {
219        &self + &rhs
220    }
221}
222
223impl<'b> Sub<&'b Scalar> for Scalar {
224    type Output = Scalar;
225    #[inline]
226    fn sub(self, rhs: &'b Scalar) -> Scalar {
227        &self - rhs
228    }
229}
230
231impl<'a> Sub<Scalar> for &'a Scalar {
232    type Output = Scalar;
233    #[inline]
234    fn sub(self, rhs: Scalar) -> Scalar {
235        self - &rhs
236    }
237}
238
239impl Sub<Scalar> for Scalar {
240    type Output = Scalar;
241    #[inline]
242    fn sub(self, rhs: Scalar) -> Scalar {
243        &self - &rhs
244    }
245}
246
247impl SubAssign<Scalar> for Scalar {
248    #[inline]
249    fn sub_assign(&mut self, rhs: Scalar) {
250        *self = &*self - &rhs;
251    }
252}
253
254impl AddAssign<Scalar> for Scalar {
255    #[inline]
256    fn add_assign(&mut self, rhs: Scalar) {
257        *self = &*self + &rhs;
258    }
259}
260
261impl<'b> SubAssign<&'b Scalar> for Scalar {
262    #[inline]
263    fn sub_assign(&mut self, rhs: &'b Scalar) {
264        *self = &*self - rhs;
265    }
266}
267
268impl<'b> AddAssign<&'b Scalar> for Scalar {
269    #[inline]
270    fn add_assign(&mut self, rhs: &'b Scalar) {
271        *self = &*self + rhs;
272    }
273}
274
275impl<'b> Mul<&'b Scalar> for Scalar {
276    type Output = Scalar;
277    #[inline]
278    fn mul(self, rhs: &'b Scalar) -> Scalar {
279        &self * rhs
280    }
281}
282
283impl<'a> Mul<Scalar> for &'a Scalar {
284    type Output = Scalar;
285    #[inline]
286    fn mul(self, rhs: Scalar) -> Scalar {
287        self * &rhs
288    }
289}
290
291impl Mul<Scalar> for Scalar {
292    type Output = Scalar;
293    #[inline]
294    fn mul(self, rhs: Scalar) -> Scalar {
295        &self * &rhs
296    }
297}
298
299impl MulAssign<Scalar> for Scalar {
300    #[inline]
301    fn mul_assign(&mut self, rhs: Scalar) {
302        *self = &*self * &rhs;
303    }
304}
305
306impl<'b> MulAssign<&'b Scalar> for Scalar {
307    #[inline]
308    fn mul_assign(&mut self, rhs: &'b Scalar) {
309        *self = &*self * rhs;
310    }
311}
312
313impl Default for Scalar {
314    #[inline]
315    fn default() -> Self {
316        Self::zero()
317    }
318}
319
320#[cfg(feature = "zeroize")]
321impl zeroize::DefaultIsZeroes for Scalar {}
322
323impl ByteSerializable for Scalar {
324    fn to_bytes(&self) -> Vec<u8> {
325        self.to_bytes().to_vec()
326    }
327
328    fn from_bytes(bytes: &[u8]) -> Result<Self> {
329        if bytes.len() != 32 {
330            return Err(Error::Length {
331                context: "Scalar::from_bytes",
332                expected: 32,
333                actual: bytes.len(),
334            });
335        }
336
337        let mut array = [0u8; 32];
338        array.copy_from_slice(bytes);
339
340        Scalar::from_bytes(&array)
341            .into_option()
342            .ok_or_else(|| Error::param("scalar_bytes", "non-canonical scalar"))
343    }
344}
345
346impl DcryptConstantTimeEq for Scalar {
347    fn ct_eq(&self, other: &Self) -> bool {
348        bool::from(subtle::ConstantTimeEq::ct_eq(self, other))
349    }
350}
351
352impl SecureZeroingType for Scalar {
353    fn zeroed() -> Self {
354        Self::zero()
355    }
356}
357
358impl Scalar {
359    /// Additive identity
360    #[inline]
361    pub const fn zero() -> Scalar {
362        Scalar([0, 0, 0, 0])
363    }
364
365    /// Multiplicative identity
366    #[inline]
367    pub const fn one() -> Scalar {
368        R
369    }
370
371    /// Check if element is zero.
372    #[inline]
373    pub fn is_zero(&self) -> Choice {
374        (self.0[0] | self.0[1] | self.0[2] | self.0[3]).ct_eq(&0)
375    }
376
377    /// Double this element
378    #[inline]
379    pub const fn double(&self) -> Scalar {
380        self.add(self)
381    }
382
383    /// Create from little-endian bytes if canonical
384    pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<Scalar> {
385        let mut tmp = Scalar([0, 0, 0, 0]);
386
387        tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap());
388        tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap());
389        tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap());
390        tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap());
391
392        // Check canonical by subtracting modulus
393        let (_, borrow) = sbb(tmp.0[0], MODULUS.0[0], 0);
394        let (_, borrow) = sbb(tmp.0[1], MODULUS.0[1], borrow);
395        let (_, borrow) = sbb(tmp.0[2], MODULUS.0[2], borrow);
396        let (_, borrow) = sbb(tmp.0[3], MODULUS.0[3], borrow);
397
398        let is_some = (borrow as u8) & 1;
399
400        // Convert to Montgomery: (a * R^2) / R = aR
401        tmp *= &R2;
402
403        CtOption::new(tmp, Choice::from(is_some))
404    }
405
406    /// Convert to little-endian bytes
407    pub fn to_bytes(&self) -> [u8; 32] {
408        // Remove Montgomery: (aR) / R = a
409        let tmp = Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0);
410
411        let mut res = [0; 32];
412        res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes());
413        res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes());
414        res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes());
415        res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes());
416
417        res
418    }
419
420    /// Create from 512-bit little-endian integer mod q
421    pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar {
422        Scalar::from_u512([
423            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()),
424            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()),
425            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()),
426            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()),
427            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[32..40]).unwrap()),
428            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[40..48]).unwrap()),
429            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[48..56]).unwrap()),
430            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[56..64]).unwrap()),
431        ])
432    }
433
434    fn expand_message_xmd(msg: &[u8], dst: &[u8], len_in_bytes: usize) -> Result<Vec<u8>> {
435        const MAX_DST_LENGTH: usize = 255;
436        const HASH_OUTPUT_SIZE: usize = 32;
437
438        if dst.len() > MAX_DST_LENGTH {
439            return Err(Error::param("dst", "domain separation tag too long"));
440        }
441
442        let ell = (len_in_bytes + HASH_OUTPUT_SIZE - 1) / HASH_OUTPUT_SIZE;
443
444        if ell > 255 {
445            return Err(Error::param("len_in_bytes", "requested output too long"));
446        }
447
448        let dst_prime_len = dst.len() as u8;
449
450        let mut hasher = Sha256::new();
451        hasher.update(&[0u8; HASH_OUTPUT_SIZE])?;
452        hasher.update(msg)?;
453        hasher.update(&((len_in_bytes as u16).to_be_bytes()))?;
454        hasher.update(&[0u8])?;
455        hasher.update(dst)?;
456        hasher.update(&[dst_prime_len])?;
457
458        let b_0 = hasher.finalize()?;
459
460        let mut uniform_bytes = Vec::with_capacity(len_in_bytes);
461        let mut b_i = vec![0u8; HASH_OUTPUT_SIZE];
462
463        for i in 1..=ell {
464            let mut hasher = Sha256::new();
465            if i == 1 {
466                hasher.update(&[0u8; HASH_OUTPUT_SIZE])?;
467            } else {
468                let mut xored = [0u8; HASH_OUTPUT_SIZE];
469                for j in 0..HASH_OUTPUT_SIZE {
470                    xored[j] = b_0.as_ref()[j] ^ b_i[j];
471                }
472                hasher.update(&xored)?;
473            }
474            hasher.update(&[i as u8])?;
475            hasher.update(dst)?;
476            hasher.update(&[dst_prime_len])?;
477            let digest = hasher.finalize()?;
478            b_i.copy_from_slice(digest.as_ref());
479            uniform_bytes.extend_from_slice(&b_i);
480        }
481
482        uniform_bytes.truncate(len_in_bytes);
483        Ok(uniform_bytes)
484    }
485
486    /// Hashes arbitrary data to a scalar field element using SHA-256.
487    ///
488    /// This function implements a standards-compliant hash-to-field method following
489    /// the IETF hash-to-curve specification using expand_message_xmd with SHA-256.
490    ///
491    /// # Arguments
492    /// * `data`: The input data to hash.
493    /// * `dst`: A Domain Separation Tag (DST) to ensure hashes are unique per application context.
494    ///
495    /// # Returns
496    /// A `Result` containing the `Scalar` or an error.
497    pub fn hash_to_field(data: &[u8], dst: &[u8]) -> Result<Self> {
498        let expanded = Self::expand_message_xmd(data, dst, 64)?;
499        let mut expanded_array = [0u8; 64];
500        expanded_array.copy_from_slice(&expanded);
501        Ok(Self::from_bytes_wide(&expanded_array))
502    }
503
504    fn from_u512(limbs: [u64; 8]) -> Scalar {
505        let d0 = Scalar([limbs[0], limbs[1], limbs[2], limbs[3]]);
506        let d1 = Scalar([limbs[4], limbs[5], limbs[6], limbs[7]]);
507        d0 * R2 + d1 * R3
508    }
509
510    /// Creates a scalar from four `u64` limbs (little-endian). This function will
511    /// convert the raw integer into Montgomery form.
512    pub const fn from_raw(val: [u64; 4]) -> Self {
513        (&Scalar(val)).mul(&R2)
514    }
515
516    /// Computes the square of this scalar.
517    #[inline]
518    pub const fn square(&self) -> Scalar {
519        let (r1, carry) = mac(0, self.0[0], self.0[1], 0);
520        let (r2, carry) = mac(0, self.0[0], self.0[2], carry);
521        let (r3, r4) = mac(0, self.0[0], self.0[3], carry);
522
523        let (r3, carry) = mac(r3, self.0[1], self.0[2], 0);
524        let (r4, r5) = mac(r4, self.0[1], self.0[3], carry);
525
526        let (r5, r6) = mac(r5, self.0[2], self.0[3], 0);
527
528        let r7 = r6 >> 63;
529        let r6 = (r6 << 1) | (r5 >> 63);
530        let r5 = (r5 << 1) | (r4 >> 63);
531        let r4 = (r4 << 1) | (r3 >> 63);
532        let r3 = (r3 << 1) | (r2 >> 63);
533        let r2 = (r2 << 1) | (r1 >> 63);
534        let r1 = r1 << 1;
535
536        let (r0, carry) = mac(0, self.0[0], self.0[0], 0);
537        let (r1, carry) = adc(0, r1, carry);
538        let (r2, carry) = mac(r2, self.0[1], self.0[1], carry);
539        let (r3, carry) = adc(0, r3, carry);
540        let (r4, carry) = mac(r4, self.0[2], self.0[2], carry);
541        let (r5, carry) = adc(0, r5, carry);
542        let (r6, carry) = mac(r6, self.0[3], self.0[3], carry);
543        let (r7, _) = adc(0, r7, carry);
544
545        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
546    }
547
548    /// Computes `x` raised to the power of `2^k`.
549    #[inline]
550    pub fn pow2k(mut x: Scalar, mut k: u32) -> Scalar {
551        while k > 0 {
552            x = x.square();
553            k -= 1;
554        }
555        x
556    }
557
558    /// Variable-time exponentiation by a 256-bit little-endian exponent.
559    fn pow_vartime(&self, by: &[u64; 4]) -> Self {
560        let mut res = Self::one();
561        for limb in by.iter().rev() {
562            for i in (0..64).rev() {
563                res = res.square();
564                if ((limb >> i) & 1) == 1 {
565                    res *= self;
566                }
567            }
568        }
569        res
570    }
571
572    /// Computes the square root of this scalar using Tonelli-Shanks.
573    /// Returns `Some(s)` with `s^2 = self` if a square root exists, else `None`.
574    pub fn sqrt(&self) -> subtle::CtOption<Self> {
575        // Trivial case: sqrt(0) = 0
576        if bool::from(self.is_zero()) {
577            return subtle::CtOption::new(Scalar::zero(), subtle::Choice::from(1));
578        }
579
580        // Choose a fixed quadratic non-residue. For this field, 5 works.
581        let z = Scalar::from(5u64);
582
583        // Precompute values per Tonelli-Shanks
584        let mut c = z.pow_vartime(&TONELLI_T); // c = z^T
585        let mut t = self.pow_vartime(&TONELLI_T); // t = a^T
586        let mut x = self.pow_vartime(&TONELLI_TP1_DIV2); // x = a^((T+1)/2)
587        let mut m = S;
588
589        // If t == 1, we guessed the root correctly.
590        if bool::from(subtle::ConstantTimeEq::ct_eq(&t, &Scalar::one())) {
591            return subtle::CtOption::new(x, subtle::ConstantTimeEq::ct_eq(&x.square(), self));
592        }
593
594        // Main Tonelli-Shanks loop
595        loop {
596            // Find smallest i in [1, m) with t^(2^i) == 1
597            let mut i = 1u32;
598            let mut t2i = t.square();
599            while i < m && !bool::from(subtle::ConstantTimeEq::ct_eq(&t2i, &Scalar::one())) {
600                t2i = t2i.square();
601                i += 1;
602            }
603
604            // If i == m, then a is not a square root
605            if i == m {
606                return subtle::CtOption::new(Scalar::zero(), subtle::Choice::from(0));
607            }
608
609            // b = c^{2^(m - i - 1)}
610            let b = Scalar::pow2k(c, m - i - 1);
611
612            // Update variables
613            x = x * b;
614            let b2 = b.square();
615            t = t * b2;
616            c = b2;
617            m = i;
618
619            // If t is now 1, we are done
620            if bool::from(subtle::ConstantTimeEq::ct_eq(&t, &Scalar::one())) {
621                break;
622            }
623        }
624
625        // Final constant-time check to ensure correctness
626        subtle::CtOption::new(x, subtle::ConstantTimeEq::ct_eq(&x.square(), self))
627    }
628
629    /// Computes the multiplicative inverse of this scalar, if it is non-zero.
630    pub fn invert(&self) -> CtOption<Self> {
631        #[inline(always)]
632        fn square_assign_multi(n: &mut Scalar, num_times: usize) {
633            for _ in 0..num_times {
634                *n = n.square();
635            }
636        }
637        // Addition chain from github.com/kwantam/addchain
638        let mut t0 = self.square();
639        let mut t1 = t0 * self;
640        let mut t16 = t0.square();
641        let mut t6 = t16.square();
642        let mut t5 = t6 * t0;
643        t0 = t6 * t16;
644        let mut t12 = t5 * t16;
645        let mut t2 = t6.square();
646        let mut t7 = t5 * t6;
647        let mut t15 = t0 * t5;
648        let mut t17 = t12.square();
649        t1 *= t17;
650        let mut t3 = t7 * t2;
651        let t8 = t1 * t17;
652        let t4 = t8 * t2;
653        let t9 = t8 * t7;
654        t7 = t4 * t5;
655        let t11 = t4 * t17;
656        t5 = t9 * t17;
657        let t14 = t7 * t15;
658        let t13 = t11 * t12;
659        t12 = t11 * t17;
660        t15 *= &t12;
661        t16 *= &t15;
662        t3 *= &t16;
663        t17 *= &t3;
664        t0 *= &t17;
665        t6 *= &t0;
666        t2 *= &t6;
667        square_assign_multi(&mut t0, 8);
668        t0 *= &t17;
669        square_assign_multi(&mut t0, 9);
670        t0 *= &t16;
671        square_assign_multi(&mut t0, 9);
672        t0 *= &t15;
673        square_assign_multi(&mut t0, 9);
674        t0 *= &t15;
675        square_assign_multi(&mut t0, 7);
676        t0 *= &t14;
677        square_assign_multi(&mut t0, 7);
678        t0 *= &t13;
679        square_assign_multi(&mut t0, 10);
680        t0 *= &t12;
681        square_assign_multi(&mut t0, 9);
682        t0 *= &t11;
683        square_assign_multi(&mut t0, 8);
684        t0 *= &t8;
685        square_assign_multi(&mut t0, 8);
686        t0 *= self;
687        square_assign_multi(&mut t0, 14);
688        t0 *= &t9;
689        square_assign_multi(&mut t0, 10);
690        t0 *= &t8;
691        square_assign_multi(&mut t0, 15);
692        t0 *= &t7;
693        square_assign_multi(&mut t0, 10);
694        t0 *= &t6;
695        square_assign_multi(&mut t0, 8);
696        t0 *= &t5;
697        square_assign_multi(&mut t0, 16);
698        t0 *= &t3;
699        square_assign_multi(&mut t0, 8);
700        t0 *= &t2;
701        square_assign_multi(&mut t0, 7);
702        t0 *= &t4;
703        square_assign_multi(&mut t0, 9);
704        t0 *= &t2;
705        square_assign_multi(&mut t0, 8);
706        t0 *= &t3;
707        square_assign_multi(&mut t0, 8);
708        t0 *= &t2;
709        square_assign_multi(&mut t0, 8);
710        t0 *= &t2;
711        square_assign_multi(&mut t0, 8);
712        t0 *= &t2;
713        square_assign_multi(&mut t0, 8);
714        t0 *= &t3;
715        square_assign_multi(&mut t0, 8);
716        t0 *= &t2;
717        square_assign_multi(&mut t0, 8);
718        t0 *= &t2;
719        square_assign_multi(&mut t0, 5);
720        t0 *= &t1;
721        square_assign_multi(&mut t0, 5);
722        t0 *= &t1;
723
724        CtOption::new(t0, !subtle::ConstantTimeEq::ct_eq(self, &Self::zero()))
725    }
726
727    #[inline(always)]
728    const fn montgomery_reduce(
729        r0: u64,
730        r1: u64,
731        r2: u64,
732        r3: u64,
733        r4: u64,
734        r5: u64,
735        r6: u64,
736        r7: u64,
737    ) -> Self {
738        let k = r0.wrapping_mul(INV);
739        let (_, carry) = mac(r0, k, MODULUS.0[0], 0);
740        let (r1, carry) = mac(r1, k, MODULUS.0[1], carry);
741        let (r2, carry) = mac(r2, k, MODULUS.0[2], carry);
742        let (r3, carry) = mac(r3, k, MODULUS.0[3], carry);
743        let (r4, carry2) = adc(r4, 0, carry);
744
745        let k = r1.wrapping_mul(INV);
746        let (_, carry) = mac(r1, k, MODULUS.0[0], 0);
747        let (r2, carry) = mac(r2, k, MODULUS.0[1], carry);
748        let (r3, carry) = mac(r3, k, MODULUS.0[2], carry);
749        let (r4, carry) = mac(r4, k, MODULUS.0[3], carry);
750        let (r5, carry2) = adc(r5, carry2, carry);
751
752        let k = r2.wrapping_mul(INV);
753        let (_, carry) = mac(r2, k, MODULUS.0[0], 0);
754        let (r3, carry) = mac(r3, k, MODULUS.0[1], carry);
755        let (r4, carry) = mac(r4, k, MODULUS.0[2], carry);
756        let (r5, carry) = mac(r5, k, MODULUS.0[3], carry);
757        let (r6, carry2) = adc(r6, carry2, carry);
758
759        let k = r3.wrapping_mul(INV);
760        let (_, carry) = mac(r3, k, MODULUS.0[0], 0);
761        let (r4, carry) = mac(r4, k, MODULUS.0[1], carry);
762        let (r5, carry) = mac(r5, k, MODULUS.0[2], carry);
763        let (r6, carry) = mac(r6, k, MODULUS.0[3], carry);
764        let (r7, _) = adc(r7, carry2, carry);
765
766        (&Scalar([r4, r5, r6, r7])).sub(&MODULUS)
767    }
768
769    /// Multiplies this scalar by another.
770    #[inline]
771    pub const fn mul(&self, rhs: &Self) -> Self {
772        let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0);
773        let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry);
774        let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry);
775        let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry);
776
777        let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0);
778        let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry);
779        let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry);
780        let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry);
781
782        let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0);
783        let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry);
784        let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry);
785        let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry);
786
787        let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0);
788        let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry);
789        let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry);
790        let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry);
791
792        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
793    }
794
795    /// Subtracts another scalar from this one.
796    #[inline]
797    pub const fn sub(&self, rhs: &Self) -> Self {
798        let (d0, borrow) = sbb(self.0[0], rhs.0[0], 0);
799        let (d1, borrow) = sbb(self.0[1], rhs.0[1], borrow);
800        let (d2, borrow) = sbb(self.0[2], rhs.0[2], borrow);
801        let (d3, borrow) = sbb(self.0[3], rhs.0[3], borrow);
802
803        let (d0, carry) = adc(d0, MODULUS.0[0] & borrow, 0);
804        let (d1, carry) = adc(d1, MODULUS.0[1] & borrow, carry);
805        let (d2, carry) = adc(d2, MODULUS.0[2] & borrow, carry);
806        let (d3, _) = adc(d3, MODULUS.0[3] & borrow, carry);
807
808        Scalar([d0, d1, d2, d3])
809    }
810
811    /// Adds another scalar to this one.
812    #[inline]
813    pub const fn add(&self, rhs: &Self) -> Self {
814        let (d0, carry) = adc(self.0[0], rhs.0[0], 0);
815        let (d1, carry) = adc(self.0[1], rhs.0[1], carry);
816        let (d2, carry) = adc(self.0[2], rhs.0[2], carry);
817        let (d3, _) = adc(self.0[3], rhs.0[3], carry);
818
819        (&Scalar([d0, d1, d2, d3])).sub(&MODULUS)
820    }
821
822    /// Computes the additive negation of this scalar.
823    #[inline]
824    pub const fn neg(&self) -> Self {
825        let (d0, borrow) = sbb(MODULUS.0[0], self.0[0], 0);
826        let (d1, borrow) = sbb(MODULUS.0[1], self.0[1], borrow);
827        let (d2, borrow) = sbb(MODULUS.0[2], self.0[2], borrow);
828        let (d3, _) = sbb(MODULUS.0[3], self.0[3], borrow);
829
830        let mask = (((self.0[0] | self.0[1] | self.0[2] | self.0[3]) == 0) as u64).wrapping_sub(1);
831
832        Scalar([d0 & mask, d1 & mask, d2 & mask, d3 & mask])
833    }
834}
835
836impl From<Scalar> for [u8; 32] {
837    fn from(value: Scalar) -> [u8; 32] {
838        value.to_bytes()
839    }
840}
841
842impl<'a> From<&'a Scalar> for [u8; 32] {
843    fn from(value: &'a Scalar) -> [u8; 32] {
844        value.to_bytes()
845    }
846}
847
848impl<T> core::iter::Sum<T> for Scalar
849where
850    T: core::borrow::Borrow<Scalar>,
851{
852    fn sum<I>(iter: I) -> Self
853    where
854        I: Iterator<Item = T>,
855    {
856        iter.fold(Self::zero(), |acc, item| acc + item.borrow())
857    }
858}
859
860impl<T> core::iter::Product<T> for Scalar
861where
862    T: core::borrow::Borrow<Scalar>,
863{
864    fn product<I>(iter: I) -> Self
865    where
866        I: Iterator<Item = T>,
867    {
868        iter.fold(Self::one(), |acc, item| acc * item.borrow())
869    }
870}
871
872// Tests
873#[test]
874fn test_inv() {
875    // Verify INV constant
876    let mut inv = 1u64;
877    for _ in 0..63 {
878        inv = inv.wrapping_mul(inv);
879        inv = inv.wrapping_mul(MODULUS.0[0]);
880    }
881    inv = inv.wrapping_neg();
882    assert_eq!(inv, INV);
883}
884
885#[cfg(feature = "std")]
886#[test]
887fn test_debug() {
888    assert_eq!(
889        format!("{:?}", Scalar::zero()),
890        "0x0000000000000000000000000000000000000000000000000000000000000000"
891    );
892    assert_eq!(
893        format!("{:?}", Scalar::one()),
894        "0x0000000000000000000000000000000000000000000000000000000000000001"
895    );
896    // R is the Montgomery representation of 1. The Debug trait should perform the
897    // conversion, so it should also format to "1".
898    assert_eq!(
899        format!("{:?}", R),
900        "0x0000000000000000000000000000000000000000000000000000000000000001"
901    );
902}
903
904#[test]
905fn test_equality() {
906    assert_eq!(Scalar::zero(), Scalar::zero());
907    assert_eq!(Scalar::one(), Scalar::one());
908    #[allow(clippy::eq_op)]
909    {
910        assert_eq!(R2, R2);
911    }
912
913    assert!(Scalar::zero() != Scalar::one());
914    assert!(Scalar::one() != R2);
915}
916
917#[test]
918fn test_to_bytes() {
919    assert_eq!(
920        Scalar::zero().to_bytes(),
921        [
922            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
923            0, 0, 0
924        ]
925    );
926
927    assert_eq!(
928        Scalar::one().to_bytes(),
929        [
930            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
931            0, 0, 0
932        ]
933    );
934
935    // R is the Montgomery representation of 1. to_bytes() should perform the
936    // conversion, so it should also produce the bytes for "1".
937    assert_eq!(
938        R.to_bytes(),
939        [
940            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
941            0, 0, 0
942        ]
943    );
944
945    assert_eq!(
946        (-&Scalar::one()).to_bytes(),
947        [
948            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
949            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
950        ]
951    );
952}
953
954#[test]
955fn test_from_bytes() {
956    let mut a = R2;
957
958    for _ in 0..100 {
959        let bytes = a.to_bytes();
960        let b = Scalar::from_bytes(&bytes).unwrap();
961        assert_eq!(a, b);
962
963        // Test negation roundtrip
964        let bytes = (-a).to_bytes();
965        let b = Scalar::from_bytes(&bytes).unwrap();
966        assert_eq!(-a, b);
967
968        a = a.square();
969    }
970}
971
972#[cfg(test)]
973const LARGEST: Scalar = Scalar([
974    0xffff_ffff_0000_0000,
975    0x53bd_a402_fffe_5bfe,
976    0x3339_d808_09a1_d805,
977    0x73ed_a753_299d_7d48,
978]);
979
980#[test]
981fn test_addition() {
982    let mut tmp = LARGEST;
983    tmp += &LARGEST;
984
985    assert_eq!(
986        tmp,
987        Scalar([
988            0xffff_fffe_ffff_ffff,
989            0x53bd_a402_fffe_5bfe,
990            0x3339_d808_09a1_d805,
991            0x73ed_a753_299d_7d48,
992        ])
993    );
994
995    let mut tmp = LARGEST;
996    tmp += &Scalar([1, 0, 0, 0]);
997
998    assert_eq!(tmp, Scalar::zero());
999}
1000
1001#[test]
1002fn test_inversion() {
1003    assert!(bool::from(Scalar::zero().invert().is_none()));
1004    assert_eq!(Scalar::one().invert().unwrap(), Scalar::one());
1005    assert_eq!((-&Scalar::one()).invert().unwrap(), -&Scalar::one());
1006
1007    let mut tmp = R2;
1008
1009    for _ in 0..100 {
1010        let mut tmp2 = tmp.invert().unwrap();
1011        tmp2.mul_assign(&tmp);
1012
1013        assert_eq!(tmp2, Scalar::one());
1014
1015        tmp.add_assign(&R2);
1016    }
1017}
1018
1019#[test]
1020fn test_sqrt() {
1021    // Test with zero
1022    assert_eq!(Scalar::zero().sqrt().unwrap(), Scalar::zero());
1023
1024    // Test with one
1025    assert_eq!(Scalar::one().sqrt().unwrap(), Scalar::one());
1026
1027    // Test with a known square
1028    let four = Scalar::from(4u64);
1029    let two = Scalar::from(2u64);
1030    let neg_two = -two;
1031
1032    let sqrt_four = four.sqrt().unwrap();
1033    assert!(sqrt_four == two || sqrt_four == neg_two);
1034    assert_eq!(sqrt_four.square(), four);
1035
1036    // Test with a random square
1037    let s = Scalar::from(123456789u64);
1038    let s_sq = s.square();
1039    let s_sqrt = s_sq.sqrt().unwrap();
1040    assert!(s_sqrt == s || s_sqrt == -s);
1041    assert_eq!(s_sqrt.square(), s_sq);
1042
1043    // Test with a non-residue.
1044    // For this field, 5 is a quadratic non-residue.
1045    let five = Scalar::from(5u64);
1046    assert!(bool::from(five.sqrt().is_none()));
1047
1048    // Test with a residue.
1049    // For a prime q where q = 1 mod 4, -1 is a residue.
1050    let neg_one = -Scalar::one();
1051    let neg_one_sqrt = neg_one.sqrt().unwrap();
1052    assert_eq!(neg_one_sqrt.square(), neg_one);
1053
1054    // Test roundtrip for many values
1055    let mut val = R2;
1056    for _ in 0..100 {
1057        let sq = val.square();
1058        let sqrt = sq.sqrt().unwrap();
1059        assert!(sqrt == val || sqrt == -val);
1060        val += R;
1061    }
1062}
1063
1064#[test]
1065fn test_from_raw() {
1066    assert_eq!(
1067        Scalar::from_raw([
1068            0x0001_ffff_fffd,
1069            0x5884_b7fa_0003_4802,
1070            0x998c_4fef_ecbc_4ff5,
1071            0x1824_b159_acc5_056f,
1072        ]),
1073        Scalar::from_raw([0xffff_ffff_ffff_ffff; 4])
1074    );
1075
1076    assert_eq!(Scalar::from_raw(MODULUS.0), Scalar::zero());
1077
1078    assert_eq!(Scalar::from_raw([1, 0, 0, 0]), R);
1079}
1080
1081#[test]
1082fn test_scalar_hash_to_field() {
1083    let data1 = b"some input data";
1084    let data2 = b"different input data";
1085    let dst1 = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_"; // Standard DST format
1086    let dst2 = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
1087
1088    // 1. Different data should produce different scalars
1089    let s1 = Scalar::hash_to_field(data1, dst1).unwrap();
1090    let s2 = Scalar::hash_to_field(data2, dst1).unwrap();
1091    assert_ne!(s1, s2);
1092
1093    // 2. Same data with different DSTs should produce different scalars
1094    let s3 = Scalar::hash_to_field(data1, dst1).unwrap();
1095    let s4 = Scalar::hash_to_field(data1, dst2).unwrap();
1096    assert_ne!(s3, s4);
1097
1098    // 3. Hashing should be deterministic
1099    let s5 = Scalar::hash_to_field(data1, dst1).unwrap();
1100    assert_eq!(s3, s5);
1101
1102    // 4. Verify output is always valid scalar (less than modulus)
1103    for test_case in &[b"" as &[u8], b"a", b"test", &[0xFF; 100], &[0x00; 64]] {
1104        let scalar = Scalar::hash_to_field(test_case, dst1).unwrap();
1105        // The scalar should already be reduced, so converting to/from bytes should work
1106        let bytes = scalar.to_bytes();
1107        let scalar2 = Scalar::from_bytes(&bytes).unwrap();
1108        assert_eq!(scalar, scalar2, "Output should be a valid reduced scalar");
1109    }
1110
1111    // 5. Test that the expansion reduces bias appropriately
1112    // With 64 bytes (512 bits) being reduced to ~255 bits, bias should be negligible
1113    let mut scalars = Vec::new();
1114    for i in 0u32..100 {
1115        let data = i.to_le_bytes();
1116        let s = Scalar::hash_to_field(&data, dst1).unwrap();
1117        scalars.push(s);
1118    }
1119    // All should be different (no collisions in small sample)
1120    for i in 0..scalars.len() {
1121        for j in i + 1..scalars.len() {
1122            assert_ne!(
1123                scalars[i], scalars[j],
1124                "Unexpected collision at {} and {}",
1125                i, j
1126            );
1127        }
1128    }
1129
1130    // 6. Test empty DST and empty data edge cases
1131    let s_empty = Scalar::hash_to_field(b"", b"").unwrap();
1132    let s_empty2 = Scalar::hash_to_field(b"", b"").unwrap();
1133    assert_eq!(s_empty, s_empty2, "Empty input should still be deterministic");
1134
1135    // 7. Verify that DST length is properly included (catches common implementation bugs)
1136    let dst_short = b"A";
1137    let dst_long = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"; // 50 A's
1138    let s_short = Scalar::hash_to_field(data1, dst_short).unwrap();
1139    let s_long = Scalar::hash_to_field(data1, dst_long).unwrap();
1140    assert_ne!(s_short, s_long, "DST length should affect output");
1141
1142    // 8. Test mathematical properties: hash(data) should be uniformly distributed
1143    // We can't test true uniformity easily, but we can check it's not always even/odd
1144    let mut has_odd = false;
1145    let mut has_even = false;
1146    for i in 0u8..20 {
1147        let s = Scalar::hash_to_field(&[i], dst1).unwrap();
1148        // Check the least significant bit
1149        if s.to_bytes()[0] & 1 == 0 {
1150            has_even = true;
1151        } else {
1152            has_odd = true;
1153        }
1154    }
1155    assert!(has_odd && has_even, "Hash output should have both odd and even values");
1156
1157    // 9. Test expand_message_xmd internal function with basic test vectors
1158    // These help ensure our implementation follows the standard
1159    let expanded = Scalar::expand_message_xmd(b"", b"QUUX-V01-CS02-with-SHA256", 32).unwrap();
1160    assert_eq!(expanded.len(), 32);
1161
1162    // Basic sanity check: different messages produce different expansions
1163    let expanded1 = Scalar::expand_message_xmd(b"msg1", b"dst", 64).unwrap();
1164    let expanded2 = Scalar::expand_message_xmd(b"msg2", b"dst", 64).unwrap();
1165    assert_ne!(expanded1, expanded2);
1166}
1167
1168#[cfg(feature = "zeroize")]
1169#[test]
1170fn test_zeroize() {
1171    use zeroize::Zeroize;
1172
1173    let mut a = Scalar::from_raw([
1174        0x1fff_3231_233f_fffd,
1175        0x4884_b7fa_0003_4802,
1176        0x998c_4fef_ecbc_4ff3,
1177        0x1824_b159_acc5_0562,
1178    ]);
1179    a.zeroize();
1180    // Fixed: disambiguate ct_eq
1181    assert!(bool::from(subtle::ConstantTimeEq::ct_eq(
1182        &a,
1183        &Scalar::zero()
1184    )));
1185}