use crate::util::{adc, mac, sbb};
use core::convert::TryFrom;
use core::fmt;
use core::ops::{Add, AddAssign, BitAnd, BitXor, Mul, MulAssign, Neg, Sub, SubAssign};
use rand::{CryptoRng, Rng};
use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Borrow;
use std::cmp::{Ord, Ordering, PartialOrd};
use std::iter::{Product, Sum};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
#[derive(Clone, Copy, Eq)]
pub struct Scalar(pub [u64; 4]);
impl fmt::Debug for Scalar {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let tmp = self.to_bytes();
write!(f, "0x")?;
for &b in tmp.iter().rev() {
write!(f, "{:02x}", b)?;
}
Ok(())
}
}
impl From<u64> for Scalar {
fn from(val: u64) -> Scalar {
Scalar([val, 0, 0, 0]) * R2
}
}
impl ConstantTimeEq for Scalar {
fn ct_eq(&self, other: &Self) -> Choice {
self.0[0].ct_eq(&other.0[0])
& self.0[1].ct_eq(&other.0[1])
& self.0[2].ct_eq(&other.0[2])
& self.0[3].ct_eq(&other.0[3])
}
}
impl PartialEq for Scalar {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).unwrap_u8() == 1
}
}
impl PartialOrd for Scalar {
fn partial_cmp(&self, other: &Scalar) -> Option<Ordering> {
Some(self.cmp(&other))
}
}
impl Ord for Scalar {
fn cmp(&self, other: &Self) -> Ordering {
for i in (0..4).rev() {
if self.0[i] > other.0[i] {
return Ordering::Greater;
} else if self.0[i] < other.0[i] {
return Ordering::Less;
}
}
Ordering::Equal
}
}
impl ConditionallySelectable for Scalar {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
Scalar([
u64::conditional_select(&a.0[0], &b.0[0], choice),
u64::conditional_select(&a.0[1], &b.0[1], choice),
u64::conditional_select(&a.0[2], &b.0[2], choice),
u64::conditional_select(&a.0[3], &b.0[3], choice),
])
}
}
impl Serialize for Scalar {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeTuple;
let mut tup = serializer.serialize_tuple(32)?;
for byte in self.to_bytes().iter() {
tup.serialize_element(byte)?;
}
tup.end()
}
}
impl<'de> Deserialize<'de> for Scalar {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ScalarVisitor;
impl<'de> Visitor<'de> for ScalarVisitor {
type Value = Scalar;
fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
formatter.write_str("a 32-byte cannonical Scalar from Bls12_381")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Scalar, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut bytes = [0u8; 32];
for i in 0..32 {
bytes[i] = seq
.next_element()?
.ok_or(serde::de::Error::invalid_length(i, &"expected 32 bytes"))?;
}
let res = Scalar::from_bytes(&bytes);
if res.is_some().unwrap_u8() == 1u8 {
return Ok(res.unwrap());
} else {
return Err(serde::de::Error::custom(
&"scalar was not canonically encoded",
));
}
}
}
deserializer.deserialize_tuple(32, ScalarVisitor)
}
}
#[allow(dead_code)]
pub const GEN_X: Scalar = Scalar([
0x1539098E9CBCC1D5,
0x0CCC77B0E1804E8D,
0x6EEF947A6FD0FB2C,
0xA3D063F54E10DDE9,
]);
#[allow(dead_code)]
pub const GEN_Y: Scalar = Scalar([
0x6540D21E7007DC60,
0x3B0D848E832A862F,
0xB53BB87E05DA8257,
0xCD482CC3FD6FF4D,
]);
pub const MODULUS: Scalar = Scalar([
0xffffffff00000001,
0x53bda402fffe5bfe,
0x3339d80809a1d805,
0x73eda753299d7d48,
]);
impl<'a> Neg for &'a Scalar {
type Output = Scalar;
#[inline]
fn neg(self) -> Scalar {
self.neg()
}
}
impl Neg for Scalar {
type Output = Scalar;
#[inline]
fn neg(self) -> Scalar {
-&self
}
}
impl<'a, 'b> Sub<&'b Scalar> for &'a Scalar {
type Output = Scalar;
#[inline]
fn sub(self, rhs: &'b Scalar) -> Scalar {
self.sub(rhs)
}
}
impl<'a, 'b> Add<&'b Scalar> for &'a Scalar {
type Output = Scalar;
#[inline]
fn add(self, rhs: &'b Scalar) -> Scalar {
self.add(rhs)
}
}
impl<'a, 'b> Mul<&'b Scalar> for &'a Scalar {
type Output = Scalar;
#[inline]
fn mul(self, rhs: &'b Scalar) -> Scalar {
self.mul(rhs)
}
}
impl<'a, 'b> BitXor<&'b Scalar> for &'a Scalar {
type Output = Scalar;
fn bitxor(self, rhs: &'b Scalar) -> Scalar {
let a_red = self.reduce();
let b_red = rhs.reduce();
Scalar::from_raw([
a_red.0[0] ^ b_red.0[0],
a_red.0[1] ^ b_red.0[1],
a_red.0[2] ^ b_red.0[2],
a_red.0[3] ^ b_red.0[3],
])
}
}
impl BitXor<Scalar> for Scalar {
type Output = Scalar;
fn bitxor(self, rhs: Scalar) -> Scalar {
&self ^ &rhs
}
}
impl<'a, 'b> BitAnd<&'b Scalar> for &'a Scalar {
type Output = Scalar;
fn bitand(self, rhs: &'b Scalar) -> Scalar {
let a_red = self.reduce();
let b_red = rhs.reduce();
Scalar::from_raw([
a_red.0[0] & b_red.0[0],
a_red.0[1] & b_red.0[1],
a_red.0[2] & b_red.0[2],
a_red.0[3] & b_red.0[3],
])
}
}
impl BitAnd<Scalar> for Scalar {
type Output = Scalar;
fn bitand(self, rhs: Scalar) -> Scalar {
&self & &rhs
}
}
impl_binops_additive!(Scalar, Scalar);
impl_binops_multiplicative!(Scalar, Scalar);
const INV: u64 = 0xfffffffeffffffff;
const R: Scalar = Scalar([
0x00000001fffffffe,
0x5884b7fa00034802,
0x998c4fefecbc4ff5,
0x1824b159acc5056f,
]);
const R2: Scalar = Scalar([
0xc999e990f3f29c6d,
0x2b6cedcb87925c23,
0x05d314967254398f,
0x0748d9d99f59ff11,
]);
const R3: Scalar = Scalar([
0xc62c1807439b73af,
0x1b3e0d188cf06990,
0x73d13c71c7b5f418,
0x6e2a5bb9c8db33e9,
]);
pub const TWO_ADACITY: u32 = 32;
pub const ROOT_OF_UNITY: Scalar = Scalar([
0xb9b58d8c5f0e466a,
0x5b1b4c801819d7ec,
0x0af53ae352a31e64,
0x5bf3adda19e9b27b,
]);
pub const GENERATOR: Scalar = Scalar([7, 0, 0, 0]);
impl<T> Product<T> for Scalar
where
T: Borrow<Scalar>,
{
fn product<I>(iter: I) -> Self
where
I: Iterator<Item = T>,
{
iter.fold(Scalar::one(), |acc, item| acc * item.borrow())
}
}
impl<T> Sum<T> for Scalar
where
T: Borrow<Scalar>,
{
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = T>,
{
iter.fold(Scalar::zero(), |acc, item| acc + item.borrow())
}
}
impl Default for Scalar {
#[inline]
fn default() -> Self {
Self::zero()
}
}
impl Scalar {
#[inline]
pub const fn zero() -> Scalar {
Scalar([0, 0, 0, 0])
}
#[inline]
pub const fn one() -> Scalar {
R
}
pub fn is_zero(&self) -> Choice {
self.ct_eq(&Scalar::zero())
}
pub fn is_one(&self) -> Choice {
self.ct_eq(&Scalar::one())
}
pub const fn internal_repr(&self) -> &[u64; 4] {
&self.0
}
#[inline]
pub const fn double(&self) -> Scalar {
self.add(self)
}
pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<Scalar> {
let mut tmp = Scalar([0, 0, 0, 0]);
tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap());
tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap());
tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap());
tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap());
let (_, borrow) = sbb(tmp.0[0], MODULUS.0[0], 0);
let (_, borrow) = sbb(tmp.0[1], MODULUS.0[1], borrow);
let (_, borrow) = sbb(tmp.0[2], MODULUS.0[2], borrow);
let (_, borrow) = sbb(tmp.0[3], MODULUS.0[3], borrow);
let is_some = (borrow as u8) & 1;
tmp *= &R2;
CtOption::new(tmp, Choice::from(is_some))
}
pub fn to_bytes(&self) -> [u8; 32] {
let tmp = Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0);
let mut res = [0; 32];
res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes());
res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes());
res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes());
res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes());
res
}
pub fn to_bits(&self) -> [u8; 256] {
let mut res = [0u8; 256];
let bytes = self.to_bytes();
for (byte, bits) in bytes.iter().zip(res.chunks_mut(8)) {
bits.iter_mut()
.enumerate()
.for_each(|(i, bit)| *bit = (byte >> i) & 1)
}
res
}
pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar {
Scalar::from_u512([
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()),
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()),
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()),
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()),
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[32..40]).unwrap()),
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[40..48]).unwrap()),
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[48..56]).unwrap()),
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[56..64]).unwrap()),
])
}
fn from_u512(limbs: [u64; 8]) -> Scalar {
let d0 = Scalar([limbs[0], limbs[1], limbs[2], limbs[3]]);
let d1 = Scalar([limbs[4], limbs[5], limbs[6], limbs[7]]);
d0 * R2 + d1 * R3
}
pub const fn from_raw(val: [u64; 4]) -> Self {
(&Scalar(val)).mul(&R2)
}
pub fn random<T>(rand: &mut T) -> Scalar
where
T: Rng + CryptoRng,
{
let mut bytes = [0u8; 32];
rand.fill_bytes(&mut bytes);
bytes[31] &= 0b0011_1111;
Scalar::from_bytes(&bytes).unwrap()
}
pub fn reduce(&self) -> Scalar {
Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0)
}
#[inline]
pub const fn square(&self) -> Scalar {
let (r1, carry) = mac(0, self.0[0], self.0[1], 0);
let (r2, carry) = mac(0, self.0[0], self.0[2], carry);
let (r3, r4) = mac(0, self.0[0], self.0[3], carry);
let (r3, carry) = mac(r3, self.0[1], self.0[2], 0);
let (r4, r5) = mac(r4, self.0[1], self.0[3], carry);
let (r5, r6) = mac(r5, self.0[2], self.0[3], 0);
let r7 = r6 >> 63;
let r6 = (r6 << 1) | (r5 >> 63);
let r5 = (r5 << 1) | (r4 >> 63);
let r4 = (r4 << 1) | (r3 >> 63);
let r3 = (r3 << 1) | (r2 >> 63);
let r2 = (r2 << 1) | (r1 >> 63);
let r1 = r1 << 1;
let (r0, carry) = mac(0, self.0[0], self.0[0], 0);
let (r1, carry) = adc(0, r1, carry);
let (r2, carry) = mac(r2, self.0[1], self.0[1], carry);
let (r3, carry) = adc(0, r3, carry);
let (r4, carry) = mac(r4, self.0[2], self.0[2], carry);
let (r5, carry) = adc(0, r5, carry);
let (r6, carry) = mac(r6, self.0[3], self.0[3], carry);
let (r7, _) = adc(0, r7, carry);
Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
}
pub fn sqrt(&self) -> CtOption<Self> {
let w = self.pow_vartime(&[
0x7fff2dff7fffffff,
0x04d0ec02a9ded201,
0x94cebea4199cec04,
0x0000000039f6d3a9,
]);
let mut v = TWO_ADACITY;
let mut x = self * w;
let mut b = x * w;
let mut z = ROOT_OF_UNITY;
for max_v in (1..=TWO_ADACITY).rev() {
let mut k = 1;
let mut tmp = b.square();
let mut j_less_than_v: Choice = 1.into();
for j in 2..max_v {
let tmp_is_one = tmp.ct_eq(&Scalar::one());
let squared = Scalar::conditional_select(&tmp, &z, tmp_is_one).square();
tmp = Scalar::conditional_select(&squared, &tmp, tmp_is_one);
let new_z = Scalar::conditional_select(&z, &squared, tmp_is_one);
j_less_than_v &= !j.ct_eq(&v);
k = u32::conditional_select(&j, &k, tmp_is_one);
z = Scalar::conditional_select(&z, &new_z, j_less_than_v);
}
let result = x * z;
x = Scalar::conditional_select(&result, &x, b.ct_eq(&Scalar::one()));
z = z.square();
b *= z;
v = k;
}
CtOption::new(
x,
(x * x).ct_eq(self),
)
}
pub fn pow(&self, by: &[u64; 4]) -> Self {
let mut res = Self::one();
for e in by.iter().rev() {
for i in (0..64).rev() {
res = res.square();
let mut tmp = res;
tmp *= self;
res.conditional_assign(&tmp, (((*e >> i) & 0x1) as u8).into());
}
}
res
}
pub fn pow_vartime(&self, by: &[u64; 4]) -> Self {
let mut res = Self::one();
for e in by.iter().rev() {
for i in (0..64).rev() {
res = res.square();
if ((*e >> i) & 1) == 1 {
res.mul_assign(self);
}
}
}
res
}
pub fn pow_of_2(by: u64) -> Self {
let two = Scalar::from(2u64);
let mut res = Self::one();
for i in (0..64).rev() {
res = res.square();
let mut tmp = res;
tmp *= two;
res.conditional_assign(&tmp, (((by >> i) & 0x1) as u8).into());
}
res
}
pub fn invert(&self) -> CtOption<Self> {
#[inline(always)]
fn square_assign_multi(n: &mut Scalar, num_times: usize) {
for _ in 0..num_times {
*n = n.square();
}
}
let mut t0 = self.square();
let mut t1 = t0 * self;
let mut t16 = t0.square();
let mut t6 = t16.square();
let mut t5 = t6 * t0;
t0 = t6 * t16;
let mut t12 = t5 * t16;
let mut t2 = t6.square();
let mut t7 = t5 * t6;
let mut t15 = t0 * t5;
let mut t17 = t12.square();
t1 *= t17;
let mut t3 = t7 * t2;
let t8 = t1 * t17;
let t4 = t8 * t2;
let t9 = t8 * t7;
t7 = t4 * t5;
let t11 = t4 * t17;
t5 = t9 * t17;
let t14 = t7 * t15;
let t13 = t11 * t12;
t12 = t11 * t17;
t15 *= &t12;
t16 *= &t15;
t3 *= &t16;
t17 *= &t3;
t0 *= &t17;
t6 *= &t0;
t2 *= &t6;
square_assign_multi(&mut t0, 8);
t0 *= &t17;
square_assign_multi(&mut t0, 9);
t0 *= &t16;
square_assign_multi(&mut t0, 9);
t0 *= &t15;
square_assign_multi(&mut t0, 9);
t0 *= &t15;
square_assign_multi(&mut t0, 7);
t0 *= &t14;
square_assign_multi(&mut t0, 7);
t0 *= &t13;
square_assign_multi(&mut t0, 10);
t0 *= &t12;
square_assign_multi(&mut t0, 9);
t0 *= &t11;
square_assign_multi(&mut t0, 8);
t0 *= &t8;
square_assign_multi(&mut t0, 8);
t0 *= self;
square_assign_multi(&mut t0, 14);
t0 *= &t9;
square_assign_multi(&mut t0, 10);
t0 *= &t8;
square_assign_multi(&mut t0, 15);
t0 *= &t7;
square_assign_multi(&mut t0, 10);
t0 *= &t6;
square_assign_multi(&mut t0, 8);
t0 *= &t5;
square_assign_multi(&mut t0, 16);
t0 *= &t3;
square_assign_multi(&mut t0, 8);
t0 *= &t2;
square_assign_multi(&mut t0, 7);
t0 *= &t4;
square_assign_multi(&mut t0, 9);
t0 *= &t2;
square_assign_multi(&mut t0, 8);
t0 *= &t3;
square_assign_multi(&mut t0, 8);
t0 *= &t2;
square_assign_multi(&mut t0, 8);
t0 *= &t2;
square_assign_multi(&mut t0, 8);
t0 *= &t2;
square_assign_multi(&mut t0, 8);
t0 *= &t3;
square_assign_multi(&mut t0, 8);
t0 *= &t2;
square_assign_multi(&mut t0, 8);
t0 *= &t2;
square_assign_multi(&mut t0, 5);
t0 *= &t1;
square_assign_multi(&mut t0, 5);
t0 *= &t1;
CtOption::new(t0, !self.ct_eq(&Self::zero()))
}
#[inline(always)]
const fn montgomery_reduce(
r0: u64,
r1: u64,
r2: u64,
r3: u64,
r4: u64,
r5: u64,
r6: u64,
r7: u64,
) -> Self {
let k = r0.wrapping_mul(INV);
let (_, carry) = mac(r0, k, MODULUS.0[0], 0);
let (r1, carry) = mac(r1, k, MODULUS.0[1], carry);
let (r2, carry) = mac(r2, k, MODULUS.0[2], carry);
let (r3, carry) = mac(r3, k, MODULUS.0[3], carry);
let (r4, carry2) = adc(r4, 0, carry);
let k = r1.wrapping_mul(INV);
let (_, carry) = mac(r1, k, MODULUS.0[0], 0);
let (r2, carry) = mac(r2, k, MODULUS.0[1], carry);
let (r3, carry) = mac(r3, k, MODULUS.0[2], carry);
let (r4, carry) = mac(r4, k, MODULUS.0[3], carry);
let (r5, carry2) = adc(r5, carry2, carry);
let k = r2.wrapping_mul(INV);
let (_, carry) = mac(r2, k, MODULUS.0[0], 0);
let (r3, carry) = mac(r3, k, MODULUS.0[1], carry);
let (r4, carry) = mac(r4, k, MODULUS.0[2], carry);
let (r5, carry) = mac(r5, k, MODULUS.0[3], carry);
let (r6, carry2) = adc(r6, carry2, carry);
let k = r3.wrapping_mul(INV);
let (_, carry) = mac(r3, k, MODULUS.0[0], 0);
let (r4, carry) = mac(r4, k, MODULUS.0[1], carry);
let (r5, carry) = mac(r5, k, MODULUS.0[2], carry);
let (r6, carry) = mac(r6, k, MODULUS.0[3], carry);
let (r7, _) = adc(r7, carry2, carry);
(&Scalar([r4, r5, r6, r7])).sub(&MODULUS)
}
#[inline]
pub const fn mul(&self, rhs: &Self) -> Self {
let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0);
let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry);
let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry);
let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry);
let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0);
let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry);
let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry);
let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry);
let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0);
let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry);
let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry);
let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry);
let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0);
let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry);
let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry);
let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry);
Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7)
}
#[inline]
pub const fn sub(&self, rhs: &Self) -> Self {
let (d0, borrow) = sbb(self.0[0], rhs.0[0], 0);
let (d1, borrow) = sbb(self.0[1], rhs.0[1], borrow);
let (d2, borrow) = sbb(self.0[2], rhs.0[2], borrow);
let (d3, borrow) = sbb(self.0[3], rhs.0[3], borrow);
let (d0, carry) = adc(d0, MODULUS.0[0] & borrow, 0);
let (d1, carry) = adc(d1, MODULUS.0[1] & borrow, carry);
let (d2, carry) = adc(d2, MODULUS.0[2] & borrow, carry);
let (d3, _) = adc(d3, MODULUS.0[3] & borrow, carry);
Scalar([d0, d1, d2, d3])
}
#[inline]
pub const fn add(&self, rhs: &Self) -> Self {
let (d0, carry) = adc(self.0[0], rhs.0[0], 0);
let (d1, carry) = adc(self.0[1], rhs.0[1], carry);
let (d2, carry) = adc(self.0[2], rhs.0[2], carry);
let (d3, _) = adc(self.0[3], rhs.0[3], carry);
(&Scalar([d0, d1, d2, d3])).sub(&MODULUS)
}
#[inline]
pub const fn neg(&self) -> Self {
let (d0, borrow) = sbb(MODULUS.0[0], self.0[0], 0);
let (d1, borrow) = sbb(MODULUS.0[1], self.0[1], borrow);
let (d2, borrow) = sbb(MODULUS.0[2], self.0[2], borrow);
let (d3, _) = sbb(MODULUS.0[3], self.0[3], borrow);
let mask = (((self.0[0] | self.0[1] | self.0[2] | self.0[3]) == 0) as u64).wrapping_sub(1);
Scalar([d0 & mask, d1 & mask, d2 & mask, d3 & mask])
}
#[inline]
pub fn divn(&mut self, mut n: u32) {
if n >= 256 {
*self = Self::from(0);
return;
}
while n >= 64 {
let mut t = 0;
for i in self.0.iter_mut().rev() {
core::mem::swap(&mut t, i);
}
n -= 64;
}
if n > 0 {
let mut t = 0;
for i in self.0.iter_mut().rev() {
let t2 = *i << (64 - n);
*i >>= n;
*i |= t;
t = t2;
}
}
}
}
impl<'a> From<&'a Scalar> for [u8; 32] {
fn from(value: &'a Scalar) -> [u8; 32] {
value.to_bytes()
}
}
#[test]
fn test_inv() {
let mut inv = 1u64;
for _ in 0..63 {
inv = inv.wrapping_mul(inv);
inv = inv.wrapping_mul(MODULUS.0[0]);
}
inv = inv.wrapping_neg();
assert_eq!(inv, INV);
}
#[cfg(feature = "std")]
#[test]
fn test_debug() {
assert_eq!(
format!("{:?}", Scalar::zero()),
"0x0000000000000000000000000000000000000000000000000000000000000000"
);
assert_eq!(
format!("{:?}", Scalar::one()),
"0x0000000000000000000000000000000000000000000000000000000000000001"
);
assert_eq!(
format!("{:?}", R2),
"0x1824b159acc5056f998c4fefecbc4ff55884b7fa0003480200000001fffffffe"
);
}
#[test]
fn test_equality() {
assert_eq!(Scalar::zero(), Scalar::zero());
assert_eq!(Scalar::one(), Scalar::one());
assert_eq!(R2, R2);
assert!(Scalar::zero() != Scalar::one());
assert!(Scalar::one() != R2);
}
#[test]
fn test_to_bytes() {
assert_eq!(
Scalar::zero().to_bytes(),
[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0
]
);
assert_eq!(
Scalar::one().to_bytes(),
[
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0
]
);
assert_eq!(
R2.to_bytes(),
[
254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24
]
);
assert_eq!(
(-&Scalar::one()).to_bytes(),
[
0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
]
);
}
#[test]
fn test_from_bytes() {
assert_eq!(
Scalar::from_bytes(&[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0
])
.unwrap(),
Scalar::zero()
);
assert_eq!(
Scalar::from_bytes(&[
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0
])
.unwrap(),
Scalar::one()
);
assert_eq!(
Scalar::from_bytes(&[
254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24
])
.unwrap(),
R2
);
assert!(
Scalar::from_bytes(&[
0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
])
.is_some()
.unwrap_u8()
== 1
);
assert!(
Scalar::from_bytes(&[
1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
])
.is_none()
.unwrap_u8()
== 1
);
assert!(
Scalar::from_bytes(&[
2, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115
])
.is_none()
.unwrap_u8()
== 1
);
assert!(
Scalar::from_bytes(&[
1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
216, 58, 51, 72, 125, 157, 41, 83, 167, 237, 115
])
.is_none()
.unwrap_u8()
== 1
);
assert!(
Scalar::from_bytes(&[
1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 116
])
.is_none()
.unwrap_u8()
== 1
);
}
#[test]
fn test_from_u512_zero() {
assert_eq!(
Scalar::zero(),
Scalar::from_u512([
MODULUS.0[0],
MODULUS.0[1],
MODULUS.0[2],
MODULUS.0[3],
0,
0,
0,
0
])
);
}
#[test]
fn test_from_u512_r() {
assert_eq!(R, Scalar::from_u512([1, 0, 0, 0, 0, 0, 0, 0]));
}
#[test]
fn test_from_u512_r2() {
assert_eq!(R2, Scalar::from_u512([0, 0, 0, 0, 1, 0, 0, 0]));
}
#[test]
fn test_from_u512_max() {
let max_u64 = 0xffffffffffffffff;
assert_eq!(
R3 - R,
Scalar::from_u512([max_u64, max_u64, max_u64, max_u64, max_u64, max_u64, max_u64, max_u64])
);
}
#[test]
fn test_from_bytes_wide_r2() {
assert_eq!(
R2,
Scalar::from_bytes_wide(&[
254, 255, 255, 255, 1, 0, 0, 0, 2, 72, 3, 0, 250, 183, 132, 88, 245, 79, 188, 236, 239,
79, 140, 153, 111, 5, 197, 172, 89, 177, 36, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
])
);
}
#[test]
fn test_from_bytes_wide_negative_one() {
assert_eq!(
-&Scalar::one(),
Scalar::from_bytes_wide(&[
0, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
])
);
}
#[test]
fn test_from_bytes_wide_maximum() {
assert_eq!(
Scalar([
0xc62c1805439b73b1,
0xc2b9551e8ced218e,
0xda44ec81daf9a422,
0x5605aa601c162e79
]),
Scalar::from_bytes_wide(&[0xff; 64])
);
}
#[test]
fn test_zero() {
assert_eq!(Scalar::zero(), -&Scalar::zero());
assert_eq!(Scalar::zero(), Scalar::zero() + Scalar::zero());
assert_eq!(Scalar::zero(), Scalar::zero() - Scalar::zero());
assert_eq!(Scalar::zero(), Scalar::zero() * Scalar::zero());
}
#[cfg(test)]
const LARGEST: Scalar = Scalar([
0xffffffff00000000,
0x53bda402fffe5bfe,
0x3339d80809a1d805,
0x73eda753299d7d48,
]);
#[test]
fn test_addition() {
let mut tmp = LARGEST;
tmp += &LARGEST;
assert_eq!(
tmp,
Scalar([
0xfffffffeffffffff,
0x53bda402fffe5bfe,
0x3339d80809a1d805,
0x73eda753299d7d48
])
);
let mut tmp = LARGEST;
tmp += &Scalar([1, 0, 0, 0]);
assert_eq!(tmp, Scalar::zero());
}
#[test]
fn test_negation() {
let tmp = -&LARGEST;
assert_eq!(tmp, Scalar([1, 0, 0, 0]));
let tmp = -&Scalar::zero();
assert_eq!(tmp, Scalar::zero());
let tmp = -&Scalar([1, 0, 0, 0]);
assert_eq!(tmp, LARGEST);
}
#[test]
fn test_subtraction() {
let mut tmp = LARGEST;
tmp -= &LARGEST;
assert_eq!(tmp, Scalar::zero());
let mut tmp = Scalar::zero();
tmp -= &LARGEST;
let mut tmp2 = MODULUS;
tmp2 -= &LARGEST;
assert_eq!(tmp, tmp2);
}
#[test]
fn test_multiplication() {
let mut cur = LARGEST;
for _ in 0..100 {
let mut tmp = cur;
tmp *= &cur;
let mut tmp2 = Scalar::zero();
for b in cur
.to_bytes()
.iter()
.rev()
.flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) == 1u8))
{
let tmp3 = tmp2;
tmp2.add_assign(&tmp3);
if b {
tmp2.add_assign(&cur);
}
}
assert_eq!(tmp, tmp2);
cur.add_assign(&LARGEST);
}
}
#[test]
fn test_squaring() {
let mut cur = LARGEST;
for _ in 0..100 {
let mut tmp = cur;
tmp = tmp.square();
let mut tmp2 = Scalar::zero();
for b in cur
.to_bytes()
.iter()
.rev()
.flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) == 1u8))
{
let tmp3 = tmp2;
tmp2.add_assign(&tmp3);
if b {
tmp2.add_assign(&cur);
}
}
assert_eq!(tmp, tmp2);
cur.add_assign(&LARGEST);
}
}
#[test]
fn test_inversion() {
assert_eq!(Scalar::zero().invert().is_none().unwrap_u8(), 1);
assert_eq!(Scalar::one().invert().unwrap(), Scalar::one());
assert_eq!((-&Scalar::one()).invert().unwrap(), -&Scalar::one());
let mut tmp = R2;
for _ in 0..100 {
let mut tmp2 = tmp.invert().unwrap();
tmp2.mul_assign(&tmp);
assert_eq!(tmp2, Scalar::one());
tmp.add_assign(&R2);
}
}
#[test]
fn test_invert_is_pow() {
let q_minus_2 = [
0xfffffffeffffffff,
0x53bda402fffe5bfe,
0x3339d80809a1d805,
0x73eda753299d7d48,
];
let mut r1 = R;
let mut r2 = R;
let mut r3 = R;
for _ in 0..100 {
r1 = r1.invert().unwrap();
r2 = r2.pow_vartime(&q_minus_2);
r3 = r3.pow(&q_minus_2);
assert_eq!(r1, r2);
assert_eq!(r2, r3);
r1.add_assign(&R);
r2 = r1;
r3 = r1;
}
}
#[test]
fn test_sqrt() {
{
assert_eq!(Scalar::zero().sqrt().unwrap(), Scalar::zero());
}
let mut square = Scalar([
0x46cd85a5f273077e,
0x1d30c47dd68fc735,
0x77f656f60beca0eb,
0x494aa01bdf32468d,
]);
let mut none_count = 0;
for _ in 0..100 {
let square_root = square.sqrt();
if square_root.is_none().unwrap_u8() == 1 {
none_count += 1;
} else {
assert_eq!(square_root.unwrap() * square_root.unwrap(), square);
}
square -= Scalar::one();
}
assert_eq!(49, none_count);
}
#[test]
fn test_from_raw() {
assert_eq!(
Scalar::from_raw([
0x1fffffffd,
0x5884b7fa00034802,
0x998c4fefecbc4ff5,
0x1824b159acc5056f
]),
Scalar::from_raw([0xffffffffffffffff; 4])
);
assert_eq!(Scalar::from_raw(MODULUS.0), Scalar::zero());
assert_eq!(Scalar::from_raw([1, 0, 0, 0]), R);
}
#[test]
fn test_double() {
let a = Scalar::from_raw([
0x1fff3231233ffffd,
0x4884b7fa00034802,
0x998c4fefecbc4ff3,
0x1824b159acc50562,
]);
assert_eq!(a.double(), a + a);
}
#[test]
fn test_partial_ord() {
let one = Scalar::one();
assert!(one < -one);
}
#[test]
fn test_xor() {
let a = Scalar::from(500u64);
let b = Scalar::from(499u64);
let res = Scalar::from(7u64);
assert_eq!(&a ^ &b, res);
}
#[test]
fn test_and() {
let a = Scalar::one();
let b = Scalar::one();
let res = Scalar::one();
assert_eq!(&a & &b, res);
assert_eq!(a & -a, Scalar::zero());
}
#[test]
fn test_iter_sum() {
let scalars = vec![Scalar::one(), Scalar::one()];
let res: Scalar = scalars.iter().sum();
assert_eq!(res, Scalar::one() + Scalar::one());
}
#[test]
fn test_iter_prod() {
let scalars = vec![Scalar::one() + Scalar::one(), Scalar::one() + Scalar::one()];
let res: Scalar = scalars.iter().product();
assert_eq!(res, Scalar::from(4u64));
}
#[test]
fn serde_bincode_scalar_roundtrip() {
use bincode;
let scalar = -Scalar::from(3u64);
let encoded = bincode::serialize(&scalar).unwrap();
let parsed: Scalar = bincode::deserialize(&encoded).unwrap();
assert_eq!(parsed, scalar);
assert_eq!(encoded.len(), 32);
assert_eq!(scalar, bincode::deserialize(&scalar.to_bytes()).unwrap(),);
}
#[test]
fn random_scalar_generation() {
for _ in 0..5000 {
Scalar::random(&mut rand::thread_rng());
}
}
#[test]
fn bit_repr() {
let two_pow_128 = Scalar::from(2u64).pow(&[128, 0, 0, 0]);
let two_pow_128_bits = [
0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
];
assert_eq!(&two_pow_128.to_bits()[..], &two_pow_128_bits[..]);
let two_pow_128_minus_rand = Scalar::from(2u64).pow(&[128, 0, 0, 0]) - Scalar::from(7568589u64);
let two_pow_128_bits = [
1u8, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
];
assert_eq!(
&two_pow_128_minus_rand.to_bits()[..128],
&two_pow_128_bits[..]
)
}
#[test]
fn pow_of_two_test() {
let two = Scalar::from(2u64);
for i in 0..1000 {
assert_eq!(Scalar::pow_of_2(i as u64), two.pow(&[i as u64, 0, 0, 0]));
}
}