1use crate::config::AuthConfig;
43use base64::Engine;
44use std::collections::HashMap;
45use std::sync::Arc;
46use tokio::sync::RwLock;
47use warp::{Filter, Rejection, Reply};
48
49#[derive(Clone)]
51pub struct AuthState {
52 config: AuthConfig,
53 failed_attempts: Arc<RwLock<HashMap<String, (u32, std::time::Instant)>>>,
54}
55
56impl AuthState {
57 pub fn new(config: AuthConfig) -> Self {
58 Self {
59 config,
60 failed_attempts: Arc::new(RwLock::new(HashMap::new())),
61 }
62 }
63
64 pub fn is_enabled(&self) -> bool {
66 self.config.enabled
67 }
68
69 pub async fn verify_credentials(&self, username: &str, password: &str) -> bool {
71 if !self.config.enabled {
72 return true; }
74
75 if self.is_locked_out(username).await {
77 return false;
78 }
79
80 let valid = username == self.config.username && self.verify_password(password);
81
82 if !valid {
83 self.record_failed_attempt(username).await;
84 } else {
85 self.clear_failed_attempts(username).await;
86 }
87
88 valid
89 }
90
91 fn verify_password(&self, password: &str) -> bool {
93 #[cfg(feature = "auth")]
94 {
95 bcrypt::verify(password, &self.config.password_hash).unwrap_or(false)
96 }
97 #[cfg(not(feature = "auth"))]
98 {
99 password == self.config.password_hash
101 }
102 }
103
104 async fn is_locked_out(&self, username: &str) -> bool {
106 let attempts = self.failed_attempts.read().await;
107 if let Some((count, last_attempt)) = attempts.get(username) {
108 if *count >= self.config.max_failed_attempts {
109 let elapsed = last_attempt.elapsed();
110 return elapsed < self.config.lockout_duration;
111 }
112 }
113 false
114 }
115
116 async fn record_failed_attempt(&self, username: &str) {
118 let mut attempts = self.failed_attempts.write().await;
119 let default_entry = (0, std::time::Instant::now());
120 let (count, _) = attempts.get(username).unwrap_or(&default_entry);
121 let new_count = *count;
122 attempts.insert(
123 username.to_string(),
124 (new_count + 1, std::time::Instant::now()),
125 );
126 }
127
128 async fn clear_failed_attempts(&self, username: &str) {
130 let mut attempts = self.failed_attempts.write().await;
131 attempts.remove(username);
132 }
133
134 pub async fn cleanup_expired_attempts(&self) {
136 let mut attempts = self.failed_attempts.write().await;
137 let now = std::time::Instant::now();
138 attempts.retain(|_, (_, last_attempt)| {
139 now.duration_since(*last_attempt) < self.config.lockout_duration * 2
140 });
141 }
142}
143
144pub fn extract_basic_auth(auth_header: &str) -> Option<(String, String)> {
171 if !auth_header.starts_with("Basic ") {
172 return None;
173 }
174
175 let encoded = &auth_header[6..];
176 let decoded = ::base64::prelude::BASE64_STANDARD.decode(encoded).ok()?;
177 let decoded_str = String::from_utf8(decoded).ok()?;
178
179 let mut parts = decoded_str.splitn(2, ':');
180 let username = parts.next()?.to_string();
181 let password = parts.next()?.to_string();
182
183 Some((username, password))
184}
185
186pub fn auth_filter(
188 auth_state: AuthState,
189) -> impl Filter<Extract = ((),), Error = Rejection> + Clone {
190 warp::header::optional::<String>("authorization").and_then(
191 move |auth_header: Option<String>| {
192 let auth_state = auth_state.clone();
193 async move {
194 if !auth_state.is_enabled() {
195 return Ok::<_, Rejection>(());
196 }
197
198 let auth_header = auth_header
199 .ok_or_else(|| warp::reject::custom(AuthError::MissingCredentials))?;
200
201 let (username, password) = extract_basic_auth(&auth_header)
202 .ok_or_else(|| warp::reject::custom(AuthError::InvalidFormat))?;
203
204 if auth_state.verify_credentials(&username, &password).await {
205 Ok(())
206 } else {
207 Err(warp::reject::custom(AuthError::InvalidCredentials))
208 }
209 }
210 },
211 )
212}
213
214#[derive(Debug)]
216pub enum AuthError {
217 MissingCredentials,
218 InvalidFormat,
219 InvalidCredentials,
220 AccountLocked,
221}
222
223impl warp::reject::Reject for AuthError {}
224
225pub async fn handle_auth_rejection(
227 err: Rejection,
228) -> Result<Box<dyn Reply>, std::convert::Infallible> {
229 if let Some(auth_error) = err.find::<AuthError>() {
230 match auth_error {
231 AuthError::MissingCredentials => {
232 let response = warp::reply::with_header(
233 warp::reply::with_status(
234 "Authentication required",
235 warp::http::StatusCode::UNAUTHORIZED,
236 ),
237 "WWW-Authenticate",
238 "Basic realm=\"Hammerwork Dashboard\"",
239 );
240 Ok(Box::new(response))
241 }
242 AuthError::InvalidFormat => {
243 let error_response = serde_json::json!({"error": "Invalid authentication format"});
244 Ok(Box::new(warp::reply::with_status(
245 warp::reply::json(&error_response),
246 warp::http::StatusCode::BAD_REQUEST,
247 )))
248 }
249 AuthError::InvalidCredentials => {
250 let error_response = serde_json::json!({"error": "Invalid credentials"});
251 Ok(Box::new(warp::reply::with_status(
252 warp::reply::json(&error_response),
253 warp::http::StatusCode::UNAUTHORIZED,
254 )))
255 }
256 AuthError::AccountLocked => {
257 let error_response = serde_json::json!({"error": "Account temporarily locked"});
258 Ok(Box::new(warp::reply::with_status(
259 warp::reply::json(&error_response),
260 warp::http::StatusCode::UNAUTHORIZED,
261 )))
262 }
263 }
264 } else {
265 let error_response = serde_json::json!({"error": "Internal server error"});
267 Ok(Box::new(warp::reply::with_status(
268 warp::reply::json(&error_response),
269 warp::http::StatusCode::INTERNAL_SERVER_ERROR,
270 )))
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use std::time::Duration;
278
279 #[tokio::test]
280 async fn test_auth_state_creation() {
281 let config = AuthConfig {
282 enabled: true,
283 username: "testuser".to_string(),
284 password_hash: "testhash".to_string(),
285 ..Default::default()
286 };
287
288 let auth_state = AuthState::new(config);
289 assert!(auth_state.is_enabled());
290 }
291
292 #[tokio::test]
293 async fn test_disabled_auth() {
294 let config = AuthConfig {
295 enabled: false,
296 ..Default::default()
297 };
298
299 let auth_state = AuthState::new(config);
300 assert!(!auth_state.is_enabled());
301 assert!(auth_state.verify_credentials("anyone", "anything").await);
302 }
303
304 #[tokio::test]
305 async fn test_failed_attempts_tracking() {
306 let config = AuthConfig {
307 enabled: true,
308 username: "admin".to_string(),
309 password_hash: "wronghash".to_string(),
310 max_failed_attempts: 3,
311 lockout_duration: Duration::from_secs(60),
312 ..Default::default()
313 };
314
315 let auth_state = AuthState::new(config);
316
317 for _ in 0..3 {
319 assert!(!auth_state.verify_credentials("admin", "wrongpass").await);
320 }
321
322 assert!(auth_state.is_locked_out("admin").await);
324 }
325
326 #[test]
327 fn test_extract_basic_auth() {
328 let auth_header = "Basic YWRtaW46cGFzc3dvcmQ=";
330 let (username, password) = extract_basic_auth(auth_header).unwrap();
331 assert_eq!(username, "admin");
332 assert_eq!(password, "password");
333 }
334
335 #[test]
336 fn test_extract_basic_auth_invalid() {
337 assert!(extract_basic_auth("Bearer token").is_none());
338 assert!(extract_basic_auth("Basic invalid").is_none());
339 }
340}