#[cfg(feature = "http-server")]
mod inner {
use axum::{
body::Body,
http::{Method, Request, Response, StatusCode, header},
};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use std::{
fmt,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use subtle::ConstantTimeEq;
use tower::{Layer, Service};
type HmacSha256 = Hmac<Sha256>;
#[derive(Debug, thiserror::Error)]
pub enum AuthConfigError {
#[error("API key list must not be empty")]
EmptyKeyList,
#[error("API key must not contain whitespace")]
WhitespaceInKey,
#[error("failed to seed HMAC key from system RNG: {0}")]
RngFailure(getrandom::Error),
}
struct ApiKeyState {
hmac_key: [u8; 32],
tags: Vec<[u8; 32]>,
}
pub struct ApiKeyConfig {
pub(crate) keys: Vec<[u8; 32]>,
pub(crate) hmac_key: [u8; 32],
}
impl ApiKeyConfig {
pub fn new(raw_keys: &[&str]) -> Result<Self, AuthConfigError> {
if raw_keys.is_empty() {
return Err(AuthConfigError::EmptyKeyList);
}
if raw_keys
.iter()
.any(|k| k.bytes().any(|b| b.is_ascii_whitespace()))
{
return Err(AuthConfigError::WhitespaceInKey);
}
let mut hmac_key = [0u8; 32];
getrandom::fill(&mut hmac_key).map_err(AuthConfigError::RngFailure)?;
let keys = raw_keys
.iter()
.map(|k| hmac_tag(&hmac_key, k.as_bytes()))
.collect();
Ok(Self { keys, hmac_key })
}
}
impl fmt::Debug for ApiKeyConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ApiKeyConfig")
.field("keys", &format!("[{} redacted tags]", self.keys.len()))
.field("hmac_key", &"[redacted]")
.finish()
}
}
#[derive(Clone)]
pub struct ApiKeyAuthLayer {
inner: Arc<ApiKeyState>,
}
impl ApiKeyAuthLayer {
pub fn new(config: ApiKeyConfig) -> Self {
Self {
inner: Arc::new(ApiKeyState {
hmac_key: config.hmac_key,
tags: config.keys,
}),
}
}
}
impl<S> Layer<S> for ApiKeyAuthLayer {
type Service = ApiKeyAuthService<S>;
fn layer(&self, inner: S) -> Self::Service {
ApiKeyAuthService {
inner,
state: self.inner.clone(),
}
}
}
#[derive(Clone)]
pub struct ApiKeyAuthService<S> {
inner: S,
state: Arc<ApiKeyState>,
}
impl<S> Service<Request<Body>> for ApiKeyAuthService<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
if req.method() == Method::OPTIONS {
let fut = self.inner.call(req);
return Box::pin(fut);
}
let state = self.state.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
match extract_token(&req) {
Some(candidate) if matches_any(&state, candidate) => inner.call(req).await,
_ => Ok(unauthorized_response()),
}
})
}
}
fn hmac_tag(key: &[u8; 32], data: &[u8]) -> [u8; 32] {
let mut mac = HmacSha256::new_from_slice(key)
.expect("HMAC-SHA256 accepts any key length");
mac.update(data);
mac.finalize().into_bytes().into()
}
fn matches_any(state: &ApiKeyState, candidate: &[u8]) -> bool {
let candidate_tag = hmac_tag(&state.hmac_key, candidate);
let mut acc: u8 = 0;
for tag in &state.tags {
acc |= tag.ct_eq(&candidate_tag).unwrap_u8();
}
acc == 1
}
fn extract_token(req: &Request<Body>) -> Option<&[u8]> {
if let Some(v) = req.headers().get(header::AUTHORIZATION)
&& let Some(stripped) = v.as_bytes().strip_prefix(b"Bearer ")
{
return Some(trim_ascii(stripped));
}
if let Some(v) = req.headers().get("x-pjs-api-key") {
return Some(trim_ascii(v.as_bytes()));
}
None
}
fn trim_ascii(bytes: &[u8]) -> &[u8] {
let start = bytes
.iter()
.position(|b| !b.is_ascii_whitespace())
.unwrap_or(bytes.len());
let end = bytes
.iter()
.rposition(|b| !b.is_ascii_whitespace())
.map_or(start, |i| i + 1);
&bytes[start..end]
}
fn unauthorized_response() -> Response<Body> {
let body = serde_json::json!({ "error": "Unauthorized" }).to_string();
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body))
.expect("static unauthorized response is always valid")
}
#[cfg(feature = "http-auth-jwt")]
pub use jwt::{JwtAuthLayer, JwtAuthService, JwtConfig};
#[cfg(feature = "http-auth-jwt")]
mod jwt {
use axum::{
body::Body,
http::{Method, Request, Response, StatusCode, header},
};
use jsonwebtoken::{DecodingKey, Validation};
use serde::de::DeserializeOwned;
use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower::{Layer, Service};
pub struct JwtConfig {
pub decoding_key: DecodingKey,
pub validation: Validation,
}
pub struct JwtAuthLayer<C> {
inner: Arc<JwtState>,
_claims: PhantomData<fn() -> C>,
}
impl<C> Clone for JwtAuthLayer<C> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_claims: PhantomData,
}
}
}
struct JwtState {
decoding_key: DecodingKey,
validation: Validation,
}
impl<C> JwtAuthLayer<C>
where
C: DeserializeOwned + Send + Sync + 'static,
{
pub fn new(config: JwtConfig) -> Self {
Self {
inner: Arc::new(JwtState {
decoding_key: config.decoding_key,
validation: config.validation,
}),
_claims: PhantomData,
}
}
}
impl<S, C> Layer<S> for JwtAuthLayer<C>
where
C: DeserializeOwned + Send + Sync + 'static,
{
type Service = JwtAuthService<S, C>;
fn layer(&self, inner: S) -> Self::Service {
JwtAuthService {
inner,
state: self.inner.clone(),
_claims: PhantomData,
}
}
}
pub struct JwtAuthService<S, C> {
inner: S,
state: Arc<JwtState>,
_claims: PhantomData<fn() -> C>,
}
impl<S, C> Clone for JwtAuthService<S, C>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
state: self.state.clone(),
_claims: PhantomData,
}
}
}
impl<S, C> Service<Request<Body>> for JwtAuthService<S, C>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
C: DeserializeOwned + Send + Sync + 'static,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
if req.method() == Method::OPTIONS {
let fut = self.inner.call(req);
return Box::pin(fut);
}
let state = self.state.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let token = match extract_bearer(&req) {
Some(t) => t,
None => return Ok(jwt_unauthorized_response()),
};
let token_str = match std::str::from_utf8(token) {
Ok(s) => s,
Err(_) => return Ok(jwt_unauthorized_response()),
};
match jsonwebtoken::decode::<C>(
token_str,
&state.decoding_key,
&state.validation,
) {
Ok(_) => inner.call(req).await,
Err(_) => Ok(jwt_unauthorized_response()),
}
})
}
}
fn extract_bearer(req: &Request<Body>) -> Option<&[u8]> {
req.headers()
.get(header::AUTHORIZATION)?
.as_bytes()
.strip_prefix(b"Bearer ")
}
fn jwt_unauthorized_response() -> Response<Body> {
let body = serde_json::json!({ "error": "Unauthorized" }).to_string();
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body))
.expect("static unauthorized response is always valid")
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Method, Request, StatusCode},
};
use tower::{Service, ServiceExt};
#[test]
fn api_key_config_debug_redacts_key_material() {
let config = ApiKeyConfig::new(&["test-key-one"]).unwrap();
let debug = format!("{config:?}");
assert!(
debug.contains("redacted"),
"debug output must redact keys: {debug}"
);
assert!(
!debug.contains("hmac_key: ["),
"hmac_key must not appear as raw bytes: {debug}"
);
}
#[test]
fn empty_key_list_is_rejected() {
let err = ApiKeyConfig::new(&[]).unwrap_err();
assert!(matches!(err, AuthConfigError::EmptyKeyList));
}
#[test]
fn key_with_whitespace_is_rejected() {
let err = ApiKeyConfig::new(&["valid-key", "bad key"]).unwrap_err();
assert!(matches!(err, AuthConfigError::WhitespaceInKey));
}
#[test]
fn key_with_leading_whitespace_is_rejected() {
let err = ApiKeyConfig::new(&[" leading"]).unwrap_err();
assert!(matches!(err, AuthConfigError::WhitespaceInKey));
}
#[test]
fn key_with_trailing_whitespace_is_rejected() {
let err = ApiKeyConfig::new(&["trailing "]).unwrap_err();
assert!(matches!(err, AuthConfigError::WhitespaceInKey));
}
#[test]
fn valid_single_key_is_accepted() {
assert!(ApiKeyConfig::new(&["valid-key"]).is_ok());
}
#[test]
fn valid_multiple_keys_are_accepted() {
assert!(ApiKeyConfig::new(&["key-one", "key-two", "key-three"]).is_ok());
}
type OkFn = fn(
Request<Body>,
)
-> std::future::Ready<Result<Response<Body>, std::convert::Infallible>>;
type TestSvc = ApiKeyAuthService<tower::util::ServiceFn<OkFn>>;
fn make_service(key: &str) -> TestSvc {
let config = ApiKeyConfig::new(&[key]).expect("valid key");
let layer = ApiKeyAuthLayer::new(config);
layer.layer(tower::service_fn(|_req: Request<Body>| {
std::future::ready(Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.unwrap(),
))
}))
}
#[tokio::test]
async fn valid_bearer_token_returns_200() {
let mut svc = make_service("my-secret-key");
let req = Request::builder()
.method(Method::GET)
.header("Authorization", "Bearer my-secret-key")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn valid_x_pjs_api_key_returns_200() {
let mut svc = make_service("my-secret-key");
let req = Request::builder()
.method(Method::GET)
.header("X-PJS-API-Key", "my-secret-key")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn wrong_token_returns_401() {
let mut svc = make_service("my-secret-key");
let req = Request::builder()
.method(Method::GET)
.header("Authorization", "Bearer wrong-key")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn missing_header_returns_401() {
let mut svc = make_service("my-secret-key");
let req = Request::builder()
.method(Method::GET)
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn options_bypasses_auth() {
let mut svc = make_service("my-secret-key");
let req = Request::builder()
.method(Method::OPTIONS)
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[test]
fn matches_any_correct_single_key() {
let config = ApiKeyConfig::new(&["secret"]).unwrap();
let state = ApiKeyState {
hmac_key: config.hmac_key,
tags: config.keys,
};
assert!(matches_any(&state, b"secret"));
assert!(!matches_any(&state, b"wrong"));
}
#[test]
fn matches_any_correct_multiple_keys() {
let config = ApiKeyConfig::new(&["key-a", "key-b", "key-c"]).unwrap();
let state = ApiKeyState {
hmac_key: config.hmac_key,
tags: config.keys,
};
assert!(matches_any(&state, b"key-a"));
assert!(matches_any(&state, b"key-b"));
assert!(matches_any(&state, b"key-c"));
assert!(!matches_any(&state, b"key-d"));
}
#[test]
fn extract_token_bearer() {
let req = Request::builder()
.header("Authorization", "Bearer test-token")
.body(Body::empty())
.unwrap();
assert_eq!(extract_token(&req), Some(b"test-token".as_slice()));
}
#[test]
fn extract_token_x_pjs_api_key() {
let req = Request::builder()
.header("X-PJS-API-Key", "test-token")
.body(Body::empty())
.unwrap();
assert_eq!(extract_token(&req), Some(b"test-token".as_slice()));
}
#[test]
fn extract_token_none_when_absent() {
let req = Request::builder().body(Body::empty()).unwrap();
assert_eq!(extract_token(&req), None);
}
#[test]
fn extract_token_bearer_preferred_over_x_pjs() {
let req = Request::builder()
.header("Authorization", "Bearer bearer-val")
.header("X-PJS-API-Key", "api-key-val")
.body(Body::empty())
.unwrap();
assert_eq!(extract_token(&req), Some(b"bearer-val".as_slice()));
}
}
}
#[cfg(feature = "http-server")]
pub use inner::{ApiKeyAuthLayer, ApiKeyAuthService, ApiKeyConfig, AuthConfigError};
#[cfg(all(feature = "http-server", feature = "http-auth-jwt"))]
pub use inner::{JwtAuthLayer, JwtAuthService, JwtConfig};