axum_security/cookie/
session.rs1use std::{convert::Infallible, hash::Hash};
2
3use axum::{
4 extract::{FromRequestParts, OptionalFromRequestParts},
5 http::{Extensions, StatusCode, request::Parts},
6};
7
8use crate::cookie::SessionId;
9
10#[derive(Clone, Debug)]
11pub struct CookieSession<S> {
12 pub session_id: SessionId,
13 pub created_at: u64,
14 pub state: S,
15}
16
17impl<S> CookieSession<S> {
18 pub fn new(id: SessionId, created_at: u64, value: S) -> Self {
19 Self {
20 session_id: id,
21 created_at,
22 state: value,
23 }
24 }
25
26 pub fn from_extensions(extensions: &mut Extensions) -> Option<Self>
27 where
28 S: Send + Sync + 'static,
29 {
30 extensions.remove()
31 }
32}
33
34impl<S> Hash for CookieSession<S> {
35 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
36 self.session_id.hash(state)
37 }
38}
39
40impl<S> Eq for CookieSession<S> {}
41
42impl<S> PartialEq for CookieSession<S> {
43 fn eq(&self, other: &Self) -> bool {
44 self.session_id == other.session_id
45 }
46}
47
48impl<S, T> FromRequestParts<S> for CookieSession<T>
49where
50 S: Send + Sync,
51 T: Send + Sync + 'static,
52{
53 type Rejection = StatusCode;
54
55 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, StatusCode> {
56 if let Some(session) = parts.extensions.remove() {
57 Ok(session)
58 } else {
59 Err(StatusCode::UNAUTHORIZED)
60 }
61 }
62}
63
64impl<S, T> OptionalFromRequestParts<S> for CookieSession<T>
65where
66 S: Send + Sync,
67 T: Send + Sync + 'static,
68{
69 type Rejection = Infallible;
70
71 async fn from_request_parts(
72 parts: &mut Parts,
73 _state: &S,
74 ) -> Result<Option<Self>, Self::Rejection> {
75 Ok(parts.extensions.remove())
76 }
77}
78
79#[cfg(test)]
80mod extract_cookie {
81 use axum::{
82 extract::FromRequestParts,
83 http::{Request, StatusCode},
84 };
85
86 use crate::cookie::{CookieSession, SessionId};
87
88 #[tokio::test]
89 async fn extract() {
90 let cookie = CookieSession::new(SessionId::new(), 0, ());
91
92 let (mut parts, _) = Request::builder()
93 .extension(cookie.clone())
94 .body(())
95 .unwrap()
96 .into_parts();
97
98 let extracted_cookie = CookieSession::<()>::from_request_parts(&mut parts, &())
99 .await
100 .unwrap();
101
102 assert!(cookie.session_id == extracted_cookie.session_id);
103 assert!(cookie.created_at == extracted_cookie.created_at);
104 }
105
106 #[tokio::test]
107 async fn extract_rejection() {
108 let (mut parts, _) = Request::builder().body(()).unwrap().into_parts();
109
110 let rejection = CookieSession::<()>::from_request_parts(&mut parts, &())
111 .await
112 .unwrap_err();
113
114 assert!(rejection == StatusCode::UNAUTHORIZED);
115 }
116}