rust_api/middleware.rs
1//! Request-lifecycle middleware utilities and Tower layer factories.
2//!
3//! This module provides utilities that operate at **request time**, not at
4//! route registration time. These are Tower layer factories — they produce
5//! middleware that wraps individual routes or entire routers.
6//!
7//! # Module Boundaries
8//!
9//! - `pipeline.rs` — build-time: compose routes into a `Router`
10//! - `controller.rs` — build-time: declare which handlers belong to a
11//! controller
12//! - `middleware.rs` — request-time: inspect/modify requests and responses
13//!
14//! # Protected Route Groups
15//!
16//! Auth is a **cross-cutting concern** — it belongs here as a router transform,
17//! not inside a controller handler. Apply it to a route group via `.map()`:
18//!
19//! ```ignore
20//! RouterPipeline::new()
21//! .mount_guarded::<AdminController, _>(admin_svc, || { /* config check */ })
22//! .map(require_bearer(admin_key))
23//! ```
24//!
25//! Or scoped to just a sub-group:
26//!
27//! ```ignore
28//! RouterPipeline::new()
29//! .group("/admin", |g| g
30//! .mount::<AdminController>(admin_svc)
31//! .map(require_bearer(admin_key)) // only admin routes are protected
32//! )
33//! ```
34
35use axum::{body::Body, http::Request, middleware::Next, response::IntoResponse, Router};
36
37// ---------------------------------------------------------------------------
38// require_bearer
39// ---------------------------------------------------------------------------
40
41/// Returns a `Router -> Router` transform that enforces `Authorization: Bearer
42/// <token>` on every request passing through the router it is applied to.
43///
44/// Returns `401 Unauthorized` if the header is absent, malformed, or if the
45/// token does not match `expected` (compared in **constant time** to prevent
46/// timing oracles).
47///
48/// # Usage
49///
50/// Pass directly to `.map()` — the function signature matches `.map()`'s
51/// expected `Fn(Router<()>) -> Router<()>`:
52///
53/// ```ignore
54/// use rust_api::prelude::*;
55///
56/// RouterPipeline::new()
57/// .mount_guarded::<AdminController, _>(admin_svc, || { /* config check */ })
58/// .map(require_bearer(admin_key))
59/// .build()?
60/// ```
61pub fn require_bearer(
62 expected: impl Into<String>,
63) -> impl Fn(Router<()>) -> Router<()> + Clone + Send + 'static {
64 let expected = expected.into();
65 move |router: Router<()>| {
66 let expected = expected.clone();
67 router.layer(axum::middleware::from_fn(
68 move |req: Request<Body>, next: Next| {
69 let expected = expected.clone();
70 async move {
71 let authorized = req
72 .headers()
73 .get(axum::http::header::AUTHORIZATION)
74 .and_then(|v| v.to_str().ok())
75 .and_then(|v| v.strip_prefix("Bearer "))
76 .map(|token| constant_time_eq(token.as_bytes(), expected.as_bytes()))
77 .unwrap_or(false);
78
79 if authorized {
80 next.run(req).await
81 } else {
82 axum::http::StatusCode::UNAUTHORIZED.into_response()
83 }
84 }
85 },
86 ))
87 }
88}
89
90// ---------------------------------------------------------------------------
91// guard
92// ---------------------------------------------------------------------------
93
94/// Returns a `Router -> Router` transform that guards every request with a
95/// predicate.
96///
97/// Returns `403 Forbidden` if `guard_fn(&request)` returns `false`. The
98/// predicate runs before any extractors, so it has access to headers, URI,
99/// and method.
100///
101/// For **authentication**, prefer [`require_bearer`] — it handles the
102/// `Authorization: Bearer` protocol correctly. `guard` is suited for
103/// non-auth predicates (e.g., IP allowlists, feature flags, method
104/// restrictions).
105///
106/// # Usage
107///
108/// Pass directly to `.map()` on the pipeline:
109///
110/// ```ignore
111/// use rust_api::prelude::*;
112///
113/// RouterPipeline::new()
114/// .mount::<MyController>(svc)
115/// .map(guard(|req| is_allowed_ip(req)))
116/// .build()?
117/// ```
118pub fn guard<G>(guard_fn: G) -> impl Fn(Router<()>) -> Router<()> + Clone + Send + 'static
119where
120 G: Fn(&Request<Body>) -> bool + Clone + Send + Sync + 'static,
121{
122 move |router: Router<()>| {
123 let guard_fn = guard_fn.clone();
124 router.layer(axum::middleware::from_fn(
125 move |req: Request<Body>, next: Next| {
126 let guard_fn = guard_fn.clone();
127 async move {
128 if guard_fn(&req) {
129 next.run(req).await
130 } else {
131 axum::http::StatusCode::FORBIDDEN.into_response()
132 }
133 }
134 },
135 ))
136 }
137}
138
139// ---------------------------------------------------------------------------
140// Internal helpers
141// ---------------------------------------------------------------------------
142
143/// Constant-time byte-slice equality — prevents timing oracle attacks.
144///
145/// XORs every byte of both slices (zero-padded to the longer length) and
146/// accumulates the differences. No early exit: a short token cannot
147/// short-circuit the comparison.
148fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
149 let len = a.len().max(b.len());
150 let mut diff: u8 = 0;
151 for i in 0..len {
152 let ab = a.get(i).copied().unwrap_or(0);
153 let bb = b.get(i).copied().unwrap_or(0);
154 diff |= ab ^ bb;
155 }
156 diff == 0
157}
158
159// ---------------------------------------------------------------------------
160// Tests
161// ---------------------------------------------------------------------------
162
163#[cfg(test)]
164mod tests {
165 use axum::{body::Body, http::Request, routing::get, Router};
166 use http_body_util::BodyExt;
167 use tower::ServiceExt;
168
169 use super::*;
170
171 // -----------------------------------------------------------------------
172 // constant_time_eq
173 // -----------------------------------------------------------------------
174
175 #[test]
176 fn ct_eq_identical_slices() {
177 assert!(constant_time_eq(b"secret", b"secret"));
178 }
179
180 #[test]
181 fn ct_eq_different_slices() {
182 assert!(!constant_time_eq(b"secret", b"wrong!"));
183 }
184
185 #[test]
186 fn ct_eq_empty_slices() {
187 assert!(constant_time_eq(b"", b""));
188 }
189
190 #[test]
191 fn ct_eq_different_lengths_short_a() {
192 assert!(!constant_time_eq(b"abc", b"abcd"));
193 }
194
195 #[test]
196 fn ct_eq_different_lengths_short_b() {
197 assert!(!constant_time_eq(b"abcd", b"abc"));
198 }
199
200 #[test]
201 fn ct_eq_empty_vs_nonempty() {
202 assert!(!constant_time_eq(b"", b"x"));
203 }
204
205 // -----------------------------------------------------------------------
206 // require_bearer
207 // -----------------------------------------------------------------------
208
209 fn bearer_router() -> Router<()> {
210 let inner = Router::new().route("/protected", get(|| async { "ok" }));
211 require_bearer("correct-token")(inner)
212 }
213
214 #[tokio::test]
215 async fn bearer_accepts_correct_token() {
216 let app = bearer_router();
217 let response = app
218 .oneshot(
219 Request::builder()
220 .uri("/protected")
221 .header("Authorization", "Bearer correct-token")
222 .body(Body::empty())
223 .unwrap(),
224 )
225 .await
226 .unwrap();
227
228 assert_eq!(response.status(), 200);
229 let body = response.into_body().collect().await.unwrap().to_bytes();
230 assert_eq!(&body[..], b"ok");
231 }
232
233 #[tokio::test]
234 async fn bearer_rejects_wrong_token() {
235 let app = bearer_router();
236 let response = app
237 .oneshot(
238 Request::builder()
239 .uri("/protected")
240 .header("Authorization", "Bearer wrong-token")
241 .body(Body::empty())
242 .unwrap(),
243 )
244 .await
245 .unwrap();
246
247 assert_eq!(response.status(), 401);
248 }
249
250 #[tokio::test]
251 async fn bearer_rejects_missing_header() {
252 let app = bearer_router();
253 let response = app
254 .oneshot(
255 Request::builder()
256 .uri("/protected")
257 .body(Body::empty())
258 .unwrap(),
259 )
260 .await
261 .unwrap();
262
263 assert_eq!(response.status(), 401);
264 }
265
266 #[tokio::test]
267 async fn bearer_rejects_malformed_header() {
268 let app = bearer_router();
269 let response = app
270 .oneshot(
271 Request::builder()
272 .uri("/protected")
273 .header("Authorization", "correct-token")
274 .body(Body::empty())
275 .unwrap(),
276 )
277 .await
278 .unwrap();
279
280 assert_eq!(response.status(), 401);
281 }
282
283 // -----------------------------------------------------------------------
284 // guard
285 // -----------------------------------------------------------------------
286
287 fn guard_router(
288 predicate: impl Fn(&Request<Body>) -> bool + Clone + Send + Sync + 'static,
289 ) -> Router<()> {
290 let inner = Router::new().route("/guarded", get(|| async { "ok" }));
291 guard(predicate)(inner)
292 }
293
294 #[tokio::test]
295 async fn guard_allows_request_when_predicate_is_true() {
296 let app = guard_router(|_req| true);
297 let response = app
298 .oneshot(
299 Request::builder()
300 .uri("/guarded")
301 .body(Body::empty())
302 .unwrap(),
303 )
304 .await
305 .unwrap();
306
307 assert_eq!(response.status(), 200);
308 }
309
310 #[tokio::test]
311 async fn guard_blocks_request_with_403_when_predicate_is_false() {
312 let app = guard_router(|_req| false);
313 let response = app
314 .oneshot(
315 Request::builder()
316 .uri("/guarded")
317 .body(Body::empty())
318 .unwrap(),
319 )
320 .await
321 .unwrap();
322
323 assert_eq!(response.status(), 403);
324 }
325
326 #[tokio::test]
327 async fn guard_predicate_receives_live_request_headers() {
328 let app = guard_router(|req| req.headers().contains_key("x-allowed"));
329
330 // without header → 403
331 let blocked = app
332 .clone()
333 .oneshot(
334 Request::builder()
335 .uri("/guarded")
336 .body(Body::empty())
337 .unwrap(),
338 )
339 .await
340 .unwrap();
341 assert_eq!(blocked.status(), 403);
342
343 // with header → 200
344 let allowed = app
345 .oneshot(
346 Request::builder()
347 .uri("/guarded")
348 .header("x-allowed", "yes")
349 .body(Body::empty())
350 .unwrap(),
351 )
352 .await
353 .unwrap();
354 assert_eq!(allowed.status(), 200);
355 }
356
357 #[tokio::test]
358 async fn guard_predicate_receives_live_request_uri() {
359 // predicate inspects the URI path
360 let app = guard_router(|req| req.uri().path().starts_with("/guarded"));
361
362 let response = app
363 .oneshot(
364 Request::builder()
365 .uri("/guarded")
366 .body(Body::empty())
367 .unwrap(),
368 )
369 .await
370 .unwrap();
371 assert_eq!(response.status(), 200);
372 }
373}