use base64ct::{Base64UrlUnpadded, Encoding};
use serde::{
de::{DeserializeOwned, Error as DeError, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use smallvec::{smallvec, SmallVec};
use core::{cmp, fmt};
#[cfg(feature = "ciborium")]
use crate::error::CborDeError;
use crate::{
alloc::{format, Cow, String, Vec},
Algorithm, Claims, Empty, ParseError, ValidationError,
};
const SIGNATURE_SIZE: usize = 128;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum Thumbprint<const N: usize> {
Bytes([u8; N]),
String(String),
}
impl<const N: usize> From<[u8; N]> for Thumbprint<N> {
fn from(value: [u8; N]) -> Self {
Self::Bytes(value)
}
}
impl<const N: usize> From<String> for Thumbprint<N> {
fn from(s: String) -> Self {
Self::String(s)
}
}
impl<const N: usize> From<&str> for Thumbprint<N> {
fn from(s: &str) -> Self {
Self::String(s.into())
}
}
impl<const N: usize> Serialize for Thumbprint<N> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let input = match self {
Self::Bytes(bytes) => bytes.as_slice(),
Self::String(s) => s.as_bytes(),
};
serializer.serialize_str(&Base64UrlUnpadded::encode_string(input))
}
}
impl<'de, const N: usize> Deserialize<'de> for Thumbprint<N> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct Base64Visitor<const L: usize>;
impl<const L: usize> Visitor<'_> for Base64Visitor<L> {
type Value = Thumbprint<L>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "base64url-encoded thumbprint")
}
fn visit_str<E: DeError>(self, mut value: &str) -> Result<Self::Value, E> {
for _ in 0..2 {
if value.as_bytes().last() == Some(&b'=') {
value = &value[..value.len() - 1];
}
}
let decoded_len = value.len() * 3 / 4;
match decoded_len.cmp(&L) {
cmp::Ordering::Less => Err(E::custom(format!(
"thumbprint must contain at least {L} bytes"
))),
cmp::Ordering::Equal => {
let mut bytes = [0_u8; L];
let len = Base64UrlUnpadded::decode(value, &mut bytes)
.map_err(E::custom)?
.len();
debug_assert_eq!(len, L);
Ok(bytes.into())
}
cmp::Ordering::Greater => {
let decoded = Base64UrlUnpadded::decode_vec(value).map_err(E::custom)?;
let decoded = String::from_utf8(decoded)
.map_err(|err| E::custom(err.utf8_error()))?;
Ok(decoded.into())
}
}
}
}
deserializer.deserialize_str(Base64Visitor)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Header<T = Empty> {
#[serde(rename = "jku", default, skip_serializing_if = "Option::is_none")]
pub key_set_url: Option<String>,
#[serde(rename = "kid", default, skip_serializing_if = "Option::is_none")]
pub key_id: Option<String>,
#[serde(rename = "x5u", default, skip_serializing_if = "Option::is_none")]
pub certificate_url: Option<String>,
#[serde(rename = "x5t", default, skip_serializing_if = "Option::is_none")]
pub certificate_sha1_thumbprint: Option<Thumbprint<20>>,
#[serde(rename = "x5t#S256", default, skip_serializing_if = "Option::is_none")]
pub certificate_thumbprint: Option<Thumbprint<32>>,
#[serde(rename = "typ", default, skip_serializing_if = "Option::is_none")]
pub token_type: Option<String>,
#[serde(flatten)]
pub other_fields: T,
}
impl Header {
pub const fn empty() -> Self {
Self {
key_set_url: None,
key_id: None,
certificate_url: None,
certificate_sha1_thumbprint: None,
certificate_thumbprint: None,
token_type: None,
other_fields: Empty {},
}
}
}
impl<T> Header<T> {
pub const fn new(fields: T) -> Header<T> {
Header {
key_set_url: None,
key_id: None,
certificate_url: None,
certificate_sha1_thumbprint: None,
certificate_thumbprint: None,
token_type: None,
other_fields: fields,
}
}
#[must_use]
pub fn with_key_set_url(mut self, key_set_url: impl Into<String>) -> Self {
self.key_set_url = Some(key_set_url.into());
self
}
#[must_use]
pub fn with_key_id(mut self, key_id: impl Into<String>) -> Self {
self.key_id = Some(key_id.into());
self
}
#[must_use]
pub fn with_certificate_url(mut self, certificate_url: impl Into<String>) -> Self {
self.certificate_url = Some(certificate_url.into());
self
}
#[must_use]
pub fn with_certificate_sha1_thumbprint(
mut self,
certificate_thumbprint: impl Into<Thumbprint<20>>,
) -> Self {
self.certificate_sha1_thumbprint = Some(certificate_thumbprint.into());
self
}
#[must_use]
pub fn with_certificate_thumbprint(
mut self,
certificate_thumbprint: impl Into<Thumbprint<32>>,
) -> Self {
self.certificate_thumbprint = Some(certificate_thumbprint.into());
self
}
#[must_use]
pub fn with_token_type(mut self, token_type: impl Into<String>) -> Self {
self.token_type = Some(token_type.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct CompleteHeader<'a, T> {
#[serde(rename = "alg")]
pub algorithm: Cow<'a, str>,
#[serde(rename = "cty", default, skip_serializing_if = "Option::is_none")]
pub content_type: Option<String>,
#[serde(flatten)]
pub inner: T,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ContentType {
Json,
#[cfg(feature = "ciborium")]
Cbor,
}
#[derive(Debug, Clone)]
pub struct UntrustedToken<'a, H = Empty> {
pub(crate) signed_data: Cow<'a, [u8]>,
header: Header<H>,
algorithm: String,
content_type: ContentType,
serialized_claims: Vec<u8>,
signature: SmallVec<[u8; SIGNATURE_SIZE]>,
}
#[derive(Debug, Clone)]
pub struct Token<T, H = Empty> {
header: Header<H>,
claims: Claims<T>,
}
impl<T, H> Token<T, H> {
pub(crate) fn new(header: Header<H>, claims: Claims<T>) -> Self {
Self { header, claims }
}
pub fn header(&self) -> &Header<H> {
&self.header
}
pub fn claims(&self) -> &Claims<T> {
&self.claims
}
pub fn into_parts(self) -> (Header<H>, Claims<T>) {
(self.header, self.claims)
}
}
#[non_exhaustive]
pub struct SignedToken<A: Algorithm + ?Sized, T, H = Empty> {
pub signature: A::Signature,
pub token: Token<T, H>,
}
impl<A, T, H> fmt::Debug for SignedToken<A, T, H>
where
A: Algorithm,
A::Signature: fmt::Debug,
T: fmt::Debug,
H: fmt::Debug,
{
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("SignedToken")
.field("token", &self.token)
.field("signature", &self.signature)
.finish()
}
}
impl<A, T, H> Clone for SignedToken<A, T, H>
where
A: Algorithm,
A::Signature: Clone,
T: Clone,
H: Clone,
{
fn clone(&self) -> Self {
Self {
signature: self.signature.clone(),
token: self.token.clone(),
}
}
}
impl<'a, H: DeserializeOwned> TryFrom<&'a str> for UntrustedToken<'a, H> {
type Error = ParseError;
fn try_from(s: &'a str) -> Result<Self, Self::Error> {
let token_parts: Vec<_> = s.splitn(4, '.').collect();
match &token_parts[..] {
[header, claims, signature] => {
let header = Base64UrlUnpadded::decode_vec(header)
.map_err(|_| ParseError::InvalidBase64Encoding)?;
let serialized_claims = Base64UrlUnpadded::decode_vec(claims)
.map_err(|_| ParseError::InvalidBase64Encoding)?;
let mut decoded_signature = smallvec![0; 3 * (signature.len() + 3) / 4];
let signature_len =
Base64UrlUnpadded::decode(signature, &mut decoded_signature[..])
.map_err(|_| ParseError::InvalidBase64Encoding)?
.len();
decoded_signature.truncate(signature_len);
let header: CompleteHeader<_> =
serde_json::from_slice(&header).map_err(ParseError::MalformedHeader)?;
let content_type = match header.content_type {
None => ContentType::Json,
Some(s) if s.eq_ignore_ascii_case("json") => ContentType::Json,
#[cfg(feature = "ciborium")]
Some(s) if s.eq_ignore_ascii_case("cbor") => ContentType::Cbor,
Some(s) => return Err(ParseError::UnsupportedContentType(s)),
};
let signed_data = s.rsplit_once('.').unwrap().0.as_bytes();
Ok(Self {
signed_data: Cow::Borrowed(signed_data),
header: header.inner,
algorithm: header.algorithm.into_owned(),
content_type,
serialized_claims,
signature: decoded_signature,
})
}
_ => Err(ParseError::InvalidTokenStructure),
}
}
}
impl<'a> UntrustedToken<'a> {
pub fn new<S: AsRef<str> + ?Sized>(s: &'a S) -> Result<Self, ParseError> {
Self::try_from(s.as_ref())
}
}
impl<H> UntrustedToken<'_, H> {
pub fn into_owned(self) -> UntrustedToken<'static, H> {
UntrustedToken {
signed_data: Cow::Owned(self.signed_data.into_owned()),
header: self.header,
algorithm: self.algorithm,
content_type: self.content_type,
serialized_claims: self.serialized_claims,
signature: self.signature,
}
}
pub fn header(&self) -> &Header<H> {
&self.header
}
pub fn algorithm(&self) -> &str {
&self.algorithm
}
pub fn signature_bytes(&self) -> &[u8] {
&self.signature
}
pub fn deserialize_claims_unchecked<T>(&self) -> Result<Claims<T>, ValidationError>
where
T: DeserializeOwned,
{
match self.content_type {
ContentType::Json => serde_json::from_slice(&self.serialized_claims)
.map_err(ValidationError::MalformedClaims),
#[cfg(feature = "ciborium")]
ContentType::Cbor => {
ciborium::from_reader(&self.serialized_claims[..]).map_err(|err| {
ValidationError::MalformedCborClaims(match err {
CborDeError::Io(err) => CborDeError::Io(anyhow::anyhow!(err)),
CborDeError::Syntax(offset) => CborDeError::Syntax(offset),
CborDeError::Semantic(offset, description) => {
CborDeError::Semantic(offset, description)
}
CborDeError::RecursionLimitExceeded => CborDeError::RecursionLimitExceeded,
})
})
}
}
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use base64ct::{Base64UrlUnpadded, Encoding};
use super::*;
use crate::{
alg::{Hs256, Hs256Key},
alloc::{ToOwned, ToString},
AlgorithmExt, Empty,
};
type Obj = serde_json::Map<String, serde_json::Value>;
const HS256_TOKEN: &str = "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.\
eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFt\
cGxlLmNvbS9pc19yb290Ijp0cnVlfQ.\
dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
const HS256_KEY: &str = "AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75\
aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow";
#[test]
fn invalid_token_structure() {
let mangled_str = HS256_TOKEN.replace('.', "");
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::InvalidTokenStructure
);
let mut mangled_str = HS256_TOKEN.to_owned();
let signature_start = mangled_str.rfind('.').unwrap();
mangled_str.truncate(signature_start);
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::InvalidTokenStructure
);
let mut mangled_str = HS256_TOKEN.to_owned();
mangled_str.push('.');
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::InvalidTokenStructure
);
}
#[test]
fn base64_error_during_parsing() {
let mangled_str = HS256_TOKEN.replace('0', "+");
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::InvalidBase64Encoding
);
}
#[test]
fn base64_padding_error_during_parsing() {
let mut mangled_str = HS256_TOKEN.to_owned();
mangled_str.pop();
mangled_str.push('_'); assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::InvalidBase64Encoding
);
}
#[test]
fn header_fields_are_not_serialized_if_not_present() {
let header = Header::empty();
let json = serde_json::to_string(&header).unwrap();
assert_eq!(json, "{}");
}
#[test]
fn header_with_x5t_field() {
let header = r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pk"}"#;
let header: CompleteHeader<Header<Empty>> = serde_json::from_str(header).unwrap();
let thumbprint = header.inner.certificate_sha1_thumbprint.as_ref().unwrap();
let Thumbprint::Bytes(thumbprint) = thumbprint else {
unreachable!();
};
assert_eq!(thumbprint[0], 0x94);
assert_eq!(thumbprint[19], 0x99);
let json = serde_json::to_value(header).unwrap();
assert_eq!(
json,
serde_json::json!({
"alg": "HS256",
"x5t": "lDpwLQbzRZmu4fjajvn3KWAx1pk",
})
);
}
#[test]
fn header_with_padded_x5t_field() {
let header = r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pk=="}"#;
let header: CompleteHeader<Header<Empty>> = serde_json::from_str(header).unwrap();
let thumbprint = header.inner.certificate_sha1_thumbprint.as_ref().unwrap();
let Thumbprint::Bytes(thumbprint) = thumbprint else {
unreachable!()
};
assert_eq!(thumbprint[0], 0x94);
assert_eq!(thumbprint[19], 0x99);
}
#[test]
fn header_with_hex_x5t_field() {
let header =
r#"{"alg":"HS256","x5t":"NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg"}"#;
let header: CompleteHeader<Header<Empty>> = serde_json::from_str(header).unwrap();
let thumbprint = header.inner.certificate_sha1_thumbprint.as_ref().unwrap();
let Thumbprint::String(thumbprint) = thumbprint else {
unreachable!()
};
assert_eq!(thumbprint, "65AF6909B1B0758E06C6E048C46002B5C695E36B");
let json = serde_json::to_value(header).unwrap();
assert_eq!(
json,
serde_json::json!({
"alg": "HS256",
"x5t": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg",
})
);
}
#[test]
fn header_with_padded_hex_x5t_field() {
let header =
r#"{"alg":"HS256","x5t":"NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg=="}"#;
let header: CompleteHeader<Header<Empty>> = serde_json::from_str(header).unwrap();
let thumbprint = header.inner.certificate_sha1_thumbprint.as_ref().unwrap();
let Thumbprint::String(thumbprint) = thumbprint else {
unreachable!()
};
assert_eq!(thumbprint, "65AF6909B1B0758E06C6E048C46002B5C695E36B");
}
#[test]
fn header_with_overly_short_x5t_field() {
let header = r#"{"alg":"HS256","x5t":"aGk="}"#;
let err = serde_json::from_str::<CompleteHeader<Header<Empty>>>(header).unwrap_err();
let err = err.to_string();
assert!(
err.contains("thumbprint must contain at least 20 bytes"),
"{err}"
);
}
#[test]
fn header_with_non_base64_x5t_field() {
let headers = [
r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1p?"}"#,
r#"{"alg":"HS256","x5t":"NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk!RTM2Qg"}"#,
];
for header in headers {
let err = serde_json::from_str::<CompleteHeader<Header<Empty>>>(header).unwrap_err();
let err = err.to_string();
assert!(err.contains("Base64"), "{err}");
}
}
#[test]
fn header_with_x5t_sha256_field() {
let header = r#"{"alg":"HS256","x5t#S256":"MV9b23bQeMQ7isAGTkoBZGErH853yGk0W_yUx1iU7dM"}"#;
let header: CompleteHeader<Header<Empty>> = serde_json::from_str(header).unwrap();
let thumbprint = header.inner.certificate_thumbprint.as_ref().unwrap();
let Thumbprint::Bytes(thumbprint) = thumbprint else {
unreachable!()
};
assert_eq!(thumbprint[0], 0x31);
assert_eq!(thumbprint[31], 0xd3);
let json = serde_json::to_value(header).unwrap();
assert_eq!(
json,
serde_json::json!({
"alg": "HS256",
"x5t#S256": "MV9b23bQeMQ7isAGTkoBZGErH853yGk0W_yUx1iU7dM",
})
);
}
#[test]
fn malformed_header() {
let mangled_headers = [
r#"{"alg":"HS256""#,
"{}",
r#"{"alg":5}"#,
r#"{"alg":[1,"foo"]}"#,
r#"{"alg":false}"#,
r#"{"alg":"HS256","alg":"none"}"#,
r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1p"}"#,
r#"{"alg":"HS256","x5t":["lDpwLQbzRZmu4fjajvn3KWAx1pk"]}"#,
r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1 k"}"#,
r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pk==="}"#,
r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pkk"}"#,
r#"{"alg":"HS256","x5t":"MV9b23bQeMQ7isAGTkoBZGErH853yGk0W_yUx1iU7dM"}"#,
r#"{"alg":"HS256","x5t#S256":"lDpwLQbzRZmu4fjajvn3KWAx1pk"}"#,
];
for mangled_header in &mangled_headers {
let mangled_header = Base64UrlUnpadded::encode_string(mangled_header.as_bytes());
let mut mangled_str = HS256_TOKEN.to_owned();
mangled_str.replace_range(..mangled_str.find('.').unwrap(), &mangled_header);
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::MalformedHeader(_)
);
}
}
#[test]
fn unsupported_content_type() {
let mangled_header = br#"{"alg":"HS256","cty":"txt"}"#;
let mangled_header = Base64UrlUnpadded::encode_string(mangled_header);
let mut mangled_str = HS256_TOKEN.to_owned();
mangled_str.replace_range(..mangled_str.find('.').unwrap(), &mangled_header);
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::UnsupportedContentType(s) if s == "txt"
);
}
#[test]
fn extracting_custom_header_fields() {
let header = r#"{"alg":"HS256","custom":[1,"field"],"x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pk"}"#;
let header: CompleteHeader<Header<Obj>> = serde_json::from_str(header).unwrap();
assert_eq!(header.algorithm, "HS256");
assert!(header.inner.certificate_sha1_thumbprint.is_some());
assert_eq!(header.inner.other_fields.len(), 1);
assert!(header.inner.other_fields["custom"].is_array());
}
#[test]
fn malformed_json_claims() {
let malformed_claims = [
r#"{"exp":1500000000"#,
r#"{"exp":"1500000000"}"#,
r#"{"exp":false}"#,
r#"{"exp":1500000000,"nbf":1400000000,"exp":1510000000}"#,
r#"{"exp":1500000000000000000000000000000000}"#,
];
let claims_start = HS256_TOKEN.find('.').unwrap() + 1;
let claims_end = HS256_TOKEN.rfind('.').unwrap();
let key = Base64UrlUnpadded::decode_vec(HS256_KEY).unwrap();
let key = Hs256Key::new(key);
for claims in &malformed_claims {
let encoded_claims = Base64UrlUnpadded::encode_string(claims.as_bytes());
let mut mangled_str = HS256_TOKEN.to_owned();
mangled_str.replace_range(claims_start..claims_end, &encoded_claims);
let token = UntrustedToken::new(&mangled_str).unwrap();
assert_matches!(
Hs256.validator::<Obj>(&key).validate(&token).unwrap_err(),
ValidationError::MalformedClaims(_),
"Failing claims: {claims}"
);
}
}
fn test_invalid_signature_len(mangled_str: &str, actual_len: usize) {
let token = UntrustedToken::new(&mangled_str).unwrap();
let key = Base64UrlUnpadded::decode_vec(HS256_KEY).unwrap();
let key = Hs256Key::new(key);
let err = Hs256.validator::<Empty>(&key).validate(&token).unwrap_err();
assert_matches!(
err,
ValidationError::InvalidSignatureLen { actual, expected: 32 }
if actual == actual_len
);
}
#[test]
fn short_signature_error() {
test_invalid_signature_len(&HS256_TOKEN[..HS256_TOKEN.len() - 3], 30);
}
#[test]
fn long_signature_error() {
let mut mangled_string = HS256_TOKEN.to_owned();
mangled_string.push('a');
test_invalid_signature_len(&mangled_string, 33);
}
}