use crate::types::{AuthProvider, GrpcRequestParts, RequestParts};
use futures::future::BoxFuture;
use micromegas_tracing::prelude::*;
use std::sync::Arc;
use tonic::Status;
use tower::Service;
#[derive(Clone)]
pub struct AuthService<S> {
pub inner: S,
pub auth_provider: Option<Arc<dyn AuthProvider>>,
}
impl<S> Service<http::Request<tonic::body::Body>> for AuthService<S>
where
S: Service<http::Request<tonic::body::Body>> + Clone + Send + 'static,
S::Response: 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Response = S::Response;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: http::Request<tonic::body::Body>) -> Self::Future {
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
let auth_provider = self.auth_provider.clone();
Box::pin(async move {
if let Some(provider) = auth_provider {
let (mut parts, body) = req.into_parts();
let request_parts = GrpcRequestParts {
metadata: tonic::metadata::MetadataMap::from_headers(parts.headers.clone()),
};
match provider
.validate_request(&request_parts as &dyn RequestParts)
.await
{
Ok(auth_ctx) => {
info!(
"authenticated: subject={} email={:?} issuer={} admin={}",
auth_ctx.subject, auth_ctx.email, auth_ctx.issuer, auth_ctx.is_admin
);
parts.headers.insert(
"x-user-id",
http::HeaderValue::from_str(&auth_ctx.subject)
.expect("valid user id header"),
);
if let Some(email) = &auth_ctx.email {
parts.headers.insert(
"x-user-email",
http::HeaderValue::from_str(email).expect("valid email header"),
);
}
parts.headers.insert(
"x-user-issuer",
http::HeaderValue::from_str(&auth_ctx.issuer)
.expect("valid issuer header"),
);
parts.extensions.insert(auth_ctx);
let req = http::Request::from_parts(parts, body);
inner.call(req).await.map_err(Into::into)
}
Err(e) => {
warn!("authentication failed: {e}");
Err(Box::new(Status::unauthenticated("invalid token"))
as Box<dyn std::error::Error + Send + Sync>)
}
}
} else {
inner.call(req).await.map_err(Into::into)
}
})
}
}