use http::{Request, Response};
use std::future::Future;
use tower_async_layer::Layer;
use tower_async_service::Service;
#[derive(Debug, Clone)]
pub struct AsyncRequireAuthorizationLayer<T> {
auth: T,
}
impl<T> AsyncRequireAuthorizationLayer<T> {
pub fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
Self { auth }
}
}
impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
where
T: Clone,
{
type Service = AsyncRequireAuthorization<S, T>;
fn layer(&self, inner: S) -> Self::Service {
AsyncRequireAuthorization::new(inner, self.auth.clone())
}
}
#[derive(Clone, Debug)]
pub struct AsyncRequireAuthorization<S, T> {
inner: S,
auth: T,
}
impl<S, T> AsyncRequireAuthorization<S, T> {
define_inner_service_accessors!();
}
impl<S, T> AsyncRequireAuthorization<S, T> {
pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
Self { inner, auth }
}
pub fn layer(auth: T) -> AsyncRequireAuthorizationLayer<T> {
AsyncRequireAuthorizationLayer::new(auth)
}
}
impl<ReqBody, ResBody, S, Auth> Service<Request<ReqBody>> for AsyncRequireAuthorization<S, Auth>
where
Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = ResBody>,
S: Service<Request<Auth::RequestBody>, Response = Response<ResBody>> + Clone,
{
type Response = Response<ResBody>;
type Error = S::Error;
async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
let req = match self.auth.authorize(req).await {
Ok(req) => req,
Err(res) => return Ok(res),
};
self.inner.call(req).await
}
}
pub trait AsyncAuthorizeRequest<B> {
type RequestBody;
type ResponseBody;
fn authorize(
&self,
request: Request<B>,
) -> impl std::future::Future<
Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>,
>;
}
impl<B, F, Fut, ReqBody, ResBody> AsyncAuthorizeRequest<B> for F
where
F: Fn(Request<B>) -> Fut,
Fut: Future<Output = Result<Request<ReqBody>, Response<ResBody>>>,
{
type RequestBody = ReqBody;
type ResponseBody = ResBody;
async fn authorize(
&self,
request: Request<B>,
) -> Result<Request<Self::RequestBody>, Response<Self::ResponseBody>> {
self(request).await
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use crate::test_helpers::Body;
use http::{header, StatusCode};
use tower_async::{BoxError, ServiceBuilder};
#[derive(Clone, Copy)]
struct MyAuth;
impl<B> AsyncAuthorizeRequest<B> for MyAuth
where
B: Send + 'static,
{
type RequestBody = B;
type ResponseBody = Body;
async fn authorize(
&self,
mut request: Request<B>,
) -> Result<Request<Self::RequestBody>, Response<Self::ResponseBody>> {
let authorized = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|it: &http::HeaderValue| it.to_str().ok())
.and_then(|it| it.strip_prefix("Bearer "))
.map(|it| it == "69420")
.unwrap_or(false);
if authorized {
let user_id = UserId("6969".to_owned());
request.extensions_mut().insert(user_id);
Ok(request)
} else {
Err(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.unwrap())
}
}
}
#[derive(Debug, Clone)]
struct UserId(String);
#[tokio::test]
async fn require_async_auth_works() {
let service = ServiceBuilder::new()
.layer(AsyncRequireAuthorizationLayer::new(MyAuth))
.service_fn(echo);
let request = Request::get("/")
.header(header::AUTHORIZATION, "Bearer 69420")
.body(Body::empty())
.unwrap();
let res = service.call(request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn require_async_auth_401() {
let service = ServiceBuilder::new()
.layer(AsyncRequireAuthorizationLayer::new(MyAuth))
.service_fn(echo);
let request = Request::get("/")
.header(header::AUTHORIZATION, "Bearer deez")
.body(Body::empty())
.unwrap();
let res = service.call(request).await.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
}