1use axum::RequestPartsExt;
2use axum::extract::{FromRequestParts, OptionalFromRequestParts, Request, State};
3use axum::http::StatusCode;
4use axum::http::request::Parts;
5use axum::middleware::Next;
6use axum::response::Response;
7use axum_extra::extract::PrivateCookieJar;
8use axum_extra::extract::cookie::Cookie;
9use cookie::{SameSite, time};
10use kellnr_appstate::AppStateData;
11use kellnr_common::util::generate_rand_string;
12use kellnr_settings::constants::COOKIE_SESSION_ID;
13use time::Duration;
14use tracing::error;
15
16use crate::error::RouteError;
17
18pub(crate) async fn create_session_jar(
21 cookies: PrivateCookieJar,
22 app_state: &AppStateData,
23 username: &str,
24) -> Result<PrivateCookieJar, RouteError> {
25 let session_token = generate_rand_string(12);
26 app_state
27 .db
28 .add_session_token(username, &session_token)
29 .await
30 .map_err(|e| {
31 error!("Failed to create session: {e}");
32 RouteError::Status(StatusCode::INTERNAL_SERVER_ERROR)
33 })?;
34 let session_age_seconds = app_state.settings.registry.session_age_seconds as i64;
35 Ok(cookies.add(
36 Cookie::build((COOKIE_SESSION_ID, session_token))
37 .max_age(Duration::seconds(session_age_seconds))
38 .same_site(SameSite::Strict)
39 .path("/"),
40 ))
41}
42
43pub trait Name {
44 fn name(&self) -> String;
45 fn new(name: String) -> Self;
46}
47
48pub struct AdminUser(pub String);
49
50impl AdminUser {
51 pub fn name(&self) -> &str {
52 &self.0
53 }
54}
55
56impl Name for AdminUser {
57 fn name(&self) -> String {
58 self.0.clone()
59 }
60 fn new(name: String) -> Self {
61 Self(name)
62 }
63}
64
65impl FromRequestParts<AppStateData> for AdminUser {
66 type Rejection = RouteError;
67
68 async fn from_request_parts(
69 parts: &mut Parts,
70 state: &AppStateData,
71 ) -> Result<Self, Self::Rejection> {
72 let jar: PrivateCookieJar = parts.extract_with_state(state).await.unwrap();
73 let session_cookie = jar.get(COOKIE_SESSION_ID);
74 match session_cookie {
75 Some(cookie) => match state.db.validate_session(cookie.value()).await {
76 Ok((name, true)) => Ok(Self(name)),
77 Ok((_, false)) => Err(RouteError::InsufficientPrivileges),
78 Err(_) => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
79 },
80 None => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
81 }
82 }
83}
84
85pub struct NormalUser(pub String);
86impl Name for NormalUser {
87 fn name(&self) -> String {
88 self.0.clone()
89 }
90 fn new(name: String) -> Self {
91 Self(name)
92 }
93}
94
95pub struct AnyUser(pub String);
96impl Name for AnyUser {
97 fn name(&self) -> String {
98 self.0.clone()
99 }
100 fn new(name: String) -> Self {
101 Self(name)
102 }
103}
104
105#[derive(Debug)]
106pub enum MaybeUser {
107 Normal(String),
109 Admin(String),
110}
111
112impl MaybeUser {
113 pub fn name(&self) -> &str {
114 match self {
115 Self::Normal(name) | Self::Admin(name) => name,
116 }
117 }
118
119 pub fn assert_normal(&self) -> Result<(), RouteError> {
120 match self {
121 MaybeUser::Normal(_) => Ok(()),
122 MaybeUser::Admin(_) => Err(RouteError::InsufficientPrivileges),
123 }
124 }
125
126 pub fn assert_admin(&self) -> Result<(), RouteError> {
127 match self {
128 MaybeUser::Normal(_) => Err(RouteError::InsufficientPrivileges),
129 MaybeUser::Admin(_) => Ok(()),
130 }
131 }
132}
133
134impl FromRequestParts<AppStateData> for MaybeUser {
135 type Rejection = RouteError;
136
137 async fn from_request_parts(
138 parts: &mut Parts,
139 state: &AppStateData,
140 ) -> Result<Self, Self::Rejection> {
141 let jar: PrivateCookieJar = parts.extract_with_state(state).await.unwrap();
142 let session_cookie = jar.get(COOKIE_SESSION_ID);
143 match session_cookie {
144 Some(cookie) => match state.db.validate_session(cookie.value()).await {
145 Ok((name, true)) => Ok(Self::Admin(name)),
147 Ok((name, false)) => Ok(Self::Normal(name)),
149 Err(_) => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
150 },
151 None => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
152 }
153 }
154}
155
156impl OptionalFromRequestParts<AppStateData> for MaybeUser {
157 type Rejection = RouteError;
158
159 async fn from_request_parts(
160 parts: &mut Parts,
161 state: &AppStateData,
162 ) -> Result<Option<Self>, Self::Rejection> {
163 let jar: PrivateCookieJar = parts.extract_with_state(state).await.unwrap();
164 let session_cookie = jar.get(COOKIE_SESSION_ID);
165 match session_cookie {
166 Some(cookie) => match state.db.validate_session(cookie.value()).await {
167 Ok((name, true)) => Ok(Some(Self::Admin(name))),
169 Ok((name, false)) => Ok(Some(Self::Normal(name))),
171 Err(_) => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
172 },
173 None => Ok(None),
174 }
175 }
176}
177
178pub async fn session_auth_when_required(
181 State(state): State<AppStateData>,
182 jar: PrivateCookieJar,
183 request: Request,
184 next: Next,
185) -> Result<Response, RouteError> {
186 if !state.settings.registry.auth_required {
187 return Ok(next.run(request).await);
189 }
190 let session_cookie = jar.get(COOKIE_SESSION_ID);
191 match session_cookie {
192 Some(cookie) => match state.db.validate_session(cookie.value()).await {
193 Ok(_) => Ok(next.run(request).await),
195 Err(_) => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
197 },
198 None => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
200 }
201}
202
203#[cfg(test)]
204mod session_tests {
205 use std::sync::Arc;
206
207 use axum::Router;
208 use axum::body::Body;
209 use axum::http::header;
210 use axum::routing::get;
211 use cookie::Key;
212 use kellnr_db::DbProvider;
213 use kellnr_db::error::DbError;
214 use kellnr_db::mock::MockDb;
215 use kellnr_storage::cached_crate_storage::DynStorage;
216 use kellnr_storage::fs_storage::FSStorage;
217 use kellnr_storage::kellnr_crate_storage::KellnrCrateStorage;
218 use mockall::predicate::eq;
219 use tower::ServiceExt;
220
221 use super::*;
222 use crate::test_helper::encode_cookies;
223
224 type Result<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
225
226 async fn admin_endpoint(user: MaybeUser) -> Result<(), RouteError> {
227 user.assert_admin()?;
228 Ok(())
229 }
230
231 async fn normal_endpoint(user: MaybeUser) -> Result<(), RouteError> {
232 user.assert_normal()?;
233 Ok(())
234 }
235
236 async fn any_endpoint(_user: MaybeUser) {}
237
238 fn app(db: Arc<dyn DbProvider>) -> Router {
239 let settings = kellnr_settings::test_settings();
240 let storage = Box::new(FSStorage::new(&settings.crates_path()).unwrap()) as DynStorage;
241 Router::new()
242 .route("/admin", get(admin_endpoint))
243 .route("/normal", get(normal_endpoint))
244 .route("/any", get(any_endpoint))
245 .with_state(AppStateData {
246 db,
247 signing_key: Key::from(crate::test_helper::TEST_KEY),
248 crate_storage: Arc::new(KellnrCrateStorage::new(&settings, storage)),
249 settings: Arc::new(settings),
250 ..kellnr_appstate::test_state()
251 })
252 }
253
254 fn c1234() -> String {
257 encode_cookies([(COOKIE_SESSION_ID, "1234")])
258 }
259
260 #[tokio::test]
261 async fn admin_auth_works() -> Result {
262 let mut mock_db = MockDb::new();
263 mock_db
264 .expect_validate_session()
265 .with(eq("1234"))
266 .returning(|_st| Ok(("admin".to_string(), true)));
267
268 let r = app(Arc::new(mock_db))
269 .oneshot(
270 Request::get("/admin")
271 .header(
272 header::COOKIE,
273 encode_cookies([(COOKIE_SESSION_ID, "1234")]),
274 )
275 .body(Body::empty())?,
276 )
277 .await?;
278 assert!(r.status().is_success());
279
280 Ok(())
281 }
282
283 #[tokio::test]
284 async fn admin_auth_user_is_no_admin() -> Result {
285 let mut mock_db = MockDb::new();
286 mock_db
287 .expect_validate_session()
288 .with(eq("1234"))
289 .returning(|_st| Ok(("admin".to_string(), false)));
290
291 let r = app(Arc::new(mock_db))
292 .oneshot(
293 Request::get("/admin")
294 .header(header::COOKIE, c1234())
295 .body(Body::empty())?,
296 )
297 .await?;
298 assert_eq!(r.status(), StatusCode::FORBIDDEN);
299
300 Ok(())
301 }
302
303 #[tokio::test]
304 async fn admin_auth_user_but_no_cookie_sent() -> Result {
305 let mock_db = MockDb::new();
306
307 let r = app(Arc::new(mock_db))
308 .oneshot(Request::get("/admin").body(Body::empty())?)
309 .await?;
310 assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
311
312 Ok(())
313 }
314
315 #[tokio::test]
316 async fn admin_auth_user_but_no_cookie_in_store() -> Result {
317 let mut mock_db = MockDb::new();
318 mock_db
319 .expect_validate_session()
320 .with(eq("1234"))
321 .returning(|_st| Err(DbError::SessionNotFound));
322
323 let r = app(Arc::new(mock_db))
324 .oneshot(
325 Request::get("/admin")
326 .header(header::COOKIE, c1234())
327 .body(Body::empty())?,
328 )
329 .await?;
330 assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
331
332 Ok(())
333 }
334
335 #[tokio::test]
338 async fn normal_auth_works() -> Result {
339 let mut mock_db = MockDb::new();
340 mock_db
341 .expect_validate_session()
342 .with(eq("1234"))
343 .returning(|_st| Ok(("normal".to_string(), false)));
344
345 let r = app(Arc::new(mock_db))
346 .oneshot(
347 Request::get("/normal")
348 .header(header::COOKIE, c1234())
349 .body(Body::empty())?,
350 )
351 .await?;
352 assert_eq!(r.status(), StatusCode::OK);
353
354 Ok(())
355 }
356
357 #[tokio::test]
358 async fn normal_auth_user_is_admin() -> Result {
359 let mut mock_db = MockDb::new();
360 mock_db
361 .expect_validate_session()
362 .with(eq("1234"))
363 .returning(|_st| Ok(("normal".to_string(), true)));
364
365 let r = app(Arc::new(mock_db))
366 .oneshot(
367 Request::get("/normal")
368 .header(header::COOKIE, c1234())
369 .body(Body::empty())?,
370 )
371 .await?;
372 assert_eq!(r.status(), StatusCode::FORBIDDEN);
373
374 Ok(())
375 }
376
377 #[tokio::test]
378 async fn normal_auth_user_but_no_cookie_sent() -> Result {
379 let mock_db = MockDb::new();
380
381 let r = app(Arc::new(mock_db))
382 .oneshot(Request::get("/normal").body(Body::empty())?)
383 .await?;
384 assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
385
386 Ok(())
387 }
388
389 #[tokio::test]
390 async fn normal_auth_user_but_no_cookie_in_store() -> Result {
391 let mut mock_db = MockDb::new();
392 mock_db
393 .expect_validate_session()
394 .with(eq("1234"))
395 .returning(|_st| Err(DbError::SessionNotFound));
396
397 let r = app(Arc::new(mock_db))
398 .oneshot(
399 Request::get("/normal")
400 .header(header::COOKIE, c1234())
401 .body(Body::empty())?,
402 )
403 .await?;
404 assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
405
406 Ok(())
407 }
408
409 #[tokio::test]
412 async fn any_auth_user_is_normal() -> Result {
413 let mut mock_db = MockDb::new();
414 mock_db
415 .expect_validate_session()
416 .with(eq("1234"))
417 .returning(|_st| Ok(("guest".to_string(), false)));
418
419 let r = app(Arc::new(mock_db))
420 .oneshot(
421 Request::get("/any")
422 .header(header::COOKIE, c1234())
423 .body(Body::empty())?,
424 )
425 .await?;
426 assert_eq!(r.status(), StatusCode::OK);
427
428 Ok(())
429 }
430
431 #[tokio::test]
432 async fn any_auth_user_is_admin() -> Result {
433 let mut mock_db = MockDb::new();
434 mock_db
435 .expect_validate_session()
436 .with(eq("1234"))
437 .returning(|_st| Ok(("guest".to_string(), true)));
438
439 let r = app(Arc::new(mock_db))
440 .oneshot(
441 Request::get("/any")
442 .header(header::COOKIE, c1234())
443 .body(Body::empty())?,
444 )
445 .await?;
446 assert_eq!(r.status(), StatusCode::OK);
447
448 Ok(())
449 }
450
451 #[tokio::test]
452 async fn any_auth_user_but_no_cookie_sent() -> Result {
453 let mock_db = MockDb::new();
454
455 let r = app(Arc::new(mock_db))
456 .oneshot(Request::get("/any").body(Body::empty())?)
457 .await?;
458 assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
459 Ok(())
460 }
461
462 #[tokio::test]
463 async fn any_auth_user_but_no_cookie_in_store() -> Result {
464 let mut mock_db = MockDb::new();
465 mock_db
466 .expect_validate_session()
467 .with(eq("1234"))
468 .returning(|_st| Err(DbError::SessionNotFound));
469
470 let r = app(Arc::new(mock_db))
471 .oneshot(
472 Request::get("/any")
473 .header(header::COOKIE, c1234())
474 .body(Body::empty())?,
475 )
476 .await?;
477
478 assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
479 Ok(())
480 }
481}
482
483#[cfg(test)]
484mod auth_middleware_tests {
485 use std::sync::Arc;
486
487 use axum::Router;
488 use axum::body::Body;
489 use axum::http::header;
490 use axum::middleware::from_fn_with_state;
491 use axum::routing::get;
492 use cookie::Key;
493 use kellnr_db::DbProvider;
494 use kellnr_db::error::DbError;
495 use kellnr_db::mock::MockDb;
496 use kellnr_settings::Settings;
497 use mockall::predicate::eq;
498 use tower::ServiceExt;
499
500 use super::*;
501 use crate::test_helper::encode_cookies;
502
503 fn app_required_auth(db: Arc<dyn DbProvider>) -> Router {
504 let settings = Settings::default();
505 let state = AppStateData {
506 db,
507 signing_key: Key::from(crate::test_helper::TEST_KEY),
508 settings: Arc::new(Settings {
509 registry: kellnr_settings::Registry {
510 auth_required: true,
511 ..kellnr_settings::Registry::default()
512 },
513 ..settings
514 }),
515 ..kellnr_appstate::test_state()
516 };
517 Router::new()
518 .route("/guarded", get(StatusCode::OK))
519 .route_layer(from_fn_with_state(
520 state.clone(),
521 session_auth_when_required,
522 ))
523 .route("/not_guarded", get(StatusCode::OK))
524 .with_state(state)
525 }
526
527 fn app_not_required_auth(db: Arc<dyn DbProvider>) -> Router {
528 let settings = Settings::default();
529 let state = AppStateData {
530 db,
531 signing_key: Key::from(crate::test_helper::TEST_KEY),
532 settings: Arc::new(settings),
533 ..kellnr_appstate::test_state()
534 };
535 Router::new()
536 .route("/guarded", get(StatusCode::OK))
537 .route_layer(from_fn_with_state(
538 state.clone(),
539 session_auth_when_required,
540 ))
541 .with_state(state)
542 }
543
544 type Result<T = ()> = std::result::Result<T, Box<dyn std::error::Error>>;
545
546 fn c1234() -> String {
547 encode_cookies([(COOKIE_SESSION_ID, "1234")])
548 }
549
550 #[tokio::test]
551 async fn guarded_route_with_valid_cookie() -> Result {
552 let mut mock_db = MockDb::new();
553 mock_db
554 .expect_validate_session()
555 .with(eq("1234"))
556 .returning(|_st| Ok(("guest".to_string(), false)));
557
558 let r = app_required_auth(Arc::new(mock_db))
559 .oneshot(
560 Request::get("/guarded")
561 .header(header::COOKIE, c1234())
562 .body(Body::empty())?,
563 )
564 .await?;
565 assert_eq!(r.status(), StatusCode::OK);
566
567 Ok(())
568 }
569
570 #[tokio::test]
571 async fn guarded_route_with_invalid_cookie() -> Result {
572 let mut mock_db = MockDb::new();
573 mock_db
574 .expect_validate_session()
575 .with(eq("1234"))
576 .returning(|_st| Err(DbError::SessionNotFound));
577
578 let r = app_required_auth(Arc::new(mock_db))
579 .oneshot(
580 Request::get("/guarded")
581 .header(header::COOKIE, c1234())
582 .body(Body::empty())?,
583 )
584 .await?;
585 assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
586
587 Ok(())
588 }
589
590 #[tokio::test]
591 async fn guarded_route_without_cookie() -> Result {
592 let mock_db = MockDb::new();
593
594 let r = app_required_auth(Arc::new(mock_db))
595 .oneshot(Request::get("/guarded").body(Body::empty())?)
596 .await?;
597 assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
598
599 Ok(())
600 }
601
602 #[tokio::test]
603 async fn not_guarded_route_without_cookie() -> Result {
604 let mock_db = MockDb::new();
605
606 let r = app_required_auth(Arc::new(mock_db))
607 .oneshot(Request::get("/not_guarded").body(Body::empty())?)
608 .await?;
609 assert_eq!(r.status(), StatusCode::OK);
610
611 Ok(())
612 }
613
614 #[tokio::test]
615 async fn app_not_required_auth_with_guarded_route() -> Result {
616 let mock_db = MockDb::new();
617
618 let r = app_not_required_auth(Arc::new(mock_db))
619 .oneshot(Request::get("/guarded").body(Body::empty())?)
620 .await?;
621 assert_eq!(r.status(), StatusCode::OK);
622
623 Ok(())
624 }
625}