use alloc::{
collections::{BTreeMap, BTreeSet},
string::{String, ToString},
};
use core::{marker::PhantomData, ops::Deref};
use mediatype::MediaType;
use serde::Deserialize;
use serde_json::{Map, Value};
mod builder;
mod error;
mod parameters;
mod types;
mod value;
use self::parameters::Parameters;
#[doc(inline)]
pub use self::{
builder::{JoseHeaderBuilder, JoseHeaderBuilderError},
error::Error,
types::*,
value::*,
};
use crate::{
format::Format,
jwa::{JsonWebContentEncryptionAlgorithm, JsonWebEncryptionAlgorithm, JsonWebSigningAlgorithm},
uri::BorrowedUri,
JsonWebKey, UntypedAdditionalProperties,
};
#[derive(Debug)]
pub struct JoseHeader<F, T> {
parameters: Parameters<T>,
_format: PhantomData<F>,
}
impl<F> JoseHeader<F, Jws>
where
F: Format,
{
pub(crate) fn overwrite_alg_and_key_id(
&mut self,
alg: JsonWebSigningAlgorithm,
kid: Option<&str>,
) {
let is_protected = matches!(self.algorithm(), HeaderValue::Protected(_));
let alg = if is_protected {
HeaderValue::Protected(alg)
} else {
HeaderValue::Unprotected(alg)
};
let kid = kid.map(|s| {
let kid = s.to_string();
if is_protected {
HeaderValue::Protected(kid)
} else {
HeaderValue::Unprotected(kid)
}
});
self.parameters.key_id = kid;
self.parameters.specific.algorithm = alg;
}
}
impl<F, T> JoseHeader<F, T>
where
F: Format,
T: Type,
{
pub fn builder() -> JoseHeaderBuilder<F, T> {
JoseHeaderBuilder::default()
}
pub fn into_builder(self) -> JoseHeaderBuilder<F, T> {
JoseHeaderBuilder::from_header(self)
}
pub fn jwk_set_url(&self) -> Option<HeaderValue<BorrowedUri<'_>>> {
self.parameters
.jwk_set_url
.as_ref()
.map(|x| x.as_ref().map(|x| x.borrow()))
}
pub fn json_web_key(&self) -> Option<HeaderValue<&JsonWebKey<UntypedAdditionalProperties>>> {
self.parameters
.json_web_key
.as_ref()
.map(HeaderValue::as_ref)
}
pub fn key_identifier(&self) -> Option<HeaderValue<&str>> {
self.parameters.key_id.as_ref().map(HeaderValue::as_deref)
}
pub fn x509_url(&self) -> Option<HeaderValue<BorrowedUri<'_>>> {
self.parameters
.x509_url
.as_ref()
.map(|x| x.as_ref().map(|x| x.borrow()))
}
pub fn x509_certificate_chain(&self) -> Option<HeaderValue<impl Iterator<Item = &[u8]>>> {
self.parameters
.x509_certificate_chain
.as_ref()
.map(HeaderValue::as_deref)
.map(|value| value.map(|certs| certs.iter().map(Deref::deref)))
}
pub fn x509_certificate_sha1_thumbprint(&self) -> Option<HeaderValue<&[u8; 20]>> {
self.parameters
.x509_certificate_sha1_thumbprint
.as_ref()
.map(HeaderValue::as_ref)
}
pub fn x509_certificate_sha256_thumbprint(&self) -> Option<HeaderValue<&[u8; 32]>> {
self.parameters
.x509_certificate_sha256_thumbprint
.as_ref()
.map(HeaderValue::as_ref)
}
pub fn typ(&self) -> Option<HeaderValue<MediaType<'_>>> {
self.parameters
.typ
.as_ref()
.map(|value| value.as_ref().map(|v| v.0.to_ref()))
}
pub fn content_type(&self) -> Option<HeaderValue<MediaType<'_>>> {
self.parameters
.content_type
.as_ref()
.map(|value| value.as_ref().map(|v| v.0.to_ref()))
}
pub fn additional(&self, parameter_name: impl AsRef<str>) -> Option<HeaderValue<&Value>> {
self.parameters
.additional
.get(parameter_name.as_ref())
.map(|v| v.as_ref())
}
pub fn critical_headers(&self) -> impl Iterator<Item = &'_ str> {
self.parameters
.critical_headers
.iter()
.flatten()
.map(Deref::deref)
}
}
impl<F> JoseHeader<F, Jws>
where
F: Format,
{
pub fn algorithm(&self) -> HeaderValue<&JsonWebSigningAlgorithm> {
self.parameters.specific.algorithm.as_ref()
}
pub fn payload_base64_url_encoded(&self) -> bool {
self.parameters
.specific
.payload_base64_url_encoded
.unwrap_or(true)
}
}
impl<F> JoseHeader<F, Jwe>
where
F: Format,
{
pub fn algorithm(&self) -> HeaderValue<&JsonWebEncryptionAlgorithm> {
self.parameters.specific.algorithm.as_ref()
}
pub fn content_encryption_algorithm(&self) -> HeaderValue<&JsonWebContentEncryptionAlgorithm> {
self.parameters
.specific
.content_encryption_algorithm
.as_ref()
}
}
impl<F, T> JoseHeader<F, T>
where
F: Format,
T: Type,
{
pub(crate) fn from_values(
protected: Option<Map<String, Value>>,
unprotected: Option<Map<String, Value>>,
) -> Result<Self, Error> {
let de = HeaderDeserializer::from_values(protected, unprotected)?;
let (specific, mut de) = T::from_deserializer(de).map_err(|(e, _)| e)?;
Ok(Self {
parameters: Parameters {
critical_headers: de
.deserialize_field("crit")
.transpose()?
.map(|v| v.protected().ok_or(Error::ExpectedProtected))
.transpose()?
.map(|v: BTreeSet<_>| {
if v.is_empty() {
return Err(Error::EmptyCriticalHeaders);
}
for forbidden in T::forbidden_critical_headers() {
if v.contains(*forbidden) {
return Err(Error::ForbiddenHeader(forbidden.to_string()));
}
}
Ok(v)
})
.transpose()?,
jwk_set_url: de.deserialize_field("jku").transpose()?,
json_web_key: de.deserialize_field("jwk").transpose()?,
key_id: de.deserialize_field("kid").transpose()?,
x509_url: de.deserialize_field("x5u").transpose()?,
x509_certificate_chain: de.deserialize_field("x5c").transpose()?,
x509_certificate_sha1_thumbprint: de.deserialize_field("x5t").transpose()?,
x509_certificate_sha256_thumbprint: de.deserialize_field("x5t#S256").transpose()?,
typ: de.deserialize_field("typ").transpose()?,
content_type: de.deserialize_field("cty").transpose()?,
specific,
additional: de.additional(),
},
_format: PhantomData,
})
}
#[allow(clippy::type_complexity)]
pub(crate) fn into_values(
self,
) -> Result<(Option<Map<String, Value>>, Option<Map<String, Value>>), Error> {
let parameters = self.parameters;
let mut collected_parameters = parameters.additional;
if let Some(crit) = parameters.critical_headers {
if !crit.is_empty() {
collected_parameters.insert(
"crit".to_string(),
HeaderValue::Protected(serde_json::to_value(crit)?),
);
}
} else {
collected_parameters.remove("crit");
}
macro_rules! insert {
($($name:literal: $value:expr),+,) => {
$(if let Some(value) = $value {
collected_parameters.insert(
$name.to_string(),
value.map(serde_json::to_value).transpose()?,
);
} else {
collected_parameters.remove($name);
})+
};
}
insert! {
"jku": parameters.jwk_set_url,
"jwk": parameters.json_web_key,
"kid": parameters.key_id,
"x5u": parameters.x509_url,
"x5c": parameters.x509_certificate_chain,
"x5t": parameters.x509_certificate_sha1_thumbprint,
"x5t#S256": parameters.x509_certificate_sha256_thumbprint,
"typ": parameters.typ,
"cty": parameters.content_type,
}
let mut protected = Map::new();
let mut unprotected = Map::new();
for (key, value) in collected_parameters
.into_iter()
.chain(parameters.specific.into_map()?)
{
match value {
HeaderValue::Protected(value) => protected.insert(key, value),
HeaderValue::Unprotected(value) => unprotected.insert(key, value),
};
}
let protected = match protected.is_empty() {
true => None,
false => Some(protected),
};
let unprotected = match unprotected.is_empty() {
true => None,
false => Some(unprotected),
};
Ok((protected, unprotected))
}
}
#[derive(Debug)]
pub struct HeaderDeserializer {
protected: Map<String, Value>,
unprotected: Map<String, Value>,
}
impl HeaderDeserializer {
fn from_values(
protected: Option<Map<String, Value>>,
unprotected: Option<Map<String, Value>>,
) -> Result<Self, Error> {
if let Some(ref p) = protected {
if p.is_empty() {
return Err(Error::EmptyHeader);
}
}
if let Some(ref u) = unprotected {
if u.is_empty() {
return Err(Error::EmptyHeader);
}
}
let (protected, unprotected) = match (protected, unprotected) {
(Some(protected), Some(unprotected)) => (protected, unprotected),
(Some(protected), None) => (protected, Map::new()),
(None, Some(unprotected)) => (Map::new(), unprotected),
(None, None) => return Err(Error::NoHeader),
};
let protected_keys: BTreeSet<&str> = protected.keys().map(Deref::deref).collect();
let unprotected_keys: BTreeSet<&str> = unprotected.keys().map(Deref::deref).collect();
if !protected_keys.is_disjoint(&unprotected_keys) {
return Err(Error::NotDisjoint);
}
Ok(Self {
protected,
unprotected,
})
}
fn deserialize_field<'a, 'de, V>(
&'a mut self,
field: &'a str,
) -> Option<Result<HeaderValue<V>, serde_json::Error>>
where
V: Deserialize<'de>,
'a: 'de,
{
if let Some(p) = self.protected.remove(field) {
debug_assert_eq!(self.unprotected.remove(field), None);
return Some(V::deserialize(p).map(|v| HeaderValue::Protected(v)));
}
if let Some(u) = self.unprotected.remove(field) {
debug_assert_eq!(self.protected.remove(field), None);
return Some(V::deserialize(u).map(|v| HeaderValue::Unprotected(v)));
}
None
}
fn additional(self) -> BTreeMap<String, HeaderValue<Value>> {
self.protected
.into_iter()
.map(|(field, value)| (field, HeaderValue::Protected(value)))
.chain(
self.unprotected
.into_iter()
.map(|(field, value)| (field, HeaderValue::Unprotected(value))),
)
.collect()
}
}