use std::sync::Arc;
use axum::extract::{Request, State};
use axum::http::{HeaderValue, StatusCode, header};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use base64::Engine;
use base64::engine::general_purpose::STANDARD;
pub(crate) const DEFAULT_ADMIN_USER: &str = "admin";
pub(crate) const DEFAULT_ADMIN_PASSWORD: &str = "flusso";
#[derive(Clone)]
pub(crate) struct BasicAuth {
user: String,
password: String,
}
impl std::fmt::Debug for BasicAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BasicAuth")
.field("user", &self.user)
.finish_non_exhaustive()
}
}
impl BasicAuth {
pub(crate) fn new(user: String, password: String) -> Self {
Self { user, password }
}
pub(crate) fn uses_default_password(&self) -> bool {
self.password == DEFAULT_ADMIN_PASSWORD
}
fn check(&self, header: &HeaderValue) -> bool {
let Some(encoded) = header
.to_str()
.ok()
.and_then(|value| value.strip_prefix("Basic "))
else {
return false;
};
let Ok(decoded) = STANDARD.decode(encoded.trim()) else {
return false;
};
let Ok(decoded) = String::from_utf8(decoded) else {
return false;
};
let Some((user, password)) = decoded.split_once(':') else {
return false;
};
ct_eq(user.as_bytes(), self.user.as_bytes())
& ct_eq(password.as_bytes(), self.password.as_bytes())
}
}
fn ct_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b) {
diff |= x ^ y;
}
diff == 0
}
pub(crate) async fn require_basic_auth(
State(auth): State<Arc<BasicAuth>>,
request: Request,
next: Next,
) -> Response {
match request.headers().get(header::AUTHORIZATION) {
Some(value) if auth.check(value) => next.run(request).await,
_ => (
StatusCode::UNAUTHORIZED,
[(
header::WWW_AUTHENTICATE,
r#"Basic realm="flusso", charset="UTF-8""#,
)],
"unauthorized\n",
)
.into_response(),
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests;