use alloc::sync::Arc;
use core::task::{Context, Poll};
use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
use http_body::Body;
use tower_layer::Layer;
use tower_service::Service;
use crate::middleware::futures::DefinedFuture;
#[derive(Debug, Clone)]
pub struct ApiKeyAuth {
pub header: Arc<HeaderName>,
pub allowed_keys: Arc<Vec<String>>,
}
impl ApiKeyAuth {
pub fn new(header: HeaderName, allowed_keys: Vec<String>) -> Self {
Self {
allowed_keys: Arc::new(allowed_keys),
header: Arc::new(header),
}
}
pub fn get_header<'a, T>(&self, request: &'a Request<T>) -> Option<&'a str> {
request
.headers()
.get(self.header.as_ref())
.map(HeaderValue::to_str)
.transpose()
.ok()
.flatten()
}
pub fn is_allowed_key(&self, key: &str) -> bool {
self.allowed_keys
.iter()
.map(String::as_str)
.any(|allowed| allowed.eq(key))
}
}
impl<S> Layer<S> for ApiKeyAuth {
type Service = ApiKeyAuthService<S>;
fn layer(&self, inner: S) -> Self::Service {
ApiKeyAuthService {
inner,
auth: self.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct ApiKeyAuthService<S> {
inner: S,
auth: ApiKeyAuth,
}
impl<S> ApiKeyAuthService<S> {
pub fn new(inner: S, auth: ApiKeyAuth) -> Self {
Self { inner, auth }
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ApiKeyAuthService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ResBody: Body + Default,
{
type Response = S::Response;
type Error = S::Error;
type Future = DefinedFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
let Some(key) = self.auth.get_header(&request) else {
return DefinedFuture::return_status(StatusCode::UNAUTHORIZED);
};
if self.auth.is_allowed_key(key) {
DefinedFuture::proceed(self.inner.call(request))
} else {
DefinedFuture::return_status(StatusCode::FORBIDDEN)
}
}
}
#[cfg(test)]
mod test {
use axum::{Router, routing::get};
use bytes::Bytes;
use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
use http_body_util::Full;
use tower::{ServiceBuilder, ServiceExt};
use tower_http::BoxError;
use crate::{middleware::api_key::ApiKeyAuth, test::ResponseTestExt};
async fn echo(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
Ok(Response::new(req.into_body()))
}
#[tokio::test]
async fn axum_compat() {
let api_key_auth = ApiKeyAuth::new(
HeaderName::from_static("x-api-key"),
vec!["api-key-1".to_string()],
);
let router = Router::new()
.route("/", get(|| async { StatusCode::OK }))
.layer(api_key_auth);
router
.oneshot(
Request::builder()
.uri("/")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap()
.expect_status(StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn blocks_no_header() {
let api_key_auth = ApiKeyAuth::new(
HeaderName::from_static("x-api-key"),
vec!["api-key-1".to_string()],
);
let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
let body: Bytes = Bytes::new();
let request = Request::new(Full::new(body));
service
.ready()
.await
.unwrap()
.oneshot(request)
.await
.unwrap()
.expect_status(StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn blocks_not_allowed() {
let api_key_auth = ApiKeyAuth::new(
HeaderName::from_static("x-api-key"),
vec!["api-key-1".to_string()],
);
let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
let body: Bytes = Bytes::new();
let mut request = Request::new(Full::new(body));
request
.headers_mut()
.insert("x-api-key", HeaderValue::from_static("not-allowed-key"));
service
.ready()
.await
.unwrap()
.oneshot(request)
.await
.unwrap()
.expect_status(StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn allows_allowed() {
let api_key_auth = ApiKeyAuth::new(
HeaderName::from_static("x-api-key"),
vec!["api-key-1".to_string()],
);
let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
let body: Bytes = Bytes::new();
let mut request = Request::new(Full::new(body));
request
.headers_mut()
.insert("x-api-key", HeaderValue::from_static("api-key-1"));
service
.ready()
.await
.unwrap()
.oneshot(request)
.await
.unwrap()
.expect_status(StatusCode::OK);
}
}