elliptic-curve 0.11.4

General purpose Elliptic Curve Cryptography (ECC) support, including types and traits for representing various elliptic curve forms, scalars, points, and public/secret keys composed thereof.
Documentation
//! Non-zero scalar type.

use crate::{
    bigint::Encoding as _,
    hex,
    ops::{Invert, Reduce, ReduceNonZero},
    rand_core::{CryptoRng, RngCore},
    Curve, Error, FieldBytes, IsHigh, Result, Scalar, ScalarArithmetic, ScalarCore, SecretKey,
};
use core::{
    fmt,
    ops::{Deref, Neg},
    str,
};
use crypto_bigint::{ArrayEncoding, Integer};
use ff::{Field, PrimeField};
use generic_array::GenericArray;
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
use zeroize::Zeroize;

/// Non-zero scalar type.
///
/// This type ensures that its value is not zero, ala `core::num::NonZero*`.
/// To do this, the generic `S` type must impl both `Default` and
/// `ConstantTimeEq`, with the requirement that `S::default()` returns 0.
///
/// In the context of ECC, it's useful for ensuring that scalar multiplication
/// cannot result in the point at infinity.
#[cfg_attr(docsrs, doc(cfg(feature = "arithmetic")))]
#[derive(Clone)]
pub struct NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    scalar: Scalar<C>,
}

impl<C> NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    /// Generate a random `NonZeroScalar`.
    pub fn random(mut rng: impl CryptoRng + RngCore) -> Self {
        // Use rejection sampling to eliminate zero values.
        // While this method isn't constant-time, the attacker shouldn't learn
        // anything about unrelated outputs so long as `rng` is a secure `CryptoRng`.
        loop {
            if let Some(result) = Self::new(Field::random(&mut rng)).into() {
                break result;
            }
        }
    }

    /// Create a [`NonZeroScalar`] from a scalar.
    pub fn new(scalar: Scalar<C>) -> CtOption<Self> {
        CtOption::new(Self { scalar }, !scalar.is_zero())
    }

    /// Decode a [`NonZeroScalar`] from a big endian-serialized field element.
    pub fn from_repr(repr: FieldBytes<C>) -> CtOption<Self> {
        Scalar::<C>::from_repr(repr).and_then(Self::new)
    }

    /// Create a [`NonZeroScalar`] from a `C::UInt`.
    pub fn from_uint(uint: C::UInt) -> CtOption<Self> {
        ScalarCore::new(uint).and_then(|scalar| Self::new(scalar.into()))
    }
}

impl<C> AsRef<Scalar<C>> for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    fn as_ref(&self) -> &Scalar<C> {
        &self.scalar
    }
}

impl<C> ConditionallySelectable for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
        Self {
            scalar: Scalar::<C>::conditional_select(&a.scalar, &b.scalar, choice),
        }
    }
}

impl<C> ConstantTimeEq for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    fn ct_eq(&self, other: &Self) -> Choice {
        self.scalar.ct_eq(&other.scalar)
    }
}

impl<C> Copy for NonZeroScalar<C> where C: Curve + ScalarArithmetic {}

impl<C> Deref for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    type Target = Scalar<C>;

    fn deref(&self) -> &Scalar<C> {
        &self.scalar
    }
}

impl<C> From<NonZeroScalar<C>> for FieldBytes<C>
where
    C: Curve + ScalarArithmetic,
{
    fn from(scalar: NonZeroScalar<C>) -> FieldBytes<C> {
        Self::from(&scalar)
    }
}

impl<C> From<&NonZeroScalar<C>> for FieldBytes<C>
where
    C: Curve + ScalarArithmetic,
{
    fn from(scalar: &NonZeroScalar<C>) -> FieldBytes<C> {
        scalar.to_repr()
    }
}

impl<C> From<NonZeroScalar<C>> for ScalarCore<C>
where
    C: Curve + ScalarArithmetic,
{
    fn from(scalar: NonZeroScalar<C>) -> ScalarCore<C> {
        ScalarCore::from_be_bytes(scalar.to_repr()).unwrap()
    }
}

