use std::fmt;
use aws_lc_rs::{rsa, signature::{self, ParsedPublicKey}};
use base64::{Engine as _, prelude::BASE64_URL_SAFE_NO_PAD};
use crate::{Alg, Error, Header, Jwk, SignatureValid, jwk};
#[derive(Debug, Clone)]
pub struct VerifyingKey(VerifyingKeyInner);
#[derive(Debug, Clone)]
enum VerifyingKeyInner {
Curve {
key: ParsedPublicKey,
alg: CurveAlg,
},
Rsa {
key: rsa::PublicKeyComponents<Vec<u8>>,
alg: Option<RsaAlg>,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CurveAlg {
ES256,
ES384,
ES512,
EdDSA,
}
impl CurveAlg {
fn from_alg(alg: &Alg<'_>) -> Option<Self> {
match alg {
Alg::ES256 => Some(Self::ES256),
Alg::ES384 => Some(Self::ES384),
Alg::ES512 => Some(Self::ES512),
Alg::EdDSA => Some(Self::EdDSA),
_ => None,
}
}
fn to_alg(self) -> Alg<'static> {
match self {
Self::ES256 => Alg::ES256,
Self::ES384 => Alg::ES384,
Self::ES512 => Alg::ES512,
Self::EdDSA => Alg::EdDSA,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum RsaAlg {
RS256,
RS384,
RS512,
PS256,
PS384,
PS512,
}
impl RsaAlg {
fn from_alg(alg: &Alg<'_>) -> Option<Self> {
match alg {
Alg::RS256 => Some(RsaAlg::RS256),
Alg::RS384 => Some(RsaAlg::RS384),
Alg::RS512 => Some(RsaAlg::RS512),
Alg::PS256 => Some(RsaAlg::PS256),
Alg::PS384 => Some(RsaAlg::PS384),
Alg::PS512 => Some(RsaAlg::PS512),
_ => None,
}
}
}
macro_rules! bail {
($($t:tt)*) => {
return Err(Error::InvalidJwk(format!($($t)*)))
};
}
impl VerifyingKey {
pub fn from_jwk(jwk: &Jwk<'_>) -> Result<Self, Error> {
if jwk.usage.as_ref().is_some_and(|usage| *usage != jwk::KeyUsage::Signature) {
bail!("Field `use` of key is not 'sig'");
}
let inner = match &jwk.key_data {
jwk::KeyData::Ec { crv, x, y, .. } => {
let (alg, crypto_algo) = match crv {
jwk::EcCurve::P256 => (CurveAlg::ES256, &signature::ECDSA_P256_SHA256_FIXED),
jwk::EcCurve::P384 => (CurveAlg::ES384, &signature::ECDSA_P384_SHA384_FIXED),
jwk::EcCurve::P521 => (CurveAlg::ES512, &signature::ECDSA_P521_SHA512_FIXED),
_ => bail!("Curve type '{crv}' not supported"),
};
assert_alg(&jwk.alg, alg.to_alg(), crv)?;
let Some(y) = y else { bail!("curve is missing y coordinate"); };
let key = ParsedPublicKey::new(crypto_algo, ecdsa_sec1_key(x, y)?)?;
VerifyingKeyInner::Curve { alg, key }
}
jwk::KeyData::Okp { crv: crv @ jwk::OkpCurve::Ed25519, x, .. } => {
assert_alg(&jwk.alg, Alg::EdDSA, crv)?;
VerifyingKeyInner::Curve {
alg: CurveAlg::EdDSA,
key: ParsedPublicKey::new(
&aws_lc_rs::signature::ED25519,
BASE64_URL_SAFE_NO_PAD.decode(x.as_bytes()).map_err(base64_err)?,
)?,
}
}
jwk::KeyData::Okp { crv, .. } => bail!("Curve type '{crv}' not supported"),
jwk::KeyData::Rsa { n, e, .. } => {
let alg = match &jwk.alg {
Some(alg) => {
match RsaAlg::from_alg(alg) {
Some(alg) => Some(alg),
None => bail!("alg is '{alg}', but kty is 'RSA', which is incompatible"),
}
},
None => None,
};
let key = rsa::PublicKeyComponents {
n: BASE64_URL_SAFE_NO_PAD.decode(n.as_bytes()).map_err(base64_err)?,
e: BASE64_URL_SAFE_NO_PAD.decode(e.as_bytes()).map_err(base64_err)?,
};
VerifyingKeyInner::Rsa { key, alg }
},
jwk::KeyData::Oct { .. } => return Err(Error::UnsupportedAlg),
};
Ok(Self(inner))
}
pub fn supports_alg(&self, alg: &Alg<'_>) -> bool {
let query_alg = alg;
match &self.0 {
VerifyingKeyInner::Curve { alg, .. } => CurveAlg::from_alg(query_alg) == Some(*alg),
VerifyingKeyInner::Rsa { alg, .. } => {
let Some(query_alg) = RsaAlg::from_alg(query_alg) else { return false };
alg.is_none_or(|alg| alg == query_alg)
}
}
}
pub fn verify<E>(
&self,
header: &Header<'_, E>,
message: &str,
signature: &[u8],
) -> Result<SignatureValid, Error> {
if !self.supports_alg(&header.alg) {
return Err(Error::AlgoMismatch);
}
match &self.0 {
VerifyingKeyInner::Curve { key, .. } => {
key.verify_sig(message.as_bytes(), signature)
.map_err(|_| Error::InvalidSignature)
.map(|_| SignatureValid::unchecked_create_proof())
}
VerifyingKeyInner::Rsa { key, .. } => {
let Some(jwt_alg) = RsaAlg::from_alg(&header.alg) else {
unreachable!("`supports_alg` should have taken care of this");
};
let params = match jwt_alg {
RsaAlg::RS256 => &signature::RSA_PKCS1_2048_8192_SHA256,
RsaAlg::RS384 => &signature::RSA_PKCS1_2048_8192_SHA384,
RsaAlg::RS512 => &signature::RSA_PKCS1_2048_8192_SHA512,
RsaAlg::PS256 => &signature::RSA_PSS_2048_8192_SHA256,
RsaAlg::PS384 => &signature::RSA_PSS_2048_8192_SHA384,
RsaAlg::PS512 => &signature::RSA_PSS_2048_8192_SHA512,
};
key.verify(params, message.as_bytes(), signature)
.map_err(|_| Error::InvalidSignature)
.map(|_| SignatureValid::unchecked_create_proof())
}
}
}
}
fn base64_err(source: base64::DecodeError) -> Error {
Error::InvalidJwk(format!("invalid base64: {source}"))
}
fn ecdsa_sec1_key(x: &str, y: &str) -> Result<Vec<u8>, Error> {
let mut out = vec![4]; BASE64_URL_SAFE_NO_PAD.decode_vec(x, &mut out).map_err(base64_err)?;
BASE64_URL_SAFE_NO_PAD.decode_vec(y, &mut out).map_err(base64_err)?;
Ok(out)
}
fn assert_alg(
actual: &Option<Alg<'_>>,
expected: Alg<'_>,
key_type: impl fmt::Display,
) -> Result<(), Error> {
if actual.as_ref().is_some_and(|actual| *actual != expected) {
bail!("key type {key_type} does not match 'alg' field {actual:?}");
}
Ok(())
}