jwtk 0.2.4

JWT signing (JWS) and verification, with first class JWK and JWK Set (JWKS) support.
Documentation
//! JWK and JWK Set.
//!
//! Only public keys are really supported for now.

use std::collections::{BTreeMap, HashMap};

use crate::{
    ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey, EcdsaPublicKey},
    eddsa::{Ed25519PrivateKey, Ed25519PublicKey},
    rsa::{RsaAlgorithm, RsaPrivateKey, RsaPublicKey},
    some::SomePublicKey,
    url_safe_trailing_bits, verify, verify_only, Error, Header, HeaderAndClaims, PublicKeyToJwk,
    Result, SigningKey, SomePrivateKey, VerificationKey,
};
use openssl::{
    bn::BigNum,
    hash::{hash, MessageDigest},
    pkey::PKey,
    rsa::{Rsa, RsaPrivateKeyBuilder},
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value;

// TODO: private key jwk.

/// JWK Representation.
#[non_exhaustive]
#[derive(Debug, Deserialize, Serialize, Default)]
pub struct Jwk {
    pub kty: String,
    #[serde(rename = "use", skip_serializing_if = "Option::is_none")]
    pub use_: Option<String>,
    #[serde(default, skip_serializing_if = "Vec::is_empty")]
    pub key_ops: Vec<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub alg: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub crv: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub kid: Option<String>,

    #[serde(skip_serializing_if = "Option::is_none")]
    pub n: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub e: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub x: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub y: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub d: Option<String>,

    // RSA private key.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub p: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub q: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub dp: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub dq: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub qi: Option<String>,
    #[serde(default, skip_serializing_if = "Vec::is_empty")]
    pub oth: Vec<Value>,
}

impl Jwk {
    pub fn to_verification_key(&self) -> Result<SomePublicKey> {
        // Check `use` and `key_ops`.
        if !matches!(self.use_.as_deref(), None | Some("sig")) {
            return Err(Error::UnsupportedOrInvalidKey);
        }
        if !(self.key_ops.is_empty() || self.key_ops.iter().any(|ops| ops == "verify")) {
            return Err(Error::UnsupportedOrInvalidKey);
        }

        // If let would be too long.
        #[allow(clippy::single_match)]
        match &*self.kty {
            "RSA" => match (self.alg.as_deref(), &self.n, &self.e) {
                (alg, Some(ref n), Some(ref e)) => {
                    let n = base64::decode_config(n, url_safe_trailing_bits())?;
                    let e = base64::decode_config(e, url_safe_trailing_bits())?;
                    // If `alg` is specified, the key will only verify
                    // signatures generated by ONLY this specific `alg`,
                    // otherwise it will verify signatures generated by ANY RSA
                    // algorithm.
                    let alg = if let Some(alg) = alg {
                        Some(RsaAlgorithm::from_name(alg)?)
                    } else {
                        None
                    };
                    return Ok(SomePublicKey::Rsa(RsaPublicKey::from_components(
                        &n, &e, alg,
                    )?));
                }
                _ => {}
            },
            "EC" => match (self.crv.as_deref(), &self.x, &self.y) {
                // For EC keys `crv` is required.
                (Some(crv), Some(ref x), Some(ref y)) => {
                    let x = base64::decode_config(x, url_safe_trailing_bits())?;
                    let y = base64::decode_config(y, url_safe_trailing_bits())?;
                    let alg = EcdsaAlgorithm::from_curve_name(crv)?;
                    return Ok(SomePublicKey::Ecdsa(EcdsaPublicKey::from_coordinates(
                        &x, &y, alg,
                    )?));
                }
                _ => {}
            },
            "OKP" => match (self.crv.as_deref(), &self.x) {
                (Some(crv), Some(ref x)) => {
                    let x = base64::decode_config(x, url_safe_trailing_bits())?;
                    match crv {
                        "Ed25519" => {
                            return Ok(SomePublicKey::Ed25519(Ed25519PublicKey::from_bytes(&x)?));
                        }
                        _ => {}
                    }
                }
                _ => {}
            },
            _ => {}
        }

        Err(Error::UnsupportedOrInvalidKey)
    }

    #[allow(clippy::many_single_char_names)]
    pub fn to_signing_key(&self, rsa_fallback_algorithm: RsaAlgorithm) -> Result<SomePrivateKey> {
        match &*self.kty {
            "RSA" => {
                let alg = if let Some(ref alg) = self.alg {
                    RsaAlgorithm::from_name(alg)?
                } else {
                    rsa_fallback_algorithm
                };
                match (self.d.as_deref(), self.n.as_deref(), self.e.as_deref()) {
                    (Some(d), Some(n), Some(e)) => {
                        fn decode(x: &str) -> Result<BigNum> {
                            Ok(BigNum::from_slice(&base64::decode_config(
                                x,
                                url_safe_trailing_bits(),
                            )?)?)
                        }
                        let d = decode(d)?;
                        let n = decode(n)?;
                        let e = decode(e)?;
                        match (
                            self.p.as_deref(),
                            self.q.as_deref(),
                            self.dp.as_deref(),
                            self.dq.as_deref(),
                            self.qi.as_deref(),
                            self.oth.is_empty(),
                        ) {
                            (None, None, None, None, None, true) => {
                                let rsa = RsaPrivateKeyBuilder::new(n, e, d)?.build();
                                let pkey = PKey::from_rsa(rsa)?;
                                RsaPrivateKey::from_pkey_without_check(pkey, alg).map(Into::into)
                            }
                            (Some(p), Some(q), Some(dp), Some(dq), Some(qi), true) => {
                                let p = decode(p)?;
                                let q = decode(q)?;
                                let dp = decode(dp)?;
                                let dq = decode(dq)?;
                                let qi = decode(qi)?;
                                let rsa = Rsa::from_private_components(n, e, d, p, q, dp, dq, qi)?;
                                let pkey = PKey::from_rsa(rsa)?;
                                RsaPrivateKey::from_pkey(pkey, alg).map(Into::into)
                            }
                            _ => Err(Error::UnsupportedOrInvalidKey),
                        }
                    }
                    _ => Err(Error::UnsupportedOrInvalidKey),
                }
            }
            "EC" => {
                match (
                    self.crv.as_deref(),
                    self.d.as_deref(),
                    self.x.as_deref(),
                    self.y.as_deref(),
                ) {
                    (Some(crv), Some(d), Some(x), Some(y)) => {
                        let alg = EcdsaAlgorithm::from_curve_name(crv)?;
                        let d = base64::decode_config(d, url_safe_trailing_bits())?;
                        let x = base64::decode_config(x, url_safe_trailing_bits())?;
                        let y = base64::decode_config(y, url_safe_trailing_bits())?;
                        EcdsaPrivateKey::from_private_components(alg, &d, &x, &y).map(Into::into)
                    }
                    _ => Err(Error::UnsupportedOrInvalidKey),
                }
            }
            "OKP" => match (self.crv.as_deref(), self.d.as_deref()) {
                (Some("Ed25519"), Some(d)) => {
                    let d = base64::decode_config(d, url_safe_trailing_bits())?;
                    Ed25519PrivateKey::from_bytes(&d).map(Into::into)
                }
                _ => Err(Error::UnsupportedOrInvalidKey),
            },
            _ => Err(Error::UnsupportedOrInvalidKey),
        }
    }

    /// Get key thumbprint (rfc 7638) with SHA-256.
    pub fn get_thumbprint_sha256(&self) -> Result<[u8; 32]> {
        let as_json = match &*self.kty {
            "RSA" => {
                let mut v = BTreeMap::new();
                v.insert(
                    "e",
                    self.e.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
                );
                v.insert("kty", "RSA");
                v.insert(
                    "n",
                    self.n.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
                );
                serde_json::to_string(&v)?
            }
            "EC" => {
                let mut v = BTreeMap::new();
                v.insert(
                    "crv",
                    self.crv.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
                );
                v.insert("kty", "EC");
                v.insert(
                    "x",
                    self.x.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
                );
                v.insert(
                    "y",
                    self.y.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
                );
                serde_json::to_string(&v)?
            }
            "OKP" => {
                let mut v = BTreeMap::new();
                v.insert(
                    "crv",
                    self.crv.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
                );
                v.insert("kty", "OKP");
                v.insert(
                    "x",
                    self.x.as_deref().ok_or(Error::UnsupportedOrInvalidKey)?,
                );
                serde_json::to_string(&v)?
            }
            _ => return Err(Error::UnsupportedOrInvalidKey),
        };
        let hash = hash(MessageDigest::sha256(), as_json.as_bytes())?;
        let mut out = [0u8; 32];
        out.copy_from_slice(&hash[..]);
        Ok(out)
    }

    /// Get key thumbprint with SHA-256, base64url-encoded.
    pub fn get_thumbprint_sha256_base64(&self) -> Result<String> {
        Ok(base64::encode_config(
            self.get_thumbprint_sha256()?,
            url_safe_trailing_bits(),
        ))
    }
}

/// JWK Set Representation.
#[derive(Debug, Serialize, Deserialize)]
pub struct JwkSet {
    pub keys: Vec<Jwk>,
}

impl JwkSet {
    pub fn verifier(&self) -> JwkSetVerifier {
        let mut prepared = JwkSetVerifier {
            keys: HashMap::new(),
            require_kid: true,
        };
        for k in self.keys.iter() {
            if let Some(ref kid) = k.kid {
                if let Ok(vk) = k.to_verification_key() {
                    prepared.keys.insert(kid.clone(), vk);
                }
            }
        }
        prepared
    }
}

/// Jwk set parsed and converted, ready to verify tokens.
pub struct JwkSetVerifier {
    keys: HashMap<String, SomePublicKey>,
    require_kid: bool,
}

impl JwkSetVerifier {
    /// If called with `false`, subsequent `verify` and `verify_only` calls will
    /// try all keys from the key set if a `kid` is not specified in the token.
    pub fn set_require_kid(&mut self, required: bool) {
        self.require_kid = required;
    }

    pub fn find(&self, kid: &str) -> Option<&SomePublicKey> {
        if let Some(vk) = self.keys.get(kid) {
            Some(vk)
        } else {
            None
        }
    }

    /// Decode and verify token with keys from this JWK set.
    ///
    /// The `alg`, `exp` and `nbf` fields are automatically checked.
    pub fn verify<ExtraClaims: DeserializeOwned>(
        &self,
        token: &str,
    ) -> Result<HeaderAndClaims<ExtraClaims>> {
        self.find_and_verify(token, verify)
    }

    /// Decode and verify token with keys from this JWK set. Won't check `exp` and `nbf`.
    pub fn verify_only<ExtraClaims: DeserializeOwned>(
        &self,
        token: &str,
    ) -> Result<HeaderAndClaims<ExtraClaims>> {
        self.find_and_verify(token, verify_only)
    }

    /// Find and verify token with keys from this JWK set.
    ///
    /// restrict_kid is true will only match keys with the same `kid`.
    fn find_and_verify<ExtraClaims: DeserializeOwned>(
        &self,
        token: &str,
        verifier: fn(&str, &dyn VerificationKey) -> Result<HeaderAndClaims<ExtraClaims>>,
    ) -> Result<HeaderAndClaims<ExtraClaims>> {
        let mut parts = token.split('.');

        let mut header = parts.next().ok_or(Error::InvalidToken)?.as_bytes();

        let header_r = base64::read::DecoderReader::new(&mut header, url_safe_trailing_bits());
        let header: Header = serde_json::from_reader(header_r)?;

        if let Some(kid) = header.kid {
            let k = self.find(&kid).ok_or(Error::NoKey)?;
            verifier(token, k)
        } else if !self.require_kid {
            if let Some(res) = self
                .keys
                .iter()
                .map(|(_, key)| verifier(token, key))
                .find_map(|res| res.ok())
            {
                Ok(res)
            } else {
                Err(Error::NoKey)
            }
        } else {
            Err(Error::NoKey)
        }
    }
}

/// A key associated with a key id (`kid`).
///
/// When the key is used for signing, `kid` is automatically set.
#[derive(Debug)]
pub struct WithKid<S> {
    kid: String,
    inner: S,
}

impl<S> WithKid<S> {
    pub fn new(kid: String, inner: S) -> Self {
        Self { kid, inner }
    }

    /// Use key thumbprint as key id.
    pub fn new_with_thumbprint_id(inner: S) -> Result<Self>
    where
        S: PublicKeyToJwk,
    {
        Ok(Self {
            kid: inner.public_key_to_jwk()?.get_thumbprint_sha256_base64()?,
            inner,
        })
    }

    pub fn kid(&self) -> &str {
        &self.kid
    }

    pub fn set_kid(&mut self, kid: impl Into<String>) {
        self.kid = kid.into();
    }

    pub fn as_inner(&self) -> &S {
        &self.inner
    }

    pub fn into_inner(self) -> S {
        self.inner
    }

    pub fn as_inner_mut(&mut self) -> &mut S {
        &mut self.inner
    }
}

impl<S: SigningKey> SigningKey for WithKid<S> {
    fn kid(&self) -> Option<&str> {
        Some(&self.kid)
    }

    fn sign(&self, v: &[u8]) -> Result<smallvec::SmallVec<[u8; 64]>> {
        self.inner.sign(v)
    }

    fn alg(&self) -> &'static str {
        self.inner.alg()
    }
}

