use crate::errors::{JwtError, JwtOperation};
use crate::jwt::authority::JwtAuthority;
use crate::{Codec, Error, Result};
use std::collections::HashMap;
use std::collections::HashSet;
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
use std::sync::RwLock;
use chrono::Utc;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode_header};
use p384::elliptic_curve::rand_core::OsRng;
use p384::pkcs8::{EncodePrivateKey, EncodePublicKey, LineEnding};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_with::skip_serializing_none;
use uuid::Uuid;
pub mod authority;
pub mod jwks;
pub mod remote_verifier;
pub mod validation_result;
pub mod validation_service;
fn validation_with_es384_only() -> Validation {
let mut validation = Validation::new(Algorithm::ES384);
validation.algorithms = vec![Algorithm::ES384];
validation
}
fn canonical_es384_header_with_kid(kid: &str) -> Header {
let mut header = Header::new(Algorithm::ES384);
header.typ = Some("JWT".to_string());
header.kid = Some(kid.to_string());
header
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Es384KeyPairPaths {
private_key_path: PathBuf,
public_key_path: PathBuf,
}
impl Es384KeyPairPaths {
pub fn new(private_key_path: impl Into<PathBuf>, public_key_path: impl Into<PathBuf>) -> Self {
Self {
private_key_path: private_key_path.into(),
public_key_path: public_key_path.into(),
}
}
pub fn private_key_path(&self) -> &Path {
&self.private_key_path
}
pub fn public_key_path(&self) -> &Path {
&self.public_key_path
}
pub fn both_exist(&self) -> bool {
self.private_key_path.is_file() && self.public_key_path.is_file()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Es384KeyPair {
paths: Es384KeyPairPaths,
private_key_pem: Vec<u8>,
public_key_pem: Vec<u8>,
}
impl Es384KeyPair {
pub fn new(
paths: Es384KeyPairPaths,
private_key_pem: impl Into<Vec<u8>>,
public_key_pem: impl Into<Vec<u8>>,
) -> Self {
Self {
paths,
private_key_pem: private_key_pem.into(),
public_key_pem: public_key_pem.into(),
}
}
pub fn paths(&self) -> &Es384KeyPairPaths {
&self.paths
}
pub fn private_key_path(&self) -> &Path {
self.paths.private_key_path()
}
pub fn public_key_path(&self) -> &Path {
self.paths.public_key_path()
}
pub fn private_key_pem(&self) -> &[u8] {
&self.private_key_pem
}
pub fn public_key_pem(&self) -> &[u8] {
&self.public_key_pem
}
pub fn to_jwt_options(&self) -> Result<JsonWebTokenOptions> {
JsonWebTokenOptions::from_es384_pem(&self.private_key_pem, &self.public_key_pem)
}
pub fn to_codec<P>(&self) -> Result<JsonWebToken<P>> {
Ok(JsonWebToken::new_with_options(self.to_jwt_options()?))
}
pub fn to_authority<P>(&self) -> Result<JwtAuthority<P>>
where
P: Serialize + DeserializeOwned + Clone,
{
JwtAuthority::from_es384_pem(&self.private_key_pem, &self.public_key_pem)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Es384KeyPairLoader {
paths: Es384KeyPairPaths,
}
impl Es384KeyPairLoader {
pub fn new(private_key_path: impl Into<PathBuf>, public_key_path: impl Into<PathBuf>) -> Self {
Self {
paths: Es384KeyPairPaths::new(private_key_path, public_key_path),
}
}
pub fn paths(&self) -> &Es384KeyPairPaths {
&self.paths
}
pub async fn initialize_if_required(&self) -> Result<Es384KeyPair> {
let private_exists = self.paths.private_key_path.is_file();
let public_exists = self.paths.public_key_path.is_file();
match (private_exists, public_exists) {
(true, true) => self.load().await,
(false, false) => {
self.create_parent_directories().await?;
let (private_key_pem, public_key_pem) = generate_es384_key_pair_pem()?;
tokio::fs::write(&self.paths.private_key_path, &private_key_pem)
.await
.map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Encode,
format!(
"failed to write ES384 private key `{}`: {error}",
self.paths.private_key_path.display()
),
))
})?;
tokio::fs::write(&self.paths.public_key_path, &public_key_pem)
.await
.map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Encode,
format!(
"failed to write ES384 public key `{}`: {error}",
self.paths.public_key_path.display()
),
))
})?;
self.validate_loaded_pair(Es384KeyPair::new(
self.paths.clone(),
private_key_pem,
public_key_pem,
))
}
_ => Err(Error::Jwt(JwtError::processing(
JwtOperation::Encode,
format!(
"ES384 key initialization requires both key files to exist or neither to exist; private=`{}`, public=`{}`",
self.paths.private_key_path.display(),
self.paths.public_key_path.display()
),
))),
}
}
pub async fn load(&self) -> Result<Es384KeyPair> {
let private_key_pem = tokio::fs::read(&self.paths.private_key_path)
.await
.map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Decode,
format!(
"failed to read ES384 private key `{}`: {error}",
self.paths.private_key_path.display()
),
))
})?;
let public_key_pem =
tokio::fs::read(&self.paths.public_key_path)
.await
.map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Decode,
format!(
"failed to read ES384 public key `{}`: {error}",
self.paths.public_key_path.display()
),
))
})?;
self.validate_loaded_pair(Es384KeyPair::new(
self.paths.clone(),
private_key_pem,
public_key_pem,
))
}
async fn create_parent_directories(&self) -> Result<()> {
if let Some(parent) = self.paths.private_key_path.parent()
&& !parent.as_os_str().is_empty()
{
tokio::fs::create_dir_all(parent).await.map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Encode,
format!(
"failed to create private key directory `{}`: {error}",
parent.display()
),
))
})?;
}
if let Some(parent) = self.paths.public_key_path.parent()
&& !parent.as_os_str().is_empty()
{
tokio::fs::create_dir_all(parent).await.map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Encode,
format!(
"failed to create public key directory `{}`: {error}",
parent.display()
),
))
})?;
}
Ok(())
}
fn validate_loaded_pair(&self, key_pair: Es384KeyPair) -> Result<Es384KeyPair> {
key_pair.to_jwt_options()?;
Ok(key_pair)
}
}
fn generate_es384_key_pair_pem() -> Result<(Vec<u8>, Vec<u8>)> {
let signing_key = p384::ecdsa::SigningKey::random(&mut OsRng);
let verifying_key = signing_key.verifying_key();
let private_key_pem = signing_key
.to_pkcs8_pem(LineEnding::LF)
.map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Encode,
format!("failed to encode ES384 private key as PKCS#8 PEM: {error}"),
))
})?
.to_string()
.into_bytes();
let public_key_pem = verifying_key
.to_public_key_pem(LineEnding::LF)
.map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Encode,
format!("failed to encode ES384 public key as PEM: {error}"),
))
})?
.into_bytes();
Ok((private_key_pem, public_key_pem))
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[skip_serializing_none]
pub struct RegisteredClaims {
#[serde(rename = "iss")]
pub issuer: String,
#[serde(rename = "sub")]
pub subject: Option<String>,
#[serde(rename = "aud")]
pub audience: Option<HashSet<String>>,
#[serde(rename = "exp")]
pub expiration_time: u64,
#[serde(rename = "nbf")]
pub not_before_time: Option<u64>,
#[serde(rename = "iat")]
pub issued_at_time: u64,
#[serde(rename = "jti")]
pub jwt_id: Option<String>,
#[serde(rename = "sid")]
pub session_id: Option<String>,
}
impl RegisteredClaims {
pub fn new(issuer: &str, expiration_time: u64) -> Self {
let issued_at_time = u64::try_from(Utc::now().timestamp()).unwrap_or(0);
Self {
issuer: issuer.to_string(),
subject: None,
audience: None,
expiration_time,
not_before_time: None,
issued_at_time,
jwt_id: Some(Uuid::now_v7().to_string()),
session_id: None,
}
}
#[must_use]
pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
pub struct JwtClaims<CustomClaims> {
#[serde(flatten)]
pub registered_claims: RegisteredClaims,
#[serde(flatten)]
pub custom_claims: CustomClaims,
}
impl<CustomClaims> JwtClaims<CustomClaims> {
pub fn new(custom_claims: CustomClaims, registered_claims: RegisteredClaims) -> Self {
Self {
custom_claims,
registered_claims,
}
}
pub fn has_issuer(&self, issuer: &str) -> bool {
self.registered_claims.issuer == issuer
}
}
#[derive(Debug, Clone)]
pub struct JsonWebTokenOptions {
encoding_key: Option<EncodingKey>,
key_id: String,
decoding_keys_by_kid: HashMap<String, DecodingKey>,
fallback_decoding_key: Option<DecodingKey>,
header: Header,
validation: Validation,
}
const DEV_ES384_PRIVATE_KEY_PEM: &[u8] = br#"-----BEGIN PRIVATE KEY-----
MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDCFT7MfRqWZfNgVX/cH
bxFTlPkBeCKqjsLkZXD/J3ZYHV1EtQksdrKtOzTr2hMs6pmhZANiAASyND9eQ5Qk
7ZteSEPMpExbVJenRWwyobExJMb62mmp3eA7Fszy8uBbLj8HRB16y3QbLcTxCBoo
ldBXfNFzM133OuTV2bBWXq5h34l+A0h4gU/odZ678LfAgnrRYMG4ZjU=
-----END PRIVATE KEY-----
"#;
const DEV_ES384_PUBLIC_KEY_PEM: &[u8] = br#"-----BEGIN PUBLIC KEY-----
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEsjQ/XkOUJO2bXkhDzKRMW1SXp0VsMqGx
MSTG+tppqd3gOxbM8vLgWy4/B0Qdest0Gy3E8QgaKJXQV3zRczNd9zrk1dmwVl6u
Yd+JfgNIeIFP6HWeu/C3wIJ60WDBuGY1
-----END PUBLIC KEY-----
"#;
impl Default for JsonWebTokenOptions {
fn default() -> Self {
match Self::from_es384_pem(DEV_ES384_PRIVATE_KEY_PEM, DEV_ES384_PUBLIC_KEY_PEM) {
Ok(options) => options,
Err(error) => panic!("failed to initialize default ES384 JWT options: {error}"),
}
}
}
impl JsonWebTokenOptions {
pub fn from_es384_pem(private_key_pem: &[u8], public_key_pem: &[u8]) -> Result<Self> {
let encoding_key = EncodingKey::from_ec_pem(private_key_pem).map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Encode,
format!("failed to parse ES384 private key: {error}"),
))
})?;
let decoding_key = DecodingKey::from_ec_pem(public_key_pem).map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Decode,
format!("failed to parse ES384 public key: {error}"),
))
})?;
let key_id = jwks::es384_kid_from_public_key_pem(public_key_pem)?;
let header = canonical_es384_header_with_kid(&key_id);
let validation = validation_with_es384_only();
let mut decoding_keys_by_kid = HashMap::new();
decoding_keys_by_kid.insert(key_id.clone(), decoding_key.clone());
Ok(Self {
encoding_key: Some(encoding_key),
key_id,
decoding_keys_by_kid,
fallback_decoding_key: Some(decoding_key),
header,
validation,
})
}
pub fn for_es384_verification_only(public_key_pem: &[u8]) -> Result<Self> {
let decoding_key = DecodingKey::from_ec_pem(public_key_pem).map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Decode,
format!("failed to parse ES384 public key: {error}"),
))
})?;
let key_id = jwks::es384_kid_from_public_key_pem(public_key_pem)?;
let header = canonical_es384_header_with_kid(&key_id);
let validation = validation_with_es384_only();
let mut decoding_keys_by_kid = HashMap::new();
decoding_keys_by_kid.insert(key_id.clone(), decoding_key.clone());
Ok(Self {
encoding_key: None,
key_id,
decoding_keys_by_kid,
fallback_decoding_key: Some(decoding_key),
header,
validation,
})
}
pub fn for_es384_jwks_keys(keys: &[jwks::EcP384Jwk]) -> Result<Self> {
if keys.is_empty() {
return Err(Error::Jwt(JwtError::processing(
JwtOperation::Validate,
"JWKS key set is empty",
)));
}
let mut decoding_keys_by_kid = HashMap::new();
for key in keys {
let decoding_key = key.to_decoding_key()?;
decoding_keys_by_kid.insert(key.kid.clone(), decoding_key);
}
let key_id = keys[0].kid.clone();
let header = canonical_es384_header_with_kid(&key_id);
let validation = validation_with_es384_only();
Ok(Self {
encoding_key: None,
key_id,
decoding_keys_by_kid,
fallback_decoding_key: None,
header,
validation,
})
}
pub fn key_id(&self) -> &str {
&self.key_id
}
pub fn verification_key_count(&self) -> usize {
self.decoding_keys_by_kid.len()
}
pub fn allows_missing_kid_fallback(&self) -> bool {
self.fallback_decoding_key.is_some()
}
pub fn with_key_id(mut self, key_id: impl Into<String>) -> Self {
let key_id = key_id.into();
self.key_id = key_id.clone();
self.header = canonical_es384_header_with_kid(&key_id);
self
}
pub fn with_verification_keys(
mut self,
keys: HashMap<String, DecodingKey>,
allow_missing_kid_fallback: bool,
) -> Self {
let fallback = if allow_missing_kid_fallback {
keys.values().next().cloned()
} else {
None
};
self.decoding_keys_by_kid = keys;
self.fallback_decoding_key = fallback;
self
}
pub fn with_added_verification_key(mut self, kid: impl Into<String>, key: DecodingKey) -> Self {
self.decoding_keys_by_kid.insert(kid.into(), key);
self
}
pub fn with_validation(self, validation: Validation) -> Self {
let mut validation = validation;
validation.algorithms = vec![Algorithm::ES384];
Self { validation, ..self }
}
}
#[derive(Clone)]
pub struct JsonWebToken<P> {
enc_key: Option<EncodingKey>,
key_id: String,
verification_state: std::sync::Arc<RwLock<VerificationState>>,
header: Header,
validation: Validation,
phantom_payload: PhantomData<P>,
}
#[derive(Clone)]
struct VerificationState {
dec_keys_by_kid: HashMap<String, DecodingKey>,
fallback_dec_key: Option<DecodingKey>,
}
impl<P> JsonWebToken<P> {
pub fn new_with_options(options: JsonWebTokenOptions) -> Self {
let JsonWebTokenOptions {
encoding_key,
key_id,
decoding_keys_by_kid,
fallback_decoding_key,
header,
validation,
} = options;
Self {
enc_key: encoding_key,
key_id,
verification_state: std::sync::Arc::new(RwLock::new(VerificationState {
dec_keys_by_kid: decoding_keys_by_kid,
fallback_dec_key: fallback_decoding_key,
})),
header,
validation,
phantom_payload: PhantomData,
}
}
pub fn key_id(&self) -> &str {
&self.key_id
}
pub fn verification_key_count(&self) -> usize {
self.verification_state
.read()
.map(|state| state.dec_keys_by_kid.len())
.unwrap_or(0)
}
pub fn has_verification_key(&self, kid: &str) -> bool {
self.verification_state
.read()
.map(|state| state.dec_keys_by_kid.contains_key(kid))
.unwrap_or(false)
}
pub fn allows_missing_kid_fallback(&self) -> bool {
self.verification_state
.read()
.map(|state| state.fallback_dec_key.is_some())
.unwrap_or(false)
}
pub fn replace_verification_keys(
&self,
keys: HashMap<String, DecodingKey>,
allow_missing_kid_fallback: bool,
) {
if let Ok(mut state) = self.verification_state.write() {
let fallback = if allow_missing_kid_fallback {
keys.values().next().cloned()
} else {
None
};
state.dec_keys_by_kid = keys;
state.fallback_dec_key = fallback;
}
}
pub fn replace_verification_keys_from_jwks(
&self,
keys: &[jwks::EcP384Jwk],
allow_missing_kid_fallback: bool,
) -> Result<()> {
let mut decoding_keys = HashMap::new();
for key in keys {
let decoding_key = key.to_decoding_key()?;
decoding_keys.insert(key.kid.clone(), decoding_key);
}
self.replace_verification_keys(decoding_keys, allow_missing_kid_fallback);
Ok(())
}
fn decoding_key_for_header(&self, header: &Header) -> Result<DecodingKey> {
if header.alg != Algorithm::ES384 {
return Err(Error::Jwt(JwtError::processing(
JwtOperation::Validate,
format!(
"JWT header algorithm mismatch: expected ES384 but got {:?}",
header.alg
),
)));
}
if header.typ.as_deref() != Some("JWT") {
return Err(Error::Jwt(JwtError::processing(
JwtOperation::Validate,
"JWT header `typ` must be `JWT`",
)));
}
let state = self.verification_state.read().map_err(|_| {
Error::Jwt(JwtError::processing(
JwtOperation::Validate,
"verification key state lock is poisoned",
))
})?;
if let Some(kid) = header.kid.as_deref() {
return state.dec_keys_by_kid.get(kid).cloned().ok_or_else(|| {
Error::Jwt(JwtError::processing(
JwtOperation::Validate,
format!("JWT `kid` `{kid}` is not configured for verification"),
))
});
}
state.fallback_dec_key.clone().ok_or_else(|| {
Error::Jwt(JwtError::processing(
JwtOperation::Validate,
"JWT header is missing `kid` and no fallback verification key is configured",
))
})
}
}
impl<P> Default for JsonWebToken<P> {
fn default() -> Self {
Self::new_with_options(JsonWebTokenOptions::default())
}
}
impl<P> Codec for JsonWebToken<P>
where
P: Serialize + DeserializeOwned + Clone,
{
type Payload = P;
fn encode(&self, payload: &Self::Payload) -> Result<Vec<u8>> {
let Some(enc_key) = &self.enc_key else {
return Err(Error::Jwt(JwtError::processing(
JwtOperation::Encode,
"JWT encoding key is not configured for this codec",
)));
};
let token = jsonwebtoken::encode(&self.header, payload, enc_key).map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Encode,
format!("JWT encoding failed: {error}"),
))
})?;
Ok(token.into_bytes())
}
fn decode(&self, encoded_value: &[u8]) -> Result<Self::Payload> {
let header = decode_header(std::str::from_utf8(encoded_value).map_err(|error| {
Error::Jwt(JwtError::processing(
JwtOperation::Decode,
format!("JWT bytes are not valid UTF-8: {error}"),
))
})?)
.map_err(|error| {
Error::Jwt(JwtError::processing_with_preview(
JwtOperation::Decode,
format!("JWT header decoding failed: {error}"),
Some(format!("token_len={}", encoded_value.len())),
))
})?;
let decoding_key = self.decoding_key_for_header(&header)?;
let claims =
jsonwebtoken::decode::<Self::Payload>(encoded_value, &decoding_key, &self.validation)
.map_err(|error| {
Error::Jwt(JwtError::processing_with_preview(
JwtOperation::Decode,
format!("JWT decoding failed: {error}"),
Some(format!("token_len={}", encoded_value.len())),
))
})?;
Ok(claims.claims)
}
}