axum_security/cookie/
builder.rs1use std::{sync::Arc, time::Duration};
2
3use cookie_monster::{Cookie, CookieBuilder, SameSite};
4
5use crate::cookie::{
6 CookieContext, CookieContextInner, CookieStore, expiry::SessionExpiry, store::ErasedStore,
7};
8
9static DEFAULT_SESSION_COOKIE_NAME: &str = "session";
10static DEFAULT_DEV_SESSION_COOKIE_NAME: &str = "dev-session";
11
12pub struct CookieSessionBuilder<S> {
13 store: S,
14 pub(crate) dev: bool,
15 pub(crate) dev_cookie: CookieBuilder,
16 pub(crate) cookie: CookieBuilder,
17 pub(crate) expiry: Option<SessionExpiry>,
18}
19
20impl CookieSessionBuilder<()> {
21 pub fn new() -> CookieSessionBuilder<()> {
22 Self {
23 store: (),
24 dev: false,
25 expiry: None,
26 dev_cookie: Cookie::named(DEFAULT_DEV_SESSION_COOKIE_NAME)
28 .path("/")
29 .same_site(SameSite::Lax),
30 cookie: Cookie::named(DEFAULT_SESSION_COOKIE_NAME)
31 .same_site(SameSite::Strict)
32 .http_only()
33 .secure(),
34 }
35 }
36}
37
38impl<S> CookieSessionBuilder<S> {
39 pub fn cookie(mut self, f: impl FnOnce(CookieBuilder) -> CookieBuilder) -> Self {
40 self.cookie = f(Cookie::named(DEFAULT_SESSION_COOKIE_NAME));
41 self
42 }
43
44 pub fn dev_cookie(mut self, f: impl FnOnce(CookieBuilder) -> CookieBuilder) -> Self {
45 self.dev_cookie = f(Cookie::named(DEFAULT_DEV_SESSION_COOKIE_NAME));
46 self
47 }
48
49 pub fn use_dev_cookie(mut self, dev: bool) -> Self {
50 self.dev = dev;
51 self
52 }
53
54 pub fn use_normal_cookie(self, prod: bool) -> Self {
55 self.use_dev_cookie(!prod)
56 }
57
58 pub fn expires_max_age(mut self) -> Self {
59 self.expiry = Some(SessionExpiry::CookieMaxAge);
60 self
61 }
62
63 pub fn expires_after(mut self, session_duration: Duration) -> Self {
64 self.expiry = Some(SessionExpiry::Duration(session_duration));
65 self
66 }
67
68 pub fn expires_none(mut self) -> Self {
69 self.expiry = None;
70 self
71 }
72
73 pub fn store<S1>(self, store: S1) -> CookieSessionBuilder<S1> {
74 CookieSessionBuilder {
75 store,
76 dev: self.dev,
77 dev_cookie: self.dev_cookie,
78 cookie: self.cookie,
79 expiry: self.expiry,
80 }
81 }
82}
83
84impl<S> CookieSessionBuilder<S> {
85 pub fn build<T>(self) -> CookieContext<T>
86 where
87 T: Send + Sync + 'static,
88 S: CookieStore<State = T>,
89 {
90 let session_expiry = self.expiry.map(|e| match e {
91 SessionExpiry::CookieMaxAge => self.cookie.get_max_age().expect("No max-age set"),
92 SessionExpiry::Duration(duration) => duration,
93 });
94
95 let store = ErasedStore::new(self.store);
96
97 let handle = if let Some(expiry) = session_expiry
98 && store.spawn_maintenance_task()
99 {
100 let this = store.clone();
101 Some(tokio::spawn(super::expiry::maintenance_task(this, expiry)))
102 } else {
103 None
104 };
105
106 CookieContext(Arc::new(CookieContextInner {
107 store,
108 cookie_opts: if self.dev {
109 self.dev_cookie
110 } else {
111 self.cookie
112 },
113 handle,
114 }))
115 }
116}
117
118impl Default for CookieSessionBuilder<()> {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124#[cfg(test)]
125mod cookie {
126 use cookie_monster::CookieJar;
127
128 use crate::cookie::{CookieContext, MemStore};
129
130 #[derive(Clone)]
131 struct User {
132 id: i32,
133 }
134
135 #[tokio::test]
136 async fn create() {
137 let cookie_context = CookieContext::builder()
138 .store(MemStore::new())
139 .build::<User>();
140
141 let test_user = User { id: 1 };
142 let test_user_id = test_user.id;
143
144 let cookie = cookie_context.create_session(test_user).await.unwrap();
145
146 let mut jar = CookieJar::new();
147 jar.add(cookie);
148
149 let user = cookie_context.load_from_jar(&jar).await.unwrap();
150
151 assert!(user.is_some());
152 assert!(test_user_id == user.unwrap().state.id);
153 }
154
155 #[tokio::test]
156 async fn delete() {
157 let cookie_context = CookieContext::builder()
158 .store(MemStore::new())
159 .build::<User>();
160
161 let test_user = User { id: 1 };
162 let test_user_id = test_user.id;
163
164 let cookie = cookie_context.create_session(test_user).await.unwrap();
165
166 let user = cookie_context.remove_session_cookie(&cookie).await.unwrap();
167
168 assert!(user.is_some());
169 assert!(test_user_id == user.unwrap().state.id);
170
171 let after = cookie_context.load_from_cookie(&cookie).await.unwrap();
172 assert!(after.is_none());
173 }
174
175 #[tokio::test]
176 async fn defaults() {
177 let cookie = CookieContext::builder()
178 .store(MemStore::new())
179 .build::<()>()
180 .create_session(())
181 .await
182 .unwrap();
183
184 assert!(cookie.name() == "session");
185
186 let cookie = CookieContext::builder()
187 .store(MemStore::new())
188 .use_dev_cookie(true)
189 .build::<()>()
190 .create_session(())
191 .await
192 .unwrap();
193
194 assert!(cookie.name() == "dev-session");
195
196 let cookie = CookieContext::builder()
197 .store(MemStore::new())
198 .cookie(|c| c.name("test"))
199 .dev_cookie(|c| c.name("not-test"))
200 .build::<()>()
201 .create_session(())
202 .await
203 .unwrap();
204
205 assert!(cookie.name() == "test");
206
207 let cookie = CookieContext::builder()
208 .store(MemStore::new())
209 .cookie(|c| c.name("not-test"))
210 .dev_cookie(|c| c.name("test"))
211 .use_dev_cookie(true)
212 .build::<()>()
213 .create_session(())
214 .await
215 .unwrap();
216
217 assert!(cookie.name() == "test");
218 }
219}