use std::sync::Arc;
use std::task::{Context, Poll};
use http::{Request, Response};
use tonic::body::Body;
use tower::{Layer, Service};
use crate::middleware::bearer::extract_bearer_token;
use crate::middleware::context::{AuthIdentity, RequestContext};
use crate::middleware::error::MiddlewareError;
use crate::middleware::stack::MiddlewareStack;
#[derive(Clone)]
pub struct GrpcAuthLayer {
stack: Arc<MiddlewareStack>,
}
impl GrpcAuthLayer {
pub fn new(stack: Arc<MiddlewareStack>) -> Self {
Self { stack }
}
}
impl<S> Layer<S> for GrpcAuthLayer {
type Service = GrpcAuthService<S>;
fn layer(&self, inner: S) -> Self::Service {
GrpcAuthService {
inner,
stack: self.stack.clone(),
}
}
}
#[derive(Clone)]
pub struct GrpcAuthService<S> {
inner: S,
stack: Arc<MiddlewareStack>,
}
impl<S> Service<Request<Body>> for GrpcAuthService<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let stack = self.stack.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let headers = req.headers().clone();
if stack.is_empty() {
req.extensions_mut().insert(RequestContext {
bearer_token: None,
headers,
identity: AuthIdentity::Anonymous,
extensions: Default::default(),
});
return inner.call(req).await;
}
let bearer_token = headers
.get(http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(extract_bearer_token);
let mut ctx = RequestContext {
bearer_token,
headers: headers.clone(),
identity: AuthIdentity::Anonymous,
extensions: Default::default(),
};
match stack.before_request(&mut ctx).await {
Ok(()) => {
req.extensions_mut().insert(ctx);
inner.call(req).await
}
Err(err) => Ok(middleware_error_to_grpc_response(&err)),
}
})
}
}
fn middleware_error_to_grpc_response(err: &MiddlewareError) -> Response<Body> {
let message = format!("{err:?}");
let status = match err {
MiddlewareError::Unauthenticated(_) | MiddlewareError::HttpChallenge { .. } => {
tonic::Status::unauthenticated(message)
}
MiddlewareError::Forbidden(_) => tonic::Status::permission_denied(message),
MiddlewareError::Internal(_) => tonic::Status::internal(message),
};
status.into_http()
}