jsonwebtoken_hs256/
lib.rs

1use base64::engine::general_purpose::URL_SAFE_NO_PAD;
2use base64::Engine;
3use ring::constant_time::verify_slices_are_equal;
4use ring::hmac::{self, HMAC_SHA256};
5use serde::de::DeserializeOwned;
6use serde::de::{self, Visitor};
7use serde::Deserializer;
8use serde::{Deserialize, Serialize};
9use std::borrow::Cow;
10use std::fmt;
11use std::marker::PhantomData;
12use std::time::{SystemTime, UNIX_EPOCH};
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum EncodeError {
17    #[error("{0}")]
18    Json(#[from] serde_json::Error),
19}
20
21#[derive(Debug, Error)]
22pub enum DecodeError {
23    /// When a token doesn't have a valid JWT shape
24    #[error("Invalid token")]
25    InvalidToken,
26    /// When the signature doesn't match
27    #[error("Invalid signature")]
28    InvalidSignature,
29    // Validation errors
30    /// When a claim required by the validation is not present
31    #[error("Missing required claim: {0}")]
32    MissingRequiredClaim(String),
33    /// When a token’s `exp` claim indicates that it has expired
34    #[error("Expired signature")]
35    ExpiredSignature,
36    /// When a token’s `iss` claim does not match the expected issuer
37    #[error("Invalid issuer")]
38    InvalidIssuer,
39    /// When a token’s `aud` claim does not match one of the expected audience values
40    #[error("Invalid audience")]
41    InvalidAudience,
42    /// When a token’s `sub` claim does not match one of the expected subject values
43    #[error("Invalid subject")]
44    InvalidSubject,
45    /// When a token’s `nbf` claim represents a time in the future
46    #[error("Immature signature")]
47    ImmatureSignature,
48    /// When the algorithm in the header doesn't match the one passed to `decode` or the encoding/decoding key
49    /// used doesn't match the alg requested
50    #[error("Invalid algorithm")]
51    InvalidAlgorithm,
52    // 3rd party errors
53    /// An error happened when decoding some base64 text
54    #[error("{0}")]
55    Base64(#[from] base64::DecodeError),
56    /// An error happened while deserializing JSON
57    #[error("{0}")]
58    Json(#[from] serde_json::Error),
59}
60
61const ALGORITHM: &str = "HS256";
62
63fn b64_encode<T: AsRef<[u8]>>(input: T) -> String {
64    URL_SAFE_NO_PAD.encode(input)
65}
66
67fn b64_decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, DecodeError> {
68    URL_SAFE_NO_PAD.decode(input).map_err(|e| e.into())
69}
70
71/// Serializes a struct to JSON and encodes it in base64
72fn b64_encode_part<T: Serialize>(input: &T) -> serde_json::Result<String> {
73    let json = serde_json::to_vec(input)?;
74    Ok(b64_encode(json))
75}
76
77pub struct Jwt {
78    header: String,
79    key: hmac::Key,
80}
81
82impl Jwt {
83    pub fn new(secret: &[u8]) -> Self {
84        let header: String = b64_encode(b"{\"alg\":\"HS256\",\"typ\":\"JWT\"}");
85        let key = hmac::Key::new(HMAC_SHA256, secret);
86        Self { key, header }
87    }
88
89    fn sign(&self, message: &[u8]) -> String {
90        let digest = hmac::sign(&self.key, message);
91        b64_encode(digest)
92    }
93
94    fn verify(&self, signature: &str, message: &[u8]) -> bool {
95        let signed = self.sign(message);
96        verify_slices_are_equal(signature.as_bytes(), signed.as_bytes()).is_ok()
97    }
98
99    pub fn encode<T: Serialize>(&self, claims: &T) -> Result<String, EncodeError> {
100        let encoded_claims = b64_encode_part(claims)?;
101        let header = &self.header;
102        let mut message = String::with_capacity(header.len() + 1 + encoded_claims.len());
103        message.push_str(header);
104        message.push('.');
105        message.push_str(&encoded_claims);
106
107        let signature = self.sign(message.as_bytes());
108        message.reserve_exact(signature.len() + 1);
109        message.push('.');
110        message.push_str(&signature);
111
112        Ok(message)
113    }
114
115    pub fn decode<T: DeserializeOwned>(&self, token: &str) -> Result<T, DecodeError> {
116        match self.verify_signature(token) {
117            Err(e) => Err(e),
118            Ok((_, claims)) => {
119                let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(claims)?;
120                let claims = decoded_claims.deserialize()?;
121                self.validate(decoded_claims.deserialize()?)?;
122                Ok(claims)
123            }
124        }
125    }
126
127    fn validate(&self, claims: ClaimsForValidation) -> Result<(), DecodeError> {
128        let now = SystemTime::now()
129            .duration_since(UNIX_EPOCH)
130            .expect("Time went backwards")
131            .as_secs();
132        if !matches!(claims.exp, TryParse::Parsed(_)) {
133            return Err(DecodeError::MissingRequiredClaim("exp".to_string()));
134        }
135
136        if matches!(claims.exp, TryParse::Parsed(exp) if exp < now - 60) {
137            return Err(DecodeError::ExpiredSignature);
138        }
139
140        Ok(())
141    }
142
143    fn verify_signature<'a>(&self, token: &'a str) -> Result<(Header, &'a str), DecodeError> {
144        let (message, signature) = match token.rsplit_once('.') {
145            Some(value) => value,
146            None => return Err(DecodeError::InvalidToken),
147        };
148
149        let (header, payload) = match message.rsplit_once('.') {
150            Some(value) => value,
151            None => return Err(DecodeError::InvalidToken),
152        };
153
154        let header = Header::from_encoded(header)?;
155
156        if header.alg != ALGORITHM {
157            return Err(DecodeError::InvalidAlgorithm);
158        }
159
160        if !self.verify(signature, message.as_bytes()) {
161            return Err(DecodeError::InvalidSignature);
162        }
163
164        Ok((header, payload))
165    }
166}
167
168#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
169struct Header {
170    pub typ: String,
171    pub alg: String,
172}
173
174impl Header {
175    /// Converts an encoded part into the Header struct if possible
176    fn from_encoded<T: AsRef<[u8]>>(encoded_part: T) -> Result<Self, DecodeError> {
177        let decoded = b64_decode(encoded_part)?;
178        Ok(serde_json::from_slice(&decoded)?)
179    }
180}
181
182impl Default for Header {
183    fn default() -> Self {
184        Self {
185            typ: "JWT".to_string(),
186            alg: ALGORITHM.to_string(),
187        }
188    }
189}
190
191/// This is used to decode from base64 then deserialize from JSON to several structs:
192/// - The user-provided struct
193/// - The ClaimsForValidation struct from this crate to run validation on
194struct DecodedJwtPartClaims {
195    b64_decoded: Vec<u8>,
196}
197
198impl DecodedJwtPartClaims {
199    fn from_jwt_part_claims(
200        encoded_jwt_part_claims: impl AsRef<[u8]>,
201    ) -> Result<Self, DecodeError> {
202        Ok(Self {
203            b64_decoded: b64_decode(encoded_jwt_part_claims)?,
204        })
205    }
206
207    fn deserialize<'a, T: Deserialize<'a>>(&'a self) -> Result<T, DecodeError> {
208        Ok(serde_json::from_slice(&self.b64_decoded)?)
209    }
210}
211
212#[derive(Deserialize)]
213struct ClaimsForValidation {
214    #[serde(deserialize_with = "numeric_type", default)]
215    exp: TryParse<u64>,
216}
217
218#[derive(Debug)]
219enum TryParse<T> {
220    Parsed(T),
221    FailedToParse,
222    NotPresent,
223}
224
225impl<'de, T: Deserialize<'de>> Deserialize<'de> for TryParse<T> {
226    fn deserialize<D: serde::Deserializer<'de>>(
227        deserializer: D,
228    ) -> std::result::Result<Self, D::Error> {
229        Ok(match Option::<T>::deserialize(deserializer) {
230            Ok(Some(value)) => TryParse::Parsed(value),
231            Ok(None) => TryParse::NotPresent,
232            Err(_) => TryParse::FailedToParse,
233        })
234    }
235}
236
237impl<T> Default for TryParse<T> {
238    fn default() -> Self {
239        Self::NotPresent
240    }
241}
242
243/// Usually #[serde(borrow)] on `Cow` enables deserializing with no allocations where
244/// possible (no escapes in the original str) but it does not work on e.g. `HashSet<Cow<str>>`
245/// We use this struct in this case.
246#[derive(Deserialize, PartialEq, Eq, Hash)]
247struct BorrowedCowIfPossible<'a>(#[serde(borrow)] Cow<'a, str>);
248impl std::borrow::Borrow<str> for BorrowedCowIfPossible<'_> {
249    fn borrow(&self) -> &str {
250        &self.0
251    }
252}
253
254fn numeric_type<'de, D>(deserializer: D) -> std::result::Result<TryParse<u64>, D::Error>
255where
256    D: Deserializer<'de>,
257{
258    struct NumericType(PhantomData<fn() -> TryParse<u64>>);
259
260    impl<'de> Visitor<'de> for NumericType {
261        type Value = TryParse<u64>;
262
263        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
264            formatter.write_str("A NumericType that can be reasonably coerced into a u64")
265        }
266
267        fn visit_f64<E>(self, value: f64) -> std::result::Result<Self::Value, E>
268        where
269            E: de::Error,
270        {
271            if value.is_finite() && value >= 0.0 && value < (u64::MAX as f64) {
272                Ok(TryParse::Parsed(value.round() as u64))
273            } else {
274                Err(serde::de::Error::custom(
275                    "NumericType must be representable as a u64",
276                ))
277            }
278        }
279
280        fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
281        where
282            E: de::Error,
283        {
284            Ok(TryParse::Parsed(value))
285        }
286    }
287
288    match deserializer.deserialize_any(NumericType(PhantomData)) {
289        Ok(ok) => Ok(ok),
290        Err(_) => Ok(TryParse::FailedToParse),
291    }
292}