use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
use serde::{Deserialize, Serialize};
use tower_layer::Layer;
use tower_service::Service;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub exp: u64,
pub roles: Vec<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("missing authorization token")]
MissingToken,
#[error("invalid token: {0}")]
InvalidToken(String),
#[error("token has expired")]
Expired,
#[error("unauthorized")]
Unauthorized,
}
impl From<AuthError> for tonic::Status {
fn from(e: AuthError) -> tonic::Status {
tonic::Status::unauthenticated(e.to_string())
}
}
pub trait AuthValidator: Send + Sync {
fn validate<'a>(
&'a self,
token: &'a str,
) -> Pin<Box<dyn Future<Output = Result<Claims, AuthError>> + Send + 'a>>;
}
#[derive(Clone)]
pub struct BearerTokenValidator {
decoding_key: DecodingKey,
validation: Validation,
}
impl BearerTokenValidator {
pub fn new(secret: &[u8]) -> Self {
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_aud = false;
validation.leeway = 0;
Self {
decoding_key: DecodingKey::from_secret(secret),
validation,
}
}
}
impl AuthValidator for BearerTokenValidator {
fn validate<'a>(
&'a self,
token: &'a str,
) -> Pin<Box<dyn Future<Output = Result<Claims, AuthError>> + Send + 'a>> {
let jwt = match token.strip_prefix("Bearer ") {
Some(t) => t.to_owned(),
None => {
return Box::pin(async {
Err(AuthError::InvalidToken("not a Bearer token".to_string()))
});
}
};
let decoding_key = self.decoding_key.clone();
let validation = self.validation.clone();
Box::pin(async move {
match decode::<Claims>(&jwt, &decoding_key, &validation) {
Ok(token_data) => Ok(token_data.claims),
Err(e) => {
use jsonwebtoken::errors::ErrorKind;
match e.kind() {
ErrorKind::ExpiredSignature => Err(AuthError::Expired),
_ => Err(AuthError::InvalidToken(e.to_string())),
}
}
}
})
}
}
#[derive(Clone)]
pub struct AuthMiddlewareLayer<V>
where
V: AuthValidator + Clone + Send + Sync + 'static,
{
validator: V,
}
impl<V> AuthMiddlewareLayer<V>
where
V: AuthValidator + Clone + Send + Sync + 'static,
{
pub fn new(validator: V) -> Self {
Self { validator }
}
}
impl<S, V> Layer<S> for AuthMiddlewareLayer<V>
where
V: AuthValidator + Clone + Send + Sync + 'static,
{
type Service = AuthMiddleware<S, V>;
fn layer(&self, inner: S) -> Self::Service {
AuthMiddleware {
inner,
validator: self.validator.clone(),
}
}
}
#[derive(Clone)]
pub struct AuthMiddleware<S, V>
where
V: AuthValidator + Clone + Send + Sync + 'static,
{
inner: S,
validator: V,
}
impl<S, B, ResBody, V> Service<http::Request<B>> for AuthMiddleware<S, V>
where
S: Service<http::Request<B>, Response = http::Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
B: Send + 'static,
ResBody: Default + Send + 'static,
V: AuthValidator + Clone + Send + Sync + 'static,
{
type Response = http::Response<ResBody>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: http::Request<B>) -> Self::Future {
let token_result = req
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned())
.ok_or(AuthError::MissingToken);
let validator = self.validator.clone();
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
let token = match token_result {
Ok(t) => t,
Err(e) => {
return Ok(build_unauthenticated_response(e.to_string()));
}
};
match validator.validate(&token).await {
Ok(claims) => {
req.extensions_mut().insert(claims);
inner.call(req).await
}
Err(e) => Ok(build_unauthenticated_response(e.to_string())),
}
})
}
}
fn build_unauthenticated_response<ResBody: Default>(message: String) -> http::Response<ResBody> {
let status = tonic::Status::unauthenticated(message);
let (mut parts, _body) = http::Response::new(ResBody::default()).into_parts();
parts.status = http::StatusCode::OK; parts.headers.insert(
"grpc-status",
http::HeaderValue::from_str(&(status.code() as i32).to_string())
.unwrap_or_else(|_| http::HeaderValue::from_static("16")),
);
if let Ok(v) = http::HeaderValue::from_str(status.message()) {
parts.headers.insert("grpc-message", v);
}
http::Response::from_parts(parts, ResBody::default())
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
use std::time::{SystemTime, UNIX_EPOCH};
use tower_service::Service as _;
fn make_jwt(secret: &[u8], exp_offset_secs: i64) -> String {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before UNIX epoch")
.as_secs();
let exp = if exp_offset_secs >= 0 {
now + exp_offset_secs as u64
} else {
now.saturating_sub(exp_offset_secs.unsigned_abs())
};
let claims = Claims {
sub: "test-user".to_string(),
exp,
roles: vec!["admin".to_string()],
};
encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret),
)
.expect("JWT encoding should not fail in tests")
}
#[derive(Clone)]
struct AlwaysOkValidator;
impl AuthValidator for AlwaysOkValidator {
fn validate<'a>(
&'a self,
_token: &'a str,
) -> Pin<Box<dyn Future<Output = Result<Claims, AuthError>> + Send + 'a>> {
Box::pin(async {
Ok(Claims {
sub: "always-ok".to_string(),
exp: u64::MAX,
roles: vec![],
})
})
}
}
#[derive(Clone)]
struct EchoService;
impl Service<http::Request<String>> for EchoService {
type Response = http::Response<String>;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: http::Request<String>) -> Self::Future {
let claims = req.extensions().get::<Claims>().cloned();
Box::pin(async move {
let body = match claims {
Some(c) => format!("sub={}", c.sub),
None => "no-claims".to_string(),
};
Ok(http::Response::new(body))
})
}
}
#[tokio::test]
async fn test_auth_rejects_missing_token() {
let layer = AuthMiddlewareLayer::new(AlwaysOkValidator);
let mut svc = layer.layer(EchoService);
let req = http::Request::builder()
.body(String::new())
.expect("request builder should not fail");
let resp = svc.call(req).await.expect("service call should not error");
let grpc_status = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.unwrap_or("missing");
assert_eq!(
grpc_status, "16",
"expected grpc-status=16 (UNAUTHENTICATED)"
);
}
#[tokio::test]
async fn test_auth_accepts_valid_token() {
let secret = b"supersecret";
let jwt = make_jwt(secret, 3600);
let bearer = format!("Bearer {jwt}");
let layer = AuthMiddlewareLayer::new(BearerTokenValidator::new(secret));
let mut svc = layer.layer(EchoService);
let req = http::Request::builder()
.header(http::header::AUTHORIZATION, &bearer)
.body(String::new())
.expect("request builder should not fail");
let resp = svc.call(req).await.expect("service call should not error");
assert_eq!(resp.body(), "sub=test-user");
assert!(
resp.headers().get("grpc-status").is_none(),
"should not have grpc-status on success"
);
}
#[tokio::test]
async fn test_auth_custom_validator() {
let layer = AuthMiddlewareLayer::new(AlwaysOkValidator);
let mut svc = layer.layer(EchoService);
let req = http::Request::builder()
.header(http::header::AUTHORIZATION, "Bearer anything")
.body(String::new())
.expect("request builder should not fail");
let resp = svc.call(req).await.expect("service call should not error");
assert_eq!(resp.body(), "sub=always-ok");
}
#[tokio::test]
async fn test_bearer_validator_rejects_non_bearer_prefix() {
let validator = BearerTokenValidator::new(b"secret");
let result = validator.validate("Token abc123").await;
assert!(
matches!(result, Err(AuthError::InvalidToken(_))),
"expected InvalidToken for non-Bearer scheme"
);
}
#[tokio::test]
async fn test_bearer_validator_rejects_expired() {
let secret = b"expiry-test-secret";
let jwt = make_jwt(secret, -1); let bearer = format!("Bearer {jwt}");
let validator = BearerTokenValidator::new(secret);
let result = validator.validate(&bearer).await;
assert!(
matches!(result, Err(AuthError::Expired)),
"expected Expired error for expired JWT, got: {:?}",
result
);
}
#[test]
fn test_auth_validator_is_object_safe() {
let validator = BearerTokenValidator::new(b"test");
let _dyn_ref: &dyn AuthValidator = &validator;
}
#[test]
fn test_layer_construction() {
let _layer: AuthMiddlewareLayer<AlwaysOkValidator> =
AuthMiddlewareLayer::new(AlwaysOkValidator);
}
}