allowthem_server/
oauth_bearer.rs1use axum::extract::FromRequestParts;
2use axum::http::StatusCode;
3use axum::http::header::AUTHORIZATION;
4use axum::http::request::Parts;
5use axum::response::{IntoResponse, Response};
6
7use allowthem_core::{AccessTokenClaims, AccessTokenError, AllowThem, AuthError};
8
9pub enum OAuthBearerError {
13 Missing,
15 Expired,
17 InvalidToken(String),
19 Internal,
21}
22
23impl IntoResponse for OAuthBearerError {
24 fn into_response(self) -> Response {
25 let (status, www_auth) = match self {
26 Self::Missing => (
27 StatusCode::UNAUTHORIZED,
28 "Bearer realm=\"allowthem\"".to_string(),
29 ),
30 Self::Expired => (
31 StatusCode::UNAUTHORIZED,
32 "Bearer realm=\"allowthem\", error=\"invalid_token\", \
33 error_description=\"token expired\""
34 .to_string(),
35 ),
36 Self::InvalidToken(desc) => (
37 StatusCode::UNAUTHORIZED,
38 format!(
39 "Bearer realm=\"allowthem\", error=\"invalid_token\", \
40 error_description=\"{desc}\""
41 ),
42 ),
43 Self::Internal => {
44 tracing::error!("internal error during OAuth bearer validation");
45 (
46 StatusCode::INTERNAL_SERVER_ERROR,
47 "Bearer realm=\"allowthem\", error=\"server_error\"".to_string(),
48 )
49 }
50 };
51
52 let mut response = status.into_response();
53 if let Ok(value) = www_auth.parse() {
54 response.headers_mut().insert("WWW-Authenticate", value);
55 }
56 response
57 }
58}
59
60pub struct OAuthBearerToken(pub AccessTokenClaims);
67
68impl<S: Send + Sync> FromRequestParts<S> for OAuthBearerToken {
69 type Rejection = OAuthBearerError;
70
71 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
72 let ath = parts
73 .extensions
74 .get::<AllowThem>()
75 .cloned()
76 .ok_or(OAuthBearerError::Internal)?;
77
78 let auth_header = parts
79 .headers
80 .get(AUTHORIZATION)
81 .and_then(|v| v.to_str().ok())
82 .ok_or(OAuthBearerError::Missing)?;
83
84 let jwt = auth_header
85 .strip_prefix("Bearer ")
86 .ok_or(OAuthBearerError::Missing)?;
87
88 let base_url = ath.base_url().map_err(|_| OAuthBearerError::Internal)?;
89
90 let claims = ath
91 .db()
92 .validate_access_token(jwt, base_url)
93 .await
94 .map_err(|e| match e {
95 AuthError::AccessToken(AccessTokenError::Expired) => OAuthBearerError::Expired,
96 AuthError::AccessToken(AccessTokenError::InvalidSignature) => {
97 OAuthBearerError::InvalidToken("invalid signature".into())
98 }
99 AuthError::AccessToken(AccessTokenError::UnknownKid(_)) => {
100 OAuthBearerError::InvalidToken("unknown signing key".into())
101 }
102 AuthError::AccessToken(AccessTokenError::InvalidClaims(msg)) => {
103 OAuthBearerError::InvalidToken(msg)
104 }
105 AuthError::AccessToken(AccessTokenError::MalformedToken(msg)) => {
106 OAuthBearerError::InvalidToken(msg)
107 }
108 _ => OAuthBearerError::Internal,
109 })?;
110
111 Ok(OAuthBearerToken(claims))
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use axum::Router;
118 use axum::http::{Request, StatusCode};
119 use axum::response::Json;
120 use axum::routing::get;
121 use tower::ServiceExt;
122
123 use allowthem_core::{AllowThem, AllowThemBuilder};
124
125 use super::*;
126
127 async fn test_setup() -> (AllowThem, Router) {
128 let ath = AllowThemBuilder::new("sqlite::memory:")
129 .cookie_secure(false)
130 .base_url("https://auth.example.com")
131 .build()
132 .await
133 .unwrap();
134
135 let app = Router::new()
136 .route(
137 "/test",
138 get(|OAuthBearerToken(claims): OAuthBearerToken| async move {
139 Json(serde_json::json!({"sub": claims.sub.to_string()}))
140 }),
141 )
142 .layer(axum::middleware::from_fn_with_state(
143 ath.clone(),
144 crate::cors::inject_ath_into_extensions,
145 ))
146 .with_state(ath.clone());
147
148 (ath, app)
149 }
150
151 #[tokio::test]
152 async fn test_missing_auth_header_returns_401_with_www_authenticate() {
153 let (_, app) = test_setup().await;
154
155 let req = Request::builder()
156 .uri("/test")
157 .body(axum::body::Body::empty())
158 .unwrap();
159 let resp = app.oneshot(req).await.unwrap();
160
161 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
162 let www_auth = resp
163 .headers()
164 .get("WWW-Authenticate")
165 .unwrap()
166 .to_str()
167 .unwrap();
168 assert!(www_auth.contains("Bearer realm=\"allowthem\""));
169 }
170
171 #[tokio::test]
172 async fn test_malformed_bearer_returns_401() {
173 let (_, app) = test_setup().await;
174
175 let req = Request::builder()
176 .uri("/test")
177 .header(AUTHORIZATION, "Token abc123")
178 .body(axum::body::Body::empty())
179 .unwrap();
180 let resp = app.oneshot(req).await.unwrap();
181
182 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
183 }
184
185 #[tokio::test]
186 async fn test_invalid_jwt_returns_401() {
187 let (_, app) = test_setup().await;
188
189 let req = Request::builder()
190 .uri("/test")
191 .header(AUTHORIZATION, "Bearer not.a.jwt")
192 .body(axum::body::Body::empty())
193 .unwrap();
194 let resp = app.oneshot(req).await.unwrap();
195
196 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
197 }
198}