use alloc::{format, string::String, vec, vec::Vec};
use thiserror::Error;
use crate::{
crypto,
format::{
Compact, DecodeFormat, DecodeFormatWithContext, Format, JsonFlattened, JsonGeneral,
JsonGeneralSignature,
},
header, Base64UrlString, JoseHeader,
};
mod builder;
mod sign;
mod verify;
#[doc(inline)]
pub use {builder::*, sign::*, verify::*};
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum PayloadKind {
Attached(PayloadData),
Detached(PayloadData),
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum PayloadData {
Standard(Base64UrlString),
}
pub trait IntoPayload {
type Error;
fn into_payload(self) -> Result<PayloadKind, Self::Error>;
}
pub trait FromRawPayload: Sized {
type Error;
type Context;
fn from_attached(context: &Self::Context, payload: PayloadData) -> Result<Self, Self::Error>;
fn from_detached<F, T>(
context: &Self::Context,
header: &JoseHeader<F, T>,
) -> Result<(Self, PayloadData), Self::Error>;
fn from_detached_many<F, T>(
context: &Self::Context,
headers: &[JoseHeader<F, T>],
) -> Result<(Self, PayloadData), Self::Error>;
}
#[derive(Debug, Error)]
pub enum SignError<P> {
#[error("the number of headers does not match the number of signers")]
HeaderCountMismatch,
#[error("failed to serialize header: {0}")]
SerializeHeader(#[source] serde_json::Error),
#[error("the protected header was empty on a compact JWS")]
EmptyProtectedHeader,
#[error("invalid JWS header: {0}")]
InvalidHeader(#[source] header::Error),
#[error(transparent)]
Sign(crypto::Error),
#[error(transparent)]
Payload(P),
}
#[derive(Debug)]
pub struct JsonWebSignature<F: Format, T = ()> {
header: F::JwsHeader,
payload: T,
}
impl<F: Format> JsonWebSignature<F, ()> {
pub fn builder() -> JsonWebSignatureBuilder<F> {
JsonWebSignatureBuilder::new()
}
}
impl<F: Format, T> JsonWebSignature<F, T> {
pub(crate) fn new(header: F::JwsHeader, payload: T) -> Self {
Self { header, payload }
}
pub fn payload(&self) -> &T {
&self.payload
}
}
impl<T> JsonWebSignature<Compact, T> {
pub fn header(&self) -> &JoseHeader<Compact, header::Jws> {
&self.header
}
}
impl<T> JsonWebSignature<JsonFlattened, T> {
pub fn header(&self) -> &JoseHeader<JsonFlattened, header::Jws> {
&self.header
}
}
impl<T> JsonWebSignature<JsonGeneral, T> {
pub fn header(&self) -> &Vec<JoseHeader<JsonGeneral, header::Jws>> {
&self.header
}
}
impl<F: Format, T: IntoPayload> JsonWebSignature<F, T> {
pub fn sign<S: AsRef<[u8]>>(
mut self,
signer: &mut dyn Signer<S>,
) -> Result<Signed<F>, SignError<T::Error>> {
F::update_header(&mut self.header, signer);
let serialized_header = F::serialize_header(self.header).map_err(|x| match x {
SignError::HeaderCountMismatch => SignError::HeaderCountMismatch,
SignError::SerializeHeader(x) => SignError::SerializeHeader(x),
SignError::InvalidHeader(x) => SignError::InvalidHeader(x),
SignError::EmptyProtectedHeader => SignError::EmptyProtectedHeader,
SignError::Sign(x) => SignError::Sign(x),
SignError::Payload(x) => match x {},
})?;
let mut msg = F::message_from_header(&serialized_header)
.map(|x| x.to_vec())
.unwrap_or_default();
msg.push(b'.');
let payload = self.payload.into_payload().map_err(SignError::Payload)?;
let payload = match payload {
PayloadKind::Attached(PayloadData::Standard(b64)) => {
msg.extend(b64.as_bytes());
Some(PayloadData::Standard(b64))
}
PayloadKind::Detached(PayloadData::Standard(b64)) => {
msg.extend(b64.as_bytes());
None
}
};
let signature = signer.sign(&msg).map_err(SignError::Sign)?;
Ok(Signed {
value: F::finalize(serialized_header, payload, signature.as_ref())
.map_err(SignError::SerializeHeader)?,
})
}
}
impl<T: IntoPayload> JsonWebSignature<JsonGeneral, T> {
pub fn sign_many<'s, S: AsRef<[u8]> + 's>(
self,
signers: impl IntoIterator<Item = &'s mut dyn Signer<S>>,
) -> Result<Signed<JsonGeneral>, SignError<T::Error>> {
if self.header.is_empty() {
return Err(SignError::HeaderCountMismatch);
}
let signers = signers.into_iter().collect::<Vec<_>>();
if signers.len() != self.header.len() {
return Err(SignError::HeaderCountMismatch);
}
let payload = self.payload.into_payload().map_err(SignError::Payload)?;
let payload_msg = match payload {
PayloadKind::Attached(PayloadData::Standard(ref b64)) => b64.as_bytes(),
PayloadKind::Detached(_) => todo!(),
};
let mut signatures = vec![];
for (mut hdr, signer) in self.header.into_iter().zip(signers) {
hdr.overwrite_alg_and_key_id(signer.algorithm(), signer.key_id());
let mut msg = vec![];
let serialized_hdr = {
let (protected, unprotected) =
hdr.into_values().map_err(SignError::InvalidHeader)?;
let protected = match protected {
Some(hdr) => {
let json =
serde_json::to_string(&hdr).map_err(SignError::SerializeHeader)?;
let encoded = Base64UrlString::encode(json);
msg.extend(encoded.as_bytes());
Some(encoded)
}
None => None,
};
(protected, unprotected)
};
msg.push(b'.');
msg.extend(payload_msg);
let signature = signer.sign(&msg).map_err(SignError::Sign)?;
signatures.push(JsonGeneralSignature {
protected: serialized_hdr.0,
header: serialized_hdr.1,
signature: Base64UrlString::encode(signature.as_ref()),
});
}
let payload = match payload {
PayloadKind::Attached(PayloadData::Standard(s)) => Some(s),
PayloadKind::Detached(_) => None,
};
Ok(Signed {
value: JsonGeneral {
payload,
signatures,
},
})
}
}
#[derive(Debug, Error)]
pub enum ParseCompactError<P> {
#[error("encountered unsupported critical headers (crit header field)")]
UnsupportedCriticalHeader,
#[error("one of the parts was an invalid UTF-8 byte sequence")]
InvalidUtf8Encoding,
#[error("one of the parts was an invalid json string")]
InvalidJson(#[source] serde_json::Error),
#[error("invalid JWS header: {0}")]
InvalidHeader(#[source] header::Error),
#[error("got compact representation that didn't have 3 parts")]
InvalidLength,
#[error(transparent)]
Payload(P),
}
impl<F: Format, T> crate::sealed::Sealed for JsonWebSignature<F, T> {}
impl<T: FromRawPayload<Context = ()>> DecodeFormat<Compact> for JsonWebSignature<Compact, T> {
type Decoded<D> = Unverified<D>;
type Error = ParseCompactError<T::Error>;
fn decode(input: Compact) -> Result<Self::Decoded<Self>, Self::Error> {
Self::decode_with_context(input, &())
}
}
impl<C, T: FromRawPayload<Context = C>> DecodeFormatWithContext<Compact, C>
for JsonWebSignature<Compact, T>
{
type Decoded<D> = Unverified<D>;
type Error = ParseCompactError<T::Error>;
fn decode_with_context(input: Compact, context: &C) -> Result<Unverified<Self>, Self::Error> {
if input.len() != 3 {
return Err(ParseCompactError::InvalidLength);
}
let (header, raw_header) = {
let raw = input.part(0).expect("`len()` is checked above to be 3");
let json = String::from_utf8(raw.decode())
.map_err(|_| ParseCompactError::InvalidUtf8Encoding)?;
let header = serde_json::from_str::<serde_json::Map<String, serde_json::Value>>(&json)
.map_err(ParseCompactError::InvalidJson)?;
let header = JoseHeader::from_values(Some(header), None)
.map_err(ParseCompactError::InvalidHeader)?;
(header, raw)
};
let (payload, raw_payload) = {
let raw = input.part(1).expect("`len()` is checked above to be 3");
let (payload, raw) = if raw.is_empty() {
T::from_detached(context, &header).map_err(ParseCompactError::Payload)?
} else {
let data = PayloadData::Standard(raw.clone());
(
T::from_attached(context, data.clone()).map_err(ParseCompactError::Payload)?,
data,
)
};
(payload, raw)
};
let PayloadData::Standard(raw_payload) = raw_payload;
let signature = input.part(2).expect("`len()` is checked above to be 3");
let msg = format!("{}.{}", raw_header, raw_payload);
Ok(Unverified {
value: JsonWebSignature { header, payload },
signature: signature.decode(),
msg: msg.into_bytes(),
})
}
}
fn parse_json_header<F: Format, E>(
protected: Option<&Base64UrlString>,
header: Option<serde_json::Map<String, serde_json::Value>>,
) -> Result<JoseHeader<F, header::Jws>, ParseJsonError<E>> {
let protected = match protected {
Some(encoded) => {
let json = String::from_utf8(encoded.decode())
.map_err(|_| ParseJsonError::InvalidUtf8Encoding)?;
let values = serde_json::from_str::<serde_json::Map<String, serde_json::Value>>(&json)
.map_err(ParseJsonError::InvalidJson)?;
Some(values)
}
None => None,
};
JoseHeader::from_values(protected, header).map_err(ParseJsonError::InvalidHeader)
}
#[derive(Debug, Error)]
pub enum ParseJsonError<P> {
#[error("the signatures array was empty")]
EmptySignatures,
#[error("invalid JWS header: {0}")]
InvalidHeader(#[source] header::Error),
#[error("protected header or signature contained invalid UTF-8")]
InvalidUtf8Encoding,
#[error("protected header contained invalid JSON")]
InvalidJson(#[source] serde_json::Error),
#[error(transparent)]
Payload(P),
}
impl<T: FromRawPayload<Context = ()>> DecodeFormat<JsonFlattened>
for JsonWebSignature<JsonFlattened, T>
{
type Decoded<D> = Unverified<D>;
type Error = ParseJsonError<T::Error>;
fn decode(input: JsonFlattened) -> Result<Self::Decoded<Self>, Self::Error> {
Self::decode_with_context(input, &())
}
}
impl<C, T: FromRawPayload<Context = C>> DecodeFormatWithContext<JsonFlattened, C>
for JsonWebSignature<JsonFlattened, T>
{
type Decoded<D> = Unverified<D>;
type Error = ParseJsonError<T::Error>;
fn decode_with_context(
JsonFlattened {
payload,
protected,
header,
signature,
}: JsonFlattened,
context: &C,
) -> Result<Self::Decoded<Self>, Self::Error> {
let protected_str = protected.clone().unwrap_or_default().into_inner();
let header = parse_json_header(protected.as_ref(), header)?;
let (payload, raw_payload) = match payload {
Some(b64) => (
T::from_attached(context, PayloadData::Standard(b64.clone()))
.map_err(ParseJsonError::Payload)?,
PayloadData::Standard(b64),
),
None => T::from_detached(context, &header).map_err(ParseJsonError::Payload)?,
};
let PayloadData::Standard(raw_payload) = raw_payload;
let msg = format!("{}.{}", protected_str, raw_payload);
Ok(Unverified {
value: JsonWebSignature { header, payload },
signature: signature.decode(),
msg: msg.into_bytes(),
})
}
}
impl<T: FromRawPayload<Context = ()>> DecodeFormat<JsonGeneral>
for JsonWebSignature<JsonGeneral, T>
{
type Decoded<D> = ManyUnverified<D>;
type Error = ParseJsonError<T::Error>;
fn decode(input: JsonGeneral) -> Result<Self::Decoded<Self>, Self::Error> {
Self::decode_with_context(input, &())
}
}
impl<C, T: FromRawPayload<Context = C>> DecodeFormatWithContext<JsonGeneral, C>
for JsonWebSignature<JsonGeneral, T>
{
type Decoded<D> = ManyUnverified<D>;
type Error = ParseJsonError<T::Error>;
fn decode_with_context(
JsonGeneral {
payload,
signatures,
}: JsonGeneral,
context: &C,
) -> Result<Self::Decoded<Self>, Self::Error> {
if signatures.is_empty() {
return Err(ParseJsonError::EmptySignatures);
}
let mut headers = Vec::with_capacity(signatures.len());
let mut sigs = Vec::with_capacity(signatures.len());
for sig in signatures {
let header = parse_json_header(sig.protected.as_ref(), sig.header)?;
headers.push(header);
sigs.push((sig.protected.unwrap_or_default(), sig.signature.decode()));
}
let (payload, raw_payload) = match payload {
Some(b64) => (
T::from_attached(context, PayloadData::Standard(b64.clone()))
.map_err(ParseJsonError::Payload)?,
PayloadData::Standard(b64),
),
None => T::from_detached_many(context, &headers).map_err(ParseJsonError::Payload)?,
};
let PayloadData::Standard(raw_payload) = raw_payload;
let unverified_signatures = sigs
.into_iter()
.map(|(protected, signature)| {
let msg = format!("{}.{}", protected, raw_payload);
(msg.into_bytes(), signature)
})
.collect();
Ok(ManyUnverified {
value: JsonWebSignature {
header: headers,
payload,
},
signatures: unverified_signatures,
})
}
}