jwt_compact_frame/
token.rs

1//! `Token` and closely related types.
2
3use base64ct::{Base64UrlUnpadded, Encoding};
4use bounded_collections::{BoundedVec, ConstU32};
5use parity_scale_codec::{Decode, Encode, MaxEncodedLen};
6use scale_info::TypeInfo;
7use serde::{
8	de::{DeserializeOwned, Error as DeError, Visitor},
9	Deserialize, Deserializer, Serialize, Serializer,
10};
11
12use core::{cmp, fmt};
13
14#[cfg(feature = "ciborium")]
15use crate::error::CborDeError;
16use crate::{
17	alloc::{format, Cow, String, Vec},
18	Algorithm, Claims, Empty, ParseError, ValidationError,
19};
20
21/// Maximum "reasonable" signature size in bytes.
22const SIGNATURE_SIZE: u32 = 256;
23
24/// Representation of a X.509 certificate thumbprint (`x5t` and `x5t#S256` fields in
25/// the JWT [`Header`]).
26///
27/// As per the JWS spec in [RFC 7515], a certificate thumbprint (i.e., the SHA-1 / SHA-256
28/// digest of the certificate) must be base64url-encoded. Some JWS implementations however
29/// encode not the thumbprint itself, but rather its hex encoding, sometimes even
30/// with additional chars spliced within. To account for these implementations,
31/// a thumbprint is represented as an enum – either a properly encoded hash digest,
32/// or an opaque base64-encoded string.
33///
34/// [RFC 7515]: https://www.rfc-editor.org/rfc/rfc7515.html
35///
36/// # Examples
37///
38/// ```
39/// # use assert_matches::assert_matches;
40/// # use jwt_compact_frame::{
41/// #     alg::{Hs256, Hs256Key}, AlgorithmExt, Claims, Header, Thumbprint, UntrustedToken,
42/// # };
43/// # fn main() -> anyhow::Result<()> {
44/// let key = Hs256Key::new(b"super_secret_key_donut_steel");
45///
46/// // Creates a token with a custom-encoded SHA-1 thumbprint.
47/// let thumbprint = "65:AF:69:09:B1:B0:75:8E:06:C6:E0:48:C4:60:02:B5:C6:95:E3:6B";
48/// let header = Header::empty()
49///     .with_key_id("my_key")
50///     .with_certificate_sha1_thumbprint(thumbprint);
51/// let token = Hs256.token(&header, &Claims::empty(), &key)?;
52/// println!("{token}");
53///
54/// // Deserialize the token and check that its header fields are readable.
55/// let token = UntrustedToken::new(&token)?;
56/// let deserialized_thumbprint =
57///     token.header().certificate_sha1_thumbprint.as_ref();
58/// assert_matches!(
59///     deserialized_thumbprint,
60///     Some(Thumbprint::String(s)) if s == thumbprint
61/// );
62/// # Ok(())
63/// # }
64/// ```
65#[derive(Debug, Clone, PartialEq, Eq, Hash, Decode, Encode, TypeInfo)]
66#[non_exhaustive]
67pub enum Thumbprint<const N: usize> {
68	/// Byte representation of a SHA-1 or SHA-256 digest.
69	Bytes([u8; N]),
70	/// Opaque string representation of the thumbprint. It is the responsibility
71	/// of an application to verify that this value is valid.
72	String(String),
73}
74
75impl<const N: usize> From<[u8; N]> for Thumbprint<N> {
76	fn from(value: [u8; N]) -> Self {
77		Self::Bytes(value)
78	}
79}
80
81impl<const N: usize> From<String> for Thumbprint<N> {
82	fn from(s: String) -> Self {
83		Self::String(s)
84	}
85}
86
87impl<const N: usize> From<&str> for Thumbprint<N> {
88	fn from(s: &str) -> Self {
89		Self::String(s.into())
90	}
91}
92
93impl<const N: usize> Serialize for Thumbprint<N> {
94	fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
95		let input = match self {
96			Self::Bytes(bytes) => bytes.as_slice(),
97			Self::String(s) => s.as_bytes(),
98		};
99		serializer.serialize_str(&Base64UrlUnpadded::encode_string(input))
100	}
101}
102
103impl<'de, const N: usize> Deserialize<'de> for Thumbprint<N> {
104	fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
105		struct Base64Visitor<const L: usize>;
106
107		impl<const L: usize> Visitor<'_> for Base64Visitor<L> {
108			type Value = Thumbprint<L>;
109
110			fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
111				write!(formatter, "base64url-encoded thumbprint")
112			}
113
114			fn visit_str<E: DeError>(self, mut value: &str) -> Result<Self::Value, E> {
115				// Allow for padding. RFC 7515 defines base64url encoding as one without padding:
116				//
117				// > Base64url Encoding: Base64 encoding using the URL- and filename-safe
118				// > character set defined in Section 5 of RFC 4648 [RFC4648], with all trailing '='
119				// > characters omitted [...]
120				//
121				// ...but it's easy to trim the padding, so we support it anyway.
122				//
123				// See: https://www.rfc-editor.org/rfc/rfc7515.html#section-2
124				for _ in 0..2 {
125					if value.as_bytes().last() == Some(&b'=') {
126						value = &value[..value.len() - 1];
127					}
128				}
129
130				let decoded_len = value.len() * 3 / 4;
131				match decoded_len.cmp(&L) {
132					cmp::Ordering::Less => Err(E::custom(format!("thumbprint must contain at least {L} bytes"))),
133					cmp::Ordering::Equal => {
134						let mut bytes = [0_u8; L];
135						let len = Base64UrlUnpadded::decode(value, &mut bytes).map_err(E::custom)?.len();
136						debug_assert_eq!(len, L);
137						Ok(bytes.into())
138					},
139					cmp::Ordering::Greater => {
140						let decoded = Base64UrlUnpadded::decode_vec(value).map_err(E::custom)?;
141						let decoded = String::from_utf8(decoded).map_err(|err| E::custom(err.utf8_error()))?;
142						Ok(decoded.into())
143					},
144				}
145			}
146		}
147
148		deserializer.deserialize_str(Base64Visitor)
149	}
150}
151
152/// JWT header.
153///
154/// See [RFC 7515](https://tools.ietf.org/html/rfc7515#section-4.1) for the description
155/// of the fields. The purpose of all fields except `token_type` is to determine
156/// the verifying key. Since these values will be provided by the adversary in the case of
157/// an attack, they require additional verification (e.g., a provided certificate might
158/// be checked against the list of "acceptable" certificate authorities).
159///
160/// A `Header` can be created using `Default` implementation, which does not set any fields.
161/// For added fluency, you may use `with_*` methods:
162///
163/// ```
164/// # use jwt_compact_frame::Header;
165/// use sha2::{digest::Digest, Sha256};
166///
167/// let my_key_cert = // DER-encoded key certificate
168/// #   b"Hello, world!";
169/// let thumbprint: [u8; 32] = Sha256::digest(my_key_cert).into();
170/// let header = Header::empty()
171///     .with_key_id("my-key-id")
172///     .with_certificate_thumbprint(thumbprint);
173/// ```
174#[derive(Debug, Clone, Default, Serialize, Deserialize, Decode, Encode, TypeInfo, Eq, PartialEq)]
175#[non_exhaustive]
176pub struct Header<T = Empty> {
177	/// URL of the JSON Web Key Set containing the key that has signed the token.
178	/// This field is renamed to [`jku`] for serialization.
179	///
180	/// [`jku`]: https://www.rfc-editor.org/rfc/rfc7515.html#section-4.1.2
181	#[serde(rename = "jku", default, skip_serializing_if = "Option::is_none")]
182	pub key_set_url: Option<String>,
183
184	/// Identifier of the key that has signed the token. This field is renamed to [`kid`]
185	/// for serialization.
186	///
187	/// [`kid`]: https://www.rfc-editor.org/rfc/rfc7515.html#section-4.1.4
188	#[serde(rename = "kid", default, skip_serializing_if = "Option::is_none")]
189	pub key_id: Option<String>,
190
191	/// URL of the X.509 certificate for the signing key. This field is renamed to [`x5u`]
192	/// for serialization.
193	///
194	/// [`x5u`]: https://www.rfc-editor.org/rfc/rfc7515.html#section-4.1.5
195	#[serde(rename = "x5u", default, skip_serializing_if = "Option::is_none")]
196	pub certificate_url: Option<String>,
197
198	/// SHA-1 thumbprint of the X.509 certificate for the signing key.
199	/// This field is renamed to [`x5t`] for serialization.
200	///
201	/// [`x5t`]: https://www.rfc-editor.org/rfc/rfc7515.html#section-4.1.7
202	#[serde(rename = "x5t", default, skip_serializing_if = "Option::is_none")]
203	pub certificate_sha1_thumbprint: Option<Thumbprint<20>>,
204
205	/// SHA-256 thumbprint of the X.509 certificate for the signing key.
206	/// This field is renamed to [`x5t#S256`] for serialization.
207	///
208	/// [`x5t#S256`]: https://www.rfc-editor.org/rfc/rfc7515.html#section-4.1.8
209	#[serde(rename = "x5t#S256", default, skip_serializing_if = "Option::is_none")]
210	pub certificate_thumbprint: Option<Thumbprint<32>>,
211
212	/// Application-specific [token type]. This field is renamed to `typ` for serialization.
213	///
214	/// [token type]: https://tools.ietf.org/html/rfc7519#section-5.1
215	#[serde(rename = "typ", default, skip_serializing_if = "Option::is_none")]
216	pub token_type: Option<String>,
217
218	/// Other fields encoded in the header. These fields may be used by agreement between
219	/// the producer and consumer of the token to pass additional information.
220	/// See Sections 4.2 and 4.3 of [RFC 7515](https://www.rfc-editor.org/rfc/rfc7515#section-4.2)
221	/// for details.
222	///
223	/// For the token creation and validation to work properly, the fields type must [`Serialize`]
224	/// to a JSON object.
225	///
226	/// Note that these fields do not include the signing algorithm (`alg`) and the token
227	/// content type (`cty`) since both these fields have predefined semantics and are used
228	/// internally by the crate logic.
229	#[serde(flatten)]
230	pub other_fields: T,
231}
232
233impl Header {
234	/// Creates an empty header.
235	pub const fn empty() -> Self {
236		Self {
237			key_set_url: None,
238			key_id: None,
239			certificate_url: None,
240			certificate_sha1_thumbprint: None,
241			certificate_thumbprint: None,
242			token_type: None,
243			other_fields: Empty {},
244		}
245	}
246}
247
248impl<T> Header<T> {
249	/// Creates a header with the specified custom fields.
250	pub const fn new(fields: T) -> Self {
251		Self {
252			key_set_url: None,
253			key_id: None,
254			certificate_url: None,
255			certificate_sha1_thumbprint: None,
256			certificate_thumbprint: None,
257			token_type: None,
258			other_fields: fields,
259		}
260	}
261
262	/// Sets the `key_set_url` field for this header.
263	#[must_use]
264	pub fn with_key_set_url(mut self, key_set_url: impl Into<String>) -> Self {
265		self.key_set_url = Some(key_set_url.into());
266		self
267	}
268
269	/// Sets the `key_id` field for this header.
270	#[must_use]
271	pub fn with_key_id(mut self, key_id: impl Into<String>) -> Self {
272		self.key_id = Some(key_id.into());
273		self
274	}
275
276	/// Sets the `certificate_url` field for this header.
277	#[must_use]
278	pub fn with_certificate_url(mut self, certificate_url: impl Into<String>) -> Self {
279		self.certificate_url = Some(certificate_url.into());
280		self
281	}
282
283	/// Sets the `certificate_sha1_thumbprint` field for this header.
284	#[must_use]
285	pub fn with_certificate_sha1_thumbprint(mut self, certificate_thumbprint: impl Into<Thumbprint<20>>) -> Self {
286		self.certificate_sha1_thumbprint = Some(certificate_thumbprint.into());
287		self
288	}
289
290	/// Sets the `certificate_thumbprint` field for this header.
291	#[must_use]
292	pub fn with_certificate_thumbprint(mut self, certificate_thumbprint: impl Into<Thumbprint<32>>) -> Self {
293		self.certificate_thumbprint = Some(certificate_thumbprint.into());
294		self
295	}
296
297	/// Sets the `token_type` field for this header.
298	#[must_use]
299	pub fn with_token_type(mut self, token_type: impl Into<String>) -> Self {
300		self.token_type = Some(token_type.into());
301		self
302	}
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize, Decode, Encode, TypeInfo, Eq, PartialEq)]
306pub struct CompleteHeader<'a, T> {
307	#[serde(rename = "alg")]
308	pub algorithm: Cow<'a, str>,
309	#[serde(rename = "cty", default, skip_serializing_if = "Option::is_none")]
310	pub content_type: Option<String>,
311	#[serde(flatten)]
312	pub inner: T,
313}
314
315#[derive(Debug, Clone, Copy, PartialEq, Eq, Decode, Encode, TypeInfo, MaxEncodedLen, Serialize, Deserialize)]
316enum ContentType {
317	Json,
318	#[cfg(feature = "ciborium")]
319	Cbor,
320}
321
322/// Parsed, but unvalidated token.
323///
324/// The type param ([`Empty`] by default) corresponds to the [additional information] enclosed
325/// in the token [`Header`].
326///
327/// An `UntrustedToken` can be parsed from a string using the [`TryFrom`] implementation.
328/// This checks that a token is well-formed (has a header, claims and a signature),
329/// but does not validate the signature.
330/// As a shortcut, a token without additional header info can be created using [`Self::new()`].
331///
332/// [additional information]: Header#other_fields
333///
334/// # Examples
335///
336/// ```
337/// # use jwt_compact_frame::UntrustedToken;
338/// let token_str = "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJp\
339///     c3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leG\
340///     FtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJ\
341///     U1p1r_wW1gFWFOEjXk";
342/// let token: UntrustedToken = token_str.try_into()?;
343/// // The same operation using a shortcut:
344/// let same_token = UntrustedToken::new(token_str)?;
345/// // Token header can be accessed to select the verifying key etc.
346/// let key_id: Option<&str> = token.header().key_id.as_deref();
347/// # Ok::<_, anyhow::Error>(())
348/// ```
349///
350/// ## Handling tokens with custom header fields
351///
352/// ```
353/// # use serde::Deserialize;
354/// # use jwt_compact_frame::UntrustedToken;
355/// #[derive(Debug, Clone, Deserialize)]
356/// struct HeaderExtensions {
357///     custom: String,
358/// }
359///
360/// let token_str = "eyJhbGciOiJIUzI1NiIsImtpZCI6InRlc3Rfa2V5Iiwid\
361///     HlwIjoiSldUIiwiY3VzdG9tIjoiY3VzdG9tIn0.eyJzdWIiOiIxMjM0NTY\
362///     3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9._27Fb6nF\
363///     Tg-HSt3vO4ylaLGcU_ZV2VhMJR4HL7KaQik";
364/// let token: UntrustedToken<HeaderExtensions> = token_str.try_into()?;
365/// let extensions = &token.header().other_fields;
366/// println!("{}", extensions.custom);
367/// # Ok::<_, anyhow::Error>(())
368/// ```
369#[derive(Debug, Clone, Decode, Encode, TypeInfo, Eq, PartialEq, MaxEncodedLen, Serialize, Deserialize)]
370pub struct UntrustedToken<H = Empty> {
371	// TODO: Find a reasonable upper bound for the signed data size.
372	pub(crate) signed_data: BoundedVec<u8, ConstU32<2048>>,
373	header: Header<H>,
374	// The algorithm is a very short string.
375	algorithm: BoundedVec<u8, ConstU32<8>>,
376	content_type: ContentType,
377	// TODO: Find a reasonable upper bound for the claims size.
378	serialized_claims: BoundedVec<u8, ConstU32<2048>>,
379	signature: BoundedVec<u8, ConstU32<SIGNATURE_SIZE>>,
380}
381
382/// Token with validated integrity.
383///
384/// Claims encoded in the token can be verified by invoking [`Claims`] methods
385/// via [`Self::claims()`].
386#[derive(Debug, Clone, Serialize, Deserialize, Decode, Encode, TypeInfo, Eq, PartialEq)]
387pub struct Token<T, H = Empty> {
388	header: Header<H>,
389	claims: Claims<T>,
390}
391
392impl<T, H> Token<T, H> {
393	pub(crate) const fn new(header: Header<H>, claims: Claims<T>) -> Self {
394		Self { header, claims }
395	}
396
397	/// Gets token header.
398	pub const fn header(&self) -> &Header<H> {
399		&self.header
400	}
401
402	/// Gets token claims.
403	pub const fn claims(&self) -> &Claims<T> {
404		&self.claims
405	}
406
407	/// Splits the `Token` into the respective `Header` and `Claims` while consuming it.
408	pub fn into_parts(self) -> (Header<H>, Claims<T>) {
409		(self.header, self.claims)
410	}
411}
412
413/// `Token` together with the validated token signature.
414///
415/// # Examples
416///
417/// ```
418/// # use jwt_compact_frame::{alg::{Hs256, Hs256Key, Hs256Signature}, prelude::*};
419/// # use chrono::Duration;
420/// # use serde::{Deserialize, Serialize};
421/// #
422/// #[derive(Serialize, Deserialize)]
423/// struct MyClaims {
424///     // Custom claims in the token...
425/// }
426///
427/// # fn main() -> anyhow::Result<()> {
428/// # let key = Hs256Key::new(b"super_secret_key");
429/// # let claims = Claims::new(MyClaims {})
430/// #     .set_duration_and_issuance(&TimeOptions::default(), Duration::days(7));
431/// let token_string: String = // token from an external source
432/// #   Hs256.token(&Header::empty(), &claims, &key)?;
433/// let token = UntrustedToken::new(&token_string)?;
434/// let signed = Hs256.validator::<MyClaims>(&key)
435///     .validate_for_signed_token(&token)?;
436///
437/// // `signature` is strongly typed.
438/// let signature: Hs256Signature = signed.signature;
439/// // Token itself is available via `token` field.
440/// let claims = signed.token.claims();
441/// claims.validate_expiration(&TimeOptions::default())?;
442/// // Process the claims...
443/// # Ok(())
444/// # } // end main()
445/// ```
446#[non_exhaustive]
447#[derive(Encode, Decode, MaxEncodedLen, Serialize, Deserialize)]
448pub struct SignedToken<A: Algorithm + ?Sized, T, H = Empty> {
449	/// Token signature.
450	pub signature: A::Signature,
451	/// Verified token.
452	pub token: Token<T, H>,
453}
454
455impl<A, T, H> fmt::Debug for SignedToken<A, T, H>
456where
457	A: Algorithm,
458	A::Signature: fmt::Debug,
459	T: fmt::Debug,
460	H: fmt::Debug,
461{
462	fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
463		formatter.debug_struct("SignedToken").field("token", &self.token).field("signature", &self.signature).finish()
464	}
465}
466
467impl<A, T, H> Clone for SignedToken<A, T, H>
468where
469	A: Algorithm,
470	A::Signature: Clone,
471	T: Clone,
472	H: Clone,
473{
474	fn clone(&self) -> Self {
475		Self { signature: self.signature.clone(), token: self.token.clone() }
476	}
477}
478
479impl<'a, H: DeserializeOwned> TryFrom<&'a str> for UntrustedToken<H> {
480	type Error = ParseError;
481
482	fn try_from(s: &'a str) -> Result<Self, Self::Error> {
483		let token_parts: Vec<_> = s.splitn(4, '.').collect();
484		match &token_parts[..] {
485			[header, claims, signature] => {
486				let header = Base64UrlUnpadded::decode_vec(header).map_err(|_| ParseError::InvalidBase64Encoding)?;
487				let serialized_claims = Base64UrlUnpadded::decode_vec(claims)
488					.map_err(|_| ParseError::InvalidBase64Encoding)
489					.and_then(|claims| BoundedVec::try_from(claims).map_err(|_| ParseError::ClaimsTooLarge))?;
490
491				// Decode into a temporary Vec<u8>
492				let signature_len =
493					Base64UrlUnpadded::decode_vec(signature).map_err(|_| ParseError::InvalidBase64Encoding)?;
494
495				// Convert the Vec<u8> into a BoundedVec<u8>
496				let bounded_decoded_signature =
497					BoundedVec::try_from(signature_len).map_err(|_| ParseError::SignatureTooLarge)?;
498
499				let header: CompleteHeader<_> = serde_json::from_slice(&header).map_err(ParseError::MalformedHeader)?;
500				let content_type = match header.content_type {
501					None => ContentType::Json,
502					Some(s) if s.eq_ignore_ascii_case("json") => ContentType::Json,
503					#[cfg(feature = "ciborium")]
504					Some(s) if s.eq_ignore_ascii_case("cbor") => ContentType::Cbor,
505					Some(s) => return Err(ParseError::UnsupportedContentType(s)),
506				};
507				let signed_data = s.rsplit_once('.').unwrap().0.as_bytes();
508				let bounded_signed_data =
509					BoundedVec::try_from(signed_data.to_vec()).map_err(|_| ParseError::SignedDataTooLarge)?;
510				let bounded_header_algorithm = BoundedVec::try_from(header.algorithm.as_bytes().to_vec())
511					.map_err(|_| ParseError::SignedDataTooLarge)?;
512				Ok(Self {
513					signed_data: bounded_signed_data,
514					header: header.inner,
515					algorithm: bounded_header_algorithm,
516					content_type,
517					serialized_claims,
518					signature: bounded_decoded_signature,
519				})
520			},
521			_ => Err(ParseError::InvalidTokenStructure),
522		}
523	}
524}
525
526impl<'a> UntrustedToken {
527	/// Creates an untrusted token from a string. This is a shortcut for calling the [`TryFrom`]
528	/// conversion.
529	pub fn new<S: AsRef<str> + ?Sized>(s: &'a S) -> Result<Self, ParseError> {
530		Self::try_from(s.as_ref())
531	}
532}
533
534impl<H> UntrustedToken<H> {
535	/// Converts this token to an owned form.
536	pub fn into_owned(self) -> Self {
537		Self {
538			signed_data: self.signed_data,
539			header: self.header,
540			algorithm: self.algorithm,
541			content_type: self.content_type,
542			serialized_claims: self.serialized_claims,
543			signature: self.signature,
544		}
545	}
546
547	/// Gets the token header.
548	pub const fn header(&self) -> &Header<H> {
549		&self.header
550	}
551
552	/// Gets the integrity algorithm used to secure the token.
553	pub fn algorithm(&self) -> &str {
554		core::str::from_utf8(&self.algorithm).expect("algorithm is always valid UTF-8")
555	}
556
557	/// Returns signature bytes from the token. These bytes are **not** guaranteed to form a valid
558	/// signature.
559	pub fn signature_bytes(&self) -> &[u8] {
560		&self.signature
561	}
562
563	/// Deserializes claims from this token without checking token integrity. The resulting
564	/// claims are thus **not** guaranteed to be valid.
565	pub fn deserialize_claims_unchecked<T>(&self) -> Result<Claims<T>, ValidationError>
566	where
567		T: DeserializeOwned,
568	{
569		match self.content_type {
570			ContentType::Json => {
571				serde_json::from_slice(&self.serialized_claims).map_err(ValidationError::MalformedClaims)
572			},
573
574			#[cfg(feature = "ciborium")]
575			ContentType::Cbor => {
576				ciborium::from_reader(&self.serialized_claims[..]).map_err(|err| {
577					ValidationError::MalformedCborClaims(match err {
578						CborDeError::Io(err) => CborDeError::Io(anyhow::anyhow!(err)),
579						// ^ In order to be able to use `anyhow!` in both std and no-std envs,
580						// we inline the error transform directly here.
581						CborDeError::Syntax(offset) => CborDeError::Syntax(offset),
582						CborDeError::Semantic(offset, description) => CborDeError::Semantic(offset, description),
583						CborDeError::RecursionLimitExceeded => CborDeError::RecursionLimitExceeded,
584					})
585				})
586			},
587		}
588	}
589}
590
591#[cfg(test)]
592mod tests {
593	use assert_matches::assert_matches;
594	use base64ct::{Base64UrlUnpadded, Encoding};
595
596	use super::*;
597	use crate::{
598		alg::{Hs256, Hs256Key},
599		alloc::{ToOwned, ToString},
600		AlgorithmExt, Empty,
601	};
602
603	type Obj = serde_json::Map<String, serde_json::Value>;
604
605	const HS256_TOKEN: &str = "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.\
606                               eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFt\
607                               cGxlLmNvbS9pc19yb290Ijp0cnVlfQ.\
608                               dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
609	const HS256_KEY: &str = "AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75\
610                             aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow";
611
612	#[test]
613	fn invalid_token_structure() {
614		let mangled_str = HS256_TOKEN.replace('.', "");
615		assert_matches!(UntrustedToken::new(&mangled_str).unwrap_err(), ParseError::InvalidTokenStructure);
616
617		let mut mangled_str = HS256_TOKEN.to_owned();
618		let signature_start = mangled_str.rfind('.').unwrap();
619		mangled_str.truncate(signature_start);
620		assert_matches!(UntrustedToken::new(&mangled_str).unwrap_err(), ParseError::InvalidTokenStructure);
621
622		let mut mangled_str = HS256_TOKEN.to_owned();
623		mangled_str.push('.');
624		assert_matches!(UntrustedToken::new(&mangled_str).unwrap_err(), ParseError::InvalidTokenStructure);
625	}
626
627	#[test]
628	fn base64_error_during_parsing() {
629		let mangled_str = HS256_TOKEN.replace('0', "+");
630		assert_matches!(UntrustedToken::new(&mangled_str).unwrap_err(), ParseError::InvalidBase64Encoding);
631	}
632
633	#[test]
634	fn base64_padding_error_during_parsing() {
635		let mut mangled_str = HS256_TOKEN.to_owned();
636		mangled_str.pop();
637		mangled_str.push('_'); // leads to non-zero padding for the last encoded byte
638		assert_matches!(UntrustedToken::new(&mangled_str).unwrap_err(), ParseError::InvalidBase64Encoding);
639	}
640
641	#[test]
642	fn header_fields_are_not_serialized_if_not_present() {
643		let header = Header::empty();
644		let json = serde_json::to_string(&header).unwrap();
645		assert_eq!(json, "{}");
646	}
647
648	#[test]
649	fn header_with_x5t_field() {
650		let header = r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pk"}"#;
651		let header: CompleteHeader<Header<Empty>> = serde_json::from_str(header).unwrap();
652		let thumbprint = header.inner.certificate_sha1_thumbprint.as_ref().unwrap();
653		let Thumbprint::Bytes(thumbprint) = thumbprint else {
654			unreachable!();
655		};
656
657		assert_eq!(thumbprint[0], 0x94);
658		assert_eq!(thumbprint[19], 0x99);
659
660		let json = serde_json::to_value(header).unwrap();
661		assert_eq!(
662			json,
663			serde_json::json!({
664				"alg": "HS256",
665				"x5t": "lDpwLQbzRZmu4fjajvn3KWAx1pk",
666			})
667		);
668	}
669
670	#[test]
671	fn header_with_padded_x5t_field() {
672		let header = r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pk=="}"#;
673		let header: CompleteHeader<Header<Empty>> = serde_json::from_str(header).unwrap();
674		let thumbprint = header.inner.certificate_sha1_thumbprint.as_ref().unwrap();
675		let Thumbprint::Bytes(thumbprint) = thumbprint else { unreachable!() };
676
677		assert_eq!(thumbprint[0], 0x94);
678		assert_eq!(thumbprint[19], 0x99);
679	}
680
681	#[test]
682	fn header_with_hex_x5t_field() {
683		let header = r#"{"alg":"HS256","x5t":"NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg"}"#;
684		let header: CompleteHeader<Header<Empty>> = serde_json::from_str(header).unwrap();
685		let thumbprint = header.inner.certificate_sha1_thumbprint.as_ref().unwrap();
686		let Thumbprint::String(thumbprint) = thumbprint else { unreachable!() };
687
688		assert_eq!(thumbprint, "65AF6909B1B0758E06C6E048C46002B5C695E36B");
689
690		let json = serde_json::to_value(header).unwrap();
691		assert_eq!(
692			json,
693			serde_json::json!({
694				"alg": "HS256",
695				"x5t": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg",
696			})
697		);
698	}
699
700	#[test]
701	fn header_with_padded_hex_x5t_field() {
702		let header = r#"{"alg":"HS256","x5t":"NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg=="}"#;
703		let header: CompleteHeader<Header<Empty>> = serde_json::from_str(header).unwrap();
704		let thumbprint = header.inner.certificate_sha1_thumbprint.as_ref().unwrap();
705		let Thumbprint::String(thumbprint) = thumbprint else { unreachable!() };
706
707		assert_eq!(thumbprint, "65AF6909B1B0758E06C6E048C46002B5C695E36B");
708	}
709
710	#[test]
711	fn header_with_overly_short_x5t_field() {
712		let header = r#"{"alg":"HS256","x5t":"aGk="}"#;
713		let err = serde_json::from_str::<CompleteHeader<Header<Empty>>>(header).unwrap_err();
714		let err = err.to_string();
715		assert!(err.contains("thumbprint must contain at least 20 bytes"), "{err}");
716	}
717
718	#[test]
719	fn header_with_non_base64_x5t_field() {
720		let headers = [
721			r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1p?"}"#,
722			r#"{"alg":"HS256","x5t":"NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk!RTM2Qg"}"#,
723		];
724		for header in headers {
725			let err = serde_json::from_str::<CompleteHeader<Header<Empty>>>(header).unwrap_err();
726			let err = err.to_string();
727			assert!(err.contains("Base64"), "{err}");
728		}
729	}
730
731	#[test]
732	fn header_with_x5t_sha256_field() {
733		let header = r#"{"alg":"HS256","x5t#S256":"MV9b23bQeMQ7isAGTkoBZGErH853yGk0W_yUx1iU7dM"}"#;
734		let header: CompleteHeader<Header<Empty>> = serde_json::from_str(header).unwrap();
735		let thumbprint = header.inner.certificate_thumbprint.as_ref().unwrap();
736		let Thumbprint::Bytes(thumbprint) = thumbprint else { unreachable!() };
737
738		assert_eq!(thumbprint[0], 0x31);
739		assert_eq!(thumbprint[31], 0xd3);
740
741		let json = serde_json::to_value(header).unwrap();
742		assert_eq!(
743			json,
744			serde_json::json!({
745				"alg": "HS256",
746				"x5t#S256": "MV9b23bQeMQ7isAGTkoBZGErH853yGk0W_yUx1iU7dM",
747			})
748		);
749	}
750
751	#[test]
752	fn malformed_header() {
753		let mangled_headers = [
754			// Missing closing brace
755			r#"{"alg":"HS256""#,
756			// Missing necessary `alg` field
757			"{}",
758			// `alg` field is not a string
759			r#"{"alg":5}"#,
760			r#"{"alg":[1,"foo"]}"#,
761			r#"{"alg":false}"#,
762			// Duplicate `alg` field
763			r#"{"alg":"HS256","alg":"none"}"#,
764			// Invalid thumbprint fields
765			r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1p"}"#,
766			r#"{"alg":"HS256","x5t":["lDpwLQbzRZmu4fjajvn3KWAx1pk"]}"#,
767			r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1 k"}"#,
768			r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pk==="}"#,
769			r#"{"alg":"HS256","x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pkk"}"#,
770			r#"{"alg":"HS256","x5t":"MV9b23bQeMQ7isAGTkoBZGErH853yGk0W_yUx1iU7dM"}"#,
771			r#"{"alg":"HS256","x5t#S256":"lDpwLQbzRZmu4fjajvn3KWAx1pk"}"#,
772		];
773
774		for mangled_header in &mangled_headers {
775			let mangled_header = Base64UrlUnpadded::encode_string(mangled_header.as_bytes());
776			let mut mangled_str = HS256_TOKEN.to_owned();
777			mangled_str.replace_range(..mangled_str.find('.').unwrap(), &mangled_header);
778			assert_matches!(UntrustedToken::new(&mangled_str).unwrap_err(), ParseError::MalformedHeader(_));
779		}
780	}
781
782	#[test]
783	fn unsupported_content_type() {
784		let mangled_header = br#"{"alg":"HS256","cty":"txt"}"#;
785		let mangled_header = Base64UrlUnpadded::encode_string(mangled_header);
786		let mut mangled_str = HS256_TOKEN.to_owned();
787		mangled_str.replace_range(..mangled_str.find('.').unwrap(), &mangled_header);
788		assert_matches!(
789			UntrustedToken::new(&mangled_str).unwrap_err(),
790			ParseError::UnsupportedContentType(s) if s == "txt"
791		);
792	}
793
794	#[test]
795	fn extracting_custom_header_fields() {
796		let header = r#"{"alg":"HS256","custom":[1,"field"],"x5t":"lDpwLQbzRZmu4fjajvn3KWAx1pk"}"#;
797		let header: CompleteHeader<Header<Obj>> = serde_json::from_str(header).unwrap();
798		assert_eq!(header.algorithm, "HS256");
799		assert!(header.inner.certificate_sha1_thumbprint.is_some());
800		assert_eq!(header.inner.other_fields.len(), 1);
801		assert!(header.inner.other_fields["custom"].is_array());
802	}
803
804	#[test]
805	fn malformed_json_claims() {
806		let malformed_claims = [
807			// Missing closing brace
808			r#"{"exp":1500000000"#,
809			// `exp` claim is not a number
810			r#"{"exp":"1500000000"}"#,
811			r#"{"exp":false}"#,
812			// Duplicate `exp` claim
813			r#"{"exp":1500000000,"nbf":1400000000,"exp":1510000000}"#,
814			// Too large `exp` value
815			r#"{"exp":1500000000000000000000000000000000}"#,
816		];
817
818		let claims_start = HS256_TOKEN.find('.').unwrap() + 1;
819		let claims_end = HS256_TOKEN.rfind('.').unwrap();
820		let key = Base64UrlUnpadded::decode_vec(HS256_KEY).unwrap();
821		let key = Hs256Key::new(key);
822
823		for claims in &malformed_claims {
824			let encoded_claims = Base64UrlUnpadded::encode_string(claims.as_bytes());
825			let mut mangled_str = HS256_TOKEN.to_owned();
826			mangled_str.replace_range(claims_start..claims_end, &encoded_claims);
827			let token = UntrustedToken::new(&mangled_str).unwrap();
828			assert_matches!(
829				Hs256.validator::<Obj>(&key).validate(&token).unwrap_err(),
830				ValidationError::MalformedClaims(_),
831				"Failing claims: {claims}"
832			);
833		}
834	}
835
836	fn test_invalid_signature_len(mangled_str: &str, actual_len: usize) {
837		let token = UntrustedToken::new(&mangled_str).unwrap();
838		let key = Base64UrlUnpadded::decode_vec(HS256_KEY).unwrap();
839		let key = Hs256Key::new(key);
840
841		let err = Hs256.validator::<Empty>(&key).validate(&token).unwrap_err();
842		assert_matches!(
843			err,
844			ValidationError::InvalidSignatureLen { actual, expected: 32 }
845				if actual == actual_len
846		);
847	}
848
849	#[test]
850	fn short_signature_error() {
851		test_invalid_signature_len(&HS256_TOKEN[..HS256_TOKEN.len() - 3], 30);
852	}
853
854	#[test]
855	fn long_signature_error() {
856		let mut mangled_string = HS256_TOKEN.to_owned();
857		mangled_string.push('a');
858		test_invalid_signature_len(&mangled_string, 33);
859	}
860}