1use std::net::IpAddr;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use dashmap::DashMap;
6use http::header::{AUTHORIZATION, WWW_AUTHENTICATE};
7use salvo::http::StatusCode;
8use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
9use tracing::debug;
10
11use crate::config::BasicAuthUser;
12use crate::crypto::constant_time_eq;
13use crate::encoding::base64_decode;
14
15pub struct BasicAuthHoop {
32 users: Vec<AuthUser>,
33 realm: String,
34 brute_force_max: u32,
36 brute_force_window: Duration,
38 failure_map: Arc<DashMap<IpAddr, (u32, Instant)>>,
40}
41
42enum HashType {
43 Plaintext,
44 Bcrypt,
45 Argon2,
46 Scrypt,
47 Pbkdf2,
48}
49
50struct AuthUser {
51 username: String,
52 password_hash: String,
53 hash_type: HashType,
54}
55
56impl BasicAuthHoop {
57 pub fn new(
58 users: &[BasicAuthUser],
59 brute_force_max: Option<u32>,
60 brute_force_window: Option<Duration>,
61 ) -> Self {
62 let users = users
63 .iter()
64 .map(|u| {
65 let hash_type = if u.password_hash.starts_with("$2b$")
66 || u.password_hash.starts_with("$2a$")
67 || u.password_hash.starts_with("$2y$")
68 {
69 HashType::Bcrypt
70 } else if u.password_hash.starts_with("$argon2id$")
71 || u.password_hash.starts_with("$argon2i$")
72 {
73 HashType::Argon2
74 } else if u.password_hash.starts_with("$scrypt$") {
75 HashType::Scrypt
76 } else if u.password_hash.starts_with("$pbkdf2-sha256$") {
77 HashType::Pbkdf2
78 } else {
79 HashType::Plaintext
80 };
81 AuthUser {
82 username: u.username.clone(),
83 password_hash: u.password_hash.clone(),
84 hash_type,
85 }
86 })
87 .collect();
88 Self {
89 users,
90 realm: "gatel".to_string(),
91 brute_force_max: brute_force_max.unwrap_or(5),
92 brute_force_window: brute_force_window.unwrap_or(Duration::from_secs(300)),
93 failure_map: Arc::new(DashMap::new()),
94 }
95 }
96}
97
98#[async_trait]
99impl salvo::Handler for BasicAuthHoop {
100 async fn handle(
101 &self,
102 req: &mut Request,
103 depot: &mut Depot,
104 res: &mut Response,
105 ctrl: &mut FlowCtrl,
106 ) {
107 let client_ip = super::client_addr(req).ip();
108
109 if let Some(entry) = self.failure_map.get(&client_ip) {
111 let (count, last_attempt) = *entry;
112 if count >= self.brute_force_max && last_attempt.elapsed() < self.brute_force_window {
113 debug!(
114 ip = %client_ip,
115 failures = count,
116 "IP blocked due to brute-force protection, returning 429"
117 );
118 res.status_code(StatusCode::TOO_MANY_REQUESTS);
119 res.body("Too Many Requests");
120 ctrl.skip_rest();
121 return;
122 }
123 }
124
125 let credentials = match extract_basic_credentials(req.headers()) {
127 Some(creds) => creds,
128 None => {
129 debug!("no Authorization header, returning 401");
130 record_failure(&self.failure_map, client_ip);
131 set_unauthorized(res, &self.realm);
132 ctrl.skip_rest();
133 return;
134 }
135 };
136
137 let authenticated = self
139 .users
140 .iter()
141 .any(|user| verify_user(user, &credentials.0, &credentials.1));
142
143 if !authenticated {
144 debug!(
145 username = credentials.0.as_str(),
146 "authentication failed, returning 401"
147 );
148 record_failure(&self.failure_map, client_ip);
149 set_unauthorized(res, &self.realm);
150 ctrl.skip_rest();
151 return;
152 }
153
154 self.failure_map.remove(&client_ip);
156 debug!(username = credentials.0.as_str(), "authenticated");
157 depot.insert("auth_user", credentials.0.clone());
158 ctrl.call_next(req, depot, res).await;
159 }
160}
161
162fn record_failure(map: &DashMap<IpAddr, (u32, Instant)>, ip: IpAddr) {
164 let mut entry = map.entry(ip).or_insert((0, Instant::now()));
165 entry.0 += 1;
166 entry.1 = Instant::now();
167}
168
169fn extract_basic_credentials(headers: &http::HeaderMap) -> Option<(String, String)> {
171 let header_value = headers.get(AUTHORIZATION)?.to_str().ok()?;
172 let encoded = header_value.strip_prefix("Basic ")?;
173
174 let decoded_bytes = base64_decode(encoded)?;
176 let decoded = String::from_utf8(decoded_bytes).ok()?;
177
178 let (username, password) = decoded.split_once(':')?;
180 Some((username.to_string(), password.to_string()))
181}
182
183fn set_unauthorized(res: &mut Response, realm: &str) {
185 res.status_code(StatusCode::UNAUTHORIZED);
186 let _ = res.add_header(
187 WWW_AUTHENTICATE,
188 format!("Basic realm=\"{realm}\", charset=\"UTF-8\""),
189 true,
190 );
191 res.body("Unauthorized");
192}
193
194fn verify_user(user: &AuthUser, username: &str, password: &str) -> bool {
196 if user.username != username {
197 return false;
198 }
199
200 match user.hash_type {
201 HashType::Bcrypt => {
202 #[cfg(feature = "bcrypt")]
203 {
204 bcrypt::verify(password, &user.password_hash).unwrap_or(false)
205 }
206 #[cfg(not(feature = "bcrypt"))]
207 {
208 tracing::warn!(
209 "bcrypt password hash found but bcrypt feature is not enabled, rejecting"
210 );
211 false
212 }
213 }
214 HashType::Argon2 => {
215 #[cfg(feature = "argon2")]
216 {
217 use argon2::Argon2;
218 use password_hash::{PasswordHash, PasswordVerifier};
219 let parsed = match PasswordHash::new(&user.password_hash) {
220 Ok(h) => h,
221 Err(_) => return false,
222 };
223 Argon2::default()
224 .verify_password(password.as_bytes(), &parsed)
225 .is_ok()
226 }
227 #[cfg(not(feature = "argon2"))]
228 {
229 tracing::warn!(
230 "argon2 password hash found but argon2 feature is not enabled, rejecting"
231 );
232 false
233 }
234 }
235 HashType::Scrypt => {
236 #[cfg(feature = "scrypt")]
237 {
238 use password_hash::{PasswordHash, PasswordVerifier};
239 use scrypt::Scrypt;
240 let parsed = match PasswordHash::new(&user.password_hash) {
241 Ok(h) => h,
242 Err(_) => return false,
243 };
244 Scrypt.verify_password(password.as_bytes(), &parsed).is_ok()
245 }
246 #[cfg(not(feature = "scrypt"))]
247 {
248 tracing::warn!(
249 "scrypt password hash found but scrypt feature is not enabled, rejecting"
250 );
251 false
252 }
253 }
254 HashType::Pbkdf2 => {
255 #[cfg(feature = "pbkdf2")]
256 {
257 use password_hash::{PasswordHash, PasswordVerifier};
258 use pbkdf2::Pbkdf2;
259 let parsed = match PasswordHash::new(&user.password_hash) {
260 Ok(h) => h,
261 Err(_) => return false,
262 };
263 Pbkdf2.verify_password(password.as_bytes(), &parsed).is_ok()
264 }
265 #[cfg(not(feature = "pbkdf2"))]
266 {
267 tracing::warn!(
268 "pbkdf2 password hash found but pbkdf2 feature is not enabled, rejecting"
269 );
270 false
271 }
272 }
273 HashType::Plaintext => {
274 constant_time_eq(password.as_bytes(), user.password_hash.as_bytes())
276 }
277 }
278}