bytedocs_rs/core/
session_auth.rs

1use crate::core::types::AuthConfig;
2use askama::Template;
3use axum::{
4    extract::Request,
5    http::StatusCode,
6    middleware::Next,
7    response::{Html, IntoResponse, Response},
8};
9use dashmap::DashMap;
10use serde::Deserialize;
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13use subtle::ConstantTimeEq;
14use tokio::time::{interval, Duration};
15use uuid::Uuid;
16
17#[derive(Template)]
18#[template(path = "auth/login.html")]
19struct LoginTemplate {
20    error: String,
21}
22
23#[derive(Template)]
24#[template(path = "auth/banned.html")]
25struct BannedTemplate {
26    max_attempts: i32,
27    ban_duration: i32,
28    client_ip: String,
29    blocked_at: String,
30}
31
32#[derive(Template)]
33#[template(path = "auth/config_error.html")]
34struct ConfigErrorTemplate {
35    error_title: String,
36    error_message: String,
37    error_details: Vec<String>,
38}
39
40#[derive(Deserialize)]
41struct LoginForm {
42    password: String,
43}
44
45pub struct SessionAuthMiddleware {
46    config: AuthConfig,
47    sessions: Arc<DashMap<String, i64>>,  // session ID -> auth time
48    ip_bans: Arc<DashMap<String, i64>>,   // IP -> ban expiry time
49    attempts: Arc<DashMap<String, i32>>,  // IP -> attempt count
50}
51
52impl SessionAuthMiddleware {
53    pub async fn new(config: &AuthConfig) -> anyhow::Result<Self> {
54        if config.r#type != "session" {
55            return Err(anyhow::anyhow!("invalid config for session auth"));
56        }
57
58        let middleware = Self {
59            config: config.clone(),
60            sessions: Arc::new(DashMap::new()),
61            ip_bans: Arc::new(DashMap::new()),
62            attempts: Arc::new(DashMap::new()),
63        };
64
65        // Start cleanup routine
66        middleware.start_cleanup_routine().await;
67
68        Ok(middleware)
69    }
70
71    pub async fn handle(&self, request: Request, next: Next) -> Result<Response, StatusCode> {
72        // Skip auth if disabled
73        if !self.config.enabled {
74            return Ok(next.run(request).await);
75        }
76
77        // Validate that password is configured
78        if self.config.password.is_empty() {
79            return Ok(self.render_config_error().into_response());
80        }
81
82        let ip = self.get_client_ip(&request);
83        let session_id = self.get_session_id(&request);
84
85        // Check if IP is banned
86        if self.is_ip_banned(&ip) {
87            return Ok(self.render_banned(&ip).into_response());
88        }
89
90        // Check if already authenticated
91        if self.is_authenticated(&session_id) {
92            return Ok(next.run(request).await);
93        }
94
95        // Handle POST request (login form submission)
96        if request.method() == "POST" {
97            // Extract form data from request body
98            let body = axum::body::to_bytes(request.into_body(), usize::MAX).await
99                .map_err(|_| StatusCode::BAD_REQUEST)?;
100
101            if let Ok(form_data) = serde_urlencoded::from_bytes::<LoginForm>(&body) {
102                if !form_data.password.is_empty() {
103                    return self.handle_login_simple(next, &ip, &session_id, &form_data.password).await;
104                }
105            }
106        }
107
108        // Show login form
109        Ok(self.render_login("").into_response())
110    }
111
112    fn get_client_ip(&self, request: &Request) -> String {
113        // Check X-Forwarded-For header
114        if let Some(forwarded) = request.headers().get("x-forwarded-for") {
115            if let Ok(forwarded_str) = forwarded.to_str() {
116                if let Some(first_ip) = forwarded_str.split(',').next() {
117                    return first_ip.trim().to_string();
118                }
119            }
120        }
121
122        // Check X-Real-IP header
123        if let Some(real_ip) = request.headers().get("x-real-ip") {
124            if let Ok(real_ip_str) = real_ip.to_str() {
125                return real_ip_str.to_string();
126            }
127        }
128
129        // Fallback to connection info (this would need to be passed in real implementation)
130        "127.0.0.1".to_string()
131    }
132
133    fn get_session_id(&self, request: &Request) -> String {
134        if let Some(cookie_header) = request.headers().get("cookie") {
135            if let Ok(cookie_str) = cookie_header.to_str() {
136                for cookie in cookie_str.split(';') {
137                    let cookie = cookie.trim();
138                    if cookie.starts_with("bytedocs_session=") {
139                        return cookie[17..].to_string();
140                    }
141                }
142            }
143        }
144        String::new()
145    }
146
147    fn is_ip_banned(&self, ip: &str) -> bool {
148        if !self.config.ip_ban_enabled {
149            return false;
150        }
151
152        // Check if IP is whitelisted
153        for whitelist_ip in &self.config.admin_whitelist_ips {
154            if ip == whitelist_ip {
155                return false;
156            }
157        }
158
159        if let Some(ban_expiry) = self.ip_bans.get(ip) {
160            let now = SystemTime::now()
161                .duration_since(UNIX_EPOCH)
162                .unwrap()
163                .as_secs() as i64;
164
165            // Check if ban has expired
166            if now > *ban_expiry {
167                self.ip_bans.remove(ip);
168                self.attempts.remove(ip);
169                return false;
170            }
171            return true;
172        }
173
174        false
175    }
176
177    fn is_authenticated(&self, session_id: &str) -> bool {
178        if session_id.is_empty() {
179            return false;
180        }
181
182        if let Some(auth_time) = self.sessions.get(session_id) {
183            let now = SystemTime::now()
184                .duration_since(UNIX_EPOCH)
185                .unwrap()
186                .as_secs() as i64;
187
188            // Check session expiration
189            let expiration_time = *auth_time + (self.config.session_expire as i64 * 60);
190            if now > expiration_time {
191                self.sessions.remove(session_id);
192                return false;
193            }
194            return true;
195        }
196
197        false
198    }
199
200    async fn handle_login_simple(
201        &self,
202        _next: Next,
203        ip: &str,
204        session_id: &str,
205        password: &str,
206    ) -> Result<Response, StatusCode> {
207        // Check password using constant-time comparison
208        if password.as_bytes().ct_eq(self.config.password.as_bytes()).unwrap_u8() == 1 {
209            // Success - clear attempts and set session
210            self.attempts.remove(ip);
211
212            // Generate session ID if not exists
213            let session_id = if session_id.is_empty() {
214                self.generate_session_id()
215            } else {
216                session_id.to_string()
217            };
218
219            let now = SystemTime::now()
220                .duration_since(UNIX_EPOCH)
221                .unwrap()
222                .as_secs() as i64;
223
224            self.sessions.insert(session_id.clone(), now);
225
226            // Create a redirect response with session cookie
227            let mut response = axum::response::Response::new("Login successful".into());
228            response.headers_mut().insert("location", "/docs".parse().unwrap());
229            *response.status_mut() = StatusCode::SEE_OTHER;
230
231            // Add session cookie to headers
232            let cookie = format!(
233                "bytedocs_session={}; Path=/; HttpOnly; Max-Age={}",
234                session_id,
235                self.config.session_expire * 60
236            );
237
238            response.headers_mut().insert("set-cookie", cookie.parse().unwrap());
239
240            return Ok(response);
241        }
242
243        // Failed login - increment attempts
244        let attempts = self.attempts.entry(ip.to_string())
245            .and_modify(|v| *v += 1)
246            .or_insert(1);
247
248        let current_attempts = *attempts;
249
250        // Ban IP if max attempts reached (unless whitelisted)
251        if current_attempts >= self.config.ip_ban_max_attempts && self.config.ip_ban_enabled {
252            let is_whitelisted = self.config.admin_whitelist_ips.contains(&ip.to_string());
253
254            if !is_whitelisted {
255                let ban_expiry = SystemTime::now()
256                    .duration_since(UNIX_EPOCH)
257                    .unwrap()
258                    .as_secs() as i64 + (self.config.ip_ban_duration as i64 * 60);
259
260                self.ip_bans.insert(ip.to_string(), ban_expiry);
261                self.attempts.remove(ip);
262
263                return Ok(self.render_banned(ip).into_response());
264            } else {
265                // If IP is whitelisted, just reset attempts instead of banning
266                self.attempts.remove(ip);
267            }
268        }
269
270        // Show error
271        let remaining_attempts = self.config.ip_ban_max_attempts - current_attempts;
272        let error_message = format!("Password salah. Sisa percobaan: {}", remaining_attempts);
273
274        Ok(self.render_login(&error_message).into_response())
275    }
276
277    fn render_login(&self, error: &str) -> Html<String> {
278        let template = LoginTemplate {
279            error: error.to_string(),
280        };
281        Html(template.render().unwrap_or_else(|_| "Template error".to_string()))
282    }
283
284    fn render_banned(&self, ip: &str) -> Html<String> {
285        let blocked_at = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
286        let template = BannedTemplate {
287            max_attempts: self.config.ip_ban_max_attempts,
288            ban_duration: self.config.ip_ban_duration,
289            client_ip: ip.to_string(),
290            blocked_at,
291        };
292        Html(template.render().unwrap_or_else(|_| "Template error".to_string()))
293    }
294
295    fn render_config_error(&self) -> Html<String> {
296        let template = ConfigErrorTemplate {
297            error_title: "Authentication Not Configured".to_string(),
298            error_message: "Bytedocs authentication is enabled but no password is configured.".to_string(),
299            error_details: vec![
300                "Please set BYTEDOCS_AUTH_PASSWORD in your environment variables".to_string(),
301                "Or disable authentication by setting BYTEDOCS_AUTH_ENABLED=false".to_string(),
302                "Check your configuration settings".to_string(),
303            ],
304        };
305        Html(template.render().unwrap_or_else(|_| "Template error".to_string()))
306    }
307
308    async fn start_cleanup_routine(&self) {
309        let sessions = Arc::clone(&self.sessions);
310        let ip_bans = Arc::clone(&self.ip_bans);
311        let attempts = Arc::clone(&self.attempts);
312        let session_expire = self.config.session_expire;
313
314        tokio::spawn(async move {
315            let mut interval = interval(Duration::from_secs(600)); // 10 minutes
316
317            loop {
318                interval.tick().await;
319
320                let now = SystemTime::now()
321                    .duration_since(UNIX_EPOCH)
322                    .unwrap()
323                    .as_secs() as i64;
324
325                // Clean up expired sessions
326                sessions.retain(|_, auth_time| {
327                    now <= *auth_time + (session_expire as i64 * 60)
328                });
329
330                // Clean up expired bans
331                ip_bans.retain(|ip, ban_expiry| {
332                    if now > *ban_expiry {
333                        attempts.remove(ip);
334                        false
335                    } else {
336                        true
337                    }
338                });
339            }
340        });
341    }
342
343    fn generate_session_id(&self) -> String {
344        Uuid::new_v4().to_string()
345    }
346}