1use std::{
5 sync::Arc,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use async_session::{
11 base64,
12 hmac::{Hmac, Mac, NewMac},
13 sha2::Sha256,
14 SessionStore,
15};
16use axum::{
17 http::{
18 header::{HeaderValue, COOKIE, SET_COOKIE},
19 Request, StatusCode,
20 },
21 response::Response,
22};
23use axum_extra::extract::cookie::{Cookie, Key, SameSite};
24use futures::future::BoxFuture;
25use tokio::sync::RwLock;
26use tower::{Layer, Service};
27
28const BASE64_DIGEST_LEN: usize = 44;
29
30pub type SessionHandle = Arc<RwLock<async_session::Session>>;
40
41#[derive(Clone)]
43pub enum PersistencePolicy {
44 Always,
46 ChangedOnly,
49 ExistingOnly,
52}
53
54#[derive(Clone)]
56pub struct SessionLayer<Store> {
57 store: Store,
58 cookie_path: String,
59 cookie_name: String,
60 cookie_domain: Option<String>,
61 persistence_policy: PersistencePolicy,
62 session_ttl: Option<Duration>,
63 same_site_policy: SameSite,
64 http_only: bool,
65 secure: bool,
66 key: Key,
67}
68
69impl<Store: SessionStore> SessionLayer<Store> {
70 #[deprecated(
106 since = "0.6.0",
107 note = "Development of axum-sessions has moved to the tower-sessions crate. Please \
108 consider migrating."
109 )]
110 pub fn new(store: Store, secret: &[u8]) -> Self {
111 if secret.len() < 64 {
112 panic!("`secret` must be at least 64 bytes.")
113 }
114
115 Self {
116 store,
117 persistence_policy: PersistencePolicy::Always,
118 cookie_path: "/".into(),
119 cookie_name: "sid".into(),
120 cookie_domain: None,
121 same_site_policy: SameSite::Strict,
122 session_ttl: Some(Duration::from_secs(24 * 60 * 60)),
123 http_only: true,
124 secure: true,
125 key: Key::from(secret),
126 }
127 }
128
129 pub fn with_persistence_policy(mut self, policy: PersistencePolicy) -> Self {
133 self.persistence_policy = policy;
134 self
135 }
136
137 pub fn with_cookie_path(mut self, cookie_path: impl AsRef<str>) -> Self {
139 self.cookie_path = cookie_path.as_ref().to_owned();
140 self
141 }
142
143 pub fn with_cookie_name(mut self, cookie_name: impl AsRef<str>) -> Self {
145 self.cookie_name = cookie_name.as_ref().to_owned();
146 self
147 }
148
149 pub fn with_cookie_domain(mut self, cookie_domain: impl AsRef<str>) -> Self {
151 self.cookie_domain = Some(cookie_domain.as_ref().to_owned());
152 self
153 }
154
155 fn should_store(&self, cookie_value: &Option<String>, session_data_changed: bool) -> bool {
157 session_data_changed
158 || matches!(self.persistence_policy, PersistencePolicy::Always)
159 || (matches!(self.persistence_policy, PersistencePolicy::ExistingOnly)
160 && cookie_value.is_some())
161 }
162
163 pub fn with_same_site_policy(mut self, policy: SameSite) -> Self {
166 self.same_site_policy = policy;
167 self
168 }
169
170 pub fn with_session_ttl(mut self, session_ttl: Option<Duration>) -> Self {
173 self.session_ttl = session_ttl;
174 self
175 }
176
177 pub fn with_http_only(mut self, http_only: bool) -> Self {
179 self.http_only = http_only;
180 self
181 }
182
183 pub fn with_secure(mut self, secure: bool) -> Self {
185 self.secure = secure;
186 self
187 }
188
189 async fn load_or_create(&self, cookie_value: Option<String>) -> SessionHandle {
190 let session = match cookie_value {
191 Some(cookie_value) => self.store.load_session(cookie_value).await.ok().flatten(),
192 None => None,
193 };
194
195 Arc::new(RwLock::new(
196 session
197 .and_then(async_session::Session::validate)
198 .unwrap_or_default(),
199 ))
200 }
201
202 fn build_cookie(&self, cookie_value: String) -> Cookie<'static> {
203 let mut cookie = Cookie::build(self.cookie_name.clone(), cookie_value)
204 .http_only(self.http_only)
205 .same_site(self.same_site_policy)
206 .secure(self.secure)
207 .path(self.cookie_path.clone())
208 .finish();
209
210 if let Some(ttl) = self.session_ttl {
211 cookie.set_expires(Some((std::time::SystemTime::now() + ttl).into()));
212 }
213
214 if let Some(cookie_domain) = self.cookie_domain.clone() {
215 cookie.set_domain(cookie_domain)
216 }
217
218 self.sign_cookie(&mut cookie);
219
220 cookie
221 }
222
223 fn build_removal_cookie(&self) -> Cookie<'static> {
224 let cookie = Cookie::build(self.cookie_name.clone(), "")
225 .http_only(true)
226 .path(self.cookie_path.clone());
227
228 let mut cookie = if let Some(cookie_domain) = self.cookie_domain.clone() {
229 cookie.domain(cookie_domain)
230 } else {
231 cookie
232 }
233 .finish();
234
235 cookie.make_removal();
236
237 self.sign_cookie(&mut cookie);
238
239 cookie
240 }
241
242 fn sign_cookie(&self, cookie: &mut Cookie<'_>) {
246 let mut mac = Hmac::<Sha256>::new_from_slice(self.key.signing()).expect("good key");
248 mac.update(cookie.value().as_bytes());
249
250 let mut new_value = base64::encode(mac.finalize().into_bytes());
252 new_value.push_str(cookie.value());
253 cookie.set_value(new_value);
254 }
255
256 fn verify_signature(&self, cookie_value: &str) -> Result<String, &'static str> {
262 if cookie_value.len() < BASE64_DIGEST_LEN {
263 return Err("length of value is <= BASE64_DIGEST_LEN");
264 }
265
266 let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN);
268 let digest = base64::decode(digest_str).map_err(|_| "bad base64 digest")?;
269
270 let mut mac = Hmac::<Sha256>::new_from_slice(self.key.signing()).expect("good key");
272 mac.update(value.as_bytes());
273 mac.verify(&digest)
274 .map(|_| value.to_string())
275 .map_err(|_| "value did not verify")
276 }
277}
278
279impl<Inner, Store: SessionStore> Layer<Inner> for SessionLayer<Store> {
280 type Service = Session<Inner, Store>;
281
282 fn layer(&self, inner: Inner) -> Self::Service {
283 Session {
284 inner,
285 layer: self.clone(),
286 }
287 }
288}
289
290#[derive(Clone)]
292pub struct Session<Inner, Store: SessionStore> {
293 inner: Inner,
294 layer: SessionLayer<Store>,
295}
296
297impl<Inner, ReqBody, ResBody, Store: SessionStore> Service<Request<ReqBody>>
298 for Session<Inner, Store>
299where
300 Inner: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
301 ResBody: Send + 'static,
302 ReqBody: Send + 'static,
303 Inner::Future: Send + 'static,
304{
305 type Response = Inner::Response;
306 type Error = Inner::Error;
307 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
308
309 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
310 self.inner.poll_ready(cx)
311 }
312
313 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
314 let session_layer = self.layer.clone();
315
316 let cookie_value = request
321 .headers()
322 .get_all(COOKIE)
323 .iter()
324 .filter_map(|cookie_header| cookie_header.to_str().ok())
325 .flat_map(|cookie_header| cookie_header.split(';'))
326 .filter_map(|cookie_header| Cookie::parse_encoded(cookie_header.trim()).ok())
327 .filter(|cookie| cookie.name() == session_layer.cookie_name)
328 .find_map(|cookie| self.layer.verify_signature(cookie.value()).ok());
329
330 let inner = self.inner.clone();
331 let mut inner = std::mem::replace(&mut self.inner, inner);
332 Box::pin(async move {
333 let session_handle = session_layer.load_or_create(cookie_value.clone()).await;
334
335 let mut session = session_handle.write().await;
336 if let Some(ttl) = session_layer.session_ttl {
337 (*session).expire_in(ttl);
338 }
339 drop(session);
340
341 request.extensions_mut().insert(session_handle.clone());
342 let mut response = inner.call(request).await?;
343
344 let session = session_handle.read().await;
345 let (session_is_destroyed, session_data_changed) =
346 (session.is_destroyed(), session.data_changed());
347 drop(session);
348
349 let session = RwLock::into_inner(
352 Arc::try_unwrap(session_handle).expect("Session handle still has owners."),
353 );
354 if session_is_destroyed {
355 if let Err(e) = session_layer.store.destroy_session(session).await {
356 tracing::error!("Failed to destroy session: {:?}", e);
357 *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
358 }
359
360 let removal_cookie = session_layer.build_removal_cookie();
361
362 response.headers_mut().append(
363 SET_COOKIE,
364 HeaderValue::from_str(&removal_cookie.to_string()).unwrap(),
365 );
366
367 } else if session_layer.should_store(&cookie_value, session_data_changed) {
374 match session_layer.store.store_session(session).await {
375 Ok(Some(cookie_value)) => {
376 let cookie = session_layer.build_cookie(cookie_value);
377 response.headers_mut().append(
378 SET_COOKIE,
379 HeaderValue::from_str(&cookie.to_string()).unwrap(),
380 );
381 }
382
383 Ok(None) => {}
384
385 Err(e) => {
386 tracing::error!("Failed to reach session storage: {:?}", e);
387 *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
388 }
389 }
390 }
391
392 Ok(response)
393 })
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use async_session::{
400 serde::{Deserialize, Serialize},
401 serde_json,
402 };
403 use axum::http::{Request, Response};
404 use http::{
405 header::{COOKIE, SET_COOKIE},
406 HeaderValue, StatusCode,
407 };
408 use hyper::Body;
409 use rand::Rng;
410 use tower::{BoxError, Service, ServiceBuilder, ServiceExt};
411
412 use super::PersistencePolicy;
413 use crate::{async_session::MemoryStore, SessionHandle, SessionLayer};
414
415 #[derive(Deserialize, Serialize, PartialEq, Debug)]
416 struct Counter {
417 counter: i32,
418 }
419
420 enum ExpectedResult {
421 Some,
422 None,
423 }
424
425 #[tokio::test]
426 async fn sets_session_cookie() {
427 let secret = rand::thread_rng().gen::<[u8; 64]>();
428 let store = MemoryStore::new();
429 let session_layer = SessionLayer::new(store, &secret);
430 let mut service = ServiceBuilder::new().layer(session_layer).service_fn(echo);
431
432 let request = Request::get("/").body(Body::empty()).unwrap();
433
434 let res = service.ready().await.unwrap().call(request).await.unwrap();
435 assert_eq!(res.status(), StatusCode::OK);
436
437 assert!(res
438 .headers()
439 .get(SET_COOKIE)
440 .unwrap()
441 .to_str()
442 .unwrap()
443 .starts_with("sid="))
444 }
445
446 #[tokio::test]
447 async fn uses_valid_session() {
448 let secret = rand::thread_rng().gen::<[u8; 64]>();
449 let store = MemoryStore::new();
450 let session_layer = SessionLayer::new(store, &secret);
451 let mut service = ServiceBuilder::new()
452 .layer(session_layer)
453 .service_fn(increment);
454
455 let request = Request::get("/").body(Body::empty()).unwrap();
456
457 let res = service.ready().await.unwrap().call(request).await.unwrap();
458 let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone();
459
460 assert_eq!(res.status(), StatusCode::OK);
461
462 let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
463 let counter: Counter = serde_json::from_slice(json_bs).unwrap();
464 assert_eq!(counter, Counter { counter: 0 });
465
466 let mut request = Request::get("/").body(Body::empty()).unwrap();
467 request
468 .headers_mut()
469 .insert(COOKIE, session_cookie.to_owned());
470 let res = service.ready().await.unwrap().call(request).await.unwrap();
471 assert_eq!(res.status(), StatusCode::OK);
472
473 let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
474 let counter: Counter = serde_json::from_slice(json_bs).unwrap();
475 assert_eq!(counter, Counter { counter: 1 });
476 }
477
478 #[tokio::test]
479 async fn multiple_cookies_in_single_header() {
480 let secret = rand::thread_rng().gen::<[u8; 64]>();
481 let store = MemoryStore::new();
482 let session_layer = SessionLayer::new(store, &secret);
483 let mut service = ServiceBuilder::new()
484 .layer(session_layer)
485 .service_fn(increment);
486
487 let request = Request::get("/").body(Body::empty()).unwrap();
488
489 let res = service.ready().await.unwrap().call(request).await.unwrap();
490 let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone();
491
492 let request_cookie =
495 HeaderValue::from_str(&format!("key=value; {}", session_cookie.to_str().unwrap()))
496 .unwrap();
497
498 assert_eq!(res.status(), StatusCode::OK);
499
500 let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
501 let counter: Counter = serde_json::from_slice(json_bs).unwrap();
502 assert_eq!(counter, Counter { counter: 0 });
503
504 let mut request = Request::get("/").body(Body::empty()).unwrap();
505 request.headers_mut().insert(COOKIE, request_cookie);
506 let res = service.ready().await.unwrap().call(request).await.unwrap();
507 assert_eq!(res.status(), StatusCode::OK);
508
509 let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
510 let counter: Counter = serde_json::from_slice(json_bs).unwrap();
511 assert_eq!(counter, Counter { counter: 1 });
512 }
513
514 #[tokio::test]
515 async fn multiple_cookie_headers() {
516 let secret = rand::thread_rng().gen::<[u8; 64]>();
517 let store = MemoryStore::new();
518 let session_layer = SessionLayer::new(store, &secret);
519 let mut service = ServiceBuilder::new()
520 .layer(session_layer)
521 .service_fn(increment);
522
523 let request = Request::get("/").body(Body::empty()).unwrap();
524
525 let res = service.ready().await.unwrap().call(request).await.unwrap();
526 let session_cookie = res.headers().get(SET_COOKIE).unwrap().clone();
527 let dummy_cookie = HeaderValue::from_str("key=value").unwrap();
528
529 assert_eq!(res.status(), StatusCode::OK);
530
531 let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
532 let counter: Counter = serde_json::from_slice(json_bs).unwrap();
533 assert_eq!(counter, Counter { counter: 0 });
534
535 let mut request = Request::get("/").body(Body::empty()).unwrap();
536 request.headers_mut().append(COOKIE, dummy_cookie);
537 request.headers_mut().append(COOKIE, session_cookie);
538 let res = service.ready().await.unwrap().call(request).await.unwrap();
539 assert_eq!(res.status(), StatusCode::OK);
540
541 let json_bs = &hyper::body::to_bytes(res.into_body()).await.unwrap()[..];
542 let counter: Counter = serde_json::from_slice(json_bs).unwrap();
543 assert_eq!(counter, Counter { counter: 1 });
544 }
545
546 #[tokio::test]
547 async fn no_cookie_stored_when_no_session_is_required() {
548 let secret = rand::thread_rng().gen::<[u8; 64]>();
549 let store = MemoryStore::new();
550 let session_layer = SessionLayer::new(store, &secret)
551 .with_persistence_policy(PersistencePolicy::ChangedOnly);
552 let mut service = ServiceBuilder::new().layer(session_layer).service_fn(echo);
553
554 let request = Request::get("/").body(Body::empty()).unwrap();
555
556 let res = service.ready().await.unwrap().call(request).await.unwrap();
557 assert_eq!(res.status(), StatusCode::OK);
558
559 assert!(res.headers().get(SET_COOKIE).is_none());
560 }
561
562 async fn invalid_session_check_cookie_result(
563 persistence_policy: PersistencePolicy,
564 change_data: bool,
565 expect_cookie_header: (ExpectedResult, ExpectedResult),
566 ) {
567 let (expect_cookie_header_first, expect_cookie_header_second) = expect_cookie_header;
568 let secret = rand::thread_rng().gen::<[u8; 64]>();
569 let store = MemoryStore::new();
570 let session_layer =
571 SessionLayer::new(store, &secret).with_persistence_policy(persistence_policy);
572 let mut service = ServiceBuilder::new()
573 .layer(&session_layer)
574 .service_fn(echo_read_session);
575
576 let request = Request::get("/").body(Body::empty()).unwrap();
577
578 let res = service.ready().await.unwrap().call(request).await.unwrap();
579 assert_eq!(res.status(), StatusCode::OK);
580
581 match expect_cookie_header_first {
582 ExpectedResult::Some => assert!(
583 res.headers().get(SET_COOKIE).is_some(),
584 "Set-Cookie must be present for first response"
585 ),
586 ExpectedResult::None => assert!(
587 res.headers().get(SET_COOKIE).is_none(),
588 "Set-Cookie must not be present for first response"
589 ),
590 }
591
592 let mut service =
593 ServiceBuilder::new()
594 .layer(session_layer)
595 .service_fn(move |req| async move {
596 if change_data {
597 echo_with_session_change(req).await
598 } else {
599 echo_read_session(req).await
600 }
601 });
602 let mut request = Request::get("/").body(Body::empty()).unwrap();
603 request
604 .headers_mut()
605 .insert(COOKIE, "sid=aW52YWxpZC1zZXNzaW9uLWlk".parse().unwrap());
606 let res = service.ready().await.unwrap().call(request).await.unwrap();
607 match expect_cookie_header_second {
608 ExpectedResult::Some => assert!(
609 res.headers().get(SET_COOKIE).is_some(),
610 "Set-Cookie must be present for second response"
611 ),
612 ExpectedResult::None => assert!(
613 res.headers().get(SET_COOKIE).is_none(),
614 "Set-Cookie must not be present for second response"
615 ),
616 }
617 }
618
619 #[tokio::test]
620 async fn invalid_session_always_sets_guest_cookie() {
621 invalid_session_check_cookie_result(
622 PersistencePolicy::Always,
623 false,
624 (ExpectedResult::Some, ExpectedResult::Some),
625 )
626 .await;
627 }
628
629 #[tokio::test]
630 async fn invalid_session_sets_new_session_cookie_when_data_changes() {
631 invalid_session_check_cookie_result(
632 PersistencePolicy::ExistingOnly,
633 true,
634 (ExpectedResult::None, ExpectedResult::Some),
635 )
636 .await;
637 }
638
639 #[tokio::test]
640 async fn invalid_session_sets_no_cookie_when_no_data_changes() {
641 invalid_session_check_cookie_result(
642 PersistencePolicy::ExistingOnly,
643 false,
644 (ExpectedResult::None, ExpectedResult::None),
645 )
646 .await;
647 }
648
649 #[tokio::test]
650 async fn invalid_session_changedonly_sets_cookie_when_changed() {
651 invalid_session_check_cookie_result(
652 PersistencePolicy::ChangedOnly,
653 true,
654 (ExpectedResult::None, ExpectedResult::Some),
655 )
656 .await;
657 }
658
659 #[tokio::test]
660 async fn destroyed_sessions_sets_removal_cookie() {
661 let secret = rand::thread_rng().gen::<[u8; 64]>();
662 let store = MemoryStore::new();
663 let session_layer = SessionLayer::new(store, &secret);
664 let mut service = ServiceBuilder::new()
665 .layer(session_layer)
666 .service_fn(destroy);
667
668 let request = Request::get("/").body(Body::empty()).unwrap();
669
670 let res = service.ready().await.unwrap().call(request).await.unwrap();
671 assert_eq!(res.status(), StatusCode::OK);
672
673 let session_cookie = res
674 .headers()
675 .get(SET_COOKIE)
676 .unwrap()
677 .to_str()
678 .unwrap()
679 .to_string();
680 let mut request = Request::get("/destroy").body(Body::empty()).unwrap();
681 request
682 .headers_mut()
683 .insert(COOKIE, session_cookie.parse().unwrap());
684 let res = service.ready().await.unwrap().call(request).await.unwrap();
685 assert_eq!(
686 res.headers()
687 .get(SET_COOKIE)
688 .unwrap()
689 .to_str()
690 .unwrap()
691 .len(),
692 116
693 );
694 }
695
696 #[test]
697 #[should_panic]
698 fn too_short_secret() {
699 let store = MemoryStore::new();
700 SessionLayer::new(store, b"");
701 }
702
703 async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
704 Ok(Response::new(req.into_body()))
705 }
706
707 async fn echo_read_session(req: Request<Body>) -> Result<Response<Body>, BoxError> {
708 {
709 let session_handle = req.extensions().get::<SessionHandle>().unwrap();
710 let session = session_handle.write().await;
711 let _ = session.get::<String>("signed_in").unwrap_or_default();
712 }
713 Ok(Response::new(req.into_body()))
714 }
715
716 async fn echo_with_session_change(req: Request<Body>) -> Result<Response<Body>, BoxError> {
717 {
718 let session_handle = req.extensions().get::<SessionHandle>().unwrap();
719 let mut session = session_handle.write().await;
720 session.insert("signed_in", true).unwrap();
721 }
722 Ok(Response::new(req.into_body()))
723 }
724
725 async fn destroy(req: Request<Body>) -> Result<Response<Body>, BoxError> {
726 if req.headers().get(COOKIE).is_some() {
728 let session_handle = req.extensions().get::<SessionHandle>().unwrap();
729 let mut session = session_handle.write().await;
730 session.destroy();
731 }
732
733 Ok(Response::new(req.into_body()))
734 }
735
736 async fn increment(mut req: Request<Body>) -> Result<Response<Body>, BoxError> {
737 let mut counter = 0;
738
739 {
740 let session_handle = req.extensions().get::<SessionHandle>().unwrap();
741 let mut session = session_handle.write().await;
742 counter = session
743 .get("counter")
744 .map(|count: i32| count + 1)
745 .unwrap_or(counter);
746 session.insert("counter", counter).unwrap();
747 }
748
749 let body = serde_json::to_string(&Counter { counter }).unwrap();
750 *req.body_mut() = Body::from(body);
751
752 Ok(Response::new(req.into_body()))
753 }
754}