impl<S: VerificationKey> VerificationKey for WithKid<S> {
    fn verify(&self, v: &[u8], sig: &[u8], alg: &str) -> Result<()> {
        self.inner.verify(v, sig, alg)
    }
}

impl<K: PublicKeyToJwk> PublicKeyToJwk for WithKid<K> {
    fn public_key_to_jwk(&self) -> Result<Jwk> {
        let mut jwk = self.inner.public_key_to_jwk()?;
        jwk.kid = Some(self.kid.clone());
        Ok(jwk)
    }
}

#[cfg(feature = "remote-jwks")]
struct JWKSCache {
    jwks: JwkSetVerifier,
    valid_until: std::time::Instant,
}

/// A JWK Set served from a remote url. Automatically fetched and cached.
#[cfg(feature = "remote-jwks")]
pub struct RemoteJwksVerifier {
    url: String,
    client: reqwest::Client,
    cache_duration: std::time::Duration,
    cache: tokio::sync::RwLock<Option<JWKSCache>>,
    require_kid: bool,
}

#[cfg(feature = "remote-jwks")]
impl RemoteJwksVerifier {
    pub fn new(
        url: String,
        client: Option<reqwest::Client>,
        cache_duration: std::time::Duration,
    ) -> Self {
        Self {
            url,
            client: client.unwrap_or_default(),
            cache_duration,
            cache: tokio::sync::RwLock::new(None),
            require_kid: true,
        }
    }

