use core::{
mem,
task::{Context, Poll},
};
use http::{Request, Response};
use http_body::Body;
use tower_layer::Layer;
use tower_service::Service;
use crate::middleware::futures::{UndefinedFuture, undefined::DefiningFuture};
#[derive(Debug)]
pub struct Authorization<B, Fut>
where
Fut: DefiningFuture<B>,
{
authorize: fn(Request<B>) -> Fut,
}
impl<B, Fut> Clone for Authorization<B, Fut>
where
Fut: DefiningFuture<B>,
{
fn clone(&self) -> Self {
Self {
authorize: self.authorize,
}
}
}
impl<B, Fut> Authorization<B, Fut>
where
Fut: DefiningFuture<B>,
{
pub fn new(authorize: fn(Request<B>) -> Fut) -> Self {
Self { authorize }
}
}
impl<Svc, B, Fut> Layer<Svc> for Authorization<B, Fut>
where
Svc: Clone,
Fut: DefiningFuture<B>,
{
type Service = AuthorizationService<Svc, B, Fut>;
fn layer(&self, inner: Svc) -> Self::Service {
AuthorizationService {
inner,
auth: self.clone(),
}
}
}
#[derive(Debug)]
pub struct AuthorizationService<Svc, B, Fut>
where
Svc: Clone,
Fut: DefiningFuture<B>,
{
inner: Svc,
auth: Authorization<B, Fut>,
}
impl<Svc, B, Fut> Clone for AuthorizationService<Svc, B, Fut>
where
Svc: Clone,
Fut: DefiningFuture<B>,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
auth: self.auth.clone(),
}
}
}
impl<Svc, B, Fut> AuthorizationService<Svc, B, Fut>
where
Svc: Clone,
Fut: DefiningFuture<B>,
{
pub fn new(inner: Svc, auth: Authorization<B, Fut>) -> Self {
Self { inner, auth }
}
}
impl<Svc, ReqBody, ResBody, Fut> Service<Request<ReqBody>>
for AuthorizationService<Svc, ReqBody, Fut>
where
Svc: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
ResBody: Body + Send + Default,
ReqBody: Send + 'static,
Fut: DefiningFuture<ReqBody>,
{
type Response = Svc::Response;
type Error = Svc::Error;
type Future = UndefinedFuture<Svc, ReqBody>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
let auth_future = (self.auth.clone().authorize)(request);
let mut inner = self.inner.clone();
mem::swap(&mut self.inner, &mut inner);
UndefinedFuture::define(Box::pin(auth_future), inner)
}
}
#[cfg(test)]
mod test {
use axum::{Router, routing::get};
use http::{Request, StatusCode};
use tower::ServiceExt;
use ts_token::jwt::TokenType;
use crate::{
middleware::{authorization::Authorization, test::get_request},
test::ResponseTestExt,
};
#[tokio::test]
async fn axum() {
async fn authorize<B>(request: Request<B>) -> Result<Request<B>, StatusCode> {
Ok(request)
}
let auth = Authorization::new(authorize);
let request = get_request(Some(TokenType::Common));
Router::new()
.route("/resource/id", get(|| async move { StatusCode::OK }))
.layer(auth)
.oneshot(request)
.await
.unwrap()
.expect_status(StatusCode::OK);
}
}