Skip to main content

axum_security/cookie/
session.rs

1use std::{convert::Infallible, hash::Hash};
2
3use axum::{
4    extract::{FromRequestParts, OptionalFromRequestParts},
5    http::{Extensions, StatusCode, request::Parts},
6};
7
8use crate::cookie::SessionId;
9
10#[derive(Clone, Debug)]
11pub struct CookieSession<S> {
12    pub session_id: SessionId,
13    pub created_at: u64,
14    pub state: S,
15}
16
17impl<S> CookieSession<S> {
18    pub fn new(id: SessionId, created_at: u64, value: S) -> Self {
19        Self {
20            session_id: id,
21            created_at,
22            state: value,
23        }
24    }
25
26    pub fn from_extensions(extensions: &mut Extensions) -> Option<Self>
27    where
28        S: Send + Sync + 'static,
29    {
30        extensions.remove()
31    }
32}
33
34impl<S> Hash for CookieSession<S> {
35    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
36        self.session_id.hash(state)
37    }
38}
39
40impl<S> Eq for CookieSession<S> {}
41
42impl<S> PartialEq for CookieSession<S> {
43    fn eq(&self, other: &Self) -> bool {
44        self.session_id == other.session_id
45    }
46}
47
48impl<S, T> FromRequestParts<S> for CookieSession<T>
49where
50    S: Send + Sync,
51    T: Send + Sync + 'static,
52{
53    type Rejection = StatusCode;
54
55    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, StatusCode> {
56        if let Some(session) = parts.extensions.remove() {
57            Ok(session)
58        } else {
59            Err(StatusCode::UNAUTHORIZED)
60        }
61    }
62}
63
64impl<S, T> OptionalFromRequestParts<S> for CookieSession<T>
65where
66    S: Send + Sync,
67    T: Send + Sync + 'static,
68{
69    type Rejection = Infallible;
70
71    async fn from_request_parts(
72        parts: &mut Parts,
73        _state: &S,
74    ) -> Result<Option<Self>, Self::Rejection> {
75        Ok(parts.extensions.remove())
76    }
77}
78
79#[cfg(test)]
80mod extract_cookie {
81    use axum::{
82        extract::FromRequestParts,
83        http::{Request, StatusCode},
84    };
85
86    use crate::cookie::{CookieSession, SessionId};
87
88    #[tokio::test]
89    async fn extract() {
90        let cookie = CookieSession::new(SessionId::new(), 0, ());
91
92        let (mut parts, _) = Request::builder()
93            .extension(cookie.clone())
94            .body(())
95            .unwrap()
96            .into_parts();
97
98        let extracted_cookie = CookieSession::<()>::from_request_parts(&mut parts, &())
99            .await
100            .unwrap();
101
102        assert!(cookie.session_id == extracted_cookie.session_id);
103        assert!(cookie.created_at == extracted_cookie.created_at);
104    }
105
106    #[tokio::test]
107    async fn extract_rejection() {
108        let (mut parts, _) = Request::builder().body(()).unwrap().into_parts();
109
110        let rejection = CookieSession::<()>::from_request_parts(&mut parts, &())
111            .await
112            .unwrap_err();
113
114        assert!(rejection == StatusCode::UNAUTHORIZED);
115    }
116}