jwt_compact_frame/
traits.rs

1//! Key traits defined by the crate.
2
3use base64ct::{Base64UrlUnpadded, Encoding};
4use parity_scale_codec::Encode;
5use scale_info::TypeInfo;
6use serde::{de::DeserializeOwned, Serialize};
7
8use core::{marker::PhantomData, num::NonZeroUsize};
9
10#[cfg(feature = "ciborium")]
11use crate::error::CborSerError;
12use crate::{
13	alloc::{Cow, String, ToOwned, Vec},
14	token::CompleteHeader,
15	Claims, CreationError, Header, SignedToken, Token, UntrustedToken, ValidationError,
16};
17
18/// Signature for a certain JWT signing [`Algorithm`].
19///
20/// We require that signature can be restored from a byte slice,
21/// and can be represented as a byte slice.
22pub trait AlgorithmSignature: Sized {
23	/// Constant byte length of signatures supported by the [`Algorithm`], or `None` if
24	/// the signature length is variable.
25	///
26	/// - If this value is `Some(_)`, the signature will be first checked for its length
27	///   during token verification. An [`InvalidSignatureLen`] error will be raised if the length
28	///   is invalid. [`Self::try_from_slice()`] will thus always receive a slice with
29	///   the expected length.
30	/// - If this value is `None`, no length check is performed before calling
31	///   [`Self::try_from_slice()`].
32	///
33	/// [`InvalidSignatureLen`]: crate::ValidationError::InvalidSignatureLen
34	const LENGTH: Option<NonZeroUsize> = None;
35
36	/// Attempts to restore a signature from a byte slice. This method may fail
37	/// if the slice is malformed.
38	fn try_from_slice(slice: &[u8]) -> anyhow::Result<Self>;
39
40	/// Represents this signature as bytes.
41	fn as_bytes(&self) -> Cow<'_, [u8]>;
42}
43
44/// JWT signing algorithm.
45pub trait Algorithm {
46	/// Key used when issuing new tokens.
47	type SigningKey;
48	/// Key used when verifying tokens. May coincide with [`Self::SigningKey`] for symmetric
49	/// algorithms (e.g., `HS*`).
50	type VerifyingKey;
51	/// Signature produced by the algorithm.
52	type Signature: AlgorithmSignature;
53
54	/// Returns the name of this algorithm, as mentioned in the `alg` field of the JWT header.
55	fn name(&self) -> Cow<'static, str>;
56
57	/// Signs a `message` with the `signing_key`.
58	fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature;
59
60	/// Verifies the `message` against the `signature` and `verifying_key`.
61	fn verify_signature(&self, signature: &Self::Signature, verifying_key: &Self::VerifyingKey, message: &[u8])
62		-> bool;
63}
64
65/// Algorithm that uses a custom name when creating and validating tokens.
66///
67/// # Examples
68///
69/// ```
70/// use jwt_compact_frame::{alg::{Hs256, Hs256Key}, prelude::*, Empty, Renamed};
71///
72/// # fn main() -> anyhow::Result<()> {
73/// let alg = Renamed::new(Hs256, "HS2");
74/// let key = Hs256Key::new(b"super_secret_key_donut_steel");
75/// let token_string = alg.token(&Header::empty(), &Claims::empty(), &key)?;
76///
77/// let token = UntrustedToken::new(&token_string)?;
78/// assert_eq!(token.algorithm(), "HS2");
79/// // Note that the created token cannot be verified against the original algorithm
80/// // since the algorithm name recorded in the token header doesn't match.
81/// assert!(Hs256.validator::<Empty>(&key).validate(&token).is_err());
82///
83/// // ...but the modified alg is working as expected.
84/// assert!(alg.validator::<Empty>(&key).validate(&token).is_ok());
85/// # Ok(())
86/// # }
87/// ```
88#[derive(Debug, Clone, Copy, TypeInfo, Encode)]
89pub struct Renamed<A> {
90	inner: A,
91	name: &'static str,
92}
93
94impl<A: Algorithm> Renamed<A> {
95	/// Creates a renamed algorithm.
96	pub const fn new(algorithm: A, new_name: &'static str) -> Self {
97		Self { inner: algorithm, name: new_name }
98	}
99}
100
101impl<A: Algorithm> Algorithm for Renamed<A> {
102	type Signature = A::Signature;
103	type SigningKey = A::SigningKey;
104	type VerifyingKey = A::VerifyingKey;
105
106	fn name(&self) -> Cow<'static, str> {
107		Cow::Borrowed(self.name)
108	}
109
110	fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
111		self.inner.sign(signing_key, message)
112	}
113
114	fn verify_signature(
115		&self,
116		signature: &Self::Signature,
117		verifying_key: &Self::VerifyingKey,
118		message: &[u8],
119	) -> bool {
120		self.inner.verify_signature(signature, verifying_key, message)
121	}
122}
123
124/// Automatically implemented extensions of the `Algorithm` trait.
125pub trait AlgorithmExt: Algorithm {
126	/// Creates a new token and serializes it to string.
127	fn token<T>(
128		&self,
129		header: &Header<impl Serialize>,
130		claims: &Claims<T>,
131		signing_key: &Self::SigningKey,
132	) -> Result<String, CreationError>
133	where
134		T: Serialize;
135
136	/// Creates a new token with CBOR-encoded claims and serializes it to string.
137	#[cfg(feature = "ciborium")]
138	#[cfg_attr(docsrs, doc(cfg(feature = "ciborium")))]
139	fn compact_token<T>(
140		&self,
141		header: &Header<impl Serialize>,
142		claims: &Claims<T>,
143		signing_key: &Self::SigningKey,
144	) -> Result<String, CreationError>
145	where
146		T: Serialize;
147
148	/// Creates a JWT validator for the specified verifying key and the claims type.
149	/// The validator can then be used to validate integrity of one or more tokens.
150	fn validator<'a, T>(&'a self, verifying_key: &'a Self::VerifyingKey) -> Validator<'a, Self, T>;
151
152	/// Validates the token integrity against the provided `verifying_key`.
153	#[deprecated = "Use `.validator().validate()` for added flexibility"]
154	fn validate_integrity<T>(
155		&self,
156		token: &UntrustedToken,
157		verifying_key: &Self::VerifyingKey,
158	) -> Result<Token<T>, ValidationError>
159	where
160		T: DeserializeOwned;
161
162	/// Validates the token integrity against the provided `verifying_key`.
163	///
164	/// Unlike [`validate_integrity`](#tymethod.validate_integrity), this method retains more
165	/// information about the original token, in particular, its signature.
166	#[deprecated = "Use `.validator().validate_for_signed_token()` for added flexibility"]
167	fn validate_for_signed_token<T>(
168		&self,
169		token: &UntrustedToken,
170		verifying_key: &Self::VerifyingKey,
171	) -> Result<SignedToken<Self, T>, ValidationError>
172	where
173		T: DeserializeOwned;
174}
175
176impl<A: Algorithm> AlgorithmExt for A {
177	fn token<T>(
178		&self,
179		header: &Header<impl Serialize>,
180		claims: &Claims<T>,
181		signing_key: &Self::SigningKey,
182	) -> Result<String, CreationError>
183	where
184		T: Serialize,
185	{
186		let complete_header = CompleteHeader { algorithm: self.name(), content_type: None, inner: header };
187		let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
188		let mut buffer = Vec::new();
189		encode_base64_buf(&header, &mut buffer);
190
191		let claims = serde_json::to_string(claims).map_err(CreationError::Claims)?;
192		buffer.push(b'.');
193		encode_base64_buf(&claims, &mut buffer);
194
195		let signature = self.sign(signing_key, &buffer);
196		buffer.push(b'.');
197		encode_base64_buf(signature.as_bytes(), &mut buffer);
198
199		// SAFETY: safe by construction: base64 alphabet and `.` char are valid UTF-8.
200		Ok(unsafe { String::from_utf8_unchecked(buffer) })
201	}
202
203	#[cfg(feature = "ciborium")]
204	fn compact_token<T>(
205		&self,
206		header: &Header<impl Serialize>,
207		claims: &Claims<T>,
208		signing_key: &Self::SigningKey,
209	) -> Result<String, CreationError>
210	where
211		T: Serialize,
212	{
213		let complete_header =
214			CompleteHeader { algorithm: self.name(), content_type: Some("CBOR".to_owned()), inner: header };
215		let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
216		let mut buffer = Vec::new();
217		encode_base64_buf(&header, &mut buffer);
218
219		let mut serialized_claims = vec![];
220		ciborium::into_writer(claims, &mut serialized_claims).map_err(|err| {
221			CreationError::CborClaims(match err {
222				CborSerError::Value(message) => CborSerError::Value(message),
223				CborSerError::Io(_) => unreachable!(), // writing to a `Vec` always succeeds
224			})
225		})?;
226		buffer.push(b'.');
227		encode_base64_buf(&serialized_claims, &mut buffer);
228
229		let signature = self.sign(signing_key, &buffer);
230		buffer.push(b'.');
231		encode_base64_buf(signature.as_bytes(), &mut buffer);
232
233		// SAFETY: safe by construction: base64 alphabet and `.` char are valid UTF-8.
234		Ok(unsafe { String::from_utf8_unchecked(buffer) })
235	}
236
237	fn validator<'a, T>(&'a self, verifying_key: &'a Self::VerifyingKey) -> Validator<'a, Self, T> {
238		Validator { algorithm: self, verifying_key, _claims: PhantomData }
239	}
240
241	fn validate_integrity<T>(
242		&self,
243		token: &UntrustedToken,
244		verifying_key: &Self::VerifyingKey,
245	) -> Result<Token<T>, ValidationError>
246	where
247		T: DeserializeOwned,
248	{
249		self.validator::<T>(verifying_key).validate(token)
250	}
251
252	fn validate_for_signed_token<T>(
253		&self,
254		token: &UntrustedToken,
255		verifying_key: &Self::VerifyingKey,
256	) -> Result<SignedToken<Self, T>, ValidationError>
257	where
258		T: DeserializeOwned,
259	{
260		self.validator::<T>(verifying_key).validate_for_signed_token(token)
261	}
262}
263
264/// Validator for a certain signing [`Algorithm`] associated with a specific verifying key
265/// and a claims type. Produced by the [`AlgorithmExt::validator()`] method.
266#[derive(Debug)]
267pub struct Validator<'a, A: Algorithm + ?Sized, T> {
268	algorithm: &'a A,
269	verifying_key: &'a A::VerifyingKey,
270	_claims: PhantomData<fn() -> T>,
271}
272
273impl<A: Algorithm + ?Sized, T> Clone for Validator<'_, A, T> {
274	fn clone(&self) -> Self {
275		*self
276	}
277}
278
279impl<A: Algorithm + ?Sized, T> Copy for Validator<'_, A, T> {}
280
281impl<A: Algorithm + ?Sized, T: DeserializeOwned> Validator<'_, A, T> {
282	/// Validates the token integrity against a verifying key enclosed in this validator.
283	pub fn validate<H: Clone>(self, token: &UntrustedToken<H>) -> Result<Token<T, H>, ValidationError> {
284		self.validate_for_signed_token(token).map(|signed| signed.token)
285	}
286
287	/// Validates the token integrity against a verifying key enclosed in this validator,
288	/// and returns the validated [`Token`] together with its signature.
289	pub fn validate_for_signed_token<H: Clone>(
290		self,
291		token: &UntrustedToken<H>,
292	) -> Result<SignedToken<A, T, H>, ValidationError> {
293		let expected_alg = self.algorithm.name();
294		if expected_alg != token.algorithm() {
295			return Err(ValidationError::AlgorithmMismatch {
296				expected: expected_alg.into_owned(),
297				actual: token.algorithm().to_owned(),
298			});
299		}
300
301		let signature = token.signature_bytes();
302		if let Some(expected_len) = A::Signature::LENGTH {
303			if signature.len() != expected_len.get() {
304				return Err(ValidationError::InvalidSignatureLen {
305					expected: expected_len.get(),
306					actual: signature.len(),
307				});
308			}
309		}
310
311		let signature = A::Signature::try_from_slice(signature).map_err(ValidationError::MalformedSignature)?;
312		// We assume that parsing claims is less computationally demanding than
313		// validating a signature.
314		let claims = token.deserialize_claims_unchecked::<T>()?;
315		if !self.algorithm.verify_signature(&signature, self.verifying_key, &token.signed_data) {
316			return Err(ValidationError::InvalidSignature);
317		}
318
319		Ok(SignedToken { signature, token: Token::new(token.header().clone(), claims) })
320	}
321}
322
323fn encode_base64_buf(source: impl AsRef<[u8]>, buffer: &mut Vec<u8>) {
324	let source = source.as_ref();
325	let previous_len = buffer.len();
326	let claims_len = Base64UrlUnpadded::encoded_len(source);
327	buffer.resize(previous_len + claims_len, 0);
328	Base64UrlUnpadded::encode(source, &mut buffer[previous_len..])
329		.expect("miscalculated base64-encoded length; this should never happen");
330}