use core::{
mem,
task::{Context, Poll},
};
use std::sync::LazyLock;
use http::{Request, Response, StatusCode, header::AUTHORIZATION};
use http_body::Body;
use parking_lot::RwLock;
use tower_layer::Layer;
use tower_service::Service;
use ts_error::{IntoReport, LogError};
use ts_token::{
JsonWebToken,
jwks::{JsonWebKeyCache, JsonWebKeySetProvider},
jwt::TokenType,
};
use crate::middleware::futures::UndefinedFuture;
static CACHE: LazyLock<RwLock<JsonWebKeyCache>> =
LazyLock::new(|| RwLock::new(JsonWebKeyCache::new()));
#[derive(Debug, Clone)]
pub struct TokenAuth<P>
where
P: JsonWebKeySetProvider,
{
pub is_token_required: bool,
pub provider: P,
}
impl<P> TokenAuth<P>
where
P: JsonWebKeySetProvider,
{
pub fn required(provider: P) -> Self {
Self {
is_token_required: true,
provider,
}
}
pub fn optional(provider: P) -> Self {
Self {
is_token_required: false,
provider,
}
}
pub async fn authenticate<T>(self, mut request: Request<T>) -> Result<Request<T>, StatusCode> {
let Some(authorization_header) = request.headers().get(AUTHORIZATION) else {
if self.is_token_required {
return Err(StatusCode::UNAUTHORIZED);
} else {
return Ok(request);
}
};
let Ok(authorization_header) = authorization_header.to_str() else {
return Err(StatusCode::UNAUTHORIZED);
};
if authorization_header
.get(0..7)
.is_none_or(|bearer| !bearer.to_lowercase().eq("bearer "))
{
return Err(StatusCode::UNAUTHORIZED);
}
let Some(encoded_token) = authorization_header.get(7..) else {
return Err(StatusCode::UNAUTHORIZED);
};
let Some(token) = JsonWebToken::deserialize(encoded_token) else {
return Err(StatusCode::UNAUTHORIZED);
};
if !token.claims.is_valid() {
return Err(StatusCode::UNAUTHORIZED);
}
{
let mut cache = CACHE.write();
cache.remove_stale_keys();
}
let cache_contains_key = {
let cache = CACHE.read();
cache.get(&token.header.kid).is_some()
};
if !cache_contains_key {
let jwks = self
.provider
.fetch()
.await
.into_report()
.log_err()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let mut cache = CACHE.write();
cache.insert(jwks);
}
let cache = CACHE.read();
let Some(key) = cache.get(&token.header.kid) else {
return Err(StatusCode::UNAUTHORIZED);
};
if !key.verifies_signature(&token) {
return Err(StatusCode::UNAUTHORIZED);
}
if let TokenType::Consent { act } = &token.claims.typ {
let expected_action = format!("{} {}", request.method(), request.uri().path());
if expected_action.ne(act) {
return Err(StatusCode::FORBIDDEN);
}
}
request.extensions_mut().insert(token);
Ok(request)
}
}
impl<S, P> Layer<S> for TokenAuth<P>
where
P: JsonWebKeySetProvider,
{
type Service = TokenAuthService<S, P>;
fn layer(&self, inner: S) -> Self::Service {
TokenAuthService {
inner,
auth: self.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct TokenAuthService<S, P>
where
P: JsonWebKeySetProvider,
{
inner: S,
auth: TokenAuth<P>,
}
impl<S, P> TokenAuthService<S, P>
where
P: JsonWebKeySetProvider,
{
pub fn new(inner: S, auth: TokenAuth<P>) -> Self {
Self { inner, auth }
}
}
impl<Svc, Prov, ReqBody, ResBody> Service<Request<ReqBody>> for TokenAuthService<Svc, Prov>
where
Svc: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send,
ResBody: Body + Send + Default,
ReqBody: Send + 'static,
Prov: JsonWebKeySetProvider,
{
type Response = Svc::Response;
type Error = Svc::Error;
type Future = UndefinedFuture<Svc, ReqBody>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
let auth_future = self.auth.clone().authenticate(request);
let mut inner = self.inner.clone();
mem::swap(&mut self.inner, &mut inner);
UndefinedFuture::define(Box::pin(auth_future), inner)
}
}
#[cfg(test)]
mod test {
use axum::{Extension, Router, routing::get};
use bytes::Bytes;
use http::{Request, Response, StatusCode, header::AUTHORIZATION};
use http_body_util::Full;
use tower::{ServiceBuilder, ServiceExt};
use tower_http::BoxError;
use ts_token::{
JsonWebKey, JsonWebKeySet, JsonWebToken, jwks::StaticJsonWebKeySet, jwt::TokenType,
};
use crate::middleware::test::get_request;
use crate::middleware::{test::JWK, token::TokenAuth};
use crate::test::ResponseTestExt;
async fn test_service(
token_type: Option<TokenType>,
token_required: bool,
) -> Response<Full<Bytes>> {
let jwk: JsonWebKey = serde_json::from_str(JWK).unwrap();
let provider = StaticJsonWebKeySet::new(JsonWebKeySet { keys: vec![jwk] });
let auth = if token_required {
TokenAuth::required(provider)
} else {
TokenAuth::optional(provider)
};
ServiceBuilder::new()
.layer(auth)
.service_fn(async |req: Request<Full<Bytes>>| {
Ok::<_, BoxError>(Response::new(req.into_body()))
})
.ready()
.await
.expect("service should be ok")
.oneshot(get_request(token_type))
.await
.unwrap()
}
#[tokio::test]
async fn axum() {
let jwk: JsonWebKey = serde_json::from_str(JWK).unwrap();
let provider = StaticJsonWebKeySet::new(JsonWebKeySet { keys: vec![jwk] });
let auth = TokenAuth::required(provider);
Router::new()
.route(
"/resource/id",
get(|Extension(token): Extension<JsonWebToken>| async move {
assert_eq!("subject", token.claims.sub);
StatusCode::OK
}),
)
.layer(auth)
.oneshot(get_request(Some(TokenType::Common)))
.await
.unwrap()
.expect_status(StatusCode::OK);
}
#[tokio::test]
async fn consent_token() {
test_service(
Some(TokenType::Consent {
act: "DELETE /resource/id".to_string(),
}),
true,
)
.await
.expect_status(StatusCode::FORBIDDEN);
test_service(
Some(TokenType::Consent {
act: "GET /resource/id".to_string(),
}),
true,
)
.await
.expect_status(StatusCode::OK);
test_service(
Some(TokenType::Consent {
act: "GET /resource/id2".to_string(),
}),
true,
)
.await
.expect_status(StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn requirement() {
test_service(None, true)
.await
.expect_status(StatusCode::UNAUTHORIZED);
test_service(None, false)
.await
.expect_status(StatusCode::OK);
test_service(Some(TokenType::Common), true)
.await
.expect_status(StatusCode::OK);
test_service(Some(TokenType::Common), false)
.await
.expect_status(StatusCode::OK);
}
#[tokio::test]
async fn token_validity() {
test_service(Some(TokenType::Provisioning), true)
.await
.expect_status(StatusCode::OK);
test_service(Some(TokenType::Common), true)
.await
.expect_status(StatusCode::OK);
{
const INVALID_JWT: &str = r#"bearer eyJhbGciOiJFZDI1NTE5IiwidHlwIjoiSldUIiwia2lkIjoiVU1JaTBoZGxCQmNJRzhvQ09tQmlfMGJ2UWZsaXZneHA5REtlMkw2UGpiRSJ9.eyJ0aWQiOiJ0b2tlbi1pZCIsImV4cCI6MiwiaWF0IjoxLCJzdWIiOiJzdWJqZWN0LWlkIiwidHlwIjoiY29tbW9uIn0.f7PHRouKc9DYxbRNZdUdrdmM6gC-HdmlorxZHPv5s21oqmbJMsOXXFpnh_52fXPbgY-rNPCvwHFyVKsovk51CA"#;
let jwk: JsonWebKey = serde_json::from_str(JWK).unwrap();
let provider = StaticJsonWebKeySet::new(JsonWebKeySet { keys: vec![jwk] });
let auth = TokenAuth::required(provider);
let request = Request::builder()
.uri("/resource/id")
.header(AUTHORIZATION, INVALID_JWT)
.body(Full::<Bytes>::default())
.expect("request should be valid");
ServiceBuilder::new()
.layer(auth)
.service_fn(async |req: Request<Full<Bytes>>| {
Ok::<_, BoxError>(Response::new(req.into_body()))
})
.ready()
.await
.expect("service should be ok")
.oneshot(request)
.await
.unwrap()
.expect_status(StatusCode::UNAUTHORIZED);
}
}
}