use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::response::IntoResponse;
use http::Request;
use tower::{Layer, Service};
use crate::Error;
use crate::auth::session::Session;
use super::claims::Claims;
use super::decoder::JwtDecoder;
use super::error::JwtError;
use super::service::JwtSessionService;
use super::source::{BearerSource, TokenSource};
use crate::auth::session::token::SessionToken;
#[derive(Clone)]
pub struct JwtLayer {
decoder: JwtDecoder,
sources: Arc<[Arc<dyn TokenSource>]>,
service: Option<JwtSessionService>,
}
impl JwtLayer {
pub fn new(decoder: JwtDecoder) -> Self {
Self {
decoder,
sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
service: None,
}
}
pub fn from_service(service: JwtSessionService) -> Self {
let decoder = service.decoder().clone();
Self {
decoder,
sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
service: Some(service),
}
}
pub fn with_sources(mut self, sources: Vec<Arc<dyn TokenSource>>) -> Self {
self.sources = Arc::from(sources);
self
}
}
impl<Svc> Layer<Svc> for JwtLayer {
type Service = JwtMiddleware<Svc>;
fn layer(&self, inner: Svc) -> Self::Service {
JwtMiddleware {
inner,
decoder: self.decoder.clone(),
sources: self.sources.clone(),
service: self.service.clone(),
}
}
}
pub struct JwtMiddleware<Svc> {
inner: Svc,
decoder: JwtDecoder,
sources: Arc<[Arc<dyn TokenSource>]>,
service: Option<JwtSessionService>,
}
impl<Svc: Clone> Clone for JwtMiddleware<Svc> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
decoder: self.decoder.clone(),
sources: self.sources.clone(),
service: self.service.clone(),
}
}
}
impl<Svc> Service<Request<Body>> for JwtMiddleware<Svc>
where
Svc: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
Svc::Future: Send + 'static,
Svc::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
{
type Response = http::Response<Body>;
type Error = Svc::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let decoder = self.decoder.clone();
let sources = self.sources.clone();
let service = self.service.clone();
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
let (mut parts, body) = request.into_parts();
let token = sources.iter().find_map(|s| s.extract(&parts));
let token = match token {
Some(t) => t,
None => {
let err = Error::unauthorized("unauthorized")
.chain(JwtError::MissingToken)
.with_code(JwtError::MissingToken.code());
return Ok(err.into_response());
}
};
let claims: Claims = match decoder.decode(&token) {
Ok(c) => c,
Err(e) => return Ok(e.into_response()),
};
if let Some(svc) = service {
if claims.aud.as_deref() != Some("access") {
let err = Error::unauthorized("unauthorized").with_code("auth:aud_mismatch");
return Ok(err.into_response());
}
if svc.config().stateful_validation {
let jti = match claims.jti.as_deref() {
Some(j) => j,
None => {
let err = Error::unauthorized("unauthorized")
.with_code("auth:session_not_found");
return Ok(err.into_response());
}
};
let session_token = match SessionToken::from_raw(jti) {
Some(t) => t,
None => {
let err = Error::unauthorized("unauthorized")
.with_code("auth:session_not_found");
return Ok(err.into_response());
}
};
let lookup = svc.store().read_by_token_hash(&session_token.hash()).await;
let raw = match lookup {
Err(e) => return Ok(e.into_response()),
Ok(None) => {
let err = Error::unauthorized("unauthorized")
.with_code("auth:session_not_found");
return Ok(err.into_response());
}
Ok(Some(row)) => row,
};
parts.extensions.insert(Session::from(raw));
}
}
parts.extensions.insert(claims);
let request = Request::from_parts(parts, body);
inner.call(request).await
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::{Response, StatusCode};
use std::convert::Infallible;
use tower::ServiceExt;
use crate::auth::session::jwt::{Claims, JwtEncoder, JwtSessionsConfig};
fn test_config() -> JwtSessionsConfig {
JwtSessionsConfig {
signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
..JwtSessionsConfig::default()
}
}
fn now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
}
fn make_token(config: &JwtSessionsConfig) -> String {
let encoder = JwtEncoder::from_config(config);
let claims = Claims::new().with_sub("user_1").with_exp(now_secs() + 3600);
encoder.encode(&claims).unwrap()
}
async fn echo_handler(req: Request<Body>) -> Result<Response<Body>, Infallible> {
let has_claims = req.extensions().get::<Claims>().is_some();
let body = if has_claims { "ok" } else { "no-claims" };
Ok(Response::new(Body::from(body)))
}
#[tokio::test]
async fn valid_token_passes_through() {
let config = test_config();
let decoder = JwtDecoder::from_config(&config);
let token = make_token(&config);
let layer = JwtLayer::new(decoder);
let svc = layer.layer(tower::service_fn(echo_handler));
let req = Request::builder()
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn missing_header_returns_401() {
let config = test_config();
let decoder = JwtDecoder::from_config(&config);
let layer = JwtLayer::new(decoder);
let svc = layer.layer(tower::service_fn(echo_handler));
let req = Request::builder().body(Body::empty()).unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn expired_token_returns_401() {
let config = test_config();
let encoder = JwtEncoder::from_config(&config);
let decoder = JwtDecoder::from_config(&config);
let claims = Claims::new().with_exp(now_secs() - 10);
let token = encoder.encode(&claims).unwrap();
let layer = JwtLayer::new(decoder);
let svc = layer.layer(tower::service_fn(echo_handler));
let req = Request::builder()
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn tampered_token_returns_401() {
let config = test_config();
let decoder = JwtDecoder::from_config(&config);
let token = make_token(&config);
let dot = token.rfind('.').unwrap();
let mid = dot + (token.len() - dot) / 2;
let mut bytes = token.into_bytes();
bytes[mid] = if bytes[mid] == b'A' { b'Z' } else { b'A' };
let token = String::from_utf8(bytes).unwrap();
let layer = JwtLayer::new(decoder);
let svc = layer.layer(tower::service_fn(echo_handler));
let req = Request::builder()
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn claims_inserted_into_extensions() {
let config = test_config();
let decoder = JwtDecoder::from_config(&config);
let token = make_token(&config);
let layer = JwtLayer::new(decoder);
let inner = tower::service_fn(|req: Request<Body>| async move {
let claims = req.extensions().get::<Claims>().unwrap();
assert_eq!(claims.subject(), Some("user_1"));
Ok::<_, Infallible>(Response::new(Body::empty()))
});
let svc = layer.layer(inner);
let req = Request::builder()
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn custom_token_source_works() {
let config = test_config();
let decoder = JwtDecoder::from_config(&config);
let token = make_token(&config);
let layer = JwtLayer::new(decoder).with_sources(vec![Arc::new(
super::super::source::QuerySource("token"),
) as Arc<dyn TokenSource>]);
let svc = layer.layer(tower::service_fn(echo_handler));
let req = Request::builder()
.uri(format!("/path?token={token}"))
.body(Body::empty())
.unwrap();
let resp = svc.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
}