use rustapi_core::{
middleware::{BoxedNext, MiddlewareLayer},
Request, Response, ResponseBody,
};
use std::collections::HashSet;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Clone)]
pub struct ApiKeyConfig {
pub keys: Arc<HashSet<String>>,
pub header_name: String,
pub query_param_name: Option<String>,
pub skip_paths: Vec<String>,
}
impl Default for ApiKeyConfig {
fn default() -> Self {
Self {
keys: Arc::new(HashSet::new()),
header_name: "X-API-Key".to_string(),
query_param_name: None,
skip_paths: vec!["/health".to_string(), "/docs".to_string()],
}
}
}
#[derive(Clone)]
pub struct ApiKeyLayer {
config: ApiKeyConfig,
}
impl ApiKeyLayer {
pub fn new() -> Self {
Self {
config: ApiKeyConfig::default(),
}
}
pub fn header(mut self, name: impl Into<String>) -> Self {
self.config.header_name = name.into();
self
}
pub fn query_param(mut self, name: impl Into<String>) -> Self {
self.config.query_param_name = Some(name.into());
self
}
pub fn add_key(mut self, key: impl Into<String>) -> Self {
let keys = Arc::make_mut(&mut self.config.keys);
keys.insert(key.into());
self
}
pub fn add_keys(mut self, keys: Vec<String>) -> Self {
let key_set = Arc::make_mut(&mut self.config.keys);
for key in keys {
key_set.insert(key);
}
self
}
pub fn skip_path(mut self, path: impl Into<String>) -> Self {
self.config.skip_paths.push(path.into());
self
}
}
impl Default for ApiKeyLayer {
fn default() -> Self {
Self::new()
}
}
impl MiddlewareLayer for ApiKeyLayer {
fn call(
&self,
req: Request,
next: BoxedNext,
) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
let config = self.config.clone();
Box::pin(async move {
let path = req.uri().path();
if config.skip_paths.iter().any(|p| path.starts_with(p)) {
return next(req).await;
}
let api_key = if let Some(header_value) = req.headers().get(&config.header_name) {
header_value.to_str().ok()
} else {
None
};
let api_key = if api_key.is_none() {
if let Some(query_param) = &config.query_param_name {
req.uri().query().and_then(|q| {
q.split('&').find_map(|param| {
let mut parts = param.split('=');
if parts.next()? == query_param {
parts.next()
} else {
None
}
})
})
} else {
None
}
} else {
api_key
};
match api_key {
Some(key) if config.keys.contains(key) => {
next(req).await
}
Some(_) => {
create_unauthorized_response("Invalid API key")
}
None => {
create_unauthorized_response("Missing API key")
}
}
})
}
fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
Box::new(self.clone())
}
}
fn create_unauthorized_response(message: &str) -> Response {
let error_body = serde_json::json!({
"error": {
"type": "unauthorized",
"message": message
}
});
let body = serde_json::to_vec(&error_body).unwrap_or_default();
http::Response::builder()
.status(401)
.header("Content-Type", "application/json")
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from(body),
)))
.unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use std::sync::Arc;
#[tokio::test]
async fn api_key_valid_header() {
let layer = ApiKeyLayer::new()
.header("X-API-Key")
.add_key("test-key-123");
let next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(200)
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from("OK"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let req = http::Request::builder()
.method("GET")
.uri("/api/users")
.header("X-API-Key", "test-key-123")
.body(())
.unwrap();
let req = Request::from_http_request(req, Bytes::new());
let response = layer.call(req, next).await;
assert_eq!(response.status(), 200);
}
#[tokio::test]
async fn api_key_invalid_header() {
let layer = ApiKeyLayer::new()
.header("X-API-Key")
.add_key("test-key-123");
let next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(200)
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from("OK"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let req = http::Request::builder()
.method("GET")
.uri("/api/users")
.header("X-API-Key", "wrong-key")
.body(())
.unwrap();
let req = Request::from_http_request(req, Bytes::new());
let response = layer.call(req, next).await;
assert_eq!(response.status(), 401);
}
#[tokio::test]
async fn api_key_missing() {
let layer = ApiKeyLayer::new()
.header("X-API-Key")
.add_key("test-key-123");
let next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(200)
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from("OK"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let req = http::Request::builder()
.method("GET")
.uri("/api/users")
.body(())
.unwrap();
let req = Request::from_http_request(req, Bytes::new());
let response = layer.call(req, next).await;
assert_eq!(response.status(), 401);
}
#[tokio::test]
async fn api_key_skips_health_check() {
let layer = ApiKeyLayer::new()
.header("X-API-Key")
.add_key("test-key-123");
let next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(200)
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from("OK"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let req = http::Request::builder()
.method("GET")
.uri("/health")
.body(())
.unwrap();
let req = Request::from_http_request(req, Bytes::new());
let response = layer.call(req, next).await;
assert_eq!(response.status(), 200);
}
#[tokio::test]
async fn api_key_query_param() {
let layer = ApiKeyLayer::new()
.query_param("api_key")
.add_key("test-key-123");
let next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(200)
.body(ResponseBody::Full(http_body_util::Full::new(
bytes::Bytes::from("OK"),
)))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let req = http::Request::builder()
.method("GET")
.uri("/api/users?api_key=test-key-123")
.body(())
.unwrap();
let req = Request::from_http_request(req, Bytes::new());
let response = layer.call(req, next).await;
assert_eq!(response.status(), 200);
}
}