Skip to main content

dcrypt_algorithms/ec/bls12_381/
scalar.rs

1//! BLS12-381 scalar field F_q where q = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
2
3use crate::error::{Error, Result};
4use crate::hash::{sha2::Sha256, HashFunction};
5use crate::types::{ByteSerializable, ConstantTimeEq as DcryptConstantTimeEq, SecureZeroingType};
6use core::fmt;
7use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
8use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
9
10// Arithmetic helpers
11/// Compute a + b + carry, returning (result, carry)
12#[inline(always)]
13const fn adc(a: u64, b: u64, carry: u64) -> (u64, u64) {
14    let ret = (a as u128) + (b as u128) + (carry as u128);
15    (ret as u64, (ret >> 64) as u64)
16}
17
18/// Compute a - (b + borrow), returning (result, borrow)
19#[inline(always)]
20const fn sbb(a: u64, b: u64, borrow: u64) -> (u64, u64) {
21    let ret = (a as u128).wrapping_sub((b as u128) + ((borrow >> 63) as u128));
22    (ret as u64, (ret >> 64) as u64)
23}
24
25/// Compute a + (b * c) + carry, returning (result, carry)
26#[inline(always)]
27const fn mac(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) {
28    let ret = (a as u128) + ((b as u128) * (c as u128)) + (carry as u128);
29    (ret as u64, (ret >> 64) as u64)
30}
31
32/// Scalar field element of BLS12-381
33/// Internal: Four 64-bit limbs in little-endian Montgomery form
34#[derive(Clone, Copy, Eq)]
35pub struct Scalar(pub(crate) [u64; 4]);
36
37impl fmt::Debug for Scalar {
38    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39        let tmp = self.to_bytes();
40        write!(f, "0x")?;
41        for &b in tmp.iter().rev() {
42            write!(f, "{:02x}", b)?;
43        }
44        Ok(())
45    }
46}
47
48impl fmt::Display for Scalar {
49    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
50        write!(f, "{:?}", self)
51    }
52}
53
54impl From<u64> for Scalar {
55    fn from(val: u64) -> Scalar {
56        Scalar([val, 0, 0, 0]) * R2
57    }
58}
59
60impl ConstantTimeEq for Scalar {
61    fn ct_eq(&self, other: &Self) -> Choice {
62        self.0[0].ct_eq(&other.0[0])
63            & self.0[1].ct_eq(&other.0[1])
64            & self.0[2].ct_eq(&other.0[2])
65            & self.0[3].ct_eq(&other.0[3])
66    }
67}
68
69impl PartialEq for Scalar {
70    #[inline]
71    fn eq(&self, other: &Self) -> bool {
72        bool::from(subtle::ConstantTimeEq::ct_eq(self, other))
73    }
74}
75
76impl ConditionallySelectable for Scalar {
77    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
78        Scalar([
79            u64::conditional_select(&a.0[0], &b.0[0], choice),
80            u64::conditional_select(&a.0[1], &b.0[1], choice),
81            u64::conditional_select(&a.0[2], &b.0[2], choice),
82            u64::conditional_select(&a.0[3], &b.0[3], choice),
83        ])
84    }
85}
86
87// Constants
88const MODULUS: Scalar = Scalar([
89    0xffff_ffff_0000_0001,
90    0x53bd_a402_fffe_5bfe,
91    0x3339_d808_09a1_d805,
92    0x73ed_a753_299d_7d48,
93]);
94
95/// INV = -(q^{-1} mod 2^64) mod 2^64
96const INV: u64 = 0xffff_fffe_ffff_ffff;
97
98/// R = 2^256 mod q
99const R: Scalar = Scalar([
100    0x0000_0001_ffff_fffe,
101    0x5884_b7fa_0003_4802,
102    0x998c_4fef_ecbc_4ff5,
103    0x1824_b159_acc5_056f,
104]);
105
106/// R^2 = 2^512 mod q
107const R2: Scalar = Scalar([
108    0xc999_e990_f3f2_9c6d,
109    0x2b6c_edcb_8792_5c23,
110    0x05d3_1496_7254_398f,
111    0x0748_d9d9_9f59_ff11,
112]);
113
114/// R^3 = 2^768 mod q
115const R3: Scalar = Scalar([
116    0xc62c_1807_439b_73af,
117    0x1b3e_0d18_8cf0_6990,
118    0x73d1_3c71_c7b5_f418,
119    0x6e2a_5bb9_c8db_33e9,
120]);
121
122// Constants for Tonelli-Shanks square root algorithm
123// 2-adicity of (r - 1)
124const S: u32 = 32;
125
126// T = (r - 1) / 2^S  (odd part)
127const TONELLI_T: [u64; 4] = [
128    0xfffe_5bfe_ffff_ffff,
129    0x09a1_d805_53bd_a402,
130    0x299d_7d48_3339_d808,
131    0x0000_0000_73ed_a753,
132];
133
134// (T + 1)/2, used to initialize x = a^((T+1)/2)
135const TONELLI_TP1_DIV2: [u64; 4] = [
136    0x7fff_2dff_8000_0000,
137    0x04d0_ec02_a9de_d201,
138    0x94ce_bea4_199c_ec04,
139    0x0000_0000_39f6_d3a9,
140];
141
142// Exponent (r-1)/2, the Legendre exponent
143#[allow(dead_code)]
144const LEGENDRE_EXP: [u64; 4] = [
145    0x7fff_ffff_8000_0000,
146    0xa9de_d201_7fff_2dff,
147    0x199c_ec04_04d0_ec02,
148    0x39f6_d3a9_94ce_bea4,
149];
150
151impl<'a> Neg for &'a Scalar {
152    type Output = Scalar;
153
154    #[inline]
155    fn neg(self) -> Scalar {
156        self.neg()
157    }
158}
159
160impl Neg for Scalar {
161    type Output = Scalar;
162
163    #[inline]
164    fn neg(self) -> Scalar {
165        -&self
166    }
167}
168
169impl<'a, 'b> Sub<&'b Scalar> for &'a Scalar {
170    type Output = Scalar;
171
172    #[inline]
173    fn sub(self, rhs: &'b Scalar) -> Scalar {
174        self.sub(rhs)
175    }
176}
177
178impl<'a, 'b> Add<&'b Scalar> for &'a Scalar {
179    type Output = Scalar;
180
181    #[inline]
182    fn add(self, rhs: &'b Scalar) -> Scalar {
183        self.add(rhs)
184    }
185}
186
187impl<'a, 'b> Mul<&'b Scalar> for &'a Scalar {
188    type Output = Scalar;
189
190    #[inline]
191    fn mul(self, rhs: &'b Scalar) -> Scalar {
192        self.mul(rhs)
193    }
194}
195
196// Binop implementations
197impl<'b> Add<&'b Scalar> for Scalar {
198    type Output = Scalar;
199    #[inline]
200    fn add(self, rhs: &'b Scalar) -> Scalar {
201        &self + rhs
202    }
203}
204
205impl<'a> Add<Scalar> for &'a Scalar {
206    type Output = Scalar;
207    #[inline]
208    fn add(self, rhs: Scalar) -> Scalar {
209        self + &rhs
210    }
211}
212
213impl Add<Scalar> for Scalar {
214    type Output = Scalar;
215    #[inline]
216    fn add(self, rhs: Scalar) -> Scalar {
217        &self + &rhs
218    }
219}
220
221impl<'b> Sub<&'b Scalar> for Scalar {
222    type Output = Scalar;
223    #[inline]
224    fn sub(self, rhs: &'b Scalar) -> Scalar {
225        &self - rhs
226    }
227}
228
229impl<'a> Sub<Scalar> for &'a Scalar {
230    type Output = Scalar;
231    #[inline]
232    fn sub(self, rhs: Scalar) -> Scalar {
233        self - &rhs
234    }
235}
236
237impl Sub<Scalar> for Scalar {
238    type Output = Scalar;
239    #[inline]
240    fn sub(self, rhs: Scalar) -> Scalar {
241        &self - &rhs
242    }
243}
244
245impl SubAssign<Scalar> for Scalar {
246    #[inline]
247    fn sub_assign(&mut self, rhs: Scalar) {
248        *self = &*self - &rhs;
249    }
250}
251
252impl AddAssign<Scalar> for Scalar {
253    #[inline]
254    fn add_assign(&mut self, rhs: Scalar) {
255        *self = &*self + &rhs;
256    }
257}
258
259impl<'b> SubAssign<&'b Scalar> for Scalar {
260    #[inline]
261    fn sub_assign(&mut self, rhs: &'b Scalar) {
262        *self = &*self - rhs;
263    }
264}
265
266impl<'b> AddAssign<&'b Scalar> for Scalar {
267    #[inline]
268    fn add_assign(&mut self, rhs: &'b Scalar) {
269        *self = &*self + rhs;
270    }
271}
272
273impl<'b> Mul<&'b Scalar> for Scalar {
274    type Output = Scalar;
275    #[inline]
276    fn mul(self, rhs: &'b Scalar) -> Scalar {
277        &self * rhs
278    }
279}
280
281impl<'a> Mul<Scalar> for &'a Scalar {
282    type Output = Scalar;
283    #[inline]
284    fn mul(self, rhs: Scalar) -> Scalar {
285        self * &rhs
286    }
287}
288
289impl Mul<Scalar> for Scalar {
290    type Output = Scalar;
291    #[inline]
292    fn mul(self, rhs: Scalar) -> Scalar {
293        &self * &rhs
294    }
295}
296
297impl MulAssign<Scalar> for Scalar {
298    #[inline]
299    fn mul_assign(&mut self, rhs: Scalar) {
300        *self = &*self * &rhs;
301    }
302}
303
304impl<'b> MulAssign<&'b Scalar> for Scalar {
305    #[inline]
306    fn mul_assign(&mut self, rhs: &'b Scalar) {
307        *self = &*self * rhs;
308    }
309}
310
311impl Default for Scalar {
312    #[inline]
313    fn default() -> Self {
314        Self::zero()
315    }
316}
317
318#[cfg(feature = "zeroize")]
319impl zeroize::DefaultIsZeroes for Scalar {}
320
321impl ByteSerializable for Scalar {
322    fn to_bytes(&self) -> Vec<u8> {
323        self.to_bytes().to_vec()
324    }
325
326    fn from_bytes(bytes: &[u8]) -> Result<Self> {
327        if bytes.len() != 32 {
328            return Err(Error::Length {
329                context: "Scalar::from_bytes",
330                expected: 32,
331                actual: bytes.len(),
332            });
333        }
334
335        let mut array = [0u8; 32];
336        array.copy_from_slice(bytes);
337
338        Scalar::from_bytes(&array)
339            .into_option()
340            .ok_or_else(|| Error::param("scalar_bytes", "non-canonical scalar"))
341    }
342}
343
344impl DcryptConstantTimeEq for Scalar {
345    fn ct_eq(&self, other: &Self) -> bool {
346        bool::from(subtle::ConstantTimeEq::ct_eq(self, other))
347    }
348}
349
350impl SecureZeroingType for Scalar {
351    fn zeroed() -> Self {
352        Self::zero()
353    }
354}
355
356impl Scalar {
357    /// Additive identity
358    #[inline]
359    pub const fn zero() -> Scalar {
360        Scalar([0, 0, 0, 0])
361    }
362
363    /// Multiplicative identity
364    #[inline]
365    pub const fn one() -> Scalar {
366        R
367    }
368
369    /// Check if element is zero.
370    #[inline]
371    pub fn is_zero(&self) -> Choice {
372        (self.0[0] | self.0[1] | self.0[2] | self.0[3]).ct_eq(&0)
373    }
374
375    /// Double this element
376    #[inline]
377    pub const fn double(&self) -> Scalar {
378        self.add(self)
379    }
380
381    /// Create from little-endian bytes if canonical
382    pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<Scalar> {
383        let mut tmp = Scalar([0, 0, 0, 0]);
384
385        tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap());
386        tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap());
387        tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap());
388        tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap());
389
390        // Check canonical by subtracting modulus
391        let (_, borrow) = sbb(tmp.0[0], MODULUS.0[0], 0);
392        let (_, borrow) = sbb(tmp.0[1], MODULUS.0[1], borrow);
393        let (_, borrow) = sbb(tmp.0[2], MODULUS.0[2], borrow);
394        let (_, borrow) = sbb(tmp.0[3], MODULUS.0[3], borrow);
395
396        let is_some = (borrow as u8) & 1;
397
398        // Convert to Montgomery: (a * R^2) / R = aR
399        tmp *= &R2;
400
401        CtOption::new(tmp, Choice::from(is_some))
402    }
403
404    /// Convert to little-endian bytes
405    pub fn to_bytes(&self) -> [u8; 32] {
406        // Remove Montgomery: (aR) / R = a
407        let tmp = Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0);
408
409        let mut res = [0; 32];
410        res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes());
411        res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes());
412        res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes());
413        res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes());
414
415        res
416    }
417
418    /// Create from 512-bit little-endian integer mod q
419    pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar {
420        Scalar::from_u512([
421            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()),
422            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()),
423            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()),
424            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()),
425            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[32..40]).unwrap()),
426            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[40..48]).unwrap()),
427            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[48..56]).unwrap()),
428            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[56..64]).unwrap()),
429        ])
430    }
431
432    fn expand_message_xmd(msg: &[u8], dst: &[u8], len_in_bytes: usize) -> Result<Vec<u8>> {
433        const MAX_DST_LENGTH: usize = 255;
434        const HASH_OUTPUT_SIZE: usize = 32;
435
436        if dst.len() > MAX_DST_LENGTH {
437            return Err(Error::param("dst", "domain separation tag too long"));
438        }
439
440        let ell = (len_in_bytes + HASH_OUTPUT_SIZE - 1) / HASH_OUTPUT_SIZE;
441
442        if ell > 255 {
443            return Err(Error::param("len_in_bytes", "requested output too long"));
444        }
445
446        let dst_prime_len = dst.len() as u8;
447
448        let mut hasher = Sha256::new();
449        hasher.update(&[0u8; HASH_OUTPUT_SIZE])?;
450        hasher.update(msg)?;
451        hasher.update(&((len_in_bytes as u16).to_be_bytes()))?;
452        hasher.update(&[0u8])?;
453        hasher.update(dst)?;
454        hasher.update(&[dst_prime_len])?;
455
456        let b_0 = hasher.finalize()?;
457
458        let mut uniform_bytes = Vec::with_capacity(len_in_bytes);
459        let mut b_i = vec![0u8; HASH_OUTPUT_SIZE];
460
461        for i in 1..=ell {
462            let mut hasher = Sha256::new();
463            if i == 1 {
464                hasher.update(&[0u8; HASH_OUTPUT_SIZE])?;
465            } else {
466                let mut xored = [0u8; HASH_OUTPUT_SIZE];
467                for j in 0..HASH_OUTPUT_SIZE {
468                    xored[j] = b_0.as_ref()[j] ^ b_i[j];
469                }
470                hasher.update(&xored)?;
471            }
472            hasher.update(&[i as u8])?;
473            hasher.update(dst)?;
474            hasher.update(&[dst_prime_len])?;
475            let digest = hasher.finalize()?;
476            b_i.copy_from_slice(digest.as_ref());
477            uniform_bytes.extend_from_slice(&b_i);
478        }
479
480        uniform_bytes.truncate(len_in_bytes);
481        Ok(uniform_bytes)
482    }
483
484    /// Hashes arbitrary data to a scalar field element using SHA-256.
485    ///
486    /// This function implements a standards-compliant hash-to-field method following
487    /// the IETF hash-to-curve specification using expand_message_xmd with SHA-256.
488    ///
489    /// # Arguments
490    /// * `data`: The input data to hash.
491    /// * `dst`: A Domain Separation Tag (DST) to ensure hashes are unique per application context.
492    ///
493    /// # Returns
494    /// A `Result` containing the `Scalar` or an error.
495    pub fn hash_to_field(data: &[u8], dst: &[u8]) -> Result<Self> {
496        let expanded = Self::expand_message_xmd(data, dst, 64)?;
497        let mut expanded_array = [0u8; 64];
498        expanded_array.copy_from_slice(&expanded);
499        Ok(Self::from_bytes_wide(&expanded_array))
500    }
501
502    fn from_u512(limbs: [u64; 8]) -> Scalar {
503        let d0 = Scalar([limbs[0], limbs[1], limbs[2], limbs[3]]);
504        let d1 = Scalar([limbs[4], limbs[5], limbs[6], limbs[7]]);
505        d0 * R2 + d1 * R3
506    }
507
508    /// Creates a scalar from four `u64` limbs (little-endian). This function will
509    /// convert the raw integer into Montgomery form.
510    pub const fn from_raw(val: [u64; 4]) -> Self {
511        (&Scalar(val)).mul(&R2)
512    }
513
514    /// Computes the square of this scalar.
515    #[inline]
516    pub const fn square(&self) -> Scalar {
517        let (r1, carry) = mac(0, self.0[0], self.0[1], 0);
518        let (r2, carry) = mac(0, self.0[0], self.0[2], carry);
519        let (r3, r4) = mac(0, self.0[0], self.0[3], carry);
520
521        let (r3, carry) = mac(r3, self.0[1], self.0[2], 0);
522        let (r4, r5) = mac(r4, self.0[1], self.0[3], carry);
523
524        let (r5, r6) = mac(r5, self.0[2], self.0[3], 0);
525
526        let r7 = r6 >> 63;
527        let r6 = (r6 << 1) | (r5 >> 63);
528        let r5 = (r5 << 1) | (r4 >> 63);
529        let r4 = (r4 << 1) | (r3 >> 63);
530        let r3 = (r3 << 1) | (r2 >> 63);
531        let r2 = (r2 << 1) | (r1 >> 63);
532        let r1 = r1 << 1;
533
534        let (r0, carry) = mac(0, self.0[0], self.0[0], 0);
535        let (r1, carry) = adc(0, r1, carry);
536        let (r2, carry) = mac(r2, self.0[1], self.0[1], carry);
537        let (r3, carry) = adc(0, r3, carry);
538        let (r4, carry) = mac(r4, self.0[2], self.0[2], carry);
539        let (r5, carry) = adc(0, r5, carry);
540        let (r6, carry) = mac(r6, self.0[3], self.0[3], carry);
541        let (r7, _) = adc(0, r7, carry);
542
543        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
544    }
545
546    /// Computes `x` raised to the power of `2^k`.
547    #[inline]
548    pub fn pow2k(mut x: Scalar, mut k: u32) -> Scalar {
549        while k > 0 {
550            x = x.square();
551            k -= 1;
552        }
553        x
554    }
555
556    /// Variable-time exponentiation by a 256-bit little-endian exponent.
557    fn pow_vartime(&self, by: &[u64; 4]) -> Self {
558        let mut res = Self::one();
559        for limb in by.iter().rev() {
560            for i in (0..64).rev() {
561                res = res.square();
562                if ((limb >> i) & 1) == 1 {
563                    res *= self;
564                }
565            }
566        }
567        res
568    }
569
570    /// Computes the square root of this scalar using Tonelli-Shanks.
571    /// Returns `Some(s)` with `s^2 = self` if a square root exists, else `None`.
572    pub fn sqrt(&self) -> subtle::CtOption<Self> {
573        // Trivial case: sqrt(0) = 0
574        if bool::from(self.is_zero()) {
575            return subtle::CtOption::new(Scalar::zero(), subtle::Choice::from(1));
576        }
577
578        // Choose a fixed quadratic non-residue. For this field, 5 works.
579        let z = Scalar::from(5u64);
580
581        // Precompute values per Tonelli-Shanks
582        let mut c = z.pow_vartime(&TONELLI_T); // c = z^T
583        let mut t = self.pow_vartime(&TONELLI_T); // t = a^T
584        let mut x = self.pow_vartime(&TONELLI_TP1_DIV2); // x = a^((T+1)/2)
585        let mut m = S;
586
587        // If t == 1, we guessed the root correctly.
588        if bool::from(subtle::ConstantTimeEq::ct_eq(&t, &Scalar::one())) {
589            return subtle::CtOption::new(x, subtle::ConstantTimeEq::ct_eq(&x.square(), self));
590        }
591
592        // Main Tonelli-Shanks loop
593        loop {
594            // Find smallest i in [1, m) with t^(2^i) == 1
595            let mut i = 1u32;
596            let mut t2i = t.square();
597            while i < m && !bool::from(subtle::ConstantTimeEq::ct_eq(&t2i, &Scalar::one())) {
598                t2i = t2i.square();
599                i += 1;
600            }
601
602            // If i == m, then a is not a square root
603            if i == m {
604                return subtle::CtOption::new(Scalar::zero(), subtle::Choice::from(0));
605            }
606
607            // b = c^{2^(m - i - 1)}
608            let b = Scalar::pow2k(c, m - i - 1);
609
610            // Update variables
611            x = x * b;
612            let b2 = b.square();
613            t = t * b2;
614            c = b2;
615            m = i;
616
617            // If t is now 1, we are done
618            if bool::from(subtle::ConstantTimeEq::ct_eq(&t, &Scalar::one())) {
619                break;
620            }
621        }
622
623        // Final constant-time check to ensure correctness
624        subtle::CtOption::new(x, subtle::ConstantTimeEq::ct_eq(&x.square(), self))
625    }
626
627    /// Computes the multiplicative inverse of this scalar, if it is non-zero.
628    pub fn invert(&self) -> CtOption<Self> {
629        #[inline(always)]
630        fn square_assign_multi(n: &mut Scalar, num_times: usize) {
631            for _ in 0..num_times {
632                *n = n.square();
633            }
634        }
635        // Addition chain from github.com/kwantam/addchain
636        let mut t0 = self.square();
637        let mut t1 = t0 * self;
638        let mut t16 = t0.square();
639        let mut t6 = t16.square();
640        let mut t5 = t6 * t0;
641        t0 = t6 * t16;
642        let mut t12 = t5 * t16;
643        let mut t2 = t6.square();
644        let mut t7 = t5 * t6;
645        let mut t15 = t0 * t5;
646        let mut t17 = t12.square();
647        t1 *= t17;
648        let mut t3 = t7 * t2;
649        let t8 = t1 * t17;
650        let t4 = t8 * t2;
651        let t9 = t8 * t7;
652        t7 = t4 * t5;
653        let t11 = t4 * t17;
654        t5 = t9 * t17;
655        let t14 = t7 * t15;
656        let t13 = t11 * t12;
657        t12 = t11 * t17;
658        t15 *= &t12;
659        t16 *= &t15;
660        t3 *= &t16;
661        t17 *= &t3;
662        t0 *= &t17;
663        t6 *= &t0;
664        t2 *= &t6;
665        square_assign_multi(&mut t0, 8);
666        t0 *= &t17;
667        square_assign_multi(&mut t0, 9);
668        t0 *= &t16;
669        square_assign_multi(&mut t0, 9);
670        t0 *= &t15;
671        square_assign_multi(&mut t0, 9);
672        t0 *= &t15;
673        square_assign_multi(&mut t0, 7);
674        t0 *= &t14;
675        square_assign_multi(&mut t0, 7);
676        t0 *= &t13;
677        square_assign_multi(&mut t0, 10);
678        t0 *= &t12;
679        square_assign_multi(&mut t0, 9);
680        t0 *= &t11;
681        square_assign_multi(&mut t0, 8);
682        t0 *= &t8;
683        square_assign_multi(&mut t0, 8);
684        t0 *= self;
685        square_assign_multi(&mut t0, 14);
686        t0 *= &t9;
687        square_assign_multi(&mut t0, 10);
688        t0 *= &t8;
689        square_assign_multi(&mut t0, 15);
690        t0 *= &t7;
691        square_assign_multi(&mut t0, 10);
692        t0 *= &t6;
693        square_assign_multi(&mut t0, 8);
694        t0 *= &t5;
695        square_assign_multi(&mut t0, 16);
696        t0 *= &t3;
697        square_assign_multi(&mut t0, 8);
698        t0 *= &t2;
699        square_assign_multi(&mut t0, 7);
700        t0 *= &t4;
701        square_assign_multi(&mut t0, 9);
702        t0 *= &t2;
703        square_assign_multi(&mut t0, 8);
704        t0 *= &t3;
705        square_assign_multi(&mut t0, 8);
706        t0 *= &t2;
707        square_assign_multi(&mut t0, 8);
708        t0 *= &t2;
709        square_assign_multi(&mut t0, 8);
710        t0 *= &t2;
711        square_assign_multi(&mut t0, 8);
712        t0 *= &t3;
713        square_assign_multi(&mut t0, 8);
714        t0 *= &t2;
715        square_assign_multi(&mut t0, 8);
716        t0 *= &t2;
717        square_assign_multi(&mut t0, 5);
718        t0 *= &t1;
719        square_assign_multi(&mut t0, 5);
720        t0 *= &t1;
721
722        CtOption::new(t0, !subtle::ConstantTimeEq::ct_eq(self, &Self::zero()))
723    }
724
725    #[inline(always)]
726    const fn montgomery_reduce(
727        r0: u64,
728        r1: u64,
729        r2: u64,
730        r3: u64,
731        r4: u64,
732        r5: u64,
733        r6: u64,
734        r7: u64,
735    ) -> Self {
736        let k = r0.wrapping_mul(INV);
737        let (_, carry) = mac(r0, k, MODULUS.0[0], 0);
738        let (r1, carry) = mac(r1, k, MODULUS.0[1], carry);
739        let (r2, carry) = mac(r2, k, MODULUS.0[2], carry);
740        let (r3, carry) = mac(r3, k, MODULUS.0[3], carry);
741        let (r4, carry2) = adc(r4, 0, carry);
742
743        let k = r1.wrapping_mul(INV);
744        let (_, carry) = mac(r1, k, MODULUS.0[0], 0);
745        let (r2, carry) = mac(r2, k, MODULUS.0[1], carry);
746        let (r3, carry) = mac(r3, k, MODULUS.0[2], carry);
747        let (r4, carry) = mac(r4, k, MODULUS.0[3], carry);
748        let (r5, carry2) = adc(r5, carry2, carry);
749
750        let k = r2.wrapping_mul(INV);
751        let (_, carry) = mac(r2, k, MODULUS.0[0], 0);
752        let (r3, carry) = mac(r3, k, MODULUS.0[1], carry);
753        let (r4, carry) = mac(r4, k, MODULUS.0[2], carry);
754        let (r5, carry) = mac(r5, k, MODULUS.0[3], carry);
755        let (r6, carry2) = adc(r6, carry2, carry);
756
757        let k = r3.wrapping_mul(INV);
758        let (_, carry) = mac(r3, k, MODULUS.0[0], 0);
759        let (r4, carry) = mac(r4, k, MODULUS.0[1], carry);
760        let (r5, carry) = mac(r5, k, MODULUS.0[2], carry);
761        let (r6, carry) = mac(r6, k, MODULUS.0[3], carry);
762        let (r7, _) = adc(r7, carry2, carry);
763
764        (&Scalar([r4, r5, r6, r7])).sub(&MODULUS)
765    }
766
767    /// Multiplies this scalar by another.
768    #[inline]
769    pub const fn mul(&self, rhs: &Self) -> Self {
770        let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0);
771        let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry);
772        let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry);
773        let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry);
774
775        let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0);
776        let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry);
777        let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry);
778        let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry);
779
780        let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0);
781        let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry);
782        let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry);
783        let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry);
784
785        let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0);
786        let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry);
787        let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry);
788        let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry);
789
790        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
791    }
792
793    /// Subtracts another scalar from this one.
794    #[inline]
795    pub const fn sub(&self, rhs: &Self) -> Self {
796        let (d0, borrow) = sbb(self.0[0], rhs.0[0], 0);
797        let (d1, borrow) = sbb(self.0[1], rhs.0[1], borrow);
798        let (d2, borrow) = sbb(self.0[2], rhs.0[2], borrow);
799        let (d3, borrow) = sbb(self.0[3], rhs.0[3], borrow);
800
801        let (d0, carry) = adc(d0, MODULUS.0[0] & borrow, 0);
802        let (d1, carry) = adc(d1, MODULUS.0[1] & borrow, carry);
803        let (d2, carry) = adc(d2, MODULUS.0[2] & borrow, carry);
804        let (d3, _) = adc(d3, MODULUS.0[3] & borrow, carry);
805
806        Scalar([d0, d1, d2, d3])
807    }
808
809    /// Adds another scalar to this one.
810    #[inline]
811    pub const fn add(&self, rhs: &Self) -> Self {
812        let (d0, carry) = adc(self.0[0], rhs.0[0], 0);
813        let (d1, carry) = adc(self.0[1], rhs.0[1], carry);
814        let (d2, carry) = adc(self.0[2], rhs.0[2], carry);
815        let (d3, _) = adc(self.0[3], rhs.0[3], carry);
816
817        (&Scalar([d0, d1, d2, d3])).sub(&MODULUS)
818    }
819
820    /// Computes the additive negation of this scalar.
821    #[inline]
822    pub const fn neg(&self) -> Self {
823        let (d0, borrow) = sbb(MODULUS.0[0], self.0[0], 0);
824        let (d1, borrow) = sbb(MODULUS.0[1], self.0[1], borrow);
825        let (d2, borrow) = sbb(MODULUS.0[2], self.0[2], borrow);
826        let (d3, _) = sbb(MODULUS.0[3], self.0[3], borrow);
827
828        let mask = (((self.0[0] | self.0[1] | self.0[2] | self.0[3]) == 0) as u64).wrapping_sub(1);
829
830        Scalar([d0 & mask, d1 & mask, d2 & mask, d3 & mask])
831    }
832}
833
834impl From<Scalar> for [u8; 32] {
835    fn from(value: Scalar) -> [u8; 32] {
836        value.to_bytes()
837    }
838}
839
840impl<'a> From<&'a Scalar> for [u8; 32] {
841    fn from(value: &'a Scalar) -> [u8; 32] {
842        value.to_bytes()
843    }
844}
845
846impl<T> core::iter::Sum<T> for Scalar
847where
848    T: core::borrow::Borrow<Scalar>,
849{
850    fn sum<I>(iter: I) -> Self
851    where
852        I: Iterator<Item = T>,
853    {
854        iter.fold(Self::zero(), |acc, item| acc + item.borrow())
855    }
856}
857
858impl<T> core::iter::Product<T> for Scalar
859where
860    T: core::borrow::Borrow<Scalar>,
861{
862    fn product<I>(iter: I) -> Self
863    where
864        I: Iterator<Item = T>,
865    {
866        iter.fold(Self::one(), |acc, item| acc * item.borrow())
867    }
868}
869
870// Tests
871#[test]
872fn test_inv() {
873    // Verify INV constant
874    let mut inv = 1u64;
875    for _ in 0..63 {
876        inv = inv.wrapping_mul(inv);
877        inv = inv.wrapping_mul(MODULUS.0[0]);
878    }
879    inv = inv.wrapping_neg();
880    assert_eq!(inv, INV);
881}
882
883#[cfg(feature = "std")]
884#[test]
885fn test_debug() {
886    assert_eq!(
887        format!("{:?}", Scalar::zero()),
888        "0x0000000000000000000000000000000000000000000000000000000000000000"
889    );
890    assert_eq!(
891        format!("{:?}", Scalar::one()),
892        "0x0000000000000000000000000000000000000000000000000000000000000001"
893    );
894    // R is the Montgomery representation of 1. The Debug trait should perform the
895    // conversion, so it should also format to "1".
896    assert_eq!(
897        format!("{:?}", R),
898        "0x0000000000000000000000000000000000000000000000000000000000000001"
899    );
900}
901
902#[test]
903fn test_equality() {
904    assert_eq!(Scalar::zero(), Scalar::zero());
905    assert_eq!(Scalar::one(), Scalar::one());
906    #[allow(clippy::eq_op)]
907    {
908        assert_eq!(R2, R2);
909    }
910
911    assert!(Scalar::zero() != Scalar::one());
912    assert!(Scalar::one() != R2);
913}
914
915#[test]
916fn test_to_bytes() {
917    assert_eq!(
918        Scalar::zero().to_bytes(),
919        [
920            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
921            0, 0, 0
922        ]
923    );
924
925    assert_eq!(
926        Scalar::one().to_bytes(),
927        [
928            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
929            0, 0, 0
930        ]
931    );
932
933    // R is the Montgomery representation of 1. to_bytes() should perform the
934    // conversion, so it should also produce the bytes for "1".
935    assert_eq!(
936        R.to_bytes(),
937        [
938            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
939            0, 0, 0
940        ]
941    );
942
943    assert_eq!(
944        (-&Scalar::one()).to_bytes(),
945        [
946            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
947            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
948        ]
949    );
950}
951
952#[test]
953fn test_from_bytes() {
954    let mut a = R2;
955
956    for _ in 0..100 {
957        let bytes = a.to_bytes();
958        let b = Scalar::from_bytes(&bytes).unwrap();
959        assert_eq!(a, b);
960
961        // Test negation roundtrip
962        let bytes = (-a).to_bytes();
963        let b = Scalar::from_bytes(&bytes).unwrap();
964        assert_eq!(-a, b);
965
966        a = a.square();
967    }
968}
969
970#[cfg(test)]
971const LARGEST: Scalar = Scalar([
972    0xffff_ffff_0000_0000,
973    0x53bd_a402_fffe_5bfe,
974    0x3339_d808_09a1_d805,
975    0x73ed_a753_299d_7d48,
976]);
977
978#[test]
979fn test_addition() {
980    let mut tmp = LARGEST;
981    tmp += &LARGEST;
982
983    assert_eq!(
984        tmp,
985        Scalar([
986            0xffff_fffe_ffff_ffff,
987            0x53bd_a402_fffe_5bfe,
988            0x3339_d808_09a1_d805,
989            0x73ed_a753_299d_7d48,
990        ])
991    );
992
993    let mut tmp = LARGEST;
994    tmp += &Scalar([1, 0, 0, 0]);
995
996    assert_eq!(tmp, Scalar::zero());
997}
998
999#[test]
1000fn test_inversion() {
1001    assert!(bool::from(Scalar::zero().invert().is_none()));
1002    assert_eq!(Scalar::one().invert().unwrap(), Scalar::one());
1003    assert_eq!((-&Scalar::one()).invert().unwrap(), -&Scalar::one());
1004
1005    let mut tmp = R2;
1006
1007    for _ in 0..100 {
1008        let mut tmp2 = tmp.invert().unwrap();
1009        tmp2.mul_assign(&tmp);
1010
1011        assert_eq!(tmp2, Scalar::one());
1012
1013        tmp.add_assign(&R2);
1014    }
1015}
1016
1017#[test]
1018fn test_sqrt() {
1019    // Test with zero
1020    assert_eq!(Scalar::zero().sqrt().unwrap(), Scalar::zero());
1021
1022    // Test with one
1023    assert_eq!(Scalar::one().sqrt().unwrap(), Scalar::one());
1024
1025    // Test with a known square
1026    let four = Scalar::from(4u64);
1027    let two = Scalar::from(2u64);
1028    let neg_two = -two;
1029
1030    let sqrt_four = four.sqrt().unwrap();
1031    assert!(sqrt_four == two || sqrt_four == neg_two);
1032    assert_eq!(sqrt_four.square(), four);
1033
1034    // Test with a random square
1035    let s = Scalar::from(123456789u64);
1036    let s_sq = s.square();
1037    let s_sqrt = s_sq.sqrt().unwrap();
1038    assert!(s_sqrt == s || s_sqrt == -s);
1039    assert_eq!(s_sqrt.square(), s_sq);
1040
1041    // Test with a non-residue.
1042    // For this field, 5 is a quadratic non-residue.
1043    let five = Scalar::from(5u64);
1044    assert!(bool::from(five.sqrt().is_none()));
1045
1046    // Test with a residue.
1047    // For a prime q where q = 1 mod 4, -1 is a residue.
1048    let neg_one = -Scalar::one();
1049    let neg_one_sqrt = neg_one.sqrt().unwrap();
1050    assert_eq!(neg_one_sqrt.square(), neg_one);
1051
1052    // Test roundtrip for many values
1053    let mut val = R2;
1054    for _ in 0..100 {
1055        let sq = val.square();
1056        let sqrt = sq.sqrt().unwrap();
1057        assert!(sqrt == val || sqrt == -val);
1058        val += R;
1059    }
1060}
1061
1062#[test]
1063fn test_from_raw() {
1064    assert_eq!(
1065        Scalar::from_raw([
1066            0x0001_ffff_fffd,
1067            0x5884_b7fa_0003_4802,
1068            0x998c_4fef_ecbc_4ff5,
1069            0x1824_b159_acc5_056f,
1070        ]),
1071        Scalar::from_raw([0xffff_ffff_ffff_ffff; 4])
1072    );
1073
1074    assert_eq!(Scalar::from_raw(MODULUS.0), Scalar::zero());
1075
1076    assert_eq!(Scalar::from_raw([1, 0, 0, 0]), R);
1077}
1078
1079#[test]
1080fn test_scalar_hash_to_field() {
1081    let data1 = b"some input data";
1082    let data2 = b"different input data";
1083    let dst1 = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_"; // Standard DST format
1084    let dst2 = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
1085
1086    // 1. Different data should produce different scalars
1087    let s1 = Scalar::hash_to_field(data1, dst1).unwrap();
1088    let s2 = Scalar::hash_to_field(data2, dst1).unwrap();
1089    assert_ne!(s1, s2);
1090
1091    // 2. Same data with different DSTs should produce different scalars
1092    let s3 = Scalar::hash_to_field(data1, dst1).unwrap();
1093    let s4 = Scalar::hash_to_field(data1, dst2).unwrap();
1094    assert_ne!(s3, s4);
1095
1096    // 3. Hashing should be deterministic
1097    let s5 = Scalar::hash_to_field(data1, dst1).unwrap();
1098    assert_eq!(s3, s5);
1099
1100    // 4. Verify output is always valid scalar (less than modulus)
1101    for test_case in &[b"" as &[u8], b"a", b"test", &[0xFF; 100], &[0x00; 64]] {
1102        let scalar = Scalar::hash_to_field(test_case, dst1).unwrap();
1103        // The scalar should already be reduced, so converting to/from bytes should work
1104        let bytes = scalar.to_bytes();
1105        let scalar2 = Scalar::from_bytes(&bytes).unwrap();
1106        assert_eq!(scalar, scalar2, "Output should be a valid reduced scalar");
1107    }
1108
1109    // 5. Test that the expansion reduces bias appropriately
1110    // With 64 bytes (512 bits) being reduced to ~255 bits, bias should be negligible
1111    let mut scalars = Vec::new();
1112    for i in 0u32..100 {
1113        let data = i.to_le_bytes();
1114        let s = Scalar::hash_to_field(&data, dst1).unwrap();
1115        scalars.push(s);
1116    }
1117    // All should be different (no collisions in small sample)
1118    for i in 0..scalars.len() {
1119        for j in i + 1..scalars.len() {
1120            assert_ne!(
1121                scalars[i], scalars[j],
1122                "Unexpected collision at {} and {}",
1123                i, j
1124            );
1125        }
1126    }
1127
1128    // 6. Test empty DST and empty data edge cases
1129    let s_empty = Scalar::hash_to_field(b"", b"").unwrap();
1130    let s_empty2 = Scalar::hash_to_field(b"", b"").unwrap();
1131    assert_eq!(
1132        s_empty, s_empty2,
1133        "Empty input should still be deterministic"
1134    );
1135
1136    // 7. Verify that DST length is properly included (catches common implementation bugs)
1137    let dst_short = b"A";
1138    let dst_long = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"; // 50 A's
1139    let s_short = Scalar::hash_to_field(data1, dst_short).unwrap();
1140    let s_long = Scalar::hash_to_field(data1, dst_long).unwrap();
1141    assert_ne!(s_short, s_long, "DST length should affect output");
1142
1143    // 8. Test mathematical properties: hash(data) should be uniformly distributed
1144    // We can't test true uniformity easily, but we can check it's not always even/odd
1145    let mut has_odd = false;
1146    let mut has_even = false;
1147    for i in 0u8..20 {
1148        let s = Scalar::hash_to_field(&[i], dst1).unwrap();
1149        // Check the least significant bit
1150        if s.to_bytes()[0] & 1 == 0 {
1151            has_even = true;
1152        } else {
1153            has_odd = true;
1154        }
1155    }
1156    assert!(
1157        has_odd && has_even,
1158        "Hash output should have both odd and even values"
1159    );
1160
1161    // 9. Test expand_message_xmd internal function with basic test vectors
1162    // These help ensure our implementation follows the standard
1163    let expanded = Scalar::expand_message_xmd(b"", b"QUUX-V01-CS02-with-SHA256", 32).unwrap();
1164    assert_eq!(expanded.len(), 32);
1165
1166    // Basic sanity check: different messages produce different expansions
1167    let expanded1 = Scalar::expand_message_xmd(b"msg1", b"dst", 64).unwrap();
1168    let expanded2 = Scalar::expand_message_xmd(b"msg2", b"dst", 64).unwrap();
1169    assert_ne!(expanded1, expanded2);
1170}
1171
1172#[cfg(feature = "zeroize")]
1173#[test]
1174fn test_zeroize() {
1175    use zeroize::Zeroize;
1176
1177    let mut a = Scalar::from_raw([
1178        0x1fff_3231_233f_fffd,
1179        0x4884_b7fa_0003_4802,
1180        0x998c_4fef_ecbc_4ff3,
1181        0x1824_b159_acc5_0562,
1182    ]);
1183    a.zeroize();
1184    // Fixed: disambiguate ct_eq
1185    assert!(bool::from(subtle::ConstantTimeEq::ct_eq(
1186        &a,
1187        &Scalar::zero()
1188    )));
1189}