use futures_core::ready;
use futures_util::future::{self, BoxFuture};
use http::Request;
use jsonwebtoken::TokenData;
use pin_project::pin_project;
use serde::de::DeserializeOwned;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower_layer::Layer;
use tower_service::Service;
use crate::authorizer::Authorizer;
use crate::AuthError;
pub trait Authorize<B> {
type Future: Future<Output = Result<Request<B>, AuthError>>;
fn authorize(&self, request: Request<B>) -> Self::Future;
}
impl<S, B, C> Authorize<B> for AuthorizationService<S, C>
where
B: Send + 'static,
C: Clone + DeserializeOwned + Send + Sync + 'static,
{
type Future = BoxFuture<'static, Result<Request<B>, AuthError>>;
fn authorize(&self, mut request: Request<B>) -> Self::Future {
let tkns_auths: Vec<(String, Arc<Authorizer<C>>)> = self
.auths
.iter()
.filter_map(|a| a.extract_token(request.headers()).map(|t| (t, a.clone())))
.collect();
if tkns_auths.is_empty() {
return Box::pin(future::ready(Err(AuthError::MissingToken())));
}
Box::pin(async move {
let mut token_data: Result<TokenData<C>, AuthError> = Err(AuthError::NoAuthorizer());
for (token, auth) in tkns_auths {
token_data = auth.check_auth(token.as_str()).await;
if token_data.is_ok() {
break;
}
}
match token_data {
Ok(tdata) => {
request.extensions_mut().insert(tdata);
Ok(request)
}
Err(err) => Err(err), }
})
}
}
#[derive(Clone)]
pub struct AuthorizationLayer<C>
where
C: Clone + DeserializeOwned + Send,
{
auths: Vec<Arc<Authorizer<C>>>,
}
impl<C> AuthorizationLayer<C>
where
C: Clone + DeserializeOwned + Send,
{
pub fn new(auths: Vec<Arc<Authorizer<C>>>) -> AuthorizationLayer<C> {
Self { auths }
}
}
impl<S, C> Layer<S> for AuthorizationLayer<C>
where
C: Clone + DeserializeOwned + Send + Sync,
{
type Service = AuthorizationService<S, C>;
fn layer(&self, inner: S) -> Self::Service {
AuthorizationService::new(inner, self.auths.clone())
}
}
#[derive(Clone)]
pub enum JwtSource {
AuthorizationHeader,
Cookie(String),
}
#[derive(Clone)]
pub struct AuthorizationService<S, C>
where
C: Clone + DeserializeOwned + Send,
{
pub inner: S,
pub auths: Vec<Arc<Authorizer<C>>>,
}
impl<S, C> AuthorizationService<S, C>
where
C: Clone + DeserializeOwned + Send,
{
pub fn get_ref(&self) -> &S {
&self.inner
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S, C> AuthorizationService<S, C>
where
C: Clone + DeserializeOwned + Send + Sync,
{
pub fn new(inner: S, auths: Vec<Arc<Authorizer<C>>>) -> AuthorizationService<S, C> {
Self { inner, auths }
}
}
impl<S, C, B> Service<Request<B>> for AuthorizationService<S, C>
where
B: Send + 'static,
S: Service<Request<B>> + Clone,
S::Response: From<AuthError>,
C: Clone + DeserializeOwned + Send + Sync + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S, C, B>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let inner = self.inner.clone();
let inner = std::mem::replace(&mut self.inner, inner);
let auth_fut = self.authorize(req);
ResponseFuture {
state: State::Authorize { auth_fut },
service: inner,
}
}
}
#[pin_project]
pub struct ResponseFuture<S, C, B>
where
B: Send + 'static,
S: Service<Request<B>>,
C: Clone + DeserializeOwned + Send + Sync + 'static,
{
#[pin]
state: State<<AuthorizationService<S, C> as Authorize<B>>::Future, S::Future>,
service: S,
}
#[pin_project(project = StateProj)]
enum State<A, SFut> {
Authorize {
#[pin]
auth_fut: A,
},
Authorized {
#[pin]
svc_fut: SFut,
},
}
impl<S, C, B> Future for ResponseFuture<S, C, B>
where
B: Send,
S: Service<Request<B>>,
S::Response: From<AuthError>,
C: Clone + DeserializeOwned + Send + Sync,
{
type Output = Result<S::Response, S::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
match this.state.as_mut().project() {
StateProj::Authorize { auth_fut } => {
let auth = ready!(auth_fut.poll(cx));
match auth {
Ok(req) => {
let svc_fut = this.service.call(req);
this.state.set(State::Authorized { svc_fut })
}
Err(res) => {
tracing::info!("err: {:?}", res);
return Poll::Ready(Ok(res.into()));
}
};
}
StateProj::Authorized { svc_fut } => {
return svc_fut.poll(cx);
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::{authorizer::Authorizer, IntoLayer, JwtAuthorizer, RegisteredClaims};
use super::AuthorizationLayer;
#[tokio::test]
async fn auth_into_layer() {
let auth1: Authorizer = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
let layer = auth1.into_layer();
assert_eq!(1, layer.auths.len());
}
#[tokio::test]
async fn auths_into_layer() {
let auth1 = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
let auth2 = JwtAuthorizer::from_secret("bbb").build().await.unwrap();
let layer: AuthorizationLayer<RegisteredClaims> = [auth1, auth2].into_layer();
assert_eq!(2, layer.auths.len());
}
#[tokio::test]
async fn vec_auths_into_layer() {
let auth1 = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
let auth2 = JwtAuthorizer::from_secret("bbb").build().await.unwrap();
let layer: AuthorizationLayer<RegisteredClaims> = vec![auth1, auth2].into_layer();
assert_eq!(2, layer.auths.len());
}
#[tokio::test]
async fn jwt_auth_to_layer() {
let auth1: JwtAuthorizer = JwtAuthorizer::from_secret("aaa");
#[allow(deprecated)]
let layer = auth1.layer().await.unwrap();
assert_eq!(1, layer.auths.len());
}
}