1use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14use axum::body::Body;
15use axum::response::IntoResponse;
16use http::Request;
17use tower::{Layer, Service};
18
19use crate::Error;
20use crate::auth::apikey::ApiKeyMeta;
21use crate::auth::role::Role;
22
23pub fn require_role(roles: impl IntoIterator<Item = impl Into<String>>) -> RequireRoleLayer {
57 RequireRoleLayer {
58 roles: Arc::new(roles.into_iter().map(Into::into).collect()),
59 }
60}
61
62pub struct RequireRoleLayer {
64 roles: Arc<Vec<String>>,
65}
66
67impl Clone for RequireRoleLayer {
68 fn clone(&self) -> Self {
69 Self {
70 roles: self.roles.clone(),
71 }
72 }
73}
74
75impl<S> Layer<S> for RequireRoleLayer {
76 type Service = RequireRoleService<S>;
77
78 fn layer(&self, inner: S) -> Self::Service {
79 RequireRoleService {
80 inner,
81 roles: self.roles.clone(),
82 }
83 }
84}
85
86pub struct RequireRoleService<S> {
88 inner: S,
89 roles: Arc<Vec<String>>,
90}
91
92impl<S: Clone> Clone for RequireRoleService<S> {
93 fn clone(&self) -> Self {
94 Self {
95 inner: self.inner.clone(),
96 roles: self.roles.clone(),
97 }
98 }
99}
100
101impl<S> Service<Request<Body>> for RequireRoleService<S>
102where
103 S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
104 S::Future: Send + 'static,
105 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
106{
107 type Response = http::Response<Body>;
108 type Error = S::Error;
109 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
110
111 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112 self.inner.poll_ready(cx)
113 }
114
115 fn call(&mut self, request: Request<Body>) -> Self::Future {
116 let roles = self.roles.clone();
117 let mut inner = self.inner.clone();
118 std::mem::swap(&mut self.inner, &mut inner);
119
120 Box::pin(async move {
121 let role = match request.extensions().get::<Role>() {
122 Some(r) => r,
123 None => {
124 return Ok(Error::unauthorized("authentication required").into_response());
125 }
126 };
127
128 if !roles.iter().any(|allowed| allowed == role.as_str()) {
129 return Ok(Error::forbidden("insufficient role").into_response());
130 }
131
132 inner.call(request).await
133 })
134 }
135}
136
137pub fn require_authenticated() -> RequireAuthenticatedLayer {
168 RequireAuthenticatedLayer
169}
170
171pub struct RequireAuthenticatedLayer;
173
174impl Clone for RequireAuthenticatedLayer {
175 fn clone(&self) -> Self {
176 Self
177 }
178}
179
180impl<S> Layer<S> for RequireAuthenticatedLayer {
181 type Service = RequireAuthenticatedService<S>;
182
183 fn layer(&self, inner: S) -> Self::Service {
184 RequireAuthenticatedService { inner }
185 }
186}
187
188pub struct RequireAuthenticatedService<S> {
190 inner: S,
191}
192
193impl<S: Clone> Clone for RequireAuthenticatedService<S> {
194 fn clone(&self) -> Self {
195 Self {
196 inner: self.inner.clone(),
197 }
198 }
199}
200
201impl<S> Service<Request<Body>> for RequireAuthenticatedService<S>
202where
203 S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
204 S::Future: Send + 'static,
205 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
206{
207 type Response = http::Response<Body>;
208 type Error = S::Error;
209 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
210
211 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212 self.inner.poll_ready(cx)
213 }
214
215 fn call(&mut self, request: Request<Body>) -> Self::Future {
216 let mut inner = self.inner.clone();
217 std::mem::swap(&mut self.inner, &mut inner);
218
219 Box::pin(async move {
220 if request.extensions().get::<Role>().is_none() {
221 return Ok(Error::unauthorized("authentication required").into_response());
222 }
223
224 inner.call(request).await
225 })
226 }
227}
228
229pub fn require_scope(scope: &str) -> ScopeLayer {
265 ScopeLayer {
266 scope: scope.to_owned(),
267 }
268}
269
270#[derive(Clone)]
275pub struct ScopeLayer {
276 scope: String,
277}
278
279impl<S> Layer<S> for ScopeLayer {
280 type Service = ScopeMiddleware<S>;
281
282 fn layer(&self, inner: S) -> Self::Service {
283 ScopeMiddleware {
284 inner,
285 scope: self.scope.clone(),
286 }
287 }
288}
289
290pub struct ScopeMiddleware<S> {
292 inner: S,
293 scope: String,
294}
295
296impl<S: Clone> Clone for ScopeMiddleware<S> {
297 fn clone(&self) -> Self {
298 Self {
299 inner: self.inner.clone(),
300 scope: self.scope.clone(),
301 }
302 }
303}
304
305impl<S> Service<Request<Body>> for ScopeMiddleware<S>
306where
307 S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
308 S::Future: Send + 'static,
309 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
310{
311 type Response = http::Response<Body>;
312 type Error = S::Error;
313 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
314
315 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
316 self.inner.poll_ready(cx)
317 }
318
319 fn call(&mut self, request: Request<Body>) -> Self::Future {
320 let scope = self.scope.clone();
321 let mut inner = self.inner.clone();
322 std::mem::swap(&mut self.inner, &mut inner);
323
324 Box::pin(async move {
325 let Some(meta) = request.extensions().get::<ApiKeyMeta>() else {
326 tracing::error!(
327 "require_scope guard reached without an API key in extensions; \
328 ApiKeyLayer must run before this guard"
329 );
330 return Ok(Error::internal("server misconfigured").into_response());
331 };
332
333 if !meta.scopes.iter().any(|s| s == &scope) {
334 return Ok(
335 Error::forbidden(format!("missing required scope: {scope}")).into_response()
336 );
337 }
338
339 inner.call(request).await
340 })
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use http::{Response, StatusCode};
348 use std::convert::Infallible;
349 use tower::ServiceExt;
350
351 async fn ok_handler(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
352 Ok(Response::new(Body::from("ok")))
353 }
354
355 #[tokio::test]
358 async fn require_role_passes_when_role_in_list() {
359 let layer = require_role(["admin", "owner"]);
360 let svc = layer.layer(tower::service_fn(ok_handler));
361
362 let mut req = Request::builder().body(Body::empty()).unwrap();
363 req.extensions_mut().insert(Role("admin".into()));
364 let resp = svc.oneshot(req).await.unwrap();
365 assert_eq!(resp.status(), StatusCode::OK);
366 }
367
368 #[tokio::test]
369 async fn require_role_403_when_role_not_in_list() {
370 let layer = require_role(["admin", "owner"]);
371 let svc = layer.layer(tower::service_fn(ok_handler));
372
373 let mut req = Request::builder().body(Body::empty()).unwrap();
374 req.extensions_mut().insert(Role("viewer".into()));
375 let resp = svc.oneshot(req).await.unwrap();
376 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
377 }
378
379 #[tokio::test]
380 async fn require_role_401_when_role_missing() {
381 let layer = require_role(["admin"]);
382 let svc = layer.layer(tower::service_fn(ok_handler));
383
384 let req = Request::builder().body(Body::empty()).unwrap();
385 let resp = svc.oneshot(req).await.unwrap();
386 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
387 }
388
389 #[tokio::test]
390 async fn require_role_403_when_empty_roles_list() {
391 let layer = require_role(std::iter::empty::<String>());
392 let svc = layer.layer(tower::service_fn(ok_handler));
393
394 let mut req = Request::builder().body(Body::empty()).unwrap();
395 req.extensions_mut().insert(Role("admin".into()));
396 let resp = svc.oneshot(req).await.unwrap();
397 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
398 }
399
400 #[tokio::test]
401 async fn require_role_empty_string_matches() {
402 let layer = require_role([""]);
403 let svc = layer.layer(tower::service_fn(ok_handler));
404
405 let mut req = Request::builder().body(Body::empty()).unwrap();
406 req.extensions_mut().insert(Role("".into()));
407 let resp = svc.oneshot(req).await.unwrap();
408 assert_eq!(resp.status(), StatusCode::OK);
409 }
410
411 #[tokio::test]
412 async fn require_role_does_not_call_inner_on_reject() {
413 use std::sync::atomic::{AtomicBool, Ordering};
414
415 let called = Arc::new(AtomicBool::new(false));
416 let called_clone = called.clone();
417
418 let layer = require_role(["admin"]);
419 let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
420 let called = called_clone.clone();
421 async move {
422 called.store(true, Ordering::SeqCst);
423 Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
424 }
425 }));
426
427 let mut req = Request::builder().body(Body::empty()).unwrap();
428 req.extensions_mut().insert(Role("viewer".into()));
429 let resp = svc.oneshot(req).await.unwrap();
430 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
431 assert!(!called.load(Ordering::SeqCst));
432 }
433
434 #[tokio::test]
437 async fn require_authenticated_passes_when_role_present() {
438 let layer = require_authenticated();
439 let svc = layer.layer(tower::service_fn(ok_handler));
440
441 let mut req = Request::builder().body(Body::empty()).unwrap();
442 req.extensions_mut().insert(Role("viewer".into()));
443 let resp = svc.oneshot(req).await.unwrap();
444 assert_eq!(resp.status(), StatusCode::OK);
445 }
446
447 #[tokio::test]
448 async fn require_authenticated_401_when_role_missing() {
449 let layer = require_authenticated();
450 let svc = layer.layer(tower::service_fn(ok_handler));
451
452 let req = Request::builder().body(Body::empty()).unwrap();
453 let resp = svc.oneshot(req).await.unwrap();
454 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
455 }
456
457 #[tokio::test]
458 async fn require_authenticated_does_not_call_inner_on_reject() {
459 use std::sync::atomic::{AtomicBool, Ordering};
460
461 let called = Arc::new(AtomicBool::new(false));
462 let called_clone = called.clone();
463
464 let layer = require_authenticated();
465 let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
466 let called = called_clone.clone();
467 async move {
468 called.store(true, Ordering::SeqCst);
469 Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
470 }
471 }));
472
473 let req = Request::builder().body(Body::empty()).unwrap();
474 let resp = svc.oneshot(req).await.unwrap();
475 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
476 assert!(!called.load(Ordering::SeqCst));
477 }
478
479 fn meta_with_scopes(scopes: &[&str]) -> ApiKeyMeta {
482 ApiKeyMeta {
483 id: "01HX".into(),
484 tenant_id: "t".into(),
485 name: "test key".into(),
486 scopes: scopes.iter().map(|s| (*s).into()).collect(),
487 expires_at: None,
488 last_used_at: None,
489 created_at: "2026-01-01T00:00:00Z".into(),
490 }
491 }
492
493 #[tokio::test]
494 async fn require_scope_passes_when_scope_present() {
495 let layer = require_scope("read:orders");
496 let svc = layer.layer(tower::service_fn(ok_handler));
497
498 let mut req = Request::builder().body(Body::empty()).unwrap();
499 req.extensions_mut()
500 .insert(meta_with_scopes(&["read:orders", "write:orders"]));
501 let resp = svc.oneshot(req).await.unwrap();
502 assert_eq!(resp.status(), StatusCode::OK);
503 }
504
505 #[tokio::test]
506 async fn require_scope_403_when_scope_absent() {
507 let layer = require_scope("admin:all");
508 let svc = layer.layer(tower::service_fn(ok_handler));
509
510 let mut req = Request::builder().body(Body::empty()).unwrap();
511 req.extensions_mut()
512 .insert(meta_with_scopes(&["read:orders"]));
513 let resp = svc.oneshot(req).await.unwrap();
514 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
515 }
516
517 #[tokio::test]
518 async fn require_scope_500_when_apikey_meta_missing() {
519 let layer = require_scope("read:orders");
520 let svc = layer.layer(tower::service_fn(ok_handler));
521
522 let req = Request::builder().body(Body::empty()).unwrap();
523 let resp = svc.oneshot(req).await.unwrap();
524 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
525 }
526}