use rsa::RsaPublicKey;
use serde::de::DeserializeOwned;
use crate::crypto::verify;
use crate::errors::{new_error, ErrorKind, Result};
use crate::header::Header;
use crate::serialization::from_jwt_part_claims;
use crate::validation::{validate, Validation};
use base64::{engine::general_purpose::STANDARD, Engine};
#[derive(Debug)]
pub struct TokenData<T> {
pub header: Header,
pub claims: T,
}
macro_rules! expect_two {
($iter:expr) => {{
let mut i = $iter;
match (i.next(), i.next(), i.next()) {
(Some(first), Some(second), None) => (first, second),
_ => return Err(new_error(ErrorKind::InvalidToken)),
}
}};
}
#[derive(Debug, Clone, PartialEq)]
pub enum DecodingKey {
Hmac(Vec<u8>),
Rsa(rsa::RsaPublicKey),
}
impl DecodingKey {
pub fn from_hmac_secret(secret: &[u8]) -> Self {
DecodingKey::Hmac(secret.to_vec())
}
pub fn from_base64_hmac_secret(secret: &str) -> Result<Self> {
Ok(DecodingKey::Hmac(STANDARD.decode(secret)?))
}
pub fn from_rsa(key: rsa::RsaPublicKey) -> Result<Self> {
Ok(DecodingKey::Rsa(key))
}
pub fn from_rsa_components(n: &str, e: &str) -> Result<Self> {
use crate::serialization::b64_decode;
let n = rsa::BigUint::from_bytes_be(&b64_decode(n)?);
let e = rsa::BigUint::from_bytes_be(&b64_decode(e)?);
Ok(DecodingKey::Rsa(
RsaPublicKey::new(n, e).map_err(|_| new_error(ErrorKind::InvalidKeyFormat))?,
))
}
}
pub fn decode<T: DeserializeOwned>(
token: &str,
key: &DecodingKey,
validation: &Validation,
) -> Result<TokenData<T>> {
let (signature, message) = expect_two!(token.rsplitn(2, '.'));
let (claims, header) = expect_two!(message.rsplitn(2, '.'));
let header = Header::from_encoded(header)?;
if !validation.algorithms.is_empty() & !&validation.algorithms.contains(&header.alg) {
return Err(new_error(ErrorKind::InvalidAlgorithm));
}
if !verify(signature, message, key, header.alg)? {
return Err(new_error(ErrorKind::InvalidSignature));
}
let (decoded_claims, claims_map): (T, _) = from_jwt_part_claims(claims)?;
validate(&claims_map, validation)?;
Ok(TokenData { header, claims: decoded_claims })
}
pub fn dangerous_insecure_decode<T: DeserializeOwned>(token: &str) -> Result<TokenData<T>> {
let (_, message) = expect_two!(token.rsplitn(2, '.'));
let (claims, header) = expect_two!(message.rsplitn(2, '.'));
let header = Header::from_encoded(header)?;
let (decoded_claims, _): (T, _) = from_jwt_part_claims(claims)?;
Ok(TokenData { header, claims: decoded_claims })
}
pub fn dangerous_insecure_decode_with_validation<T: DeserializeOwned>(
token: &str,
validation: &Validation,
) -> Result<TokenData<T>> {
let (_, message) = expect_two!(token.rsplitn(2, '.'));
let (claims, header) = expect_two!(message.rsplitn(2, '.'));
let header = Header::from_encoded(header)?;
if !validation.algorithms.is_empty() & !&validation.algorithms.contains(&header.alg) {
return Err(new_error(ErrorKind::InvalidAlgorithm));
}
let (decoded_claims, claims_map): (T, _) = from_jwt_part_claims(claims)?;
validate(&claims_map, validation)?;
Ok(TokenData { header, claims: decoded_claims })
}
#[deprecated(
note = "This function has been renamed to `dangerous_insecure_decode` and will be removed in a later version."
)]
pub fn dangerous_unsafe_decode<T: DeserializeOwned>(token: &str) -> Result<TokenData<T>> {
dangerous_insecure_decode(token)
}
pub fn decode_header(token: &str) -> Result<Header> {
let (_, message) = expect_two!(token.rsplitn(2, '.'));
let (_, header) = expect_two!(message.rsplitn(2, '.'));
Header::from_encoded(header)
}