use base64::{engine::general_purpose, Engine as _};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::str::FromStr;
use url::Url;
use crate::error::JwtError;
use crate::jws::Jws;
use crate::traits::JwsVerifiable;
use base64urlsafedata::Base64UrlSafeData;
#[derive(Debug, Serialize, Clone, Deserialize)]
pub struct JwkKeySet {
pub keys: Vec<Jwk>,
}
#[derive(Debug, Serialize, Clone, Deserialize, PartialEq)]
#[allow(non_camel_case_types)]
pub enum EcCurve {
#[serde(rename = "P-256")]
P256,
}
#[derive(Debug, Serialize, Clone, Deserialize, PartialEq)]
#[allow(non_camel_case_types)]
#[serde(tag = "kty")]
pub enum Jwk {
EC {
crv: EcCurve,
x: Base64UrlSafeData,
y: Base64UrlSafeData,
#[serde(skip_serializing_if = "Option::is_none")]
alg: Option<JwaAlg>,
#[serde(rename = "use", skip_serializing_if = "Option::is_none")]
use_: Option<JwkUse>,
#[serde(skip_serializing_if = "Option::is_none")]
kid: Option<String>,
},
RSA {
n: Base64UrlSafeData,
e: Base64UrlSafeData,
#[serde(skip_serializing_if = "Option::is_none")]
alg: Option<JwaAlg>,
#[serde(rename = "use", skip_serializing_if = "Option::is_none")]
use_: Option<JwkUse>,
#[serde(skip_serializing_if = "Option::is_none")]
kid: Option<String>,
},
}
#[derive(Debug, Serialize, Clone, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum JwkUse {
Sig,
Enc,
}
#[derive(Debug, Serialize, Copy, Clone, Deserialize, PartialEq, Default)]
#[allow(non_camel_case_types)]
pub enum JwaAlg {
ES256,
RS256,
#[default]
HS256,
}
#[derive(Debug, Serialize, Clone, Deserialize, Default, PartialEq)]
pub struct ProtectedHeader {
pub(crate) alg: JwaAlg,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) jku: Option<Url>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) jwk: Option<Jwk>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) kid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) crit: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) typ: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) cty: Option<String>,
#[serde(skip_deserializing, skip_serializing_if = "Option::is_none")]
pub(crate) x5u: Option<()>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) x5c: Option<Vec<String>>,
#[serde(skip_deserializing, skip_serializing_if = "Option::is_none")]
pub(crate) x5t: Option<()>,
#[serde(
skip_deserializing,
rename = "x5t#S256",
skip_serializing_if = "Option::is_none"
)]
pub(crate) x5t_s256: Option<()>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) ctx: Option<String>,
#[cfg(feature = "msextensions")]
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) r#use: Option<String>,
}
#[derive(Clone)]
pub struct JwsCompact {
pub(crate) header: ProtectedHeader,
pub(crate) hdr_b64: String,
pub(crate) payload_b64: String,
pub(crate) signature: Vec<u8>,
}
impl fmt::Debug for JwsCompact {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("JwsCompact")
.field("header", &self.header)
.field("payload", &self.payload_b64)
.finish()
}
}
impl JwsCompact {
pub fn get_jwk_pubkey_url(&self) -> Option<&Url> {
self.header.jku.as_ref()
}
pub fn get_jwk_pubkey(&self) -> Option<&Jwk> {
self.header.jwk.as_ref()
}
}
impl FromStr for JwsCompact {
type Err = JwtError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut siter = s.splitn(3, '.');
let hdr_str = siter.next().ok_or_else(|| {
debug!("invalid compact format - protected header not present");
JwtError::InvalidCompactFormat
})?;
let header: ProtectedHeader = general_purpose::URL_SAFE_NO_PAD
.decode(hdr_str)
.map_err(|_| {
debug!("invalid base64 while decoding header");
JwtError::InvalidBase64
})
.and_then(|bytes| {
serde_json::from_slice(&bytes).map_err(|e| {
debug!(?e, "invalid header format - invalid json");
JwtError::InvalidHeaderFormat
})
})?;
let hdr_b64 = hdr_str.to_string();
if let Some(crit) = &header.crit {
if !crit.is_empty() {
error!("critical extension - unable to process critical extensions");
return Err(JwtError::CriticalExtension);
}
}
let payload_str = siter.next().ok_or_else(|| {
debug!("invalid compact format - payload not present");
JwtError::InvalidCompactFormat
})?;
let sig_str = siter.next().ok_or_else(|| {
debug!("invalid compact format - signature not present");
JwtError::InvalidCompactFormat
})?;
if siter.next().is_some() {
debug!("invalid compact format - extra fields present");
return Err(JwtError::InvalidCompactFormat);
}
let payload_b64 = payload_str.to_string();
let signature = general_purpose::URL_SAFE_NO_PAD
.decode(sig_str)
.map_err(|_| {
debug!("invalid base64 when decoding signature");
JwtError::InvalidBase64
})?;
Ok(JwsCompact {
header,
hdr_b64,
payload_b64,
signature,
})
}
}
impl fmt::Display for JwsCompact {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sig = general_purpose::URL_SAFE_NO_PAD.encode(&self.signature);
write!(f, "{}.{}.{}", self.hdr_b64, self.payload_b64, sig)
}
}
impl Serialize for JwsCompact {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let self_str = self.to_string();
serializer.serialize_str(&self_str)
}
}
struct JwsCompactVisitor;
impl<'de> serde::de::Visitor<'de> for JwsCompactVisitor {
type Value = JwsCompact;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a compact JWS which consists of three base64 url safe unpadded strings separated with '.'")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
JwsCompact::from_str(v)
.map_err(|_| serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &self))
}
}
impl<'de> Deserialize<'de> for JwsCompact {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(JwsCompactVisitor)
}
}
impl JwsVerifiable for JwsCompact {
type Verified = Jws;
fn data(&self) -> JwsCompactVerifyData {
JwsCompactVerifyData {
header: &self.header,
hdr_bytes: self.hdr_b64.as_bytes(),
payload_bytes: self.payload_b64.as_bytes(),
signature_bytes: self.signature.as_slice(),
}
}
fn alg(&self) -> JwaAlg {
self.header.alg
}
fn kid(&self) -> Option<&str> {
self.header.kid.as_deref()
}
fn post_process(&self, value: Jws) -> Result<Self::Verified, JwtError> {
Ok(value)
}
}
pub struct JwsCompactVerifyData<'a> {
#[allow(dead_code)]
pub(crate) header: &'a ProtectedHeader,
#[allow(dead_code)]
pub(crate) hdr_bytes: &'a [u8],
#[allow(dead_code)]
pub(crate) payload_bytes: &'a [u8],
#[allow(dead_code)]
pub(crate) signature_bytes: &'a [u8],
}
#[cfg(any(feature = "unsafe_release_without_verify", feature = "openssl"))]
impl<'a> JwsCompactVerifyData<'a> {
pub(crate) fn release(&self) -> Result<Jws, JwtError> {
general_purpose::URL_SAFE_NO_PAD
.decode(self.payload_bytes)
.map_err(|_| {
debug!("invalid base64 while decoding payload");
JwtError::InvalidBase64
})
.map(|payload| Jws {
header: self.header.clone(),
payload,
})
}
}
#[derive(Debug, Serialize, Copy, Clone, Deserialize, PartialEq, Default)]
#[allow(non_camel_case_types)]
pub enum JweAlg {
#[default]
A128KW,
A256KW,
#[serde(rename = "ECDH-ES+A128KW")]
ECDH_ES_A128KW,
#[serde(rename = "RSA-OAEP")]
RSA_OAEP,
#[serde(rename = "dir")]
DIRECT,
}
#[derive(Debug, Serialize, Copy, Clone, Deserialize, PartialEq, Default)]
#[allow(non_camel_case_types)]
pub enum JweEnc {
#[default]
A128GCM,
A256GCM,
#[serde(rename = "A128CBC-HS256")]
A128CBC_HS256,
}
#[derive(Debug, Serialize, Clone, Deserialize, Default, PartialEq)]
pub struct JweProtectedHeader {
pub(crate) alg: JweAlg,
pub(crate) enc: JweEnc,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) epk: Option<Jwk>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) jku: Option<Url>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) jwk: Option<Jwk>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) kid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) crit: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) typ: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) cty: Option<String>,
#[serde(skip_deserializing, skip_serializing_if = "Option::is_none")]
pub(crate) x5u: Option<()>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) x5c: Option<Vec<String>>,
#[serde(skip_deserializing, skip_serializing_if = "Option::is_none")]
pub(crate) x5t: Option<()>,
#[serde(
skip_deserializing,
rename = "x5t#S256",
skip_serializing_if = "Option::is_none"
)]
pub(crate) x5t_s256: Option<()>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) ctx: Option<String>,
}
#[derive(Clone)]
pub struct JweCompact {
pub(crate) header: JweProtectedHeader,
pub(crate) hdr_b64: String,
pub(crate) content_enc_key: Vec<u8>,
pub(crate) iv: Vec<u8>,
pub(crate) ciphertext: Vec<u8>,
pub(crate) authentication_tag: Vec<u8>,
}
impl fmt::Debug for JweCompact {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("JweCompact")
.field("header", &self.header)
.field("encrypted_payload_length", &self.ciphertext.len())
.finish()
}
}
impl JweCompact {
pub fn kid(&self) -> Option<&str> {
self.header.kid.as_deref()
}
pub fn get_jwk_pubkey_url(&self) -> Option<&Url> {
self.header.jku.as_ref()
}
pub fn get_jwk_pubkey(&self) -> Option<&Jwk> {
self.header.jwk.as_ref()
}
pub fn get_alg_enc(&self) -> (JweAlg, JweEnc) {
(self.header.alg, self.header.enc)
}
}
impl FromStr for JweCompact {
type Err = JwtError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut siter = s.splitn(5, '.');
let hdr_str = siter.next().ok_or_else(|| {
debug!("invalid compact format - unprotected header not present");
JwtError::InvalidCompactFormat
})?;
let header: JweProtectedHeader = general_purpose::URL_SAFE_NO_PAD
.decode(hdr_str)
.map_err(|_| {
debug!("invalid base64 while decoding header");
JwtError::InvalidBase64
})
.and_then(|bytes| {
serde_json::from_slice(&bytes).map_err(|e| {
debug!(?e, "invalid header format - invalid json");
JwtError::InvalidHeaderFormat
})
})?;
let hdr_b64 = hdr_str.to_string();
if let Some(crit) = &header.crit {
if !crit.is_empty() {
error!("critical extension - unable to process critical extensions");
return Err(JwtError::CriticalExtension);
}
}
let content_enc_key_str = siter.next().ok_or_else(|| {
debug!("invalid compact format - content encryption key not present");
JwtError::InvalidCompactFormat
})?;
let iv_str = siter.next().ok_or_else(|| {
debug!("invalid compact format - iv not present");
JwtError::InvalidCompactFormat
})?;
let ciphertext_str = siter.next().ok_or_else(|| {
debug!("invalid compact format - ciphertext not present");
JwtError::InvalidCompactFormat
})?;
let authentication_tag_str = siter.next().ok_or_else(|| {
debug!("invalid compact format - ciphertext not present");
JwtError::InvalidCompactFormat
})?;
if siter.next().is_some() {
debug!("invalid compact format - extra fields present");
return Err(JwtError::InvalidCompactFormat);
}
let content_enc_key = general_purpose::URL_SAFE_NO_PAD
.decode(content_enc_key_str)
.map_err(|_| {
debug!("invalid base64 when decoding content encryption key");
JwtError::InvalidBase64
})?;
let iv = general_purpose::URL_SAFE_NO_PAD
.decode(iv_str)
.map_err(|_| {
debug!("invalid base64 when decoding iv");
JwtError::InvalidBase64
})?;
let ciphertext = general_purpose::URL_SAFE_NO_PAD
.decode(ciphertext_str)
.map_err(|_| {
debug!("invalid base64 when decoding ciphertext");
JwtError::InvalidBase64
})?;
let authentication_tag = general_purpose::URL_SAFE_NO_PAD
.decode(authentication_tag_str)
.map_err(|_| {
debug!("invalid base64 when decoding authentication tag");
JwtError::InvalidBase64
})?;
Ok(JweCompact {
header,
hdr_b64,
content_enc_key,
iv,
ciphertext,
authentication_tag,
})
}
}
impl fmt::Display for JweCompact {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let content_enc_key_b64 = general_purpose::URL_SAFE_NO_PAD.encode(&self.content_enc_key);
let iv_b64 = general_purpose::URL_SAFE_NO_PAD.encode(&self.iv);
let cipher_b64 = general_purpose::URL_SAFE_NO_PAD.encode(&self.ciphertext);
let aad_b64 = general_purpose::URL_SAFE_NO_PAD.encode(&self.authentication_tag);
write!(
f,
"{}.{}.{}.{}.{}",
self.hdr_b64, content_enc_key_b64, iv_b64, cipher_b64, aad_b64
)
}
}