use base64ct::{Base64UrlUnpadded, Encoding};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use smallvec::{smallvec, SmallVec};
use core::fmt;
use crate::{
alloc::{Cow, String, Vec},
Algorithm, Claims, ParseError, ValidationError,
};
const SIGNATURE_SIZE: usize = 128;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Header {
#[serde(rename = "jku", default, skip_serializing_if = "Option::is_none")]
pub key_set_url: Option<String>,
#[serde(rename = "kid", default, skip_serializing_if = "Option::is_none")]
pub key_id: Option<String>,
#[serde(rename = "x5u", default, skip_serializing_if = "Option::is_none")]
pub certificate_url: Option<String>,
#[serde(
rename = "x5t",
with = "base64url",
default,
skip_serializing_if = "Option::is_none"
)]
pub certificate_sha1_thumbprint: Option<[u8; 20]>,
#[serde(
rename = "x5t#S256",
with = "base64url",
default,
skip_serializing_if = "Option::is_none"
)]
pub certificate_thumbprint: Option<[u8; 32]>,
#[serde(rename = "typ", default, skip_serializing_if = "Option::is_none")]
pub token_type: Option<String>,
}
impl Header {
#[must_use]
pub fn with_key_set_url(mut self, key_set_url: impl Into<String>) -> Self {
self.key_set_url = Some(key_set_url.into());
self
}
#[must_use]
pub fn with_key_id(mut self, key_id: impl Into<String>) -> Self {
self.key_id = Some(key_id.into());
self
}
#[must_use]
pub fn with_certificate_url(mut self, certificate_url: impl Into<String>) -> Self {
self.certificate_url = Some(certificate_url.into());
self
}
#[must_use]
pub fn with_certificate_sha1_thumbprint(mut self, certificate_thumbprint: [u8; 20]) -> Self {
self.certificate_sha1_thumbprint = Some(certificate_thumbprint);
self
}
#[must_use]
pub fn with_certificate_thumbprint(mut self, certificate_thumbprint: [u8; 32]) -> Self {
self.certificate_thumbprint = Some(certificate_thumbprint);
self
}
#[must_use]
pub fn with_token_type(mut self, token_type: impl Into<String>) -> Self {
self.token_type = Some(token_type.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct CompleteHeader<'a> {
#[serde(rename = "alg")]
pub algorithm: Cow<'a, str>,
#[serde(rename = "cty", default, skip_serializing_if = "Option::is_none")]
pub content_type: Option<String>,
#[serde(flatten)]
pub inner: Header,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ContentType {
Json,
#[cfg(feature = "serde_cbor")]
Cbor,
}
#[derive(Debug, Clone)]
pub struct UntrustedToken<'a> {
pub(crate) signed_data: Cow<'a, [u8]>,
header: Header,
algorithm: String,
content_type: ContentType,
serialized_claims: Vec<u8>,
signature: SmallVec<[u8; SIGNATURE_SIZE]>,
}
#[derive(Debug, Clone)]
pub struct Token<T> {
header: Header,
claims: Claims<T>,
}
impl<T> Token<T> {
pub(crate) fn new(header: Header, claims: Claims<T>) -> Self {
Self { header, claims }
}
pub fn header(&self) -> &Header {
&self.header
}
pub fn claims(&self) -> &Claims<T> {
&self.claims
}
}
#[non_exhaustive]
pub struct SignedToken<A: Algorithm + ?Sized, T> {
pub signature: A::Signature,
pub token: Token<T>,
}
impl<A, T> fmt::Debug for SignedToken<A, T>
where
A: Algorithm,
A::Signature: fmt::Debug,
T: fmt::Debug,
{
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("SignedToken")
.field("token", &self.token)
.field("signature", &self.signature)
.finish()
}
}
impl<A, T> Clone for SignedToken<A, T>
where
A: Algorithm,
A::Signature: Clone,
T: Clone,
{
fn clone(&self) -> Self {
Self {
signature: self.signature.clone(),
token: self.token.clone(),
}
}
}
impl<'a> TryFrom<&'a str> for UntrustedToken<'a> {
type Error = ParseError;
fn try_from(s: &'a str) -> Result<Self, Self::Error> {
let token_parts: Vec<_> = s.splitn(4, '.').collect();
match &token_parts[..] {
[header, claims, signature] => {
let header = Base64UrlUnpadded::decode_vec(header)
.map_err(|_| ParseError::InvalidBase64Encoding)?;
let serialized_claims = Base64UrlUnpadded::decode_vec(claims)
.map_err(|_| ParseError::InvalidBase64Encoding)?;
let mut decoded_signature = smallvec![0; 3 * (signature.len() + 3) / 4];
let signature_len =
Base64UrlUnpadded::decode(signature, &mut decoded_signature[..])
.map_err(|_| ParseError::InvalidBase64Encoding)?
.len();
decoded_signature.truncate(signature_len);
let header: CompleteHeader<'_> =
serde_json::from_slice(&header).map_err(ParseError::MalformedHeader)?;
let content_type = match header.content_type {
None => ContentType::Json,
Some(ref s) if s.eq_ignore_ascii_case("json") => ContentType::Json,
#[cfg(feature = "serde_cbor")]
Some(ref s) if s.eq_ignore_ascii_case("cbor") => ContentType::Cbor,
Some(s) => return Err(ParseError::UnsupportedContentType(s)),
};
let signed_data = s.rsplit_once('.').unwrap().0.as_bytes();
Ok(Self {
signed_data: Cow::Borrowed(signed_data),
header: header.inner,
algorithm: header.algorithm.into_owned(),
content_type,
serialized_claims,
signature: decoded_signature,
})
}
_ => Err(ParseError::InvalidTokenStructure),
}
}
}
impl<'a> UntrustedToken<'a> {
pub fn new<S: AsRef<str> + ?Sized>(s: &'a S) -> Result<Self, ParseError> {
Self::try_from(s.as_ref())
}
pub fn into_owned(self) -> UntrustedToken<'static> {
UntrustedToken {
signed_data: Cow::Owned(self.signed_data.into_owned()),
header: self.header,
algorithm: self.algorithm,
content_type: self.content_type,
serialized_claims: self.serialized_claims,
signature: self.signature,
}
}
pub fn header(&self) -> &Header {
&self.header
}
pub fn algorithm(&self) -> &str {
&self.algorithm
}
pub fn signature_bytes(&self) -> &[u8] {
&self.signature
}
pub fn deserialize_claims_unchecked<T>(&self) -> Result<Claims<T>, ValidationError>
where
T: DeserializeOwned,
{
match self.content_type {
ContentType::Json => serde_json::from_slice(&self.serialized_claims)
.map_err(ValidationError::MalformedClaims),
#[cfg(feature = "serde_cbor")]
ContentType::Cbor => serde_cbor::from_slice(&self.serialized_claims)
.map_err(ValidationError::MalformedCborClaims),
}
}
}
mod base64url {
use base64ct::{Base64UrlUnpadded, Encoding};
use serde::{
de::{Error as DeError, Visitor},
Deserializer, Serializer,
};
use core::{fmt, marker::PhantomData};
#[allow(clippy::option_if_let_else)] pub fn serialize<T, S>(value: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
where
T: AsRef<[u8]>,
S: Serializer,
{
if let Some(value) = value {
let bytes = value.as_ref();
serializer.serialize_str(&Base64UrlUnpadded::encode_string(bytes))
} else {
serializer.serialize_none()
}
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
where
T: Default + AsMut<[u8]>,
D: Deserializer<'de>,
{
struct Base64Visitor<V>(PhantomData<V>);
impl<V> Visitor<'_> for Base64Visitor<V>
where
V: Default + AsMut<[u8]>,
{
type Value = V;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "base64url-encoded digest")
}
fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
let mut bytes = V::default();
let expected_len = bytes.as_mut().len();
let decoded_len = value.len() * 3 / 4;
if decoded_len != expected_len {
return Err(E::invalid_length(decoded_len, &self));
}
let len = Base64UrlUnpadded::decode(value, bytes.as_mut())
.map_err(E::custom)?
.len();
if len != expected_len {
return Err(E::invalid_length(len, &self));
}
Ok(bytes)
}
}
deserializer
.deserialize_str(Base64Visitor(PhantomData))
.map(Some)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alg::{Hs256, Hs256Key},
alloc::ToOwned,
AlgorithmExt, Empty,
};
use assert_matches::assert_matches;
use base64ct::{Base64UrlUnpadded, Encoding};
type Obj = serde_json::Map<String, serde_json::Value>;
const HS256_TOKEN: &str = "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.\
eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFt\
cGxlLmNvbS9pc19yb290Ijp0cnVlfQ.\
dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
const HS256_KEY: &str = "AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75\
aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow";
#[test]
fn invalid_token_structure() {
let mangled_str = HS256_TOKEN.replace('.', "");
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::InvalidTokenStructure
);
let mut mangled_str = HS256_TOKEN.to_owned();
let signature_start = mangled_str.rfind('.').unwrap();
mangled_str.truncate(signature_start);
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::InvalidTokenStructure
);
let mut mangled_str = HS256_TOKEN.to_owned();
mangled_str.push('.');
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::InvalidTokenStructure
);
}
#[test]
fn base64_error_during_parsing() {
let mangled_str = HS256_TOKEN.replace('0', "+");
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::InvalidBase64Encoding
);
}
#[test]
fn base64_padding_error_during_parsing() {
let mut mangled_str = HS256_TOKEN.to_owned();
mangled_str.pop();
mangled_str.push('_'); assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::InvalidBase64Encoding
);
}
#[test]
fn header_fields_are_not_serialized_if_not_present() {
let header = Header::default();
let json = serde_json::to_string(&header).unwrap();
assert_eq!(json, "{}");
}
#[test]
fn header_with_x5t_field() {
let header = r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pk"}"#;
let header: CompleteHeader<'_> = serde_json::from_str(header).unwrap();
let thumbprint = header.inner.certificate_sha1_thumbprint.unwrap();
assert_eq!(thumbprint[0], 0x94);
assert_eq!(thumbprint[19], 0x99);
let json = serde_json::to_value(header).unwrap();
assert_eq!(
json,
serde_json::json!({
"alg": "HS256",
"x5t": "lDpwLQbzRZmu4fjajvn3KWAx1pk",
})
);
}
#[test]
fn header_with_x5t_sha256_field() {
let header = r#"{"alg":"HS256","x5t#S256":"MV9b23bQeMQ7isAGTkoBZGErH853yGk0W_yUx1iU7dM"}"#;
let header: CompleteHeader<'_> = serde_json::from_str(header).unwrap();
let thumbprint = header.inner.certificate_thumbprint.unwrap();
assert_eq!(thumbprint[0], 0x31);
assert_eq!(thumbprint[31], 0xd3);
let json = serde_json::to_value(header).unwrap();
assert_eq!(
json,
serde_json::json!({
"alg": "HS256",
"x5t#S256": "MV9b23bQeMQ7isAGTkoBZGErH853yGk0W_yUx1iU7dM",
})
);
}
#[test]
fn malformed_header() {
let mangled_headers = [
r#"{"alg":"HS256""#,
"{}",
r#"{"alg":5}"#,
r#"{"alg":[1,"foo"]}"#,
r#"{"alg":false}"#,
r#"{"alg":"HS256","alg":"none"}"#,
r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1p"}"#,
r#"{"alg":"HS256","x5t":["lDpwLQbzRZmu4fjajvn3KWAx1pk"]}"#,
r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1 k"}"#,
r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pk==="}"#,
r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pkk"}"#,
r#"{"alg":"HS256","x5t":"MV9b23bQeMQ7isAGTkoBZGErH853yGk0W_yUx1iU7dM"}"#,
r#"{"alg":"HS256","x5t#S256":"lDpwLQbzRZmu4fjajvn3KWAx1pk"}"#,
];
for mangled_header in &mangled_headers {
let mangled_header = Base64UrlUnpadded::encode_string(mangled_header.as_bytes());
let mut mangled_str = HS256_TOKEN.to_owned();
mangled_str.replace_range(..mangled_str.find('.').unwrap(), &mangled_header);
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::MalformedHeader(_)
);
}
}
#[test]
fn unsupported_content_type() {
let mangled_header = br#"{"alg":"HS256","cty":"txt"}"#;
let mangled_header = Base64UrlUnpadded::encode_string(mangled_header);
let mut mangled_str = HS256_TOKEN.to_owned();
mangled_str.replace_range(..mangled_str.find('.').unwrap(), &mangled_header);
assert_matches!(
UntrustedToken::new(&mangled_str).unwrap_err(),
ParseError::UnsupportedContentType(ref s) if s == "txt"
);
}
#[test]
fn malformed_json_claims() {
let malformed_claims = [
r#"{"exp":1500000000"#,
r#"{"exp":"1500000000"}"#,
r#"{"exp":false}"#,
r#"{"exp":1500000000,"nbf":1400000000,"exp":1510000000}"#,
r#"{"exp":1500000000000000000000000000000000}"#,
];
let claims_start = HS256_TOKEN.find('.').unwrap() + 1;
let claims_end = HS256_TOKEN.rfind('.').unwrap();
let key = Base64UrlUnpadded::decode_vec(HS256_KEY).unwrap();
let key = Hs256Key::new(&key);
for claims in &malformed_claims {
let encoded_claims = Base64UrlUnpadded::encode_string(claims.as_bytes());
let mut mangled_str = HS256_TOKEN.to_owned();
mangled_str.replace_range(claims_start..claims_end, &encoded_claims);
let token = UntrustedToken::new(&mangled_str).unwrap();
assert_matches!(
Hs256.validate_integrity::<Obj>(&token, &key).unwrap_err(),
ValidationError::MalformedClaims(_),
"Failing claims: {}",
claims
);
}
}
fn test_invalid_signature_len(mangled_str: &str, actual_len: usize) {
let token = UntrustedToken::new(&mangled_str).unwrap();
let key = Base64UrlUnpadded::decode_vec(HS256_KEY).unwrap();
let key = Hs256Key::new(&key);
let err = Hs256.validate_integrity::<Empty>(&token, &key).unwrap_err();
assert_matches!(
err,
ValidationError::InvalidSignatureLen { actual, expected: 32 }
if actual == actual_len
);
}
#[test]
fn short_signature_error() {
test_invalid_signature_len(&HS256_TOKEN[..HS256_TOKEN.len() - 3], 30);
}
#[test]
fn long_signature_error() {
let mut mangled_string = HS256_TOKEN.to_owned();
mangled_string.push('a');
test_invalid_signature_len(&mangled_string, 33);
}
}