#[cfg(all(feature = "pkcs8", feature = "sec1"))]
mod pkcs8;
use crate::{Curve, Error, FieldBytes, Result, ScalarValue};
use array::typenum::Unsigned;
use common::{Generate, InvalidKey, KeySizeUser, TryKeyInit};
use core::fmt::{self, Debug};
use rand_core::{CryptoRng, TryCryptoRng};
use subtle::{Choice, ConstantTimeEq, CtOption};
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
#[cfg(feature = "ecdh")]
use crate::ecdh;
#[cfg(feature = "arithmetic")]
use crate::{CurveArithmetic, NonZeroScalar, PublicKey};
#[cfg(all(feature = "arithmetic", feature = "pem"))]
use alloc::string::String;
#[cfg(feature = "pem")]
use pem_rfc7468::{self as pem, PemLabel};
#[cfg(all(feature = "alloc", feature = "arithmetic", feature = "sec1"))]
use {
crate::{
AffinePoint,
sec1::{FromSec1Point, ToSec1Point},
},
alloc::vec::Vec,
sec1::der::Encode,
};
#[cfg(feature = "sec1")]
use {
crate::{
DecodeError, DecodeResult, FieldBytesSize,
sec1::{ModulusSize, Sec1Point, ValidatePublicKey},
},
sec1::der::{self, Decode, oid::AssociatedOid},
};
#[cfg(all(doc, feature = "pkcs8"))]
use {crate::pkcs8::DecodePrivateKey, core::str::FromStr};
#[derive(Clone)]
pub struct SecretKey<C: Curve> {
inner: ScalarValue<C>,
}
impl<C> SecretKey<C>
where
C: Curve,
{
pub const MIN_SIZE: usize = 24;
pub fn from_scalar(scalar: impl Into<ScalarValue<C>>) -> CtOption<Self> {
let inner = scalar.into();
CtOption::new(Self { inner }, !inner.is_zero())
}
pub fn as_scalar_value(&self) -> &ScalarValue<C> {
&self.inner
}
#[cfg(feature = "arithmetic")]
pub fn to_nonzero_scalar(&self) -> NonZeroScalar<C>
where
C: CurveArithmetic,
{
self.into()
}
#[cfg(feature = "arithmetic")]
pub fn public_key(&self) -> PublicKey<C>
where
C: CurveArithmetic,
{
PublicKey::from_secret_scalar(&self.to_nonzero_scalar())
}
pub fn from_bytes(bytes: &FieldBytes<C>) -> Result<Self> {
let inner = ScalarValue::<C>::from_bytes(bytes)
.into_option()
.ok_or(Error)?;
if inner.is_zero().into() {
return Err(Error);
}
Ok(Self { inner })
}
pub fn from_slice(slice: &[u8]) -> Result<Self> {
if let Ok(field_bytes) = <&FieldBytes<C>>::try_from(slice) {
Self::from_bytes(field_bytes)
} else if (Self::MIN_SIZE..C::FieldBytesSize::USIZE).contains(&slice.len()) {
let mut bytes = Zeroizing::new(FieldBytes::<C>::default());
let offset = C::FieldBytesSize::USIZE.saturating_sub(slice.len());
bytes[offset..].copy_from_slice(slice);
Self::from_bytes(&bytes)
} else {
Err(Error)
}
}
pub fn to_bytes(&self) -> FieldBytes<C> {
self.inner.to_bytes()
}
#[cfg(feature = "ecdh")]
pub fn diffie_hellman(&self, public_key: &PublicKey<C>) -> ecdh::SharedSecret<C>
where
C: CurveArithmetic,
{
ecdh::diffie_hellman(self.to_nonzero_scalar(), public_key.as_affine())
}
#[cfg(any(feature = "pkcs8", feature = "sec1"))]
#[allow(clippy::missing_panics_doc, reason = "should not panic")]
pub fn from_der(der_bytes: &[u8]) -> DecodeResult<Self>
where
C: AssociatedOid + Curve + ValidatePublicKey,
FieldBytesSize<C>: ModulusSize,
{
#[allow(unused_assignments)]
let mut err: Option<DecodeError> = None;
#[cfg(feature = "pkcs8")]
match ::pkcs8::DecodePrivateKey::from_pkcs8_der(der_bytes) {
Ok(sk) => return Ok(sk),
Err(e) => err = Some(e.into()),
}
#[cfg(feature = "sec1")]
match Self::from_sec1_der(der_bytes) {
Ok(sk) => return Ok(sk),
Err(e) => {
let _ = err.get_or_insert(e);
}
}
Err(err.expect("should be set"))
}
#[cfg(feature = "pem")]
pub fn from_pem(pem: &str) -> DecodeResult<Self>
where
C: AssociatedOid + Curve + ValidatePublicKey,
FieldBytesSize<C>: ModulusSize,
{
let label = pem_rfc7468::decode_label(pem.as_bytes()).map_err(DecodeError::Pem)?;
if ::pkcs8::PrivateKeyInfoRef::validate_pem_label(label).is_ok() {
return ::pkcs8::DecodePrivateKey::from_pkcs8_pem(pem).map_err(DecodeError::Pkcs8);
} else if ::sec1::EcPrivateKey::validate_pem_label(label).is_ok() {
return ::sec1::DecodeEcPrivateKey::from_sec1_pem(pem).map_err(DecodeError::Sec1);
}
Err(pem_rfc7468::Error::Label.into())
}
#[cfg(feature = "sec1")]
pub fn from_sec1_der(der_bytes: &[u8]) -> DecodeResult<Self>
where
C: AssociatedOid + Curve + ValidatePublicKey,
FieldBytesSize<C>: ModulusSize,
{
let sec1_key = sec1::EcPrivateKey::try_from(der_bytes)?;
Self::try_from(sec1_key).map_err(|e| DecodeError::Sec1(e.into()))
}
#[cfg(all(feature = "alloc", feature = "arithmetic", feature = "sec1"))]
pub fn to_sec1_der(&self) -> der::Result<Zeroizing<Vec<u8>>>
where
C: AssociatedOid + CurveArithmetic,
AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
FieldBytesSize<C>: ModulusSize,
{
let private_key_bytes = Zeroizing::new(self.to_bytes());
let public_key_bytes = self.public_key().to_sec1_point(false);
let parameters = sec1::EcParameters::NamedCurve(C::OID);
let ec_private_key = Zeroizing::new(
sec1::EcPrivateKey {
private_key: &private_key_bytes,
parameters: Some(parameters),
public_key: Some(public_key_bytes.as_bytes()),
}
.to_der()?,
);
Ok(ec_private_key)
}
#[cfg(feature = "pem")]
pub fn from_sec1_pem(s: &str) -> DecodeResult<Self>
where
C: AssociatedOid + Curve + ValidatePublicKey,
FieldBytesSize<C>: ModulusSize,
{
let (label, der_bytes) = pem::decode_vec(s.as_bytes()).map_err(DecodeError::Pem)?;
if label != sec1::EcPrivateKey::PEM_LABEL {
return Err(pem_rfc7468::Error::Label.into());
}
Self::from_sec1_der(&der_bytes)
}
#[cfg(feature = "pem")]
pub fn to_sec1_pem(&self, line_ending: pem::LineEnding) -> Result<Zeroizing<String>>
where
C: AssociatedOid + CurveArithmetic,
AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
FieldBytesSize<C>: ModulusSize,
{
self.to_sec1_der()
.ok()
.and_then(|der| {
pem::encode_string(sec1::EcPrivateKey::PEM_LABEL, line_ending, &der).ok()
})
.map(Zeroizing::new)
.ok_or(Error)
}
#[deprecated(since = "0.14.0", note = "use the `Generate` trait instead")]
pub fn random<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
Self::generate_from_rng(rng)
}
}
impl<C> ConstantTimeEq for SecretKey<C>
where
C: Curve,
{
fn ct_eq(&self, other: &Self) -> Choice {
self.inner.ct_eq(&other.inner)
}
}
impl<C> Debug for SecretKey<C>
where
C: Curve,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(core::any::type_name::<Self>())
.finish_non_exhaustive()
}
}
impl<C> Drop for SecretKey<C>
where
C: Curve,
{
fn drop(&mut self) {
self.inner.zeroize();
}
}
impl<C> ZeroizeOnDrop for SecretKey<C> where C: Curve {}
impl<C: Curve> Eq for SecretKey<C> {}
impl<C> PartialEq for SecretKey<C>
where
C: Curve,
{
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl<C> Generate for SecretKey<C>
where
C: Curve,
{
fn try_generate_from_rng<R: TryCryptoRng + ?Sized>(
rng: &mut R,
) -> core::result::Result<Self, R::Error> {
Ok(Self {
inner: ScalarValue::<C>::try_generate_from_rng(rng)?,
})
}
}
impl<C> KeySizeUser for SecretKey<C>
where
C: Curve,
{
type KeySize = C::FieldBytesSize;
}
impl<C> TryKeyInit for SecretKey<C>
where
C: Curve,
{
fn new(key_bytes: &FieldBytes<C>) -> core::result::Result<Self, InvalidKey> {
Self::from_bytes(key_bytes).map_err(|_| InvalidKey)
}
}
#[cfg(feature = "sec1")]
impl<C> sec1::DecodeEcPrivateKey for SecretKey<C>
where
C: AssociatedOid + Curve + ValidatePublicKey,
FieldBytesSize<C>: ModulusSize,
{
fn from_sec1_der(bytes: &[u8]) -> sec1::Result<Self> {
Ok(sec1::EcPrivateKey::from_der(bytes)?.try_into()?)
}
}
#[cfg(all(feature = "alloc", feature = "arithmetic", feature = "sec1"))]
impl<C> sec1::EncodeEcPrivateKey for SecretKey<C>
where
C: AssociatedOid + CurveArithmetic,
AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
FieldBytesSize<C>: ModulusSize,
{
fn to_sec1_der(&self) -> sec1::Result<der::SecretDocument> {
let private_key_bytes = Zeroizing::new(self.to_bytes());
let public_key_bytes = self.public_key().to_sec1_point(false);
Ok(der::SecretDocument::encode_msg(&sec1::EcPrivateKey {
private_key: &private_key_bytes,
parameters: Some(C::OID.into()),
public_key: Some(public_key_bytes.as_bytes()),
})?)
}
}
#[cfg(feature = "sec1")]
impl<C> TryFrom<sec1::EcPrivateKey<'_>> for SecretKey<C>
where
C: AssociatedOid + Curve + ValidatePublicKey,
FieldBytesSize<C>: ModulusSize,
{
type Error = der::Error;
fn try_from(sec1_private_key: sec1::EcPrivateKey<'_>) -> der::Result<Self> {
if let Some(sec1::EcParameters::NamedCurve(curve_oid)) = sec1_private_key.parameters {
if C::OID != curve_oid {
return Err(der::Tag::ObjectIdentifier.value_error().into());
}
}
let secret_key = Self::from_slice(sec1_private_key.private_key)
.map_err(|_| der::Tag::OctetString.value_error())?;
if let Some(pk_bytes) = sec1_private_key.public_key {
let pk = Sec1Point::<C>::from_bytes(pk_bytes)
.map_err(|_| der::Tag::BitString.value_error())?;
if C::validate_public_key(&secret_key, &pk).is_err() {
return Err(der::Tag::BitString.value_error().into());
}
}
Ok(secret_key)
}
}
#[cfg(feature = "arithmetic")]
impl<C> From<NonZeroScalar<C>> for SecretKey<C>
where
C: CurveArithmetic,
{
fn from(scalar: NonZeroScalar<C>) -> SecretKey<C> {
SecretKey::from(&scalar)
}
}
#[cfg(feature = "arithmetic")]
impl<C> From<&NonZeroScalar<C>> for SecretKey<C>
where
C: CurveArithmetic,
{
fn from(scalar: &NonZeroScalar<C>) -> SecretKey<C> {
SecretKey {
inner: scalar.into(),
}
}
}