Skip to main content

allowthem_server/
oauth_bearer.rs

1use axum::extract::FromRequestParts;
2use axum::http::StatusCode;
3use axum::http::header::AUTHORIZATION;
4use axum::http::request::Parts;
5use axum::response::{IntoResponse, Response};
6
7use allowthem_core::{AccessTokenClaims, AccessTokenError, AllowThem, AuthError};
8
9/// Rejection type for the `OAuthBearerToken` extractor.
10///
11/// Returns 401 with `WWW-Authenticate: Bearer` headers per RFC 6750.
12pub enum OAuthBearerError {
13    /// No Authorization header or not a Bearer token.
14    Missing,
15    /// Token is expired.
16    Expired,
17    /// Token signature is invalid or kid is unknown.
18    InvalidToken(String),
19    /// Internal error during validation.
20    Internal,
21}
22
23impl IntoResponse for OAuthBearerError {
24    fn into_response(self) -> Response {
25        let (status, www_auth) = match self {
26            Self::Missing => (
27                StatusCode::UNAUTHORIZED,
28                "Bearer realm=\"allowthem\"".to_string(),
29            ),
30            Self::Expired => (
31                StatusCode::UNAUTHORIZED,
32                "Bearer realm=\"allowthem\", error=\"invalid_token\", \
33                 error_description=\"token expired\""
34                    .to_string(),
35            ),
36            Self::InvalidToken(desc) => (
37                StatusCode::UNAUTHORIZED,
38                format!(
39                    "Bearer realm=\"allowthem\", error=\"invalid_token\", \
40                     error_description=\"{desc}\""
41                ),
42            ),
43            Self::Internal => {
44                tracing::error!("internal error during OAuth bearer validation");
45                (
46                    StatusCode::INTERNAL_SERVER_ERROR,
47                    "Bearer realm=\"allowthem\", error=\"server_error\"".to_string(),
48                )
49            }
50        };
51
52        let mut response = status.into_response();
53        if let Ok(value) = www_auth.parse() {
54            response.headers_mut().insert("WWW-Authenticate", value);
55        }
56        response
57    }
58}
59
60/// Axum extractor that validates an OAuth2 RS256 access token.
61///
62/// Reads `Authorization: Bearer <jwt>`, validates the RS256 signature,
63/// checks expiry and issuer, and returns the validated claims.
64///
65/// Rejects with 401 and `WWW-Authenticate: Bearer` header per RFC 6750.
66pub struct OAuthBearerToken(pub AccessTokenClaims);
67
68impl<S: Send + Sync> FromRequestParts<S> for OAuthBearerToken {
69    type Rejection = OAuthBearerError;
70
71    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
72        let ath = parts
73            .extensions
74            .get::<AllowThem>()
75            .cloned()
76            .ok_or(OAuthBearerError::Internal)?;
77
78        let auth_header = parts
79            .headers
80            .get(AUTHORIZATION)
81            .and_then(|v| v.to_str().ok())
82            .ok_or(OAuthBearerError::Missing)?;
83
84        let jwt = auth_header
85            .strip_prefix("Bearer ")
86            .ok_or(OAuthBearerError::Missing)?;
87
88        let base_url = ath.base_url().map_err(|_| OAuthBearerError::Internal)?;
89
90        let claims = ath
91            .db()
92            .validate_access_token(jwt, base_url)
93            .await
94            .map_err(|e| match e {
95                AuthError::AccessToken(AccessTokenError::Expired) => OAuthBearerError::Expired,
96                AuthError::AccessToken(AccessTokenError::InvalidSignature) => {
97                    OAuthBearerError::InvalidToken("invalid signature".into())
98                }
99                AuthError::AccessToken(AccessTokenError::UnknownKid(_)) => {
100                    OAuthBearerError::InvalidToken("unknown signing key".into())
101                }
102                AuthError::AccessToken(AccessTokenError::InvalidClaims(msg)) => {
103                    OAuthBearerError::InvalidToken(msg)
104                }
105                AuthError::AccessToken(AccessTokenError::MalformedToken(msg)) => {
106                    OAuthBearerError::InvalidToken(msg)
107                }
108                _ => OAuthBearerError::Internal,
109            })?;
110
111        Ok(OAuthBearerToken(claims))
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use axum::Router;
118    use axum::http::{Request, StatusCode};
119    use axum::response::Json;
120    use axum::routing::get;
121    use tower::ServiceExt;
122
123    use allowthem_core::{AllowThem, AllowThemBuilder};
124
125    use super::*;
126
127    async fn test_setup() -> (AllowThem, Router) {
128        let ath = AllowThemBuilder::new("sqlite::memory:")
129            .cookie_secure(false)
130            .base_url("https://auth.example.com")
131            .build()
132            .await
133            .unwrap();
134
135        let app = Router::new()
136            .route(
137                "/test",
138                get(|OAuthBearerToken(claims): OAuthBearerToken| async move {
139                    Json(serde_json::json!({"sub": claims.sub.to_string()}))
140                }),
141            )
142            .layer(axum::middleware::from_fn_with_state(
143                ath.clone(),
144                crate::cors::inject_ath_into_extensions,
145            ))
146            .with_state(ath.clone());
147
148        (ath, app)
149    }
150
151    #[tokio::test]
152    async fn test_missing_auth_header_returns_401_with_www_authenticate() {
153        let (_, app) = test_setup().await;
154
155        let req = Request::builder()
156            .uri("/test")
157            .body(axum::body::Body::empty())
158            .unwrap();
159        let resp = app.oneshot(req).await.unwrap();
160
161        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
162        let www_auth = resp
163            .headers()
164            .get("WWW-Authenticate")
165            .unwrap()
166            .to_str()
167            .unwrap();
168        assert!(www_auth.contains("Bearer realm=\"allowthem\""));
169    }
170
171    #[tokio::test]
172    async fn test_malformed_bearer_returns_401() {
173        let (_, app) = test_setup().await;
174
175        let req = Request::builder()
176            .uri("/test")
177            .header(AUTHORIZATION, "Token abc123")
178            .body(axum::body::Body::empty())
179            .unwrap();
180        let resp = app.oneshot(req).await.unwrap();
181
182        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
183    }
184
185    #[tokio::test]
186    async fn test_invalid_jwt_returns_401() {
187        let (_, app) = test_setup().await;
188
189        let req = Request::builder()
190            .uri("/test")
191            .header(AUTHORIZATION, "Bearer not.a.jwt")
192            .body(axum::body::Body::empty())
193            .unwrap();
194        let resp = app.oneshot(req).await.unwrap();
195
196        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
197    }
198}