use axum::{
extract::Request, http::StatusCode, response::IntoResponse, response::Response, Json,
RequestExt,
};
use chrono::{DateTime, TimeDelta, Utc};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use futures_util::future::BoxFuture;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fmt::Display;
use std::task::{Context, Poll};
pub trait JwtToken: Send + Sync {
fn subject(&self) -> String;
}
fn get_jwt_secret() -> String {
std::env::var("JWT_SECRET").expect("JWT_SECRET must be set")
}
fn get_jwt_issuer() -> String {
std::env::var("JWT_ISSUER").expect("JWT_ISSUER must be set")
}
fn get_jwt_audience() -> String {
std::env::var("JWT_AUDIENCE").expect("JWT_AUDIENCE must be set")
}
pub fn parse_jwt_token(token: &str) -> Result<Claims, jsonwebtoken::errors::Error> {
let jwt_issuer = get_jwt_issuer();
let jwt_audience = get_jwt_audience();
let jwt_secret = get_jwt_secret();
let decode_key = DecodingKey::from_secret(jwt_secret.as_bytes());
let mut validation = Validation::new(Algorithm::HS256);
validation.set_audience(&[jwt_audience]);
validation.set_issuer(&[jwt_issuer]);
let token_data = decode::<Claims>(token, &decode_key, &validation)?;
Ok(token_data.claims)
}
pub struct CreateJwtResult {
pub access_token: String,
pub access_token_expires_at: DateTime<Utc>,
pub scopes: Vec<String>,
pub issuer: String,
pub audience: String,
pub subject: String,
}
pub struct CreateJwtConfig {
pub expires_in: TimeDelta,
pub subject: String,
pub scopes: Vec<String>,
}
pub fn create_jwt_token(
config: CreateJwtConfig,
) -> Result<CreateJwtResult, jsonwebtoken::errors::Error> {
let jwt_secret = get_jwt_secret();
let iss = get_jwt_issuer();
let aud = get_jwt_audience();
let sub = config.subject;
let scopes = config.scopes;
let expires_in = config.expires_in;
let now = Utc::now();
let access_token_expires_at = now + expires_in;
let issued_at = now.timestamp() as u64;
let exp = access_token_expires_at.timestamp() as u64;
let claims = Claims {
iss,
sub,
issued_at,
exp,
aud,
scopes,
};
let encode_key = EncodingKey::from_secret(jwt_secret.as_bytes());
let access_token = encode(&Header::default(), &claims, &encode_key)?;
Ok(CreateJwtResult {
access_token,
access_token_expires_at,
scopes: claims.scopes,
issuer: claims.iss,
audience: claims.aud,
subject: claims.sub,
})
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub aud: String,
pub iss: String,
pub issued_at: u64,
pub exp: u64,
pub scopes: Vec<String>,
}
impl Claims {
pub fn has_scopes(&self, expected_scopes: &[String]) -> bool {
expected_scopes
.iter()
.all(|scope| self.scopes.contains(&scope.to_string()))
}
}
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Email: {}", self.sub)
}
}
#[derive(Clone)]
pub struct RequireScopeLayer {
required_scopes: Vec<String>,
}
impl RequireScopeLayer {
pub fn new() -> Self {
Self {
required_scopes: Vec::new(),
}
}
pub fn with(mut self, require_scope: Vec<&str>) -> Self {
self.required_scopes = require_scope.iter().map(|s| s.to_string()).collect();
self
}
}
impl Default for RequireScopeLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for RequireScopeLayer {
type Service = RequireScopeMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
RequireScopeMiddleware {
inner,
required_scopes: self.required_scopes.clone(),
}
}
}
#[derive(Clone)]
pub struct RequireScopeMiddleware<S> {
inner: S,
required_scopes: Vec<String>,
}
impl<S> Service<Request> for RequireScopeMiddleware<S>
where
S: Service<Request, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut request: Request) -> Self::Future {
let required_scopes = self.required_scopes.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
match request.extract_parts::<Claims>().await {
Ok(claims) => {
if claims.has_scopes(&required_scopes) {
return inner.call(request).await;
}
let response = AuthError::NotSufficientScopes.into_response();
Ok(response)
}
Err(_) => {
let response = AuthError::InvalidToken.into_response();
Ok(response)
}
}
})
}
}
#[cfg(not(any(test, feature = "mock_jwt")))]
use axum::RequestPartsExt;
#[cfg(not(any(test, feature = "mock_jwt")))]
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use derive_more::Display;
use thiserror::Error;
use tower::{Layer, Service};
#[cfg(not(any(test, feature = "mock_jwt")))]
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| AuthError::InvalidToken)?;
let claims = parse_jwt_token(bearer.token()).map_err(|_| AuthError::InvalidToken)?;
Ok(claims)
}
}
#[cfg(any(test, feature = "mock_jwt"))]
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let sub = parts
.headers
.get("X-Claims-Subject")
.unwrap()
.to_str()
.unwrap();
let iss = parts
.headers
.get("X-Claims-Issuer")
.unwrap()
.to_str()
.unwrap();
let aud = parts
.headers
.get("X-Claims-Audience")
.unwrap()
.to_str()
.unwrap();
let issued_at = parts
.headers
.get("X-Claims-Issued-At")
.unwrap()
.to_str()
.unwrap();
let exp = parts
.headers
.get("X-Claims-Expiration")
.unwrap()
.to_str()
.unwrap();
let scopes = parts
.headers
.get("X-Claims-Scopes")
.unwrap()
.to_str()
.unwrap()
.split(',')
.map(|s| s.to_string())
.collect();
let sub = sub.to_string();
let iss = iss.to_string();
let aud = aud.to_string();
let issued_at = issued_at.parse().unwrap();
let exp = exp.parse().unwrap();
Ok(Claims {
sub,
aud,
iss,
issued_at,
exp,
scopes,
})
}
}
impl IntoResponse for AuthError {
fn into_response(self) -> axum::response::Response {
let (status, error_message) = match self {
AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
AuthError::NotSufficientScopes => (StatusCode::FORBIDDEN, "Not sufficient scopes"),
};
let body = Json(json!({
"error": error_message,
}));
(status, body).into_response()
}
}
#[derive(Debug, Error, Display)]
pub enum AuthError {
InvalidToken,
NotSufficientScopes,
}
#[cfg(test)]
mod tests {
use super::*;
use fake::faker::internet::en::FreeEmail;
use fake::Fake;
fn setup() {
std::env::set_var("JWT_SECRET", "secret");
std::env::set_var("JWT_ISSUER", "issuer");
std::env::set_var("JWT_AUDIENCE", "audience");
}
#[test]
fn test_create_token() {
setup();
let email: String = FreeEmail().fake();
let config = CreateJwtConfig {
subject: email.clone(),
scopes: vec!["customers:read".to_string()],
expires_in: chrono::Duration::days(7),
};
let jwt_token = create_jwt_token(config).expect("Failed to create JWT token");
assert_eq!(jwt_token.scopes, vec!["customers:read"]);
let now_plus_5_days =
(chrono::Utc::now() + chrono::Duration::days(7)) - chrono::Duration::seconds(30);
assert!(jwt_token.access_token_expires_at > now_plus_5_days);
let claims = parse_jwt_token(&jwt_token.access_token).unwrap();
assert_eq!(vec!["customers:read".to_string()], claims.scopes);
assert_eq!(email, claims.sub);
}
#[test]
fn test_invalid_token() {
setup();
let email: String = FreeEmail().fake();
let config = CreateJwtConfig {
subject: email,
scopes: vec![],
expires_in: chrono::Duration::days(7),
};
let mut token = create_jwt_token(config).expect("Failed to create JWT token");
token.access_token.push('a');
let claims = parse_jwt_token(&token.access_token);
assert!(claims.is_err());
}
}