Skip to main content

fraiseql_server/middleware/
auth.rs

1//! Authentication middleware.
2//!
3//! Provides bearer token authentication for protected endpoints.
4
5use std::sync::Arc;
6
7use axum::{
8    body::Body,
9    extract::State,
10    http::{Request, StatusCode, header},
11    middleware::Next,
12    response::{IntoResponse, Response},
13};
14use subtle::ConstantTimeEq as _;
15
16/// Shared state for bearer token authentication.
17#[derive(Clone)]
18pub struct BearerAuthState {
19    /// Expected bearer token.
20    pub token: Arc<String>,
21}
22
23impl BearerAuthState {
24    /// Create new bearer auth state.
25    #[must_use]
26    pub fn new(token: String) -> Self {
27        Self {
28            token: Arc::new(token),
29        }
30    }
31}
32
33/// Bearer token authentication middleware.
34///
35/// Validates that requests include a valid `Authorization: Bearer <token>` header.
36///
37/// # Response
38///
39/// - **401 Unauthorized**: Missing or malformed Authorization header
40/// - **403 Forbidden**: Invalid token
41///
42/// # Example
43///
44/// ```text
45/// // Requires: running Axum application with a route handler.
46/// use axum::{Router, middleware};
47/// use fraiseql_server::middleware::{bearer_auth_middleware, BearerAuthState};
48///
49/// let auth_state = BearerAuthState::new("my-secret-token".to_string());
50///
51/// let app = Router::new()
52///     .route("/protected", get(handler))
53///     .layer(middleware::from_fn_with_state(auth_state, bearer_auth_middleware));
54/// ```
55pub async fn bearer_auth_middleware(
56    State(auth_state): State<BearerAuthState>,
57    request: Request<Body>,
58    next: Next,
59) -> Response {
60    // Extract Authorization header
61    let auth_header = request
62        .headers()
63        .get(header::AUTHORIZATION)
64        .and_then(|value| value.to_str().ok());
65
66    match auth_header {
67        None => {
68            return (
69                StatusCode::UNAUTHORIZED,
70                [(header::WWW_AUTHENTICATE, "Bearer")],
71                "Missing Authorization header",
72            )
73                .into_response();
74        },
75        Some(header_value) => {
76            // Check for "Bearer " prefix
77            if !header_value.starts_with("Bearer ") {
78                return (
79                    StatusCode::UNAUTHORIZED,
80                    [(header::WWW_AUTHENTICATE, "Bearer")],
81                    "Invalid Authorization header format. Expected: Bearer <token>",
82                )
83                    .into_response();
84            }
85
86            // Extract token
87            let token = &header_value[7..]; // Skip "Bearer "
88
89            // Constant-time comparison to prevent timing attacks
90            if !constant_time_compare(token, &auth_state.token) {
91                return (StatusCode::FORBIDDEN, "Invalid token").into_response();
92            }
93        },
94    }
95
96    // Token is valid, proceed with request
97    next.run(request).await
98}
99
100/// Extract the bearer token from an `Authorization` header value.
101///
102/// Returns `Some(token)` if the header has the `Bearer ` prefix (with trailing space),
103/// `None` for all other formats (Basic, Digest, missing prefix, etc.).
104///
105/// Exposed as `pub` for property testing.
106pub fn extract_bearer_token(header_value: &str) -> Option<&str> {
107    header_value.strip_prefix("Bearer ")
108}
109
110/// Constant-time string comparison to prevent timing attacks.
111///
112/// Uses [`subtle::ConstantTimeEq`] to compare the byte representations of
113/// both strings, preventing the compiler from optimising the comparison into
114/// an early-exit branch that would leak information about where the strings
115/// differ (timing oracle, RFC 6749 §10.12).
116///
117/// Strings of different lengths return `false` without inspecting bytes;
118/// token lengths are considered non-secret (administrators choose them).
119fn constant_time_compare(a: &str, b: &str) -> bool {
120    a.as_bytes().ct_eq(b.as_bytes()).into()
121}
122
123#[cfg(test)]
124mod tests {
125    #![allow(clippy::unwrap_used)] // Reason: test code, panics acceptable
126    #![allow(clippy::cast_precision_loss)] // Reason: test metrics reporting
127    #![allow(clippy::cast_sign_loss)] // Reason: test data uses small positive integers
128    #![allow(clippy::cast_possible_truncation)] // Reason: test data values are bounded
129    #![allow(clippy::cast_possible_wrap)] // Reason: test data values are bounded
130    #![allow(clippy::missing_panics_doc)] // Reason: test helpers
131    #![allow(clippy::missing_errors_doc)] // Reason: test helpers
132    #![allow(missing_docs)] // Reason: test code
133    #![allow(clippy::items_after_statements)] // Reason: test helpers defined near use site
134
135    use axum::{
136        Router,
137        body::Body,
138        http::{Request, StatusCode},
139        middleware,
140        routing::get,
141    };
142    use tower::ServiceExt;
143
144    use super::*;
145
146    async fn protected_handler() -> &'static str {
147        "secret data"
148    }
149
150    fn create_test_app(token: &str) -> Router {
151        let auth_state = BearerAuthState::new(token.to_string());
152
153        Router::new()
154            .route("/protected", get(protected_handler))
155            .layer(middleware::from_fn_with_state(auth_state, bearer_auth_middleware))
156    }
157
158    #[tokio::test]
159    async fn test_valid_token_allows_access() {
160        let app = create_test_app("secret-token-12345");
161
162        let request = Request::builder()
163            .uri("/protected")
164            .header("Authorization", "Bearer secret-token-12345")
165            .body(Body::empty())
166            .unwrap();
167
168        let response = app.oneshot(request).await.unwrap();
169
170        assert_eq!(response.status(), StatusCode::OK);
171    }
172
173    #[tokio::test]
174    async fn test_missing_auth_header_returns_401() {
175        let app = create_test_app("secret-token-12345");
176
177        let request = Request::builder().uri("/protected").body(Body::empty()).unwrap();
178
179        let response = app.oneshot(request).await.unwrap();
180
181        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
182        assert!(response.headers().contains_key("www-authenticate"));
183    }
184
185    #[tokio::test]
186    async fn test_invalid_auth_format_returns_401() {
187        let app = create_test_app("secret-token-12345");
188
189        let request = Request::builder()
190            .uri("/protected")
191            .header("Authorization", "Basic dXNlcjpwYXNz") // Basic auth, not Bearer
192            .body(Body::empty())
193            .unwrap();
194
195        let response = app.oneshot(request).await.unwrap();
196
197        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
198    }
199
200    #[tokio::test]
201    async fn test_wrong_token_returns_403() {
202        let app = create_test_app("secret-token-12345");
203
204        let request = Request::builder()
205            .uri("/protected")
206            .header("Authorization", "Bearer wrong-token")
207            .body(Body::empty())
208            .unwrap();
209
210        let response = app.oneshot(request).await.unwrap();
211
212        assert_eq!(response.status(), StatusCode::FORBIDDEN);
213    }
214
215    #[tokio::test]
216    async fn test_empty_bearer_token_returns_403() {
217        let app = create_test_app("secret-token-12345");
218
219        let request = Request::builder()
220            .uri("/protected")
221            .header("Authorization", "Bearer ")
222            .body(Body::empty())
223            .unwrap();
224
225        let response = app.oneshot(request).await.unwrap();
226
227        assert_eq!(response.status(), StatusCode::FORBIDDEN);
228    }
229
230    #[test]
231    fn test_constant_time_compare_equal() {
232        assert!(constant_time_compare("hello", "hello"));
233        assert!(constant_time_compare("", ""));
234        assert!(constant_time_compare("a-long-token-123", "a-long-token-123"));
235    }
236
237    #[test]
238    fn test_constant_time_compare_not_equal() {
239        assert!(!constant_time_compare("hello", "world"));
240        assert!(!constant_time_compare("hello", "hello!"));
241        assert!(!constant_time_compare("hello", "hell"));
242        assert!(!constant_time_compare("abc", "abd"));
243    }
244
245    #[test]
246    fn test_constant_time_compare_different_lengths() {
247        assert!(!constant_time_compare("short", "longer-string"));
248        assert!(!constant_time_compare("", "notempty"));
249    }
250
251    // ── subtle-based comparison tests (15-1) ────────────────────────────────
252
253    #[test]
254    fn test_subtle_compare_identical_tokens() {
255        // Verify the subtle-based helper accepts identical tokens of various lengths.
256        assert!(constant_time_compare("x", "x"));
257        assert!(constant_time_compare(
258            "super-secret-32-char-admin-token",
259            "super-secret-32-char-admin-token"
260        ));
261    }
262
263    #[test]
264    fn test_subtle_compare_off_by_one_byte() {
265        // A single byte difference anywhere must be rejected.
266        assert!(!constant_time_compare("token-abc", "token-abd")); // last byte differs
267        assert!(!constant_time_compare("Aoken-abc", "token-abc")); // first byte differs
268    }
269
270    #[test]
271    fn test_subtle_compare_empty_strings() {
272        // Two empty strings are equal; empty vs non-empty is not.
273        assert!(constant_time_compare("", ""));
274        assert!(!constant_time_compare("", "a"));
275        assert!(!constant_time_compare("a", ""));
276    }
277}