use crate::config::AuthConfig;
use crate::error::{ServiceError, ServiceResult};
use super::{AuthContext, JwtClaims};
use axum::body::Body;
use axum::response::{IntoResponse, Response};
use connectrpc::{ConnectError, ErrorCode};
use jsonwebtoken::{DecodingKey, EncodingKey, Validation};
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context as TaskContext, Poll},
};
use tower::{Layer, Service};
#[derive(Debug, Clone)]
pub struct JwtValidator {
secret: String,
expiry: u64,
issuer: Option<String>,
audience: Option<String>,
}
impl JwtValidator {
pub fn new(config: AuthConfig) -> Self {
Self {
secret: config.jwt_secret,
expiry: config.token_expiry,
issuer: None,
audience: None,
}
}
pub fn with_secret(secret: impl Into<String>) -> Self {
Self {
secret: secret.into(),
expiry: 3600, issuer: None,
audience: None,
}
}
pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = Some(issuer.into());
self
}
pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
pub fn validate(&self, token: &str) -> ServiceResult<JwtClaims> {
let decoding_key = DecodingKey::from_secret(self.secret.as_ref());
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
if let Some(ref issuer) = self.issuer {
validation.set_issuer(&[issuer.as_str()]);
}
if let Some(ref audience) = self.audience {
validation.set_audience(&[audience.as_str()]);
}
let token_data = jsonwebtoken::decode::<JwtClaims>(token, &decoding_key, &validation)
.map_err(|e| ServiceError::Unauthenticated(format!("Invalid token: {}", e)))?;
if token_data.claims.is_expired() {
return Err(ServiceError::Unauthenticated("Token expired".to_string()));
}
Ok(token_data.claims)
}
pub fn create_token(&self, claims: JwtClaims) -> ServiceResult<String> {
let encoding_key = EncodingKey::from_secret(self.secret.as_ref());
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::HS256);
let mut token_claims = claims;
if let Some(ref issuer) = self.issuer {
token_claims.iss = Some(issuer.clone());
}
if let Some(ref audience) = self.audience {
token_claims.aud = Some(audience.clone());
}
jsonwebtoken::encode(&header, &token_claims, &encoding_key)
.map_err(|e| ServiceError::Internal(format!("Failed to create token: {}", e)))
}
pub fn create_user_token(&self, user_id: impl Into<String>) -> ServiceResult<String> {
let now = chrono::Utc::now().timestamp();
let claims = JwtClaims {
sub: user_id.into(),
iat: now,
exp: now + self.expiry as i64,
iss: self.issuer.clone(),
aud: self.audience.clone(),
extra: std::collections::HashMap::new(),
};
self.create_token(claims)
}
}
fn unauthorized(message: &str) -> Response {
ConnectError::new(ErrorCode::Unauthenticated, message).into_response()
}
#[derive(Debug, Clone)]
pub struct JwtLayer {
validator: Arc<JwtValidator>,
}
impl JwtLayer {
pub fn new(validator: JwtValidator) -> Self {
Self {
validator: Arc::new(validator),
}
}
pub fn from_config(config: AuthConfig) -> Self {
Self::new(JwtValidator::new(config))
}
}
impl<S> Layer<S> for JwtLayer {
type Service = JwtService<S>;
fn layer(&self, inner: S) -> Self::Service {
JwtService {
inner,
validator: Arc::clone(&self.validator),
}
}
}
#[derive(Debug, Clone)]
pub struct JwtService<S> {
inner: S,
validator: Arc<JwtValidator>,
}
impl<S> Service<http::Request<Body>> for JwtService<S>
where
S: Service<http::Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
let validator = Arc::clone(&self.validator);
let auth_ctx = match extract_jwt_claims(req.headers()) {
None => AuthContext::unauthenticated(),
Some(token) => match validator.validate(&token) {
Ok(claims) => {
let subject = claims.sub.clone();
let exp = claims.exp;
AuthContext::authenticated(subject, Some(claims)).with_exp(exp)
}
Err(_) => {
let resp = unauthorized("invalid or expired token");
return Box::pin(async move { Ok(resp) });
}
},
};
req.extensions_mut().insert(auth_ctx);
Box::pin(self.inner.call(req))
}
}
pub fn extract_jwt_claims(headers: &http::HeaderMap) -> Option<String> {
headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer ").map(str::to_string))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jwt_validator_new() {
let config = AuthConfig::default();
let validator = JwtValidator::new(config);
assert_eq!(validator.expiry, 3600);
}
#[test]
fn test_jwt_claims_is_expired() {
let claims = JwtClaims {
sub: "user-1".to_string(),
iat: 0,
exp: 0,
iss: None,
aud: None,
extra: std::collections::HashMap::new(),
};
assert!(claims.is_expired());
}
#[test]
fn test_jwt_claims_not_expired() {
let now = chrono::Utc::now().timestamp();
let claims = JwtClaims {
sub: "user-1".to_string(),
iat: now,
exp: now + 3600,
iss: None,
aud: None,
extra: std::collections::HashMap::new(),
};
assert!(!claims.is_expired());
}
#[test]
fn test_extract_jwt_claims_bearer() {
let mut headers = http::HeaderMap::new();
headers.insert("Authorization", http::HeaderValue::from_static("Bearer mytoken"));
assert_eq!(extract_jwt_claims(&headers), Some("mytoken".to_string()));
}
#[test]
fn test_extract_jwt_claims_no_bearer() {
let mut headers = http::HeaderMap::new();
headers.insert("Authorization", http::HeaderValue::from_static("Basic creds"));
assert_eq!(extract_jwt_claims(&headers), None);
}
#[test]
fn test_extract_jwt_claims_no_header() {
let headers = http::HeaderMap::new();
assert_eq!(extract_jwt_claims(&headers), None);
}
#[test]
fn test_jwt_roundtrip() {
let validator = JwtValidator::with_secret("my-secret");
let token = validator.create_user_token("user-1").unwrap();
assert!(!token.is_empty());
let claims = validator.validate(&token).unwrap();
assert_eq!(claims.sub, "user-1");
}
#[test]
fn test_jwt_invalid_signature() {
let validator = JwtValidator::with_secret("my-secret");
let other_validator = JwtValidator::with_secret("other-secret");
let token = other_validator.create_user_token("user-1").unwrap();
let result = validator.validate(&token);
assert!(matches!(result, Err(ServiceError::Unauthenticated(_))));
}
use axum::body::Body;
use axum::response::{IntoResponse, Response};
use tower::{ServiceBuilder, ServiceExt};
fn echo_service() -> impl Service<
http::Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Future<Output = Result<Response, std::convert::Infallible>>,
> + Clone {
tower::service_fn(|_req: http::Request<Body>| async {
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(http::StatusCode::OK)
.body(Body::empty())
.unwrap()
.into_response(),
)
})
}
#[tokio::test]
async fn test_jwt_layer_missing_token_passes_unauthenticated() {
let validator = JwtValidator::with_secret("secret");
let layer = JwtLayer::new(validator);
let mut svc = ServiceBuilder::new().layer(layer).service(
tower::service_fn(|req: http::Request<Body>| async move {
let ctx = req.extensions().get::<AuthContext>().cloned().unwrap();
assert!(!ctx.is_authenticated());
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(http::StatusCode::OK)
.body(Body::empty())
.unwrap()
.into_response(),
)
}),
);
let req = http::Request::builder()
.uri("/")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::OK);
}
#[tokio::test]
async fn test_jwt_layer_valid_token_injects_context() {
let validator = JwtValidator::with_secret("secret");
let token = validator.create_user_token("alice").unwrap();
let layer = JwtLayer::new(validator);
let mut svc = ServiceBuilder::new().layer(layer).service(
tower::service_fn(|req: http::Request<Body>| async move {
let ctx = req.extensions().get::<AuthContext>().cloned().unwrap();
assert!(ctx.is_authenticated());
assert_eq!(ctx.subject(), Some(&"alice".to_string()));
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(http::StatusCode::OK)
.body(Body::empty())
.unwrap()
.into_response(),
)
}),
);
let req = http::Request::builder()
.uri("/")
.header("Authorization", format!("Bearer {}", token))
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::OK);
}
#[tokio::test]
async fn test_jwt_layer_invalid_token_returns_401() {
let validator = JwtValidator::with_secret("secret");
let layer = JwtLayer::new(validator);
let mut svc = ServiceBuilder::new()
.layer(layer)
.service(echo_service());
let req = http::Request::builder()
.uri("/")
.header("Authorization", "Bearer not.a.valid.token")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_jwt_layer_wrong_secret_returns_401() {
let signer = JwtValidator::with_secret("other-secret");
let token = signer.create_user_token("bob").unwrap();
let validator = JwtValidator::with_secret("secret");
let layer = JwtLayer::new(validator);
let mut svc = ServiceBuilder::new()
.layer(layer)
.service(echo_service());
let req = http::Request::builder()
.uri("/")
.header("Authorization", format!("Bearer {}", token))
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);
}
}