use std::collections::HashMap;
use std::result;
use base64::{Engine, engine::general_purpose::STANDARD};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::algorithms::Algorithm;
use crate::errors::Result;
use crate::jwk::Jwk;
use crate::serialization::b64_decode;
const ZIP_SERIAL_DEFLATE: &str = "DEF";
const ENC_A128CBC_HS256: &str = "A128CBC-HS256";
const ENC_A192CBC_HS384: &str = "A192CBC-HS384";
const ENC_A256CBC_HS512: &str = "A256CBC-HS512";
const ENC_A128GCM: &str = "A128GCM";
const ENC_A192GCM: &str = "A192GCM";
const ENC_A256GCM: &str = "A256GCM";
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[allow(clippy::upper_case_acronyms, non_camel_case_types)]
pub enum Enc {
A128CBC_HS256,
A192CBC_HS384,
A256CBC_HS512,
A128GCM,
A192GCM,
A256GCM,
Other(String),
}
impl Serialize for Enc {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Enc::A128CBC_HS256 => ENC_A128CBC_HS256,
Enc::A192CBC_HS384 => ENC_A192CBC_HS384,
Enc::A256CBC_HS512 => ENC_A256CBC_HS512,
Enc::A128GCM => ENC_A128GCM,
Enc::A192GCM => ENC_A192GCM,
Enc::A256GCM => ENC_A256GCM,
Enc::Other(v) => v,
}
.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Enc {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.as_str() {
ENC_A128CBC_HS256 => return Ok(Enc::A128CBC_HS256),
ENC_A192CBC_HS384 => return Ok(Enc::A192CBC_HS384),
ENC_A256CBC_HS512 => return Ok(Enc::A256CBC_HS512),
ENC_A128GCM => return Ok(Enc::A128GCM),
ENC_A192GCM => return Ok(Enc::A192GCM),
ENC_A256GCM => return Ok(Enc::A256GCM),
_ => (),
}
Ok(Enc::Other(s))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Zip {
Deflate,
Other(String),
}
impl Serialize for Zip {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Zip::Deflate => ZIP_SERIAL_DEFLATE,
Zip::Other(v) => v,
}
.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Zip {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.as_str() {
ZIP_SERIAL_DEFLATE => Ok(Zip::Deflate),
_ => Ok(Zip::Other(s)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Header {
#[serde(skip_serializing_if = "Option::is_none")]
pub typ: Option<String>,
pub alg: Algorithm,
#[serde(skip_serializing_if = "Option::is_none")]
pub cty: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jku: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwk: Option<Jwk>,
#[serde(skip_serializing_if = "Option::is_none")]
pub kid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x5u: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x5c: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x5t: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "x5t#S256")]
pub x5t_s256: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub crit: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enc: Option<Enc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub zip: Option<Zip>,
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nonce: Option<String>,
#[serde(flatten)]
pub extras: HashMap<String, String>,
}
impl Header {
pub fn new(algorithm: Algorithm) -> Self {
Header {
typ: Some("JWT".to_string()),
alg: algorithm,
cty: None,
jku: None,
jwk: None,
kid: None,
x5u: None,
x5c: None,
x5t: None,
x5t_s256: None,
crit: None,
enc: None,
zip: None,
url: None,
nonce: None,
extras: Default::default(),
}
}
pub(crate) fn from_encoded<T: AsRef<[u8]>>(encoded_part: T) -> Result<Self> {
let decoded = b64_decode(encoded_part)?;
Ok(serde_json::from_slice(&decoded)?)
}
pub fn x5c_der(&self) -> Result<Option<Vec<Vec<u8>>>> {
Ok(self
.x5c
.as_ref()
.map(|b64_certs| {
b64_certs.iter().map(|x| STANDARD.decode(x)).collect::<result::Result<_, _>>()
})
.transpose()?)
}
}
impl Default for Header {
fn default() -> Self {
Header::new(Algorithm::default())
}
}