use cfg_if::cfg_if;
cfg_if! {
if #[cfg(any(target_pointer_width = "32", feature = "force-32-bit"))] {
mod scalar_8x32;
use scalar_8x32::MODULUS;
use scalar_8x32::Scalar8x32 as ScalarImpl;
use scalar_8x32::WideScalar16x32 as WideScalarImpl;
} else if #[cfg(target_pointer_width = "64")] {
mod scalar_4x64;
use scalar_4x64::MODULUS;
use scalar_4x64::Scalar4x64 as ScalarImpl;
use scalar_4x64::WideScalar8x64 as WideScalarImpl;
}
}
use crate::{FieldBytes, Secp256k1};
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Shr, Sub, SubAssign};
use elliptic_curve::{
ff::{Field, PrimeField},
generic_array::arr,
rand_core::{CryptoRng, RngCore},
subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption},
};
#[cfg(feature = "digest")]
use elliptic_curve::{consts::U32, Digest, FromDigest};
#[cfg(feature = "zeroize")]
use elliptic_curve::zeroize::Zeroize;
#[cfg(test)]
use num_bigint::{BigUint, ToBigUint};
pub type NonZeroScalar = elliptic_curve::scalar::NonZeroScalar<Secp256k1>;
pub type ScalarBits = elliptic_curve::scalar::ScalarBits<Secp256k1>;
#[derive(Clone, Copy, Debug, Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "arithmetic")))]
pub struct Scalar(ScalarImpl);
impl Field for Scalar {
fn random(rng: impl RngCore) -> Self {
Self::generate_vartime(rng)
}
fn zero() -> Self {
Scalar::zero()
}
fn one() -> Self {
Scalar::one()
}
fn is_zero(&self) -> bool {
self.0.is_zero().into()
}
#[must_use]
fn square(&self) -> Self {
Scalar::square(self)
}
#[must_use]
fn double(&self) -> Self {
self.add(self)
}
fn invert(&self) -> CtOption<Self> {
Scalar::invert(self)
}
fn sqrt(&self) -> CtOption<Self> {
todo!("see RustCrypto/elliptic-curves#170");
}
}
impl PrimeField for Scalar {
type Repr = FieldBytes;
cfg_if! {
if #[cfg(any(target_pointer_width = "32", feature = "force-32-bit"))] {
type ReprBits = [u32; 8];
} else if #[cfg(target_pointer_width = "64")] {
type ReprBits = [u64; 4];
}
}
const NUM_BITS: u32 = 256;
const CAPACITY: u32 = 255;
const S: u32 = 6;
fn from_repr(bytes: FieldBytes) -> Option<Self> {
ScalarImpl::from_bytes(bytes.as_ref()).map(Self).into()
}
fn to_repr(&self) -> FieldBytes {
self.to_bytes()
}
fn to_le_bits(&self) -> ScalarBits {
self.into()
}
fn is_odd(&self) -> bool {
self.0.is_odd().into()
}
fn char_le_bits() -> ScalarBits {
MODULUS.into()
}
fn multiplicative_generator() -> Self {
7u64.into()
}
fn root_of_unity() -> Self {
Scalar::from_repr(arr![u8;
0xc1, 0xdc, 0x06, 0x0e, 0x7a, 0x91, 0x98, 0x6d, 0xf9, 0x87, 0x9a, 0x3f, 0xbc, 0x48,
0x3a, 0x89, 0x8b, 0xde, 0xab, 0x68, 0x07, 0x56, 0x04, 0x59, 0x92, 0xf4, 0xb5, 0x40,
0x2b, 0x05, 0x2f, 0x2,
])
.unwrap()
}
}
impl From<u32> for Scalar {
fn from(k: u32) -> Self {
Self(ScalarImpl::from(k))
}
}
impl From<u64> for Scalar {
fn from(k: u64) -> Self {
Self(ScalarImpl::from(k))
}
}
impl Scalar {
pub const fn zero() -> Self {
Self(ScalarImpl::zero())
}
pub const fn one() -> Scalar {
Self(ScalarImpl::one())
}
pub fn is_zero(&self) -> Choice {
self.0.is_zero()
}
pub fn truncate_to_u32(&self) -> u32 {
self.0.truncate_to_u32()
}
#[cfg(feature = "endomorphism-mul")]
pub(crate) const fn from_bytes_unchecked(bytes: &[u8; 32]) -> Self {
Self(ScalarImpl::from_bytes_unchecked(bytes))
}
pub fn from_bytes_reduced(bytes: &FieldBytes) -> Self {
Self(ScalarImpl::from_bytes_reduced(bytes.as_ref()))
}
pub fn to_bytes(&self) -> FieldBytes {
self.0.to_bytes()
}
pub fn is_high(&self) -> Choice {
self.0.is_high()
}
pub fn negate(&self) -> Self {
Self(self.0.negate())
}
pub fn add(&self, rhs: &Scalar) -> Scalar {
Self(self.0.add(&(rhs.0)))
}
pub fn sub(&self, rhs: &Scalar) -> Scalar {
Self(self.0.sub(&(rhs.0)))
}
pub fn mul(&self, rhs: &Scalar) -> Scalar {
Self(self.0.mul(&(rhs.0)))
}
pub fn square(&self) -> Self {
self.mul(&self)
}
pub fn rshift(&self, shift: usize) -> Scalar {
Self(self.0.rshift(shift))
}
fn pow2k(&self, k: usize) -> Self {
let mut x = *self;
for _j in 0..k {
x = x.square();
}
x
}
pub fn invert(&self) -> CtOption<Self> {
let x_1 = *self;
let x_10 = self.pow2k(1);
let x_11 = x_10.mul(&x_1);
let x_101 = x_10.mul(&x_11);
let x_111 = x_10.mul(&x_101);
let x_1001 = x_10.mul(&x_111);
let x_1011 = x_10.mul(&x_1001);
let x_1101 = x_10.mul(&x_1011);
let x6 = x_1101.pow2k(2).mul(&x_1011);
let x8 = x6.pow2k(2).mul(&x_11);
let x14 = x8.pow2k(6).mul(&x6);
let x28 = x14.pow2k(14).mul(&x14);
let x56 = x28.pow2k(28).mul(&x28);
#[rustfmt::skip]
let res = x56
.pow2k(56).mul(&x56)
.pow2k(14).mul(&x14)
.pow2k(3).mul(&x_101)
.pow2k(4).mul(&x_111)
.pow2k(4).mul(&x_101)
.pow2k(5).mul(&x_1011)
.pow2k(4).mul(&x_1011)
.pow2k(4).mul(&x_111)
.pow2k(5).mul(&x_111)
.pow2k(6).mul(&x_1101)
.pow2k(4).mul(&x_101)
.pow2k(3).mul(&x_111)
.pow2k(5).mul(&x_1001)
.pow2k(6).mul(&x_101)
.pow2k(10).mul(&x_111)
.pow2k(4).mul(&x_111)
.pow2k(9).mul(&x8)
.pow2k(5).mul(&x_1001)
.pow2k(6).mul(&x_1011)
.pow2k(4).mul(&x_1101)
.pow2k(5).mul(&x_11)
.pow2k(6).mul(&x_1101)
.pow2k(10).mul(&x_1101)
.pow2k(4).mul(&x_1001)
.pow2k(6).mul(&x_1)
.pow2k(8).mul(&x6);
CtOption::new(res, !self.is_zero())
}
#[cfg(test)]
pub fn modulus_as_biguint() -> BigUint {
Self::one().negate().to_biguint().unwrap() + 1.to_biguint().unwrap()
}
pub fn generate_biased(mut rng: impl CryptoRng + RngCore) -> Self {
let mut buf = [0u8; 64];
rng.fill_bytes(&mut buf);
Scalar(WideScalarImpl::from_bytes(&buf).reduce())
}
pub fn generate_vartime(mut rng: impl RngCore) -> Self {
let mut bytes = FieldBytes::default();
loop {
rng.fill_bytes(&mut bytes);
if let Some(scalar) = Scalar::from_repr(bytes) {
return scalar;
}
}
}
pub fn conditional_add_bit(&self, bit: usize, flag: Choice) -> Self {
Self(self.0.conditional_add_bit(bit, flag))
}
pub fn mul_shift_var(&self, b: &Scalar, shift: usize) -> Self {
Self(self.0.mul_shift_var(&(b.0), shift))
}
}
#[cfg(feature = "digest")]
#[cfg_attr(docsrs, doc(cfg(feature = "digest")))]
impl FromDigest<Secp256k1> for Scalar {
fn from_digest<D>(digest: D) -> Self
where
D: Digest<OutputSize = U32>,
{
Self::from_bytes_reduced(&digest.finalize())
}
}
impl Shr<usize> for Scalar {
type Output = Self;
fn shr(self, rhs: usize) -> Self::Output {
self.rshift(rhs)
}
}
impl Shr<usize> for &Scalar {
type Output = Scalar;
fn shr(self, rhs: usize) -> Self::Output {
self.rshift(rhs)
}
}
impl ConditionallySelectable for Scalar {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
Self(ScalarImpl::conditional_select(&(a.0), &(b.0), choice))
}
}
impl ConstantTimeEq for Scalar {
fn ct_eq(&self, other: &Self) -> Choice {
self.0.ct_eq(&(other.0))
}
}
impl PartialEq for Scalar {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl Eq for Scalar {}
impl Neg for Scalar {
type Output = Scalar;
fn neg(self) -> Scalar {
self.negate()
}
}
impl Neg for &Scalar {
type Output = Scalar;
fn neg(self) -> Scalar {
self.negate()
}
}
impl Add<Scalar> for Scalar {
type Output = Scalar;
fn add(self, other: Scalar) -> Scalar {
Scalar::add(&self, &other)
}
}
impl Add<&Scalar> for &Scalar {
type Output = Scalar;
fn add(self, other: &Scalar) -> Scalar {
Scalar::add(self, other)
}
}
impl Add<Scalar> for &Scalar {
type Output = Scalar;
fn add(self, other: Scalar) -> Scalar {
Scalar::add(self, &other)
}
}
impl Add<&Scalar> for Scalar {
type Output = Scalar;
fn add(self, other: &Scalar) -> Scalar {
Scalar::add(&self, other)
}
}
impl AddAssign<Scalar> for Scalar {
fn add_assign(&mut self, rhs: Scalar) {
*self = Scalar::add(self, &rhs);
}
}
impl AddAssign<&Scalar> for Scalar {
fn add_assign(&mut self, rhs: &Scalar) {
*self = Scalar::add(self, &rhs);
}
}
impl Sub<Scalar> for Scalar {
type Output = Scalar;
fn sub(self, other: Scalar) -> Scalar {
Scalar::sub(&self, &other)
}
}
impl Sub<&Scalar> for &Scalar {
type Output = Scalar;
fn sub(self, other: &Scalar) -> Scalar {
Scalar::sub(self, other)
}
}
impl Sub<&Scalar> for Scalar {
type Output = Scalar;
fn sub(self, other: &Scalar) -> Scalar {
Scalar::sub(&self, other)
}
}
impl SubAssign<Scalar> for Scalar {
fn sub_assign(&mut self, rhs: Scalar) {
*self = Scalar::sub(self, &rhs);
}
}
impl SubAssign<&Scalar> for Scalar {
fn sub_assign(&mut self, rhs: &Scalar) {
*self = Scalar::sub(self, rhs);
}
}
impl Mul<Scalar> for Scalar {
type Output = Scalar;
fn mul(self, other: Scalar) -> Scalar {
Scalar::mul(&self, &other)
}
}
impl Mul<&Scalar> for &Scalar {
type Output = Scalar;
fn mul(self, other: &Scalar) -> Scalar {
Scalar::mul(self, other)
}
}
impl Mul<&Scalar> for Scalar {
type Output = Scalar;
fn mul(self, other: &Scalar) -> Scalar {
Scalar::mul(&self, other)
}
}
impl MulAssign<Scalar> for Scalar {
fn mul_assign(&mut self, rhs: Scalar) {
*self = Scalar::mul(self, &rhs);
}
}
impl MulAssign<&Scalar> for Scalar {
fn mul_assign(&mut self, rhs: &Scalar) {
*self = Scalar::mul(self, rhs);
}
}
impl From<&Scalar> for ScalarBits {
fn from(scalar: &Scalar) -> ScalarBits {
scalar.0.into()
}
}
impl From<Scalar> for FieldBytes {
fn from(scalar: Scalar) -> Self {
scalar.to_bytes()
}
}
impl From<&Scalar> for FieldBytes {
fn from(scalar: &Scalar) -> Self {
scalar.to_bytes()
}
}
#[cfg(feature = "zeroize")]
impl Zeroize for Scalar {
fn zeroize(&mut self) {
self.0.zeroize()
}
}
#[cfg(test)]
mod tests {
use super::Scalar;
use crate::arithmetic::dev::{biguint_to_bytes, bytes_to_biguint};
use elliptic_curve::ff::PrimeField;
use num_bigint::{BigUint, ToBigUint};
use proptest::prelude::*;
impl From<&BigUint> for Scalar {
fn from(x: &BigUint) -> Self {
debug_assert!(x < &Scalar::modulus_as_biguint());
let bytes = biguint_to_bytes(x);
Self::from_repr(bytes.into()).unwrap()
}
}
impl From<BigUint> for Scalar {
fn from(x: BigUint) -> Self {
Self::from(&x)
}
}
impl ToBigUint for Scalar {
fn to_biguint(&self) -> Option<BigUint> {
Some(bytes_to_biguint(self.to_bytes().as_ref()))
}
}
#[test]
fn is_high() {
let high: bool = Scalar::zero().is_high().into();
assert!(!high);
let m = Scalar::modulus_as_biguint();
let m_by_2 = &m >> 1;
let one = 1.to_biguint().unwrap();
let high: bool = Scalar::from(&m_by_2 - &one).is_high().into();
assert!(!high);
let high: bool = Scalar::from(&m_by_2).is_high().into();
assert!(high);
let high: bool = Scalar::from(&m - &one).is_high().into();
assert!(high);
}
#[test]
fn negate() {
let zero_neg = -Scalar::zero();
assert_eq!(zero_neg, Scalar::zero());
let m = Scalar::modulus_as_biguint();
let one = 1.to_biguint().unwrap();
let m_minus_one = &m - &one;
let m_by_2 = &m >> 1;
let one_neg = -Scalar::one();
assert_eq!(one_neg, Scalar::from(&m_minus_one));
let frac_modulus_2_neg = -Scalar::from(&m_by_2);
let frac_modulus_2_plus_one = Scalar::from(&m_by_2 + &one);
assert_eq!(frac_modulus_2_neg, frac_modulus_2_plus_one);
let modulus_minus_one_neg = -Scalar::from(&m - &one);
assert_eq!(modulus_minus_one_neg, Scalar::one());
}
#[test]
fn add_result_within_256_bits() {
let t = 1.to_biguint().unwrap() << 255;
let one = 1.to_biguint().unwrap();
let a = Scalar::from(&t - &one);
let b = Scalar::from(&t);
let res = &a + &b;
let m = Scalar::modulus_as_biguint();
let res_ref = Scalar::from((&t + &t - &one) % &m);
assert_eq!(res, res_ref);
}
#[test]
fn generate_biased() {
use elliptic_curve::rand_core::OsRng;
let a = Scalar::generate_biased(&mut OsRng);
assert_eq!((a - &a).is_zero().unwrap_u8(), 1);
}
#[test]
fn generate_vartime() {
use elliptic_curve::rand_core::OsRng;
let a = Scalar::generate_vartime(&mut OsRng);
assert_eq!((a - &a).is_zero().unwrap_u8(), 1);
}
prop_compose! {
fn scalar()(bytes in any::<[u8; 32]>()) -> Scalar {
let mut res = bytes_to_biguint(&bytes);
let m = Scalar::modulus_as_biguint();
if res >= m {
res -= m;
}
Scalar::from(&res)
}
}
proptest! {
#[test]
fn fuzzy_roundtrip_to_bytes(a in scalar()) {
let a_back = Scalar::from_repr(a.to_bytes()).unwrap();
assert_eq!(a, a_back);
}
#[test]
#[cfg(feature = "endomorphism-mul")]
fn fuzzy_roundtrip_to_bytes_unchecked(a in scalar()) {
let bytes = a.to_bytes();
let a_back = Scalar::from_bytes_unchecked(bytes.as_ref());
assert_eq!(a, a_back);
}
#[test]
fn fuzzy_add(a in scalar(), b in scalar()) {
let a_bi = a.to_biguint().unwrap();
let b_bi = b.to_biguint().unwrap();
let res_bi = (&a_bi + &b_bi) % &Scalar::modulus_as_biguint();
let res_ref = Scalar::from(&res_bi);
let res_test = a.add(&b);
assert_eq!(res_ref, res_test);
}
#[test]
fn fuzzy_sub(a in scalar(), b in scalar()) {
let a_bi = a.to_biguint().unwrap();
let b_bi = b.to_biguint().unwrap();
let m = Scalar::modulus_as_biguint();
let res_bi = (&m + &a_bi - &b_bi) % &m;
let res_ref = Scalar::from(&res_bi);
let res_test = a.sub(&b);
assert_eq!(res_ref, res_test);
}
#[test]
fn fuzzy_neg(a in scalar()) {
let a_bi = a.to_biguint().unwrap();
let m = Scalar::modulus_as_biguint();
let res_bi = (&m - &a_bi) % &m;
let res_ref = Scalar::from(&res_bi);
let res_test = -a;
assert_eq!(res_ref, res_test);
}
#[test]
fn fuzzy_mul(a in scalar(), b in scalar()) {
let a_bi = a.to_biguint().unwrap();
let b_bi = b.to_biguint().unwrap();
let res_bi = (&a_bi * &b_bi) % &Scalar::modulus_as_biguint();
let res_ref = Scalar::from(&res_bi);
let res_test = a.mul(&b);
assert_eq!(res_ref, res_test);
}
#[test]
fn fuzzy_rshift(a in scalar(), b in 0usize..512) {
let a_bi = a.to_biguint().unwrap();
let res_bi = &a_bi >> b;
let res_ref = Scalar::from(&res_bi);
let res_test = a >> b;
assert_eq!(res_ref, res_test);
}
#[test]
fn fuzzy_invert(
a in scalar()
) {
let a = if bool::from(a.is_zero()) { Scalar::one() } else { a };
let a_bi = a.to_biguint().unwrap();
let inv = a.invert().unwrap();
let inv_bi = inv.to_biguint().unwrap();
let m = Scalar::modulus_as_biguint();
assert_eq!((&inv_bi * &a_bi) % &m, 1.to_biguint().unwrap());
}
}
}