use crate::{backend, hash::HashInto, marker::*, op};
use backend::BackendScalar;
use core::marker::PhantomData;
use digest::{generic_array::typenum::U32, Digest};
use rand_core::{CryptoRng, RngCore};
#[derive(Clone, Eq)]
pub struct Scalar<S = Secret, Z = NonZero>(pub(crate) backend::Scalar, PhantomData<(Z, S)>);
impl<Z> core::hash::Hash for Scalar<Public, Z> {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.to_bytes().hash(state)
}
}
impl<Z, S> Scalar<S, Z> {
pub fn to_bytes(&self) -> [u8; 32] {
backend::Scalar::to_bytes(&self.0)
}
pub fn conditional_negate(&mut self, cond: bool) {
op::ScalarUnary::conditional_negate(self, cond)
}
pub fn is_high(&self) -> bool {
op::ScalarUnary::is_high(self)
}
pub fn is_zero(&self) -> bool {
op::ScalarUnary::is_zero(self)
}
pub(crate) fn from_inner(inner: backend::Scalar) -> Self {
Scalar(inner, PhantomData)
}
pub fn set_secrecy<SNew>(self) -> Scalar<SNew, Z> {
Scalar::from_inner(self.0)
}
}
impl<S> Scalar<S, NonZero> {
pub fn invert(&self) -> Self {
Self::from_inner(op::ScalarUnary::invert(self))
}
}
impl Scalar<Secret, NonZero> {
pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
let mut bytes = [0u8; 32];
rng.fill_bytes(&mut bytes);
Scalar::from_bytes_mod_order(bytes)
.mark::<NonZero>()
.expect("computationally unreachable")
}
pub fn from_hash(hash: impl Digest<OutputSize = U32>) -> Self {
let mut bytes = [0u8; 32];
bytes.copy_from_slice(hash.finalize().as_slice());
Scalar::from_bytes_mod_order(bytes)
.mark::<NonZero>()
.expect("computationally unreachable")
}
pub fn from_non_zero_u32(int: core::num::NonZeroU32) -> Self {
Self::from_inner(backend::Scalar::from_u32(int.get()))
}
pub fn one() -> Self {
crate::s!(1)
}
pub fn minus_one() -> Self {
Self::from_inner(backend::Scalar::minus_one())
}
}
impl Scalar<Secret, Zero> {
pub fn from_bytes_mod_order(bytes: [u8; 32]) -> Self {
Self::from_inner(backend::Scalar::from_bytes_mod_order(bytes))
}
pub fn from_slice_mod_order(slice: &[u8]) -> Option<Self> {
if slice.len() != 32 {
return None;
}
let mut bytes = [0u8; 32];
bytes.copy_from_slice(slice);
Some(Self::from_bytes_mod_order(bytes))
}
pub fn from_bytes(bytes: [u8; 32]) -> Option<Self> {
backend::Scalar::from_bytes(bytes).map(Self::from_inner)
}
pub fn from_slice(slice: &[u8]) -> Option<Self> {
if slice.len() != 32 {
return None;
}
let mut bytes = [0u8; 32];
bytes.copy_from_slice(&slice);
Self::from_bytes(bytes)
}
pub fn zero() -> Self {
Self::from_inner(backend::Scalar::zero())
}
}
impl<Z1, Z2, S1, S2> PartialEq<Scalar<S2, Z2>> for Scalar<S1, Z1> {
fn eq(&self, rhs: &Scalar<S2, Z2>) -> bool {
crate::op::ScalarBinary::eq((self, rhs))
}
}
impl From<u32> for Scalar<Secret, Zero> {
fn from(int: u32) -> Self {
Self::from_inner(backend::Scalar::from_u32(int))
}
}
crate::impl_fromstr_deserailize! {
name => "non-zero secp256k1 scalar",
fn from_bytes<S>(bytes: [u8;32]) -> Option<Scalar<S,NonZero>> {
Scalar::from_bytes(bytes).and_then(|scalar| scalar.set_secrecy::<S>().mark::<NonZero>())
}
}
crate::impl_display_debug_serialize! {
fn to_bytes<Z,S>(scalar: &Scalar<S,Z>) -> [u8;32] {
scalar.to_bytes()
}
}
crate::impl_fromstr_deserailize! {
name => "secp256k1 scalar",
fn from_bytes<S>(bytes: [u8;32]) -> Option<Scalar<S,Zero>> {
Scalar::from_bytes(bytes).map(|scalar| scalar.set_secrecy::<S>())
}
}
impl<S, Z> core::ops::Neg for Scalar<S, Z> {
type Output = Scalar<S, Z>;
fn neg(self) -> Self::Output {
use crate::op::ScalarUnary;
Scalar::from_inner(ScalarUnary::negate(&self))
}
}
impl<S, Z> core::ops::Neg for &Scalar<S, Z> {
type Output = Scalar<S, Z>;
fn neg(self) -> Self::Output {
use crate::op::ScalarUnary;
Scalar::from_inner(ScalarUnary::negate(self))
}
}
impl HashInto for Scalar {
fn hash_into(&self, hash: &mut impl digest::Digest) {
hash.update(&self.to_bytes())
}
}
impl<S> Default for Scalar<S, Zero>
where
S: Secrecy,
{
fn default() -> Self {
Scalar::zero().mark::<S>()
}
}
impl<S> Default for Scalar<S, NonZero>
where
S: Secrecy,
{
fn default() -> Self {
Scalar::one().mark::<S>()
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{op, s};
#[cfg(feature = "serde")]
#[test]
fn scalar_serde_rountrip() {
let original = Scalar::random(&mut rand::thread_rng());
let serialized = bincode::serialize(&original).unwrap();
let deserialized = bincode::deserialize::<Scalar>(&serialized[..]).unwrap();
assert_eq!(deserialized, original)
}
crate::test_plus_wasm! {
fn random() {
let scalar_1 = Scalar::random(&mut rand::thread_rng());
let scalar_2 = Scalar::random(&mut rand::thread_rng());
assert_ne!(scalar_1, scalar_2);
}
fn invert() {
let x = Scalar::random(&mut rand::thread_rng());
assert!(s!(x * {x.invert()}) == Scalar::from(1));
}
fn neg() {
let x = Scalar::random(&mut rand::thread_rng());
assert_eq!(s!(x - x), Scalar::zero());
assert_eq!(-Scalar::zero(), Scalar::zero())
}
fn one() {
assert_eq!(Scalar::one(), Scalar::from(1));
assert_eq!(Scalar::minus_one(), -Scalar::one());
assert_eq!(op::scalar_mul(&s!(3), &Scalar::minus_one()), -s!(3));
}
fn zero() {
assert_eq!(Scalar::zero(), Scalar::from(0));
}
fn from_slice() {
assert!(Scalar::from_slice(b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".as_ref()).is_some());
assert!(Scalar::from_slice(b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".as_ref()).is_none());
assert!(Scalar::from_slice(
hex_literal::hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
.as_ref()
)
.is_none());
}
fn from_slice_mod_order() {
assert_eq!(
Scalar::from_slice_mod_order(b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".as_ref())
.unwrap()
.to_bytes(),
*b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
);
assert_eq!(
Scalar::from_slice_mod_order(
hex_literal::hex!(
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364142"
)
.as_ref()
)
.unwrap(),
Scalar::from(1)
)
}
fn scalar_subtraction_is_not_commutative() {
let two = Scalar::from(2);
let three = Scalar::from(3);
let minus_1 = Scalar::minus_one();
let one = Scalar::from(1);
assert_eq!(
minus_1,
Scalar::from_bytes_mod_order(hex_literal::hex!(
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364140"
))
);
assert_eq!(s!(two - three), minus_1);
assert_eq!(s!(three - two), one);
}
fn nz_scalar_to_scalar_subtraction_is_not_commutative() {
let two = s!(2);
let three = s!(3);
let minus_1 = Scalar::minus_one();
let one = Scalar::from(1);
assert_eq!(s!(two - three), minus_1);
assert_eq!(s!(three - two), one);
}
}
}