1use crate::{FieldBytes, NonZeroScalar, ORDER, ORDER_HEX, Secp256k1, WideBytes};
4use core::iter::{Product, Sum};
5use elliptic_curve::{
6 Curve, Error, Generate, ScalarValue,
7 bigint::{ArrayEncoding, Limb, U256, U512, Word, cpubits, modular::Retrieve},
8 ctutils,
9 ff::{self, Field, FromUniformBytes, PrimeField},
10 ops::{
11 Add, AddAssign, Invert, Mul, MulAssign, Neg, Reduce, ReduceNonZero, Shr, ShrAssign, Sub,
12 SubAssign,
13 },
14 rand_core::{CryptoRng, TryCryptoRng, TryRng},
15 scalar::{FromUintUnchecked, IsHigh},
16 subtle::{
17 Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
18 CtOption,
19 },
20 zeroize::DefaultIsZeroes,
21};
22
23cpubits! {
24 32 => {
25 #[path = "scalar/wide32.rs"]
26 mod wide;
27 }
28 64 => {
29 #[path = "scalar/wide64.rs"]
30 mod wide;
31 }
32}
33pub(crate) use self::wide::WideScalar;
34
35#[cfg(feature = "serde")]
36use serdect::serde::{Deserialize, Serialize, de, ser};
37#[cfg(feature = "bits")]
38use {crate::ScalarBits, elliptic_curve::group::ff::PrimeFieldBits};
39
40#[cfg(test)]
41use num_bigint::{BigUint, ToBigUint};
42
43const MODULUS: [Word; U256::LIMBS] = ORDER.as_ref().to_words();
46
47const FRAC_MODULUS_2: U256 = ORDER.as_ref().shr_vartime(1);
49
50#[derive(Clone, Copy, Debug, Default, PartialOrd, Ord)]
82pub struct Scalar(pub(crate) U256);
83
84impl AsRef<Scalar> for Scalar {
85 fn as_ref(&self) -> &Scalar {
86 self
87 }
88}
89
90impl Scalar {
91 pub const ZERO: Self = Self(U256::ZERO);
93
94 pub const ONE: Self = Self(U256::ONE);
96
97 pub fn is_zero(&self) -> Choice {
99 self.0.is_zero().into()
100 }
101
102 pub fn to_bytes(&self) -> FieldBytes {
104 self.0.to_be_byte_array()
105 }
106
107 pub const fn negate(&self) -> Self {
109 Self(self.0.neg_mod(ORDER.as_nz_ref()))
110 }
111
112 pub const fn add(&self, rhs: &Self) -> Self {
114 Self(self.0.add_mod(&rhs.0, ORDER.as_nz_ref()))
115 }
116
117 pub const fn sub(&self, rhs: &Self) -> Self {
119 Self(self.0.sub_mod(&rhs.0, ORDER.as_nz_ref()))
120 }
121
122 pub fn mul(&self, rhs: &Scalar) -> Scalar {
124 WideScalar::mul_wide(self, rhs).reduce()
125 }
126
127 pub fn square(&self) -> Self {
129 self.mul(self)
130 }
131
132 pub fn shr_vartime(&self, shift: u32) -> Scalar {
136 Self(self.0.unbounded_shr_vartime(shift))
137 }
138
139 pub fn invert(&self) -> CtOption<Self> {
141 let inv = self.retrieve().invert_odd_mod(&ORDER);
142
143 CtOption::from(inv).map(Self::from_uint_unchecked)
144 }
145
146 pub fn invert_vartime(&self) -> CtOption<Self> {
148 let inv = self.retrieve().invert_odd_mod_vartime(&ORDER);
149
150 CtOption::from(inv).map(Self::from_uint_unchecked)
151 }
152
153 #[cfg(test)]
155 pub fn modulus_as_biguint() -> BigUint {
156 Self::ONE.negate().to_biguint().unwrap() + 1.to_biguint().unwrap()
157 }
158
159 pub fn generate_biased_from_rng<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
161 let Ok(scalar) = Self::try_generate_biased_from_rng(rng);
162 scalar
163 }
164
165 pub fn try_generate_biased_from_rng<R: TryCryptoRng + ?Sized>(
167 rng: &mut R,
168 ) -> Result<Self, R::Error> {
169 let mut buf = [0u8; 64];
172 rng.try_fill_bytes(&mut buf)?;
173 Ok(WideScalar::from_bytes(&buf).reduce())
174 }
175
176 pub(crate) const fn from_bytes_unchecked(bytes: &[u8; 32]) -> Self {
179 Self(U256::from_be_slice(bytes))
180 }
181}
182
183impl Field for Scalar {
184 const ZERO: Self = Self::ZERO;
185 const ONE: Self = Self::ONE;
186
187 fn try_from_rng<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
188 let mut bytes = FieldBytes::default();
198
199 loop {
201 rng.try_fill_bytes(&mut bytes)?;
202 if let Some(scalar) = Scalar::from_repr(bytes).into() {
203 return Ok(scalar);
204 }
205 }
206 }
207
208 fn square(&self) -> Self {
209 Scalar::square(self)
210 }
211
212 fn double(&self) -> Self {
213 self.add(self)
214 }
215
216 fn invert(&self) -> CtOption<Self> {
217 Scalar::invert(self)
218 }
219
220 #[allow(clippy::many_single_char_names)]
223 fn sqrt(&self) -> CtOption<Self> {
224 let w = self.pow_vartime([
226 0x777fa4bd19a06c82,
227 0xfd755db9cd5e9140,
228 0xffffffffffffffff,
229 0x1ffffffffffffff,
230 ]);
231
232 let mut v = Self::S;
233 let mut x = *self * w;
234 let mut b = x * w;
235 let mut z = Self::ROOT_OF_UNITY;
236
237 for max_v in (1..=Self::S).rev() {
238 let mut k = 1;
239 let mut tmp = b.square();
240 let mut j_less_than_v = Choice::from(1);
241
242 for j in 2..max_v {
243 let tmp_is_one = tmp.ct_eq(&Self::ONE);
244 let squared = Self::conditional_select(&tmp, &z, tmp_is_one).square();
245 tmp = Self::conditional_select(&squared, &tmp, tmp_is_one);
246 let new_z = Self::conditional_select(&z, &squared, tmp_is_one);
247 j_less_than_v &= !ConstantTimeEq::ct_eq(&j, &v);
248 k = u32::conditional_select(&j, &k, tmp_is_one);
249 z = Self::conditional_select(&z, &new_z, j_less_than_v);
250 }
251
252 let result = x * z;
253 x = Self::conditional_select(&result, &x, b.ct_eq(&Self::ONE));
254 z = z.square();
255 b *= z;
256 v = k;
257 }
258
259 CtOption::new(x, x.square().ct_eq(self))
260 }
261
262 fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
263 ff::helpers::sqrt_ratio_generic(num, div)
264 }
265}
266
267impl Generate for Scalar {
268 fn try_generate_from_rng<R: TryCryptoRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
269 Self::try_from_rng(rng)
270 }
271}
272
273impl PrimeField for Scalar {
274 type Repr = FieldBytes;
275
276 const MODULUS: &'static str = ORDER_HEX;
277 const NUM_BITS: u32 = 256;
278 const CAPACITY: u32 = 255;
279 const TWO_INV: Self = Self(U256::from_be_hex(
280 "7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a1",
281 ));
282 const MULTIPLICATIVE_GENERATOR: Self = Self(U256::from_u8(7));
283 const S: u32 = 6;
284 const ROOT_OF_UNITY: Self = Self(U256::from_be_hex(
285 "0c1dc060e7a91986df9879a3fbc483a898bdeab680756045992f4b5402b052f2",
286 ));
287 const ROOT_OF_UNITY_INV: Self = Self(U256::from_be_hex(
288 "fd3ae181f12d7096efc7b0c75b8cbb7277a275910aa413c3b6fb30a0884f0d1c",
289 ));
290 const DELTA: Self = Self(U256::from_be_hex(
291 "0000000000000000000cbc21fe4561c8d63b78e780e1341e199417c8c0bb7601",
292 ));
293
294 fn from_repr(bytes: FieldBytes) -> CtOption<Self> {
299 let inner = U256::from_be_byte_array(bytes);
300 CtOption::new(
301 Self(inner),
302 ConstantTimeLess::ct_lt(&inner, &Secp256k1::ORDER),
303 )
304 }
305
306 fn to_repr(&self) -> FieldBytes {
307 self.to_bytes()
308 }
309
310 fn to_le_repr(&self) -> Self::Repr {
311 self.0.to_le_byte_array()
312 }
313
314 fn is_odd(&self) -> Choice {
315 self.0.is_odd().into()
316 }
317}
318
319cpubits! {
321 64 => {
322 #[cfg(all(feature = "bits", target_pointer_width = "32"))]
323 compile_error!("the 'bits' feature is not supported on this target");
324 }
325}
326
327#[cfg(feature = "bits")]
328impl PrimeFieldBits for Scalar {
329 cpubits! {
330 32 => { type ReprBits = [u32; 8]; }
331 64 => { type ReprBits = [u64; 4]; }
332 }
333
334 fn to_le_bits(&self) -> ScalarBits {
335 self.into()
336 }
337
338 fn char_le_bits() -> ScalarBits {
339 ORDER.to_words().into()
340 }
341}
342
343impl DefaultIsZeroes for Scalar {}
344
345impl From<u32> for Scalar {
346 fn from(k: u32) -> Self {
347 Self(k.into())
348 }
349}
350
351impl From<u64> for Scalar {
352 fn from(k: u64) -> Self {
353 Self(k.into())
354 }
355}
356
357impl From<u128> for Scalar {
358 fn from(k: u128) -> Self {
359 Self(k.into())
360 }
361}
362
363impl FromUniformBytes<64> for Scalar {
364 fn from_uniform_bytes(bytes: &[u8; 64]) -> Self {
365 WideScalar::from_bytes(bytes).reduce()
366 }
367}
368
369impl From<NonZeroScalar> for Scalar {
370 fn from(scalar: NonZeroScalar) -> Self {
371 *scalar.as_ref()
372 }
373}
374
375impl From<&NonZeroScalar> for Scalar {
376 fn from(scalar: &NonZeroScalar) -> Self {
377 *scalar.as_ref()
378 }
379}
380
381impl From<ScalarValue<Secp256k1>> for Scalar {
382 fn from(scalar: ScalarValue<Secp256k1>) -> Scalar {
383 Scalar(*scalar.as_uint())
384 }
385}
386
387impl From<&ScalarValue<Secp256k1>> for Scalar {
388 fn from(scalar: &ScalarValue<Secp256k1>) -> Scalar {
389 Scalar(*scalar.as_uint())
390 }
391}
392
393impl From<Scalar> for ScalarValue<Secp256k1> {
394 fn from(scalar: Scalar) -> ScalarValue<Secp256k1> {
395 ScalarValue::from(&scalar)
396 }
397}
398
399impl From<&Scalar> for ScalarValue<Secp256k1> {
400 fn from(scalar: &Scalar) -> ScalarValue<Secp256k1> {
401 ScalarValue::new(scalar.0).unwrap()
402 }
403}
404
405impl TryFrom<Scalar> for NonZeroScalar {
407 type Error = Error;
408
409 fn try_from(scalar: Scalar) -> Result<Self, Error> {
410 NonZeroScalar::new(scalar).into_option().ok_or(Error)
411 }
412}
413
414impl FromUintUnchecked for Scalar {
415 type Uint = U256;
416
417 fn from_uint_unchecked(uint: Self::Uint) -> Self {
418 Self(uint)
419 }
420}
421
422impl Invert for Scalar {
423 type Output = CtOption<Self>;
424
425 fn invert(&self) -> CtOption<Self> {
426 Scalar::invert(self)
427 }
428
429 fn invert_vartime(&self) -> CtOption<Self> {
430 Scalar::invert_vartime(self)
431 }
432}
433
434impl IsHigh for Scalar {
435 fn is_high(&self) -> Choice {
436 ConstantTimeGreater::ct_gt(&self.0, &FRAC_MODULUS_2)
437 }
438}
439
440impl Shr<usize> for Scalar {
441 type Output = Self;
442
443 fn shr(self, rhs: usize) -> Self::Output {
444 self.shr_vartime(rhs as u32)
445 }
446}
447
448impl Shr<usize> for &Scalar {
449 type Output = Scalar;
450
451 fn shr(self, rhs: usize) -> Self::Output {
452 self.shr_vartime(rhs as u32)
453 }
454}
455
456impl ShrAssign<usize> for Scalar {
457 fn shr_assign(&mut self, rhs: usize) {
458 *self = *self >> rhs;
459 }
460}
461
462impl ConditionallySelectable for Scalar {
463 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
464 Self(U256::conditional_select(&a.0, &b.0, choice))
465 }
466}
467
468impl ConstantTimeEq for Scalar {
469 fn ct_eq(&self, other: &Self) -> Choice {
470 ConstantTimeEq::ct_eq(&self.0, &other.0)
471 }
472}
473
474impl ctutils::CtEq for Scalar {
475 fn ct_eq(&self, other: &Self) -> ctutils::Choice {
476 ConstantTimeEq::ct_eq(self, other).into()
477 }
478}
479
480impl ctutils::CtSelect for Scalar {
481 fn ct_select(&self, other: &Self, choice: ctutils::Choice) -> Self {
482 ConditionallySelectable::conditional_select(self, other, choice.into())
483 }
484}
485
486impl PartialEq for Scalar {
487 fn eq(&self, other: &Self) -> bool {
488 ConstantTimeEq::ct_eq(self, other).into()
489 }
490}
491
492impl Eq for Scalar {}
493
494impl Neg for Scalar {
495 type Output = Scalar;
496
497 fn neg(self) -> Scalar {
498 self.negate()
499 }
500}
501
502impl Neg for &Scalar {
503 type Output = Scalar;
504
505 fn neg(self) -> Scalar {
506 self.negate()
507 }
508}
509
510impl Add<Scalar> for Scalar {
511 type Output = Scalar;
512
513 fn add(self, other: Scalar) -> Scalar {
514 Scalar::add(&self, &other)
515 }
516}
517
518impl Add<&Scalar> for &Scalar {
519 type Output = Scalar;
520
521 fn add(self, other: &Scalar) -> Scalar {
522 Scalar::add(self, other)
523 }
524}
525
526impl Add<Scalar> for &Scalar {
527 type Output = Scalar;
528
529 fn add(self, other: Scalar) -> Scalar {
530 Scalar::add(self, &other)
531 }
532}
533
534impl Add<&Scalar> for Scalar {
535 type Output = Scalar;
536
537 fn add(self, other: &Scalar) -> Scalar {
538 Scalar::add(&self, other)
539 }
540}
541
542impl AddAssign<Scalar> for Scalar {
543 #[inline]
544 fn add_assign(&mut self, rhs: Scalar) {
545 *self = Scalar::add(self, &rhs);
546 }
547}
548
549impl AddAssign<&Scalar> for Scalar {
550 fn add_assign(&mut self, rhs: &Scalar) {
551 *self = Scalar::add(self, rhs);
552 }
553}
554
555impl Sub<Scalar> for Scalar {
556 type Output = Scalar;
557
558 fn sub(self, other: Scalar) -> Scalar {
559 Scalar::sub(&self, &other)
560 }
561}
562
563impl Sub<&Scalar> for &Scalar {
564 type Output = Scalar;
565
566 fn sub(self, other: &Scalar) -> Scalar {
567 Scalar::sub(self, other)
568 }
569}
570
571impl Sub<&Scalar> for Scalar {
572 type Output = Scalar;
573
574 fn sub(self, other: &Scalar) -> Scalar {
575 Scalar::sub(&self, other)
576 }
577}
578
579impl SubAssign<Scalar> for Scalar {
580 fn sub_assign(&mut self, rhs: Scalar) {
581 *self = Scalar::sub(self, &rhs);
582 }
583}
584
585impl SubAssign<&Scalar> for Scalar {
586 fn sub_assign(&mut self, rhs: &Scalar) {
587 *self = Scalar::sub(self, rhs);
588 }
589}
590
591impl Mul<Scalar> for Scalar {
592 type Output = Scalar;
593
594 fn mul(self, other: Scalar) -> Scalar {
595 Scalar::mul(&self, &other)
596 }
597}
598
599impl Mul<&Scalar> for &Scalar {
600 type Output = Scalar;
601
602 fn mul(self, other: &Scalar) -> Scalar {
603 Scalar::mul(self, other)
604 }
605}
606
607impl Mul<&Scalar> for Scalar {
608 type Output = Scalar;
609
610 fn mul(self, other: &Scalar) -> Scalar {
611 Scalar::mul(&self, other)
612 }
613}
614
615elliptic_curve::scalar_mul_impls!(Secp256k1, Scalar);
616
617impl MulAssign<Scalar> for Scalar {
618 fn mul_assign(&mut self, rhs: Scalar) {
619 *self = Scalar::mul(self, &rhs);
620 }
621}
622
623impl MulAssign<&Scalar> for Scalar {
624 fn mul_assign(&mut self, rhs: &Scalar) {
625 *self = Scalar::mul(self, rhs);
626 }
627}
628
629impl Reduce<U256> for Scalar {
630 fn reduce(w: &U256) -> Self {
631 let (r, underflow) = w.borrowing_sub(&ORDER, Limb::ZERO);
632 let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
633 Self(U256::conditional_select(w, &r, !underflow))
634 }
635}
636
637impl Reduce<FieldBytes> for Scalar {
638 #[inline]
639 fn reduce(bytes: &FieldBytes) -> Self {
640 Self::reduce(&U256::from_be_byte_array(*bytes))
641 }
642}
643
644impl Reduce<U512> for Scalar {
645 fn reduce(w: &U512) -> Self {
646 WideScalar(*w).reduce()
647 }
648}
649
650impl Reduce<WideBytes> for Scalar {
651 fn reduce(bytes: &WideBytes) -> Self {
652 Self::reduce(&U512::from_be_byte_array(*bytes))
653 }
654}
655
656impl ReduceNonZero<U256> for Scalar {
657 fn reduce_nonzero(w: &U256) -> Self {
658 const ORDER_MINUS_ONE: U256 = ORDER.as_ref().wrapping_sub(&U256::ONE);
659 let (r, underflow) = w.borrowing_sub(&ORDER_MINUS_ONE, Limb::ZERO);
660 let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
661 Self(U256::conditional_select(w, &r, !underflow).wrapping_add(&U256::ONE))
662 }
663}
664
665impl ReduceNonZero<FieldBytes> for Scalar {
666 #[inline]
667 fn reduce_nonzero(bytes: &FieldBytes) -> Self {
668 Self::reduce_nonzero(&U256::from_be_byte_array(*bytes))
669 }
670}
671
672impl ReduceNonZero<U512> for Scalar {
673 fn reduce_nonzero(w: &U512) -> Self {
674 WideScalar(*w).reduce_nonzero()
675 }
676}
677
678impl ReduceNonZero<WideBytes> for Scalar {
679 #[inline]
680 fn reduce_nonzero(bytes: &WideBytes) -> Self {
681 Self::reduce_nonzero(&U512::from_be_byte_array(*bytes))
682 }
683}
684
685impl Retrieve for Scalar {
686 type Output = U256;
687
688 fn retrieve(&self) -> U256 {
689 U256::from_be_byte_array(self.to_bytes())
690 }
691}
692
693impl Sum for Scalar {
694 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
695 iter.reduce(Add::add).unwrap_or(Self::ZERO)
696 }
697}
698
699impl<'a> Sum<&'a Scalar> for Scalar {
700 fn sum<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
701 iter.copied().sum()
702 }
703}
704
705impl Product for Scalar {
706 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
707 iter.reduce(Mul::mul).unwrap_or(Self::ONE)
708 }
709}
710
711impl<'a> Product<&'a Scalar> for Scalar {
712 fn product<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
713 iter.copied().product()
714 }
715}
716
717#[cfg(feature = "bits")]
718impl From<&Scalar> for ScalarBits {
719 fn from(scalar: &Scalar) -> ScalarBits {
720 scalar.0.to_words().into()
721 }
722}
723
724impl From<Scalar> for FieldBytes {
725 fn from(scalar: Scalar) -> Self {
726 scalar.to_bytes()
727 }
728}
729
730impl From<&Scalar> for FieldBytes {
731 fn from(scalar: &Scalar) -> Self {
732 scalar.to_bytes()
733 }
734}
735
736impl From<Scalar> for U256 {
737 fn from(scalar: Scalar) -> Self {
738 scalar.0
739 }
740}
741
742impl From<&Scalar> for U256 {
743 fn from(scalar: &Scalar) -> Self {
744 scalar.0
745 }
746}
747
748#[cfg(feature = "serde")]
749impl Serialize for Scalar {
750 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
751 where
752 S: ser::Serializer,
753 {
754 ScalarValue::from(self).serialize(serializer)
755 }
756}
757
758#[cfg(feature = "serde")]
759impl<'de> Deserialize<'de> for Scalar {
760 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
761 where
762 D: de::Deserializer<'de>,
763 {
764 Ok(ScalarValue::deserialize(deserializer)?.into())
765 }
766}
767
768#[cfg(test)]
769mod tests {
770 use super::Scalar;
771 use crate::{
772 FieldBytes, NonZeroScalar, ORDER, WideBytes,
773 arithmetic::dev::{biguint_to_bytes, bytes_to_biguint},
774 };
775 use elliptic_curve::{
776 array::Array,
777 bigint::{ArrayEncoding, U256, U512},
778 ff::{Field, PrimeField},
779 ops::{BatchInvert, Reduce},
780 scalar::IsHigh,
781 };
782 use num_bigint::{BigUint, ToBigUint};
783 use num_traits::Zero;
784 use proptest::prelude::*;
785
786 #[cfg(feature = "getrandom")]
787 use elliptic_curve::{Generate, common::getrandom::SysRng};
788
789 impl From<&BigUint> for Scalar {
790 fn from(x: &BigUint) -> Self {
791 debug_assert!(x < &Scalar::modulus_as_biguint());
792 let bytes = biguint_to_bytes(x);
793 Self::from_repr(bytes.into()).unwrap()
794 }
795 }
796
797 impl From<BigUint> for Scalar {
798 fn from(x: BigUint) -> Self {
799 Self::from(&x)
800 }
801 }
802
803 impl ToBigUint for Scalar {
804 fn to_biguint(&self) -> Option<BigUint> {
805 Some(bytes_to_biguint(self.to_bytes().as_ref()))
806 }
807 }
808
809 const T: [u64; 4] = [
811 0xeeff497a3340d905,
812 0xfaeabb739abd2280,
813 0xffffffffffffffff,
814 0x03ffffffffffffff,
815 ];
816
817 #[test]
818 fn two_inv_constant() {
819 assert_eq!(Scalar::from(2u32) * Scalar::TWO_INV, Scalar::ONE);
820 }
821
822 #[test]
823 fn root_of_unity_constant() {
824 assert_eq!(
826 Scalar::ROOT_OF_UNITY.pow_vartime([1u64 << Scalar::S, 0, 0, 0]),
827 Scalar::ONE
828 );
829
830 assert_eq!(
832 Scalar::MULTIPLICATIVE_GENERATOR.pow_vartime(T),
833 Scalar::ROOT_OF_UNITY
834 )
835 }
836
837 #[test]
838 fn root_of_unity_inv_constant() {
839 assert_eq!(
840 Scalar::ROOT_OF_UNITY * Scalar::ROOT_OF_UNITY_INV,
841 Scalar::ONE
842 );
843 }
844
845 #[test]
846 fn delta_constant() {
847 assert_eq!(Scalar::DELTA.pow_vartime(T), Scalar::ONE);
849 }
850
851 #[test]
852 fn is_high() {
853 let high: bool = Scalar::ZERO.is_high().into();
855 assert!(!high);
856
857 let one = 1.to_biguint().unwrap();
859 let high: bool = Scalar::from(&one).is_high().into();
860 assert!(!high);
861
862 let m = Scalar::modulus_as_biguint();
863 let m_by_2 = &m >> 1;
864
865 let high: bool = Scalar::from(&m_by_2).is_high().into();
867 assert!(!high);
868
869 let high: bool = Scalar::from(&m_by_2 + &one).is_high().into();
871 assert!(high);
872
873 let high: bool = Scalar::from(&m - &one).is_high().into();
875 assert!(high);
876 }
877
878 #[test]
880 fn sqrt() {
881 for &n in &[1u64, 4, 9, 16, 25, 36, 49, 64] {
882 let scalar = Scalar::from(n);
883 let sqrt = scalar.sqrt().unwrap();
884 assert_eq!(sqrt.square(), scalar);
885 }
886 }
887
888 #[test]
890 fn invert() {
891 assert_eq!(Scalar::ONE, Scalar::ONE.invert().unwrap());
892
893 let three = Scalar::from(3u64);
894 let inv_three = three.invert().unwrap();
895 assert_eq!(three * inv_three, Scalar::ONE);
896
897 let minus_three = -three;
898 let inv_minus_three = minus_three.invert().unwrap();
899 assert_eq!(inv_minus_three, -inv_three);
900 assert_eq!(three * inv_minus_three, -Scalar::ONE);
901
902 assert!(bool::from(Scalar::ZERO.invert().is_none()));
903 assert_eq!(Scalar::from(2u64).invert().unwrap(), Scalar::TWO_INV);
904 assert_eq!(
905 Scalar::ROOT_OF_UNITY.invert_vartime().unwrap(),
906 Scalar::ROOT_OF_UNITY_INV
907 );
908 }
909
910 #[test]
912 fn invert_vartime() {
913 assert_eq!(Scalar::ONE, Scalar::ONE.invert_vartime().unwrap());
914
915 let three = Scalar::from(3u64);
916 let inv_three = three.invert_vartime().unwrap();
917 assert_eq!(three * inv_three, Scalar::ONE);
918
919 let minus_three = -three;
920 let inv_minus_three = minus_three.invert_vartime().unwrap();
921 assert_eq!(inv_minus_three, -inv_three);
922 assert_eq!(three * inv_minus_three, -Scalar::ONE);
923
924 assert!(bool::from(Scalar::ZERO.invert_vartime().is_none()));
925 assert_eq!(
926 Scalar::from(2u64).invert_vartime().unwrap(),
927 Scalar::TWO_INV
928 );
929 assert_eq!(
930 Scalar::ROOT_OF_UNITY.invert_vartime().unwrap(),
931 Scalar::ROOT_OF_UNITY_INV
932 );
933 }
934
935 #[test]
936 #[cfg(feature = "getrandom")]
937 fn batch_invert_array() {
938 let k: Scalar = Scalar::generate();
939 let l: Scalar = Scalar::generate();
940
941 let expected = [k.invert().unwrap(), l.invert().unwrap()];
942 assert_eq!(
943 <Scalar as BatchInvert<_>>::batch_invert([k, l]).unwrap(),
944 expected
945 );
946 }
947
948 #[test]
949 #[cfg(all(feature = "alloc", feature = "getrandom"))]
950 fn batch_invert() {
951 let k: Scalar = Scalar::generate();
952 let l: Scalar = Scalar::generate();
953
954 let expected = vec![k.invert().unwrap(), l.invert().unwrap()];
955 let scalars = vec![k, l];
956 let res = <Scalar as BatchInvert<_>>::batch_invert(scalars).unwrap();
957 assert_eq!(res, expected);
958 }
959
960 #[test]
961 fn negate() {
962 let zero_neg = -Scalar::ZERO;
963 assert_eq!(zero_neg, Scalar::ZERO);
964
965 let m = Scalar::modulus_as_biguint();
966 let one = 1.to_biguint().unwrap();
967 let m_minus_one = &m - &one;
968 let m_by_2 = &m >> 1;
969
970 let one_neg = -Scalar::ONE;
971 assert_eq!(one_neg, Scalar::from(&m_minus_one));
972
973 let frac_modulus_2_neg = -Scalar::from(&m_by_2);
974 let frac_modulus_2_plus_one = Scalar::from(&m_by_2 + &one);
975 assert_eq!(frac_modulus_2_neg, frac_modulus_2_plus_one);
976
977 let modulus_minus_one_neg = -Scalar::from(&m - &one);
978 assert_eq!(modulus_minus_one_neg, Scalar::ONE);
979 }
980
981 #[test]
982 fn add_result_within_256_bits() {
983 let t = 1.to_biguint().unwrap() << 255;
986 let one = 1.to_biguint().unwrap();
987
988 let a = Scalar::from(&t - &one);
989 let b = Scalar::from(&t);
990 let res = a + b;
991
992 let m = Scalar::modulus_as_biguint();
993 let res_ref = Scalar::from((&t + &t - &one) % &m);
994
995 assert_eq!(res, res_ref);
996 }
997
998 #[cfg(feature = "getrandom")]
999 #[allow(clippy::op_ref)]
1000 #[test]
1001 fn try_generate_biased_from_rng() {
1002 let a = Scalar::try_generate_biased_from_rng(&mut SysRng).unwrap();
1003 assert_eq!((a - &a).is_zero().unwrap_u8(), 1);
1005 }
1006
1007 #[cfg(feature = "getrandom")]
1008 #[test]
1009 fn try_generate_from_rng() {
1010 let a = Scalar::try_generate_from_rng(&mut SysRng).unwrap();
1011 assert_eq!((a - &a).is_zero().unwrap_u8(), 1);
1013 }
1014
1015 #[test]
1016 fn from_bytes_reduced() {
1017 let m = Scalar::modulus_as_biguint();
1018
1019 fn reduce<T: Reduce<FieldBytes>>(arr: &[u8]) -> T {
1020 T::reduce(&Array::try_from(arr).unwrap())
1021 }
1022
1023 let s = reduce::<Scalar>(&[0xffu8; 32]).to_biguint().unwrap();
1026 assert!(s < m);
1027
1028 let s = reduce::<Scalar>(&[0u8; 32]).to_biguint().unwrap();
1029 assert!(s.is_zero());
1030
1031 let s = reduce::<Scalar>(&ORDER.to_be_byte_array())
1032 .to_biguint()
1033 .unwrap();
1034 assert!(s.is_zero());
1035
1036 let s = reduce::<NonZeroScalar>(&[0xffu8; 32]).to_biguint().unwrap();
1039 assert!(s < m);
1040
1041 let s = reduce::<NonZeroScalar>(&[0u8; 32]).to_biguint().unwrap();
1042 assert!(s < m);
1043 assert!(!s.is_zero());
1044
1045 let s = reduce::<NonZeroScalar>(&ORDER.to_be_byte_array())
1046 .to_biguint()
1047 .unwrap();
1048 assert!(s < m);
1049 assert!(!s.is_zero());
1050
1051 let s = reduce::<NonZeroScalar>(&(ORDER.wrapping_sub(&U256::ONE)).to_be_byte_array())
1052 .to_biguint()
1053 .unwrap();
1054 assert!(s < m);
1055 assert!(!s.is_zero());
1056 }
1057
1058 #[test]
1059 fn from_wide_bytes_reduced() {
1060 let m = Scalar::modulus_as_biguint();
1061
1062 fn reduce<T: Reduce<WideBytes>>(slice: &[u8]) -> T {
1063 let mut bytes = WideBytes::default();
1064 bytes[(64 - slice.len())..].copy_from_slice(slice);
1065 T::reduce(&bytes)
1066 }
1067
1068 let s = reduce::<Scalar>(&[0xffu8; 64]).to_biguint().unwrap();
1071 assert!(s < m);
1072
1073 let s = reduce::<Scalar>(&[0u8; 64]).to_biguint().unwrap();
1074 assert!(s.is_zero());
1075
1076 let s = reduce::<Scalar>(&ORDER.to_be_byte_array())
1077 .to_biguint()
1078 .unwrap();
1079 assert!(s.is_zero());
1080
1081 let s = reduce::<NonZeroScalar>(&[0xffu8; 64]).to_biguint().unwrap();
1084 assert!(s < m);
1085
1086 let s = reduce::<NonZeroScalar>(&[0u8; 64]).to_biguint().unwrap();
1087 assert!(s < m);
1088 assert!(!s.is_zero());
1089
1090 let s = reduce::<NonZeroScalar>(&ORDER.to_be_byte_array())
1091 .to_biguint()
1092 .unwrap();
1093 assert!(s < m);
1094 assert!(!s.is_zero());
1095
1096 let s = reduce::<NonZeroScalar>(&(ORDER.wrapping_sub(&U256::ONE)).to_be_byte_array())
1097 .to_biguint()
1098 .unwrap();
1099 assert!(s < m);
1100 assert!(!s.is_zero());
1101 }
1102
1103 prop_compose! {
1104 fn scalar()(bytes in any::<[u8; 32]>()) -> Scalar {
1105 <Scalar as Reduce<FieldBytes>>::reduce(&bytes.into())
1106 }
1107 }
1108
1109 proptest! {
1110 #[test]
1111 fn fuzzy_roundtrip_to_bytes(a in scalar()) {
1112 let a_back = Scalar::from_repr(a.to_bytes()).unwrap();
1113 assert_eq!(a, a_back);
1114 }
1115
1116 #[test]
1117 fn fuzzy_roundtrip_to_bytes_unchecked(a in scalar()) {
1118 let bytes = a.to_bytes();
1119 let a_back = Scalar::from_bytes_unchecked(bytes.as_ref());
1120 assert_eq!(a, a_back);
1121 }
1122
1123 #[test]
1124 fn fuzzy_add(a in scalar(), b in scalar()) {
1125 let a_bi = a.to_biguint().unwrap();
1126 let b_bi = b.to_biguint().unwrap();
1127
1128 let res_bi = (&a_bi + &b_bi) % &Scalar::modulus_as_biguint();
1129 let res_ref = Scalar::from(&res_bi);
1130 let res_test = a.add(&b);
1131
1132 assert_eq!(res_ref, res_test);
1133 }
1134
1135 #[test]
1136 fn fuzzy_sub(a in scalar(), b in scalar()) {
1137 let a_bi = a.to_biguint().unwrap();
1138 let b_bi = b.to_biguint().unwrap();
1139
1140 let m = Scalar::modulus_as_biguint();
1141 let res_bi = (&m + &a_bi - &b_bi) % &m;
1142 let res_ref = Scalar::from(&res_bi);
1143 let res_test = a.sub(&b);
1144
1145 assert_eq!(res_ref, res_test);
1146 }
1147
1148 #[test]
1149 fn fuzzy_neg(a in scalar()) {
1150 let a_bi = a.to_biguint().unwrap();
1151
1152 let m = Scalar::modulus_as_biguint();
1153 let res_bi = (&m - &a_bi) % &m;
1154 let res_ref = Scalar::from(&res_bi);
1155 let res_test = -a;
1156
1157 assert_eq!(res_ref, res_test);
1158 }
1159
1160 #[test]
1161 fn fuzzy_mul(a in scalar(), b in scalar()) {
1162 let a_bi = a.to_biguint().unwrap();
1163 let b_bi = b.to_biguint().unwrap();
1164
1165 let res_bi = (&a_bi * &b_bi) % &Scalar::modulus_as_biguint();
1166 let res_ref = Scalar::from(&res_bi);
1167 let res_test = a.mul(&b);
1168
1169 assert_eq!(res_ref, res_test);
1170 }
1171
1172 #[test]
1173 fn fuzzy_rshift(a in scalar(), b in 0usize..512) {
1174 let a_bi = a.to_biguint().unwrap();
1175
1176 let res_bi = &a_bi >> b;
1177 let res_ref = Scalar::from(&res_bi);
1178 let res_test = a >> b;
1179
1180 assert_eq!(res_ref, res_test);
1181 }
1182
1183 #[test]
1184 fn fuzzy_invert(
1185 a in scalar()
1186 ) {
1187 let a = if bool::from(a.is_zero()) { Scalar::ONE } else { a };
1188 let a_bi = a.to_biguint().unwrap();
1189 let inv = a.invert().unwrap();
1190 let inv_bi = inv.to_biguint().unwrap();
1191 let m = Scalar::modulus_as_biguint();
1192 assert_eq!((&inv_bi * &a_bi) % &m, 1.to_biguint().unwrap());
1193 }
1194
1195 #[test]
1196 fn fuzzy_invert_vartime(w in scalar()) {
1197 let inv: Option<Scalar> = w.invert().into();
1198 let inv_vartime: Option<Scalar> = w.invert_vartime().into();
1199 assert_eq!(inv, inv_vartime);
1200 }
1201
1202 #[test]
1203 fn fuzzy_from_wide_bytes_reduced(bytes_hi in any::<[u8; 32]>(), bytes_lo in any::<[u8; 32]>()) {
1204 let m = Scalar::modulus_as_biguint();
1205 let mut bytes = [0u8; 64];
1206 bytes[0..32].clone_from_slice(&bytes_hi);
1207 bytes[32..64].clone_from_slice(&bytes_lo);
1208 let s = <Scalar as Reduce<U512>>::reduce(&U512::from_be_slice(&bytes));
1209 let s_bu = s.to_biguint().unwrap();
1210 assert!(s_bu < m);
1211 }
1212 }
1213}