axum_sessions/
session.rs

1// Much of this code is lifted directly from
2// `tide::sessions::middleware::SessionMiddleware`. See: https://github.com/http-rs/tide/blob/20fe435a9544c10f64245e883847fc3cd1d50538/src/sessions/middleware.rs
3
4use std::{
5    sync::Arc,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use async_session::{
11    base64,
12    hmac::{Hmac, Mac, NewMac},
13    sha2::Sha256,
14    SessionStore,
15};
16use axum::{
17    http::{
18        header::{HeaderValue, COOKIE, SET_COOKIE},
19        Request, StatusCode,
20    },
21    response::Response,
22};
23use axum_extra::extract::cookie::{Cookie, Key, SameSite};
24use futures::future::BoxFuture;
25use tokio::sync::RwLock;
26use tower::{Layer, Service};
27
28const BASE64_DIGEST_LEN: usize = 44;
29
30/// A type alias which provides a handle to the underlying session.
31///
32/// This is provided via [`http::Extensions`](axum::http::Extensions). Most
33/// applications will use the
34/// [`ReadableSession`](crate::extractors::ReadableSession) and
35/// [`WritableSession`](crate::extractors::WritableSession) extractors rather
36/// than using the handle directly. A notable exception is when using this
37/// library as a generic Tower middleware: such use cases will consume the
38/// handle directly.
39pub type SessionHandle = Arc<RwLock<async_session::Session>>;
40
41/// Controls how the session data is persisted and created.
42#[derive(Clone)]
43pub enum PersistencePolicy {
44    /// Always ping the storage layer and store empty "guest" sessions.
45    Always,
46    /// Do not store empty "guest" sessions, only ping the storage layer if
47    /// the session data changed.
48    ChangedOnly,
49    /// Do not store empty "guest" sessions, always ping the storage layer for
50    /// existing sessions.
51    ExistingOnly,
52}
53
54/// Layer that provides cookie-based sessions.
55#[derive(Clone)]
56pub struct SessionLayer<Store> {
57    store: Store,
58    cookie_path: String,
59    cookie_name: String,
60    cookie_domain: Option<String>,
61    persistence_policy: PersistencePolicy,
62    session_ttl: Option<Duration>,
63    same_site_policy: SameSite,
64    http_only: bool,
65    secure: bool,
66    key: Key,
67}
68
69impl<Store: SessionStore> SessionLayer<Store> {
70    /// Creates a layer which will attach a [`SessionHandle`] to requests via an
71    /// extension. This session is derived from a cryptographically signed
72    /// cookie. When the client sends a valid, known cookie then the session is
73    /// hydrated from this. Otherwise a new cookie is created and returned in
74    /// the response.
75    ///
76    /// The default behaviour is to enable "guest" sessions with
77    /// [`PersistencePolicy::Always`].
78    ///
79    /// # Panics
80    ///
81    /// `SessionLayer::new` will panic if the secret is less than 64 bytes.
82    ///
83    /// # Customization
84    ///
85    /// The configuration of the session may be adjusted according to the needs
86    /// of your application:
87    ///
88    /// ```rust
89    /// # use axum_sessions::{PersistencePolicy, SessionLayer, async_session::MemoryStore, SameSite};
90    /// # use std::time::Duration;
91    /// SessionLayer::new(
92    ///     MemoryStore::new(),
93    ///     b"please do not hardcode your secret; instead use a
94    ///     cryptographically secure value",
95    /// )
96    /// .with_cookie_name("your.cookie.name")
97    /// .with_cookie_path("/some/path")
98    /// .with_cookie_domain("www.example.com")
99    /// .with_same_site_policy(SameSite::Lax)
100    /// .with_session_ttl(Some(Duration::from_secs(60 * 5)))
101    /// .with_persistence_policy(PersistencePolicy::Always)
102    /// .with_http_only(true)
103    /// .with_secure(true);
104    /// ```
105    #[deprecated(
106        since = "0.6.0",
107        note = "Development of axum-sessions has moved to the tower-sessions crate. Please \
108                consider migrating."
109    )]
110    pub fn new(store: Store, secret: &[u8]) -> Self {
111        if secret.len() < 64 {
112            panic!("`secret` must be at least 64 bytes.")
113        }
114
115        Self {
116            store,
117            persistence_policy: PersistencePolicy::Always,
118            cookie_path: "/".into(),
119            cookie_name: "sid".into(),
120            cookie_domain: None,
121            same_site_policy: SameSite::Strict,
122            session_ttl: Some(Duration::from_secs(24 * 60 * 60)),
123            http_only: true,
124            secure: true,
125            key: Key::from(secret),
126        }
127    }
128
129    /// When `true`, a session cookie will always be set. When `false` the
130    /// session data must be modified in order for it to be set. Defaults to
131    /// `true`.
132    pub fn with_persistence_policy(mut self, policy: PersistencePolicy) -> Self {
133        self.persistence_policy = policy;
134        self
135    }
136
137    /// Sets a cookie for the session. Defaults to `"/"`.
138    pub fn with_cookie_path(mut self, cookie_path: impl AsRef<str>) -> Self {
139        self.cookie_path = cookie_path.as_ref().to_owned();
140        self
141    }
142
143    /// Sets a cookie name for the session. Defaults to `"sid"`.
144    pub fn with_cookie_name(mut self, cookie_name: impl AsRef<str>) -> Self {
145        self.cookie_name = cookie_name.as_ref().to_owned();
146        self
147    }
148
149    /// Sets a cookie domain for the session. Defaults to `None`.
150    pub fn with_cookie_domain(mut self, cookie_domain: impl AsRef<str>) -> Self {
151        self.cookie_domain = Some(cookie_domain.as_ref().to_owned());
152        self
153    }
154
155    /// Decide if session is presented to the storage layer
156    fn should_store(&self, cookie_value: &Option<String>, session_data_changed: bool) -> bool {
157        session_data_changed
158            || matches!(self.persistence_policy, PersistencePolicy::Always)
159            || (matches!(self.persistence_policy, PersistencePolicy::ExistingOnly)
160                && cookie_value.is_some())
161    }
162
163    /// Sets a cookie same site policy for the session. Defaults to
164    /// `SameSite::Strict`.
165    pub fn with_same_site_policy(mut self, policy: SameSite) -> Self {
166        self.same_site_policy = policy;
167        self
168    }
169
170    /// Sets a cookie time-to-live (ttl) for the session. Defaults to
171    /// `Duration::from_secs(60 * 60 * 24)`; one day.
172    pub fn with_session_ttl(mut self, session_ttl: Option<Duration>) -> Self {
173        self.session_ttl = session_ttl;
174        self
175    }
176
177    /// Sets a cookie `HttpOnly` attribute for the session. Defaults to `true`.
178    pub fn with_http_only(mut self, http_only: bool) -> Self {
179        self.http_only = http_only;
180        self
181    }
182
183    /// Sets a cookie secure attribute for the session. Defaults to `true`.
184    pub fn with_secure(mut self, secure: bool) -> Self {
185        self.secure = secure;
186        self
187    }
188
189    async fn load_or_create(&self, cookie_value: Option<String>) -> SessionHandle {
190        let session = match cookie_value {
191            Some(cookie_value) => self.store.load_session(cookie_value).await.ok().flatten(),
192            None => None,
193        };
194
195        Arc::new(RwLock::new(
196            session
197                .and_then(async_session::Session::validate)
198                .unwrap_or_default(),
199        ))
200    }
201
202    fn build_cookie(&self, cookie_value: String) -> Cookie<'static> {
203        let mut cookie = Cookie::build(self.cookie_name.clone(), cookie_value)
204            .http_only(self.http_only)
205            .same_site(self.same_site_policy)
206            .secure(self.secure)
207            .path(self.cookie_path.clone())
208            .finish();
209
210        if let Some(ttl) = self.session_ttl {
211            cookie.set_expires(Some((std::time::SystemTime::now() + ttl).into()));
212        }
213
214        if let Some(cookie_domain) = self.cookie_domain.clone() {
215            cookie.set_domain(cookie_domain)
216        }
217
218        self.sign_cookie(&mut cookie);
219
220        cookie
221    }
222
223    fn build_removal_cookie(&self) -> Cookie<'static> {
224        let cookie = Cookie::build(self.cookie_name.clone(), "")
225            .http_only(true)
226            .path(self.cookie_path.clone());
227
228        let mut cookie = if let Some(cookie_domain) = self.cookie_domain.clone() {
229            cookie.domain(cookie_domain)
230        } else {
231            cookie
232        }
233        .finish();
234
235        cookie.make_removal();
236
237        self.sign_cookie(&mut cookie);
238
239        cookie
240    }
241
242    // the following is reused verbatim from
243    // https://github.com/SergioBenitez/cookie-rs/blob/master/src/secure/signed.rs#L33-L43
244    /// Signs the cookie's value providing integrity and authenticity.
245    fn sign_cookie(&self, cookie: &mut Cookie<'_>) {
246        // Compute HMAC-SHA256 of the cookie's value.
247        let mut mac = Hmac::<Sha256>::new_from_slice(self.key.signing()).expect("good key");
248        mac.update(cookie.value().as_bytes());
249
250        // Cookie's new value is [MAC | original-value].
251        let mut new_value = base64::encode(mac.finalize().into_bytes());
252        new_value.push_str(cookie.value());
253        cookie.set_value(new_value);
254    }
255
256    // the following is reused verbatim from
257    // https://github.com/SergioBenitez/cookie-rs/blob/master/src/secure/signed.rs#L45-L63
258    /// Given a signed value `str` where the signature is prepended to `value`,
259    /// verifies the signed value and returns it. If there's a problem, returns
260    /// an `Err` with a string describing the issue.
261    fn verify_signature(&self, cookie_value: &str) -> Result<String, &'static str> {
262        if cookie_value.len() < BASE64_DIGEST_LEN {
263            return Err("length of value is <= BASE64_DIGEST_LEN");
264        }
265
266        // Split [MAC | original-value] into its two parts.
267        let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN);
268        let digest = base64::decode(digest_str).map_err(|_| "bad base64 digest")?;
269
270        // Perform the verification.
271        let mut mac = Hmac::<Sha256>::new_from_slice(self.key.signing()).expect("good key");
272        mac.update(value.as_bytes());
273        mac.verify(&digest)
274            .map(|_| value.to_string())
275            .map_err(|_| "value did not verify")
276    }
277}
278
279impl<Inner, Store: SessionStore> Layer<Inner> for SessionLayer<Store> {
280    type Service = Session<Inner, Store>;
281
282    fn layer(&self, inner: Inner) -> Self::Service {
283        Session {
284            inner,
285            layer: self.clone(),
286        }
287    }
288}
289
290/// Session service container.
291#[derive(Clone)]
292pub struct Session<Inner, Store: SessionStore> {
293    inner: Inner,
294    layer: SessionLayer<Store>,
295}
296
297impl<Inner, ReqBody, ResBody, Store: SessionStore> Service<Request<ReqBody>>
298    for Session<Inner, Store>
299where
300    Inner: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
301    ResBody: Send + 'static,
302    ReqBody: Send + 'static,
303    Inner::Future: Send + 'static,
304{
305    type Response = Inner::Response;
306    type Error = Inner::Error;
307    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
308
309    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
310        self.inner.poll_ready(cx)
311    }
312
313    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
314        let session_layer = self.layer.clone();
315
316        // Multiple cookies may be all concatenated into a single Cookie header
317        // separated with semicolons (HTTP/1.1 behaviour) or into multiple separate
318        // Cookie headers (HTTP/2 behaviour). Search for the session cookie from
319        // all Cookie headers, assuming both forms are possible
320        let cookie_value = request
321            .headers()
322            .get_all(COOKIE)
323            .iter()
324            .filter_map(|cookie_header| cookie_header.to_str().ok())
325            .flat_map(|cookie_header| cookie_header.split(';'))
326            .filter_map(|cookie_header| Cookie::parse_encoded(cookie_header.trim()).ok())
327            .filter(|cookie| cookie.name() == session_layer.cookie_name)
328            .find_map(|cookie| self.layer.verify_signature(cookie.value()).ok());
329
330        let inner = self.inner.clone();
331        let mut inner = std::mem::replace(&mut self.inner, inner);
332        Box::pin(async move {
333            let session_handle = session_layer.load_or_create(cookie_value.clone()).await;
334
335            let mut session = session_handle.write().await;
336            if let Some(ttl) = session_layer.session_ttl {
337                (*session).expire_in(ttl);
338            }
339            drop(session);
340
341            request.extensions_mut().insert(session_handle.clone());
342            let mut response = inner.call(request).await?;
343
344            let session = session_handle.read().await;
345            let (session_is_destroyed, session_data_changed) =
346                (session.is_destroyed(), session.data_changed());
347            drop(session);
348
349            // Pull out the session so we can pass it to the store without `Clone` blowing
350            // away the `cookie_value`.
351            let session = RwLock::into_inner(
352                Arc::try_unwrap(session_handle).expect("Session handle still has owners."),
353            );
354            if session_is_destroyed {
355                if let Err(e) = session_layer.store.destroy_session(session).await {
356                    tracing::error!("Failed to destroy session: {:?}", e);
357                    *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
358                }
359
360                let removal_cookie = session_layer.build_removal_cookie();
361
362                response.headers_mut().append(
363                    SET_COOKIE,
364                    HeaderValue::from_str(&removal_cookie.to_string()).unwrap(),
365                );
366
367            // Store if
368            //  - We have guest sessions
369            //  - We received a valid cookie and we use the `ExistingOnly`
370            //    policy.
371            //  - If we use the `ChangedOnly` policy, only
372            //    `session.data_changed()` should trigger this branch.
373            } else if session_layer.should_store(&cookie_value, session_data_changed) {
374                match session_layer.store.store_session(session).await {
375                    Ok(Some(cookie_value)) => {
376                        let cookie = session_layer.build_cookie(cookie_value);
377                        response.headers_mut().append(
378                            SET_COOKIE,
379                            HeaderValue::from_str(&cookie.to_string()).unwrap(),
380                        );
381                    }
382
383                    Ok(None) => {}
384
385                    Err(e) => {
386                        tracing::error!("Failed to reach session storage: {:?}", e);
387                        *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
388                    }
389                }
390            }
391
392            Ok(response)
393        })
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use async_session::{
400        serde::{Deserialize, Serialize},
401        serde_json,
402    };
403    use axum::http::{Request, Response};
404    use http::{
405        header::{COOKIE, SET_COOKIE},
406        HeaderValue, StatusCode,
407    };
408    use hyper::Body;
409    use rand::Rng;
410    use tower::{BoxError, Service, ServiceBuilder, ServiceExt};
411
412    use super::PersistencePolicy;
413    use crate::{async_session::MemoryStore, SessionHandle, SessionLayer};
414
415    #[derive(Deserialize, Serialize, PartialEq, Debug)]
416    struct Counter {
417        counter: i32,
418    }
419
420    enum ExpectedResult {
421        Some,
422        None,
423    }
424
425    #[tokio::test]
426    async fn sets_session_cookie() {
427        let secret = rand::thread_rng().gen::<[u8; 64]>();
428        let store = MemoryStore::new();
429        let session_layer = SessionLayer::new(store, &secret);
430        let mut service = ServiceBuilder::new().layer(session_layer).service_fn(echo);
431
432        let request = Request::get("/").body(Body::empty()).unwrap();
433
434        let res = service.ready().await.unwrap().call(request).await.unwrap();
435        assert_eq!(res.status(), StatusCode::OK);
436
437        assert!(res
438            .headers()
439            .get(SET_COOKIE)
440            .unwrap()
441            .to_str()
442            .unwrap()
443            .starts_with("sid="))
444    }
445
446    #[tokio::test]
447    async fn uses_valid_session() {
448        let secret = rand::thread_rng().gen::<[u8; 64]>();
449        let store = MemoryStore::new();
450        let session_layer = SessionLayer::new(store, &secret);
451        let mut service = ServiceBuilder::new()
452            .layer(session_layer)
453            .service_fn(increment);
454
455        let request = Request::get("/").body(Body::empty()).unwrap();
456
457        let res = service.ready().await.unwrap().call(request).await.unwrap();
458        let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone();
459
460        assert_eq!(res.status(), StatusCode::OK);
461
462        let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
463        let counter: Counter = serde_json::from_slice(json_bs).unwrap();
464        assert_eq!(counter, Counter { counter: 0 });
465
466        let mut request = Request::get("/").body(Body::empty()).unwrap();
467        request
468            .headers_mut()
469            .insert(COOKIE, session_cookie.to_owned());
470        let res = service.ready().await.unwrap().call(request).await.unwrap();
471        assert_eq!(res.status(), StatusCode::OK);
472
473        let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
474        let counter: Counter = serde_json::from_slice(json_bs).unwrap();
475        assert_eq!(counter, Counter { counter: 1 });
476    }
477
478    #[tokio::test]
479    async fn multiple_cookies_in_single_header() {
480        let secret = rand::thread_rng().gen::<[u8; 64]>();
481        let store = MemoryStore::new();
482        let session_layer = SessionLayer::new(store, &secret);
483        let mut service = ServiceBuilder::new()
484            .layer(session_layer)
485            .service_fn(increment);
486
487        let request = Request::get("/").body(Body::empty()).unwrap();
488
489        let res = service.ready().await.unwrap().call(request).await.unwrap();
490        let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone();
491
492        // build a Cookie header that contains two cookies: an unrelated dummy cookie,
493        // and the given session cookie
494        let request_cookie =
495            HeaderValue::from_str(&format!("key=value; {}", session_cookie.to_str().unwrap()))
496                .unwrap();
497
498        assert_eq!(res.status(), StatusCode::OK);
499
500        let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
501        let counter: Counter = serde_json::from_slice(json_bs).unwrap();
502        assert_eq!(counter, Counter { counter: 0 });
503
504        let mut request = Request::get("/").body(Body::empty()).unwrap();
505        request.headers_mut().insert(COOKIE, request_cookie);
506        let res = service.ready().await.unwrap().call(request).await.unwrap();
507        assert_eq!(res.status(), StatusCode::OK);
508
509        let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
510        let counter: Counter = serde_json::from_slice(json_bs).unwrap();
511        assert_eq!(counter, Counter { counter: 1 });
512    }
513
514    #[tokio::test]
515    async fn multiple_cookie_headers() {
516        let secret = rand::thread_rng().gen::<[u8; 64]>();
517        let store = MemoryStore::new();
518        let session_layer = SessionLayer::new(store, &secret);
519        let mut service = ServiceBuilder::new()
520            .layer(session_layer)
521            .service_fn(increment);
522
523        let request = Request::get("/").body(Body::empty()).unwrap();
524
525        let res = service.ready().await.unwrap().call(request).await.unwrap();
526        let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone();
527        let dummy_cookie = HeaderValue::from_str("key=value").unwrap();
528
529        assert_eq!(res.status(), StatusCode::OK);
530
531        let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
532        let counter: Counter = serde_json::from_slice(json_bs).unwrap();
533        assert_eq!(counter, Counter { counter: 0 });
534
535        let mut request = Request::get("/").body(Body::empty()).unwrap();
536        request.headers_mut().append(COOKIE, dummy_cookie);
537        request.headers_mut().append(COOKIE, session_cookie);
538        let res = service.ready().await.unwrap().call(request).await.unwrap();
539        assert_eq!(res.status(), StatusCode::OK);
540
541        let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
542        let counter: Counter = serde_json::from_slice(json_bs).unwrap();
543        assert_eq!(counter, Counter { counter: 1 });
544    }
545
546    #[tokio::test]
547    async fn no_cookie_stored_when_no_session_is_required() {
548        let secret = rand::thread_rng().gen::<[u8; 64]>();
549        let store = MemoryStore::new();
550        let session_layer = SessionLayer::new(store, &secret)
551            .with_persistence_policy(PersistencePolicy::ChangedOnly);
552        let mut service = ServiceBuilder::new().layer(session_layer).service_fn(echo);
553
554        let request = Request::get("/").body(Body::empty()).unwrap();
555
556        let res = service.ready().await.unwrap().call(request).await.unwrap();
557        assert_eq!(res.status(), StatusCode::OK);
558
559        assert!(res.headers().get(SET_COOKIE).is_none());
560    }
561
562    async fn invalid_session_check_cookie_result(
563        persistence_policy: PersistencePolicy,
564        change_data: bool,
565        expect_cookie_header: (ExpectedResult, ExpectedResult),
566    ) {
567        let (expect_cookie_header_first, expect_cookie_header_second) = expect_cookie_header;
568        let secret = rand::thread_rng().gen::<[u8; 64]>();
569        let store = MemoryStore::new();
570        let session_layer =
571            SessionLayer::new(store, &secret).with_persistence_policy(persistence_policy);
572        let mut service = ServiceBuilder::new()
573            .layer(&session_layer)
574            .service_fn(echo_read_session);
575
576        let request = Request::get("/").body(Body::empty()).unwrap();
577
578        let res = service.ready().await.unwrap().call(request).await.unwrap();
579        assert_eq!(res.status(), StatusCode::OK);
580
581        match expect_cookie_header_first {
582            ExpectedResult::Some => assert!(
583                res.headers().get(SET_COOKIE).is_some(),
584                "Set-Cookie must be present for first response"
585            ),
586            ExpectedResult::None => assert!(
587                res.headers().get(SET_COOKIE).is_none(),
588                "Set-Cookie must not be present for first response"
589            ),
590        }
591
592        let mut service =
593            ServiceBuilder::new()
594                .layer(session_layer)
595                .service_fn(move |req| async move {
596                    if change_data {
597                        echo_with_session_change(req).await
598                    } else {
599                        echo_read_session(req).await
600                    }
601                });
602        let mut request = Request::get("/").body(Body::empty()).unwrap();
603        request
604            .headers_mut()
605            .insert(COOKIE, "sid=aW52YWxpZC1zZXNzaW9uLWlk".parse().unwrap());
606        let res = service.ready().await.unwrap().call(request).await.unwrap();
607        match expect_cookie_header_second {
608            ExpectedResult::Some => assert!(
609                res.headers().get(SET_COOKIE).is_some(),
610                "Set-Cookie must be present for second response"
611            ),
612            ExpectedResult::None => assert!(
613                res.headers().get(SET_COOKIE).is_none(),
614                "Set-Cookie must not be present for second response"
615            ),
616        }
617    }
618
619    #[tokio::test]
620    async fn invalid_session_always_sets_guest_cookie() {
621        invalid_session_check_cookie_result(
622            PersistencePolicy::Always,
623            false,
624            (ExpectedResult::Some, ExpectedResult::Some),
625        )
626        .await;
627    }
628
629    #[tokio::test]
630    async fn invalid_session_sets_new_session_cookie_when_data_changes() {
631        invalid_session_check_cookie_result(
632            PersistencePolicy::ExistingOnly,
633            true,
634            (ExpectedResult::None, ExpectedResult::Some),
635        )
636        .await;
637    }
638
639    #[tokio::test]
640    async fn invalid_session_sets_no_cookie_when_no_data_changes() {
641        invalid_session_check_cookie_result(
642            PersistencePolicy::ExistingOnly,
643            false,
644            (ExpectedResult::None, ExpectedResult::None),
645        )
646        .await;
647    }
648
649    #[tokio::test]
650    async fn invalid_session_changedonly_sets_cookie_when_changed() {
651        invalid_session_check_cookie_result(
652            PersistencePolicy::ChangedOnly,
653            true,
654            (ExpectedResult::None, ExpectedResult::Some),
655        )
656        .await;
657    }
658
659    #[tokio::test]
660    async fn destroyed_sessions_sets_removal_cookie() {
661        let secret = rand::thread_rng().gen::<[u8; 64]>();
662        let store = MemoryStore::new();
663        let session_layer = SessionLayer::new(store, &secret);
664        let mut service = ServiceBuilder::new()
665            .layer(session_layer)
666            .service_fn(destroy);
667
668        let request = Request::get("/").body(Body::empty()).unwrap();
669
670        let res = service.ready().await.unwrap().call(request).await.unwrap();
671        assert_eq!(res.status(), StatusCode::OK);
672
673        let session_cookie = res
674            .headers()
675            .get(SET_COOKIE)
676            .unwrap()
677            .to_str()
678            .unwrap()
679            .to_string();
680        let mut request = Request::get("/destroy").body(Body::empty()).unwrap();
681        request
682            .headers_mut()
683            .insert(COOKIE, session_cookie.parse().unwrap());
684        let res = service.ready().await.unwrap().call(request).await.unwrap();
685        assert_eq!(
686            res.headers()
687                .get(SET_COOKIE)
688                .unwrap()
689                .to_str()
690                .unwrap()
691                .len(),
692            116
693        );
694    }
695
696    #[test]
697    #[should_panic]
698    fn too_short_secret() {
699        let store = MemoryStore::new();
700        SessionLayer::new(store, b"");
701    }
702
703    async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
704        Ok(Response::new(req.into_body()))
705    }
706
707    async fn echo_read_session(req: Request<Body>) -> Result<Response<Body>, BoxError> {
708        {
709            let session_handle = req.extensions().get::<SessionHandle>().unwrap();
710            let session = session_handle.write().await;
711            let _ = session.get::<String>("signed_in").unwrap_or_default();
712        }
713        Ok(Response::new(req.into_body()))
714    }
715
716    async fn echo_with_session_change(req: Request<Body>) -> Result<Response<Body>, BoxError> {
717        {
718            let session_handle = req.extensions().get::<SessionHandle>().unwrap();
719            let mut session = session_handle.write().await;
720            session.insert("signed_in", true).unwrap();
721        }
722        Ok(Response::new(req.into_body()))
723    }
724
725    async fn destroy(req: Request<Body>) -> Result<Response<Body>, BoxError> {
726        // Destroy the session if we received a session cookie.
727        if req.headers().get(COOKIE).is_some() {
728            let session_handle = req.extensions().get::<SessionHandle>().unwrap();
729            let mut session = session_handle.write().await;
730            session.destroy();
731        }
732
733        Ok(Response::new(req.into_body()))
734    }
735
736    async fn increment(mut req: Request<Body>) -> Result<Response<Body>, BoxError> {
737        let mut counter = 0;
738
739        {
740            let session_handle = req.extensions().get::<SessionHandle>().unwrap();
741            let mut session = session_handle.write().await;
742            counter = session
743                .get("counter")
744                .map(|count: i32| count + 1)
745                .unwrap_or(counter);
746            session.insert("counter", counter).unwrap();
747        }
748
749        let body = serde_json::to_string(&Counter { counter }).unwrap();
750        *req.body_mut() = Body::from(body);
751
752        Ok(Response::new(req.into_body()))
753    }
754}