use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use super::{extract_bearer_token, AuthContext, AuthError, Authenticator};
#[derive(Debug, Clone)]
pub struct AuthenticatedUser<C>(pub C);
impl<C> AuthenticatedUser<C> {
pub fn claims(&self) -> &C {
&self.0
}
pub fn into_inner(self) -> C {
self.0
}
}
impl<C> std::ops::Deref for AuthenticatedUser<C> {
type Target = C;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Clone)]
pub struct AuthLayer<A> {
authenticator: Arc<A>,
}
impl<A> AuthLayer<A> {
pub fn new(authenticator: A) -> Self {
Self {
authenticator: Arc::new(authenticator),
}
}
}
impl<S, A> tower::Layer<S> for AuthLayer<A>
where
A: Clone,
{
type Service = AuthService<S, A>;
fn layer(&self, inner: S) -> Self::Service {
AuthService {
inner,
authenticator: self.authenticator.clone(),
}
}
}
#[derive(Clone)]
pub struct AuthService<S, A> {
inner: S,
authenticator: Arc<A>,
}
impl<S, A, ReqBody> tower::Service<hyper::Request<ReqBody>> for AuthService<S, A>
where
S: tower::Service<hyper::Request<ReqBody>> + Clone + Send + 'static,
S::Future: Send,
A: Authenticator + 'static,
ReqBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: hyper::Request<ReqBody>) -> Self::Future {
let authenticator = self.authenticator.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
if let Some(auth_header) = req.headers().get(hyper::header::AUTHORIZATION) {
if let Ok(header_str) = auth_header.to_str() {
if let Some(token) = extract_bearer_token(header_str) {
if let Ok(claims) = authenticator.authenticate(token).await {
let ctx = AuthContext::new(claims, token);
req.extensions_mut().insert(ctx);
}
}
}
}
inner.call(req).await
})
}
}
#[derive(Clone)]
pub struct OptionalAuthLayer<A> {
authenticator: Arc<A>,
}
impl<A> OptionalAuthLayer<A> {
pub fn new(authenticator: A) -> Self {
Self {
authenticator: Arc::new(authenticator),
}
}
}
impl<S, A> tower::Layer<S> for OptionalAuthLayer<A>
where
A: Clone,
{
type Service = AuthService<S, A>;
fn layer(&self, inner: S) -> Self::Service {
AuthService {
inner,
authenticator: self.authenticator.clone(),
}
}
}
pub trait AuthExt {
fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>>;
fn claims<C: Clone + Send + Sync + 'static>(&self) -> Option<&C> {
self.auth_context::<C>().map(|ctx| &ctx.claims)
}
}
impl<B> AuthExt for hyper::Request<B> {
fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>> {
self.extensions().get::<AuthContext<C>>()
}
}
#[derive(Debug)]
pub struct AuthRejection {
pub error: AuthError,
}
impl AuthRejection {
pub fn new(error: AuthError) -> Self {
Self { error }
}
}
impl std::fmt::Display for AuthRejection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error)
}
}
impl std::error::Error for AuthRejection {}
#[derive(Debug, Clone, Copy)]
pub struct Required;
#[derive(Debug, Clone, Copy)]
pub struct Optional;
#[derive(Debug, Clone)]
pub struct Auth<C, R = Required> {
pub context: Option<AuthContext<C>>,
_requirement: PhantomData<R>,
}
impl<C: Clone> Auth<C, Required> {
pub fn claims(&self) -> &C {
&self.context.as_ref().unwrap().claims
}
pub fn token(&self) -> &str {
self.context.as_ref().unwrap().token()
}
}
impl<C> Auth<C, Optional> {
pub fn claims(&self) -> Option<&C> {
self.context.as_ref().map(|ctx| &ctx.claims)
}
pub fn is_authenticated(&self) -> bool {
self.context.is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_authenticated_user() {
#[derive(Clone, Debug, PartialEq)]
struct Claims {
sub: String,
}
let user = AuthenticatedUser(Claims {
sub: "user123".to_string(),
});
assert_eq!(user.claims().sub, "user123");
assert_eq!(user.sub, "user123");
let claims = user.into_inner();
assert_eq!(claims.sub, "user123");
}
#[test]
fn test_auth_rejection() {
let rejection = AuthRejection::new(AuthError::MissingToken);
assert!(rejection.to_string().contains("missing"));
}
#[test]
fn test_auth_ext_trait() {
#[derive(Clone)]
struct Claims {
sub: String,
}
let mut req = hyper::Request::builder().body(()).unwrap();
assert!(req.auth_context::<Claims>().is_none());
assert!(req.claims::<Claims>().is_none());
req.extensions_mut().insert(AuthContext::new(
Claims {
sub: "user123".to_string(),
},
"token",
));
assert!(req.auth_context::<Claims>().is_some());
assert_eq!(req.claims::<Claims>().unwrap().sub, "user123");
}
}