fraiseql_server/middleware/
auth.rs1use std::sync::Arc;
6
7use axum::{
8 body::Body,
9 extract::State,
10 http::{Request, StatusCode, header},
11 middleware::Next,
12 response::{IntoResponse, Response},
13};
14use subtle::ConstantTimeEq as _;
15
16#[derive(Clone)]
18pub struct BearerAuthState {
19 pub token: Arc<String>,
21}
22
23impl BearerAuthState {
24 #[must_use]
26 pub fn new(token: String) -> Self {
27 Self {
28 token: Arc::new(token),
29 }
30 }
31}
32
33pub async fn bearer_auth_middleware(
56 State(auth_state): State<BearerAuthState>,
57 request: Request<Body>,
58 next: Next,
59) -> Response {
60 let auth_header = request
62 .headers()
63 .get(header::AUTHORIZATION)
64 .and_then(|value| value.to_str().ok());
65
66 match auth_header {
67 None => {
68 return (
69 StatusCode::UNAUTHORIZED,
70 [(header::WWW_AUTHENTICATE, "Bearer")],
71 "Missing Authorization header",
72 )
73 .into_response();
74 },
75 Some(header_value) => {
76 if !header_value.starts_with("Bearer ") {
78 return (
79 StatusCode::UNAUTHORIZED,
80 [(header::WWW_AUTHENTICATE, "Bearer")],
81 "Invalid Authorization header format. Expected: Bearer <token>",
82 )
83 .into_response();
84 }
85
86 let token = &header_value[7..]; if !constant_time_compare(token, &auth_state.token) {
91 return (StatusCode::FORBIDDEN, "Invalid token").into_response();
92 }
93 },
94 }
95
96 next.run(request).await
98}
99
100pub fn extract_bearer_token(header_value: &str) -> Option<&str> {
107 header_value.strip_prefix("Bearer ")
108}
109
110fn constant_time_compare(a: &str, b: &str) -> bool {
120 a.as_bytes().ct_eq(b.as_bytes()).into()
121}
122
123#[cfg(test)]
124mod tests {
125 #![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)] use axum::{
136 Router,
137 body::Body,
138 http::{Request, StatusCode},
139 middleware,
140 routing::get,
141 };
142 use tower::ServiceExt;
143
144 use super::*;
145
146 async fn protected_handler() -> &'static str {
147 "secret data"
148 }
149
150 fn create_test_app(token: &str) -> Router {
151 let auth_state = BearerAuthState::new(token.to_string());
152
153 Router::new()
154 .route("/protected", get(protected_handler))
155 .layer(middleware::from_fn_with_state(auth_state, bearer_auth_middleware))
156 }
157
158 #[tokio::test]
159 async fn test_valid_token_allows_access() {
160 let app = create_test_app("secret-token-12345");
161
162 let request = Request::builder()
163 .uri("/protected")
164 .header("Authorization", "Bearer secret-token-12345")
165 .body(Body::empty())
166 .unwrap();
167
168 let response = app.oneshot(request).await.unwrap();
169
170 assert_eq!(response.status(), StatusCode::OK);
171 }
172
173 #[tokio::test]
174 async fn test_missing_auth_header_returns_401() {
175 let app = create_test_app("secret-token-12345");
176
177 let request = Request::builder().uri("/protected").body(Body::empty()).unwrap();
178
179 let response = app.oneshot(request).await.unwrap();
180
181 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
182 assert!(response.headers().contains_key("www-authenticate"));
183 }
184
185 #[tokio::test]
186 async fn test_invalid_auth_format_returns_401() {
187 let app = create_test_app("secret-token-12345");
188
189 let request = Request::builder()
190 .uri("/protected")
191 .header("Authorization", "Basic dXNlcjpwYXNz") .body(Body::empty())
193 .unwrap();
194
195 let response = app.oneshot(request).await.unwrap();
196
197 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
198 }
199
200 #[tokio::test]
201 async fn test_wrong_token_returns_403() {
202 let app = create_test_app("secret-token-12345");
203
204 let request = Request::builder()
205 .uri("/protected")
206 .header("Authorization", "Bearer wrong-token")
207 .body(Body::empty())
208 .unwrap();
209
210 let response = app.oneshot(request).await.unwrap();
211
212 assert_eq!(response.status(), StatusCode::FORBIDDEN);
213 }
214
215 #[tokio::test]
216 async fn test_empty_bearer_token_returns_403() {
217 let app = create_test_app("secret-token-12345");
218
219 let request = Request::builder()
220 .uri("/protected")
221 .header("Authorization", "Bearer ")
222 .body(Body::empty())
223 .unwrap();
224
225 let response = app.oneshot(request).await.unwrap();
226
227 assert_eq!(response.status(), StatusCode::FORBIDDEN);
228 }
229
230 #[test]
231 fn test_constant_time_compare_equal() {
232 assert!(constant_time_compare("hello", "hello"));
233 assert!(constant_time_compare("", ""));
234 assert!(constant_time_compare("a-long-token-123", "a-long-token-123"));
235 }
236
237 #[test]
238 fn test_constant_time_compare_not_equal() {
239 assert!(!constant_time_compare("hello", "world"));
240 assert!(!constant_time_compare("hello", "hello!"));
241 assert!(!constant_time_compare("hello", "hell"));
242 assert!(!constant_time_compare("abc", "abd"));
243 }
244
245 #[test]
246 fn test_constant_time_compare_different_lengths() {
247 assert!(!constant_time_compare("short", "longer-string"));
248 assert!(!constant_time_compare("", "notempty"));
249 }
250
251 #[test]
254 fn test_subtle_compare_identical_tokens() {
255 assert!(constant_time_compare("x", "x"));
257 assert!(constant_time_compare(
258 "super-secret-32-char-admin-token",
259 "super-secret-32-char-admin-token"
260 ));
261 }
262
263 #[test]
264 fn test_subtle_compare_off_by_one_byte() {
265 assert!(!constant_time_compare("token-abc", "token-abd")); assert!(!constant_time_compare("Aoken-abc", "token-abc")); }
269
270 #[test]
271 fn test_subtle_compare_empty_strings() {
272 assert!(constant_time_compare("", ""));
274 assert!(!constant_time_compare("", "a"));
275 assert!(!constant_time_compare("a", ""));
276 }
277}