    /// If called with `false`, subsequent `verify` and `verify_only` calls will
    /// try all keys from the key set if a `kid` is not specified in the token.
    pub fn set_require_kid(&mut self, required: bool) {
        self.require_kid = required;
        if let Some(ref mut v) = self.cache.get_mut() {
            v.jwks.require_kid = required;
        }
    }

    async fn get_verifier(&self) -> Result<tokio::sync::RwLockReadGuard<'_, JwkSetVerifier>> {
        let cache = self.cache.read().await;
        // Cache still valid.
        if let Some(c) = &*cache {
            if c.valid_until
                .checked_duration_since(std::time::Instant::now())
                .is_some()
            {
                return Ok(tokio::sync::RwLockReadGuard::map(cache, |c| {
                    &c.as_ref().unwrap().jwks
                }));
            }
        }
        drop(cache);

        let mut cache = self.cache.write().await;
        if let Some(c) = &*cache {
            if c.valid_until
                .checked_duration_since(std::time::Instant::now())
                .is_some()
            {
                return Ok(tokio::sync::RwLockReadGuard::map(cache.downgrade(), |c| {
                    &c.as_ref().unwrap().jwks
                }));
            }
        }
        let response = self
            .client
            .get(&self.url)
            .header("accept", "application/json")
            .send()
            .await?;
        let jwks: JwkSet = response.json().await?;

