Skip to main content

gatel_core/hoops/
auth.rs

1use std::net::IpAddr;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use dashmap::DashMap;
6use http::header::{AUTHORIZATION, WWW_AUTHENTICATE};
7use salvo::http::StatusCode;
8use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
9use tracing::debug;
10
11use crate::config::BasicAuthUser;
12use crate::crypto::constant_time_eq;
13use crate::encoding::base64_decode;
14
15/// Basic authentication middleware.
16///
17/// Checks the `Authorization: Basic <base64>` header against a list of
18/// configured users. Passwords may be stored as:
19///   - Plaintext (if the hash does not start with a known prefix)
20///   - Bcrypt hashes (starting with `$2b$`, `$2a$`, or `$2y$`)
21///   - Argon2 hashes (starting with `$argon2id$` or `$argon2i$`; requires `argon2` feature)
22///   - Scrypt hashes (starting with `$scrypt$`; requires `scrypt` feature)
23///   - PBKDF2 hashes (starting with `$pbkdf2-sha256$`; requires `pbkdf2` feature)
24///
25/// Returns 401 Unauthorized with a `WWW-Authenticate` challenge if
26/// authentication fails.
27///
28/// Optionally enforces IP-based brute-force protection: after
29/// `brute_force_max` consecutive failures the client IP is blocked for
30/// `brute_force_window`.  Returns 429 Too Many Requests when blocked.
31pub struct BasicAuthHoop {
32    users: Vec<AuthUser>,
33    realm: String,
34    /// Maximum number of consecutive failures before lockout.
35    brute_force_max: u32,
36    /// Lockout duration after exceeding `brute_force_max`.
37    brute_force_window: Duration,
38    /// Failure counters per client IP: (consecutive_failures, last_attempt).
39    failure_map: Arc<DashMap<IpAddr, (u32, Instant)>>,
40}
41
42enum HashType {
43    Plaintext,
44    Bcrypt,
45    Argon2,
46    Scrypt,
47    Pbkdf2,
48}
49
50struct AuthUser {
51    username: String,
52    password_hash: String,
53    hash_type: HashType,
54}
55
56impl BasicAuthHoop {
57    pub fn new(
58        users: &[BasicAuthUser],
59        brute_force_max: Option<u32>,
60        brute_force_window: Option<Duration>,
61    ) -> Self {
62        let users = users
63            .iter()
64            .map(|u| {
65                let hash_type = if u.password_hash.starts_with("$2b$")
66                    || u.password_hash.starts_with("$2a$")
67                    || u.password_hash.starts_with("$2y$")
68                {
69                    HashType::Bcrypt
70                } else if u.password_hash.starts_with("$argon2id$")
71                    || u.password_hash.starts_with("$argon2i$")
72                {
73                    HashType::Argon2
74                } else if u.password_hash.starts_with("$scrypt$") {
75                    HashType::Scrypt
76                } else if u.password_hash.starts_with("$pbkdf2-sha256$") {
77                    HashType::Pbkdf2
78                } else {
79                    HashType::Plaintext
80                };
81                AuthUser {
82                    username: u.username.clone(),
83                    password_hash: u.password_hash.clone(),
84                    hash_type,
85                }
86            })
87            .collect();
88        Self {
89            users,
90            realm: "gatel".to_string(),
91            brute_force_max: brute_force_max.unwrap_or(5),
92            brute_force_window: brute_force_window.unwrap_or(Duration::from_secs(300)),
93            failure_map: Arc::new(DashMap::new()),
94        }
95    }
96}
97
98#[async_trait]
99impl salvo::Handler for BasicAuthHoop {
100    async fn handle(
101        &self,
102        req: &mut Request,
103        depot: &mut Depot,
104        res: &mut Response,
105        ctrl: &mut FlowCtrl,
106    ) {
107        let client_ip = super::client_addr(req).ip();
108
109        // Check if this IP is currently blocked by brute-force protection.
110        if let Some(entry) = self.failure_map.get(&client_ip) {
111            let (count, last_attempt) = *entry;
112            if count >= self.brute_force_max && last_attempt.elapsed() < self.brute_force_window {
113                debug!(
114                    ip = %client_ip,
115                    failures = count,
116                    "IP blocked due to brute-force protection, returning 429"
117                );
118                res.status_code(StatusCode::TOO_MANY_REQUESTS);
119                res.body("Too Many Requests");
120                ctrl.skip_rest();
121                return;
122            }
123        }
124
125        // Extract and decode the Authorization header.
126        let credentials = match extract_basic_credentials(req.headers()) {
127            Some(creds) => creds,
128            None => {
129                debug!("no Authorization header, returning 401");
130                record_failure(&self.failure_map, client_ip);
131                set_unauthorized(res, &self.realm);
132                ctrl.skip_rest();
133                return;
134            }
135        };
136
137        // Verify credentials against configured users.
138        let authenticated = self
139            .users
140            .iter()
141            .any(|user| verify_user(user, &credentials.0, &credentials.1));
142
143        if !authenticated {
144            debug!(
145                username = credentials.0.as_str(),
146                "authentication failed, returning 401"
147            );
148            record_failure(&self.failure_map, client_ip);
149            set_unauthorized(res, &self.realm);
150            ctrl.skip_rest();
151            return;
152        }
153
154        // Successful auth — reset failure counter and store user in depot.
155        self.failure_map.remove(&client_ip);
156        debug!(username = credentials.0.as_str(), "authenticated");
157        depot.insert("auth_user", credentials.0.clone());
158        ctrl.call_next(req, depot, res).await;
159    }
160}
161
162/// Increment the consecutive failure counter for a client IP.
163fn record_failure(map: &DashMap<IpAddr, (u32, Instant)>, ip: IpAddr) {
164    let mut entry = map.entry(ip).or_insert((0, Instant::now()));
165    entry.0 += 1;
166    entry.1 = Instant::now();
167}
168
169/// Extract (username, password) from an `Authorization: Basic <base64>` header.
170fn extract_basic_credentials(headers: &http::HeaderMap) -> Option<(String, String)> {
171    let header_value = headers.get(AUTHORIZATION)?.to_str().ok()?;
172    let encoded = header_value.strip_prefix("Basic ")?;
173
174    // Decode base64.
175    let decoded_bytes = base64_decode(encoded)?;
176    let decoded = String::from_utf8(decoded_bytes).ok()?;
177
178    // Split on first ':' — username:password.
179    let (username, password) = decoded.split_once(':')?;
180    Some((username.to_string(), password.to_string()))
181}
182
183/// Set a 401 Unauthorized response with WWW-Authenticate header.
184fn set_unauthorized(res: &mut Response, realm: &str) {
185    res.status_code(StatusCode::UNAUTHORIZED);
186    let _ = res.add_header(
187        WWW_AUTHENTICATE,
188        format!("Basic realm=\"{realm}\", charset=\"UTF-8\""),
189        true,
190    );
191    res.body("Unauthorized");
192}
193
194/// Verify a user's password against their stored hash/plaintext.
195fn verify_user(user: &AuthUser, username: &str, password: &str) -> bool {
196    if user.username != username {
197        return false;
198    }
199
200    match user.hash_type {
201        HashType::Bcrypt => {
202            #[cfg(feature = "bcrypt")]
203            {
204                bcrypt::verify(password, &user.password_hash).unwrap_or(false)
205            }
206            #[cfg(not(feature = "bcrypt"))]
207            {
208                tracing::warn!(
209                    "bcrypt password hash found but bcrypt feature is not enabled, rejecting"
210                );
211                false
212            }
213        }
214        HashType::Argon2 => {
215            #[cfg(feature = "argon2")]
216            {
217                use argon2::Argon2;
218                use password_hash::{PasswordHash, PasswordVerifier};
219                let parsed = match PasswordHash::new(&user.password_hash) {
220                    Ok(h) => h,
221                    Err(_) => return false,
222                };
223                Argon2::default()
224                    .verify_password(password.as_bytes(), &parsed)
225                    .is_ok()
226            }
227            #[cfg(not(feature = "argon2"))]
228            {
229                tracing::warn!(
230                    "argon2 password hash found but argon2 feature is not enabled, rejecting"
231                );
232                false
233            }
234        }
235        HashType::Scrypt => {
236            #[cfg(feature = "scrypt")]
237            {
238                use password_hash::{PasswordHash, PasswordVerifier};
239                use scrypt::Scrypt;
240                let parsed = match PasswordHash::new(&user.password_hash) {
241                    Ok(h) => h,
242                    Err(_) => return false,
243                };
244                Scrypt.verify_password(password.as_bytes(), &parsed).is_ok()
245            }
246            #[cfg(not(feature = "scrypt"))]
247            {
248                tracing::warn!(
249                    "scrypt password hash found but scrypt feature is not enabled, rejecting"
250                );
251                false
252            }
253        }
254        HashType::Pbkdf2 => {
255            #[cfg(feature = "pbkdf2")]
256            {
257                use password_hash::{PasswordHash, PasswordVerifier};
258                use pbkdf2::Pbkdf2;
259                let parsed = match PasswordHash::new(&user.password_hash) {
260                    Ok(h) => h,
261                    Err(_) => return false,
262                };
263                Pbkdf2.verify_password(password.as_bytes(), &parsed).is_ok()
264            }
265            #[cfg(not(feature = "pbkdf2"))]
266            {
267                tracing::warn!(
268                    "pbkdf2 password hash found but pbkdf2 feature is not enabled, rejecting"
269                );
270                false
271            }
272        }
273        HashType::Plaintext => {
274            // Plaintext comparison — constant-time-ish via byte comparison.
275            constant_time_eq(password.as_bytes(), user.password_hash.as_bytes())
276        }
277    }
278}