1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
use elliptic_curve::pkcs8::EncodePublicKey;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Validation};
use p384::ecdsa::{SigningKey, VerifyingKey};
use p384::pkcs8::{EncodePrivateKey, LineEnding};
use rand::rngs::OsRng;
use serde::de::DeserializeOwned;
use serde::Serialize;

#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(tag = "alg")]
pub enum TokenPubKey {
    /// ECDSA with SHA2-384 variant
    ES384 { r#pub: String },
}

#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#[serde(tag = "alg")]
pub enum TokenPrivKey {
    ES384 { r#priv: String },
}

/// Generate a new ES384 keypair
pub fn generate_ec384_keypair() -> anyhow::Result<(TokenPubKey, TokenPrivKey)> {
    let signing_key = SigningKey::random(&mut OsRng);
    let priv_pem = signing_key
        .to_pkcs8_der()?
        .to_pem("PRIVATE KEY", LineEnding::LF)?
        .to_string();

    let pub_key = VerifyingKey::from(signing_key);
    let pub_pem = pub_key.to_public_key_pem(LineEnding::LF)?;

    Ok((
        TokenPubKey::ES384 { r#pub: pub_pem },
        TokenPrivKey::ES384 { r#priv: priv_pem },
    ))
}

/// Sign JWT with a private key
pub fn sign_jwt<C: Serialize>(key: &TokenPrivKey, claims: &C) -> anyhow::Result<String> {
    match key {
        TokenPrivKey::ES384 { r#priv } => {
            let encoding_key = EncodingKey::from_ec_pem(r#priv.as_bytes())?;

            Ok(jsonwebtoken::encode(
                &jsonwebtoken::Header::new(Algorithm::ES384),
                &claims,
                &encoding_key,
            )?)
        }
    }
}

/// Validate a given JWT
pub fn validate_jwt<E: DeserializeOwned>(key: &TokenPubKey, token: &str) -> anyhow::Result<E> {
    match key {
        TokenPubKey::ES384 { r#pub } => {
            let decoding_key = DecodingKey::from_ec_pem(r#pub.as_bytes())?;

            let validation = Validation::new(Algorithm::ES384);
            Ok(jsonwebtoken::decode::<E>(token, &decoding_key, &validation)?.claims)
        }
    }
}

#[cfg(test)]
mod test {
    use crate::{generate_ec384_keypair, sign_jwt, validate_jwt};
    use std::time::{SystemTime, UNIX_EPOCH};

    use serde::{Deserialize, Serialize};

    fn time() -> u64 {
        SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap()
            .as_secs()
    }

    #[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
    pub struct Claims {
        sub: String,
        exp: u64,
    }

    impl Default for Claims {
        fn default() -> Self {
            Self {
                sub: "my-sub".to_string(),
                exp: time() + 100,
            }
        }
    }

    #[test]
    fn jwt_encode_sign_verify_valid() {
        let (pub_key, priv_key) = generate_ec384_keypair().unwrap();
        let claims = Claims::default();
        let jwt = sign_jwt(&priv_key, &claims).expect("Failed to sign JWT!");
        let claims_out = validate_jwt::<Claims>(&pub_key, &jwt).expect("Failed to validate JWT!");

        assert_eq!(claims, claims_out)
    }

    #[test]
    fn jwt_encode_sign_verify_invalid_key() {
        let (_pub_key, priv_key) = generate_ec384_keypair().unwrap();
        let (pub_key_2, _priv_key_2) = generate_ec384_keypair().unwrap();
        let claims = Claims::default();
        let jwt = sign_jwt(&priv_key, &claims).expect("Failed to sign JWT!");
        validate_jwt::<Claims>(&pub_key_2, &jwt).expect_err("JWT should not have validated!");
    }

    #[test]
    fn jwt_verify_random_string() {
        let (pub_key, _priv_key) = generate_ec384_keypair().unwrap();
        validate_jwt::<Claims>(&pub_key, "random_string")
            .expect_err("JWT should not have validated!");
    }

    #[test]
    fn jwt_expired() {
        let (pub_key, priv_key) = generate_ec384_keypair().unwrap();
        let claims = Claims {
            exp: time() - 100,
            ..Default::default()
        };
        let jwt = sign_jwt(&priv_key, &claims).expect("Failed to sign JWT!");
        validate_jwt::<Claims>(&pub_key, &jwt).expect_err("JWT should not have validated!");
    }

    #[test]
    fn jwt_invalid_signature() {
        let (pub_key, priv_key) = generate_ec384_keypair().unwrap();
        let claims = Claims::default();
        let jwt = sign_jwt(&priv_key, &claims).expect("Failed to sign JWT!");
        validate_jwt::<Claims>(&pub_key, &format!("{jwt}bad"))
            .expect_err("JWT should not have validated!");
    }
}