use {axum::{body::Body,
http::{Request,
StatusCode},
middleware::Next,
response::{IntoResponse,
Response}},
std::{collections::HashSet,
sync::Arc},
tracing::warn};
#[derive(Clone)]
pub struct BasicAuthConfig {
username: String,
password: String,
realm: String,
excluded_paths: HashSet<String>,
}
impl BasicAuthConfig {
pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
username: username.into(),
password: password.into(),
realm: "Restricted".to_string(),
excluded_paths: HashSet::new(),
}
}
pub fn realm(mut self, realm: impl Into<String>) -> Self {
self.realm = realm.into();
self
}
pub fn exclude(mut self, path: impl Into<String>) -> Self {
self.excluded_paths.insert(path.into());
self
}
pub fn exclude_paths(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
for path in paths {
self.excluded_paths.insert(path.into());
}
self
}
fn is_excluded(&self, path: &str) -> bool {
self.excluded_paths.contains(path)
}
fn validate(&self, username: &str, password: &str) -> bool {
self.username == username && self.password == password
}
}
#[derive(Clone)]
pub struct BasicAuthLayer {
config: Arc<BasicAuthConfig>,
}
impl BasicAuthLayer {
pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
config: Arc::new(BasicAuthConfig::new(username, password)),
}
}
pub fn from_config(config: BasicAuthConfig) -> Self {
Self {
config: Arc::new(config),
}
}
pub fn realm(mut self, realm: impl Into<String>) -> Self {
let config = Arc::make_mut(&mut self.config);
config.realm = realm.into();
self
}
pub fn exclude(mut self, path: impl Into<String>) -> Self {
let config = Arc::make_mut(&mut self.config);
config.excluded_paths.insert(path.into());
self
}
pub fn exclude_paths(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
let config = Arc::make_mut(&mut self.config);
for path in paths {
config.excluded_paths.insert(path.into());
}
self
}
pub fn into_middleware(
self,
) -> impl Fn(Request<Body>, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
+ Clone
+ Send
+ 'static {
let config = self.config;
move |request: Request<Body>, next: Next| {
let config = config.clone();
Box::pin(async move { basic_auth_check(request, next, &config).await })
}
}
}
impl<S> tower::Layer<S> for BasicAuthLayer {
type Service = BasicAuthMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
BasicAuthMiddleware {
inner,
config: self.config.clone(),
}
}
}
#[derive(Clone)]
pub struct BasicAuthMiddleware<S> {
inner: S,
config: Arc<BasicAuthConfig>,
}
impl<S> tower::Service<Request<Body>> for BasicAuthMiddleware<S>
where
S: tower::Service<Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send,
{
type Error = S::Error;
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
type Response = S::Response;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let config = self.config.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
if config.is_excluded(request.uri().path()) {
return inner.call(request).await;
}
if let Some(auth_result) = validate_basic_auth(&request, &config) {
if auth_result {
return inner.call(request).await;
}
}
Ok(unauthorized_response(&config.realm))
})
}
}
fn validate_basic_auth<B>(request: &Request<B>, config: &BasicAuthConfig) -> Option<bool> {
let auth_header = request.headers().get("Authorization").and_then(|h| h.to_str().ok())?;
if !auth_header.starts_with("Basic ") {
return Some(false);
}
let encoded = &auth_header[6 ..];
let decoded = data_encoding::BASE64.decode(encoded.as_bytes()).ok()?;
let credentials = String::from_utf8(decoded).ok()?;
let (username, password) = credentials.split_once(':')?;
Some(config.validate(username, password))
}
fn unauthorized_response(realm: &str) -> Response {
(
StatusCode::UNAUTHORIZED,
[("WWW-Authenticate", format!("Basic realm=\"{}\"", realm))],
"Unauthorized",
)
.into_response()
}
pub async fn basic_auth_middleware(
axum::extract::State(config): axum::extract::State<Arc<BasicAuthConfig>>,
request: Request<Body>,
next: Next,
) -> Response {
basic_auth_check(request, next, &config).await
}
async fn basic_auth_check(request: Request<Body>, next: Next, config: &BasicAuthConfig) -> Response {
if config.is_excluded(request.uri().path()) {
return next.run(request).await;
}
match validate_basic_auth(&request, config) {
| Some(true) => next.run(request).await,
| _ => {
warn!(path = %request.uri().path(), "Unauthorized access attempt");
unauthorized_response(&config.realm)
}
}
}