extern crate alloc;
use alloc::vec::Vec;
#[cfg(feature = "std")]
use tee_crypto::rsa::RsaKeypair;
use tee_crypto::rsa::RsaPublic;
use super::hash::MdType;
use super::map::map_tee_err;
#[cfg(feature = "std")]
use super::rng::CryptoRng;
use super::rsa_ecdsa;
use super::sm2_raw;
use crate::crypto::CryptoError;
pub const ECDSA_MAX_LEN: usize = 256;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PkType {
SM2,
Rsa,
Eckey,
Ecdsa,
}
enum PkInner {
#[cfg(feature = "std")]
Sm2Sign {
scalar: [u8; 32],
},
Sm2Verify {
sec1: Vec<u8>,
},
#[cfg(feature = "std")]
RsaSign(RsaKeypair),
RsaVerify(RsaPublic),
#[cfg(feature = "std")]
EcdsaSign {
scalar: [u8; 32],
},
EcdsaVerify {
x: Vec<u8>,
y: Vec<u8>,
},
}
pub struct Pk {
inner: PkInner,
}
impl Pk {
pub fn pk_type(&self) -> PkType {
match &self.inner {
#[cfg(feature = "std")]
PkInner::Sm2Sign { .. } => PkType::SM2,
PkInner::Sm2Verify { .. } => PkType::SM2,
#[cfg(feature = "std")]
PkInner::RsaSign(_) => PkType::Rsa,
PkInner::RsaVerify(_) => PkType::Rsa,
#[cfg(feature = "std")]
PkInner::EcdsaSign { .. } => PkType::Ecdsa,
PkInner::EcdsaVerify { .. } => PkType::Ecdsa,
}
}
pub fn len(&self) -> usize {
match &self.inner {
#[cfg(feature = "std")]
PkInner::Sm2Sign { .. } | PkInner::EcdsaSign { .. } => 256,
PkInner::Sm2Verify { .. } | PkInner::EcdsaVerify { .. } => 256,
#[cfg(feature = "std")]
PkInner::RsaSign(k) => rsa_ecdsa::rsa_key_bits(k),
PkInner::RsaVerify(k) => rsa_ecdsa::rsa_pub_key_bits(k),
}
}
#[cfg(feature = "std")]
pub fn sm2_sign(
&mut self,
md: MdType,
msg: &[u8],
sig: &mut [u8],
rng: &mut dyn CryptoRng,
) -> Result<usize, CryptoError> {
if md != MdType::SM3 {
return Err(CryptoError::UnsupportedAlgorithm);
}
let scalar = match &self.inner {
PkInner::Sm2Sign { scalar } => *scalar,
_ => return Err(CryptoError::InvalidKey),
};
let der = sm2_raw::sm2_sign_message(&scalar, msg, rng)?;
if der.len() > sig.len() {
return Err(CryptoError::InvalidLength);
}
sig[..der.len()].copy_from_slice(&der);
Ok(der.len())
}
pub fn sm2_verify(&mut self, md: MdType, msg: &[u8], sig: &[u8]) -> Result<(), CryptoError> {
if md != MdType::SM3 {
return Err(CryptoError::UnsupportedAlgorithm);
}
let sec1 = match &self.inner {
PkInner::Sm2Verify { sec1 } => sec1.as_slice(),
_ => return Err(CryptoError::InvalidKey),
};
sm2_raw::sm2_verify_message_sec1(sec1, msg, sig)
}
#[cfg(feature = "std")]
pub fn sign(
&mut self,
md: MdType,
digest: &[u8],
sig: &mut [u8],
rng: &mut dyn CryptoRng,
) -> Result<usize, CryptoError> {
let digest: [u8; 32] = digest
.try_into()
.map_err(|_| CryptoError::InvalidLength)?;
let out = match &self.inner {
PkInner::Sm2Sign { scalar } => {
if md != MdType::SM3 {
return Err(CryptoError::UnsupportedAlgorithm);
}
sm2_raw::sm2_sign_digest(scalar, &digest, rng)?
}
PkInner::RsaSign(key) => {
if md != MdType::Sha256 {
return Err(CryptoError::UnsupportedAlgorithm);
}
rsa_ecdsa::rsa_sign_digest(key, &digest, rng)?
}
PkInner::EcdsaSign { scalar } => {
if md != MdType::Sha256 {
return Err(CryptoError::UnsupportedAlgorithm);
}
rsa_ecdsa::ecdsa_sign_digest(scalar, &digest, rng)?
}
_ => return Err(CryptoError::InvalidKey),
};
if out.len() > sig.len() {
return Err(CryptoError::InvalidLength);
}
sig[..out.len()].copy_from_slice(&out);
Ok(out.len())
}
pub fn verify(&mut self, md: MdType, digest: &[u8], sig: &[u8]) -> Result<(), CryptoError> {
let digest: [u8; 32] = digest
.try_into()
.map_err(|_| CryptoError::InvalidLength)?;
match &self.inner {
PkInner::Sm2Verify { sec1 } => {
if md != MdType::SM3 {
return Err(CryptoError::UnsupportedAlgorithm);
}
tee_crypto::sm2::sm2_verify_digest_sec1(sec1, &digest, sig).map_err(map_tee_err)
}
PkInner::RsaVerify(key) => {
if md != MdType::Sha256 {
return Err(CryptoError::UnsupportedAlgorithm);
}
rsa_ecdsa::rsa_verify_digest(key, &digest, sig)
}
PkInner::EcdsaVerify { x, y } => {
if md != MdType::Sha256 {
return Err(CryptoError::UnsupportedAlgorithm);
}
rsa_ecdsa::ecdsa_verify_digest(x, y, &digest, sig)
}
#[cfg(feature = "std")]
_ => Err(CryptoError::InvalidKey),
}
}
}
impl Pk {
#[cfg(feature = "std")]
pub(crate) fn from_signing_scalar(scalar: [u8; 32]) -> Self {
Self {
inner: PkInner::Sm2Sign { scalar },
}
}
pub(crate) fn from_sm2_sec1(sec1: Vec<u8>) -> Self {
Self {
inner: PkInner::Sm2Verify { sec1 },
}
}
#[cfg(feature = "std")]
pub(crate) fn from_rsa_sign(key: RsaKeypair) -> Self {
Self {
inner: PkInner::RsaSign(key),
}
}
pub(crate) fn from_rsa_verify(key: RsaPublic) -> Self {
Self {
inner: PkInner::RsaVerify(key),
}
}
#[cfg(feature = "std")]
pub(crate) fn from_ecdsa_sign(scalar: [u8; 32]) -> Self {
Self {
inner: PkInner::EcdsaSign { scalar },
}
}
pub(crate) fn from_ecdsa_verify(x: Vec<u8>, y: Vec<u8>) -> Self {
Self {
inner: PkInner::EcdsaVerify { x, y },
}
}
}