codetether_agent/server/auth/middleware.rs
1use super::util::{constant_time_eq, provided_token};
2use super::{AuthState, extract_unverified_jwt_claims};
3use axum::{
4 body::Body,
5 http::{Request, StatusCode},
6 middleware::Next,
7 response::Response,
8};
9
10/// Axum middleware layer that enforces Bearer token auth on every request
11/// except public paths.
12///
13/// # Examples
14///
15/// ```rust
16/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
17/// use axum::{
18/// Router,
19/// body::Body,
20/// http::{Request, StatusCode},
21/// middleware,
22/// routing::get,
23/// };
24/// use codetether_agent::server::auth::{AuthState, require_auth};
25/// use tower::ServiceExt;
26///
27/// let app = Router::new()
28/// .route("/secure", get(|| async { "ok" }))
29/// .layer(middleware::from_fn(require_auth))
30/// .layer(axum::Extension(AuthState::with_token("example-token")));
31///
32/// let response = app
33/// .oneshot(
34/// Request::builder()
35/// .uri("/secure")
36/// .header("authorization", "Bearer example-token")
37/// .body(Body::empty())
38/// .expect("request"),
39/// )
40/// .await
41/// .expect("response");
42///
43/// assert_eq!(response.status(), StatusCode::OK);
44/// # });
45/// ```
46pub async fn require_auth(mut request: Request<Body>, next: Next) -> Result<Response, StatusCode> {
47 let path = request.uri().path();
48 let auth_state = request
49 .extensions()
50 .get::<AuthState>()
51 .cloned()
52 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
53 if auth_state.is_public_path(path) {
54 return Ok(next.run(request).await);
55 }
56 let provided_token = provided_token(&request).ok_or(StatusCode::UNAUTHORIZED)?;
57 if !constant_time_eq(provided_token.as_bytes(), auth_state.token().as_bytes()) {
58 return Err(StatusCode::UNAUTHORIZED);
59 }
60 if let Some(claims) = extract_unverified_jwt_claims(provided_token.as_ref()) {
61 request.extensions_mut().insert(claims);
62 }
63 Ok(next.run(request).await)
64}