1use std::{
3 borrow::Cow,
4 fmt,
5 future::Future,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9};
10
11use http::{Request, Response};
12use time::OffsetDateTime;
13#[cfg(any(feature = "signed", feature = "private"))]
14use tower_cookies::Key;
15use tower_cookies::{Cookie, CookieManager, Cookies, cookie::SameSite};
16use tower_layer::Layer;
17use tower_service::Service;
18use tracing::Instrument;
19
20use crate::{
21 Session, SessionStore,
22 session::{self, Expiry, OnExpireCallback},
23};
24
25#[doc(hidden)]
26pub trait CookieController: Clone + Send + 'static {
27 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>>;
28 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>);
29 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>);
30}
31
32#[doc(hidden)]
33#[derive(Debug, Clone)]
34pub struct PlaintextCookie;
35
36impl CookieController for PlaintextCookie {
37 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
38 cookies.get(name).map(Cookie::into_owned)
39 }
40
41 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
42 cookies.add(cookie)
43 }
44
45 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
46 cookies.remove(cookie)
47 }
48}
49
50#[doc(hidden)]
51#[cfg(feature = "signed")]
52#[derive(Debug, Clone)]
53pub struct SignedCookie {
54 key: Key,
55}
56
57#[cfg(feature = "signed")]
58impl CookieController for SignedCookie {
59 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
60 cookies.signed(&self.key).get(name).map(Cookie::into_owned)
61 }
62
63 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
64 cookies.signed(&self.key).add(cookie)
65 }
66
67 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
68 cookies.signed(&self.key).remove(cookie)
69 }
70}
71
72#[doc(hidden)]
73#[cfg(feature = "private")]
74#[derive(Debug, Clone)]
75pub struct PrivateCookie {
76 key: Key,
77}
78
79#[cfg(feature = "private")]
80impl CookieController for PrivateCookie {
81 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
82 cookies.private(&self.key).get(name).map(Cookie::into_owned)
83 }
84
85 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
86 cookies.private(&self.key).add(cookie)
87 }
88
89 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
90 cookies.private(&self.key).remove(cookie)
91 }
92}
93
94#[derive(Clone)]
95struct SessionConfig<'a> {
96 name: Cow<'a, str>,
97 http_only: bool,
98 same_site: SameSite,
99 expiry: Option<Expiry>,
100 secure: bool,
101 path: Cow<'a, str>,
102 domain: Option<Cow<'a, str>>,
103 always_save: bool,
104 on_expire: Option<OnExpireCallback>,
105}
106
107impl fmt::Debug for SessionConfig<'_> {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 f.debug_struct("SessionConfig")
110 .field("name", &self.name)
111 .field("http_only", &self.http_only)
112 .field("same_site", &self.same_site)
113 .field("expiry", &self.expiry)
114 .field("secure", &self.secure)
115 .field("path", &self.path)
116 .field("domain", &self.domain)
117 .field("always_save", &self.always_save)
118 .field("on_expire", &self.on_expire.as_ref().map(|_| "Some(_)"))
119 .finish()
120 }
121}
122
123impl<'a> SessionConfig<'a> {
124 fn build_cookie(self, session_id: session::Id, expiry: Option<Expiry>) -> Cookie<'a> {
125 let mut cookie_builder = Cookie::build((self.name, session_id.to_string()))
126 .http_only(self.http_only)
127 .same_site(self.same_site)
128 .secure(self.secure)
129 .path(self.path);
130
131 cookie_builder = match expiry {
132 Some(Expiry::OnInactivity(duration)) => cookie_builder.max_age(duration),
133 Some(Expiry::AtDateTime(datetime)) => {
134 cookie_builder.max_age(datetime - OffsetDateTime::now_utc())
135 }
136 Some(Expiry::OnSessionEnd(_)) | None => cookie_builder,
138 };
139
140 if let Some(domain) = self.domain {
141 cookie_builder = cookie_builder.domain(domain);
142 }
143
144 cookie_builder.build()
145 }
146}
147
148impl Default for SessionConfig<'_> {
149 fn default() -> Self {
150 Self {
151 name: "id".into(), http_only: true,
153 same_site: SameSite::Strict,
154 expiry: None, secure: true,
156 path: "/".into(),
157 domain: None,
158 always_save: false,
159 on_expire: None,
160 }
161 }
162}
163
164#[derive(Debug, Clone)]
166pub struct SessionManager<S, Store: SessionStore, C: CookieController = PlaintextCookie> {
167 inner: S,
168 session_store: Arc<Store>,
169 session_config: SessionConfig<'static>,
170 cookie_controller: C,
171}
172
173impl<S, Store: SessionStore> SessionManager<S, Store> {
174 pub fn new(inner: S, session_store: Store) -> Self {
176 Self {
177 inner,
178 session_store: Arc::new(session_store),
179 session_config: Default::default(),
180 cookie_controller: PlaintextCookie,
181 }
182 }
183}
184
185impl<ReqBody, ResBody, S, Store: SessionStore, C: CookieController> Service<Request<ReqBody>>
186 for SessionManager<S, Store, C>
187where
188 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
189 S::Future: Send,
190 ReqBody: Send + 'static,
191 ResBody: Default + Send,
192{
193 type Response = S::Response;
194 type Error = S::Error;
195 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
196
197 #[inline]
198 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199 self.inner.poll_ready(cx)
200 }
201
202 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
203 let span = tracing::debug_span!("call");
204
205 let session_store = self.session_store.clone();
206 let session_config = self.session_config.clone();
207 let cookie_controller = self.cookie_controller.clone();
208
209 let clone = self.inner.clone();
214 let mut inner = std::mem::replace(&mut self.inner, clone);
215
216 Box::pin(
217 async move {
218 let Some(cookies) = req.extensions().get::<_>().cloned() else {
219 tracing::error!("missing cookies request extension");
222 return Ok(Response::default());
223 };
224
225 let session_cookie = cookie_controller.get(&cookies, &session_config.name);
226 let session_id = session_cookie.as_ref().and_then(|cookie| {
227 cookie
228 .value()
229 .parse::<session::Id>()
230 .map_err(|err| {
231 tracing::warn!(
232 err = %err,
233 "possibly suspicious activity: malformed session id"
234 )
235 })
236 .ok()
237 });
238
239 let session = Session::new(
240 session_id,
241 session_store,
242 session_config.expiry,
243 session_config.on_expire.clone(),
244 );
245
246 req.extensions_mut().insert(session.clone());
247
248 let res = inner.call(req).await?;
249
250 let modified = session.is_modified();
251 let empty = session.is_empty().await;
252
253 tracing::trace!(
254 modified = modified,
255 empty = empty,
256 always_save = session_config.always_save,
257 "session response state",
258 );
259
260 match session_cookie {
261 Some(mut cookie) if empty => {
262 tracing::debug!("removing session cookie");
263
264 cookie.set_path(session_config.path);
269 if let Some(domain) = session_config.domain {
270 cookie.set_domain(domain);
271 }
272
273 cookie_controller.remove(&cookies, cookie);
274 }
275
276 _ if (modified || session_config.always_save)
277 && !empty
278 && !res.status().is_server_error() =>
279 {
280 tracing::debug!("saving session");
281 if let Err(err) = session.save().await {
282 tracing::error!(err = %err, "failed to save session");
283
284 let mut res = Response::default();
285 *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
286 return Ok(res);
287 }
288
289 let Some(session_id) = session.id() else {
290 tracing::error!("missing session id");
291
292 let mut res = Response::default();
293 *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
294 return Ok(res);
295 };
296
297 let expiry = session.expiry();
298 let session_cookie = session_config.build_cookie(session_id, expiry);
299
300 tracing::debug!("adding session cookie");
301 cookie_controller.add(&cookies, session_cookie);
302 }
303
304 _ => (),
305 };
306
307 Ok(res)
308 }
309 .instrument(span),
310 )
311 }
312}
313
314#[derive(Debug, Clone)]
316pub struct SessionManagerLayer<Store: SessionStore, C: CookieController = PlaintextCookie> {
317 session_store: Arc<Store>,
318 session_config: SessionConfig<'static>,
319 cookie_controller: C,
320}
321
322impl<Store: SessionStore, C: CookieController> SessionManagerLayer<Store, C> {
323 pub fn with_name<N: Into<Cow<'static, str>>>(mut self, name: N) -> Self {
335 self.session_config.name = name.into();
336 self
337 }
338
339 pub fn with_http_only(mut self, http_only: bool) -> Self {
357 self.session_config.http_only = http_only;
358 self
359 }
360
361 pub fn with_same_site(mut self, same_site: SameSite) -> Self {
374 self.session_config.same_site = same_site;
375 self
376 }
377
378 pub fn with_expiry(mut self, expiry: Expiry) -> Self {
392 self.session_config.expiry = Some(expiry);
393 self
394 }
395
396 pub fn with_secure(mut self, secure: bool) -> Self {
408 self.session_config.secure = secure;
409 self
410 }
411
412 pub fn with_path<P: Into<Cow<'static, str>>>(mut self, path: P) -> Self {
424 self.session_config.path = path.into();
425 self
426 }
427
428 pub fn with_domain<D: Into<Cow<'static, str>>>(mut self, domain: D) -> Self {
440 self.session_config.domain = Some(domain.into());
441 self
442 }
443
444 pub fn with_always_save(mut self, always_save: bool) -> Self {
471 self.session_config.always_save = always_save;
472 self
473 }
474
475 pub fn with_on_expire<F>(mut self, f: F) -> Self
498 where
499 F: Fn(session::Id) + Send + Sync + 'static,
500 {
501 self.session_config.on_expire = Some(Arc::new(f));
502 self
503 }
504
505 #[cfg(feature = "signed")]
523 pub fn with_signed(self, key: Key) -> SessionManagerLayer<Store, SignedCookie> {
524 SessionManagerLayer::<Store, SignedCookie> {
525 session_store: self.session_store,
526 session_config: self.session_config,
527 cookie_controller: SignedCookie { key },
528 }
529 }
530
531 #[cfg(feature = "private")]
549 pub fn with_private(self, key: Key) -> SessionManagerLayer<Store, PrivateCookie> {
550 SessionManagerLayer::<Store, PrivateCookie> {
551 session_store: self.session_store,
552 session_config: self.session_config,
553 cookie_controller: PrivateCookie { key },
554 }
555 }
556}
557
558impl<Store: SessionStore> SessionManagerLayer<Store> {
559 pub fn new(session_store: Store) -> Self {
571 let session_config = SessionConfig::default();
572
573 Self {
574 session_store: Arc::new(session_store),
575 session_config,
576 cookie_controller: PlaintextCookie,
577 }
578 }
579}
580
581impl<S, Store: SessionStore, C: CookieController> Layer<S> for SessionManagerLayer<Store, C> {
582 type Service = CookieManager<SessionManager<S, Store, C>>;
583
584 fn layer(&self, inner: S) -> Self::Service {
585 let session_manager = SessionManager {
586 inner,
587 session_store: self.session_store.clone(),
588 session_config: self.session_config.clone(),
589 cookie_controller: self.cookie_controller.clone(),
590 };
591
592 CookieManager::new(session_manager)
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 use std::str::FromStr;
599
600 use anyhow::anyhow;
601 use axum::body::Body;
602 use tower::{ServiceBuilder, ServiceExt};
603 use tower_sessions_ext_core::session::DEFAULT_DURATION;
604 use tower_sessions_ext_memory_store::MemoryStore;
605
606 use super::*;
607 use crate::session::{Id, Record};
608
609 async fn handler(req: Request<Body>) -> anyhow::Result<Response<Body>> {
610 let session = req
611 .extensions()
612 .get::<Session>()
613 .ok_or(anyhow!("Missing session"))?;
614
615 session.insert("foo", 42).await?;
616
617 Ok(Response::new(Body::empty()))
618 }
619
620 async fn noop_handler(_: Request<Body>) -> anyhow::Result<Response<Body>> {
621 Ok(Response::new(Body::empty()))
622 }
623
624 #[tokio::test]
625 async fn basic_service_test() -> anyhow::Result<()> {
626 let session_store = MemoryStore::default();
627 let session_layer = SessionManagerLayer::new(session_store);
628 let svc = ServiceBuilder::new()
629 .layer(session_layer)
630 .service_fn(handler);
631
632 let req = Request::builder().body(Body::empty())?;
633 let res = svc.clone().oneshot(req).await?;
634
635 let session = res.headers().get(http::header::SET_COOKIE);
636 assert!(session.is_some());
637
638 let req = Request::builder()
639 .header(http::header::COOKIE, session.unwrap())
640 .body(Body::empty())?;
641 let res = svc.oneshot(req).await?;
642
643 assert!(res.headers().get(http::header::SET_COOKIE).is_none());
644
645 Ok(())
646 }
647
648 #[tokio::test]
649 async fn bogus_cookie_test() -> anyhow::Result<()> {
650 let session_store = MemoryStore::default();
651 let session_layer = SessionManagerLayer::new(session_store);
652 let svc = ServiceBuilder::new()
653 .layer(session_layer)
654 .service_fn(handler);
655
656 let req = Request::builder().body(Body::empty())?;
657 let res = svc.clone().oneshot(req).await?;
658
659 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
660
661 let req = Request::builder()
662 .header(http::header::COOKIE, "id=bogus")
663 .body(Body::empty())?;
664 let res = svc.oneshot(req).await?;
665
666 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
667
668 Ok(())
669 }
670
671 #[tokio::test]
672 async fn no_set_cookie_test() -> anyhow::Result<()> {
673 let session_store = MemoryStore::default();
674 let session_layer = SessionManagerLayer::new(session_store);
675 let svc = ServiceBuilder::new()
676 .layer(session_layer)
677 .service_fn(noop_handler);
678
679 let req = Request::builder().body(Body::empty())?;
680 let res = svc.oneshot(req).await?;
681
682 assert!(res.headers().get(http::header::SET_COOKIE).is_none());
683
684 Ok(())
685 }
686
687 #[tokio::test]
688 async fn name_test() -> anyhow::Result<()> {
689 let session_store = MemoryStore::default();
690 let session_layer = SessionManagerLayer::new(session_store).with_name("my.sid");
691 let svc = ServiceBuilder::new()
692 .layer(session_layer)
693 .service_fn(handler);
694
695 let req = Request::builder().body(Body::empty())?;
696 let res = svc.oneshot(req).await?;
697
698 assert!(cookie_value_matches(&res, |s| s.starts_with("my.sid=")));
699
700 Ok(())
701 }
702
703 #[tokio::test]
704 async fn http_only_test() -> anyhow::Result<()> {
705 let session_store = MemoryStore::default();
706 let session_layer = SessionManagerLayer::new(session_store);
707 let svc = ServiceBuilder::new()
708 .layer(session_layer)
709 .service_fn(handler);
710
711 let req = Request::builder().body(Body::empty())?;
712 let res = svc.oneshot(req).await?;
713
714 assert!(cookie_value_matches(&res, |s| s.contains("HttpOnly")));
715
716 let session_store = MemoryStore::default();
717 let session_layer = SessionManagerLayer::new(session_store).with_http_only(false);
718 let svc = ServiceBuilder::new()
719 .layer(session_layer)
720 .service_fn(handler);
721
722 let req = Request::builder().body(Body::empty())?;
723 let res = svc.oneshot(req).await?;
724
725 assert!(cookie_value_matches(&res, |s| !s.contains("HttpOnly")));
726
727 Ok(())
728 }
729
730 #[tokio::test]
731 async fn same_site_strict_test() -> anyhow::Result<()> {
732 let session_store = MemoryStore::default();
733 let session_layer =
734 SessionManagerLayer::new(session_store).with_same_site(SameSite::Strict);
735 let svc = ServiceBuilder::new()
736 .layer(session_layer)
737 .service_fn(handler);
738
739 let req = Request::builder().body(Body::empty())?;
740 let res = svc.oneshot(req).await?;
741
742 assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Strict")));
743
744 Ok(())
745 }
746
747 #[tokio::test]
748 async fn same_site_lax_test() -> anyhow::Result<()> {
749 let session_store = MemoryStore::default();
750 let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax);
751 let svc = ServiceBuilder::new()
752 .layer(session_layer)
753 .service_fn(handler);
754
755 let req = Request::builder().body(Body::empty())?;
756 let res = svc.oneshot(req).await?;
757
758 assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Lax")));
759
760 Ok(())
761 }
762
763 #[tokio::test]
764 async fn same_site_none_test() -> anyhow::Result<()> {
765 let session_store = MemoryStore::default();
766 let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::None);
767 let svc = ServiceBuilder::new()
768 .layer(session_layer)
769 .service_fn(handler);
770
771 let req = Request::builder().body(Body::empty())?;
772 let res = svc.oneshot(req).await?;
773
774 assert!(cookie_value_matches(&res, |s| s.contains("SameSite=None")));
775
776 Ok(())
777 }
778
779 #[tokio::test]
780 async fn expiry_on_session_end_test() -> anyhow::Result<()> {
781 let session_store = MemoryStore::default();
782 let session_layer = SessionManagerLayer::new(session_store)
783 .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION));
784 let svc = ServiceBuilder::new()
785 .layer(session_layer)
786 .service_fn(handler);
787
788 let req = Request::builder().body(Body::empty())?;
789 let res = svc.oneshot(req).await?;
790
791 assert!(cookie_value_matches(&res, |s| !s.contains("Max-Age")));
792
793 Ok(())
794 }
795
796 #[tokio::test]
797 async fn expiry_on_inactivity_test() -> anyhow::Result<()> {
798 let session_store = MemoryStore::default();
799 let inactivity_duration = time::Duration::hours(2);
800 let session_layer = SessionManagerLayer::new(session_store)
801 .with_expiry(Expiry::OnInactivity(inactivity_duration));
802 let svc = ServiceBuilder::new()
803 .layer(session_layer)
804 .service_fn(handler);
805
806 let req = Request::builder().body(Body::empty())?;
807 let res = svc.oneshot(req).await?;
808
809 let expected_max_age = inactivity_duration.whole_seconds();
810 assert!(cookie_has_expected_max_age(&res, expected_max_age));
811
812 Ok(())
813 }
814
815 #[tokio::test]
816 async fn expiry_at_date_time_test() -> anyhow::Result<()> {
817 let session_store = MemoryStore::default();
818 let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
819 let session_layer =
820 SessionManagerLayer::new(session_store).with_expiry(Expiry::AtDateTime(expiry_time));
821 let svc = ServiceBuilder::new()
822 .layer(session_layer)
823 .service_fn(handler);
824
825 let req = Request::builder().body(Body::empty())?;
826 let res = svc.oneshot(req).await?;
827
828 let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
829 assert!(cookie_has_expected_max_age(&res, expected_max_age));
830
831 Ok(())
832 }
833
834 #[tokio::test]
835 async fn expiry_on_session_end_always_save_test() -> anyhow::Result<()> {
836 let session_store = MemoryStore::default();
837 let session_layer = SessionManagerLayer::new(session_store.clone())
838 .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION))
839 .with_always_save(true);
840 let mut svc = ServiceBuilder::new()
841 .layer(session_layer)
842 .service_fn(handler);
843
844 let req1 = Request::builder().body(Body::empty())?;
845 let res1 = svc.call(req1).await?;
846 let sid1 = get_session_id(&res1);
847 let rec1 = get_record(&session_store, &sid1).await;
848 let req2 = Request::builder()
849 .header(http::header::COOKIE, format!("id={}", sid1))
850 .body(Body::empty())?;
851 let res2 = svc.call(req2).await?;
852 let sid2 = get_session_id(&res2);
853 let rec2 = get_record(&session_store, &sid2).await;
854
855 assert!(cookie_value_matches(&res2, |s| !s.contains("Max-Age")));
856 assert!(sid1 == sid2);
857 assert!(rec1.expiry_date < rec2.expiry_date);
858
859 Ok(())
860 }
861
862 #[tokio::test]
863 async fn expiry_on_inactivity_always_save_test() -> anyhow::Result<()> {
864 let session_store = MemoryStore::default();
865 let inactivity_duration = time::Duration::hours(2);
866 let session_layer = SessionManagerLayer::new(session_store.clone())
867 .with_expiry(Expiry::OnInactivity(inactivity_duration))
868 .with_always_save(true);
869 let mut svc = ServiceBuilder::new()
870 .layer(session_layer)
871 .service_fn(handler);
872
873 let req1 = Request::builder().body(Body::empty())?;
874 let res1 = svc.call(req1).await?;
875 let sid1 = get_session_id(&res1);
876 let rec1 = get_record(&session_store, &sid1).await;
877 let req2 = Request::builder()
878 .header(http::header::COOKIE, format!("id={}", sid1))
879 .body(Body::empty())?;
880 let res2 = svc.call(req2).await?;
881 let sid2 = get_session_id(&res2);
882 let rec2 = get_record(&session_store, &sid2).await;
883
884 let expected_max_age = inactivity_duration.whole_seconds();
885 assert!(cookie_has_expected_max_age(&res2, expected_max_age));
886 assert!(sid1 == sid2);
887 assert!(rec1.expiry_date < rec2.expiry_date);
888
889 Ok(())
890 }
891
892 #[tokio::test]
893 async fn expiry_at_date_time_always_save_test() -> anyhow::Result<()> {
894 let session_store = MemoryStore::default();
895 let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
896 let session_layer = SessionManagerLayer::new(session_store.clone())
897 .with_expiry(Expiry::AtDateTime(expiry_time))
898 .with_always_save(true);
899 let mut svc = ServiceBuilder::new()
900 .layer(session_layer)
901 .service_fn(handler);
902
903 let req1 = Request::builder().body(Body::empty())?;
904 let res1 = svc.call(req1).await?;
905 let sid1 = get_session_id(&res1);
906 let rec1 = get_record(&session_store, &sid1).await;
907 let req2 = Request::builder()
908 .header(http::header::COOKIE, format!("id={}", sid1))
909 .body(Body::empty())?;
910 let res2 = svc.call(req2).await?;
911 let sid2 = get_session_id(&res2);
912 let rec2 = get_record(&session_store, &sid2).await;
913
914 let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
915 assert!(cookie_has_expected_max_age(&res2, expected_max_age));
916 assert!(sid1 == sid2);
917 assert!(rec1.expiry_date == rec2.expiry_date);
918
919 Ok(())
920 }
921
922 #[tokio::test]
923 async fn secure_test() -> anyhow::Result<()> {
924 let session_store = MemoryStore::default();
925 let session_layer = SessionManagerLayer::new(session_store).with_secure(true);
926 let svc = ServiceBuilder::new()
927 .layer(session_layer)
928 .service_fn(handler);
929
930 let req = Request::builder().body(Body::empty())?;
931 let res = svc.oneshot(req).await?;
932
933 assert!(cookie_value_matches(&res, |s| s.contains("Secure")));
934
935 let session_store = MemoryStore::default();
936 let session_layer = SessionManagerLayer::new(session_store).with_secure(false);
937 let svc = ServiceBuilder::new()
938 .layer(session_layer)
939 .service_fn(handler);
940
941 let req = Request::builder().body(Body::empty())?;
942 let res = svc.oneshot(req).await?;
943
944 assert!(cookie_value_matches(&res, |s| !s.contains("Secure")));
945
946 Ok(())
947 }
948
949 #[tokio::test]
950 async fn path_test() -> anyhow::Result<()> {
951 let session_store = MemoryStore::default();
952 let session_layer = SessionManagerLayer::new(session_store).with_path("/foo/bar");
953 let svc = ServiceBuilder::new()
954 .layer(session_layer)
955 .service_fn(handler);
956
957 let req = Request::builder().body(Body::empty())?;
958 let res = svc.oneshot(req).await?;
959
960 assert!(cookie_value_matches(&res, |s| s.contains("Path=/foo/bar")));
961
962 Ok(())
963 }
964
965 #[tokio::test]
966 async fn domain_test() -> anyhow::Result<()> {
967 let session_store = MemoryStore::default();
968 let session_layer = SessionManagerLayer::new(session_store).with_domain("example.com");
969 let svc = ServiceBuilder::new()
970 .layer(session_layer)
971 .service_fn(handler);
972
973 let req = Request::builder().body(Body::empty())?;
974 let res = svc.oneshot(req).await?;
975
976 assert!(cookie_value_matches(&res, |s| s.contains("Domain=example.com")));
977
978 Ok(())
979 }
980
981 #[cfg(feature = "signed")]
982 #[tokio::test]
983 async fn signed_test() -> anyhow::Result<()> {
984 let key = Key::generate();
985 let session_store = MemoryStore::default();
986 let session_layer = SessionManagerLayer::new(session_store).with_signed(key);
987 let svc = ServiceBuilder::new()
988 .layer(session_layer)
989 .service_fn(handler);
990
991 let req = Request::builder().body(Body::empty())?;
992 let res = svc.oneshot(req).await?;
993
994 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
995
996 Ok(())
997 }
998
999 #[cfg(feature = "private")]
1000 #[tokio::test]
1001 async fn private_test() -> anyhow::Result<()> {
1002 let key = Key::generate();
1003 let session_store = MemoryStore::default();
1004 let session_layer = SessionManagerLayer::new(session_store).with_private(key);
1005 let svc = ServiceBuilder::new()
1006 .layer(session_layer)
1007 .service_fn(handler);
1008
1009 let req = Request::builder().body(Body::empty())?;
1010 let res = svc.oneshot(req).await?;
1011
1012 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
1013
1014 Ok(())
1015 }
1016
1017 fn cookie_value_matches<F>(res: &Response<Body>, matcher: F) -> bool
1018 where
1019 F: FnOnce(&str) -> bool,
1020 {
1021 res.headers()
1022 .get(http::header::SET_COOKIE)
1023 .is_some_and(|set_cookie| set_cookie.to_str().is_ok_and(matcher))
1024 }
1025
1026 fn cookie_has_expected_max_age(res: &Response<Body>, expected_value: i64) -> bool {
1027 res.headers()
1028 .get(http::header::SET_COOKIE)
1029 .is_some_and(|set_cookie| {
1030 set_cookie.to_str().is_ok_and(|s| {
1031 let max_age_value = s
1032 .split("Max-Age=")
1033 .nth(1)
1034 .unwrap_or_default()
1035 .split(';')
1036 .next()
1037 .unwrap_or_default()
1038 .parse::<i64>()
1039 .unwrap_or_default();
1040 (max_age_value - expected_value).abs() <= 1
1041 })
1042 })
1043 }
1044
1045 fn get_session_id(res: &Response<Body>) -> String {
1046 res.headers()
1047 .get(http::header::SET_COOKIE)
1048 .unwrap()
1049 .to_str()
1050 .unwrap()
1051 .split("id=")
1052 .nth(1)
1053 .unwrap()
1054 .split(";")
1055 .next()
1056 .unwrap()
1057 .to_string()
1058 }
1059
1060 async fn get_record(store: &impl SessionStore, id: &str) -> Record {
1061 store
1062 .load(&Id::from_str(id).unwrap())
1063 .await
1064 .unwrap()
1065 .unwrap()
1066 }
1067}