1use base64::engine::general_purpose::URL_SAFE_NO_PAD;
2use base64::Engine;
3use ring::constant_time::verify_slices_are_equal;
4use ring::hmac::{self, HMAC_SHA256};
5use serde::de::DeserializeOwned;
6use serde::de::{self, Visitor};
7use serde::Deserializer;
8use serde::{Deserialize, Serialize};
9use std::borrow::Cow;
10use std::fmt;
11use std::marker::PhantomData;
12use std::time::{SystemTime, UNIX_EPOCH};
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum EncodeError {
17 #[error("{0}")]
18 Json(#[from] serde_json::Error),
19}
20
21#[derive(Debug, Error)]
22pub enum DecodeError {
23 #[error("Invalid token")]
25 InvalidToken,
26 #[error("Invalid signature")]
28 InvalidSignature,
29 #[error("Missing required claim: {0}")]
32 MissingRequiredClaim(String),
33 #[error("Expired signature")]
35 ExpiredSignature,
36 #[error("Invalid issuer")]
38 InvalidIssuer,
39 #[error("Invalid audience")]
41 InvalidAudience,
42 #[error("Invalid subject")]
44 InvalidSubject,
45 #[error("Immature signature")]
47 ImmatureSignature,
48 #[error("Invalid algorithm")]
51 InvalidAlgorithm,
52 #[error("{0}")]
55 Base64(#[from] base64::DecodeError),
56 #[error("{0}")]
58 Json(#[from] serde_json::Error),
59}
60
61const ALGORITHM: &str = "HS256";
62
63fn b64_encode<T: AsRef<[u8]>>(input: T) -> String {
64 URL_SAFE_NO_PAD.encode(input)
65}
66
67fn b64_decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, DecodeError> {
68 URL_SAFE_NO_PAD.decode(input).map_err(|e| e.into())
69}
70
71fn b64_encode_part<T: Serialize>(input: &T) -> serde_json::Result<String> {
73 let json = serde_json::to_vec(input)?;
74 Ok(b64_encode(json))
75}
76
77pub struct Jwt {
78 header: String,
79 key: hmac::Key,
80}
81
82impl Jwt {
83 pub fn new(secret: &[u8]) -> Self {
84 let header: String = b64_encode(b"{\"alg\":\"HS256\",\"typ\":\"JWT\"}");
85 let key = hmac::Key::new(HMAC_SHA256, secret);
86 Self { key, header }
87 }
88
89 fn sign(&self, message: &[u8]) -> String {
90 let digest = hmac::sign(&self.key, message);
91 b64_encode(digest)
92 }
93
94 fn verify(&self, signature: &str, message: &[u8]) -> bool {
95 let signed = self.sign(message);
96 verify_slices_are_equal(signature.as_bytes(), signed.as_bytes()).is_ok()
97 }
98
99 pub fn encode<T: Serialize>(&self, claims: &T) -> Result<String, EncodeError> {
100 let encoded_claims = b64_encode_part(claims)?;
101 let header = &self.header;
102 let mut message = String::with_capacity(header.len() + 1 + encoded_claims.len());
103 message.push_str(header);
104 message.push('.');
105 message.push_str(&encoded_claims);
106
107 let signature = self.sign(message.as_bytes());
108 message.reserve_exact(signature.len() + 1);
109 message.push('.');
110 message.push_str(&signature);
111
112 Ok(message)
113 }
114
115 pub fn decode<T: DeserializeOwned>(&self, token: &str) -> Result<T, DecodeError> {
116 match self.verify_signature(token) {
117 Err(e) => Err(e),
118 Ok((_, claims)) => {
119 let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(claims)?;
120 let claims = decoded_claims.deserialize()?;
121 self.validate(decoded_claims.deserialize()?)?;
122 Ok(claims)
123 }
124 }
125 }
126
127 fn validate(&self, claims: ClaimsForValidation) -> Result<(), DecodeError> {
128 let now = SystemTime::now()
129 .duration_since(UNIX_EPOCH)
130 .expect("Time went backwards")
131 .as_secs();
132 if !matches!(claims.exp, TryParse::Parsed(_)) {
133 return Err(DecodeError::MissingRequiredClaim("exp".to_string()));
134 }
135
136 if matches!(claims.exp, TryParse::Parsed(exp) if exp < now - 60) {
137 return Err(DecodeError::ExpiredSignature);
138 }
139
140 Ok(())
141 }
142
143 fn verify_signature<'a>(&self, token: &'a str) -> Result<(Header, &'a str), DecodeError> {
144 let (message, signature) = match token.rsplit_once('.') {
145 Some(value) => value,
146 None => return Err(DecodeError::InvalidToken),
147 };
148
149 let (header, payload) = match message.rsplit_once('.') {
150 Some(value) => value,
151 None => return Err(DecodeError::InvalidToken),
152 };
153
154 let header = Header::from_encoded(header)?;
155
156 if header.alg != ALGORITHM {
157 return Err(DecodeError::InvalidAlgorithm);
158 }
159
160 if !self.verify(signature, message.as_bytes()) {
161 return Err(DecodeError::InvalidSignature);
162 }
163
164 Ok((header, payload))
165 }
166}
167
168#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
169struct Header {
170 pub typ: String,
171 pub alg: String,
172}
173
174impl Header {
175 fn from_encoded<T: AsRef<[u8]>>(encoded_part: T) -> Result<Self, DecodeError> {
177 let decoded = b64_decode(encoded_part)?;
178 Ok(serde_json::from_slice(&decoded)?)
179 }
180}
181
182impl Default for Header {
183 fn default() -> Self {
184 Self {
185 typ: "JWT".to_string(),
186 alg: ALGORITHM.to_string(),
187 }
188 }
189}
190
191struct DecodedJwtPartClaims {
195 b64_decoded: Vec<u8>,
196}
197
198impl DecodedJwtPartClaims {
199 fn from_jwt_part_claims(
200 encoded_jwt_part_claims: impl AsRef<[u8]>,
201 ) -> Result<Self, DecodeError> {
202 Ok(Self {
203 b64_decoded: b64_decode(encoded_jwt_part_claims)?,
204 })
205 }
206
207 fn deserialize<'a, T: Deserialize<'a>>(&'a self) -> Result<T, DecodeError> {
208 Ok(serde_json::from_slice(&self.b64_decoded)?)
209 }
210}
211
212#[derive(Deserialize)]
213struct ClaimsForValidation {
214 #[serde(deserialize_with = "numeric_type", default)]
215 exp: TryParse<u64>,
216}
217
218#[derive(Debug)]
219enum TryParse<T> {
220 Parsed(T),
221 FailedToParse,
222 NotPresent,
223}
224
225impl<'de, T: Deserialize<'de>> Deserialize<'de> for TryParse<T> {
226 fn deserialize<D: serde::Deserializer<'de>>(
227 deserializer: D,
228 ) -> std::result::Result<Self, D::Error> {
229 Ok(match Option::<T>::deserialize(deserializer) {
230 Ok(Some(value)) => TryParse::Parsed(value),
231 Ok(None) => TryParse::NotPresent,
232 Err(_) => TryParse::FailedToParse,
233 })
234 }
235}
236
237impl<T> Default for TryParse<T> {
238 fn default() -> Self {
239 Self::NotPresent
240 }
241}
242
243#[derive(Deserialize, PartialEq, Eq, Hash)]
247struct BorrowedCowIfPossible<'a>(#[serde(borrow)] Cow<'a, str>);
248impl std::borrow::Borrow<str> for BorrowedCowIfPossible<'_> {
249 fn borrow(&self) -> &str {
250 &self.0
251 }
252}
253
254fn numeric_type<'de, D>(deserializer: D) -> std::result::Result<TryParse<u64>, D::Error>
255where
256 D: Deserializer<'de>,
257{
258 struct NumericType(PhantomData<fn() -> TryParse<u64>>);
259
260 impl<'de> Visitor<'de> for NumericType {
261 type Value = TryParse<u64>;
262
263 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
264 formatter.write_str("A NumericType that can be reasonably coerced into a u64")
265 }
266
267 fn visit_f64<E>(self, value: f64) -> std::result::Result<Self::Value, E>
268 where
269 E: de::Error,
270 {
271 if value.is_finite() && value >= 0.0 && value < (u64::MAX as f64) {
272 Ok(TryParse::Parsed(value.round() as u64))
273 } else {
274 Err(serde::de::Error::custom(
275 "NumericType must be representable as a u64",
276 ))
277 }
278 }
279
280 fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
281 where
282 E: de::Error,
283 {
284 Ok(TryParse::Parsed(value))
285 }
286 }
287
288 match deserializer.deserialize_any(NumericType(PhantomData)) {
289 Ok(ok) => Ok(ok),
290 Err(_) => Ok(TryParse::FailedToParse),
291 }
292}