jwtea 0.1.0

Lean JWT library
Documentation
use std::fmt;

use aws_lc_rs::{rsa, signature::{self, ParsedPublicKey}};
use base64::{Engine as _, prelude::BASE64_URL_SAFE_NO_PAD};

use crate::{Alg, Error, Header, Jwk, SignatureValid, jwk};


/// A cryptographic key for verifying signatures.
#[derive(Debug, Clone)]
pub struct VerifyingKey(VerifyingKeyInner);

#[derive(Debug, Clone)]
enum VerifyingKeyInner {
    Curve {
        key: ParsedPublicKey,
        alg: CurveAlg,
    },
    Rsa {
        key: rsa::PublicKeyComponents<Vec<u8>>,

        /// The `alg` field of the JWK, which might be unspecified. If it is,
        /// this key can in theory be used for JWTs with different RSA `alg`
        /// values.
        alg: Option<RsaAlg>,
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CurveAlg {
    ES256,
    ES384,
    ES512,
    EdDSA,
}

impl CurveAlg {
    fn from_alg(alg: &Alg<'_>) -> Option<Self> {
        match alg {
            Alg::ES256 => Some(Self::ES256),
            Alg::ES384 => Some(Self::ES384),
            Alg::ES512 => Some(Self::ES512),
            Alg::EdDSA => Some(Self::EdDSA),
            _ => None,
        }
    }

    fn to_alg(self) -> Alg<'static> {
        match self {
            Self::ES256 => Alg::ES256,
            Self::ES384 => Alg::ES384,
            Self::ES512 => Alg::ES512,
            Self::EdDSA => Alg::EdDSA,
        }
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum RsaAlg {
    RS256,
    RS384,
    RS512,
    PS256,
    PS384,
    PS512,
}

impl RsaAlg {
    fn from_alg(alg: &Alg<'_>) -> Option<Self> {
        match alg {
            Alg::RS256 => Some(RsaAlg::RS256),
            Alg::RS384 => Some(RsaAlg::RS384),
            Alg::RS512 => Some(RsaAlg::RS512),
            Alg::PS256 => Some(RsaAlg::PS256),
            Alg::PS384 => Some(RsaAlg::PS384),
            Alg::PS512 => Some(RsaAlg::PS512),
            _ => None,
        }
    }
}


/// Helper to return `Error::InvalidJwk`.
macro_rules! bail {
    ($($t:tt)*) => {
        return Err(Error::InvalidJwk(format!($($t)*)))
    };
}

impl VerifyingKey {
    /// Creates a verifying key from a JWK.
    ///
    /// Several things about the JWK are checked and errors are returned if any
    /// check fails. If the JWK contains private key data, that is ignored.
    pub fn from_jwk(jwk: &Jwk<'_>) -> Result<Self, Error> {
        if jwk.usage.as_ref().is_some_and(|usage| *usage != jwk::KeyUsage::Signature) {
            bail!("Field `use` of key is not 'sig'");
        }

        let inner = match &jwk.key_data {
            jwk::KeyData::Ec { crv, x, y, .. } => {
                let (alg, crypto_algo) = match crv {
                    jwk::EcCurve::P256 => (CurveAlg::ES256, &signature::ECDSA_P256_SHA256_FIXED),
                    jwk::EcCurve::P384 => (CurveAlg::ES384, &signature::ECDSA_P384_SHA384_FIXED),
                    jwk::EcCurve::P521 => (CurveAlg::ES512, &signature::ECDSA_P521_SHA512_FIXED),
                    _ => bail!("Curve type '{crv}' not supported"),
                };
                assert_alg(&jwk.alg, alg.to_alg(), crv)?;
                let Some(y) = y else { bail!("curve is missing y coordinate"); };
                let key = ParsedPublicKey::new(crypto_algo, ecdsa_sec1_key(x, y)?)?;

                VerifyingKeyInner::Curve { alg, key }
            }

            jwk::KeyData::Okp { crv: crv @ jwk::OkpCurve::Ed25519, x, .. } => {
                assert_alg(&jwk.alg, Alg::EdDSA, crv)?;

                VerifyingKeyInner::Curve {
                    alg: CurveAlg::EdDSA,
                    key: ParsedPublicKey::new(
                        &aws_lc_rs::signature::ED25519,
                        BASE64_URL_SAFE_NO_PAD.decode(x.as_bytes()).map_err(base64_err)?,
                    )?,
                }
            }
            jwk::KeyData::Okp { crv, .. } => bail!("Curve type '{crv}' not supported"),

            jwk::KeyData::Rsa { n, e, .. } => {
                let alg = match &jwk.alg {
                    Some(alg) => {
                        match RsaAlg::from_alg(alg) {
                            Some(alg) => Some(alg),
                            None => bail!("alg is '{alg}', but kty is 'RSA', which is incompatible"),
                        }
                    },
                    None => None,
                };

                let key = rsa::PublicKeyComponents {
                    n: BASE64_URL_SAFE_NO_PAD.decode(n.as_bytes()).map_err(base64_err)?,
                    e: BASE64_URL_SAFE_NO_PAD.decode(e.as_bytes()).map_err(base64_err)?,
                };

                VerifyingKeyInner::Rsa { key, alg }
            },

            jwk::KeyData::Oct { .. } => return Err(Error::UnsupportedAlg),
        };

        Ok(Self(inner))
    }

    /// Returns whether this key can verify a signature with the given alg.
    ///
    /// Curve keys (`ES*` and `EdDSA`) support a single fixed algorithm. RSA
    /// keys can in theory support all RSA-based algorithms (`RS*` and `PS*`),
    /// but if the JWK specified the `alg` field, this key can only be used for
    /// that algorithm.
    pub fn supports_alg(&self, alg: &Alg<'_>) -> bool {
        let query_alg = alg;
        match &self.0 {
            VerifyingKeyInner::Curve { alg, .. } => CurveAlg::from_alg(query_alg) == Some(*alg),
            VerifyingKeyInner::Rsa { alg, .. } => {
                let Some(query_alg) = RsaAlg::from_alg(query_alg) else { return false };
                alg.is_none_or(|alg| alg == query_alg)
            }
        }
    }

    /// Tries to verify the given signature.
    ///
    /// If this key does not support the `header.alg`, an error is returned.
    pub fn verify<E>(
        &self,
        header: &Header<'_, E>,
        message: &str,
        signature: &[u8],
    ) -> Result<SignatureValid, Error> {
        if !self.supports_alg(&header.alg) {
            return Err(Error::AlgoMismatch);
        }

        match &self.0 {
            VerifyingKeyInner::Curve { key, .. } => {
                key.verify_sig(message.as_bytes(), signature)
                    .map_err(|_| Error::InvalidSignature)
                    .map(|_| SignatureValid::unchecked_create_proof())
            }

            VerifyingKeyInner::Rsa { key, .. } => {
                let Some(jwt_alg) = RsaAlg::from_alg(&header.alg) else {
                    unreachable!("`supports_alg` should have taken care of this");
                };

                let params = match jwt_alg {
                    RsaAlg::RS256 => &signature::RSA_PKCS1_2048_8192_SHA256,
                    RsaAlg::RS384 => &signature::RSA_PKCS1_2048_8192_SHA384,
                    RsaAlg::RS512 => &signature::RSA_PKCS1_2048_8192_SHA512,
                    RsaAlg::PS256 => &signature::RSA_PSS_2048_8192_SHA256,
                    RsaAlg::PS384 => &signature::RSA_PSS_2048_8192_SHA384,
                    RsaAlg::PS512 => &signature::RSA_PSS_2048_8192_SHA512,
                };
                key.verify(params, message.as_bytes(), signature)
                    .map_err(|_| Error::InvalidSignature)
                    .map(|_| SignatureValid::unchecked_create_proof())
            }
        }
    }
}

fn base64_err(source: base64::DecodeError) -> Error {
    Error::InvalidJwk(format!("invalid base64: {source}"))
}

/// Returns the SEC1 uncompressed form of the ECDSA key specified by the
/// two base64 x and y coordinates.
fn ecdsa_sec1_key(x: &str, y: &str) -> Result<Vec<u8>, Error> {
    let mut out = vec![4]; // Header for SEC1 uncompressed form
    BASE64_URL_SAFE_NO_PAD.decode_vec(x, &mut out).map_err(base64_err)?;
    BASE64_URL_SAFE_NO_PAD.decode_vec(y, &mut out).map_err(base64_err)?;
    Ok(out)
}

/// Makes sure the `alg` matches the value expected by the key type.
fn assert_alg(
    actual: &Option<Alg<'_>>,
    expected: Alg<'_>,
    key_type: impl fmt::Display,
) -> Result<(), Error> {
    if actual.as_ref().is_some_and(|actual| *actual != expected) {
        bail!("key type {key_type} does not match 'alg' field {actual:?}");
    }
    Ok(())
}