impl<C> From<&NonZeroScalar<C>> for ScalarCore<C>
where
    C: Curve + ScalarArithmetic,
{
    fn from(scalar: &NonZeroScalar<C>) -> ScalarCore<C> {
        ScalarCore::from_be_bytes(scalar.to_repr()).unwrap()
    }
}

impl<C> From<SecretKey<C>> for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    fn from(sk: SecretKey<C>) -> NonZeroScalar<C> {
        Self::from(&sk)
    }
}

impl<C> From<&SecretKey<C>> for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    fn from(sk: &SecretKey<C>) -> NonZeroScalar<C> {
        let scalar = sk.as_scalar_core().to_scalar();
        debug_assert!(!bool::from(scalar.is_zero()));
        Self { scalar }
    }
}

impl<C> Invert for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    type Output = Scalar<C>;

    /// Perform a scalar inversion
    fn invert(&self) -> CtOption<Self::Output> {
        ff::Field::invert(&self.scalar)
    }
}

impl<C> IsHigh for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    fn is_high(&self) -> Choice {
        self.scalar.is_high()
    }
}

impl<C> Neg for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    type Output = NonZeroScalar<C>;

    fn neg(self) -> NonZeroScalar<C> {
        let scalar = -self.scalar;
        debug_assert!(!bool::from(scalar.is_zero()));
        NonZeroScalar { scalar }
    }
}

/// Note: implementation is the same as `ReduceNonZero`
impl<C, I> Reduce<I> for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
    I: Integer + ArrayEncoding,
    Scalar<C>: ReduceNonZero<I>,
{
    fn from_uint_reduced(n: I) -> Self {
        Self::from_uint_reduced_nonzero(n)
    }
}

impl<C, I> ReduceNonZero<I> for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
    I: Integer + ArrayEncoding,
    Scalar<C>: ReduceNonZero<I>,
{
    fn from_uint_reduced_nonzero(n: I) -> Self {
        let scalar = Scalar::<C>::from_uint_reduced_nonzero(n);
        debug_assert!(!bool::from(scalar.is_zero()));
        Self::new(scalar).unwrap()
    }
}

impl<C> TryFrom<&[u8]> for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    type Error = Error;

    fn try_from(bytes: &[u8]) -> Result<Self> {
        if bytes.len() == C::UInt::BYTE_SIZE {
            Option::from(NonZeroScalar::from_repr(GenericArray::clone_from_slice(
                bytes,
            )))
            .ok_or(Error)
        } else {
            Err(Error)
        }
    }
}

impl<C> Zeroize for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    fn zeroize(&mut self) {
        // Use zeroize's volatile writes to ensure value is cleared.
        self.scalar.zeroize();

        // Write a 1 instead of a 0 to ensure this type's non-zero invariant
        // is upheld.
        self.scalar = Scalar::<C>::one();
    }
}

impl<C> fmt::Display for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{:X}", self)
    }
}

impl<C> fmt::LowerHex for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        hex::write_lower(&self.to_repr(), f)
    }
}

impl<C> fmt::UpperHex for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        hex::write_upper(&self.to_repr(), f)
    }
}

impl<C> str::FromStr for NonZeroScalar<C>
where
    C: Curve + ScalarArithmetic,
{
    type Err = Error;

    fn from_str(hex: &str) -> Result<Self> {
        let mut bytes = FieldBytes::<C>::default();
        hex::decode(hex, &mut bytes)?;
        Option::from(Self::from_repr(bytes)).ok_or(Error)
    }
}

#[cfg(all(test, feature = "dev"))]
mod tests {
    use crate::dev::{NonZeroScalar, Scalar};
    use ff::{Field, PrimeField};
    use hex_literal::hex;
    use zeroize::Zeroize;

    #[test]
    fn round_trip() {
        let bytes = hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
        let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap();
        assert_eq!(&bytes, scalar.to_repr().as_slice());
    }

    #[test]
    fn zeroize() {
        let mut scalar = NonZeroScalar::new(Scalar::from(42u64)).unwrap();
        scalar.zeroize();
        assert_eq!(*scalar, Scalar::one());
    }
}