Skip to main content

hammerwork_web/
auth.rs

1//! Authentication middleware for the web dashboard.
2//!
3//! This module provides authentication functionality including basic auth verification,
4//! rate limiting for failed attempts, and account lockout mechanisms.
5//!
6//! # Examples
7//!
8//! ## Basic Authentication Setup
9//!
10//! ```rust
11//! use hammerwork_web::auth::{AuthState, extract_basic_auth};
12//! use hammerwork_web::config::AuthConfig;
13//!
14//! let auth_config = AuthConfig {
15//!     enabled: true,
16//!     username: "admin".to_string(),
17//!     password_hash: "plain_password".to_string(), // Use bcrypt in production
18//!     ..Default::default()
19//! };
20//!
21//! let auth_state = AuthState::new(auth_config);
22//! assert!(auth_state.is_enabled());
23//! ```
24//!
25//! ## Extracting Basic Auth Credentials
26//!
27//! ```rust
28//! use hammerwork_web::auth::extract_basic_auth;
29//!
30//! // "admin:password" in base64 is "YWRtaW46cGFzc3dvcmQ="
31//! let auth_header = "Basic YWRtaW46cGFzc3dvcmQ=";
32//! let (username, password) = extract_basic_auth(auth_header).unwrap();
33//!
34//! assert_eq!(username, "admin");
35//! assert_eq!(password, "password");
36//!
37//! // Invalid format returns None
38//! let result = extract_basic_auth("Bearer token123");
39//! assert!(result.is_none());
40//! ```
41
42use crate::config::AuthConfig;
43use base64::Engine;
44use std::collections::HashMap;
45use std::sync::Arc;
46use tokio::sync::RwLock;
47use warp::{Filter, Rejection, Reply};
48
49/// Authentication middleware state
50#[derive(Clone)]
51pub struct AuthState {
52    config: AuthConfig,
53    failed_attempts: Arc<RwLock<HashMap<String, (u32, std::time::Instant)>>>,
54}
55
56impl AuthState {
57    pub fn new(config: AuthConfig) -> Self {
58        Self {
59            config,
60            failed_attempts: Arc::new(RwLock::new(HashMap::new())),
61        }
62    }
63
64    /// Check if authentication is enabled
65    pub fn is_enabled(&self) -> bool {
66        self.config.enabled
67    }
68
69    /// Verify credentials
70    pub async fn verify_credentials(&self, username: &str, password: &str) -> bool {
71        if !self.config.enabled {
72            return true; // No auth required
73        }
74
75        // Check if user is locked out
76        if self.is_locked_out(username).await {
77            return false;
78        }
79
80        let valid = username == self.config.username && self.verify_password(password);
81
82        if !valid {
83            self.record_failed_attempt(username).await;
84        } else {
85            self.clear_failed_attempts(username).await;
86        }
87
88        valid
89    }
90
91    /// Verify password against stored hash
92    fn verify_password(&self, password: &str) -> bool {
93        #[cfg(feature = "auth")]
94        {
95            bcrypt::verify(password, &self.config.password_hash).unwrap_or(false)
96        }
97        #[cfg(not(feature = "auth"))]
98        {
99            // Fallback to plain text comparison (not recommended for production)
100            password == self.config.password_hash
101        }
102    }
103
104    /// Check if user is currently locked out
105    async fn is_locked_out(&self, username: &str) -> bool {
106        let attempts = self.failed_attempts.read().await;
107        if let Some((count, last_attempt)) = attempts.get(username) {
108            if *count >= self.config.max_failed_attempts {
109                let elapsed = last_attempt.elapsed();
110                return elapsed < self.config.lockout_duration;
111            }
112        }
113        false
114    }
115
116    /// Record a failed login attempt
117    async fn record_failed_attempt(&self, username: &str) {
118        let mut attempts = self.failed_attempts.write().await;
119        let default_entry = (0, std::time::Instant::now());
120        let (count, _) = attempts.get(username).unwrap_or(&default_entry);
121        let new_count = *count;
122        attempts.insert(
123            username.to_string(),
124            (new_count + 1, std::time::Instant::now()),
125        );
126    }
127
128    /// Clear failed attempts for successful login
129    async fn clear_failed_attempts(&self, username: &str) {
130        let mut attempts = self.failed_attempts.write().await;
131        attempts.remove(username);
132    }
133
134    /// Clean up old failed attempts periodically
135    pub async fn cleanup_expired_attempts(&self) {
136        let mut attempts = self.failed_attempts.write().await;
137        let now = std::time::Instant::now();
138        attempts.retain(|_, (_, last_attempt)| {
139            now.duration_since(*last_attempt) < self.config.lockout_duration * 2
140        });
141    }
142}
143
144/// Extract basic auth credentials from request.
145///
146/// Parses a Basic Authentication header and returns the username and password.
147/// The header format should be: `Basic <base64-encoded-credentials>`
148/// where credentials are in the format `username:password`.
149///
150/// # Examples
151///
152/// ```rust
153/// use hammerwork_web::auth::extract_basic_auth;
154///
155/// // Valid basic auth header
156/// let auth_header = "Basic YWRtaW46cGFzc3dvcmQ="; // admin:password
157/// let (username, password) = extract_basic_auth(auth_header).unwrap();
158/// assert_eq!(username, "admin");
159/// assert_eq!(password, "password");
160///
161/// // Invalid format returns None
162/// assert!(extract_basic_auth("Bearer token123").is_none());
163/// assert!(extract_basic_auth("Basic invalid_base64").is_none());
164/// ```
165///
166/// # Returns
167///
168/// - `Some((username, password))` if the header is valid
169/// - `None` if the header is malformed or not a Basic auth header
170pub fn extract_basic_auth(auth_header: &str) -> Option<(String, String)> {
171    if !auth_header.starts_with("Basic ") {
172        return None;
173    }
174
175    let encoded = &auth_header[6..];
176    let decoded = ::base64::prelude::BASE64_STANDARD.decode(encoded).ok()?;
177    let decoded_str = String::from_utf8(decoded).ok()?;
178
179    let mut parts = decoded_str.splitn(2, ':');
180    let username = parts.next()?.to_string();
181    let password = parts.next()?.to_string();
182
183    Some((username, password))
184}
185
186/// Authentication filter for Warp
187pub fn auth_filter(
188    auth_state: AuthState,
189) -> impl Filter<Extract = ((),), Error = Rejection> + Clone {
190    warp::header::optional::<String>("authorization").and_then(
191        move |auth_header: Option<String>| {
192            let auth_state = auth_state.clone();
193            async move {
194                if !auth_state.is_enabled() {
195                    return Ok::<_, Rejection>(());
196                }
197
198                let auth_header = auth_header
199                    .ok_or_else(|| warp::reject::custom(AuthError::MissingCredentials))?;
200
201                let (username, password) = extract_basic_auth(&auth_header)
202                    .ok_or_else(|| warp::reject::custom(AuthError::InvalidFormat))?;
203
204                if auth_state.verify_credentials(&username, &password).await {
205                    Ok(())
206                } else {
207                    Err(warp::reject::custom(AuthError::InvalidCredentials))
208                }
209            }
210        },
211    )
212}
213
214/// Custom authentication errors
215#[derive(Debug)]
216pub enum AuthError {
217    MissingCredentials,
218    InvalidFormat,
219    InvalidCredentials,
220    AccountLocked,
221}
222
223impl warp::reject::Reject for AuthError {}
224
225/// Handle authentication rejections
226pub async fn handle_auth_rejection(
227    err: Rejection,
228) -> Result<Box<dyn Reply>, std::convert::Infallible> {
229    if let Some(auth_error) = err.find::<AuthError>() {
230        match auth_error {
231            AuthError::MissingCredentials => {
232                let response = warp::reply::with_header(
233                    warp::reply::with_status(
234                        "Authentication required",
235                        warp::http::StatusCode::UNAUTHORIZED,
236                    ),
237                    "WWW-Authenticate",
238                    "Basic realm=\"Hammerwork Dashboard\"",
239                );
240                Ok(Box::new(response))
241            }
242            AuthError::InvalidFormat => {
243                let error_response = serde_json::json!({"error": "Invalid authentication format"});
244                Ok(Box::new(warp::reply::with_status(
245                    warp::reply::json(&error_response),
246                    warp::http::StatusCode::BAD_REQUEST,
247                )))
248            }
249            AuthError::InvalidCredentials => {
250                let error_response = serde_json::json!({"error": "Invalid credentials"});
251                Ok(Box::new(warp::reply::with_status(
252                    warp::reply::json(&error_response),
253                    warp::http::StatusCode::UNAUTHORIZED,
254                )))
255            }
256            AuthError::AccountLocked => {
257                let error_response = serde_json::json!({"error": "Account temporarily locked"});
258                Ok(Box::new(warp::reply::with_status(
259                    warp::reply::json(&error_response),
260                    warp::http::StatusCode::UNAUTHORIZED,
261                )))
262            }
263        }
264    } else {
265        // Not an auth error, return generic error
266        let error_response = serde_json::json!({"error": "Internal server error"});
267        Ok(Box::new(warp::reply::with_status(
268            warp::reply::json(&error_response),
269            warp::http::StatusCode::INTERNAL_SERVER_ERROR,
270        )))
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use std::time::Duration;
278
279    #[tokio::test]
280    async fn test_auth_state_creation() {
281        let config = AuthConfig {
282            enabled: true,
283            username: "testuser".to_string(),
284            password_hash: "testhash".to_string(),
285            ..Default::default()
286        };
287
288        let auth_state = AuthState::new(config);
289        assert!(auth_state.is_enabled());
290    }
291
292    #[tokio::test]
293    async fn test_disabled_auth() {
294        let config = AuthConfig {
295            enabled: false,
296            ..Default::default()
297        };
298
299        let auth_state = AuthState::new(config);
300        assert!(!auth_state.is_enabled());
301        assert!(auth_state.verify_credentials("anyone", "anything").await);
302    }
303
304    #[tokio::test]
305    async fn test_failed_attempts_tracking() {
306        let config = AuthConfig {
307            enabled: true,
308            username: "admin".to_string(),
309            password_hash: "wronghash".to_string(),
310            max_failed_attempts: 3,
311            lockout_duration: Duration::from_secs(60),
312            ..Default::default()
313        };
314
315        let auth_state = AuthState::new(config);
316
317        // Verify multiple failed attempts
318        for _ in 0..3 {
319            assert!(!auth_state.verify_credentials("admin", "wrongpass").await);
320        }
321
322        // Should be locked out now
323        assert!(auth_state.is_locked_out("admin").await);
324    }
325
326    #[test]
327    fn test_extract_basic_auth() {
328        // "admin:password" in base64 is "YWRtaW46cGFzc3dvcmQ="
329        let auth_header = "Basic YWRtaW46cGFzc3dvcmQ=";
330        let (username, password) = extract_basic_auth(auth_header).unwrap();
331        assert_eq!(username, "admin");
332        assert_eq!(password, "password");
333    }
334
335    #[test]
336    fn test_extract_basic_auth_invalid() {
337        assert!(extract_basic_auth("Bearer token").is_none());
338        assert!(extract_basic_auth("Basic invalid").is_none());
339    }
340}