1use std::fmt::{Debug, Formatter};
2
3use base64::{Engine, engine::general_purpose::STANDARD};
4use serde::de::DeserializeOwned;
5
6use crate::Algorithm;
7use crate::algorithms::AlgorithmFamily;
8use crate::crypto::JwtVerifier;
9use crate::errors::{ErrorKind, Result, new_error};
10use crate::header::Header;
11use crate::jwk::{AlgorithmParameters, Jwk};
12#[cfg(feature = "use_pem")]
13use crate::pem::decoder::PemEncodedKey;
14use crate::serialization::{DecodedJwtPartClaims, b64_decode};
15use crate::validation::{Validation, validate};
16#[cfg(feature = "aws_lc_rs")]
18use crate::crypto::aws_lc::{
19 ecdsa::{Es256Verifier, Es384Verifier},
20 eddsa::EdDSAVerifier,
21 hmac::{Hs256Verifier, Hs384Verifier, Hs512Verifier},
22 rsa::{
23 Rsa256Verifier, Rsa384Verifier, Rsa512Verifier, RsaPss256Verifier, RsaPss384Verifier,
24 RsaPss512Verifier,
25 },
26};
27#[cfg(feature = "rust_crypto")]
28use crate::crypto::rust_crypto::{
29 ecdsa::{Es256Verifier, Es384Verifier},
30 eddsa::EdDSAVerifier,
31 hmac::{Hs256Verifier, Hs384Verifier, Hs512Verifier},
32 rsa::{
33 Rsa256Verifier, Rsa384Verifier, Rsa512Verifier, RsaPss256Verifier, RsaPss384Verifier,
34 RsaPss512Verifier,
35 },
36};
37
38#[derive(Debug)]
40pub struct TokenData<T> {
41 pub header: Header,
43 pub claims: T,
45}
46
47impl<T> Clone for TokenData<T>
48where
49 T: Clone,
50{
51 fn clone(&self) -> Self {
52 Self { header: self.header.clone(), claims: self.claims.clone() }
53 }
54}
55
56macro_rules! expect_two {
59 ($iter:expr) => {{
60 let mut i = $iter;
61 match (i.next(), i.next(), i.next()) {
62 (Some(first), Some(second), None) => (first, second),
63 _ => return Err(new_error(ErrorKind::InvalidToken)),
64 }
65 }};
66}
67
68#[derive(Clone)]
69pub(crate) enum DecodingKeyKind {
70 SecretOrDer(Vec<u8>),
71 RsaModulusExponent { n: Vec<u8>, e: Vec<u8> },
72}
73
74impl Debug for DecodingKeyKind {
75 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76 match self {
77 Self::SecretOrDer(_) => f.debug_tuple("SecretOrDer").field(&"[redacted]").finish(),
78 Self::RsaModulusExponent { .. } => f
79 .debug_struct("RsaModulusExponent")
80 .field("n", &"[redacted]")
81 .field("e", &"[redacted]")
82 .finish(),
83 }
84 }
85}
86
87#[derive(Clone, Debug)]
90pub struct DecodingKey {
91 pub(crate) family: AlgorithmFamily,
92 pub(crate) kind: DecodingKeyKind,
93}
94
95impl DecodingKey {
96 pub fn family(&self) -> AlgorithmFamily {
98 self.family
99 }
100
101 pub fn from_secret(secret: &[u8]) -> Self {
103 DecodingKey {
104 family: AlgorithmFamily::Hmac,
105 kind: DecodingKeyKind::SecretOrDer(secret.to_vec()),
106 }
107 }
108
109 pub fn from_base64_secret(secret: &str) -> Result<Self> {
111 let out = STANDARD.decode(secret)?;
112 Ok(DecodingKey { family: AlgorithmFamily::Hmac, kind: DecodingKeyKind::SecretOrDer(out) })
113 }
114
115 #[cfg(feature = "use_pem")]
118 pub fn from_rsa_pem(key: &[u8]) -> Result<Self> {
119 let pem_key = PemEncodedKey::new(key)?;
120 let content = pem_key.as_rsa_key()?;
121 Ok(DecodingKey {
122 family: AlgorithmFamily::Rsa,
123 kind: DecodingKeyKind::SecretOrDer(content.to_vec()),
124 })
125 }
126
127 pub fn from_rsa_components(modulus: &str, exponent: &str) -> Result<Self> {
129 let n = b64_decode(modulus)?;
130 let e = b64_decode(exponent)?;
131 Ok(DecodingKey {
132 family: AlgorithmFamily::Rsa,
133 kind: DecodingKeyKind::RsaModulusExponent { n, e },
134 })
135 }
136
137 pub fn from_rsa_raw_components(modulus: &[u8], exponent: &[u8]) -> Self {
139 DecodingKey {
140 family: AlgorithmFamily::Rsa,
141 kind: DecodingKeyKind::RsaModulusExponent { n: modulus.to_vec(), e: exponent.to_vec() },
142 }
143 }
144
145 #[cfg(feature = "use_pem")]
148 pub fn from_ec_pem(key: &[u8]) -> Result<Self> {
149 let pem_key = PemEncodedKey::new(key)?;
150 let content = pem_key.as_ec_public_key()?;
151 Ok(DecodingKey {
152 family: AlgorithmFamily::Ec,
153 kind: DecodingKeyKind::SecretOrDer(content.to_vec()),
154 })
155 }
156
157 pub fn from_ec_components(x: &str, y: &str) -> Result<Self> {
159 let x_cmp = b64_decode(x)?;
160 let y_cmp = b64_decode(y)?;
161
162 let mut public_key = Vec::with_capacity(1 + x.len() + y.len());
163 public_key.push(0x04);
164 public_key.extend_from_slice(&x_cmp);
165 public_key.extend_from_slice(&y_cmp);
166
167 Ok(DecodingKey {
168 family: AlgorithmFamily::Ec,
169 kind: DecodingKeyKind::SecretOrDer(public_key),
170 })
171 }
172
173 #[cfg(feature = "use_pem")]
176 pub fn from_ed_pem(key: &[u8]) -> Result<Self> {
177 let pem_key = PemEncodedKey::new(key)?;
178 let content = pem_key.as_ed_public_key()?;
179 Ok(DecodingKey {
180 family: AlgorithmFamily::Ed,
181 kind: DecodingKeyKind::SecretOrDer(content.to_vec()),
182 })
183 }
184
185 pub fn from_rsa_der(der: &[u8]) -> Self {
187 DecodingKey {
188 family: AlgorithmFamily::Rsa,
189 kind: DecodingKeyKind::SecretOrDer(der.to_vec()),
190 }
191 }
192
193 pub fn from_ec_der(der: &[u8]) -> Self {
195 DecodingKey {
196 family: AlgorithmFamily::Ec,
197 kind: DecodingKeyKind::SecretOrDer(der.to_vec()),
198 }
199 }
200
201 pub fn from_ed_der(der: &[u8]) -> Self {
203 DecodingKey {
204 family: AlgorithmFamily::Ed,
205 kind: DecodingKeyKind::SecretOrDer(der.to_vec()),
206 }
207 }
208
209 pub fn from_ed_components(x: &str) -> Result<Self> {
211 let x_decoded = b64_decode(x)?;
212 Ok(DecodingKey {
213 family: AlgorithmFamily::Ed,
214 kind: DecodingKeyKind::SecretOrDer(x_decoded),
215 })
216 }
217
218 pub fn from_jwk(jwk: &Jwk) -> Result<Self> {
220 match &jwk.algorithm {
221 AlgorithmParameters::RSA(params) => {
222 DecodingKey::from_rsa_components(¶ms.n, ¶ms.e)
223 }
224 AlgorithmParameters::EllipticCurve(params) => {
225 DecodingKey::from_ec_components(¶ms.x, ¶ms.y)
226 }
227 AlgorithmParameters::OctetKeyPair(params) => DecodingKey::from_ed_components(¶ms.x),
228 AlgorithmParameters::OctetKey(params) => {
229 let out = b64_decode(¶ms.value)?;
230 Ok(DecodingKey {
231 family: AlgorithmFamily::Hmac,
232 kind: DecodingKeyKind::SecretOrDer(out),
233 })
234 }
235 }
236 }
237
238 pub(crate) fn as_bytes(&self) -> &[u8] {
239 match &self.kind {
240 DecodingKeyKind::SecretOrDer(b) => b,
241 DecodingKeyKind::RsaModulusExponent { .. } => unreachable!(),
242 }
243 }
244
245 pub(crate) fn try_get_hmac_secret(&self) -> Result<&[u8]> {
246 if self.family == AlgorithmFamily::Hmac {
247 Ok(self.as_bytes())
248 } else {
249 Err(new_error(ErrorKind::InvalidKeyFormat))
250 }
251 }
252}
253
254impl TryFrom<&Jwk> for DecodingKey {
255 type Error = crate::errors::Error;
256
257 fn try_from(jwk: &Jwk) -> Result<Self> {
258 Self::from_jwk(jwk)
259 }
260}
261
262pub fn decode<T: DeserializeOwned>(
281 token: impl AsRef<[u8]>,
282 key: &DecodingKey,
283 validation: &Validation,
284) -> Result<TokenData<T>> {
285 let token = token.as_ref();
286 let header = decode_header(token)?;
287
288 if validation.validate_signature && !validation.algorithms.contains(&header.alg) {
289 return Err(new_error(ErrorKind::InvalidAlgorithm));
290 }
291
292 let verifying_provider = jwt_verifier_factory(&header.alg, key)?;
293
294 let (header, claims) = verify_signature(token, validation, verifying_provider)?;
295
296 let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(claims)?;
297 let claims = decoded_claims.deserialize()?;
298 validate(decoded_claims.deserialize()?, validation)?;
299
300 Ok(TokenData { header, claims })
301}
302
303pub fn insecure_decode<T: DeserializeOwned>(token: impl AsRef<[u8]>) -> Result<TokenData<T>> {
307 let token = token.as_ref();
308
309 let (_, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
310 let (payload, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
311
312 let header = Header::from_encoded(header)?;
313 let claims = DecodedJwtPartClaims::from_jwt_part_claims(payload)?.deserialize()?;
314
315 Ok(TokenData { header, claims })
316}
317
318pub fn jwt_verifier_factory(
320 algorithm: &Algorithm,
321 key: &DecodingKey,
322) -> Result<Box<dyn JwtVerifier>> {
323 let jwt_encoder = match algorithm {
324 Algorithm::HS256 => Box::new(Hs256Verifier::new(key)?) as Box<dyn JwtVerifier>,
325 Algorithm::HS384 => Box::new(Hs384Verifier::new(key)?) as Box<dyn JwtVerifier>,
326 Algorithm::HS512 => Box::new(Hs512Verifier::new(key)?) as Box<dyn JwtVerifier>,
327 Algorithm::ES256 => Box::new(Es256Verifier::new(key)?) as Box<dyn JwtVerifier>,
328 Algorithm::ES384 => Box::new(Es384Verifier::new(key)?) as Box<dyn JwtVerifier>,
329 Algorithm::RS256 => Box::new(Rsa256Verifier::new(key)?) as Box<dyn JwtVerifier>,
330 Algorithm::RS384 => Box::new(Rsa384Verifier::new(key)?) as Box<dyn JwtVerifier>,
331 Algorithm::RS512 => Box::new(Rsa512Verifier::new(key)?) as Box<dyn JwtVerifier>,
332 Algorithm::PS256 => Box::new(RsaPss256Verifier::new(key)?) as Box<dyn JwtVerifier>,
333 Algorithm::PS384 => Box::new(RsaPss384Verifier::new(key)?) as Box<dyn JwtVerifier>,
334 Algorithm::PS512 => Box::new(RsaPss512Verifier::new(key)?) as Box<dyn JwtVerifier>,
335 Algorithm::EdDSA => Box::new(EdDSAVerifier::new(key)?) as Box<dyn JwtVerifier>,
336 };
337
338 Ok(jwt_encoder)
339}
340
341pub fn decode_header(token: impl AsRef<[u8]>) -> Result<Header> {
352 let token = token.as_ref();
353 let (_, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
354 let (_, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
355 Header::from_encoded(header)
356}
357
358pub(crate) fn verify_signature_body(
359 message: &[u8],
360 signature: &[u8],
361 header: &Header,
362 validation: &Validation,
363 verifying_provider: Box<dyn JwtVerifier>,
364) -> Result<()> {
365 if validation.validate_signature && validation.algorithms.is_empty() {
366 return Err(new_error(ErrorKind::MissingAlgorithm));
367 }
368
369 if validation.validate_signature {
370 for alg in &validation.algorithms {
371 if verifying_provider.algorithm().family() != alg.family() {
372 return Err(new_error(ErrorKind::InvalidAlgorithm));
373 }
374 }
375 }
376
377 if validation.validate_signature && !validation.algorithms.contains(&header.alg) {
378 return Err(new_error(ErrorKind::InvalidAlgorithm));
379 }
380
381 if validation.validate_signature
382 && verifying_provider.verify(message, &b64_decode(signature)?).is_err()
383 {
384 return Err(new_error(ErrorKind::InvalidSignature));
385 }
386
387 Ok(())
388}
389
390fn verify_signature<'a>(
394 token: &'a [u8],
395 validation: &Validation,
396 verifying_provider: Box<dyn JwtVerifier>,
397) -> Result<(Header, &'a [u8])> {
398 let (signature, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
399 let (payload, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
400 let header = Header::from_encoded(header)?;
401 verify_signature_body(message, signature, &header, validation, verifying_provider)?;
402
403 Ok((header, payload))
404}