use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::common::identity::{AnySignature, AnySignatureError, AnySigningKey, AnyVerifyingKey};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtHeader {
pub alg: String,
pub typ: String,
}
impl JwtHeader {
pub fn for_signing_key(key: &AnySigningKey) -> Self {
Self {
alg: key.jwt_alg().to_string(),
typ: "JWT".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtClaims {
pub iss: String,
pub aud: String,
pub exp: i64,
pub iat: i64,
pub lxm: String,
pub jti: String,
}
#[derive(Debug, Error)]
pub enum JwtError {
#[error("malformed compact JWT: expected three segments")]
MalformedCompact,
#[error("base64url decode failed for {segment}")]
Base64Decode {
segment: &'static str,
#[source]
source: base64::DecodeError,
},
#[error("JSON decode failed for {segment}")]
JsonDecode {
segment: &'static str,
#[source]
source: serde_json::Error,
},
#[error("JSON encode failed")]
JsonEncode(serde_json::Error),
#[error("signature was {actual} bytes; expected 64")]
SignatureLength {
actual: usize,
},
#[error("signature has invalid scalar values")]
InvalidSignatureScalar,
#[error("unsupported JWT alg `{alg}` (expected ES256 or ES256K)")]
UnsupportedAlg {
alg: String,
},
#[error("signature verification failed")]
SignatureVerify(#[from] AnySignatureError),
}
pub fn encode_compact(
header: &JwtHeader,
claims: &JwtClaims,
signer: &AnySigningKey,
) -> Result<String, JwtError> {
let header_json = serde_json::to_vec(header).map_err(JwtError::JsonEncode)?;
let claims_json = serde_json::to_vec(claims).map_err(JwtError::JsonEncode)?;
let header_b64 = URL_SAFE_NO_PAD.encode(&header_json);
let claims_b64 = URL_SAFE_NO_PAD.encode(&claims_json);
let signing_input = format!("{header_b64}.{claims_b64}");
let sig = signer.sign(signing_input.as_bytes());
let sig_bytes = sig.to_jws_bytes();
let sig_b64 = URL_SAFE_NO_PAD.encode(sig_bytes);
Ok(format!("{header_b64}.{claims_b64}.{sig_b64}"))
}
pub fn decode_compact(token: &str) -> Result<(JwtHeader, JwtClaims, Vec<u8>), JwtError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(JwtError::MalformedCompact);
}
let header_b64 = parts[0];
let claims_b64 = parts[1];
let sig_b64 = parts[2];
let header_bytes =
URL_SAFE_NO_PAD
.decode(header_b64)
.map_err(|source| JwtError::Base64Decode {
segment: "header",
source,
})?;
let claims_bytes =
URL_SAFE_NO_PAD
.decode(claims_b64)
.map_err(|source| JwtError::Base64Decode {
segment: "claims",
source,
})?;
let sig_bytes = URL_SAFE_NO_PAD
.decode(sig_b64)
.map_err(|source| JwtError::Base64Decode {
segment: "signature",
source,
})?;
let header: JwtHeader =
serde_json::from_slice(&header_bytes).map_err(|source| JwtError::JsonDecode {
segment: "header",
source,
})?;
let claims: JwtClaims =
serde_json::from_slice(&claims_bytes).map_err(|source| JwtError::JsonDecode {
segment: "claims",
source,
})?;
Ok((header, claims, sig_bytes))
}
pub fn verify_compact(
token: &str,
vkey: &AnyVerifyingKey,
) -> Result<(JwtHeader, JwtClaims), JwtError> {
let (header, claims, sig_bytes) = decode_compact(token)?;
let expected_alg = match vkey {
AnyVerifyingKey::K256(_) => "ES256K",
AnyVerifyingKey::P256(_) => "ES256",
};
if header.alg != expected_alg {
return Err(JwtError::UnsupportedAlg {
alg: header.alg.clone(),
});
}
if sig_bytes.len() != 64 {
return Err(JwtError::SignatureLength {
actual: sig_bytes.len(),
});
}
let sig_array: [u8; 64] = sig_bytes.as_slice().try_into().expect("len checked above");
let any_sig = match vkey {
AnyVerifyingKey::K256(_) => {
let sig = k256::ecdsa::Signature::from_bytes(&sig_array.into())
.map_err(|_| JwtError::InvalidSignatureScalar)?;
AnySignature::K256(sig)
}
AnyVerifyingKey::P256(_) => {
let sig = p256::ecdsa::Signature::from_bytes(&sig_array.into())
.map_err(|_| JwtError::InvalidSignatureScalar)?;
AnySignature::P256(sig)
}
};
let dot = token
.rfind('.')
.expect("three-segment token has a last dot");
let signing_input = &token[..dot];
use sha2::{Digest, Sha256};
let prehash: [u8; 32] = Sha256::digest(signing_input.as_bytes()).into();
vkey.verify_prehash(&prehash, &any_sig)?;
Ok((header, claims))
}
#[cfg(test)]
mod tests {
use super::*;
use k256::ecdsa::SigningKey as K256SigningKey;
use p256::ecdsa::SigningKey as P256SigningKey;
#[test]
fn encode_decode_roundtrip_k256() {
let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
let vkey = key.verifying_key();
let header = JwtHeader::for_signing_key(&key);
let claims = JwtClaims {
iss: "did:web:127.0.0.1%3A5000".to_string(),
aud: "did:plc:test".to_string(),
exp: 2000000000,
iat: 1700000000,
lxm: "com.atproto.moderation.createReport".to_string(),
jti: "0123456789abcdef".to_string(),
};
let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
let (decoded_header, decoded_claims) =
verify_compact(&token, &vkey).expect("verify succeeds");
assert_eq!(decoded_header.alg, "ES256K");
assert_eq!(decoded_claims.iss, claims.iss);
assert_eq!(decoded_claims.aud, claims.aud);
}
#[test]
fn encode_decode_roundtrip_p256() {
let key = AnySigningKey::P256(P256SigningKey::from_slice(&[2u8; 32]).expect("valid seed"));
let vkey = key.verifying_key();
let header = JwtHeader::for_signing_key(&key);
let claims = JwtClaims {
iss: "did:web:example.com".to_string(),
aud: "did:plc:test".to_string(),
exp: 2000000000,
iat: 1700000000,
lxm: "com.atproto.moderation.createReport".to_string(),
jti: "fedcba9876543210".to_string(),
};
let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
let (decoded_header, decoded_claims) =
verify_compact(&token, &vkey).expect("verify succeeds");
assert_eq!(decoded_header.alg, "ES256");
assert_eq!(decoded_claims.aud, claims.aud);
}
#[test]
fn encode_decode_roundtrip_tampered_claims_fails() {
let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
let vkey = key.verifying_key();
let header = JwtHeader::for_signing_key(&key);
let claims = JwtClaims {
iss: "did:web:127.0.0.1%3A5000".to_string(),
aud: "did:plc:test".to_string(),
exp: 2000000000,
iat: 1700000000,
lxm: "com.atproto.moderation.createReport".to_string(),
jti: "0123456789abcdef".to_string(),
};
let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3);
let tampered = format!("{}.YWJj.{}", parts[0], parts[2]);
let result = verify_compact(&tampered, &vkey);
assert!(result.is_err());
}
#[test]
fn decode_compact_malformed_two_segments() {
let result = decode_compact("header.claims");
assert!(matches!(result, Err(JwtError::MalformedCompact)));
}
#[test]
fn decode_compact_malformed_four_segments() {
let result = decode_compact("YQ.Yg.Yw.ZA");
assert!(matches!(result, Err(JwtError::MalformedCompact)));
}
#[test]
fn decode_compact_invalid_base64() {
let result = decode_compact("!!!.claims.sig");
assert!(matches!(
result,
Err(JwtError::Base64Decode {
segment: "header",
..
})
));
}
#[test]
fn verify_compact_curve_mismatch() {
let k256_key =
AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
let p256_key =
AnySigningKey::P256(P256SigningKey::from_slice(&[2u8; 32]).expect("valid seed"));
let header = JwtHeader::for_signing_key(&k256_key);
let claims = JwtClaims {
iss: "did:web:test".to_string(),
aud: "did:plc:test".to_string(),
exp: 2000000000,
iat: 1700000000,
lxm: "com.atproto.moderation.createReport".to_string(),
jti: "0123456789abcdef".to_string(),
};
let token = encode_compact(&header, &claims, &k256_key).expect("encode succeeds");
let p256_vkey = p256_key.verifying_key();
let result = verify_compact(&token, &p256_vkey);
assert!(result.is_err());
}
#[test]
fn encode_compact_produces_valid_structure() {
let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
let header = JwtHeader::for_signing_key(&key);
let claims = JwtClaims {
iss: "did:web:test".to_string(),
aud: "did:plc:test".to_string(),
exp: 2000000000,
iat: 1700000000,
lxm: "com.atproto.moderation.createReport".to_string(),
jti: "0123456789abcdef".to_string(),
};
let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3);
for (i, segment) in parts.iter().enumerate() {
let segment_name = ["header", "claims", "signature"][i];
let result = URL_SAFE_NO_PAD.decode(segment);
assert!(
result.is_ok(),
"segment {segment_name} failed to decode as base64url"
);
}
}
#[test]
fn verify_compact_invalid_signature_scalar_k256() {
let key = AnySigningKey::K256(K256SigningKey::from_slice(&[1u8; 32]).expect("valid seed"));
let vkey = key.verifying_key();
let header = JwtHeader::for_signing_key(&key);
let claims = JwtClaims {
iss: "did:web:127.0.0.1%3A5000".to_string(),
aud: "did:plc:test".to_string(),
exp: 2000000000,
iat: 1700000000,
lxm: "com.atproto.moderation.createReport".to_string(),
jti: "0123456789abcdef".to_string(),
};
let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3);
let zero_sig = URL_SAFE_NO_PAD.encode([0u8; 64]);
let tampered = format!("{}.{}.{}", parts[0], parts[1], zero_sig);
let result = verify_compact(&tampered, &vkey);
assert!(matches!(result, Err(JwtError::InvalidSignatureScalar)));
}
#[test]
fn verify_compact_invalid_signature_scalar_p256() {
let key = AnySigningKey::P256(P256SigningKey::from_slice(&[2u8; 32]).expect("valid seed"));
let vkey = key.verifying_key();
let header = JwtHeader::for_signing_key(&key);
let claims = JwtClaims {
iss: "did:web:example.com".to_string(),
aud: "did:plc:test".to_string(),
exp: 2000000000,
iat: 1700000000,
lxm: "com.atproto.moderation.createReport".to_string(),
jti: "fedcba9876543210".to_string(),
};
let token = encode_compact(&header, &claims, &key).expect("encode succeeds");
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3);
let zero_sig = URL_SAFE_NO_PAD.encode([0u8; 64]);
let tampered = format!("{}.{}.{}", parts[0], parts[1], zero_sig);
let result = verify_compact(&tampered, &vkey);
assert!(matches!(result, Err(JwtError::InvalidSignatureScalar)));
}
}