dcrypt_algorithms/ec/bls12_381/
scalar.rs

1//! BLS12-381 scalar field F_q where q = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
2
3use crate::error::{Error, Result};
4use crate::hash::{sha2::Sha256, HashFunction};
5use crate::types::{
6    ByteSerializable, ConstantTimeEq as DcryptConstantTimeEq, SecureZeroingType,
7};
8use core::fmt;
9use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
10use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
11
12// Arithmetic helpers
13/// Compute a + b + carry, returning (result, carry)
14#[inline(always)]
15const fn adc(a: u64, b: u64, carry: u64) -> (u64, u64) {
16    let ret = (a as u128) + (b as u128) + (carry as u128);
17    (ret as u64, (ret >> 64) as u64)
18}
19
20/// Compute a - (b + borrow), returning (result, borrow)
21#[inline(always)]
22const fn sbb(a: u64, b: u64, borrow: u64) -> (u64, u64) {
23    let ret = (a as u128).wrapping_sub((b as u128) + ((borrow >> 63) as u128));
24    (ret as u64, (ret >> 64) as u64)
25}
26
27/// Compute a + (b * c) + carry, returning (result, carry)
28#[inline(always)]
29const fn mac(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) {
30    let ret = (a as u128) + ((b as u128) * (c as u128)) + (carry as u128);
31    (ret as u64, (ret >> 64) as u64)
32}
33
34/// Scalar field element of BLS12-381
35/// Internal: Four 64-bit limbs in little-endian Montgomery form
36#[derive(Clone, Copy, Eq)]
37pub struct Scalar(pub(crate) [u64; 4]);
38
39impl fmt::Debug for Scalar {
40    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
41        let tmp = self.to_bytes();
42        write!(f, "0x")?;
43        for &b in tmp.iter().rev() {
44            write!(f, "{:02x}", b)?;
45        }
46        Ok(())
47    }
48}
49
50impl fmt::Display for Scalar {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        write!(f, "{:?}", self)
53    }
54}
55
56impl From<u64> for Scalar {
57    fn from(val: u64) -> Scalar {
58        Scalar([val, 0, 0, 0]) * R2
59    }
60}
61
62impl ConstantTimeEq for Scalar {
63    fn ct_eq(&self, other: &Self) -> Choice {
64        self.0[0].ct_eq(&other.0[0])
65            & self.0[1].ct_eq(&other.0[1])
66            & self.0[2].ct_eq(&other.0[2])
67            & self.0[3].ct_eq(&other.0[3])
68    }
69}
70
71impl PartialEq for Scalar {
72    #[inline]
73    fn eq(&self, other: &Self) -> bool {
74        bool::from(subtle::ConstantTimeEq::ct_eq(self, other))
75    }
76}
77
78impl ConditionallySelectable for Scalar {
79    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
80        Scalar([
81            u64::conditional_select(&a.0[0], &b.0[0], choice),
82            u64::conditional_select(&a.0[1], &b.0[1], choice),
83            u64::conditional_select(&a.0[2], &b.0[2], choice),
84            u64::conditional_select(&a.0[3], &b.0[3], choice),
85        ])
86    }
87}
88
89// Constants
90const MODULUS: Scalar = Scalar([
91    0xffff_ffff_0000_0001,
92    0x53bd_a402_fffe_5bfe,
93    0x3339_d808_09a1_d805,
94    0x73ed_a753_299d_7d48,
95]);
96
97/// INV = -(q^{-1} mod 2^64) mod 2^64
98const INV: u64 = 0xffff_fffe_ffff_ffff;
99
100/// R = 2^256 mod q
101const R: Scalar = Scalar([
102    0x0000_0001_ffff_fffe,
103    0x5884_b7fa_0003_4802,
104    0x998c_4fef_ecbc_4ff5,
105    0x1824_b159_acc5_056f,
106]);
107
108/// R^2 = 2^512 mod q
109const R2: Scalar = Scalar([
110    0xc999_e990_f3f2_9c6d,
111    0x2b6c_edcb_8792_5c23,
112    0x05d3_1496_7254_398f,
113    0x0748_d9d9_9f59_ff11,
114]);
115
116/// R^3 = 2^768 mod q
117const R3: Scalar = Scalar([
118    0xc62c_1807_439b_73af,
119    0x1b3e_0d18_8cf0_6990,
120    0x73d1_3c71_c7b5_f418,
121    0x6e2a_5bb9_c8db_33e9,
122]);
123
124impl<'a> Neg for &'a Scalar {
125    type Output = Scalar;
126
127    #[inline]
128    fn neg(self) -> Scalar {
129        self.neg()
130    }
131}
132
133impl Neg for Scalar {
134    type Output = Scalar;
135
136    #[inline]
137    fn neg(self) -> Scalar {
138        -&self
139    }
140}
141
142impl<'a, 'b> Sub<&'b Scalar> for &'a Scalar {
143    type Output = Scalar;
144
145    #[inline]
146    fn sub(self, rhs: &'b Scalar) -> Scalar {
147        self.sub(rhs)
148    }
149}
150
151impl<'a, 'b> Add<&'b Scalar> for &'a Scalar {
152    type Output = Scalar;
153
154    #[inline]
155    fn add(self, rhs: &'b Scalar) -> Scalar {
156        self.add(rhs)
157    }
158}
159
160impl<'a, 'b> Mul<&'b Scalar> for &'a Scalar {
161    type Output = Scalar;
162
163    #[inline]
164    fn mul(self, rhs: &'b Scalar) -> Scalar {
165        self.mul(rhs)
166    }
167}
168
169// Binop implementations
170impl<'b> Add<&'b Scalar> for Scalar {
171    type Output = Scalar;
172    #[inline]
173    fn add(self, rhs: &'b Scalar) -> Scalar {
174        &self + rhs
175    }
176}
177
178impl<'a> Add<Scalar> for &'a Scalar {
179    type Output = Scalar;
180    #[inline]
181    fn add(self, rhs: Scalar) -> Scalar {
182        self + &rhs
183    }
184}
185
186impl Add<Scalar> for Scalar {
187    type Output = Scalar;
188    #[inline]
189    fn add(self, rhs: Scalar) -> Scalar {
190        &self + &rhs
191    }
192}
193
194impl<'b> Sub<&'b Scalar> for Scalar {
195    type Output = Scalar;
196    #[inline]
197    fn sub(self, rhs: &'b Scalar) -> Scalar {
198        &self - rhs
199    }
200}
201
202impl<'a> Sub<Scalar> for &'a Scalar {
203    type Output = Scalar;
204    #[inline]
205    fn sub(self, rhs: Scalar) -> Scalar {
206        self - &rhs
207    }
208}
209
210impl Sub<Scalar> for Scalar {
211    type Output = Scalar;
212    #[inline]
213    fn sub(self, rhs: Scalar) -> Scalar {
214        &self - &rhs
215    }
216}
217
218impl SubAssign<Scalar> for Scalar {
219    #[inline]
220    fn sub_assign(&mut self, rhs: Scalar) {
221        *self = &*self - &rhs;
222    }
223}
224
225impl AddAssign<Scalar> for Scalar {
226    #[inline]
227    fn add_assign(&mut self, rhs: Scalar) {
228        *self = &*self + &rhs;
229    }
230}
231
232impl<'b> SubAssign<&'b Scalar> for Scalar {
233    #[inline]
234    fn sub_assign(&mut self, rhs: &'b Scalar) {
235        *self = &*self - rhs;
236    }
237}
238
239impl<'b> AddAssign<&'b Scalar> for Scalar {
240    #[inline]
241    fn add_assign(&mut self, rhs: &'b Scalar) {
242        *self = &*self + rhs;
243    }
244}
245
246impl<'b> Mul<&'b Scalar> for Scalar {
247    type Output = Scalar;
248    #[inline]
249    fn mul(self, rhs: &'b Scalar) -> Scalar {
250        &self * rhs
251    }
252}
253
254impl<'a> Mul<Scalar> for &'a Scalar {
255    type Output = Scalar;
256    #[inline]
257    fn mul(self, rhs: Scalar) -> Scalar {
258        self * &rhs
259    }
260}
261
262impl Mul<Scalar> for Scalar {
263    type Output = Scalar;
264    #[inline]
265    fn mul(self, rhs: Scalar) -> Scalar {
266        &self * &rhs
267    }
268}
269
270impl MulAssign<Scalar> for Scalar {
271    #[inline]
272    fn mul_assign(&mut self, rhs: Scalar) {
273        *self = &*self * &rhs;
274    }
275}
276
277impl<'b> MulAssign<&'b Scalar> for Scalar {
278    #[inline]
279    fn mul_assign(&mut self, rhs: &'b Scalar) {
280        *self = &*self * rhs;
281    }
282}
283
284impl Default for Scalar {
285    #[inline]
286    fn default() -> Self {
287        Self::zero()
288    }
289}
290
291#[cfg(feature = "zeroize")]
292impl zeroize::DefaultIsZeroes for Scalar {}
293
294impl ByteSerializable for Scalar {
295    fn to_bytes(&self) -> Vec<u8> {
296        self.to_bytes().to_vec()
297    }
298
299    fn from_bytes(bytes: &[u8]) -> Result<Self> {
300        if bytes.len() != 32 {
301            return Err(Error::Length {
302                context: "Scalar::from_bytes",
303                expected: 32,
304                actual: bytes.len(),
305            });
306        }
307
308        let mut array = [0u8; 32];
309        array.copy_from_slice(bytes);
310
311        Scalar::from_bytes(&array)
312            .into_option()  // Use into_option() instead of into()
313            .ok_or_else(|| Error::param("scalar_bytes", "non-canonical scalar"))
314    }
315}
316
317impl DcryptConstantTimeEq for Scalar {
318    fn ct_eq(&self, other: &Self) -> bool {
319        bool::from(subtle::ConstantTimeEq::ct_eq(self, other))
320    }
321}
322
323impl SecureZeroingType for Scalar {
324    fn zeroed() -> Self {
325        Self::zero()
326    }
327}
328
329impl Scalar {
330    /// Additive identity
331    #[inline]
332    pub const fn zero() -> Scalar {
333        Scalar([0, 0, 0, 0])
334    }
335
336    /// Multiplicative identity
337    #[inline]
338    pub const fn one() -> Scalar {
339        R
340    }
341
342    /// Double this element
343    #[inline]
344    pub const fn double(&self) -> Scalar {
345        self.add(self)
346    }
347
348    /// Create from little-endian bytes if canonical
349    pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<Scalar> {
350        let mut tmp = Scalar([0, 0, 0, 0]);
351
352        tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap());
353        tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap());
354        tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap());
355        tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap());
356
357        // Check canonical by subtracting modulus
358        let (_, borrow) = sbb(tmp.0[0], MODULUS.0[0], 0);
359        let (_, borrow) = sbb(tmp.0[1], MODULUS.0[1], borrow);
360        let (_, borrow) = sbb(tmp.0[2], MODULUS.0[2], borrow);
361        let (_, borrow) = sbb(tmp.0[3], MODULUS.0[3], borrow);
362
363        let is_some = (borrow as u8) & 1;
364
365        // Convert to Montgomery: (a * R^2) / R = aR
366        tmp *= &R2;
367
368        CtOption::new(tmp, Choice::from(is_some))
369    }
370
371    /// Convert to little-endian bytes
372    pub fn to_bytes(&self) -> [u8; 32] {
373        // Remove Montgomery: (aR) / R = a
374        let tmp = Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0);
375
376        let mut res = [0; 32];
377        res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes());
378        res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes());
379        res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes());
380        res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes());
381
382        res
383    }
384
385    /// Create from 512-bit little-endian integer mod q
386    pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar {
387        Scalar::from_u512([
388            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()),
389            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()),
390            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()),
391            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()),
392            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[32..40]).unwrap()),
393            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[40..48]).unwrap()),
394            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[48..56]).unwrap()),
395            u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[56..64]).unwrap()),
396        ])
397    }
398
399    // ============================================================================
400    // START: Standards-Compliant Hash-to-Field Implementation with SHA-256
401    // ============================================================================
402
403    /// Expands a message using SHA-256 as per IETF hash-to-curve expand_message_xmd.
404    /// 
405    /// This follows Section 5.3.1 of draft-irtf-cfrg-hash-to-curve.
406    fn expand_message_xmd(msg: &[u8], dst: &[u8], len_in_bytes: usize) -> Result<Vec<u8>> {
407        const MAX_DST_LENGTH: usize = 255;
408        const HASH_OUTPUT_SIZE: usize = 32; // SHA-256 output size
409        
410        // Check DST length
411        if dst.len() > MAX_DST_LENGTH {
412            return Err(Error::param("dst", "domain separation tag too long"));
413        }
414        
415        // ell = ceil(len_in_bytes / b_in_bytes)
416        let ell = (len_in_bytes + HASH_OUTPUT_SIZE - 1) / HASH_OUTPUT_SIZE;
417        
418        // Check ell is within bounds (max 255 for single byte counter)
419        if ell > 255 {
420            return Err(Error::param("len_in_bytes", "requested output too long"));
421        }
422        
423        // DST_prime = DST || I2OSP(len(DST), 1)
424        // I2OSP(len(DST), 1) is just the length as a single byte
425        let dst_prime_len = dst.len() as u8;
426        
427        // msg_prime = Z_pad || msg || l_i_b_str || I2OSP(0, 1) || DST_prime
428        // Z_pad = I2OSP(0, s_in_bytes) where s_in_bytes = HASH_OUTPUT_SIZE
429        // l_i_b_str = I2OSP(len_in_bytes, 2)
430        let mut hasher = Sha256::new();
431        
432        // Z_pad: 32 zero bytes for SHA-256
433        hasher.update(&[0u8; HASH_OUTPUT_SIZE])?;
434        
435        // msg
436        hasher.update(msg)?;
437        
438        // l_i_b_str: len_in_bytes as 2 bytes big-endian
439        hasher.update(&((len_in_bytes as u16).to_be_bytes()))?;
440        
441        // I2OSP(0, 1)
442        hasher.update(&[0u8])?;
443        
444        // DST_prime
445        hasher.update(dst)?;
446        hasher.update(&[dst_prime_len])?;
447        
448        // b_0 = H(msg_prime)
449        let b_0 = hasher.finalize()?;
450        
451        let mut uniform_bytes = Vec::with_capacity(len_in_bytes);
452        let mut b_i = vec![0u8; HASH_OUTPUT_SIZE];
453        
454        for i in 1..=ell {
455            // b_i = H(strxor(b_0, b_{i-1}) || I2OSP(i, 1) || DST_prime)
456            let mut hasher = Sha256::new();
457            
458            // strxor(b_0, b_{i-1})
459            if i == 1 {
460                // b_0 = b_{i-1} for first iteration, so strxor is all zeros
461                hasher.update(&[0u8; HASH_OUTPUT_SIZE])?;
462            } else {
463                // XOR b_0 with previous b_i
464                let mut xored = [0u8; HASH_OUTPUT_SIZE];
465                for j in 0..HASH_OUTPUT_SIZE {
466                    xored[j] = b_0.as_ref()[j] ^ b_i[j];
467                }
468                hasher.update(&xored)?;
469            }
470            
471            // I2OSP(i, 1)
472            hasher.update(&[i as u8])?;
473            
474            // DST_prime
475            hasher.update(dst)?;
476            hasher.update(&[dst_prime_len])?;
477            
478            let digest = hasher.finalize()?;
479            b_i.copy_from_slice(digest.as_ref());
480            
481            // Append to uniform_bytes
482            uniform_bytes.extend_from_slice(&b_i);
483        }
484        
485        // Return first len_in_bytes bytes
486        uniform_bytes.truncate(len_in_bytes);
487        Ok(uniform_bytes)
488    }
489
490    /// Hashes arbitrary data to a scalar field element using SHA-256.
491    ///
492    /// This function implements a standards-compliant hash-to-field method following
493    /// the IETF hash-to-curve specification using expand_message_xmd with SHA-256.
494    ///
495    /// # Arguments
496    /// * `data`: The input data to hash.
497    /// * `dst`: A Domain Separation Tag (DST) to ensure hashes are unique per application context.
498    ///
499    /// # Returns
500    /// A `Result` containing the `Scalar` or an error.
501    pub fn hash_to_field(
502        data: &[u8],
503        dst: &[u8],
504    ) -> Result<Self> {
505        // Expand message to 64 bytes (512 bits) for bias reduction
506        // This provides ~128 bits of security when reducing modulo the ~255-bit scalar field
507        let expanded = Self::expand_message_xmd(data, dst, 64)?;
508        
509        // Convert to array for from_bytes_wide
510        let mut expanded_array = [0u8; 64];
511        expanded_array.copy_from_slice(&expanded);
512        
513        // Reduce modulo q
514        Ok(Self::from_bytes_wide(&expanded_array))
515    }
516
517    // ============================================================================
518    // END: Standards-Compliant Hash-to-Field Implementation
519    // ============================================================================
520
521    fn from_u512(limbs: [u64; 8]) -> Scalar {
522        let d0 = Scalar([limbs[0], limbs[1], limbs[2], limbs[3]]);
523        let d1 = Scalar([limbs[4], limbs[5], limbs[6], limbs[7]]);
524        d0 * R2 + d1 * R3
525    }
526
527    /// Create from raw values and convert to Montgomery
528    pub const fn from_raw(val: [u64; 4]) -> Self {
529        (&Scalar(val)).mul(&R2)
530    }
531
532    /// Square this element
533    #[inline]
534    pub const fn square(&self) -> Scalar {
535        let (r1, carry) = mac(0, self.0[0], self.0[1], 0);
536        let (r2, carry) = mac(0, self.0[0], self.0[2], carry);
537        let (r3, r4) = mac(0, self.0[0], self.0[3], carry);
538
539        let (r3, carry) = mac(r3, self.0[1], self.0[2], 0);
540        let (r4, r5) = mac(r4, self.0[1], self.0[3], carry);
541
542        let (r5, r6) = mac(r5, self.0[2], self.0[3], 0);
543
544        let r7 = r6 >> 63;
545        let r6 = (r6 << 1) | (r5 >> 63);
546        let r5 = (r5 << 1) | (r4 >> 63);
547        let r4 = (r4 << 1) | (r3 >> 63);
548        let r3 = (r3 << 1) | (r2 >> 63);
549        let r2 = (r2 << 1) | (r1 >> 63);
550        let r1 = r1 << 1;
551
552        let (r0, carry) = mac(0, self.0[0], self.0[0], 0);
553        let (r1, carry) = adc(0, r1, carry);
554        let (r2, carry) = mac(r2, self.0[1], self.0[1], carry);
555        let (r3, carry) = adc(0, r3, carry);
556        let (r4, carry) = mac(r4, self.0[2], self.0[2], carry);
557        let (r5, carry) = adc(0, r5, carry);
558        let (r6, carry) = mac(r6, self.0[3], self.0[3], carry);
559        let (r7, _) = adc(0, r7, carry);
560
561        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
562    }
563
564    /// Multiplicative inverse
565    pub fn invert(&self) -> CtOption<Self> {
566        #[inline(always)]
567        fn square_assign_multi(n: &mut Scalar, num_times: usize) {
568            for _ in 0..num_times {
569                *n = n.square();
570            }
571        }
572        // Addition chain from github.com/kwantam/addchain
573        let mut t0 = self.square();
574        let mut t1 = t0 * self;
575        let mut t16 = t0.square();
576        let mut t6 = t16.square();
577        let mut t5 = t6 * t0;
578        t0 = t6 * t16;
579        let mut t12 = t5 * t16;
580        let mut t2 = t6.square();
581        let mut t7 = t5 * t6;
582        let mut t15 = t0 * t5;
583        let mut t17 = t12.square();
584        t1 *= t17;
585        let mut t3 = t7 * t2;
586        let t8 = t1 * t17;
587        let t4 = t8 * t2;
588        let t9 = t8 * t7;
589        t7 = t4 * t5;
590        let t11 = t4 * t17;
591        t5 = t9 * t17;
592        let t14 = t7 * t15;
593        let t13 = t11 * t12;
594        t12 = t11 * t17;
595        t15 *= &t12;
596        t16 *= &t15;
597        t3 *= &t16;
598        t17 *= &t3;
599        t0 *= &t17;
600        t6 *= &t0;
601        t2 *= &t6;
602        square_assign_multi(&mut t0, 8);
603        t0 *= &t17;
604        square_assign_multi(&mut t0, 9);
605        t0 *= &t16;
606        square_assign_multi(&mut t0, 9);
607        t0 *= &t15;
608        square_assign_multi(&mut t0, 9);
609        t0 *= &t15;
610        square_assign_multi(&mut t0, 7);
611        t0 *= &t14;
612        square_assign_multi(&mut t0, 7);
613        t0 *= &t13;
614        square_assign_multi(&mut t0, 10);
615        t0 *= &t12;
616        square_assign_multi(&mut t0, 9);
617        t0 *= &t11;
618        square_assign_multi(&mut t0, 8);
619        t0 *= &t8;
620        square_assign_multi(&mut t0, 8);
621        t0 *= self;
622        square_assign_multi(&mut t0, 14);
623        t0 *= &t9;
624        square_assign_multi(&mut t0, 10);
625        t0 *= &t8;
626        square_assign_multi(&mut t0, 15);
627        t0 *= &t7;
628        square_assign_multi(&mut t0, 10);
629        t0 *= &t6;
630        square_assign_multi(&mut t0, 8);
631        t0 *= &t5;
632        square_assign_multi(&mut t0, 16);
633        t0 *= &t3;
634        square_assign_multi(&mut t0, 8);
635        t0 *= &t2;
636        square_assign_multi(&mut t0, 7);
637        t0 *= &t4;
638        square_assign_multi(&mut t0, 9);
639        t0 *= &t2;
640        square_assign_multi(&mut t0, 8);
641        t0 *= &t3;
642        square_assign_multi(&mut t0, 8);
643        t0 *= &t2;
644        square_assign_multi(&mut t0, 8);
645        t0 *= &t2;
646        square_assign_multi(&mut t0, 8);
647        t0 *= &t2;
648        square_assign_multi(&mut t0, 8);
649        t0 *= &t3;
650        square_assign_multi(&mut t0, 8);
651        t0 *= &t2;
652        square_assign_multi(&mut t0, 8);
653        t0 *= &t2;
654        square_assign_multi(&mut t0, 5);
655        t0 *= &t1;
656        square_assign_multi(&mut t0, 5);
657        t0 *= &t1;
658
659        CtOption::new(t0, !subtle::ConstantTimeEq::ct_eq(self, &Self::zero()))
660    }
661
662    #[inline(always)]
663    const fn montgomery_reduce(
664        r0: u64,
665        r1: u64,
666        r2: u64,
667        r3: u64,
668        r4: u64,
669        r5: u64,
670        r6: u64,
671        r7: u64,
672    ) -> Self {
673        let k = r0.wrapping_mul(INV);
674        let (_, carry) = mac(r0, k, MODULUS.0[0], 0);
675        let (r1, carry) = mac(r1, k, MODULUS.0[1], carry);
676        let (r2, carry) = mac(r2, k, MODULUS.0[2], carry);
677        let (r3, carry) = mac(r3, k, MODULUS.0[3], carry);
678        let (r4, carry2) = adc(r4, 0, carry);
679
680        let k = r1.wrapping_mul(INV);
681        let (_, carry) = mac(r1, k, MODULUS.0[0], 0);
682        let (r2, carry) = mac(r2, k, MODULUS.0[1], carry);
683        let (r3, carry) = mac(r3, k, MODULUS.0[2], carry);
684        let (r4, carry) = mac(r4, k, MODULUS.0[3], carry);
685        let (r5, carry2) = adc(r5, carry2, carry);
686
687        let k = r2.wrapping_mul(INV);
688        let (_, carry) = mac(r2, k, MODULUS.0[0], 0);
689        let (r3, carry) = mac(r3, k, MODULUS.0[1], carry);
690        let (r4, carry) = mac(r4, k, MODULUS.0[2], carry);
691        let (r5, carry) = mac(r5, k, MODULUS.0[3], carry);
692        let (r6, carry2) = adc(r6, carry2, carry);
693
694        let k = r3.wrapping_mul(INV);
695        let (_, carry) = mac(r3, k, MODULUS.0[0], 0);
696        let (r4, carry) = mac(r4, k, MODULUS.0[1], carry);
697        let (r5, carry) = mac(r5, k, MODULUS.0[2], carry);
698        let (r6, carry) = mac(r6, k, MODULUS.0[3], carry);
699        let (r7, _) = adc(r7, carry2, carry);
700
701        (&Scalar([r4, r5, r6, r7])).sub(&MODULUS)
702    }
703
704    /// Multiply two scalars
705    #[inline]
706    pub const fn mul(&self, rhs: &Self) -> Self {
707        let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0);
708        let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry);
709        let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry);
710        let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry);
711
712        let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0);
713        let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry);
714        let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry);
715        let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry);
716
717        let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0);
718        let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry);
719        let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry);
720        let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry);
721
722        let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0);
723        let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry);
724        let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry);
725        let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry);
726
727        Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
728    }
729
730    /// Subtract rhs from self
731    #[inline]
732    pub const fn sub(&self, rhs: &Self) -> Self {
733        let (d0, borrow) = sbb(self.0[0], rhs.0[0], 0);
734        let (d1, borrow) = sbb(self.0[1], rhs.0[1], borrow);
735        let (d2, borrow) = sbb(self.0[2], rhs.0[2], borrow);
736        let (d3, borrow) = sbb(self.0[3], rhs.0[3], borrow);
737
738        let (d0, carry) = adc(d0, MODULUS.0[0] & borrow, 0);
739        let (d1, carry) = adc(d1, MODULUS.0[1] & borrow, carry);
740        let (d2, carry) = adc(d2, MODULUS.0[2] & borrow, carry);
741        let (d3, _) = adc(d3, MODULUS.0[3] & borrow, carry);
742
743        Scalar([d0, d1, d2, d3])
744    }
745
746    /// Add rhs to self
747    #[inline]
748    pub const fn add(&self, rhs: &Self) -> Self {
749        let (d0, carry) = adc(self.0[0], rhs.0[0], 0);
750        let (d1, carry) = adc(self.0[1], rhs.0[1], carry);
751        let (d2, carry) = adc(self.0[2], rhs.0[2], carry);
752        let (d3, _) = adc(self.0[3], rhs.0[3], carry);
753
754        (&Scalar([d0, d1, d2, d3])).sub(&MODULUS)
755    }
756
757    /// Negate self
758    #[inline]
759    pub const fn neg(&self) -> Self {
760        let (d0, borrow) = sbb(MODULUS.0[0], self.0[0], 0);
761        let (d1, borrow) = sbb(MODULUS.0[1], self.0[1], borrow);
762        let (d2, borrow) = sbb(MODULUS.0[2], self.0[2], borrow);
763        let (d3, _) = sbb(MODULUS.0[3], self.0[3], borrow);
764
765        let mask = (((self.0[0] | self.0[1] | self.0[2] | self.0[3]) == 0) as u64).wrapping_sub(1);
766
767        Scalar([d0 & mask, d1 & mask, d2 & mask, d3 & mask])
768    }
769}
770
771impl From<Scalar> for [u8; 32] {
772    fn from(value: Scalar) -> [u8; 32] {
773        value.to_bytes()
774    }
775}
776
777impl<'a> From<&'a Scalar> for [u8; 32] {
778    fn from(value: &'a Scalar) -> [u8; 32] {
779        value.to_bytes()
780    }
781}
782
783impl<T> core::iter::Sum<T> for Scalar
784where
785    T: core::borrow::Borrow<Scalar>,
786{
787    fn sum<I>(iter: I) -> Self
788    where
789        I: Iterator<Item = T>,
790    {
791        iter.fold(Self::zero(), |acc, item| acc + item.borrow())
792    }
793}
794
795impl<T> core::iter::Product<T> for Scalar
796where
797    T: core::borrow::Borrow<Scalar>,
798{
799    fn product<I>(iter: I) -> Self
800    where
801        I: Iterator<Item = T>,
802    {
803        iter.fold(Self::one(), |acc, item| acc * item.borrow())
804    }
805}
806
807// Tests
808#[test]
809fn test_inv() {
810    // Verify INV constant
811    let mut inv = 1u64;
812    for _ in 0..63 {
813        inv = inv.wrapping_mul(inv);
814        inv = inv.wrapping_mul(MODULUS.0[0]);
815    }
816    inv = inv.wrapping_neg();
817    assert_eq!(inv, INV);
818}
819
820#[cfg(feature = "std")]
821#[test]
822fn test_debug() {
823    assert_eq!(
824        format!("{:?}", Scalar::zero()),
825        "0x0000000000000000000000000000000000000000000000000000000000000000"
826    );
827    assert_eq!(
828        format!("{:?}", Scalar::one()),
829        "0x0000000000000000000000000000000000000000000000000000000000000001"
830    );
831    assert_eq!(
832        format!("{:?}", R2),
833        "0x1824b159acc5056f998c4fefecbc4ff55884b7fa0003480200000001fffffffe"
834    );
835}
836
837#[test]
838fn test_equality() {
839    assert_eq!(Scalar::zero(), Scalar::zero());
840    assert_eq!(Scalar::one(), Scalar::one());
841    #[allow(clippy::eq_op)]
842    {
843        assert_eq!(R2, R2);
844    }
845
846    assert!(Scalar::zero() != Scalar::one());
847    assert!(Scalar::one() != R2);
848}
849
850#[test]
851fn test_to_bytes() {
852    assert_eq!(
853        Scalar::zero().to_bytes(),
854        [
855            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
856            0, 0, 0
857        ]
858    );
859
860    assert_eq!(
861        Scalar::one().to_bytes(),
862        [
863            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
864            0, 0, 0
865        ]
866    );
867
868    assert_eq!(
869        R2.to_bytes(),
870        [
871            254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
872            79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24
873        ]
874    );
875
876    assert_eq!(
877        (-&Scalar::one()).to_bytes(),
878        [
879            0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
880            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
881        ]
882    );
883}
884
885#[test]
886fn test_from_bytes() {
887    let mut a = R2;
888
889    for _ in 0..100 {
890        let bytes = a.to_bytes();
891        let b = Scalar::from_bytes(&bytes).unwrap();
892        assert_eq!(a, b);
893
894        // Test negation roundtrip
895        let bytes = (-a).to_bytes();
896        let b = Scalar::from_bytes(&bytes).unwrap();
897        assert_eq!(-a, b);
898
899        a = a.square();
900    }
901}
902
903#[cfg(test)]
904const LARGEST: Scalar = Scalar([
905    0xffff_ffff_0000_0000,
906    0x53bd_a402_fffe_5bfe,
907    0x3339_d808_09a1_d805,
908    0x73ed_a753_299d_7d48,
909]);
910
911#[test]
912fn test_addition() {
913    let mut tmp = LARGEST;
914    tmp += &LARGEST;
915
916    assert_eq!(
917        tmp,
918        Scalar([
919            0xffff_fffe_ffff_ffff,
920            0x53bd_a402_fffe_5bfe,
921            0x3339_d808_09a1_d805,
922            0x73ed_a753_299d_7d48,
923        ])
924    );
925
926    let mut tmp = LARGEST;
927    tmp += &Scalar([1, 0, 0, 0]);
928
929    assert_eq!(tmp, Scalar::zero());
930}
931
932#[test]
933fn test_inversion() {
934    assert!(bool::from(Scalar::zero().invert().is_none()));
935    assert_eq!(Scalar::one().invert().unwrap(), Scalar::one());
936    assert_eq!((-&Scalar::one()).invert().unwrap(), -&Scalar::one());
937
938    let mut tmp = R2;
939
940    for _ in 0..100 {
941        let mut tmp2 = tmp.invert().unwrap();
942        tmp2.mul_assign(&tmp);
943
944        assert_eq!(tmp2, Scalar::one());
945
946        tmp.add_assign(&R2);
947    }
948}
949
950#[test]
951fn test_from_raw() {
952    assert_eq!(
953        Scalar::from_raw([
954            0x0001_ffff_fffd,
955            0x5884_b7fa_0003_4802,
956            0x998c_4fef_ecbc_4ff5,
957            0x1824_b159_acc5_056f,
958        ]),
959        Scalar::from_raw([0xffff_ffff_ffff_ffff; 4])
960    );
961
962    assert_eq!(Scalar::from_raw(MODULUS.0), Scalar::zero());
963
964    assert_eq!(Scalar::from_raw([1, 0, 0, 0]), R);
965}
966
967#[test]
968fn test_scalar_hash_to_field() {
969    let data1 = b"some input data";
970    let data2 = b"different input data";
971    let dst1 = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_NUL_";  // Standard DST format
972    let dst2 = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_";
973
974    // 1. Different data should produce different scalars
975    let s1 = Scalar::hash_to_field(data1, dst1).unwrap();
976    let s2 = Scalar::hash_to_field(data2, dst1).unwrap();
977    assert_ne!(s1, s2);
978
979    // 2. Same data with different DSTs should produce different scalars
980    let s3 = Scalar::hash_to_field(data1, dst1).unwrap();
981    let s4 = Scalar::hash_to_field(data1, dst2).unwrap();
982    assert_ne!(s3, s4);
983
984    // 3. Hashing should be deterministic
985    let s5 = Scalar::hash_to_field(data1, dst1).unwrap();
986    assert_eq!(s3, s5);
987
988    // 4. Verify output is always valid scalar (less than modulus)
989    for test_case in &[
990        b"" as &[u8],
991        b"a",
992        b"test",
993        &[0xFF; 100],
994        &[0x00; 64],
995    ] {
996        let scalar = Scalar::hash_to_field(test_case, dst1).unwrap();
997        // The scalar should already be reduced, so converting to/from bytes should work
998        let bytes = scalar.to_bytes();
999        let scalar2 = Scalar::from_bytes(&bytes).unwrap();
1000        assert_eq!(scalar, scalar2, "Output should be a valid reduced scalar");
1001    }
1002
1003    // 5. Test that the expansion reduces bias appropriately
1004    // With 64 bytes (512 bits) being reduced to ~255 bits, bias should be negligible
1005    let mut scalars = Vec::new();
1006    for i in 0u32..100 {
1007        let data = i.to_le_bytes();
1008        let s = Scalar::hash_to_field(&data, dst1).unwrap();
1009        scalars.push(s);
1010    }
1011    // All should be different (no collisions in small sample)
1012    for i in 0..scalars.len() {
1013        for j in i+1..scalars.len() {
1014            assert_ne!(scalars[i], scalars[j], "Unexpected collision at {} and {}", i, j);
1015        }
1016    }
1017
1018    // 6. Test empty DST and empty data edge cases
1019    let s_empty = Scalar::hash_to_field(b"", b"").unwrap();
1020    let s_empty2 = Scalar::hash_to_field(b"", b"").unwrap();
1021    assert_eq!(s_empty, s_empty2, "Empty input should still be deterministic");
1022
1023    // 7. Verify that DST length is properly included (catches common implementation bugs)
1024    let dst_short = b"A";
1025    let dst_long = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"; // 50 A's
1026    let s_short = Scalar::hash_to_field(data1, dst_short).unwrap();
1027    let s_long = Scalar::hash_to_field(data1, dst_long).unwrap();
1028    assert_ne!(s_short, s_long, "DST length should affect output");
1029
1030    // 8. Test mathematical properties: hash(data) should be uniformly distributed
1031    // We can't test true uniformity easily, but we can check it's not always even/odd
1032    let mut has_odd = false;
1033    let mut has_even = false;
1034    for i in 0u8..20 {
1035        let s = Scalar::hash_to_field(&[i], dst1).unwrap();
1036        // Check the least significant bit
1037        if s.to_bytes()[0] & 1 == 0 {
1038            has_even = true;
1039        } else {
1040            has_odd = true;
1041        }
1042    }
1043    assert!(has_odd && has_even, "Hash output should have both odd and even values");
1044
1045    // 9. Test expand_message_xmd internal function with basic test vectors
1046    // These help ensure our implementation follows the standard
1047    let expanded = Scalar::expand_message_xmd(b"", b"QUUX-V01-CS02-with-SHA256", 32).unwrap();
1048    assert_eq!(expanded.len(), 32);
1049    
1050    // Basic sanity check: different messages produce different expansions
1051    let expanded1 = Scalar::expand_message_xmd(b"msg1", b"dst", 64).unwrap();
1052    let expanded2 = Scalar::expand_message_xmd(b"msg2", b"dst", 64).unwrap();
1053    assert_ne!(expanded1, expanded2);
1054}
1055
1056#[cfg(feature = "zeroize")]
1057#[test]
1058fn test_zeroize() {
1059    use zeroize::Zeroize;
1060
1061    let mut a = Scalar::from_raw([
1062        0x1fff_3231_233f_fffd,
1063        0x4884_b7fa_0003_4802,
1064        0x998c_4fef_ecbc_4ff3,
1065        0x1824_b159_acc5_0562,
1066    ]);
1067    a.zeroize();
1068    // Fixed: disambiguate ct_eq
1069    assert!(bool::from(subtle::ConstantTimeEq::ct_eq(&a, &Scalar::zero())));
1070}