Skip to main content

tower_sessions_ext/
service.rs

1//! A middleware that provides [`Session`] as a request extension.
2use std::{
3    borrow::Cow,
4    fmt,
5    future::Future,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9};
10
11use http::{Request, Response};
12use time::OffsetDateTime;
13#[cfg(any(feature = "signed", feature = "private"))]
14use tower_cookies::Key;
15use tower_cookies::{Cookie, CookieManager, Cookies, cookie::SameSite};
16use tower_layer::Layer;
17use tower_service::Service;
18use tracing::Instrument;
19
20use crate::{
21    Session, SessionStore,
22    session::{self, Expiry},
23};
24
25#[doc(hidden)]
26pub trait CookieController: Clone + Send + 'static {
27    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>>;
28    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>);
29    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>);
30}
31
32#[doc(hidden)]
33#[derive(Debug, Clone)]
34pub struct PlaintextCookie;
35
36impl CookieController for PlaintextCookie {
37    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
38        cookies.get(name).map(Cookie::into_owned)
39    }
40
41    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
42        cookies.add(cookie)
43    }
44
45    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
46        cookies.remove(cookie)
47    }
48}
49
50#[doc(hidden)]
51#[cfg(feature = "signed")]
52#[derive(Debug, Clone)]
53pub struct SignedCookie {
54    key: Key,
55}
56
57#[cfg(feature = "signed")]
58impl CookieController for SignedCookie {
59    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
60        cookies.signed(&self.key).get(name).map(Cookie::into_owned)
61    }
62
63    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
64        cookies.signed(&self.key).add(cookie)
65    }
66
67    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
68        cookies.signed(&self.key).remove(cookie)
69    }
70}
71
72#[doc(hidden)]
73#[cfg(feature = "private")]
74#[derive(Debug, Clone)]
75pub struct PrivateCookie {
76    key: Key,
77}
78
79#[cfg(feature = "private")]
80impl CookieController for PrivateCookie {
81    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
82        cookies.private(&self.key).get(name).map(Cookie::into_owned)
83    }
84
85    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
86        cookies.private(&self.key).add(cookie)
87    }
88
89    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
90        cookies.private(&self.key).remove(cookie)
91    }
92}
93
94#[derive(Clone)]
95struct SessionConfig<'a> {
96    name: Cow<'a, str>,
97    http_only: bool,
98    same_site: SameSite,
99    expiry: Option<Expiry>,
100    secure: bool,
101    path: Cow<'a, str>,
102    domain: Option<Cow<'a, str>>,
103    always_save: bool,
104}
105
106impl fmt::Debug for SessionConfig<'_> {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        f.debug_struct("SessionConfig")
109            .field("name", &self.name)
110            .field("http_only", &self.http_only)
111            .field("same_site", &self.same_site)
112            .field("expiry", &self.expiry)
113            .field("secure", &self.secure)
114            .field("path", &self.path)
115            .field("domain", &self.domain)
116            .field("always_save", &self.always_save)
117            .finish()
118    }
119}
120
121impl<'a> SessionConfig<'a> {
122    fn build_cookie(self, session_id: session::Id, expiry: Option<Expiry>) -> Cookie<'a> {
123        let mut cookie_builder = Cookie::build((self.name, session_id.to_string()))
124            .http_only(self.http_only)
125            .same_site(self.same_site)
126            .secure(self.secure)
127            .path(self.path);
128
129        cookie_builder = match expiry {
130            Some(Expiry::OnInactivity(duration)) => cookie_builder.max_age(duration),
131            Some(Expiry::AtDateTime(datetime)) => {
132                cookie_builder.max_age(datetime - OffsetDateTime::now_utc())
133            }
134            // Session cookie: no Max-Age so the browser treats it as ending when the session ends.
135            Some(Expiry::OnSessionEnd(_)) | None => cookie_builder,
136        };
137
138        if let Some(domain) = self.domain {
139            cookie_builder = cookie_builder.domain(domain);
140        }
141
142        cookie_builder.build()
143    }
144}
145
146impl Default for SessionConfig<'_> {
147    fn default() -> Self {
148        Self {
149            name: "id".into(), /* See: https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html#session-id-name-fingerprinting */
150            http_only: true,
151            same_site: SameSite::Strict,
152            expiry: None, // TODO: Is `Max-Age: "Session"` the right default?
153            secure: true,
154            path: "/".into(),
155            domain: None,
156            always_save: false,
157        }
158    }
159}
160
161/// A middleware that provides [`Session`] as a request extension.
162#[derive(Debug, Clone)]
163pub struct SessionManager<S, Store: SessionStore, C: CookieController = PlaintextCookie> {
164    inner: S,
165    session_store: Arc<Store>,
166    session_config: SessionConfig<'static>,
167    cookie_controller: C,
168}
169
170impl<S, Store: SessionStore> SessionManager<S, Store> {
171    /// Create a new [`SessionManager`].
172    pub fn new(inner: S, session_store: Store) -> Self {
173        Self {
174            inner,
175            session_store: Arc::new(session_store),
176            session_config: Default::default(),
177            cookie_controller: PlaintextCookie,
178        }
179    }
180}
181
182impl<ReqBody, ResBody, S, Store: SessionStore, C: CookieController> Service<Request<ReqBody>>
183    for SessionManager<S, Store, C>
184where
185    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
186    S::Future: Send,
187    ReqBody: Send + 'static,
188    ResBody: Default + Send,
189{
190    type Response = S::Response;
191    type Error = S::Error;
192    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
193
194    #[inline]
195    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
196        self.inner.poll_ready(cx)
197    }
198
199    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
200        let span = tracing::debug_span!("call");
201
202        let session_store = self.session_store.clone();
203        let session_config = self.session_config.clone();
204        let cookie_controller = self.cookie_controller.clone();
205
206        // Because the inner service can panic until ready, we need to ensure we only
207        // use the ready service.
208        //
209        // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
210        let clone = self.inner.clone();
211        let mut inner = std::mem::replace(&mut self.inner, clone);
212
213        Box::pin(
214            async move {
215                let Some(cookies) = req.extensions().get::<_>().cloned() else {
216                    // In practice this should never happen because we wrap `CookieManager`
217                    // directly.
218                    tracing::error!("missing cookies request extension");
219                    return Ok(Response::default());
220                };
221
222                let session_cookie = cookie_controller.get(&cookies, &session_config.name);
223                let session_id = session_cookie.as_ref().and_then(|cookie| {
224                    cookie
225                        .value()
226                        .parse::<session::Id>()
227                        .map_err(|err| {
228                            tracing::warn!(
229                                err = %err,
230                                "possibly suspicious activity: malformed session id"
231                            )
232                        })
233                        .ok()
234                });
235
236                let session = Session::new(session_id, session_store, session_config.expiry);
237
238                req.extensions_mut().insert(session.clone());
239
240                let res = inner.call(req).await?;
241
242                let modified = session.is_modified();
243                let empty = session.is_empty().await;
244
245                tracing::trace!(
246                    modified = modified,
247                    empty = empty,
248                    always_save = session_config.always_save,
249                    "session response state",
250                );
251
252                match session_cookie {
253                    Some(mut cookie) if empty => {
254                        tracing::debug!("removing session cookie");
255
256                        // Path and domain must be manually set to ensure a proper removal cookie is
257                        // constructed.
258                        //
259                        // See: https://docs.rs/cookie/latest/cookie/struct.CookieJar.html#method.remove
260                        cookie.set_path(session_config.path);
261                        if let Some(domain) = session_config.domain {
262                            cookie.set_domain(domain);
263                        }
264
265                        cookie_controller.remove(&cookies, cookie);
266                    }
267
268                    _ if (modified || session_config.always_save)
269                        && !empty
270                        && !res.status().is_server_error() =>
271                    {
272                        tracing::debug!("saving session");
273                        if let Err(err) = session.save().await {
274                            tracing::error!(err = %err, "failed to save session");
275
276                            let mut res = Response::default();
277                            *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
278                            return Ok(res);
279                        }
280
281                        let Some(session_id) = session.id() else {
282                            tracing::error!("missing session id");
283
284                            let mut res = Response::default();
285                            *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
286                            return Ok(res);
287                        };
288
289                        let expiry = session.expiry();
290                        let session_cookie = session_config.build_cookie(session_id, expiry);
291
292                        tracing::debug!("adding session cookie");
293                        cookie_controller.add(&cookies, session_cookie);
294                    }
295
296                    _ => (),
297                };
298
299                Ok(res)
300            }
301            .instrument(span),
302        )
303    }
304}
305
306/// A layer for providing [`Session`] as a request extension.
307#[derive(Debug, Clone)]
308pub struct SessionManagerLayer<Store: SessionStore, C: CookieController = PlaintextCookie> {
309    session_store: Arc<Store>,
310    session_config: SessionConfig<'static>,
311    cookie_controller: C,
312}
313
314impl<Store: SessionStore, C: CookieController> SessionManagerLayer<Store, C> {
315    /// Configures the name of the cookie used for the session.
316    /// The default value is `"id"`.
317    ///
318    /// # Examples
319    ///
320    /// ```rust
321    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
322    ///
323    /// let session_store = MemoryStore::default();
324    /// let session_service = SessionManagerLayer::new(session_store).with_name("my.sid");
325    /// ```
326    pub fn with_name<N: Into<Cow<'static, str>>>(mut self, name: N) -> Self {
327        self.session_config.name = name.into();
328        self
329    }
330
331    /// Configures the `"HttpOnly"` attribute of the cookie used for the
332    /// session.
333    ///
334    /// # ⚠️ **Warning: Cross-site scripting risk**
335    ///
336    /// Applications should generally **not** override the default value of
337    /// `true`. If you do, you are exposing your application to increased risk
338    /// of cookie theft via techniques like cross-site scripting.
339    ///
340    /// # Examples
341    ///
342    /// ```rust
343    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
344    ///
345    /// let session_store = MemoryStore::default();
346    /// let session_service = SessionManagerLayer::new(session_store).with_http_only(true);
347    /// ```
348    pub fn with_http_only(mut self, http_only: bool) -> Self {
349        self.session_config.http_only = http_only;
350        self
351    }
352
353    /// Configures the `"SameSite"` attribute of the cookie used for the
354    /// session.
355    /// The default value is [`SameSite::Strict`].
356    ///
357    /// # Examples
358    ///
359    /// ```rust
360    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::SameSite};
361    ///
362    /// let session_store = MemoryStore::default();
363    /// let session_service = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax);
364    /// ```
365    pub fn with_same_site(mut self, same_site: SameSite) -> Self {
366        self.session_config.same_site = same_site;
367        self
368    }
369
370    /// Configures the `"Max-Age"` attribute of the cookie used for the session.
371    /// The default value is `None`.
372    ///
373    /// # Examples
374    ///
375    /// ```rust
376    /// use time::Duration;
377    /// use tower_sessions_ext::{Expiry, MemoryStore, SessionManagerLayer};
378    ///
379    /// let session_store = MemoryStore::default();
380    /// let session_expiry = Expiry::OnInactivity(Duration::hours(1));
381    /// let session_service = SessionManagerLayer::new(session_store).with_expiry(session_expiry);
382    /// ```
383    pub fn with_expiry(mut self, expiry: Expiry) -> Self {
384        self.session_config.expiry = Some(expiry);
385        self
386    }
387
388    /// Configures the `"Secure"` attribute of the cookie used for the session.
389    /// The default value is `true`.
390    ///
391    /// # Examples
392    ///
393    /// ```rust
394    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
395    ///
396    /// let session_store = MemoryStore::default();
397    /// let session_service = SessionManagerLayer::new(session_store).with_secure(true);
398    /// ```
399    pub fn with_secure(mut self, secure: bool) -> Self {
400        self.session_config.secure = secure;
401        self
402    }
403
404    /// Configures the `"Path"` attribute of the cookie used for the session.
405    /// The default value is `"/"`.
406    ///
407    /// # Examples
408    ///
409    /// ```rust
410    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
411    ///
412    /// let session_store = MemoryStore::default();
413    /// let session_service = SessionManagerLayer::new(session_store).with_path("/some/path");
414    /// ```
415    pub fn with_path<P: Into<Cow<'static, str>>>(mut self, path: P) -> Self {
416        self.session_config.path = path.into();
417        self
418    }
419
420    /// Configures the `"Domain"` attribute of the cookie used for the session.
421    /// The default value is `None`.
422    ///
423    /// # Examples
424    ///
425    /// ```rust
426    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
427    ///
428    /// let session_store = MemoryStore::default();
429    /// let session_service = SessionManagerLayer::new(session_store).with_domain("localhost");
430    /// ```
431    pub fn with_domain<D: Into<Cow<'static, str>>>(mut self, domain: D) -> Self {
432        self.session_config.domain = Some(domain.into());
433        self
434    }
435
436    /// Configures whether unmodified session should be saved on read or not.
437    /// When the value is `true`, the session will be saved even if it was not
438    /// changed.
439    ///
440    /// This is useful when you want to reset [`Session`] expiration time
441    /// on any valid request at the cost of higher [`SessionStore`] write
442    /// activity and transmitting `set-cookie` header with each response.
443    ///
444    /// It makes sense to use this setting with relative session expiration
445    /// values, such as `Expiry::OnInactivity(Duration)`. This setting will
446    /// _not_ cause session id to be cycled on save.
447    ///
448    /// The default value is `false`.
449    ///
450    /// # Examples
451    ///
452    /// ```rust
453    /// use time::Duration;
454    /// use tower_sessions_ext::{Expiry, MemoryStore, SessionManagerLayer};
455    ///
456    /// let session_store = MemoryStore::default();
457    /// let session_expiry = Expiry::OnInactivity(Duration::hours(1));
458    /// let session_service = SessionManagerLayer::new(session_store)
459    ///     .with_expiry(session_expiry)
460    ///     .with_always_save(true);
461    /// ```
462    pub fn with_always_save(mut self, always_save: bool) -> Self {
463        self.session_config.always_save = always_save;
464        self
465    }
466
467    /// Manages the session cookie via a signed interface.
468    ///
469    /// See [`SignedCookies`](tower_cookies::SignedCookies).
470    ///
471    /// ```rust
472    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::Key};
473    ///
474    /// # /*
475    /// let key = { /* a cryptographically random key >= 64 bytes */ };
476    /// # */
477    /// # let key: &Vec<u8> = &(0..64).collect();
478    /// # let key: &[u8] = &key[..];
479    /// # let key = Key::try_from(key).unwrap();
480    ///
481    /// let session_store = MemoryStore::default();
482    /// let session_service = SessionManagerLayer::new(session_store).with_signed(key);
483    /// ```
484    #[cfg(feature = "signed")]
485    pub fn with_signed(self, key: Key) -> SessionManagerLayer<Store, SignedCookie> {
486        SessionManagerLayer::<Store, SignedCookie> {
487            session_store: self.session_store,
488            session_config: self.session_config,
489            cookie_controller: SignedCookie { key },
490        }
491    }
492
493    /// Manages the session cookie via an encrypted interface.
494    ///
495    /// See [`PrivateCookies`](tower_cookies::PrivateCookies).
496    ///
497    /// ```rust
498    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::Key};
499    ///
500    /// # /*
501    /// let key = { /* a cryptographically random key >= 64 bytes */ };
502    /// # */
503    /// # let key: &Vec<u8> = &(0..64).collect();
504    /// # let key: &[u8] = &key[..];
505    /// # let key = Key::try_from(key).unwrap();
506    ///
507    /// let session_store = MemoryStore::default();
508    /// let session_service = SessionManagerLayer::new(session_store).with_private(key);
509    /// ```
510    #[cfg(feature = "private")]
511    pub fn with_private(self, key: Key) -> SessionManagerLayer<Store, PrivateCookie> {
512        SessionManagerLayer::<Store, PrivateCookie> {
513            session_store: self.session_store,
514            session_config: self.session_config,
515            cookie_controller: PrivateCookie { key },
516        }
517    }
518}
519
520impl<Store: SessionStore> SessionManagerLayer<Store> {
521    /// Create a new [`SessionManagerLayer`] with the provided session store
522    /// and default cookie configuration.
523    ///
524    /// # Examples
525    ///
526    /// ```rust
527    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
528    ///
529    /// let session_store = MemoryStore::default();
530    /// let session_service = SessionManagerLayer::new(session_store);
531    /// ```
532    pub fn new(session_store: Store) -> Self {
533        let session_config = SessionConfig::default();
534
535        Self {
536            session_store: Arc::new(session_store),
537            session_config,
538            cookie_controller: PlaintextCookie,
539        }
540    }
541}
542
543impl<S, Store: SessionStore, C: CookieController> Layer<S> for SessionManagerLayer<Store, C> {
544    type Service = CookieManager<SessionManager<S, Store, C>>;
545
546    fn layer(&self, inner: S) -> Self::Service {
547        let session_manager = SessionManager {
548            inner,
549            session_store: self.session_store.clone(),
550            session_config: self.session_config.clone(),
551            cookie_controller: self.cookie_controller.clone(),
552        };
553
554        CookieManager::new(session_manager)
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use std::str::FromStr;
561
562    use anyhow::anyhow;
563    use axum::body::Body;
564    use tower::{ServiceBuilder, ServiceExt};
565    use tower_sessions_ext_core::session::DEFAULT_DURATION;
566    use tower_sessions_ext_memory_store::MemoryStore;
567
568    use super::*;
569    use crate::session::{Id, Record};
570
571    async fn handler(req: Request<Body>) -> anyhow::Result<Response<Body>> {
572        let session = req
573            .extensions()
574            .get::<Session>()
575            .ok_or(anyhow!("Missing session"))?;
576
577        session.insert("foo", 42).await?;
578
579        Ok(Response::new(Body::empty()))
580    }
581
582    async fn noop_handler(_: Request<Body>) -> anyhow::Result<Response<Body>> {
583        Ok(Response::new(Body::empty()))
584    }
585
586    #[tokio::test]
587    async fn basic_service_test() -> anyhow::Result<()> {
588        let session_store = MemoryStore::default();
589        let session_layer = SessionManagerLayer::new(session_store);
590        let svc = ServiceBuilder::new()
591            .layer(session_layer)
592            .service_fn(handler);
593
594        let req = Request::builder().body(Body::empty())?;
595        let res = svc.clone().oneshot(req).await?;
596
597        let session = res.headers().get(http::header::SET_COOKIE);
598        assert!(session.is_some());
599
600        let req = Request::builder()
601            .header(http::header::COOKIE, session.unwrap())
602            .body(Body::empty())?;
603        let res = svc.oneshot(req).await?;
604
605        assert!(res.headers().get(http::header::SET_COOKIE).is_none());
606
607        Ok(())
608    }
609
610    #[tokio::test]
611    async fn bogus_cookie_test() -> anyhow::Result<()> {
612        let session_store = MemoryStore::default();
613        let session_layer = SessionManagerLayer::new(session_store);
614        let svc = ServiceBuilder::new()
615            .layer(session_layer)
616            .service_fn(handler);
617
618        let req = Request::builder().body(Body::empty())?;
619        let res = svc.clone().oneshot(req).await?;
620
621        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
622
623        let req = Request::builder()
624            .header(http::header::COOKIE, "id=bogus")
625            .body(Body::empty())?;
626        let res = svc.oneshot(req).await?;
627
628        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
629
630        Ok(())
631    }
632
633    #[tokio::test]
634    async fn no_set_cookie_test() -> anyhow::Result<()> {
635        let session_store = MemoryStore::default();
636        let session_layer = SessionManagerLayer::new(session_store);
637        let svc = ServiceBuilder::new()
638            .layer(session_layer)
639            .service_fn(noop_handler);
640
641        let req = Request::builder().body(Body::empty())?;
642        let res = svc.oneshot(req).await?;
643
644        assert!(res.headers().get(http::header::SET_COOKIE).is_none());
645
646        Ok(())
647    }
648
649    #[tokio::test]
650    async fn name_test() -> anyhow::Result<()> {
651        let session_store = MemoryStore::default();
652        let session_layer = SessionManagerLayer::new(session_store).with_name("my.sid");
653        let svc = ServiceBuilder::new()
654            .layer(session_layer)
655            .service_fn(handler);
656
657        let req = Request::builder().body(Body::empty())?;
658        let res = svc.oneshot(req).await?;
659
660        assert!(cookie_value_matches(&res, |s| s.starts_with("my.sid=")));
661
662        Ok(())
663    }
664
665    #[tokio::test]
666    async fn http_only_test() -> anyhow::Result<()> {
667        let session_store = MemoryStore::default();
668        let session_layer = SessionManagerLayer::new(session_store);
669        let svc = ServiceBuilder::new()
670            .layer(session_layer)
671            .service_fn(handler);
672
673        let req = Request::builder().body(Body::empty())?;
674        let res = svc.oneshot(req).await?;
675
676        assert!(cookie_value_matches(&res, |s| s.contains("HttpOnly")));
677
678        let session_store = MemoryStore::default();
679        let session_layer = SessionManagerLayer::new(session_store).with_http_only(false);
680        let svc = ServiceBuilder::new()
681            .layer(session_layer)
682            .service_fn(handler);
683
684        let req = Request::builder().body(Body::empty())?;
685        let res = svc.oneshot(req).await?;
686
687        assert!(cookie_value_matches(&res, |s| !s.contains("HttpOnly")));
688
689        Ok(())
690    }
691
692    #[tokio::test]
693    async fn same_site_strict_test() -> anyhow::Result<()> {
694        let session_store = MemoryStore::default();
695        let session_layer =
696            SessionManagerLayer::new(session_store).with_same_site(SameSite::Strict);
697        let svc = ServiceBuilder::new()
698            .layer(session_layer)
699            .service_fn(handler);
700
701        let req = Request::builder().body(Body::empty())?;
702        let res = svc.oneshot(req).await?;
703
704        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Strict")));
705
706        Ok(())
707    }
708
709    #[tokio::test]
710    async fn same_site_lax_test() -> anyhow::Result<()> {
711        let session_store = MemoryStore::default();
712        let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax);
713        let svc = ServiceBuilder::new()
714            .layer(session_layer)
715            .service_fn(handler);
716
717        let req = Request::builder().body(Body::empty())?;
718        let res = svc.oneshot(req).await?;
719
720        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Lax")));
721
722        Ok(())
723    }
724
725    #[tokio::test]
726    async fn same_site_none_test() -> anyhow::Result<()> {
727        let session_store = MemoryStore::default();
728        let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::None);
729        let svc = ServiceBuilder::new()
730            .layer(session_layer)
731            .service_fn(handler);
732
733        let req = Request::builder().body(Body::empty())?;
734        let res = svc.oneshot(req).await?;
735
736        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=None")));
737
738        Ok(())
739    }
740
741    #[tokio::test]
742    async fn expiry_on_session_end_test() -> anyhow::Result<()> {
743        let session_store = MemoryStore::default();
744        let session_layer = SessionManagerLayer::new(session_store)
745            .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION));
746        let svc = ServiceBuilder::new()
747            .layer(session_layer)
748            .service_fn(handler);
749
750        let req = Request::builder().body(Body::empty())?;
751        let res = svc.oneshot(req).await?;
752
753        assert!(cookie_value_matches(&res, |s| !s.contains("Max-Age")));
754
755        Ok(())
756    }
757
758    #[tokio::test]
759    async fn expiry_on_inactivity_test() -> anyhow::Result<()> {
760        let session_store = MemoryStore::default();
761        let inactivity_duration = time::Duration::hours(2);
762        let session_layer = SessionManagerLayer::new(session_store)
763            .with_expiry(Expiry::OnInactivity(inactivity_duration));
764        let svc = ServiceBuilder::new()
765            .layer(session_layer)
766            .service_fn(handler);
767
768        let req = Request::builder().body(Body::empty())?;
769        let res = svc.oneshot(req).await?;
770
771        let expected_max_age = inactivity_duration.whole_seconds();
772        assert!(cookie_has_expected_max_age(&res, expected_max_age));
773
774        Ok(())
775    }
776
777    #[tokio::test]
778    async fn expiry_at_date_time_test() -> anyhow::Result<()> {
779        let session_store = MemoryStore::default();
780        let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
781        let session_layer =
782            SessionManagerLayer::new(session_store).with_expiry(Expiry::AtDateTime(expiry_time));
783        let svc = ServiceBuilder::new()
784            .layer(session_layer)
785            .service_fn(handler);
786
787        let req = Request::builder().body(Body::empty())?;
788        let res = svc.oneshot(req).await?;
789
790        let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
791        assert!(cookie_has_expected_max_age(&res, expected_max_age));
792
793        Ok(())
794    }
795
796    #[tokio::test]
797    async fn expiry_on_session_end_always_save_test() -> anyhow::Result<()> {
798        let session_store = MemoryStore::default();
799        let session_layer = SessionManagerLayer::new(session_store.clone())
800            .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION))
801            .with_always_save(true);
802        let mut svc = ServiceBuilder::new()
803            .layer(session_layer)
804            .service_fn(handler);
805
806        let req1 = Request::builder().body(Body::empty())?;
807        let res1 = svc.call(req1).await?;
808        let sid1 = get_session_id(&res1);
809        let rec1 = get_record(&session_store, &sid1).await;
810        let req2 = Request::builder()
811            .header(http::header::COOKIE, format!("id={}", sid1))
812            .body(Body::empty())?;
813        let res2 = svc.call(req2).await?;
814        let sid2 = get_session_id(&res2);
815        let rec2 = get_record(&session_store, &sid2).await;
816
817        assert!(cookie_value_matches(&res2, |s| !s.contains("Max-Age")));
818        assert!(sid1 == sid2);
819        assert!(rec1.expiry_date < rec2.expiry_date);
820
821        Ok(())
822    }
823
824    #[tokio::test]
825    async fn expiry_on_inactivity_always_save_test() -> anyhow::Result<()> {
826        let session_store = MemoryStore::default();
827        let inactivity_duration = time::Duration::hours(2);
828        let session_layer = SessionManagerLayer::new(session_store.clone())
829            .with_expiry(Expiry::OnInactivity(inactivity_duration))
830            .with_always_save(true);
831        let mut svc = ServiceBuilder::new()
832            .layer(session_layer)
833            .service_fn(handler);
834
835        let req1 = Request::builder().body(Body::empty())?;
836        let res1 = svc.call(req1).await?;
837        let sid1 = get_session_id(&res1);
838        let rec1 = get_record(&session_store, &sid1).await;
839        let req2 = Request::builder()
840            .header(http::header::COOKIE, format!("id={}", sid1))
841            .body(Body::empty())?;
842        let res2 = svc.call(req2).await?;
843        let sid2 = get_session_id(&res2);
844        let rec2 = get_record(&session_store, &sid2).await;
845
846        let expected_max_age = inactivity_duration.whole_seconds();
847        assert!(cookie_has_expected_max_age(&res2, expected_max_age));
848        assert!(sid1 == sid2);
849        assert!(rec1.expiry_date < rec2.expiry_date);
850
851        Ok(())
852    }
853
854    #[tokio::test]
855    async fn expiry_at_date_time_always_save_test() -> anyhow::Result<()> {
856        let session_store = MemoryStore::default();
857        let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
858        let session_layer = SessionManagerLayer::new(session_store.clone())
859            .with_expiry(Expiry::AtDateTime(expiry_time))
860            .with_always_save(true);
861        let mut svc = ServiceBuilder::new()
862            .layer(session_layer)
863            .service_fn(handler);
864
865        let req1 = Request::builder().body(Body::empty())?;
866        let res1 = svc.call(req1).await?;
867        let sid1 = get_session_id(&res1);
868        let rec1 = get_record(&session_store, &sid1).await;
869        let req2 = Request::builder()
870            .header(http::header::COOKIE, format!("id={}", sid1))
871            .body(Body::empty())?;
872        let res2 = svc.call(req2).await?;
873        let sid2 = get_session_id(&res2);
874        let rec2 = get_record(&session_store, &sid2).await;
875
876        let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
877        assert!(cookie_has_expected_max_age(&res2, expected_max_age));
878        assert!(sid1 == sid2);
879        assert!(rec1.expiry_date == rec2.expiry_date);
880
881        Ok(())
882    }
883
884    #[tokio::test]
885    async fn secure_test() -> anyhow::Result<()> {
886        let session_store = MemoryStore::default();
887        let session_layer = SessionManagerLayer::new(session_store).with_secure(true);
888        let svc = ServiceBuilder::new()
889            .layer(session_layer)
890            .service_fn(handler);
891
892        let req = Request::builder().body(Body::empty())?;
893        let res = svc.oneshot(req).await?;
894
895        assert!(cookie_value_matches(&res, |s| s.contains("Secure")));
896
897        let session_store = MemoryStore::default();
898        let session_layer = SessionManagerLayer::new(session_store).with_secure(false);
899        let svc = ServiceBuilder::new()
900            .layer(session_layer)
901            .service_fn(handler);
902
903        let req = Request::builder().body(Body::empty())?;
904        let res = svc.oneshot(req).await?;
905
906        assert!(cookie_value_matches(&res, |s| !s.contains("Secure")));
907
908        Ok(())
909    }
910
911    #[tokio::test]
912    async fn path_test() -> anyhow::Result<()> {
913        let session_store = MemoryStore::default();
914        let session_layer = SessionManagerLayer::new(session_store).with_path("/foo/bar");
915        let svc = ServiceBuilder::new()
916            .layer(session_layer)
917            .service_fn(handler);
918
919        let req = Request::builder().body(Body::empty())?;
920        let res = svc.oneshot(req).await?;
921
922        assert!(cookie_value_matches(&res, |s| s.contains("Path=/foo/bar")));
923
924        Ok(())
925    }
926
927    #[tokio::test]
928    async fn domain_test() -> anyhow::Result<()> {
929        let session_store = MemoryStore::default();
930        let session_layer = SessionManagerLayer::new(session_store).with_domain("example.com");
931        let svc = ServiceBuilder::new()
932            .layer(session_layer)
933            .service_fn(handler);
934
935        let req = Request::builder().body(Body::empty())?;
936        let res = svc.oneshot(req).await?;
937
938        assert!(cookie_value_matches(&res, |s| s.contains("Domain=example.com")));
939
940        Ok(())
941    }
942
943    #[cfg(feature = "signed")]
944    #[tokio::test]
945    async fn signed_test() -> anyhow::Result<()> {
946        let key = Key::generate();
947        let session_store = MemoryStore::default();
948        let session_layer = SessionManagerLayer::new(session_store).with_signed(key);
949        let svc = ServiceBuilder::new()
950            .layer(session_layer)
951            .service_fn(handler);
952
953        let req = Request::builder().body(Body::empty())?;
954        let res = svc.oneshot(req).await?;
955
956        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
957
958        Ok(())
959    }
960
961    #[cfg(feature = "private")]
962    #[tokio::test]
963    async fn private_test() -> anyhow::Result<()> {
964        let key = Key::generate();
965        let session_store = MemoryStore::default();
966        let session_layer = SessionManagerLayer::new(session_store).with_private(key);
967        let svc = ServiceBuilder::new()
968            .layer(session_layer)
969            .service_fn(handler);
970
971        let req = Request::builder().body(Body::empty())?;
972        let res = svc.oneshot(req).await?;
973
974        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
975
976        Ok(())
977    }
978
979    fn cookie_value_matches<F>(res: &Response<Body>, matcher: F) -> bool
980    where
981        F: FnOnce(&str) -> bool,
982    {
983        res.headers()
984            .get(http::header::SET_COOKIE)
985            .is_some_and(|set_cookie| set_cookie.to_str().is_ok_and(matcher))
986    }
987
988    fn cookie_has_expected_max_age(res: &Response<Body>, expected_value: i64) -> bool {
989        res.headers()
990            .get(http::header::SET_COOKIE)
991            .is_some_and(|set_cookie| {
992                set_cookie.to_str().is_ok_and(|s| {
993                    let max_age_value = s
994                        .split("Max-Age=")
995                        .nth(1)
996                        .unwrap_or_default()
997                        .split(';')
998                        .next()
999                        .unwrap_or_default()
1000                        .parse::<i64>()
1001                        .unwrap_or_default();
1002                    (max_age_value - expected_value).abs() <= 1
1003                })
1004            })
1005    }
1006
1007    fn get_session_id(res: &Response<Body>) -> String {
1008        res.headers()
1009            .get(http::header::SET_COOKIE)
1010            .unwrap()
1011            .to_str()
1012            .unwrap()
1013            .split("id=")
1014            .nth(1)
1015            .unwrap()
1016            .split(";")
1017            .next()
1018            .unwrap()
1019            .to_string()
1020    }
1021
1022    async fn get_record(store: &impl SessionStore, id: &str) -> Record {
1023        store
1024            .load(&Id::from_str(id).unwrap())
1025            .await
1026            .unwrap()
1027            .unwrap()
1028    }
1029}