basic-jwt 0.4.0

Basic JWT signing and verification library
Documentation
use elliptic_curve::pkcs8::EncodePublicKey;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Validation};
use p384::ecdsa::signature::rand_core::OsRng;
use p384::ecdsa::{SigningKey, VerifyingKey};
use p384::pkcs8::{EncodePrivateKey, LineEnding};
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::str::FromStr;

#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(tag = "alg")]
pub enum JWTPublicKey {
    /// ECDSA with SHA2-384 variant
    ES384 {
        #[serde(rename = "pub")]
        public: String,
    },
}

#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#[serde(tag = "alg")]
pub enum JWTPrivateKey {
    ES384 { r#priv: String },
}

impl JWTPrivateKey {
    /// Generate a new ES384 signing key
    pub fn generate_ec384_signing_key() -> anyhow::Result<Self> {
        let signing_key = SigningKey::random(&mut OsRng);
        let priv_pem = signing_key
            .to_pkcs8_der()?
            .to_pem("PRIVATE KEY", LineEnding::LF)?
            .to_string();

        Ok(Self::ES384 { r#priv: priv_pem })
    }

    /// Get associated public key
    pub fn to_public_key(&self) -> anyhow::Result<JWTPublicKey> {
        match self {
            JWTPrivateKey::ES384 { r#priv } => {
                let signing_key = SigningKey::from_str(r#priv)?;

                let pub_key = VerifyingKey::from(signing_key);
                let pub_pem = pub_key.to_public_key_pem(LineEnding::LF)?;

                Ok(JWTPublicKey::ES384 { public: pub_pem })
            }
        }
    }

    /// Sign a JWT
    pub fn sign_jwt<C: Serialize>(&self, claims: &C) -> anyhow::Result<String> {
        match self {
            JWTPrivateKey::ES384 { r#priv } => {
                let encoding_key = EncodingKey::from_ec_pem(r#priv.as_bytes())?;

                Ok(jsonwebtoken::encode(
                    &jsonwebtoken::Header::new(Algorithm::ES384),
                    &claims,
                    &encoding_key,
                )?)
            }
        }
    }
}

impl JWTPublicKey {
    /// Validate a given JWT
    pub fn validate_jwt<E: DeserializeOwned + Clone>(&self, jwt: &str) -> anyhow::Result<E> {
        match self {
            JWTPublicKey::ES384 { public } => {
                let decoding_key = DecodingKey::from_ec_pem(public.as_bytes())?;

                let validation = Validation::new(Algorithm::ES384);
                Ok(jsonwebtoken::decode::<E>(jwt, &decoding_key, &validation)?.claims)
            }
        }
    }
}

#[cfg(test)]
mod test {
    use std::time::{SystemTime, UNIX_EPOCH};

    use crate::JWTPrivateKey;
    use serde::{Deserialize, Serialize};

    fn time() -> u64 {
        SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap()
            .as_secs()
    }

    #[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Clone)]
    pub struct Claims {
        sub: String,
        exp: u64,
    }

    impl Default for Claims {
        fn default() -> Self {
            Self {
                sub: "my-sub".to_string(),
                exp: time() + 100,
            }
        }
    }

    #[test]
    fn jwt_encode_sign_verify_valid() {
        let priv_key = JWTPrivateKey::generate_ec384_signing_key().unwrap();
        let pub_key = priv_key.to_public_key().unwrap();

        let claims = Claims::default();
        let jwt = priv_key.sign_jwt(&claims).expect("Failed to sign JWT!");
        let claims_out = pub_key
            .validate_jwt::<Claims>(&jwt)
            .expect("Failed to validate JWT!");

        assert_eq!(claims, claims_out)
    }

    #[test]
    fn jwt_encode_sign_verify_invalid_key() {
        let priv_key = JWTPrivateKey::generate_ec384_signing_key().unwrap();
        let pub_key_2 = JWTPrivateKey::generate_ec384_signing_key()
            .unwrap()
            .to_public_key()
            .unwrap();

        let claims = Claims::default();
        let jwt = priv_key.sign_jwt(&claims).expect("Failed to sign JWT!");
        pub_key_2
            .validate_jwt::<Claims>(&jwt)
            .expect_err("JWT should not have validated!");
    }

    #[test]
    fn jwt_verify_random_string() {
        let priv_key = JWTPrivateKey::generate_ec384_signing_key().unwrap();
        let pub_key = priv_key.to_public_key().unwrap();

        pub_key
            .validate_jwt::<Claims>("random_string")
            .expect_err("JWT should not have validated!");
    }

    #[test]
    fn jwt_expired() {
        let priv_key = JWTPrivateKey::generate_ec384_signing_key().unwrap();
        let pub_key = priv_key.to_public_key().unwrap();

        let claims = Claims {
            exp: time() - 100,
            ..Default::default()
        };
        let jwt = priv_key.sign_jwt(&claims).expect("Failed to sign JWT!");
        pub_key
            .validate_jwt::<Claims>(&jwt)
            .expect_err("JWT should not have validated!");
    }

    #[test]
    fn jwt_invalid_signature() {
        let priv_key = JWTPrivateKey::generate_ec384_signing_key().unwrap();
        let pub_key = priv_key.to_public_key().unwrap();

        let claims = Claims::default();
        let jwt = priv_key.sign_jwt(&claims).expect("Failed to sign JWT!");
        pub_key
            .validate_jwt::<Claims>(&format!("{jwt}bad"))
            .expect_err("JWT should not have validated!");
    }
}