use crate::{Request, Response};
use rama_core::{Context, Layer, Service};
use rama_utils::macros::define_inner_service_accessors;
#[derive(Debug, Clone)]
pub struct AsyncRequireAuthorizationLayer<T> {
auth: T,
}
impl<T> AsyncRequireAuthorizationLayer<T> {
pub const fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
Self { auth }
}
}
impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
where
T: Clone,
{
type Service = AsyncRequireAuthorization<S, T>;
fn layer(&self, inner: S) -> Self::Service {
AsyncRequireAuthorization::new(inner, self.auth.clone())
}
fn into_layer(self, inner: S) -> Self::Service {
AsyncRequireAuthorization::new(inner, self.auth)
}
}
#[derive(Clone, Debug)]
pub struct AsyncRequireAuthorization<S, T> {
inner: S,
auth: T,
}
impl<S, T> AsyncRequireAuthorization<S, T> {
pub const fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
Self { inner, auth }
}
define_inner_service_accessors!();
}
impl<ReqBody, ResBody, S, State, Auth> Service<State, Request<ReqBody>>
for AsyncRequireAuthorization<S, Auth>
where
Auth: AsyncAuthorizeRequest<State, ReqBody, ResponseBody = ResBody> + Send + Sync + 'static,
S: Service<State, Request<Auth::RequestBody>, Response = Response<ResBody>>,
ReqBody: Send + 'static,
ResBody: Send + 'static,
State: Clone + Send + Sync + 'static,
{
type Response = Response<ResBody>;
type Error = S::Error;
async fn serve(
&self,
ctx: Context<State>,
req: Request<ReqBody>,
) -> Result<Self::Response, Self::Error> {
let (ctx, req) = match self.auth.authorize(ctx, req).await {
Ok(req) => req,
Err(res) => return Ok(res),
};
self.inner.serve(ctx, req).await
}
}
pub trait AsyncAuthorizeRequest<S, B> {
type RequestBody;
type ResponseBody;
fn authorize(
&self,
ctx: Context<S>,
request: Request<B>,
) -> impl Future<
Output = Result<(Context<S>, Request<Self::RequestBody>), Response<Self::ResponseBody>>,
> + Send
+ '_;
}
impl<S, B, F, Fut, ReqBody, ResBody> AsyncAuthorizeRequest<S, B> for F
where
F: Fn(Context<S>, Request<B>) -> Fut + Send + Sync + 'static,
Fut:
Future<Output = Result<(Context<S>, Request<ReqBody>), Response<ResBody>>> + Send + 'static,
B: Send + 'static,
S: Clone + Send + Sync + 'static,
ReqBody: Send + 'static,
ResBody: Send + 'static,
{
type RequestBody = ReqBody;
type ResponseBody = ResBody;
async fn authorize(
&self,
ctx: Context<S>,
request: Request<B>,
) -> Result<(Context<S>, Request<Self::RequestBody>), Response<Self::ResponseBody>> {
self(ctx, request).await
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use crate::{Body, StatusCode, header};
use rama_core::error::BoxError;
use rama_core::service::service_fn;
#[derive(Clone, Copy)]
struct MyAuth;
impl<S, B> AsyncAuthorizeRequest<S, B> for MyAuth
where
S: Clone + Send + Sync + 'static,
B: Send + 'static,
{
type RequestBody = B;
type ResponseBody = Body;
async fn authorize(
&self,
mut ctx: Context<S>,
request: Request<B>,
) -> Result<(Context<S>, Request<Self::RequestBody>), Response<Self::ResponseBody>>
{
let authorized = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|it: &rama_http_types::HeaderValue| it.to_str().ok())
.and_then(|it| it.strip_prefix("Bearer "))
.map(|it| it == "69420")
.unwrap_or(false);
if authorized {
let user_id = UserId("6969".to_owned());
ctx.insert(user_id);
Ok((ctx, request))
} else {
Err(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.unwrap())
}
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct UserId(String);
#[tokio::test]
async fn require_async_auth_works() {
let service = AsyncRequireAuthorizationLayer::new(MyAuth).layer(service_fn(echo));
let request = Request::get("/")
.header(header::AUTHORIZATION, "Bearer 69420")
.body(Body::empty())
.unwrap();
let res = service.serve(Context::default(), request).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn require_async_auth_401() {
let service = AsyncRequireAuthorizationLayer::new(MyAuth).layer(service_fn(echo));
let request = Request::get("/")
.header(header::AUTHORIZATION, "Bearer deez")
.body(Body::empty())
.unwrap();
let res = service.serve(Context::default(), request).await.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
}