use {axum::{http::Request,
middleware::Next,
response::Response},
std::{future::Future,
pin::Pin,
task::{Context,
Poll}},
tower::{Layer,
Service}};
#[derive(Clone)]
pub struct AuthLayer;
impl<S> Layer<S> for AuthLayer {
type Service = AuthMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
AuthMiddleware { inner }
}
}
#[derive(Clone)]
pub struct AuthMiddleware<S> {
inner: S,
}
impl<S, ReqBody> Service<Request<ReqBody>> for AuthMiddleware<S>
where
S: Service<Request<ReqBody>, Response = Response> + Clone + Send + 'static,
S::Future: Send,
ReqBody: Send + 'static,
{
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
type Response = S::Response;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let mut inner = self.inner.clone();
Box::pin(async move {
inner.call(req).await
})
}
}
pub async fn validate_bearer_token(
req: Request<axum::body::Body>,
next: Next,
) -> Result<Response, axum::http::StatusCode> {
let auth_header = req.headers().get("Authorization").and_then(|h| h.to_str().ok());
match auth_header {
| Some(header) if header.starts_with("Bearer ") => {
let _token = &header[7 ..];
Ok(next.run(req).await)
}
| _ => Err(axum::http::StatusCode::UNAUTHORIZED),
}
}