use std::{marker::PhantomData, sync::Arc};
use tonic::{metadata::MetadataMap, Request, Status};
use super::{extract_bearer_token, AuthContext, AuthError, Authenticator};
pub struct AuthInterceptor<A, C> {
authenticator: Arc<A>,
required: bool,
_phantom: PhantomData<C>,
}
impl<A, C> Clone for AuthInterceptor<A, C> {
fn clone(&self) -> Self {
Self {
authenticator: self.authenticator.clone(),
required: self.required,
_phantom: PhantomData,
}
}
}
impl<A, C> AuthInterceptor<A, C>
where
A: Authenticator<Claims = C>,
C: Clone + Send + Sync + 'static,
{
pub fn new(authenticator: A) -> Self {
Self {
authenticator: Arc::new(authenticator),
required: true,
_phantom: PhantomData,
}
}
pub fn required(mut self) -> Self {
self.required = true;
self
}
pub fn optional(mut self) -> Self {
self.required = false;
self
}
fn authenticate_request<T>(&self, request: &Request<T>) -> Result<Option<C>, Status> {
let metadata = request.metadata();
let token = match extract_token_from_metadata(metadata) {
Some(token) => token,
None => {
if self.required {
return Err(auth_error_to_status(AuthError::MissingToken));
}
return Ok(None);
}
};
let authenticator = self.authenticator.clone();
let token_owned = token.to_string();
let result = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(async { authenticator.authenticate(&token_owned).await })
});
match result {
Ok(claims) => Ok(Some(claims)),
Err(e) => {
if self.required {
Err(auth_error_to_status(e))
} else {
Ok(None)
}
}
}
}
}
impl<A, C> tonic::service::Interceptor for AuthInterceptor<A, C>
where
A: Authenticator<Claims = C> + 'static,
C: Clone + Send + Sync + 'static,
{
fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
match self.authenticate_request(&request)? {
Some(claims) => {
let token = extract_token_from_metadata(request.metadata())
.unwrap_or("")
.to_string();
let ctx = AuthContext::new(claims, token);
request.extensions_mut().insert(ctx);
}
None => {
}
}
Ok(request)
}
}
fn extract_token_from_metadata(metadata: &MetadataMap) -> Option<&str> {
metadata
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(extract_bearer_token)
}
fn auth_error_to_status(error: AuthError) -> Status {
match error {
AuthError::MissingToken => Status::unauthenticated("missing authentication token"),
AuthError::InvalidToken(msg) => Status::unauthenticated(format!("invalid token: {}", msg)),
AuthError::TokenExpired => Status::unauthenticated("token has expired"),
AuthError::InvalidSignature => Status::unauthenticated("invalid token signature"),
AuthError::InvalidIssuer => Status::unauthenticated("invalid token issuer"),
AuthError::InvalidAudience => Status::unauthenticated("invalid token audience"),
AuthError::ValidationFailed(msg) => {
Status::permission_denied(format!("validation failed: {}", msg))
}
AuthError::Internal(msg) => Status::internal(format!("auth error: {}", msg)),
}
}
pub trait GrpcAuthExt<T> {
fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>>;
fn claims<C: Clone + Send + Sync + 'static>(&self) -> Option<&C> {
self.auth_context::<C>().map(|ctx| &ctx.claims)
}
fn require_auth<C: Clone + Send + Sync + 'static>(&self) -> Result<&C, Status> {
self.claims::<C>()
.ok_or_else(|| Status::unauthenticated("authentication required"))
}
}
impl<T> GrpcAuthExt<T> for Request<T> {
fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>> {
self.extensions().get::<AuthContext<C>>()
}
}
pub fn auth_interceptor<A, C>(
authenticator: A,
required: bool,
) -> impl FnMut(Request<()>) -> Result<Request<()>, Status> + Clone
where
A: Authenticator<Claims = C> + 'static,
C: Clone + Send + Sync + 'static,
{
let mut interceptor = AuthInterceptor::new(authenticator);
if !required {
interceptor = interceptor.optional();
}
move |req| {
let mut i = interceptor.clone();
tonic::service::Interceptor::call(&mut i, req)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_token_from_metadata() {
let mut metadata = MetadataMap::new();
assert!(extract_token_from_metadata(&metadata).is_none());
metadata.insert("authorization", "Bearer test123".parse().unwrap());
assert_eq!(extract_token_from_metadata(&metadata), Some("test123"));
}
#[test]
fn test_auth_error_to_status() {
let status = auth_error_to_status(AuthError::MissingToken);
assert_eq!(status.code(), tonic::Code::Unauthenticated);
let status = auth_error_to_status(AuthError::TokenExpired);
assert_eq!(status.code(), tonic::Code::Unauthenticated);
let status = auth_error_to_status(AuthError::ValidationFailed("test".into()));
assert_eq!(status.code(), tonic::Code::PermissionDenied);
let status = auth_error_to_status(AuthError::Internal("error".into()));
assert_eq!(status.code(), tonic::Code::Internal);
}
#[test]
fn test_grpc_auth_ext() {
#[derive(Clone)]
struct Claims {
sub: String,
}
let mut request = Request::new(());
assert!(request.auth_context::<Claims>().is_none());
assert!(request.claims::<Claims>().is_none());
assert!(request.require_auth::<Claims>().is_err());
request.extensions_mut().insert(AuthContext::new(
Claims {
sub: "user123".to_string(),
},
"token",
));
assert!(request.auth_context::<Claims>().is_some());
assert_eq!(request.claims::<Claims>().unwrap().sub, "user123");
assert!(request.require_auth::<Claims>().is_ok());
}
}