Skip to main content

tuitbot_server/auth/
middleware.rs

1//! Multi-strategy authentication middleware.
2//!
3//! Checks in order:
4//! 1. `Authorization: Bearer <token>` header → matches file-based API token
5//! 2. `tuitbot_session` cookie → SHA-256 hash lookup in sessions table
6//! 3. Neither → 401 Unauthorized
7//!
8//! For cookie-authenticated requests, mutating methods (POST/PATCH/DELETE/PUT)
9//! require a valid `X-CSRF-Token` header matching the session's CSRF token.
10
11use std::sync::Arc;
12
13use axum::extract::{Request, State};
14use axum::http::{HeaderMap, Method, StatusCode};
15use axum::middleware::Next;
16use axum::response::{IntoResponse, Response};
17use serde_json::json;
18use tuitbot_core::auth::session;
19
20use crate::state::AppState;
21
22/// Extract the session cookie value from headers.
23fn extract_session_cookie(headers: &HeaderMap) -> Option<String> {
24    headers
25        .get("cookie")
26        .and_then(|v| v.to_str().ok())
27        .and_then(|cookies| {
28            cookies.split(';').find_map(|c| {
29                let c = c.trim();
30                c.strip_prefix("tuitbot_session=").map(|v| v.to_string())
31            })
32        })
33}
34
35/// Routes exempt from authentication.
36const AUTH_EXEMPT_PATHS: &[&str] = &[
37    "/health",
38    "/api/health",
39    "/settings/status",
40    "/api/settings/status",
41    "/settings/init",
42    "/api/settings/init",
43    "/settings/test-llm",
44    "/api/settings/test-llm",
45    "/ws",
46    "/api/ws",
47    "/auth/login",
48    "/api/auth/login",
49    "/auth/status",
50    "/api/auth/status",
51    "/connectors/google-drive/callback",
52    "/api/connectors/google-drive/callback",
53    // Media file serving uses path-traversal protection (`is_safe_media_path`)
54    // and must be exempt so <img>/<video> src attributes work without auth headers.
55    "/media/file",
56    "/api/media/file",
57];
58
59/// Axum middleware that enforces multi-strategy authentication.
60pub async fn auth_middleware(
61    State(state): State<Arc<AppState>>,
62    headers: HeaderMap,
63    request: Request,
64    next: Next,
65) -> Response {
66    let path = request.uri().path();
67
68    // Skip auth for exempt endpoints.
69    if AUTH_EXEMPT_PATHS.contains(&path) {
70        return next.run(request).await;
71    }
72
73    // Strategy 1: Bearer token
74    let bearer_ok = headers
75        .get("authorization")
76        .and_then(|v| v.to_str().ok())
77        .and_then(|v| v.strip_prefix("Bearer "))
78        .is_some_and(|token| token == state.api_token);
79
80    if bearer_ok {
81        return next.run(request).await;
82    }
83
84    // Strategy 2: Session cookie
85    if let Some(session_token) = extract_session_cookie(&headers) {
86        match session::validate_session(&state.db, &session_token).await {
87            Ok(Some(sess)) => {
88                // CSRF check for mutating methods
89                let method = request.method().clone();
90                if method == Method::POST
91                    || method == Method::PATCH
92                    || method == Method::DELETE
93                    || method == Method::PUT
94                {
95                    let csrf_ok = headers
96                        .get("x-csrf-token")
97                        .and_then(|v| v.to_str().ok())
98                        .is_some_and(|t| t == sess.csrf_token);
99
100                    if !csrf_ok {
101                        return (
102                            StatusCode::FORBIDDEN,
103                            axum::Json(json!({"error": "missing or invalid CSRF token"})),
104                        )
105                            .into_response();
106                    }
107                }
108                return next.run(request).await;
109            }
110            Ok(None) => { /* session not found or expired — fall through to 401 */ }
111            Err(e) => {
112                tracing::error!(error = %e, "Session validation failed");
113            }
114        }
115    }
116
117    // Neither strategy succeeded.
118    (
119        StatusCode::UNAUTHORIZED,
120        axum::Json(json!({"error": "unauthorized"})),
121    )
122        .into_response()
123}