use pkcs1::RsaPublicKey as AsnRsaPublicKey;
use rustls::crypto::WebPkiSupportedAlgorithms;
use rustls::pki_types::{AlgorithmIdentifier, InvalidSignature, SignatureVerificationAlgorithm};
use rustls::SignatureScheme;
use rustls_pki_types::alg_id;
use symcrypt::ecc::{CurveType, EcKey, EcKeyUsage};
use symcrypt::hash::{sha256, sha384, sha512, HashAlgorithm};
use symcrypt::rsa::{RsaKey, RsaKeyUsage};
fn extract_rsa_public_key(pub_key: &[u8]) -> Result<(Vec<u8>, Vec<u8>), InvalidSignature> {
let key = AsnRsaPublicKey::try_from(pub_key).map_err(|_| InvalidSignature)?;
let modulus = key.modulus.as_bytes().to_vec();
let exponent = key.public_exponent.as_bytes().to_vec();
Ok((modulus, exponent))
}
fn extract_ecc_public_key(
pub_key: &[u8],
curve_type: CurveType,
) -> Result<Vec<u8>, InvalidSignature> {
match curve_type {
CurveType::NistP256 | CurveType::NistP384 | CurveType::NistP521 => {
if pub_key.starts_with(&[0x04]) {
Ok(pub_key[1..].to_vec())
} else {
Err(InvalidSignature) }
}
CurveType::Curve25519 => {
Err(InvalidSignature)
}
}
}
fn extract_ecc_signature(signature: &[u8], curve: CurveType) -> Result<Vec<u8>, InvalidSignature> {
let signature = AsnRsaPublicKey::try_from(signature).map_err(|_| InvalidSignature)?;
let component_length = curve.get_size() as usize;
let r = signature.modulus.as_bytes(); let s = signature.public_exponent.as_bytes();
let mut r_padded = Vec::with_capacity(component_length);
let mut s_padded = Vec::with_capacity(component_length);
if r.len() > component_length || s.len() > component_length {
return Err(InvalidSignature);
}
if r.len() < component_length {
r_padded.extend(std::iter::repeat(0).take(component_length - r.len()));
}
r_padded.extend_from_slice(r);
if s.len() < component_length {
s_padded.extend(std::iter::repeat(0).take(component_length - s.len()));
}
s_padded.extend_from_slice(s);
Ok([r_padded, s_padded].concat())
}
pub static SUPPORTED_SIG_ALGS: WebPkiSupportedAlgorithms = WebPkiSupportedAlgorithms {
all: &[
ECDSA_P256_SHA256,
ECDSA_P256_SHA384,
ECDSA_P384_SHA256,
ECDSA_P384_SHA384,
ECDSA_P521_SHA256,
ECDSA_P521_SHA384,
ECDSA_P521_SHA512,
RSA_PKCS1_SHA256,
RSA_PKCS1_SHA384,
RSA_PKCS1_SHA512,
RSA_PSS_SHA256,
RSA_PSS_SHA384,
RSA_PSS_SHA512,
],
mapping: &[
(
SignatureScheme::ECDSA_NISTP384_SHA384,
&[ECDSA_P384_SHA384, ECDSA_P256_SHA384, ECDSA_P521_SHA384],
),
(
SignatureScheme::ECDSA_NISTP256_SHA256,
&[ECDSA_P256_SHA256, ECDSA_P384_SHA256, ECDSA_P521_SHA256],
),
(SignatureScheme::ECDSA_NISTP521_SHA512, &[ECDSA_P521_SHA512]),
(SignatureScheme::RSA_PSS_SHA512, &[RSA_PSS_SHA512]),
(SignatureScheme::RSA_PSS_SHA384, &[RSA_PSS_SHA384]),
(SignatureScheme::RSA_PSS_SHA256, &[RSA_PSS_SHA256]),
(SignatureScheme::RSA_PKCS1_SHA512, &[RSA_PKCS1_SHA512]),
(SignatureScheme::RSA_PKCS1_SHA384, &[RSA_PKCS1_SHA384]),
(SignatureScheme::RSA_PKCS1_SHA256, &[RSA_PKCS1_SHA256]),
],
};
fn hash_sha256(data: &[u8]) -> Vec<u8> {
sha256(data).to_vec()
}
fn hash_sha384(data: &[u8]) -> Vec<u8> {
sha384(data).to_vec()
}
fn hash_sha512(data: &[u8]) -> Vec<u8> {
sha512(data).to_vec()
}
#[derive(Debug)]
enum KeyType {
RsaPkcs1(RsaPkcs1),
RsaPss(RsaPss),
Ecc(Ecc),
}
#[derive(Debug)]
struct RsaPkcs1 {
hash_algorithm: HashAlgorithm,
}
#[derive(Debug)]
struct RsaPss {
hash_algorithm: HashAlgorithm,
salt_length: u32,
}
#[derive(Debug)]
struct Ecc {
curve: CurveType,
}
pub static ECDSA_P256_SHA256: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::ECDSA_P256,
signature_alg_id: alg_id::ECDSA_SHA256,
hasher: hash_sha256,
key_type: KeyType::Ecc(Ecc {
curve: CurveType::NistP256,
}),
};
pub static ECDSA_P256_SHA384: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::ECDSA_P256,
signature_alg_id: alg_id::ECDSA_SHA384,
hasher: hash_sha384,
key_type: KeyType::Ecc(Ecc {
curve: CurveType::NistP256,
}),
};
pub static ECDSA_P384_SHA256: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::ECDSA_P384,
signature_alg_id: alg_id::ECDSA_SHA256,
hasher: hash_sha256,
key_type: KeyType::Ecc(Ecc {
curve: CurveType::NistP384,
}),
};
pub static ECDSA_P384_SHA384: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::ECDSA_P384,
signature_alg_id: alg_id::ECDSA_SHA384,
hasher: hash_sha384,
key_type: KeyType::Ecc(Ecc {
curve: CurveType::NistP384,
}),
};
pub static ECDSA_P521_SHA256: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::ECDSA_P521,
signature_alg_id: alg_id::ECDSA_SHA256,
hasher: hash_sha256,
key_type: KeyType::Ecc(Ecc {
curve: CurveType::NistP521,
}),
};
pub static ECDSA_P521_SHA384: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::ECDSA_P521,
signature_alg_id: alg_id::ECDSA_SHA384,
hasher: hash_sha384,
key_type: KeyType::Ecc(Ecc {
curve: CurveType::NistP521,
}),
};
pub static ECDSA_P521_SHA512: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::ECDSA_P521,
signature_alg_id: alg_id::ECDSA_SHA512,
hasher: hash_sha512,
key_type: KeyType::Ecc(Ecc {
curve: CurveType::NistP521,
}),
};
pub static RSA_PKCS1_SHA256: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::RSA_ENCRYPTION,
signature_alg_id: alg_id::RSA_PKCS1_SHA256,
hasher: hash_sha256,
key_type: KeyType::RsaPkcs1(RsaPkcs1 {
hash_algorithm: HashAlgorithm::Sha256,
}),
};
pub static RSA_PKCS1_SHA384: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::RSA_ENCRYPTION,
signature_alg_id: alg_id::RSA_PKCS1_SHA384,
hasher: hash_sha384,
key_type: KeyType::RsaPkcs1(RsaPkcs1 {
hash_algorithm: HashAlgorithm::Sha384,
}),
};
pub static RSA_PKCS1_SHA512: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::RSA_ENCRYPTION,
signature_alg_id: alg_id::RSA_PKCS1_SHA512,
hasher: hash_sha512,
key_type: KeyType::RsaPkcs1(RsaPkcs1 {
hash_algorithm: HashAlgorithm::Sha512,
}),
};
pub static RSA_PSS_SHA256: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::RSA_ENCRYPTION,
signature_alg_id: alg_id::RSA_PSS_SHA256,
hasher: hash_sha256,
key_type: KeyType::RsaPss(RsaPss {
hash_algorithm: HashAlgorithm::Sha256,
salt_length: 32,
}),
};
pub static RSA_PSS_SHA384: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::RSA_ENCRYPTION,
signature_alg_id: alg_id::RSA_PSS_SHA384,
hasher: hash_sha384,
key_type: KeyType::RsaPss(RsaPss {
hash_algorithm: HashAlgorithm::Sha384,
salt_length: 48,
}),
};
pub static RSA_PSS_SHA512: &dyn SignatureVerificationAlgorithm = &SymCryptAlgorithm {
public_key_alg_id: alg_id::RSA_ENCRYPTION,
signature_alg_id: alg_id::RSA_PSS_SHA512,
hasher: hash_sha512,
key_type: KeyType::RsaPss(RsaPss {
hash_algorithm: HashAlgorithm::Sha512,
salt_length: 64,
}),
};
#[derive(Debug)]
struct SymCryptAlgorithm {
public_key_alg_id: AlgorithmIdentifier,
signature_alg_id: AlgorithmIdentifier,
hasher: fn(&[u8]) -> Vec<u8>,
key_type: KeyType,
}
impl SignatureVerificationAlgorithm for SymCryptAlgorithm {
fn verify_signature(
&self,
public_key: &[u8],
message: &[u8],
signature: &[u8],
) -> Result<(), InvalidSignature> {
match &self.key_type {
KeyType::Ecc(ecc) => {
let key = extract_ecc_public_key(public_key, ecc.curve)?;
let sig = extract_ecc_signature(signature, ecc.curve)?;
let ec_key = EcKey::set_public_key(ecc.curve, &key, EcKeyUsage::EcDsa)
.map_err(|_| InvalidSignature)?;
let hashed_message = (self.hasher)(message);
ec_key
.ecdsa_verify(&sig, &hashed_message)
.map_err(|_| InvalidSignature)
}
KeyType::RsaPkcs1(rsa_pkcs1) => {
let (modulus, exponent) = extract_rsa_public_key(public_key)?;
let rsa_key = RsaKey::set_public_key(&modulus, &exponent, RsaKeyUsage::Sign)
.map_err(|_| InvalidSignature)?;
let hashed_message = (self.hasher)(message);
rsa_key
.pkcs1_verify(&hashed_message, signature, rsa_pkcs1.hash_algorithm)
.map_err(|_| InvalidSignature)
}
KeyType::RsaPss(rsa_pss) => {
let (modulus, exponent) = extract_rsa_public_key(public_key)?;
let rsa_key = RsaKey::set_public_key(&modulus, &exponent, RsaKeyUsage::Sign)
.map_err(|_| InvalidSignature)?;
let hashed_message = (self.hasher)(message);
rsa_key
.pss_verify(
&hashed_message,
signature,
rsa_pss.hash_algorithm,
rsa_pss.salt_length as usize,
)
.map_err(|_| InvalidSignature)
}
}
}
fn public_key_alg_id(&self) -> AlgorithmIdentifier {
self.public_key_alg_id
}
fn signature_alg_id(&self) -> AlgorithmIdentifier {
self.signature_alg_id
}
fn fips(&self) -> bool {
true
}
}