Skip to main content

codetether_agent/server/auth/
middleware.rs

1use super::util::{constant_time_eq, provided_token};
2use super::{AuthState, extract_unverified_jwt_claims};
3use axum::{
4    body::Body,
5    http::{Request, StatusCode},
6    middleware::Next,
7    response::Response,
8};
9
10/// Axum middleware layer that enforces Bearer token auth on every request
11/// except public paths.
12///
13/// # Examples
14///
15/// ```rust
16/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
17/// use axum::{
18///     Router,
19///     body::Body,
20///     http::{Request, StatusCode},
21///     middleware,
22///     routing::get,
23/// };
24/// use codetether_agent::server::auth::{AuthState, require_auth};
25/// use tower::ServiceExt;
26///
27/// let app = Router::new()
28///     .route("/secure", get(|| async { "ok" }))
29///     .layer(middleware::from_fn(require_auth))
30///     .layer(axum::Extension(AuthState::with_token("example-token")));
31///
32/// let response = app
33///     .oneshot(
34///         Request::builder()
35///             .uri("/secure")
36///             .header("authorization", "Bearer example-token")
37///             .body(Body::empty())
38///             .expect("request"),
39///     )
40///     .await
41///     .expect("response");
42///
43/// assert_eq!(response.status(), StatusCode::OK);
44/// # });
45/// ```
46pub async fn require_auth(mut request: Request<Body>, next: Next) -> Result<Response, StatusCode> {
47    let path = request.uri().path();
48    let auth_state = request
49        .extensions()
50        .get::<AuthState>()
51        .cloned()
52        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
53    if auth_state.is_public_path(path) {
54        return Ok(next.run(request).await);
55    }
56    let provided_token = provided_token(&request).ok_or(StatusCode::UNAUTHORIZED)?;
57    if !constant_time_eq(provided_token.as_bytes(), auth_state.token().as_bytes()) {
58        return Err(StatusCode::UNAUTHORIZED);
59    }
60    if let Some(claims) = extract_unverified_jwt_claims(provided_token.as_ref()) {
61        request.extensions_mut().insert(claims);
62    }
63    Ok(next.run(request).await)
64}