use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use http::header::{AUTHORIZATION, WWW_AUTHENTICATE};
use salvo::http::StatusCode;
use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
use tracing::debug;
use crate::config::BasicAuthUser;
use crate::crypto::constant_time_eq;
use crate::encoding::base64_decode;
pub struct BasicAuthHoop {
users: Vec<AuthUser>,
realm: String,
brute_force_max: u32,
brute_force_window: Duration,
failure_map: Arc<DashMap<IpAddr, (u32, Instant)>>,
}
enum HashType {
Plaintext,
Bcrypt,
Argon2,
Scrypt,
Pbkdf2,
}
struct AuthUser {
username: String,
password_hash: String,
hash_type: HashType,
}
impl BasicAuthHoop {
pub fn new(
users: &[BasicAuthUser],
brute_force_max: Option<u32>,
brute_force_window: Option<Duration>,
) -> Self {
let users = users
.iter()
.map(|u| {
let hash_type = if u.password_hash.starts_with("$2b$")
|| u.password_hash.starts_with("$2a$")
|| u.password_hash.starts_with("$2y$")
{
HashType::Bcrypt
} else if u.password_hash.starts_with("$argon2id$")
|| u.password_hash.starts_with("$argon2i$")
{
HashType::Argon2
} else if u.password_hash.starts_with("$scrypt$") {
HashType::Scrypt
} else if u.password_hash.starts_with("$pbkdf2-sha256$") {
HashType::Pbkdf2
} else {
HashType::Plaintext
};
AuthUser {
username: u.username.clone(),
password_hash: u.password_hash.clone(),
hash_type,
}
})
.collect();
Self {
users,
realm: "gatel".to_string(),
brute_force_max: brute_force_max.unwrap_or(5),
brute_force_window: brute_force_window.unwrap_or(Duration::from_secs(300)),
failure_map: Arc::new(DashMap::new()),
}
}
}
#[async_trait]
impl salvo::Handler for BasicAuthHoop {
async fn handle(
&self,
req: &mut Request,
depot: &mut Depot,
res: &mut Response,
ctrl: &mut FlowCtrl,
) {
let client_ip = super::client_addr(req).ip();
if let Some(entry) = self.failure_map.get(&client_ip) {
let (count, last_attempt) = *entry;
if count >= self.brute_force_max && last_attempt.elapsed() < self.brute_force_window {
debug!(
ip = %client_ip,
failures = count,
"IP blocked due to brute-force protection, returning 429"
);
res.status_code(StatusCode::TOO_MANY_REQUESTS);
res.body("Too Many Requests");
ctrl.skip_rest();
return;
}
}
let credentials = match extract_basic_credentials(req.headers()) {
Some(creds) => creds,
None => {
debug!("no Authorization header, returning 401");
record_failure(&self.failure_map, client_ip);
set_unauthorized(res, &self.realm);
ctrl.skip_rest();
return;
}
};
let authenticated = self
.users
.iter()
.any(|user| verify_user(user, &credentials.0, &credentials.1));
if !authenticated {
debug!(
username = credentials.0.as_str(),
"authentication failed, returning 401"
);
record_failure(&self.failure_map, client_ip);
set_unauthorized(res, &self.realm);
ctrl.skip_rest();
return;
}
self.failure_map.remove(&client_ip);
debug!(username = credentials.0.as_str(), "authenticated");
depot.insert("auth_user", credentials.0.clone());
ctrl.call_next(req, depot, res).await;
}
}
fn record_failure(map: &DashMap<IpAddr, (u32, Instant)>, ip: IpAddr) {
let mut entry = map.entry(ip).or_insert((0, Instant::now()));
entry.0 += 1;
entry.1 = Instant::now();
}
fn extract_basic_credentials(headers: &http::HeaderMap) -> Option<(String, String)> {
let header_value = headers.get(AUTHORIZATION)?.to_str().ok()?;
let encoded = header_value.strip_prefix("Basic ")?;
let decoded_bytes = base64_decode(encoded)?;
let decoded = String::from_utf8(decoded_bytes).ok()?;
let (username, password) = decoded.split_once(':')?;
Some((username.to_string(), password.to_string()))
}
fn set_unauthorized(res: &mut Response, realm: &str) {
res.status_code(StatusCode::UNAUTHORIZED);
let _ = res.add_header(
WWW_AUTHENTICATE,
format!("Basic realm=\"{realm}\", charset=\"UTF-8\""),
true,
);
res.body("Unauthorized");
}
fn verify_user(user: &AuthUser, username: &str, password: &str) -> bool {
if user.username != username {
return false;
}
match user.hash_type {
HashType::Bcrypt => {
#[cfg(feature = "bcrypt")]
{
bcrypt::verify(password, &user.password_hash).unwrap_or(false)
}
#[cfg(not(feature = "bcrypt"))]
{
tracing::warn!(
"bcrypt password hash found but bcrypt feature is not enabled, rejecting"
);
false
}
}
HashType::Argon2 => {
#[cfg(feature = "argon2")]
{
use argon2::Argon2;
use password_hash::{PasswordHash, PasswordVerifier};
let parsed = match PasswordHash::new(&user.password_hash) {
Ok(h) => h,
Err(_) => return false,
};
Argon2::default()
.verify_password(password.as_bytes(), &parsed)
.is_ok()
}
#[cfg(not(feature = "argon2"))]
{
tracing::warn!(
"argon2 password hash found but argon2 feature is not enabled, rejecting"
);
false
}
}
HashType::Scrypt => {
#[cfg(feature = "scrypt")]
{
use password_hash::{PasswordHash, PasswordVerifier};
use scrypt::Scrypt;
let parsed = match PasswordHash::new(&user.password_hash) {
Ok(h) => h,
Err(_) => return false,
};
Scrypt.verify_password(password.as_bytes(), &parsed).is_ok()
}
#[cfg(not(feature = "scrypt"))]
{
tracing::warn!(
"scrypt password hash found but scrypt feature is not enabled, rejecting"
);
false
}
}
HashType::Pbkdf2 => {
#[cfg(feature = "pbkdf2")]
{
use password_hash::{PasswordHash, PasswordVerifier};
use pbkdf2::Pbkdf2;
let parsed = match PasswordHash::new(&user.password_hash) {
Ok(h) => h,
Err(_) => return false,
};
Pbkdf2.verify_password(password.as_bytes(), &parsed).is_ok()
}
#[cfg(not(feature = "pbkdf2"))]
{
tracing::warn!(
"pbkdf2 password hash found but pbkdf2 feature is not enabled, rejecting"
);
false
}
}
HashType::Plaintext => {
constant_time_eq(password.as_bytes(), user.password_hash.as_bytes())
}
}
}