Skip to main content

kellnr_web_ui/
session.rs

1use axum::RequestPartsExt;
2use axum::extract::{FromRequestParts, OptionalFromRequestParts, Request, State};
3use axum::http::StatusCode;
4use axum::http::request::Parts;
5use axum::middleware::Next;
6use axum::response::Response;
7use axum_extra::extract::PrivateCookieJar;
8use axum_extra::extract::cookie::Cookie;
9use cookie::{SameSite, time};
10use kellnr_appstate::AppStateData;
11use kellnr_common::util::generate_rand_string;
12use kellnr_settings::constants::COOKIE_SESSION_ID;
13use time::Duration;
14use tracing::error;
15
16use crate::error::RouteError;
17
18/// Creates a new session for the user and returns a cookie jar with the session cookie set.
19/// Generates a token, persists it via db, and adds the cookie using `app_state` settings.
20pub(crate) async fn create_session_jar(
21    cookies: PrivateCookieJar,
22    app_state: &AppStateData,
23    username: &str,
24) -> Result<PrivateCookieJar, RouteError> {
25    let session_token = generate_rand_string(12);
26    app_state
27        .db
28        .add_session_token(username, &session_token)
29        .await
30        .map_err(|e| {
31            error!("Failed to create session: {e}");
32            RouteError::Status(StatusCode::INTERNAL_SERVER_ERROR)
33        })?;
34    let session_age_seconds = app_state.settings.registry.session_age_seconds as i64;
35    Ok(cookies.add(
36        Cookie::build((COOKIE_SESSION_ID, session_token))
37            .max_age(Duration::seconds(session_age_seconds))
38            .same_site(SameSite::Strict)
39            .path("/"),
40    ))
41}
42
43pub trait Name {
44    fn name(&self) -> String;
45    fn new(name: String) -> Self;
46}
47
48pub struct AdminUser(pub String);
49
50impl AdminUser {
51    pub fn name(&self) -> &str {
52        &self.0
53    }
54}
55
56impl Name for AdminUser {
57    fn name(&self) -> String {
58        self.0.clone()
59    }
60    fn new(name: String) -> Self {
61        Self(name)
62    }
63}
64
65impl FromRequestParts<AppStateData> for AdminUser {
66    type Rejection = RouteError;
67
68    async fn from_request_parts(
69        parts: &mut Parts,
70        state: &AppStateData,
71    ) -> Result<Self, Self::Rejection> {
72        let jar: PrivateCookieJar = parts.extract_with_state(state).await.unwrap();
73        let session_cookie = jar.get(COOKIE_SESSION_ID);
74        match session_cookie {
75            Some(cookie) => match state.db.validate_session(cookie.value()).await {
76                Ok((name, true)) => Ok(Self(name)),
77                Ok((_, false)) => Err(RouteError::InsufficientPrivileges),
78                Err(_) => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
79            },
80            None => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
81        }
82    }
83}
84
85pub struct NormalUser(pub String);
86impl Name for NormalUser {
87    fn name(&self) -> String {
88        self.0.clone()
89    }
90    fn new(name: String) -> Self {
91        Self(name)
92    }
93}
94
95pub struct AnyUser(pub String);
96impl Name for AnyUser {
97    fn name(&self) -> String {
98        self.0.clone()
99    }
100    fn new(name: String) -> Self {
101        Self(name)
102    }
103}
104
105#[derive(Debug)]
106pub enum MaybeUser {
107    // Consider using a db model or something?
108    Normal(String),
109    Admin(String),
110}
111
112impl MaybeUser {
113    pub fn name(&self) -> &str {
114        match self {
115            Self::Normal(name) | Self::Admin(name) => name,
116        }
117    }
118
119    pub fn assert_normal(&self) -> Result<(), RouteError> {
120        match self {
121            MaybeUser::Normal(_) => Ok(()),
122            MaybeUser::Admin(_) => Err(RouteError::InsufficientPrivileges),
123        }
124    }
125
126    pub fn assert_admin(&self) -> Result<(), RouteError> {
127        match self {
128            MaybeUser::Normal(_) => Err(RouteError::InsufficientPrivileges),
129            MaybeUser::Admin(_) => Ok(()),
130        }
131    }
132}
133
134impl FromRequestParts<AppStateData> for MaybeUser {
135    type Rejection = RouteError;
136
137    async fn from_request_parts(
138        parts: &mut Parts,
139        state: &AppStateData,
140    ) -> Result<Self, Self::Rejection> {
141        let jar: PrivateCookieJar = parts.extract_with_state(state).await.unwrap();
142        let session_cookie = jar.get(COOKIE_SESSION_ID);
143        match session_cookie {
144            Some(cookie) => match state.db.validate_session(cookie.value()).await {
145                // admin
146                Ok((name, true)) => Ok(Self::Admin(name)),
147                // not admin
148                Ok((name, false)) => Ok(Self::Normal(name)),
149                Err(_) => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
150            },
151            None => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
152        }
153    }
154}
155
156impl OptionalFromRequestParts<AppStateData> for MaybeUser {
157    type Rejection = RouteError;
158
159    async fn from_request_parts(
160        parts: &mut Parts,
161        state: &AppStateData,
162    ) -> Result<Option<Self>, Self::Rejection> {
163        let jar: PrivateCookieJar = parts.extract_with_state(state).await.unwrap();
164        let session_cookie = jar.get(COOKIE_SESSION_ID);
165        match session_cookie {
166            Some(cookie) => match state.db.validate_session(cookie.value()).await {
167                // admin
168                Ok((name, true)) => Ok(Some(Self::Admin(name))),
169                // not admin
170                Ok((name, false)) => Ok(Some(Self::Normal(name))),
171                Err(_) => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
172            },
173            None => Ok(None),
174        }
175    }
176}
177
178/// Middleware that checks if a user is logged in when `settings.registry.auth_required` is `true`
179/// If the user is not logged in, a 401 is returned.
180pub async fn session_auth_when_required(
181    State(state): State<AppStateData>,
182    jar: PrivateCookieJar,
183    request: Request,
184    next: Next,
185) -> Result<Response, RouteError> {
186    if !state.settings.registry.auth_required {
187        // If "auth_required" is "false", pass through.
188        return Ok(next.run(request).await);
189    }
190    let session_cookie = jar.get(COOKIE_SESSION_ID);
191    match session_cookie {
192        Some(cookie) => match state.db.validate_session(cookie.value()).await {
193            // user is logged in
194            Ok(_) => Ok(next.run(request).await),
195            // user is not logged in
196            Err(_) => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
197        },
198        // user is not logged in
199        None => Err(RouteError::Status(StatusCode::UNAUTHORIZED)),
200    }
201}
202
203#[cfg(test)]
204mod session_tests {
205    use std::sync::Arc;
206
207    use axum::Router;
208    use axum::body::Body;
209    use axum::http::header;
210    use axum::routing::get;
211    use cookie::Key;
212    use kellnr_db::DbProvider;
213    use kellnr_db::error::DbError;
214    use kellnr_db::mock::MockDb;
215    use kellnr_storage::cached_crate_storage::DynStorage;
216    use kellnr_storage::fs_storage::FSStorage;
217    use kellnr_storage::kellnr_crate_storage::KellnrCrateStorage;
218    use mockall::predicate::eq;
219    use tower::ServiceExt;
220
221    use super::*;
222    use crate::test_helper::encode_cookies;
223
224    type Result<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
225
226    async fn admin_endpoint(user: MaybeUser) -> Result<(), RouteError> {
227        user.assert_admin()?;
228        Ok(())
229    }
230
231    async fn normal_endpoint(user: MaybeUser) -> Result<(), RouteError> {
232        user.assert_normal()?;
233        Ok(())
234    }
235
236    async fn any_endpoint(_user: MaybeUser) {}
237
238    fn app(db: Arc<dyn DbProvider>) -> Router {
239        let settings = kellnr_settings::test_settings();
240        let storage = Box::new(FSStorage::new(&settings.crates_path()).unwrap()) as DynStorage;
241        Router::new()
242            .route("/admin", get(admin_endpoint))
243            .route("/normal", get(normal_endpoint))
244            .route("/any", get(any_endpoint))
245            .with_state(AppStateData {
246                db,
247                signing_key: Key::from(crate::test_helper::TEST_KEY),
248                crate_storage: Arc::new(KellnrCrateStorage::new(&settings, storage)),
249                settings: Arc::new(settings),
250                ..kellnr_appstate::test_state()
251            })
252    }
253
254    // AdminUser tests
255
256    fn c1234() -> String {
257        encode_cookies([(COOKIE_SESSION_ID, "1234")])
258    }
259
260    #[tokio::test]
261    async fn admin_auth_works() -> Result {
262        let mut mock_db = MockDb::new();
263        mock_db
264            .expect_validate_session()
265            .with(eq("1234"))
266            .returning(|_st| Ok(("admin".to_string(), true)));
267
268        let r = app(Arc::new(mock_db))
269            .oneshot(
270                Request::get("/admin")
271                    .header(
272                        header::COOKIE,
273                        encode_cookies([(COOKIE_SESSION_ID, "1234")]),
274                    )
275                    .body(Body::empty())?,
276            )
277            .await?;
278        assert!(r.status().is_success());
279
280        Ok(())
281    }
282
283    #[tokio::test]
284    async fn admin_auth_user_is_no_admin() -> Result {
285        let mut mock_db = MockDb::new();
286        mock_db
287            .expect_validate_session()
288            .with(eq("1234"))
289            .returning(|_st| Ok(("admin".to_string(), false)));
290
291        let r = app(Arc::new(mock_db))
292            .oneshot(
293                Request::get("/admin")
294                    .header(header::COOKIE, c1234())
295                    .body(Body::empty())?,
296            )
297            .await?;
298        assert_eq!(r.status(), StatusCode::FORBIDDEN);
299
300        Ok(())
301    }
302
303    #[tokio::test]
304    async fn admin_auth_user_but_no_cookie_sent() -> Result {
305        let mock_db = MockDb::new();
306
307        let r = app(Arc::new(mock_db))
308            .oneshot(Request::get("/admin").body(Body::empty())?)
309            .await?;
310        assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
311
312        Ok(())
313    }
314
315    #[tokio::test]
316    async fn admin_auth_user_but_no_cookie_in_store() -> Result {
317        let mut mock_db = MockDb::new();
318        mock_db
319            .expect_validate_session()
320            .with(eq("1234"))
321            .returning(|_st| Err(DbError::SessionNotFound));
322
323        let r = app(Arc::new(mock_db))
324            .oneshot(
325                Request::get("/admin")
326                    .header(header::COOKIE, c1234())
327                    .body(Body::empty())?,
328            )
329            .await?;
330        assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
331
332        Ok(())
333    }
334
335    // NormalUser tests
336
337    #[tokio::test]
338    async fn normal_auth_works() -> Result {
339        let mut mock_db = MockDb::new();
340        mock_db
341            .expect_validate_session()
342            .with(eq("1234"))
343            .returning(|_st| Ok(("normal".to_string(), false)));
344
345        let r = app(Arc::new(mock_db))
346            .oneshot(
347                Request::get("/normal")
348                    .header(header::COOKIE, c1234())
349                    .body(Body::empty())?,
350            )
351            .await?;
352        assert_eq!(r.status(), StatusCode::OK);
353
354        Ok(())
355    }
356
357    #[tokio::test]
358    async fn normal_auth_user_is_admin() -> Result {
359        let mut mock_db = MockDb::new();
360        mock_db
361            .expect_validate_session()
362            .with(eq("1234"))
363            .returning(|_st| Ok(("normal".to_string(), true)));
364
365        let r = app(Arc::new(mock_db))
366            .oneshot(
367                Request::get("/normal")
368                    .header(header::COOKIE, c1234())
369                    .body(Body::empty())?,
370            )
371            .await?;
372        assert_eq!(r.status(), StatusCode::FORBIDDEN);
373
374        Ok(())
375    }
376
377    #[tokio::test]
378    async fn normal_auth_user_but_no_cookie_sent() -> Result {
379        let mock_db = MockDb::new();
380
381        let r = app(Arc::new(mock_db))
382            .oneshot(Request::get("/normal").body(Body::empty())?)
383            .await?;
384        assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
385
386        Ok(())
387    }
388
389    #[tokio::test]
390    async fn normal_auth_user_but_no_cookie_in_store() -> Result {
391        let mut mock_db = MockDb::new();
392        mock_db
393            .expect_validate_session()
394            .with(eq("1234"))
395            .returning(|_st| Err(DbError::SessionNotFound));
396
397        let r = app(Arc::new(mock_db))
398            .oneshot(
399                Request::get("/normal")
400                    .header(header::COOKIE, c1234())
401                    .body(Body::empty())?,
402            )
403            .await?;
404        assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
405
406        Ok(())
407    }
408
409    // Guest User tests
410
411    #[tokio::test]
412    async fn any_auth_user_is_normal() -> Result {
413        let mut mock_db = MockDb::new();
414        mock_db
415            .expect_validate_session()
416            .with(eq("1234"))
417            .returning(|_st| Ok(("guest".to_string(), false)));
418
419        let r = app(Arc::new(mock_db))
420            .oneshot(
421                Request::get("/any")
422                    .header(header::COOKIE, c1234())
423                    .body(Body::empty())?,
424            )
425            .await?;
426        assert_eq!(r.status(), StatusCode::OK);
427
428        Ok(())
429    }
430
431    #[tokio::test]
432    async fn any_auth_user_is_admin() -> Result {
433        let mut mock_db = MockDb::new();
434        mock_db
435            .expect_validate_session()
436            .with(eq("1234"))
437            .returning(|_st| Ok(("guest".to_string(), true)));
438
439        let r = app(Arc::new(mock_db))
440            .oneshot(
441                Request::get("/any")
442                    .header(header::COOKIE, c1234())
443                    .body(Body::empty())?,
444            )
445            .await?;
446        assert_eq!(r.status(), StatusCode::OK);
447
448        Ok(())
449    }
450
451    #[tokio::test]
452    async fn any_auth_user_but_no_cookie_sent() -> Result {
453        let mock_db = MockDb::new();
454
455        let r = app(Arc::new(mock_db))
456            .oneshot(Request::get("/any").body(Body::empty())?)
457            .await?;
458        assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
459        Ok(())
460    }
461
462    #[tokio::test]
463    async fn any_auth_user_but_no_cookie_in_store() -> Result {
464        let mut mock_db = MockDb::new();
465        mock_db
466            .expect_validate_session()
467            .with(eq("1234"))
468            .returning(|_st| Err(DbError::SessionNotFound));
469
470        let r = app(Arc::new(mock_db))
471            .oneshot(
472                Request::get("/any")
473                    .header(header::COOKIE, c1234())
474                    .body(Body::empty())?,
475            )
476            .await?;
477
478        assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
479        Ok(())
480    }
481}
482
483#[cfg(test)]
484mod auth_middleware_tests {
485    use std::sync::Arc;
486
487    use axum::Router;
488    use axum::body::Body;
489    use axum::http::header;
490    use axum::middleware::from_fn_with_state;
491    use axum::routing::get;
492    use cookie::Key;
493    use kellnr_db::DbProvider;
494    use kellnr_db::error::DbError;
495    use kellnr_db::mock::MockDb;
496    use kellnr_settings::Settings;
497    use mockall::predicate::eq;
498    use tower::ServiceExt;
499
500    use super::*;
501    use crate::test_helper::encode_cookies;
502
503    fn app_required_auth(db: Arc<dyn DbProvider>) -> Router {
504        let settings = Settings::default();
505        let state = AppStateData {
506            db,
507            signing_key: Key::from(crate::test_helper::TEST_KEY),
508            settings: Arc::new(Settings {
509                registry: kellnr_settings::Registry {
510                    auth_required: true,
511                    ..kellnr_settings::Registry::default()
512                },
513                ..settings
514            }),
515            ..kellnr_appstate::test_state()
516        };
517        Router::new()
518            .route("/guarded", get(StatusCode::OK))
519            .route_layer(from_fn_with_state(
520                state.clone(),
521                session_auth_when_required,
522            ))
523            .route("/not_guarded", get(StatusCode::OK))
524            .with_state(state)
525    }
526
527    fn app_not_required_auth(db: Arc<dyn DbProvider>) -> Router {
528        let settings = Settings::default();
529        let state = AppStateData {
530            db,
531            signing_key: Key::from(crate::test_helper::TEST_KEY),
532            settings: Arc::new(settings),
533            ..kellnr_appstate::test_state()
534        };
535        Router::new()
536            .route("/guarded", get(StatusCode::OK))
537            .route_layer(from_fn_with_state(
538                state.clone(),
539                session_auth_when_required,
540            ))
541            .with_state(state)
542    }
543
544    type Result<T = ()> = std::result::Result<T, Box<dyn std::error::Error>>;
545
546    fn c1234() -> String {
547        encode_cookies([(COOKIE_SESSION_ID, "1234")])
548    }
549
550    #[tokio::test]
551    async fn guarded_route_with_valid_cookie() -> Result {
552        let mut mock_db = MockDb::new();
553        mock_db
554            .expect_validate_session()
555            .with(eq("1234"))
556            .returning(|_st| Ok(("guest".to_string(), false)));
557
558        let r = app_required_auth(Arc::new(mock_db))
559            .oneshot(
560                Request::get("/guarded")
561                    .header(header::COOKIE, c1234())
562                    .body(Body::empty())?,
563            )
564            .await?;
565        assert_eq!(r.status(), StatusCode::OK);
566
567        Ok(())
568    }
569
570    #[tokio::test]
571    async fn guarded_route_with_invalid_cookie() -> Result {
572        let mut mock_db = MockDb::new();
573        mock_db
574            .expect_validate_session()
575            .with(eq("1234"))
576            .returning(|_st| Err(DbError::SessionNotFound));
577
578        let r = app_required_auth(Arc::new(mock_db))
579            .oneshot(
580                Request::get("/guarded")
581                    .header(header::COOKIE, c1234())
582                    .body(Body::empty())?,
583            )
584            .await?;
585        assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
586
587        Ok(())
588    }
589
590    #[tokio::test]
591    async fn guarded_route_without_cookie() -> Result {
592        let mock_db = MockDb::new();
593
594        let r = app_required_auth(Arc::new(mock_db))
595            .oneshot(Request::get("/guarded").body(Body::empty())?)
596            .await?;
597        assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
598
599        Ok(())
600    }
601
602    #[tokio::test]
603    async fn not_guarded_route_without_cookie() -> Result {
604        let mock_db = MockDb::new();
605
606        let r = app_required_auth(Arc::new(mock_db))
607            .oneshot(Request::get("/not_guarded").body(Body::empty())?)
608            .await?;
609        assert_eq!(r.status(), StatusCode::OK);
610
611        Ok(())
612    }
613
614    #[tokio::test]
615    async fn app_not_required_auth_with_guarded_route() -> Result {
616        let mock_db = MockDb::new();
617
618        let r = app_not_required_auth(Arc::new(mock_db))
619            .oneshot(Request::get("/guarded").body(Body::empty())?)
620            .await?;
621        assert_eq!(r.status(), StatusCode::OK);
622
623        Ok(())
624    }
625}