use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::future::BoxFuture;
use http::{Request, Response};
use tonic::Status;
use tonic::body::BoxBody;
use tower::{Layer, Service};
use super::{AuthCtx, AuthError, CURRENT_AUTH, TokenExtractor, TokenVerifier};
#[derive(Clone)]
pub struct AuthLayer {
extractor: Arc<dyn TokenExtractor>,
verifier: Arc<dyn TokenVerifier>,
optional: bool,
}
impl AuthLayer {
pub fn new<E, V>(extractor: E, verifier: V) -> Self
where
E: TokenExtractor,
V: TokenVerifier,
{
Self {
extractor: Arc::new(extractor),
verifier: Arc::new(verifier),
optional: false,
}
}
pub fn optional(mut self) -> Self {
self.optional = true;
self
}
}
impl<S> Layer<S> for AuthLayer {
type Service = AuthService<S>;
fn layer(&self, inner: S) -> Self::Service {
AuthService {
inner,
extractor: self.extractor.clone(),
verifier: self.verifier.clone(),
optional: self.optional,
}
}
}
#[derive(Clone)]
pub struct AuthService<S> {
inner: S,
extractor: Arc<dyn TokenExtractor>,
verifier: Arc<dyn TokenVerifier>,
optional: bool,
}
impl<S> Service<Request<BoxBody>> for AuthService<S>
where
S: Service<Request<BoxBody>, Response = Response<BoxBody>> + Clone + Send + 'static,
S::Error: Send + 'static,
S::Future: Send + 'static,
{
type Response = Response<BoxBody>;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<BoxBody>) -> Self::Future {
let mut inner = self.inner.clone();
let extractor = self.extractor.clone();
let verifier = self.verifier.clone();
let optional = self.optional;
Box::pin(async move {
let metadata = metadata_from_headers(req.headers());
let ctx = match extractor.extract(&metadata) {
Ok(token) => match verifier.verify(&token).await {
Ok(ctx) => ctx,
Err(e) => return Ok(error_response(e)),
},
Err(AuthError::MissingToken) if optional => AuthCtx::anonymous(),
Err(e) => return Ok(error_response(e)),
};
req.extensions_mut().insert(ctx.clone());
CURRENT_AUTH.scope(ctx, inner.call(req)).await
})
}
}
fn metadata_from_headers(h: &http::HeaderMap) -> tonic::metadata::MetadataMap {
tonic::metadata::MetadataMap::from_headers(h.clone())
}
fn error_response(e: AuthError) -> Response<BoxBody> {
let status: Status = e.into();
status.into_http()
}