Skip to main content

k256/arithmetic/
scalar.rs

1//! Scalar field arithmetic.
2
3use crate::{FieldBytes, NonZeroScalar, ORDER, ORDER_HEX, Secp256k1, WideBytes};
4use core::iter::{Product, Sum};
5use elliptic_curve::{
6    Curve, Error, Generate, ScalarValue,
7    bigint::{ArrayEncoding, Limb, U256, U512, Word, cpubits, modular::Retrieve},
8    ctutils,
9    ff::{self, Field, FromUniformBytes, PrimeField},
10    ops::{
11        Add, AddAssign, Invert, Mul, MulAssign, Neg, Reduce, ReduceNonZero, Shr, ShrAssign, Sub,
12        SubAssign,
13    },
14    rand_core::{CryptoRng, TryCryptoRng, TryRng},
15    scalar::{FromUintUnchecked, IsHigh},
16    subtle::{
17        Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
18        CtOption,
19    },
20    zeroize::DefaultIsZeroes,
21};
22
23cpubits! {
24    32 => {
25        #[path = "scalar/wide32.rs"]
26        mod wide;
27    }
28    64 => {
29         #[path = "scalar/wide64.rs"]
30        mod wide;
31    }
32}
33pub(crate) use self::wide::WideScalar;
34
35#[cfg(feature = "serde")]
36use serdect::serde::{Deserialize, Serialize, de, ser};
37#[cfg(feature = "bits")]
38use {crate::ScalarBits, elliptic_curve::group::ff::PrimeFieldBits};
39
40#[cfg(test)]
41use num_bigint::{BigUint, ToBigUint};
42
43/// Constant representing the modulus
44/// n = FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141
45const MODULUS: [Word; U256::LIMBS] = ORDER.as_ref().to_words();
46
47/// Constant representing the modulus / 2
48const FRAC_MODULUS_2: U256 = ORDER.as_ref().shr_vartime(1);
49
50/// Scalars are elements in the finite field modulo n.
51///
52/// # Trait impls
53///
54/// Much of the important functionality of scalars is provided by traits from
55/// the [`ff`](https://docs.rs/ff/) crate, which is re-exported as
56/// `k256::elliptic_curve::ff`:
57///
58/// - [`Field`](https://docs.rs/ff/latest/ff/trait.Field.html) -
59///   represents elements of finite fields and provides:
60///   - [`Field::random`](https://docs.rs/ff/latest/ff/trait.Field.html#tymethod.random) -
61///     generate a random scalar
62///   - `double`, `square`, and `invert` operations
63///   - Bounds for [`Add`], [`Sub`], [`Mul`], and [`Neg`] (as well as `*Assign` equivalents)
64///   - Bounds for [`ConditionallySelectable`] from the `subtle` crate
65/// - [`PrimeField`](https://docs.rs/ff/latest/ff/trait.PrimeField.html) -
66///   represents elements of prime fields and provides:
67///   - `from_repr`/`to_repr` for converting field elements from/to big integers.
68///   - `multiplicative_generator` and `root_of_unity` constants.
69/// - [`PrimeFieldBits`](https://docs.rs/ff/latest/ff/trait.PrimeFieldBits.html) -
70///   operations over field elements represented as bits (requires `bits` feature)
71///
72/// Please see the documentation for the relevant traits for more information.
73///
74/// # `serde` support
75///
76/// When the `serde` feature of this crate is enabled, the `Serialize` and
77/// `Deserialize` traits are impl'd for this type.
78///
79/// The serialization is a fixed-width big endian encoding. When used with
80/// textual formats, the binary data is encoded as hexadecimal.
81#[derive(Clone, Copy, Debug, Default, PartialOrd, Ord)]
82pub struct Scalar(pub(crate) U256);
83
84impl AsRef<Scalar> for Scalar {
85    fn as_ref(&self) -> &Scalar {
86        self
87    }
88}
89
90impl Scalar {
91    /// Zero scalar.
92    pub const ZERO: Self = Self(U256::ZERO);
93
94    /// Multiplicative identity.
95    pub const ONE: Self = Self(U256::ONE);
96
97    /// Checks if the scalar is zero.
98    pub fn is_zero(&self) -> Choice {
99        self.0.is_zero().into()
100    }
101
102    /// Returns the SEC1 encoding of this scalar.
103    pub fn to_bytes(&self) -> FieldBytes {
104        self.0.to_be_byte_array()
105    }
106
107    /// Negates the scalar.
108    pub const fn negate(&self) -> Self {
109        Self(self.0.neg_mod(ORDER.as_nz_ref()))
110    }
111
112    /// Returns self + rhs mod n.
113    pub const fn add(&self, rhs: &Self) -> Self {
114        Self(self.0.add_mod(&rhs.0, ORDER.as_nz_ref()))
115    }
116
117    /// Returns self - rhs mod n.
118    pub const fn sub(&self, rhs: &Self) -> Self {
119        Self(self.0.sub_mod(&rhs.0, ORDER.as_nz_ref()))
120    }
121
122    /// Modulo multiplies two scalars.
123    pub fn mul(&self, rhs: &Scalar) -> Scalar {
124        WideScalar::mul_wide(self, rhs).reduce()
125    }
126
127    /// Modulo squares the scalar.
128    pub fn square(&self) -> Self {
129        self.mul(self)
130    }
131
132    /// Right shifts the scalar.
133    ///
134    /// Note: not constant-time with respect to the `shift` parameter.
135    pub fn shr_vartime(&self, shift: u32) -> Scalar {
136        Self(self.0.unbounded_shr_vartime(shift))
137    }
138
139    /// Returns the multiplicative inverse of self, if self is non-zero.
140    pub fn invert(&self) -> CtOption<Self> {
141        let inv = self.retrieve().invert_odd_mod(&ORDER);
142
143        CtOption::from(inv).map(Self::from_uint_unchecked)
144    }
145
146    /// Returns the multiplicative inverse of self in variable-time, if self is non-zero.
147    pub fn invert_vartime(&self) -> CtOption<Self> {
148        let inv = self.retrieve().invert_odd_mod_vartime(&ORDER);
149
150        CtOption::from(inv).map(Self::from_uint_unchecked)
151    }
152
153    /// Returns the scalar modulus as a `BigUint` object.
154    #[cfg(test)]
155    pub fn modulus_as_biguint() -> BigUint {
156        Self::ONE.negate().to_biguint().unwrap() + 1.to_biguint().unwrap()
157    }
158
159    /// Returns a (nearly) uniformly-random scalar, generated in constant time.
160    pub fn generate_biased_from_rng<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
161        let Ok(scalar) = Self::try_generate_biased_from_rng(rng);
162        scalar
163    }
164
165    /// Returns a (nearly) uniformly-random scalar, generated in constant time.
166    pub fn try_generate_biased_from_rng<R: TryCryptoRng + ?Sized>(
167        rng: &mut R,
168    ) -> Result<Self, R::Error> {
169        // We reduce a random 512-bit value into a 256-bit field, which results in a
170        // negligible bias from the uniform distribution, but the process is constant-time.
171        let mut buf = [0u8; 64];
172        rng.try_fill_bytes(&mut buf)?;
173        Ok(WideScalar::from_bytes(&buf).reduce())
174    }
175
176    /// Attempts to parse the given byte array as a scalar.
177    /// Does not check the result for being in the correct range.
178    pub(crate) const fn from_bytes_unchecked(bytes: &[u8; 32]) -> Self {
179        Self(U256::from_be_slice(bytes))
180    }
181}
182
183impl Field for Scalar {
184    const ZERO: Self = Self::ZERO;
185    const ONE: Self = Self::ONE;
186
187    fn try_from_rng<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
188        // Uses rejection sampling as the default random generation method,
189        // which produces a uniformly random distribution of scalars.
190        //
191        // This method is not constant time, but should be secure so long as
192        // rejected RNG outputs are unrelated to future ones (which is a
193        // necessary property of a `CryptoRng`).
194        //
195        // With an unbiased RNG, the probability of failing to complete after 4
196        // iterations is vanishingly small.
197        let mut bytes = FieldBytes::default();
198
199        // TODO: pre-generate several scalars to bring the probability of non-constant-timeness down?
200        loop {
201            rng.try_fill_bytes(&mut bytes)?;
202            if let Some(scalar) = Scalar::from_repr(bytes).into() {
203                return Ok(scalar);
204            }
205        }
206    }
207
208    fn square(&self) -> Self {
209        Scalar::square(self)
210    }
211
212    fn double(&self) -> Self {
213        self.add(self)
214    }
215
216    fn invert(&self) -> CtOption<Self> {
217        Scalar::invert(self)
218    }
219
220    /// Tonelli-Shank's algorithm for q mod 16 = 1
221    /// <https://eprint.iacr.org/2012/685.pdf> (page 12, algorithm 5)
222    #[allow(clippy::many_single_char_names)]
223    fn sqrt(&self) -> CtOption<Self> {
224        // Note: `pow_vartime` is constant-time with respect to `self`
225        let w = self.pow_vartime([
226            0x777fa4bd19a06c82,
227            0xfd755db9cd5e9140,
228            0xffffffffffffffff,
229            0x1ffffffffffffff,
230        ]);
231
232        let mut v = Self::S;
233        let mut x = *self * w;
234        let mut b = x * w;
235        let mut z = Self::ROOT_OF_UNITY;
236
237        for max_v in (1..=Self::S).rev() {
238            let mut k = 1;
239            let mut tmp = b.square();
240            let mut j_less_than_v = Choice::from(1);
241
242            for j in 2..max_v {
243                let tmp_is_one = tmp.ct_eq(&Self::ONE);
244                let squared = Self::conditional_select(&tmp, &z, tmp_is_one).square();
245                tmp = Self::conditional_select(&squared, &tmp, tmp_is_one);
246                let new_z = Self::conditional_select(&z, &squared, tmp_is_one);
247                j_less_than_v &= !ConstantTimeEq::ct_eq(&j, &v);
248                k = u32::conditional_select(&j, &k, tmp_is_one);
249                z = Self::conditional_select(&z, &new_z, j_less_than_v);
250            }
251
252            let result = x * z;
253            x = Self::conditional_select(&result, &x, b.ct_eq(&Self::ONE));
254            z = z.square();
255            b *= z;
256            v = k;
257        }
258
259        CtOption::new(x, x.square().ct_eq(self))
260    }
261
262    fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
263        ff::helpers::sqrt_ratio_generic(num, div)
264    }
265}
266
267impl Generate for Scalar {
268    fn try_generate_from_rng<R: TryCryptoRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
269        Self::try_from_rng(rng)
270    }
271}
272
273impl PrimeField for Scalar {
274    type Repr = FieldBytes;
275
276    const MODULUS: &'static str = ORDER_HEX;
277    const NUM_BITS: u32 = 256;
278    const CAPACITY: u32 = 255;
279    const TWO_INV: Self = Self(U256::from_be_hex(
280        "7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a1",
281    ));
282    const MULTIPLICATIVE_GENERATOR: Self = Self(U256::from_u8(7));
283    const S: u32 = 6;
284    const ROOT_OF_UNITY: Self = Self(U256::from_be_hex(
285        "0c1dc060e7a91986df9879a3fbc483a898bdeab680756045992f4b5402b052f2",
286    ));
287    const ROOT_OF_UNITY_INV: Self = Self(U256::from_be_hex(
288        "fd3ae181f12d7096efc7b0c75b8cbb7277a275910aa413c3b6fb30a0884f0d1c",
289    ));
290    const DELTA: Self = Self(U256::from_be_hex(
291        "0000000000000000000cbc21fe4561c8d63b78e780e1341e199417c8c0bb7601",
292    ));
293
294    /// Attempts to parse the given byte array as an SEC1-encoded scalar.
295    ///
296    /// Returns None if the byte array does not contain a big-endian integer in the range
297    /// [0, p).
298    fn from_repr(bytes: FieldBytes) -> CtOption<Self> {
299        let inner = U256::from_be_byte_array(bytes);
300        CtOption::new(
301            Self(inner),
302            ConstantTimeLess::ct_lt(&inner, &Secp256k1::ORDER),
303        )
304    }
305
306    fn to_repr(&self) -> FieldBytes {
307        self.to_bytes()
308    }
309
310    fn to_le_repr(&self) -> Self::Repr {
311        self.0.to_le_byte_array()
312    }
313
314    fn is_odd(&self) -> Choice {
315        self.0.is_odd().into()
316    }
317}
318
319// Detect mismatch between our word size and bitvec's word size
320cpubits! {
321    64 => {
322        #[cfg(all(feature = "bits", target_pointer_width = "32"))]
323        compile_error!("the 'bits' feature is not supported on this target");
324    }
325}
326
327#[cfg(feature = "bits")]
328impl PrimeFieldBits for Scalar {
329    cpubits! {
330        32 => { type ReprBits = [u32; 8]; }
331        64 => { type ReprBits = [u64; 4]; }
332    }
333
334    fn to_le_bits(&self) -> ScalarBits {
335        self.into()
336    }
337
338    fn char_le_bits() -> ScalarBits {
339        ORDER.to_words().into()
340    }
341}
342
343impl DefaultIsZeroes for Scalar {}
344
345impl From<u32> for Scalar {
346    fn from(k: u32) -> Self {
347        Self(k.into())
348    }
349}
350
351impl From<u64> for Scalar {
352    fn from(k: u64) -> Self {
353        Self(k.into())
354    }
355}
356
357impl From<u128> for Scalar {
358    fn from(k: u128) -> Self {
359        Self(k.into())
360    }
361}
362
363impl FromUniformBytes<64> for Scalar {
364    fn from_uniform_bytes(bytes: &[u8; 64]) -> Self {
365        WideScalar::from_bytes(bytes).reduce()
366    }
367}
368
369impl From<NonZeroScalar> for Scalar {
370    fn from(scalar: NonZeroScalar) -> Self {
371        *scalar.as_ref()
372    }
373}
374
375impl From<&NonZeroScalar> for Scalar {
376    fn from(scalar: &NonZeroScalar) -> Self {
377        *scalar.as_ref()
378    }
379}
380
381impl From<ScalarValue<Secp256k1>> for Scalar {
382    fn from(scalar: ScalarValue<Secp256k1>) -> Scalar {
383        Scalar(*scalar.as_uint())
384    }
385}
386
387impl From<&ScalarValue<Secp256k1>> for Scalar {
388    fn from(scalar: &ScalarValue<Secp256k1>) -> Scalar {
389        Scalar(*scalar.as_uint())
390    }
391}
392
393impl From<Scalar> for ScalarValue<Secp256k1> {
394    fn from(scalar: Scalar) -> ScalarValue<Secp256k1> {
395        ScalarValue::from(&scalar)
396    }
397}
398
399impl From<&Scalar> for ScalarValue<Secp256k1> {
400    fn from(scalar: &Scalar) -> ScalarValue<Secp256k1> {
401        ScalarValue::new(scalar.0).unwrap()
402    }
403}
404
405/// The constant-time alternative is available at [`NonZeroScalar::new()`].
406impl TryFrom<Scalar> for NonZeroScalar {
407    type Error = Error;
408
409    fn try_from(scalar: Scalar) -> Result<Self, Error> {
410        NonZeroScalar::new(scalar).into_option().ok_or(Error)
411    }
412}
413
414impl FromUintUnchecked for Scalar {
415    type Uint = U256;
416
417    fn from_uint_unchecked(uint: Self::Uint) -> Self {
418        Self(uint)
419    }
420}
421
422impl Invert for Scalar {
423    type Output = CtOption<Self>;
424
425    fn invert(&self) -> CtOption<Self> {
426        Scalar::invert(self)
427    }
428
429    fn invert_vartime(&self) -> CtOption<Self> {
430        Scalar::invert_vartime(self)
431    }
432}
433
434impl IsHigh for Scalar {
435    fn is_high(&self) -> Choice {
436        ConstantTimeGreater::ct_gt(&self.0, &FRAC_MODULUS_2)
437    }
438}
439
440impl Shr<usize> for Scalar {
441    type Output = Self;
442
443    fn shr(self, rhs: usize) -> Self::Output {
444        self.shr_vartime(rhs as u32)
445    }
446}
447
448impl Shr<usize> for &Scalar {
449    type Output = Scalar;
450
451    fn shr(self, rhs: usize) -> Self::Output {
452        self.shr_vartime(rhs as u32)
453    }
454}
455
456impl ShrAssign<usize> for Scalar {
457    fn shr_assign(&mut self, rhs: usize) {
458        *self = *self >> rhs;
459    }
460}
461
462impl ConditionallySelectable for Scalar {
463    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
464        Self(U256::conditional_select(&a.0, &b.0, choice))
465    }
466}
467
468impl ConstantTimeEq for Scalar {
469    fn ct_eq(&self, other: &Self) -> Choice {
470        ConstantTimeEq::ct_eq(&self.0, &other.0)
471    }
472}
473
474impl ctutils::CtEq for Scalar {
475    fn ct_eq(&self, other: &Self) -> ctutils::Choice {
476        ConstantTimeEq::ct_eq(self, other).into()
477    }
478}
479
480impl ctutils::CtSelect for Scalar {
481    fn ct_select(&self, other: &Self, choice: ctutils::Choice) -> Self {
482        ConditionallySelectable::conditional_select(self, other, choice.into())
483    }
484}
485
486impl PartialEq for Scalar {
487    fn eq(&self, other: &Self) -> bool {
488        ConstantTimeEq::ct_eq(self, other).into()
489    }
490}
491
492impl Eq for Scalar {}
493
494impl Neg for Scalar {
495    type Output = Scalar;
496
497    fn neg(self) -> Scalar {
498        self.negate()
499    }
500}
501
502impl Neg for &Scalar {
503    type Output = Scalar;
504
505    fn neg(self) -> Scalar {
506        self.negate()
507    }
508}
509
510impl Add<Scalar> for Scalar {
511    type Output = Scalar;
512
513    fn add(self, other: Scalar) -> Scalar {
514        Scalar::add(&self, &other)
515    }
516}
517
518impl Add<&Scalar> for &Scalar {
519    type Output = Scalar;
520
521    fn add(self, other: &Scalar) -> Scalar {
522        Scalar::add(self, other)
523    }
524}
525
526impl Add<Scalar> for &Scalar {
527    type Output = Scalar;
528
529    fn add(self, other: Scalar) -> Scalar {
530        Scalar::add(self, &other)
531    }
532}
533
534impl Add<&Scalar> for Scalar {
535    type Output = Scalar;
536
537    fn add(self, other: &Scalar) -> Scalar {
538        Scalar::add(&self, other)
539    }
540}
541
542impl AddAssign<Scalar> for Scalar {
543    #[inline]
544    fn add_assign(&mut self, rhs: Scalar) {
545        *self = Scalar::add(self, &rhs);
546    }
547}
548
549impl AddAssign<&Scalar> for Scalar {
550    fn add_assign(&mut self, rhs: &Scalar) {
551        *self = Scalar::add(self, rhs);
552    }
553}
554
555impl Sub<Scalar> for Scalar {
556    type Output = Scalar;
557
558    fn sub(self, other: Scalar) -> Scalar {
559        Scalar::sub(&self, &other)
560    }
561}
562
563impl Sub<&Scalar> for &Scalar {
564    type Output = Scalar;
565
566    fn sub(self, other: &Scalar) -> Scalar {
567        Scalar::sub(self, other)
568    }
569}
570
571impl Sub<&Scalar> for Scalar {
572    type Output = Scalar;
573
574    fn sub(self, other: &Scalar) -> Scalar {
575        Scalar::sub(&self, other)
576    }
577}
578
579impl SubAssign<Scalar> for Scalar {
580    fn sub_assign(&mut self, rhs: Scalar) {
581        *self = Scalar::sub(self, &rhs);
582    }
583}
584
585impl SubAssign<&Scalar> for Scalar {
586    fn sub_assign(&mut self, rhs: &Scalar) {
587        *self = Scalar::sub(self, rhs);
588    }
589}
590
591impl Mul<Scalar> for Scalar {
592    type Output = Scalar;
593
594    fn mul(self, other: Scalar) -> Scalar {
595        Scalar::mul(&self, &other)
596    }
597}
598
599impl Mul<&Scalar> for &Scalar {
600    type Output = Scalar;
601
602    fn mul(self, other: &Scalar) -> Scalar {
603        Scalar::mul(self, other)
604    }
605}
606
607impl Mul<&Scalar> for Scalar {
608    type Output = Scalar;
609
610    fn mul(self, other: &Scalar) -> Scalar {
611        Scalar::mul(&self, other)
612    }
613}
614
615elliptic_curve::scalar_mul_impls!(Secp256k1, Scalar);
616
617impl MulAssign<Scalar> for Scalar {
618    fn mul_assign(&mut self, rhs: Scalar) {
619        *self = Scalar::mul(self, &rhs);
620    }
621}
622
623impl MulAssign<&Scalar> for Scalar {
624    fn mul_assign(&mut self, rhs: &Scalar) {
625        *self = Scalar::mul(self, rhs);
626    }
627}
628
629impl Reduce<U256> for Scalar {
630    fn reduce(w: &U256) -> Self {
631        let (r, underflow) = w.borrowing_sub(&ORDER, Limb::ZERO);
632        let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
633        Self(U256::conditional_select(w, &r, !underflow))
634    }
635}
636
637impl Reduce<FieldBytes> for Scalar {
638    #[inline]
639    fn reduce(bytes: &FieldBytes) -> Self {
640        Self::reduce(&U256::from_be_byte_array(*bytes))
641    }
642}
643
644impl Reduce<U512> for Scalar {
645    fn reduce(w: &U512) -> Self {
646        WideScalar(*w).reduce()
647    }
648}
649
650impl Reduce<WideBytes> for Scalar {
651    fn reduce(bytes: &WideBytes) -> Self {
652        Self::reduce(&U512::from_be_byte_array(*bytes))
653    }
654}
655
656impl ReduceNonZero<U256> for Scalar {
657    fn reduce_nonzero(w: &U256) -> Self {
658        const ORDER_MINUS_ONE: U256 = ORDER.as_ref().wrapping_sub(&U256::ONE);
659        let (r, underflow) = w.borrowing_sub(&ORDER_MINUS_ONE, Limb::ZERO);
660        let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
661        Self(U256::conditional_select(w, &r, !underflow).wrapping_add(&U256::ONE))
662    }
663}
664
665impl ReduceNonZero<FieldBytes> for Scalar {
666    #[inline]
667    fn reduce_nonzero(bytes: &FieldBytes) -> Self {
668        Self::reduce_nonzero(&U256::from_be_byte_array(*bytes))
669    }
670}
671
672impl ReduceNonZero<U512> for Scalar {
673    fn reduce_nonzero(w: &U512) -> Self {
674        WideScalar(*w).reduce_nonzero()
675    }
676}
677
678impl ReduceNonZero<WideBytes> for Scalar {
679    #[inline]
680    fn reduce_nonzero(bytes: &WideBytes) -> Self {
681        Self::reduce_nonzero(&U512::from_be_byte_array(*bytes))
682    }
683}
684
685impl Retrieve for Scalar {
686    type Output = U256;
687
688    fn retrieve(&self) -> U256 {
689        U256::from_be_byte_array(self.to_bytes())
690    }
691}
692
693impl Sum for Scalar {
694    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
695        iter.reduce(Add::add).unwrap_or(Self::ZERO)
696    }
697}
698
699impl<'a> Sum<&'a Scalar> for Scalar {
700    fn sum<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
701        iter.copied().sum()
702    }
703}
704
705impl Product for Scalar {
706    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
707        iter.reduce(Mul::mul).unwrap_or(Self::ONE)
708    }
709}
710
711impl<'a> Product<&'a Scalar> for Scalar {
712    fn product<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
713        iter.copied().product()
714    }
715}
716
717#[cfg(feature = "bits")]
718impl From<&Scalar> for ScalarBits {
719    fn from(scalar: &Scalar) -> ScalarBits {
720        scalar.0.to_words().into()
721    }
722}
723
724impl From<Scalar> for FieldBytes {
725    fn from(scalar: Scalar) -> Self {
726        scalar.to_bytes()
727    }
728}
729
730impl From<&Scalar> for FieldBytes {
731    fn from(scalar: &Scalar) -> Self {
732        scalar.to_bytes()
733    }
734}
735
736impl From<Scalar> for U256 {
737    fn from(scalar: Scalar) -> Self {
738        scalar.0
739    }
740}
741
742impl From<&Scalar> for U256 {
743    fn from(scalar: &Scalar) -> Self {
744        scalar.0
745    }
746}
747
748#[cfg(feature = "serde")]
749impl Serialize for Scalar {
750    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
751    where
752        S: ser::Serializer,
753    {
754        ScalarValue::from(self).serialize(serializer)
755    }
756}
757
758#[cfg(feature = "serde")]
759impl<'de> Deserialize<'de> for Scalar {
760    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
761    where
762        D: de::Deserializer<'de>,
763    {
764        Ok(ScalarValue::deserialize(deserializer)?.into())
765    }
766}
767
768#[cfg(test)]
769mod tests {
770    use super::Scalar;
771    use crate::{
772        FieldBytes, NonZeroScalar, ORDER, WideBytes,
773        arithmetic::dev::{biguint_to_bytes, bytes_to_biguint},
774    };
775    use elliptic_curve::{
776        array::Array,
777        bigint::{ArrayEncoding, U256, U512},
778        ff::{Field, PrimeField},
779        ops::{BatchInvert, Reduce},
780        scalar::IsHigh,
781    };
782    use num_bigint::{BigUint, ToBigUint};
783    use num_traits::Zero;
784    use proptest::prelude::*;
785
786    #[cfg(feature = "getrandom")]
787    use elliptic_curve::{Generate, common::getrandom::SysRng};
788
789    impl From<&BigUint> for Scalar {
790        fn from(x: &BigUint) -> Self {
791            debug_assert!(x < &Scalar::modulus_as_biguint());
792            let bytes = biguint_to_bytes(x);
793            Self::from_repr(bytes.into()).unwrap()
794        }
795    }
796
797    impl From<BigUint> for Scalar {
798        fn from(x: BigUint) -> Self {
799            Self::from(&x)
800        }
801    }
802
803    impl ToBigUint for Scalar {
804        fn to_biguint(&self) -> Option<BigUint> {
805            Some(bytes_to_biguint(self.to_bytes().as_ref()))
806        }
807    }
808
809    /// t = (modulus - 1) >> S
810    const T: [u64; 4] = [
811        0xeeff497a3340d905,
812        0xfaeabb739abd2280,
813        0xffffffffffffffff,
814        0x03ffffffffffffff,
815    ];
816
817    #[test]
818    fn two_inv_constant() {
819        assert_eq!(Scalar::from(2u32) * Scalar::TWO_INV, Scalar::ONE);
820    }
821
822    #[test]
823    fn root_of_unity_constant() {
824        // ROOT_OF_UNITY^{2^s} mod m == 1
825        assert_eq!(
826            Scalar::ROOT_OF_UNITY.pow_vartime([1u64 << Scalar::S, 0, 0, 0]),
827            Scalar::ONE
828        );
829
830        // MULTIPLICATIVE_GENERATOR^{t} mod m == ROOT_OF_UNITY
831        assert_eq!(
832            Scalar::MULTIPLICATIVE_GENERATOR.pow_vartime(T),
833            Scalar::ROOT_OF_UNITY
834        )
835    }
836
837    #[test]
838    fn root_of_unity_inv_constant() {
839        assert_eq!(
840            Scalar::ROOT_OF_UNITY * Scalar::ROOT_OF_UNITY_INV,
841            Scalar::ONE
842        );
843    }
844
845    #[test]
846    fn delta_constant() {
847        // DELTA^{t} mod m == 1
848        assert_eq!(Scalar::DELTA.pow_vartime(T), Scalar::ONE);
849    }
850
851    #[test]
852    fn is_high() {
853        // 0 is not high
854        let high: bool = Scalar::ZERO.is_high().into();
855        assert!(!high);
856
857        // 1 is not high
858        let one = 1.to_biguint().unwrap();
859        let high: bool = Scalar::from(&one).is_high().into();
860        assert!(!high);
861
862        let m = Scalar::modulus_as_biguint();
863        let m_by_2 = &m >> 1;
864
865        // M / 2 is not high
866        let high: bool = Scalar::from(&m_by_2).is_high().into();
867        assert!(!high);
868
869        // M / 2 + 1 is high
870        let high: bool = Scalar::from(&m_by_2 + &one).is_high().into();
871        assert!(high);
872
873        // MODULUS - 1 is high
874        let high: bool = Scalar::from(&m - &one).is_high().into();
875        assert!(high);
876    }
877
878    /// Basic tests that sqrt works.
879    #[test]
880    fn sqrt() {
881        for &n in &[1u64, 4, 9, 16, 25, 36, 49, 64] {
882            let scalar = Scalar::from(n);
883            let sqrt = scalar.sqrt().unwrap();
884            assert_eq!(sqrt.square(), scalar);
885        }
886    }
887
888    /// Basic tests that `invert` works.
889    #[test]
890    fn invert() {
891        assert_eq!(Scalar::ONE, Scalar::ONE.invert().unwrap());
892
893        let three = Scalar::from(3u64);
894        let inv_three = three.invert().unwrap();
895        assert_eq!(three * inv_three, Scalar::ONE);
896
897        let minus_three = -three;
898        let inv_minus_three = minus_three.invert().unwrap();
899        assert_eq!(inv_minus_three, -inv_three);
900        assert_eq!(three * inv_minus_three, -Scalar::ONE);
901
902        assert!(bool::from(Scalar::ZERO.invert().is_none()));
903        assert_eq!(Scalar::from(2u64).invert().unwrap(), Scalar::TWO_INV);
904        assert_eq!(
905            Scalar::ROOT_OF_UNITY.invert_vartime().unwrap(),
906            Scalar::ROOT_OF_UNITY_INV
907        );
908    }
909
910    /// Basic tests that `invert_vartime` works.
911    #[test]
912    fn invert_vartime() {
913        assert_eq!(Scalar::ONE, Scalar::ONE.invert_vartime().unwrap());
914
915        let three = Scalar::from(3u64);
916        let inv_three = three.invert_vartime().unwrap();
917        assert_eq!(three * inv_three, Scalar::ONE);
918
919        let minus_three = -three;
920        let inv_minus_three = minus_three.invert_vartime().unwrap();
921        assert_eq!(inv_minus_three, -inv_three);
922        assert_eq!(three * inv_minus_three, -Scalar::ONE);
923
924        assert!(bool::from(Scalar::ZERO.invert_vartime().is_none()));
925        assert_eq!(
926            Scalar::from(2u64).invert_vartime().unwrap(),
927            Scalar::TWO_INV
928        );
929        assert_eq!(
930            Scalar::ROOT_OF_UNITY.invert_vartime().unwrap(),
931            Scalar::ROOT_OF_UNITY_INV
932        );
933    }
934
935    #[test]
936    #[cfg(feature = "getrandom")]
937    fn batch_invert_array() {
938        let k: Scalar = Scalar::generate();
939        let l: Scalar = Scalar::generate();
940
941        let expected = [k.invert().unwrap(), l.invert().unwrap()];
942        assert_eq!(
943            <Scalar as BatchInvert<_>>::batch_invert([k, l]).unwrap(),
944            expected
945        );
946    }
947
948    #[test]
949    #[cfg(all(feature = "alloc", feature = "getrandom"))]
950    fn batch_invert() {
951        let k: Scalar = Scalar::generate();
952        let l: Scalar = Scalar::generate();
953
954        let expected = vec![k.invert().unwrap(), l.invert().unwrap()];
955        let scalars = vec![k, l];
956        let res = <Scalar as BatchInvert<_>>::batch_invert(scalars).unwrap();
957        assert_eq!(res, expected);
958    }
959
960    #[test]
961    fn negate() {
962        let zero_neg = -Scalar::ZERO;
963        assert_eq!(zero_neg, Scalar::ZERO);
964
965        let m = Scalar::modulus_as_biguint();
966        let one = 1.to_biguint().unwrap();
967        let m_minus_one = &m - &one;
968        let m_by_2 = &m >> 1;
969
970        let one_neg = -Scalar::ONE;
971        assert_eq!(one_neg, Scalar::from(&m_minus_one));
972
973        let frac_modulus_2_neg = -Scalar::from(&m_by_2);
974        let frac_modulus_2_plus_one = Scalar::from(&m_by_2 + &one);
975        assert_eq!(frac_modulus_2_neg, frac_modulus_2_plus_one);
976
977        let modulus_minus_one_neg = -Scalar::from(&m - &one);
978        assert_eq!(modulus_minus_one_neg, Scalar::ONE);
979    }
980
981    #[test]
982    fn add_result_within_256_bits() {
983        // A regression for a bug where reduction was not applied
984        // when the unreduced result of addition was in the range `[modulus, 2^256)`.
985        let t = 1.to_biguint().unwrap() << 255;
986        let one = 1.to_biguint().unwrap();
987
988        let a = Scalar::from(&t - &one);
989        let b = Scalar::from(&t);
990        let res = a + b;
991
992        let m = Scalar::modulus_as_biguint();
993        let res_ref = Scalar::from((&t + &t - &one) % &m);
994
995        assert_eq!(res, res_ref);
996    }
997
998    #[cfg(feature = "getrandom")]
999    #[allow(clippy::op_ref)]
1000    #[test]
1001    fn try_generate_biased_from_rng() {
1002        let a = Scalar::try_generate_biased_from_rng(&mut SysRng).unwrap();
1003        // just to make sure `a` is not optimized out by the compiler
1004        assert_eq!((a - &a).is_zero().unwrap_u8(), 1);
1005    }
1006
1007    #[cfg(feature = "getrandom")]
1008    #[test]
1009    fn try_generate_from_rng() {
1010        let a = Scalar::try_generate_from_rng(&mut SysRng).unwrap();
1011        // just to make sure `a` is not optimized out by the compiler
1012        assert_eq!((a - &a).is_zero().unwrap_u8(), 1);
1013    }
1014
1015    #[test]
1016    fn from_bytes_reduced() {
1017        let m = Scalar::modulus_as_biguint();
1018
1019        fn reduce<T: Reduce<FieldBytes>>(arr: &[u8]) -> T {
1020            T::reduce(&Array::try_from(arr).unwrap())
1021        }
1022
1023        // Regular reduction
1024
1025        let s = reduce::<Scalar>(&[0xffu8; 32]).to_biguint().unwrap();
1026        assert!(s < m);
1027
1028        let s = reduce::<Scalar>(&[0u8; 32]).to_biguint().unwrap();
1029        assert!(s.is_zero());
1030
1031        let s = reduce::<Scalar>(&ORDER.to_be_byte_array())
1032            .to_biguint()
1033            .unwrap();
1034        assert!(s.is_zero());
1035
1036        // Reduction to a non-zero scalar
1037
1038        let s = reduce::<NonZeroScalar>(&[0xffu8; 32]).to_biguint().unwrap();
1039        assert!(s < m);
1040
1041        let s = reduce::<NonZeroScalar>(&[0u8; 32]).to_biguint().unwrap();
1042        assert!(s < m);
1043        assert!(!s.is_zero());
1044
1045        let s = reduce::<NonZeroScalar>(&ORDER.to_be_byte_array())
1046            .to_biguint()
1047            .unwrap();
1048        assert!(s < m);
1049        assert!(!s.is_zero());
1050
1051        let s = reduce::<NonZeroScalar>(&(ORDER.wrapping_sub(&U256::ONE)).to_be_byte_array())
1052            .to_biguint()
1053            .unwrap();
1054        assert!(s < m);
1055        assert!(!s.is_zero());
1056    }
1057
1058    #[test]
1059    fn from_wide_bytes_reduced() {
1060        let m = Scalar::modulus_as_biguint();
1061
1062        fn reduce<T: Reduce<WideBytes>>(slice: &[u8]) -> T {
1063            let mut bytes = WideBytes::default();
1064            bytes[(64 - slice.len())..].copy_from_slice(slice);
1065            T::reduce(&bytes)
1066        }
1067
1068        // Regular reduction
1069
1070        let s = reduce::<Scalar>(&[0xffu8; 64]).to_biguint().unwrap();
1071        assert!(s < m);
1072
1073        let s = reduce::<Scalar>(&[0u8; 64]).to_biguint().unwrap();
1074        assert!(s.is_zero());
1075
1076        let s = reduce::<Scalar>(&ORDER.to_be_byte_array())
1077            .to_biguint()
1078            .unwrap();
1079        assert!(s.is_zero());
1080
1081        // Reduction to a non-zero scalar
1082
1083        let s = reduce::<NonZeroScalar>(&[0xffu8; 64]).to_biguint().unwrap();
1084        assert!(s < m);
1085
1086        let s = reduce::<NonZeroScalar>(&[0u8; 64]).to_biguint().unwrap();
1087        assert!(s < m);
1088        assert!(!s.is_zero());
1089
1090        let s = reduce::<NonZeroScalar>(&ORDER.to_be_byte_array())
1091            .to_biguint()
1092            .unwrap();
1093        assert!(s < m);
1094        assert!(!s.is_zero());
1095
1096        let s = reduce::<NonZeroScalar>(&(ORDER.wrapping_sub(&U256::ONE)).to_be_byte_array())
1097            .to_biguint()
1098            .unwrap();
1099        assert!(s < m);
1100        assert!(!s.is_zero());
1101    }
1102
1103    prop_compose! {
1104        fn scalar()(bytes in any::<[u8; 32]>()) -> Scalar {
1105            <Scalar as Reduce<FieldBytes>>::reduce(&bytes.into())
1106        }
1107    }
1108
1109    proptest! {
1110        #[test]
1111        fn fuzzy_roundtrip_to_bytes(a in scalar()) {
1112            let a_back = Scalar::from_repr(a.to_bytes()).unwrap();
1113            assert_eq!(a, a_back);
1114        }
1115
1116        #[test]
1117        fn fuzzy_roundtrip_to_bytes_unchecked(a in scalar()) {
1118            let bytes = a.to_bytes();
1119            let a_back = Scalar::from_bytes_unchecked(bytes.as_ref());
1120            assert_eq!(a, a_back);
1121        }
1122
1123        #[test]
1124        fn fuzzy_add(a in scalar(), b in scalar()) {
1125            let a_bi = a.to_biguint().unwrap();
1126            let b_bi = b.to_biguint().unwrap();
1127
1128            let res_bi = (&a_bi + &b_bi) % &Scalar::modulus_as_biguint();
1129            let res_ref = Scalar::from(&res_bi);
1130            let res_test = a.add(&b);
1131
1132            assert_eq!(res_ref, res_test);
1133        }
1134
1135        #[test]
1136        fn fuzzy_sub(a in scalar(), b in scalar()) {
1137            let a_bi = a.to_biguint().unwrap();
1138            let b_bi = b.to_biguint().unwrap();
1139
1140            let m = Scalar::modulus_as_biguint();
1141            let res_bi = (&m + &a_bi - &b_bi) % &m;
1142            let res_ref = Scalar::from(&res_bi);
1143            let res_test = a.sub(&b);
1144
1145            assert_eq!(res_ref, res_test);
1146        }
1147
1148        #[test]
1149        fn fuzzy_neg(a in scalar()) {
1150            let a_bi = a.to_biguint().unwrap();
1151
1152            let m = Scalar::modulus_as_biguint();
1153            let res_bi = (&m - &a_bi) % &m;
1154            let res_ref = Scalar::from(&res_bi);
1155            let res_test = -a;
1156
1157            assert_eq!(res_ref, res_test);
1158        }
1159
1160        #[test]
1161        fn fuzzy_mul(a in scalar(), b in scalar()) {
1162            let a_bi = a.to_biguint().unwrap();
1163            let b_bi = b.to_biguint().unwrap();
1164
1165            let res_bi = (&a_bi * &b_bi) % &Scalar::modulus_as_biguint();
1166            let res_ref = Scalar::from(&res_bi);
1167            let res_test = a.mul(&b);
1168
1169            assert_eq!(res_ref, res_test);
1170        }
1171
1172        #[test]
1173        fn fuzzy_rshift(a in scalar(), b in 0usize..512) {
1174            let a_bi = a.to_biguint().unwrap();
1175
1176            let res_bi = &a_bi >> b;
1177            let res_ref = Scalar::from(&res_bi);
1178            let res_test = a >> b;
1179
1180            assert_eq!(res_ref, res_test);
1181        }
1182
1183        #[test]
1184        fn fuzzy_invert(
1185            a in scalar()
1186        ) {
1187            let a = if bool::from(a.is_zero()) { Scalar::ONE } else { a };
1188            let a_bi = a.to_biguint().unwrap();
1189            let inv = a.invert().unwrap();
1190            let inv_bi = inv.to_biguint().unwrap();
1191            let m = Scalar::modulus_as_biguint();
1192            assert_eq!((&inv_bi * &a_bi) % &m, 1.to_biguint().unwrap());
1193        }
1194
1195        #[test]
1196        fn fuzzy_invert_vartime(w in scalar()) {
1197            let inv: Option<Scalar> = w.invert().into();
1198            let inv_vartime: Option<Scalar> = w.invert_vartime().into();
1199            assert_eq!(inv, inv_vartime);
1200        }
1201
1202        #[test]
1203        fn fuzzy_from_wide_bytes_reduced(bytes_hi in any::<[u8; 32]>(), bytes_lo in any::<[u8; 32]>()) {
1204            let m = Scalar::modulus_as_biguint();
1205            let mut bytes = [0u8; 64];
1206            bytes[0..32].clone_from_slice(&bytes_hi);
1207            bytes[32..64].clone_from_slice(&bytes_lo);
1208            let s = <Scalar as Reduce<U512>>::reduce(&U512::from_be_slice(&bytes));
1209            let s_bu = s.to_biguint().unwrap();
1210            assert!(s_bu < m);
1211        }
1212    }
1213}