Skip to main content

allowthem_server/
bearer.rs

1use axum::extract::{FromRef, FromRequestParts};
2use axum::http::header::AUTHORIZATION;
3use axum::http::request::Parts;
4
5use allowthem_core::{AllowThem, User};
6
7use crate::error::AuthExtractError;
8
9/// Axum extractor that validates an API bearer token.
10///
11/// Reads the `Authorization: Bearer <token>` header, validates the token
12/// against the database, and returns the authenticated user.
13///
14/// Rejects with 401 if the header is absent, malformed, the token is unknown
15/// or expired, or the user is inactive.
16///
17/// This extractor requires `AllowThem: FromRef<S>` (not `Arc<dyn AuthClient>`)
18/// because API tokens are an embedded-mode feature not part of the auth trait.
19///
20/// Usage: `BearerAuthUser(user): BearerAuthUser` in handler arguments.
21pub struct BearerAuthUser(pub User);
22
23impl<S> FromRequestParts<S> for BearerAuthUser
24where
25    AllowThem: FromRef<S>,
26    S: Send + Sync,
27{
28    type Rejection = AuthExtractError;
29
30    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
31        let ath = AllowThem::from_ref(state);
32
33        let auth_header = parts
34            .headers
35            .get(AUTHORIZATION)
36            .and_then(|v| v.to_str().ok())
37            .ok_or(AuthExtractError::Unauthenticated)?;
38
39        let raw_token = auth_header
40            .strip_prefix("Bearer ")
41            .ok_or(AuthExtractError::Unauthenticated)?;
42
43        let user_id = ath
44            .db()
45            .validate_api_token(raw_token)
46            .await
47            .map_err(AuthExtractError::Internal)?
48            .ok_or(AuthExtractError::Unauthenticated)?;
49
50        let user = ath.db().get_user(user_id).await.map_err(|e| match e {
51            allowthem_core::AuthError::NotFound => AuthExtractError::Unauthenticated,
52            other => AuthExtractError::Internal(other),
53        })?;
54
55        if !user.is_active {
56            return Err(AuthExtractError::Unauthenticated);
57        }
58
59        Ok(BearerAuthUser(user))
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use axum::Json;
66    use axum::Router;
67    use axum::extract::FromRef;
68    use axum::http::{Request, StatusCode};
69    use axum::routing::get;
70    use chrono::{Duration, Utc};
71    use tower::ServiceExt;
72
73    use allowthem_core::{AllowThem, AllowThemBuilder, Email};
74
75    use super::*;
76
77    #[derive(Clone)]
78    struct TestState {
79        ath: AllowThem,
80        // server extractors also need Arc<dyn AuthClient> in general state,
81        // but bearer only needs AllowThem — keep it minimal
82    }
83
84    impl FromRef<TestState> for AllowThem {
85        fn from_ref(s: &TestState) -> Self {
86            s.ath.clone()
87        }
88    }
89
90    async fn test_setup() -> (AllowThem, String) {
91        let ath = AllowThemBuilder::new("sqlite::memory:")
92            .cookie_secure(false)
93            .build()
94            .await
95            .unwrap();
96
97        let email = Email::new("bearer@example.com".into()).unwrap();
98        let user = ath
99            .db()
100            .create_user(email, "password123", None)
101            .await
102            .unwrap();
103
104        let (raw, _) = ath
105            .db()
106            .create_api_token(user.id, "test-token", None)
107            .await
108            .unwrap();
109
110        (ath, raw)
111    }
112
113    fn test_app(ath: AllowThem) -> Router {
114        let state = TestState { ath };
115        Router::new()
116            .route("/bearer", get(bearer_handler))
117            .with_state(state)
118    }
119
120    async fn bearer_handler(BearerAuthUser(user): BearerAuthUser) -> Json<serde_json::Value> {
121        Json(serde_json::json!({"email": user.email}))
122    }
123
124    async fn read_body(resp: axum::http::Response<axum::body::Body>) -> serde_json::Value {
125        let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
126            .await
127            .unwrap();
128        serde_json::from_slice(&bytes).unwrap()
129    }
130
131    #[tokio::test]
132    async fn test_no_auth_header_returns_401() {
133        let (ath, _) = test_setup().await;
134        let app = test_app(ath);
135
136        let req = Request::builder()
137            .uri("/bearer")
138            .body(axum::body::Body::empty())
139            .unwrap();
140        let resp = app.oneshot(req).await.unwrap();
141
142        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
143    }
144
145    #[tokio::test]
146    async fn test_malformed_bearer_returns_401() {
147        let (ath, _) = test_setup().await;
148        let app = test_app(ath);
149
150        let req = Request::builder()
151            .uri("/bearer")
152            .header(AUTHORIZATION, "Token abc123")
153            .body(axum::body::Body::empty())
154            .unwrap();
155        let resp = app.oneshot(req).await.unwrap();
156
157        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
158    }
159
160    #[tokio::test]
161    async fn test_invalid_token_returns_401() {
162        let (ath, _) = test_setup().await;
163        let app = test_app(ath);
164
165        let req = Request::builder()
166            .uri("/bearer")
167            .header(AUTHORIZATION, "Bearer garbage-token-xyz")
168            .body(axum::body::Body::empty())
169            .unwrap();
170        let resp = app.oneshot(req).await.unwrap();
171
172        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
173    }
174
175    #[tokio::test]
176    async fn test_valid_bearer_returns_user() {
177        let (ath, raw_token) = test_setup().await;
178        let app = test_app(ath);
179
180        let req = Request::builder()
181            .uri("/bearer")
182            .header(AUTHORIZATION, format!("Bearer {raw_token}"))
183            .body(axum::body::Body::empty())
184            .unwrap();
185        let resp = app.oneshot(req).await.unwrap();
186
187        assert_eq!(resp.status(), StatusCode::OK);
188        let body = read_body(resp).await;
189        assert_eq!(body["email"], "bearer@example.com");
190    }
191
192    #[tokio::test]
193    async fn test_expired_token_returns_401() {
194        let ath = AllowThemBuilder::new("sqlite::memory:")
195            .cookie_secure(false)
196            .build()
197            .await
198            .unwrap();
199
200        let email = Email::new("expired-bearer@example.com".into()).unwrap();
201        let user = ath
202            .db()
203            .create_user(email, "password123", None)
204            .await
205            .unwrap();
206
207        let past = Utc::now() - Duration::hours(1);
208        let (raw, _) = ath
209            .db()
210            .create_api_token(user.id, "expired", Some(past))
211            .await
212            .unwrap();
213
214        let app = test_app(ath);
215
216        let req = Request::builder()
217            .uri("/bearer")
218            .header(AUTHORIZATION, format!("Bearer {raw}"))
219            .body(axum::body::Body::empty())
220            .unwrap();
221        let resp = app.oneshot(req).await.unwrap();
222
223        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
224    }
225
226    #[tokio::test]
227    async fn test_inactive_user_returns_401() {
228        let (ath, raw_token) = test_setup().await;
229
230        let email = Email::new("bearer@example.com".into()).unwrap();
231        let user = ath.db().get_user_by_email(&email).await.unwrap();
232        ath.db().update_user_active(user.id, false).await.unwrap();
233
234        let app = test_app(ath);
235
236        let req = Request::builder()
237            .uri("/bearer")
238            .header(AUTHORIZATION, format!("Bearer {raw_token}"))
239            .body(axum::body::Body::empty())
240            .unwrap();
241        let resp = app.oneshot(req).await.unwrap();
242
243        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
244        let body = read_body(resp).await;
245        assert_eq!(body["error"], "unauthenticated");
246    }
247}