use elliptic_curve::pkcs8::EncodePublicKey;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Validation};
use p384::ecdsa::signature::rand_core::OsRng;
use p384::ecdsa::{SigningKey, VerifyingKey};
use p384::pkcs8::{EncodePrivateKey, LineEnding};
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::str::FromStr;
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(tag = "alg")]
pub enum JWTPublicKey {
ES384 {
#[serde(rename = "pub")]
public: String,
},
}
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#[serde(tag = "alg")]
pub enum JWTPrivateKey {
ES384 { r#priv: String },
}
impl JWTPrivateKey {
pub fn generate_ec384_signing_key() -> anyhow::Result<Self> {
let signing_key = SigningKey::random(&mut OsRng);
let priv_pem = signing_key
.to_pkcs8_der()?
.to_pem("PRIVATE KEY", LineEnding::LF)?
.to_string();
Ok(Self::ES384 { r#priv: priv_pem })
}
pub fn to_public_key(&self) -> anyhow::Result<JWTPublicKey> {
match self {
JWTPrivateKey::ES384 { r#priv } => {
let signing_key = SigningKey::from_str(r#priv)?;
let pub_key = VerifyingKey::from(signing_key);
let pub_pem = pub_key.to_public_key_pem(LineEnding::LF)?;
Ok(JWTPublicKey::ES384 { public: pub_pem })
}
}
}
pub fn sign_jwt<C: Serialize>(&self, claims: &C) -> anyhow::Result<String> {
match self {
JWTPrivateKey::ES384 { r#priv } => {
let encoding_key = EncodingKey::from_ec_pem(r#priv.as_bytes())?;
Ok(jsonwebtoken::encode(
&jsonwebtoken::Header::new(Algorithm::ES384),
&claims,
&encoding_key,
)?)
}
}
}
}
impl JWTPublicKey {
pub fn validate_jwt<E: DeserializeOwned + Clone>(&self, jwt: &str) -> anyhow::Result<E> {
match self {
JWTPublicKey::ES384 { public } => {
let decoding_key = DecodingKey::from_ec_pem(public.as_bytes())?;
let validation = Validation::new(Algorithm::ES384);
Ok(jsonwebtoken::decode::<E>(jwt, &decoding_key, &validation)?.claims)
}
}
}
}
#[cfg(test)]
mod test {
use std::time::{SystemTime, UNIX_EPOCH};
use crate::JWTPrivateKey;
use serde::{Deserialize, Serialize};
fn time() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Clone)]
pub struct Claims {
sub: String,
exp: u64,
}
impl Default for Claims {
fn default() -> Self {
Self {
sub: "my-sub".to_string(),
exp: time() + 100,
}
}
}
#[test]
fn jwt_encode_sign_verify_valid() {
let priv_key = JWTPrivateKey::generate_ec384_signing_key().unwrap();
let pub_key = priv_key.to_public_key().unwrap();
let claims = Claims::default();
let jwt = priv_key.sign_jwt(&claims).expect("Failed to sign JWT!");
let claims_out = pub_key
.validate_jwt::<Claims>(&jwt)
.expect("Failed to validate JWT!");
assert_eq!(claims, claims_out)
}
#[test]
fn jwt_encode_sign_verify_invalid_key() {
let priv_key = JWTPrivateKey::generate_ec384_signing_key().unwrap();
let pub_key_2 = JWTPrivateKey::generate_ec384_signing_key()
.unwrap()
.to_public_key()
.unwrap();
let claims = Claims::default();
let jwt = priv_key.sign_jwt(&claims).expect("Failed to sign JWT!");
pub_key_2
.validate_jwt::<Claims>(&jwt)
.expect_err("JWT should not have validated!");
}
#[test]
fn jwt_verify_random_string() {
let priv_key = JWTPrivateKey::generate_ec384_signing_key().unwrap();
let pub_key = priv_key.to_public_key().unwrap();
pub_key
.validate_jwt::<Claims>("random_string")
.expect_err("JWT should not have validated!");
}
#[test]
fn jwt_expired() {
let priv_key = JWTPrivateKey::generate_ec384_signing_key().unwrap();
let pub_key = priv_key.to_public_key().unwrap();
let claims = Claims {
exp: time() - 100,
..Default::default()
};
let jwt = priv_key.sign_jwt(&claims).expect("Failed to sign JWT!");
pub_key
.validate_jwt::<Claims>(&jwt)
.expect_err("JWT should not have validated!");
}
#[test]
fn jwt_invalid_signature() {
let priv_key = JWTPrivateKey::generate_ec384_signing_key().unwrap();
let pub_key = priv_key.to_public_key().unwrap();
let claims = Claims::default();
let jwt = priv_key.sign_jwt(&claims).expect("Failed to sign JWT!");
pub_key
.validate_jwt::<Claims>(&format!("{jwt}bad"))
.expect_err("JWT should not have validated!");
}
}