1use axum::body::Body;
2use axum::extract::FromRef;
3use axum::http::{Request, StatusCode, header::COOKIE};
4use axum::middleware::Next;
5use axum::response::{IntoResponse, Response};
6use serde_json::json;
7
8use std::sync::Arc;
9
10use allowthem_core::{AuthClient, AuthError, PermissionName, RoleName, User, parse_session_cookie};
11
12pub async fn require_auth<S>(
20 state: axum::extract::State<S>,
21 mut request: Request<Body>,
22 next: Next,
23) -> Response
24where
25 Arc<dyn AuthClient>: FromRef<S>,
26 S: Send + Sync + Clone,
27{
28 let client = <Arc<dyn AuthClient>>::from_ref(&*state);
29 let headers = request.headers().clone();
32
33 let user = match authenticate(&*client, &headers).await {
34 Ok(u) => u,
35 Err(r) => return r,
36 };
37
38 request.extensions_mut().insert(user);
39 next.run(request).await
40}
41
42pub fn require_role<S>(
60 role: impl Into<String>,
61) -> impl Fn(
62 axum::extract::State<S>,
63 Request<Body>,
64 Next,
65) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
66+ Clone
67+ Send
68+ 'static
69where
70 Arc<dyn AuthClient>: FromRef<S>,
71 S: Send + Sync + Clone + 'static,
72{
73 let role_name = role.into();
74 move |state, request, next| {
75 let role_name = role_name.clone();
76 Box::pin(require_role_inner(state, request, next, role_name))
77 }
78}
79
80async fn require_role_inner<S>(
81 state: axum::extract::State<S>,
82 mut request: Request<Body>,
83 next: Next,
84 role_name: String,
85) -> Response
86where
87 Arc<dyn AuthClient>: FromRef<S>,
88 S: Send + Sync + Clone,
89{
90 let client = <Arc<dyn AuthClient>>::from_ref(&*state);
91 let headers = request.headers().clone();
92
93 let user = match authenticate(&*client, &headers).await {
94 Ok(u) => u,
95 Err(r) => return r,
96 };
97
98 let rn = RoleName::new(role_name);
99 match client.check_role(&user.id, &rn).await {
100 Ok(true) => {}
101 Ok(false) => {
102 return (
103 StatusCode::FORBIDDEN,
104 axum::Json(json!({"error": "forbidden"})),
105 )
106 .into_response();
107 }
108 Err(e) => return internal_error(e),
109 }
110
111 request.extensions_mut().insert(user);
112 next.run(request).await
113}
114
115pub fn require_permission<S>(
122 permission: impl Into<String>,
123) -> impl Fn(
124 axum::extract::State<S>,
125 Request<Body>,
126 Next,
127) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
128+ Clone
129+ Send
130+ 'static
131where
132 Arc<dyn AuthClient>: FromRef<S>,
133 S: Send + Sync + Clone + 'static,
134{
135 let perm_name = permission.into();
136 move |state, request, next| {
137 let perm_name = perm_name.clone();
138 Box::pin(require_permission_inner(state, request, next, perm_name))
139 }
140}
141
142async fn require_permission_inner<S>(
143 state: axum::extract::State<S>,
144 mut request: Request<Body>,
145 next: Next,
146 perm_name: String,
147) -> Response
148where
149 Arc<dyn AuthClient>: FromRef<S>,
150 S: Send + Sync + Clone,
151{
152 let client = <Arc<dyn AuthClient>>::from_ref(&*state);
153 let headers = request.headers().clone();
154
155 let user = match authenticate(&*client, &headers).await {
156 Ok(u) => u,
157 Err(r) => return r,
158 };
159
160 let pn = PermissionName::new(perm_name);
161 match client.check_permission(&user.id, &pn).await {
162 Ok(true) => {}
163 Ok(false) => {
164 return (
165 StatusCode::FORBIDDEN,
166 axum::Json(json!({"error": "forbidden"})),
167 )
168 .into_response();
169 }
170 Err(e) => return internal_error(e),
171 }
172
173 request.extensions_mut().insert(user);
174 next.run(request).await
175}
176
177async fn authenticate(
184 client: &dyn AuthClient,
185 headers: &axum::http::HeaderMap,
186) -> Result<User, Response> {
187 let cookie_header = headers
188 .get(COOKIE)
189 .and_then(|v| v.to_str().ok())
190 .ok_or_else(unauthenticated)?
191 .to_string();
192
193 let token = parse_session_cookie(&cookie_header, client.session_cookie_name())
194 .ok_or_else(unauthenticated)?;
195
196 let user = client
197 .validate_session(&token)
198 .await
199 .map_err(internal_error)?
200 .ok_or_else(unauthenticated)?;
201
202 Ok(user)
203}
204
205fn unauthenticated() -> Response {
206 (
207 StatusCode::UNAUTHORIZED,
208 axum::Json(json!({"error": "unauthenticated"})),
209 )
210 .into_response()
211}
212
213fn internal_error(err: AuthError) -> Response {
214 tracing::error!("auth middleware error: {err}");
215 (
216 StatusCode::INTERNAL_SERVER_ERROR,
217 axum::Json(json!({"error": "internal error"})),
218 )
219 .into_response()
220}
221
222#[cfg(test)]
223mod tests {
224 use std::sync::Arc;
225
226 use super::*;
227 use allowthem_core::{
228 AllowThem, AllowThemBuilder, AuthClient, Email, EmbeddedAuthClient, generate_token,
229 hash_token,
230 };
231 use axum::extract::FromRef;
232 use axum::http::StatusCode;
233 use axum::routing::get;
234 use axum::{Router, middleware};
235 use chrono::{Duration, Utc};
236 use tower::ServiceExt;
237
238 #[derive(Clone)]
239 struct TestState {
240 auth: Arc<dyn AuthClient>,
241 }
242
243 impl FromRef<TestState> for Arc<dyn AuthClient> {
244 fn from_ref(s: &TestState) -> Self {
245 Arc::clone(&s.auth)
246 }
247 }
248
249 async fn test_setup() -> (AllowThem, String) {
251 let ath = AllowThemBuilder::new("sqlite::memory:")
252 .cookie_secure(false)
253 .build()
254 .await
255 .unwrap();
256
257 let email = Email::new("user@example.com".into()).unwrap();
258 let user = ath
259 .db()
260 .create_user(email, "password123", None)
261 .await
262 .unwrap();
263
264 let token = generate_token();
265 let token_hash = hash_token(&token);
266 let expires = Utc::now() + Duration::hours(24);
267 ath.db()
268 .create_session(user.id, token_hash, None, None, expires)
269 .await
270 .unwrap();
271
272 let cookie = ath.session_cookie(&token);
273 let cookie_value = cookie.split(';').next().unwrap().to_string();
274 (ath, cookie_value)
275 }
276
277 async fn ok_handler() -> StatusCode {
278 StatusCode::OK
279 }
280
281 fn auth_app(ath: AllowThem) -> Router {
282 let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
283 let state = TestState { auth };
284 Router::new()
285 .route("/protected", get(ok_handler))
286 .layer(middleware::from_fn_with_state(
287 state.clone(),
288 require_auth::<TestState>,
289 ))
290 .with_state(state)
291 }
292
293 fn role_app(ath: AllowThem, role: &str) -> Router {
294 let role = role.to_string();
295 let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
296 let state = TestState { auth };
297 Router::new()
298 .route("/protected", get(ok_handler))
299 .layer(middleware::from_fn_with_state(
300 state.clone(),
301 require_role::<TestState>(role),
302 ))
303 .with_state(state)
304 }
305
306 fn perm_app(ath: AllowThem, perm: &str) -> Router {
307 let perm = perm.to_string();
308 let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
309 let state = TestState { auth };
310 Router::new()
311 .route("/protected", get(ok_handler))
312 .layer(middleware::from_fn_with_state(
313 state.clone(),
314 require_permission::<TestState>(perm),
315 ))
316 .with_state(state)
317 }
318
319 fn make_request(cookie: Option<&str>) -> axum::http::Request<Body> {
320 let mut builder = axum::http::Request::builder().uri("/protected");
321 if let Some(c) = cookie {
322 builder = builder.header(COOKIE, c);
323 }
324 builder.body(Body::empty()).unwrap()
325 }
326
327 #[tokio::test]
328 async fn authenticated_request_passes_through() {
329 let (ath, cookie) = test_setup().await;
330 let app = auth_app(ath);
331 let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
332 assert_eq!(resp.status(), StatusCode::OK);
333 }
334
335 #[tokio::test]
336 async fn unauthenticated_request_returns_401() {
337 let (ath, _) = test_setup().await;
338 let app = auth_app(ath);
339 let resp = app.oneshot(make_request(None)).await.unwrap();
340 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
341 }
342
343 #[tokio::test]
344 async fn require_role_with_correct_role_passes() {
345 let (ath, cookie) = test_setup().await;
346
347 let rn = allowthem_core::RoleName::new("admin");
349 let role = ath.db().create_role(&rn, None).await.unwrap();
350 let email = Email::new("user@example.com".into()).unwrap();
351 let user = ath.db().get_user_by_email(&email).await.unwrap();
352 ath.db().assign_role(&user.id, &role.id).await.unwrap();
353
354 let app = role_app(ath, "admin");
355 let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
356 assert_eq!(resp.status(), StatusCode::OK);
357 }
358
359 #[tokio::test]
360 async fn require_role_with_wrong_role_returns_403() {
361 let (ath, cookie) = test_setup().await;
362 let app = role_app(ath, "admin");
364 let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
365 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
366 }
367
368 #[tokio::test]
369 async fn require_permission_with_correct_permission_passes() {
370 let (ath, cookie) = test_setup().await;
371
372 let pn = allowthem_core::PermissionName::new("posts:write");
374 let perm = ath.db().create_permission(&pn, None).await.unwrap();
375 let email = Email::new("user@example.com".into()).unwrap();
376 let user = ath.db().get_user_by_email(&email).await.unwrap();
377 ath.db()
378 .assign_permission_to_user(&user.id, &perm.id)
379 .await
380 .unwrap();
381
382 let app = perm_app(ath, "posts:write");
383 let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
384 assert_eq!(resp.status(), StatusCode::OK);
385 }
386
387 #[tokio::test]
388 async fn require_permission_with_missing_permission_returns_403() {
389 let (ath, cookie) = test_setup().await;
390 let app = perm_app(ath, "posts:write");
392 let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
393 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
394 }
395}