Skip to main content

axum_security/cookie/
builder.rs

1use std::{sync::Arc, time::Duration};
2
3use cookie_monster::{Cookie, CookieBuilder, SameSite};
4
5use crate::cookie::{
6    CookieContext, CookieContextInner, CookieStore, expiry::SessionExpiry, store::ErasedStore,
7};
8
9static DEFAULT_SESSION_COOKIE_NAME: &str = "session";
10static DEFAULT_DEV_SESSION_COOKIE_NAME: &str = "dev-session";
11
12pub struct CookieSessionBuilder<S> {
13    store: S,
14    pub(crate) dev: bool,
15    pub(crate) dev_cookie: CookieBuilder,
16    pub(crate) cookie: CookieBuilder,
17    pub(crate) expiry: Option<SessionExpiry>,
18}
19
20impl CookieSessionBuilder<()> {
21    pub fn new() -> CookieSessionBuilder<()> {
22        Self {
23            store: (),
24            dev: false,
25            expiry: None,
26            // Make sure to use "/" as path so all paths can see the cookie in dev mode.
27            dev_cookie: Cookie::named(DEFAULT_DEV_SESSION_COOKIE_NAME)
28                .path("/")
29                .same_site(SameSite::Lax),
30            cookie: Cookie::named(DEFAULT_SESSION_COOKIE_NAME)
31                .same_site(SameSite::Strict)
32                .http_only()
33                .secure(),
34        }
35    }
36}
37
38impl<S> CookieSessionBuilder<S> {
39    pub fn cookie(mut self, f: impl FnOnce(CookieBuilder) -> CookieBuilder) -> Self {
40        self.cookie = f(Cookie::named(DEFAULT_SESSION_COOKIE_NAME));
41        self
42    }
43
44    pub fn dev_cookie(mut self, f: impl FnOnce(CookieBuilder) -> CookieBuilder) -> Self {
45        self.dev_cookie = f(Cookie::named(DEFAULT_DEV_SESSION_COOKIE_NAME));
46        self
47    }
48
49    pub fn use_dev_cookie(mut self, dev: bool) -> Self {
50        self.dev = dev;
51        self
52    }
53
54    pub fn use_normal_cookie(self, prod: bool) -> Self {
55        self.use_dev_cookie(!prod)
56    }
57
58    pub fn expires_max_age(mut self) -> Self {
59        self.expiry = Some(SessionExpiry::CookieMaxAge);
60        self
61    }
62
63    pub fn expires_after(mut self, session_duration: Duration) -> Self {
64        self.expiry = Some(SessionExpiry::Duration(session_duration));
65        self
66    }
67
68    pub fn expires_none(mut self) -> Self {
69        self.expiry = None;
70        self
71    }
72
73    pub fn store<S1>(self, store: S1) -> CookieSessionBuilder<S1> {
74        CookieSessionBuilder {
75            store,
76            dev: self.dev,
77            dev_cookie: self.dev_cookie,
78            cookie: self.cookie,
79            expiry: self.expiry,
80        }
81    }
82}
83
84impl<S> CookieSessionBuilder<S> {
85    pub fn build<T>(self) -> CookieContext<T>
86    where
87        T: Send + Sync + 'static,
88        S: CookieStore<State = T>,
89    {
90        let session_expiry = self.expiry.map(|e| match e {
91            SessionExpiry::CookieMaxAge => self.cookie.get_max_age().expect("No max-age set"),
92            SessionExpiry::Duration(duration) => duration,
93        });
94
95        let store = ErasedStore::new(self.store);
96
97        let handle = if let Some(expiry) = session_expiry
98            && store.spawn_maintenance_task()
99        {
100            let this = store.clone();
101            Some(tokio::spawn(super::expiry::maintenance_task(this, expiry)))
102        } else {
103            None
104        };
105
106        CookieContext(Arc::new(CookieContextInner {
107            store,
108            cookie_opts: if self.dev {
109                self.dev_cookie
110            } else {
111                self.cookie
112            },
113            handle,
114        }))
115    }
116}
117
118impl Default for CookieSessionBuilder<()> {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124#[cfg(test)]
125mod cookie {
126    use cookie_monster::CookieJar;
127
128    use crate::cookie::{CookieContext, MemStore};
129
130    #[derive(Clone)]
131    struct User {
132        id: i32,
133    }
134
135    #[tokio::test]
136    async fn create() {
137        let cookie_context = CookieContext::builder()
138            .store(MemStore::new())
139            .build::<User>();
140
141        let test_user = User { id: 1 };
142        let test_user_id = test_user.id;
143
144        let cookie = cookie_context.create_session(test_user).await.unwrap();
145
146        let mut jar = CookieJar::new();
147        jar.add(cookie);
148
149        let user = cookie_context.load_from_jar(&jar).await.unwrap();
150
151        assert!(user.is_some());
152        assert!(test_user_id == user.unwrap().state.id);
153    }
154
155    #[tokio::test]
156    async fn delete() {
157        let cookie_context = CookieContext::builder()
158            .store(MemStore::new())
159            .build::<User>();
160
161        let test_user = User { id: 1 };
162        let test_user_id = test_user.id;
163
164        let cookie = cookie_context.create_session(test_user).await.unwrap();
165
166        let user = cookie_context.remove_session_cookie(&cookie).await.unwrap();
167
168        assert!(user.is_some());
169        assert!(test_user_id == user.unwrap().state.id);
170
171        let after = cookie_context.load_from_cookie(&cookie).await.unwrap();
172        assert!(after.is_none());
173    }
174
175    #[tokio::test]
176    async fn defaults() {
177        let cookie = CookieContext::builder()
178            .store(MemStore::new())
179            .build::<()>()
180            .create_session(())
181            .await
182            .unwrap();
183
184        assert!(cookie.name() == "session");
185
186        let cookie = CookieContext::builder()
187            .store(MemStore::new())
188            .use_dev_cookie(true)
189            .build::<()>()
190            .create_session(())
191            .await
192            .unwrap();
193
194        assert!(cookie.name() == "dev-session");
195
196        let cookie = CookieContext::builder()
197            .store(MemStore::new())
198            .cookie(|c| c.name("test"))
199            .dev_cookie(|c| c.name("not-test"))
200            .build::<()>()
201            .create_session(())
202            .await
203            .unwrap();
204
205        assert!(cookie.name() == "test");
206
207        let cookie = CookieContext::builder()
208            .store(MemStore::new())
209            .cookie(|c| c.name("not-test"))
210            .dev_cookie(|c| c.name("test"))
211            .use_dev_cookie(true)
212            .build::<()>()
213            .create_session(())
214            .await
215            .unwrap();
216
217        assert!(cookie.name() == "test");
218    }
219}