use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use http::StatusCode;
use http::request::Parts;
use serde::de::DeserializeOwned;
use crate::extractors::FromRequest;
use crate::extractors::FromRequestParts;
use crate::responder::Responder;
use crate::types::Request;
#[doc(alias = "jwt")]
pub struct Jwt {
pub token: String,
pub header: String,
}
#[doc(alias = "jwt_claims")]
pub struct JwtClaims<T>(pub T);
#[derive(Debug)]
pub enum JwtError {
MissingAuthHeader,
InvalidAuthHeader,
InvalidBearerFormat,
EmptyToken,
InvalidJwtFormat,
InvalidJwtHeader,
InvalidJwtClaims,
InvalidJwtSignature,
ClaimsDeserializationError(String),
TokenExpired,
TokenNotYetValid,
}
impl Responder for JwtError {
fn into_response(self) -> crate::types::Response {
let (status, message) = match self {
JwtError::MissingAuthHeader => (StatusCode::UNAUTHORIZED, "Missing Authorization header"),
JwtError::InvalidAuthHeader => (StatusCode::UNAUTHORIZED, "Invalid Authorization header"),
JwtError::InvalidBearerFormat => (
StatusCode::UNAUTHORIZED,
"Authorization header is not Bearer token",
),
JwtError::EmptyToken => (StatusCode::UNAUTHORIZED, "JWT token is empty"),
JwtError::InvalidJwtFormat => (StatusCode::UNAUTHORIZED, "Invalid JWT token format"),
JwtError::InvalidJwtHeader => (StatusCode::UNAUTHORIZED, "Invalid JWT header section"),
JwtError::InvalidJwtClaims => (StatusCode::UNAUTHORIZED, "Invalid JWT claims section"),
JwtError::InvalidJwtSignature => (StatusCode::UNAUTHORIZED, "Invalid JWT signature section"),
JwtError::ClaimsDeserializationError(_) => (
StatusCode::UNAUTHORIZED,
"JWT claims deserialization failed",
),
JwtError::TokenExpired => (StatusCode::UNAUTHORIZED, "JWT token has expired"),
JwtError::TokenNotYetValid => (StatusCode::UNAUTHORIZED, "JWT token is not yet valid"),
};
(status, message).into_response()
}
}
impl Jwt {
fn extract_from_headers(headers: &http::HeaderMap) -> Result<Self, JwtError> {
let auth_header = headers
.get("Authorization")
.ok_or(JwtError::MissingAuthHeader)?;
let auth_str = auth_header
.to_str()
.map_err(|_| JwtError::InvalidAuthHeader)?;
if !auth_str.starts_with("Bearer ") {
return Err(JwtError::InvalidBearerFormat);
}
let token = &auth_str[7..];
if token.is_empty() {
return Err(JwtError::EmptyToken);
}
Ok(Jwt {
token: token.to_string(),
header: auth_str.to_string(),
})
}
pub fn validate_format(&self) -> Result<(), JwtError> {
let parts = self.token.split('.').collect::<Vec<&str>>();
if parts.len() != 3 {
return Err(JwtError::InvalidJwtFormat);
}
for part in &parts {
if part.is_empty() {
return Err(JwtError::InvalidJwtFormat);
}
}
Ok(())
}
pub fn header(&self) -> Result<serde_json::Value, JwtError> {
let parts: Vec<&str> = self.token.split('.').collect();
if parts.len() != 3 {
return Err(JwtError::InvalidJwtFormat);
}
let header_bytes = URL_SAFE_NO_PAD
.decode(parts[0])
.map_err(|_| JwtError::InvalidJwtHeader)?;
let header: serde_json::Value =
serde_json::from_slice(&header_bytes).map_err(|_| JwtError::InvalidJwtHeader)?;
Ok(header)
}
pub fn claims(&self) -> Result<serde_json::Value, JwtError> {
let parts: Vec<&str> = self.token.split('.').collect();
if parts.len() != 3 {
return Err(JwtError::InvalidJwtFormat);
}
let claims_bytes = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|_| JwtError::InvalidJwtClaims)?;
let claims: serde_json::Value =
serde_json::from_slice(&claims_bytes).map_err(|_| JwtError::InvalidJwtClaims)?;
Ok(claims)
}
pub fn signature(&self) -> Result<Vec<u8>, JwtError> {
let parts: Vec<&str> = self.token.split('.').collect();
if parts.len() != 3 {
return Err(JwtError::InvalidJwtFormat);
}
let signature = URL_SAFE_NO_PAD
.decode(parts[2])
.map_err(|_| JwtError::InvalidJwtSignature)?;
Ok(signature)
}
pub fn validate_expiration(&self) -> Result<(), JwtError> {
let claims = self.claims()?;
if let Some(exp) = claims.get("exp").and_then(|v| v.as_u64()) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
if exp < now {
return Err(JwtError::TokenExpired);
}
}
Ok(())
}
pub fn validate_not_before(&self) -> Result<(), JwtError> {
let claims = self.claims()?;
if let Some(nbf) = claims.get("nbf").and_then(|v| v.as_u64()) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
if nbf > now {
return Err(JwtError::TokenNotYetValid);
}
}
Ok(())
}
}
impl<T> JwtClaims<T>
where
T: DeserializeOwned,
{
fn extract_from_headers(headers: &http::HeaderMap) -> Result<Self, JwtError> {
let jwt = Jwt::extract_from_headers(headers)?;
jwt.validate_format()?;
jwt.validate_expiration()?;
jwt.validate_not_before()?;
let claims_json = jwt.claims()?;
let claims: T = serde_json::from_value(claims_json)
.map_err(|e| JwtError::ClaimsDeserializationError(e.to_string()))?;
Ok(JwtClaims(claims))
}
}
impl<'a> FromRequest<'a> for Jwt {
type Error = JwtError;
fn from_request(
req: &'a mut Request,
) -> impl Future<Output = Result<Self, Self::Error>> + Send + 'a {
futures_util::future::ready(Self::extract_from_headers(req.headers()))
}
}
impl<'a> FromRequestParts<'a> for Jwt {
type Error = JwtError;
fn from_request_parts(
parts: &'a mut Parts,
) -> impl Future<Output = Result<Self, Self::Error>> + Send + 'a {
futures_util::future::ready(Self::extract_from_headers(&parts.headers))
}
}
impl<'a, T> FromRequest<'a> for JwtClaims<T>
where
T: DeserializeOwned + Send + 'a,
{
type Error = JwtError;
fn from_request(
req: &'a mut Request,
) -> impl Future<Output = Result<Self, Self::Error>> + Send + 'a {
futures_util::future::ready(Self::extract_from_headers(req.headers()))
}
}
impl<'a, T> FromRequestParts<'a> for JwtClaims<T>
where
T: DeserializeOwned + Send + 'a,
{
type Error = JwtError;
fn from_request_parts(
parts: &'a mut Parts,
) -> impl Future<Output = Result<Self, Self::Error>> + Send + 'a {
futures_util::future::ready(Self::extract_from_headers(&parts.headers))
}
}