use axum::http::{HeaderMap, StatusCode};
use subtle::ConstantTimeEq;
use super::ProxyConfig;
pub(crate) fn constant_time_eq(a: &str, b: &str) -> bool {
a.as_bytes().ct_eq(b.as_bytes()).into()
}
pub(crate) fn check_auth(config: &ProxyConfig, headers: &HeaderMap) -> Option<StatusCode> {
let Some(ref expected) = config.proxy_api_key else {
return None; };
let provided = headers
.get("x-api-key")
.or_else(|| headers.get("authorization"))
.and_then(|v| v.to_str().ok());
match provided {
Some(key)
if constant_time_eq(key, expected)
|| constant_time_eq(key, &format!("Bearer {expected}")) =>
{
None
}
_ => Some(StatusCode::UNAUTHORIZED),
}
}
#[cfg(test)]
mod tests {
use axum::http::HeaderMap;
use super::*;
fn config_with_key(key: &str) -> ProxyConfig {
ProxyConfig {
upstream_url: String::new(),
upstream_api_key: String::new(),
proxy_api_key: Some(key.to_string()),
router: None,
}
}
fn config_without_key() -> ProxyConfig {
ProxyConfig {
upstream_url: String::new(),
upstream_api_key: String::new(),
proxy_api_key: None,
router: None,
}
}
#[test]
fn test_should_return_none_when_auth_disabled() {
let config = config_without_key();
let headers = HeaderMap::new();
assert!(check_auth(&config, &headers).is_none());
}
#[test]
fn test_should_accept_matching_x_api_key() {
let config = config_with_key("secret");
let mut headers = HeaderMap::new();
headers.insert("x-api-key", "secret".parse().unwrap());
assert!(check_auth(&config, &headers).is_none());
}
#[test]
fn test_should_accept_matching_bearer_token() {
let config = config_with_key("secret");
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer secret".parse().unwrap());
assert!(check_auth(&config, &headers).is_none());
}
#[test]
fn test_should_reject_missing_key() {
let config = config_with_key("secret");
let headers = HeaderMap::new();
assert_eq!(
check_auth(&config, &headers),
Some(StatusCode::UNAUTHORIZED)
);
}
#[test]
fn test_should_reject_wrong_key() {
let config = config_with_key("secret");
let mut headers = HeaderMap::new();
headers.insert("x-api-key", "wrong".parse().unwrap());
assert_eq!(
check_auth(&config, &headers),
Some(StatusCode::UNAUTHORIZED)
);
}
#[test]
fn test_constant_time_eq_should_match_equal_strings() {
assert!(constant_time_eq("hello", "hello"));
}
#[test]
fn test_constant_time_eq_should_reject_different_strings() {
assert!(!constant_time_eq("hello", "hell0"));
}
#[test]
fn test_constant_time_eq_should_reject_different_lengths() {
assert!(!constant_time_eq("short", "much longer"));
}
}