jwt_compact/
traits.rs

1//! Key traits defined by the crate.
2
3use base64ct::{Base64UrlUnpadded, Encoding};
4use serde::{de::DeserializeOwned, Serialize};
5
6use core::{marker::PhantomData, num::NonZeroUsize};
7
8#[cfg(feature = "ciborium")]
9use crate::error::CborSerError;
10use crate::{
11    alloc::{Cow, String, ToOwned, Vec},
12    token::CompleteHeader,
13    Claims, CreationError, Header, SignedToken, Token, UntrustedToken, ValidationError,
14};
15
16/// Signature for a certain JWT signing [`Algorithm`].
17///
18/// We require that signature can be restored from a byte slice,
19/// and can be represented as a byte slice.
20pub trait AlgorithmSignature: Sized {
21    /// Constant byte length of signatures supported by the [`Algorithm`], or `None` if
22    /// the signature length is variable.
23    ///
24    /// - If this value is `Some(_)`, the signature will be first checked for its length
25    ///   during token verification. An [`InvalidSignatureLen`] error will be raised if the length
26    ///   is invalid. [`Self::try_from_slice()`] will thus always receive a slice with
27    ///   the expected length.
28    /// - If this value is `None`, no length check is performed before calling
29    ///   [`Self::try_from_slice()`].
30    ///
31    /// [`InvalidSignatureLen`]: crate::ValidationError::InvalidSignatureLen
32    const LENGTH: Option<NonZeroUsize> = None;
33
34    /// Attempts to restore a signature from a byte slice. This method may fail
35    /// if the slice is malformed.
36    fn try_from_slice(slice: &[u8]) -> anyhow::Result<Self>;
37
38    /// Represents this signature as bytes.
39    fn as_bytes(&self) -> Cow<'_, [u8]>;
40}
41
42/// JWT signing algorithm.
43pub trait Algorithm {
44    /// Key used when issuing new tokens.
45    type SigningKey;
46    /// Key used when verifying tokens. May coincide with [`Self::SigningKey`] for symmetric
47    /// algorithms (e.g., `HS*`).
48    type VerifyingKey;
49    /// Signature produced by the algorithm.
50    type Signature: AlgorithmSignature;
51
52    /// Returns the name of this algorithm, as mentioned in the `alg` field of the JWT header.
53    fn name(&self) -> Cow<'static, str>;
54
55    /// Signs a `message` with the `signing_key`.
56    fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature;
57
58    /// Verifies the `message` against the `signature` and `verifying_key`.
59    fn verify_signature(
60        &self,
61        signature: &Self::Signature,
62        verifying_key: &Self::VerifyingKey,
63        message: &[u8],
64    ) -> bool;
65}
66
67/// Algorithm that uses a custom name when creating and validating tokens.
68///
69/// # Examples
70///
71/// ```
72/// use jwt_compact::{alg::{Hs256, Hs256Key}, prelude::*, Empty, Renamed};
73///
74/// # fn main() -> anyhow::Result<()> {
75/// let alg = Renamed::new(Hs256, "HS2");
76/// let key = Hs256Key::new(b"super_secret_key_donut_steel");
77/// let token_string = alg.token(&Header::empty(), &Claims::empty(), &key)?;
78///
79/// let token = UntrustedToken::new(&token_string)?;
80/// assert_eq!(token.algorithm(), "HS2");
81/// // Note that the created token cannot be verified against the original algorithm
82/// // since the algorithm name recorded in the token header doesn't match.
83/// assert!(Hs256.validator::<Empty>(&key).validate(&token).is_err());
84///
85/// // ...but the modified alg is working as expected.
86/// assert!(alg.validator::<Empty>(&key).validate(&token).is_ok());
87/// # Ok(())
88/// # }
89/// ```
90#[derive(Debug, Clone, Copy)]
91pub struct Renamed<A> {
92    inner: A,
93    name: &'static str,
94}
95
96impl<A: Algorithm> Renamed<A> {
97    /// Creates a renamed algorithm.
98    pub fn new(algorithm: A, new_name: &'static str) -> Self {
99        Self {
100            inner: algorithm,
101            name: new_name,
102        }
103    }
104}
105
106impl<A: Algorithm> Algorithm for Renamed<A> {
107    type SigningKey = A::SigningKey;
108    type VerifyingKey = A::VerifyingKey;
109    type Signature = A::Signature;
110
111    fn name(&self) -> Cow<'static, str> {
112        Cow::Borrowed(self.name)
113    }
114
115    fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
116        self.inner.sign(signing_key, message)
117    }
118
119    fn verify_signature(
120        &self,
121        signature: &Self::Signature,
122        verifying_key: &Self::VerifyingKey,
123        message: &[u8],
124    ) -> bool {
125        self.inner
126            .verify_signature(signature, verifying_key, message)
127    }
128}
129
130/// Automatically implemented extensions of the `Algorithm` trait.
131pub trait AlgorithmExt: Algorithm {
132    /// Creates a new token and serializes it to string.
133    fn token<T>(
134        &self,
135        header: &Header<impl Serialize>,
136        claims: &Claims<T>,
137        signing_key: &Self::SigningKey,
138    ) -> Result<String, CreationError>
139    where
140        T: Serialize;
141
142    /// Creates a new token with CBOR-encoded claims and serializes it to string.
143    #[cfg(feature = "ciborium")]
144    #[cfg_attr(docsrs, doc(cfg(feature = "ciborium")))]
145    fn compact_token<T>(
146        &self,
147        header: &Header<impl Serialize>,
148        claims: &Claims<T>,
149        signing_key: &Self::SigningKey,
150    ) -> Result<String, CreationError>
151    where
152        T: Serialize;
153
154    /// Creates a JWT validator for the specified verifying key and the claims type.
155    /// The validator can then be used to validate integrity of one or more tokens.
156    fn validator<'a, T>(&'a self, verifying_key: &'a Self::VerifyingKey) -> Validator<'a, Self, T>;
157
158    /// Validates the token integrity against the provided `verifying_key`.
159    #[deprecated = "Use `.validator().validate()` for added flexibility"]
160    fn validate_integrity<T>(
161        &self,
162        token: &UntrustedToken<'_>,
163        verifying_key: &Self::VerifyingKey,
164    ) -> Result<Token<T>, ValidationError>
165    where
166        T: DeserializeOwned;
167
168    /// Validates the token integrity against the provided `verifying_key`.
169    ///
170    /// Unlike [`validate_integrity`](#tymethod.validate_integrity), this method retains more
171    /// information about the original token, in particular, its signature.
172    #[deprecated = "Use `.validator().validate_for_signed_token()` for added flexibility"]
173    fn validate_for_signed_token<T>(
174        &self,
175        token: &UntrustedToken<'_>,
176        verifying_key: &Self::VerifyingKey,
177    ) -> Result<SignedToken<Self, T>, ValidationError>
178    where
179        T: DeserializeOwned;
180}
181
182impl<A: Algorithm> AlgorithmExt for A {
183    fn token<T>(
184        &self,
185        header: &Header<impl Serialize>,
186        claims: &Claims<T>,
187        signing_key: &Self::SigningKey,
188    ) -> Result<String, CreationError>
189    where
190        T: Serialize,
191    {
192        let complete_header = CompleteHeader {
193            algorithm: self.name(),
194            content_type: None,
195            inner: header,
196        };
197        let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
198        let mut buffer = Vec::new();
199        encode_base64_buf(&header, &mut buffer);
200
201        let claims = serde_json::to_string(claims).map_err(CreationError::Claims)?;
202        buffer.push(b'.');
203        encode_base64_buf(&claims, &mut buffer);
204
205        let signature = self.sign(signing_key, &buffer);
206        buffer.push(b'.');
207        encode_base64_buf(signature.as_bytes(), &mut buffer);
208
209        // SAFETY: safe by construction: base64 alphabet and `.` char are valid UTF-8.
210        Ok(unsafe { String::from_utf8_unchecked(buffer) })
211    }
212
213    #[cfg(feature = "ciborium")]
214    fn compact_token<T>(
215        &self,
216        header: &Header<impl Serialize>,
217        claims: &Claims<T>,
218        signing_key: &Self::SigningKey,
219    ) -> Result<String, CreationError>
220    where
221        T: Serialize,
222    {
223        let complete_header = CompleteHeader {
224            algorithm: self.name(),
225            content_type: Some("CBOR".to_owned()),
226            inner: header,
227        };
228        let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
229        let mut buffer = Vec::new();
230        encode_base64_buf(&header, &mut buffer);
231
232        let mut serialized_claims = vec![];
233        ciborium::into_writer(claims, &mut serialized_claims).map_err(|err| {
234            CreationError::CborClaims(match err {
235                CborSerError::Value(message) => CborSerError::Value(message),
236                CborSerError::Io(_) => unreachable!(), // writing to a `Vec` always succeeds
237            })
238        })?;
239        buffer.push(b'.');
240        encode_base64_buf(&serialized_claims, &mut buffer);
241
242        let signature = self.sign(signing_key, &buffer);
243        buffer.push(b'.');
244        encode_base64_buf(signature.as_bytes(), &mut buffer);
245
246        // SAFETY: safe by construction: base64 alphabet and `.` char are valid UTF-8.
247        Ok(unsafe { String::from_utf8_unchecked(buffer) })
248    }
249
250    fn validator<'a, T>(&'a self, verifying_key: &'a Self::VerifyingKey) -> Validator<'a, Self, T> {
251        Validator {
252            algorithm: self,
253            verifying_key,
254            _claims: PhantomData,
255        }
256    }
257
258    fn validate_integrity<T>(
259        &self,
260        token: &UntrustedToken<'_>,
261        verifying_key: &Self::VerifyingKey,
262    ) -> Result<Token<T>, ValidationError>
263    where
264        T: DeserializeOwned,
265    {
266        self.validator::<T>(verifying_key).validate(token)
267    }
268
269    fn validate_for_signed_token<T>(
270        &self,
271        token: &UntrustedToken<'_>,
272        verifying_key: &Self::VerifyingKey,
273    ) -> Result<SignedToken<Self, T>, ValidationError>
274    where
275        T: DeserializeOwned,
276    {
277        self.validator::<T>(verifying_key)
278            .validate_for_signed_token(token)
279    }
280}
281
282/// Validator for a certain signing [`Algorithm`] associated with a specific verifying key
283/// and a claims type. Produced by the [`AlgorithmExt::validator()`] method.
284#[derive(Debug)]
285pub struct Validator<'a, A: Algorithm + ?Sized, T> {
286    algorithm: &'a A,
287    verifying_key: &'a A::VerifyingKey,
288    _claims: PhantomData<fn() -> T>,
289}
290
291impl<A: Algorithm + ?Sized, T> Clone for Validator<'_, A, T> {
292    fn clone(&self) -> Self {
293        *self
294    }
295}
296
297impl<A: Algorithm + ?Sized, T> Copy for Validator<'_, A, T> {}
298
299impl<A: Algorithm + ?Sized, T: DeserializeOwned> Validator<'_, A, T> {
300    /// Validates the token integrity against a verifying key enclosed in this validator.
301    pub fn validate<H: Clone>(
302        self,
303        token: &UntrustedToken<'_, H>,
304    ) -> Result<Token<T, H>, ValidationError> {
305        self.validate_for_signed_token(token)
306            .map(|signed| signed.token)
307    }
308
309    /// Validates the token integrity against a verifying key enclosed in this validator,
310    /// and returns the validated [`Token`] together with its signature.
311    pub fn validate_for_signed_token<H: Clone>(
312        self,
313        token: &UntrustedToken<'_, H>,
314    ) -> Result<SignedToken<A, T, H>, ValidationError> {
315        let expected_alg = self.algorithm.name();
316        if expected_alg != token.algorithm() {
317            return Err(ValidationError::AlgorithmMismatch {
318                expected: expected_alg.into_owned(),
319                actual: token.algorithm().to_owned(),
320            });
321        }
322
323        let signature = token.signature_bytes();
324        if let Some(expected_len) = A::Signature::LENGTH {
325            if signature.len() != expected_len.get() {
326                return Err(ValidationError::InvalidSignatureLen {
327                    expected: expected_len.get(),
328                    actual: signature.len(),
329                });
330            }
331        }
332
333        let signature =
334            A::Signature::try_from_slice(signature).map_err(ValidationError::MalformedSignature)?;
335        // We assume that parsing claims is less computationally demanding than
336        // validating a signature.
337        let claims = token.deserialize_claims_unchecked::<T>()?;
338        if !self
339            .algorithm
340            .verify_signature(&signature, self.verifying_key, &token.signed_data)
341        {
342            return Err(ValidationError::InvalidSignature);
343        }
344
345        Ok(SignedToken {
346            signature,
347            token: Token::new(token.header().clone(), claims),
348        })
349    }
350}
351
352fn encode_base64_buf(source: impl AsRef<[u8]>, buffer: &mut Vec<u8>) {
353    let source = source.as_ref();
354    let previous_len = buffer.len();
355    let claims_len = Base64UrlUnpadded::encoded_len(source);
356    buffer.resize(previous_len + claims_len, 0);
357    Base64UrlUnpadded::encode(source, &mut buffer[previous_len..])
358        .expect("miscalculated base64-encoded length; this should never happen");
359}