        *cache = Some(JWKSCache {
            jwks: {
                let mut v = jwks.verifier();
                v.require_kid = self.require_kid;
                v
            },
            valid_until: std::time::Instant::now() + self.cache_duration,
        });

        Ok(tokio::sync::RwLockReadGuard::map(cache.downgrade(), |c| {
            &c.as_ref().unwrap().jwks
        }))
    }

    pub async fn verify<E: DeserializeOwned>(&self, token: &str) -> Result<HeaderAndClaims<E>> {
        let v = self.get_verifier().await?;
        v.verify(token)
    }

    pub async fn verify_only<E: DeserializeOwned>(
        &self,
        token: &str,
    ) -> Result<HeaderAndClaims<E>> {
        let v = self.get_verifier().await?;
        v.verify_only(token)
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey},
        eddsa::Ed25519PrivateKey,
        rsa::RsaPrivateKey,
        sign,
    };

    use super::*;

    #[test]
    fn test_jwk() -> Result<()> {
        assert!(Jwk {
            kty: "RSA".to_string(),
            use_: Some("enc".into()),
            ..Default::default()
        }
        .to_verification_key()
        .is_err());
        assert!(Jwk {
            kty: "RSA".to_string(),
            key_ops: vec!["encryption".into()],
            ..Default::default()
        }
        .to_verification_key()
        .is_err());

        Ok(())
    }

    #[test]
    fn test_thumbprint() -> Result<()> {
        RsaPrivateKey::generate(2048, RsaAlgorithm::RS256)?
            .public_key_to_jwk()?
            .get_thumbprint_sha256_base64()?;
        EcdsaPrivateKey::generate(EcdsaAlgorithm::ES256)?
            .public_key_to_jwk()?
            .get_thumbprint_sha256_base64()?;
        Ed25519PrivateKey::generate()?
            .public_key_to_jwk()?
            .get_thumbprint_sha256_base64()?;
        Ok(())
    }

    #[derive(Serialize, Deserialize)]
    struct MyClaim {
        foo: String,
    }

    #[test]
    fn test_jwks_verify() -> Result<()> {
        let k = EcdsaPrivateKey::generate(EcdsaAlgorithm::ES512)?;
        let kk = WithKid::new("my key".into(), k.clone());
        let k_jwk = kk.public_key_to_jwk()?;
        let jwks = JwkSet { keys: vec![k_jwk] };
        let mut verifier = jwks.verifier();

        // jwt with kid
        {
            let mut jwt = HeaderAndClaims::with_claims(MyClaim { foo: "bar".into() });
            jwt.set_kid("my key");
            let token = sign(&mut jwt, &k)?;

            verifier.verify_only::<MyClaim>(&token)?;
            let verified = verifier.verify::<MyClaim>(&token)?;
            assert_eq!(verified.claims.extra.foo, "bar");
        }

        // jwt with not exist kid
        {
            let mut jwt = HeaderAndClaims::with_claims(MyClaim { foo: "bar".into() });
            jwt.set_kid("my key2");
            let token = sign(&mut jwt, &k)?;

            let res = verifier.verify_only::<MyClaim>(&token);
            assert!(res.is_err());
        }

        // jwt with override kid
        {
            let mut jwt = HeaderAndClaims::with_claims(MyClaim { foo: "bar".into() });
            jwt.set_kid("my key2");
            let token = sign(&mut jwt, &kk)?;

            verifier.verify_only::<MyClaim>(&token)?;
            let verified = verifier.verify::<MyClaim>(&token)?;
            assert_eq!(verified.claims.extra.foo, "bar");
        }

        // jwt without kid
        {
            let token = sign(
                &mut HeaderAndClaims::with_claims(MyClaim { foo: "bar".into() }),
                &k,
            )?;

            let res = verifier.verify_only::<MyClaim>(&token);
            assert!(res.is_err());
        }

        // jwt without kid and verifier does not require one.
        {
            let token = sign(
                &mut HeaderAndClaims::with_claims(MyClaim { foo: "bar".into() }),
                &k,
            )?;

            verifier.set_require_kid(false);
            verifier.verify::<MyClaim>(&token)?;
            let verified = verifier.verify_only::<MyClaim>(&token)?;
            assert_eq!(verified.claims.extra.foo, "bar");
        }

        Ok(())
    }
}