#[cfg(feature = "api-key")]
use crate::server::auth::ApiKeyConfig;
#[cfg(feature = "api-key")]
pub struct ApiKeyMiddleware {
config: ApiKeyConfig,
}
#[cfg(feature = "api-key")]
impl ApiKeyMiddleware {
#[must_use]
pub fn new(config: ApiKeyConfig) -> Self {
Self { config }
}
#[must_use]
pub fn validate_request(
&self,
headers: &std::collections::HashMap<String, String>,
query_params: Option<&std::collections::HashMap<String, String>>,
) -> bool {
if !self.config.enabled {
return true;
}
if let Some(key) = headers.get(&self.config.header_name) {
return self.config.is_valid_key(key);
}
if self.config.allow_query_param {
if let Some(params) = query_params {
if let Some(key) = params.get(&self.config.query_param_name) {
return self.config.is_valid_key(key);
}
}
}
false
}
#[must_use]
pub fn extract_key(
&self,
headers: &std::collections::HashMap<String, String>,
query_params: Option<&std::collections::HashMap<String, String>>,
) -> Option<String> {
if let Some(key) = headers.get(&self.config.header_name) {
return Some(key.clone());
}
if self.config.allow_query_param {
if let Some(params) = query_params {
if let Some(key) = params.get(&self.config.query_param_name) {
return Some(key.clone());
}
}
}
None
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
}
#[derive(Debug, Clone)]
pub struct AuthError {
pub message: String,
pub www_authenticate: Option<String>,
}
impl AuthError {
#[must_use]
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
www_authenticate: None,
}
}
#[must_use]
pub fn unauthorized() -> Self {
Self {
message: "Unauthorized: API key required".to_string(),
www_authenticate: Some("ApiKey realm=\"crates-docs\"".to_string()),
}
}
#[must_use]
pub fn invalid_key() -> Self {
Self {
message: "Unauthorized: Invalid API key".to_string(),
www_authenticate: Some(
"ApiKey realm=\"crates-docs\" error=\"invalid_key\"".to_string(),
),
}
}
}
impl std::fmt::Display for AuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for AuthError {}
#[cfg(not(feature = "api-key"))]
pub struct NoOpMiddleware;
#[cfg(not(feature = "api-key"))]
impl NoOpMiddleware {
#[must_use]
pub fn validate_request(
&self,
_headers: &std::collections::HashMap<String, String>,
_query_params: Option<&std::collections::HashMap<String, String>>,
) -> bool {
true
}
#[must_use]
pub fn is_enabled(&self) -> bool {
false
}
}
#[cfg(test)]
#[cfg(feature = "api-key")]
mod tests {
use super::*;
fn create_test_config() -> (ApiKeyConfig, String) {
let generator = ApiKeyConfig::default();
let generated = generator
.generate_key()
.expect("failed to generate API key");
(
ApiKeyConfig {
enabled: true,
keys: vec![generated.hash],
..Default::default()
},
generated.key,
)
}
#[test]
fn test_middleware_disabled() {
let config = ApiKeyConfig::default();
let middleware = ApiKeyMiddleware::new(config);
let headers = std::collections::HashMap::new();
assert!(middleware.validate_request(&headers, None));
}
#[test]
fn test_middleware_valid_key_header() {
let (config, api_key) = create_test_config();
let middleware = ApiKeyMiddleware::new(config);
let mut headers = std::collections::HashMap::new();
headers.insert("X-API-Key".to_string(), api_key);
assert!(middleware.validate_request(&headers, None));
}
#[test]
fn test_middleware_invalid_key_header() {
let (config, _) = create_test_config();
let middleware = ApiKeyMiddleware::new(config);
let mut headers = std::collections::HashMap::new();
headers.insert("X-API-Key".to_string(), "invalid_key".to_string());
assert!(!middleware.validate_request(&headers, None));
}
#[test]
fn test_middleware_missing_key() {
let (config, _) = create_test_config();
let middleware = ApiKeyMiddleware::new(config);
let headers = std::collections::HashMap::new();
assert!(!middleware.validate_request(&headers, None));
}
#[test]
fn test_middleware_query_param_allowed() {
let (mut config, api_key) = create_test_config();
config.allow_query_param = true;
let middleware = ApiKeyMiddleware::new(config);
let headers = std::collections::HashMap::new();
let mut query_params = std::collections::HashMap::new();
query_params.insert("api_key".to_string(), api_key);
assert!(middleware.validate_request(&headers, Some(&query_params)));
}
#[test]
fn test_middleware_query_param_not_allowed() {
let (config, api_key) = create_test_config();
let middleware = ApiKeyMiddleware::new(config);
let headers = std::collections::HashMap::new();
let mut query_params = std::collections::HashMap::new();
query_params.insert("api_key".to_string(), api_key);
assert!(!middleware.validate_request(&headers, Some(&query_params)));
}
#[test]
fn test_extract_key() {
let (config, api_key) = create_test_config();
let middleware = ApiKeyMiddleware::new(config);
let mut headers = std::collections::HashMap::new();
headers.insert("X-API-Key".to_string(), api_key.clone());
let key = middleware.extract_key(&headers, None);
assert_eq!(key, Some(api_key));
}
#[test]
fn test_auth_error() {
let error = AuthError::unauthorized();
assert_eq!(error.message, "Unauthorized: API key required");
assert!(error.www_authenticate.is_some());
}
}