1use std::collections::HashMap;
65use std::future::Future;
66use std::pin::Pin;
67use std::sync::Arc;
68use std::task::{Context, Poll};
69
70use axum::extract::{FromRequestParts, Request};
71use axum::response::{IntoResponse, Response};
72use http::HeaderValue;
73use http::StatusCode;
74use http::header::{COOKIE, SET_COOKIE};
75use http::request::Parts;
76use thiserror::Error;
77use tokio::sync::RwLock;
78use tower::{Layer, Service};
79use uuid::Uuid;
80
81#[derive(Clone, Debug)]
91pub struct Session {
92 inner: Arc<RwLock<SessionInner>>,
93}
94
95#[derive(Debug)]
96struct SessionInner {
97 id: String,
98 old_id: Option<String>,
99 data: HashMap<String, String>,
100 dirty: bool,
101 destroyed: bool,
102}
103
104impl Session {
105 #[doc(hidden)]
107 #[must_use]
108 pub fn new_for_test(id: String, data: HashMap<String, String>) -> Self {
109 Self::new(id, data)
110 }
111
112 fn new(id: String, data: HashMap<String, String>) -> Self {
113 Self {
114 inner: Arc::new(RwLock::new(SessionInner {
115 id,
116 old_id: None,
117 data,
118 dirty: false,
119 destroyed: false,
120 })),
121 }
122 }
123
124 pub async fn id(&self) -> String {
126 self.inner.read().await.id.clone()
127 }
128
129 pub async fn get(&self, key: &str) -> Option<String> {
131 self.inner.read().await.data.get(key).cloned()
132 }
133
134 pub async fn insert(&self, key: impl Into<String>, value: impl Into<String>) {
136 let mut inner = self.inner.write().await;
137 inner.data.insert(key.into(), value.into());
138 inner.dirty = true;
139 }
140
141 pub async fn remove(&self, key: &str) -> Option<String> {
143 let mut inner = self.inner.write().await;
144 let val = inner.data.remove(key);
145 if val.is_some() {
146 inner.dirty = true;
147 }
148 val
149 }
150
151 pub async fn clear(&self) {
153 let mut inner = self.inner.write().await;
154 inner.data.clear();
155 inner.dirty = true;
156 }
157
158 pub async fn rotate_id(&self) {
163 let mut inner = self.inner.write().await;
164 let new_id = Uuid::new_v4().to_string();
165 if inner.old_id.is_none() {
166 inner.old_id = Some(inner.id.clone());
167 }
168 inner.id = new_id;
169 inner.dirty = true;
170 }
171
172 pub async fn destroy(&self) {
175 let mut inner = self.inner.write().await;
176 inner.data.clear();
177 inner.destroyed = true;
178 inner.dirty = true;
179 }
180
181 pub async fn contains_key(&self, key: &str) -> bool {
183 self.inner.read().await.data.contains_key(key)
184 }
185}
186
187impl<S> FromRequestParts<S> for Session
188where
189 S: Send + Sync,
190{
191 type Rejection = std::convert::Infallible;
192
193 fn from_request_parts(
194 parts: &mut Parts,
195 _state: &S,
196 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
197 let session = parts
198 .extensions
199 .get::<Self>()
200 .cloned()
201 .expect("SessionLayer must be installed to use the Session extractor");
202 async move { Ok(session) }
203 }
204}
205
206pub trait SessionStore: Send + Sync + 'static {
213 fn load(
216 &self,
217 id: &str,
218 ) -> impl Future<Output = Result<Option<HashMap<String, String>>, SessionStoreError>> + Send;
219
220 fn save(
222 &self,
223 id: &str,
224 data: HashMap<String, String>,
225 ) -> impl Future<Output = Result<(), SessionStoreError>> + Send;
226
227 fn destroy(&self, id: &str) -> impl Future<Output = Result<(), SessionStoreError>> + Send;
229}
230
231#[derive(Debug, Clone, Error, PartialEq, Eq)]
233#[error("{message}")]
234pub struct SessionStoreError {
235 message: String,
236}
237
238impl SessionStoreError {
239 #[must_use]
241 pub fn backend(operation: &'static str, error: impl std::fmt::Display) -> Self {
242 Self {
243 message: format!("{operation} failed: {error}"),
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default)]
255pub struct MemoryStore {
256 sessions: Arc<RwLock<HashMap<String, HashMap<String, String>>>>,
257}
258
259impl MemoryStore {
260 #[must_use]
262 pub fn new() -> Self {
263 Self::default()
264 }
265}
266
267impl SessionStore for MemoryStore {
268 async fn load(&self, id: &str) -> Result<Option<HashMap<String, String>>, SessionStoreError> {
269 Ok(self.sessions.read().await.get(id).cloned())
270 }
271
272 async fn save(&self, id: &str, data: HashMap<String, String>) -> Result<(), SessionStoreError> {
273 self.sessions.write().await.insert(id.to_owned(), data);
274 Ok(())
275 }
276
277 async fn destroy(&self, id: &str) -> Result<(), SessionStoreError> {
278 self.sessions.write().await.remove(id);
279 Ok(())
280 }
281}
282
283pub(crate) type BoxedLoadFuture<'a> = Pin<
295 Box<
296 dyn Future<Output = Result<Option<HashMap<String, String>>, SessionStoreError>> + Send + 'a,
297 >,
298>;
299pub(crate) type BoxedUnitFuture<'a> =
300 Pin<Box<dyn Future<Output = Result<(), SessionStoreError>> + Send + 'a>>;
301
302pub(crate) trait BoxedSessionStore: Send + Sync + 'static {
303 fn boxed_load<'a>(&'a self, id: &'a str) -> BoxedLoadFuture<'a>;
304
305 fn boxed_save<'a>(&'a self, id: &'a str, data: HashMap<String, String>) -> BoxedUnitFuture<'a>;
306
307 fn boxed_destroy<'a>(&'a self, id: &'a str) -> BoxedUnitFuture<'a>;
308}
309
310impl<S: SessionStore> BoxedSessionStore for S {
311 fn boxed_load<'a>(&'a self, id: &'a str) -> BoxedLoadFuture<'a> {
312 Box::pin(SessionStore::load(self, id))
313 }
314
315 fn boxed_save<'a>(&'a self, id: &'a str, data: HashMap<String, String>) -> BoxedUnitFuture<'a> {
316 Box::pin(SessionStore::save(self, id, data))
317 }
318
319 fn boxed_destroy<'a>(&'a self, id: &'a str) -> BoxedUnitFuture<'a> {
320 Box::pin(SessionStore::destroy(self, id))
321 }
322}
323
324#[derive(Clone)]
325pub(crate) struct ArcSessionStore(pub(crate) Arc<dyn BoxedSessionStore>);
326
327impl SessionStore for ArcSessionStore {
328 async fn load(&self, id: &str) -> Result<Option<HashMap<String, String>>, SessionStoreError> {
329 self.0.boxed_load(id).await
330 }
331
332 async fn save(&self, id: &str, data: HashMap<String, String>) -> Result<(), SessionStoreError> {
333 self.0.boxed_save(id, data).await
334 }
335
336 async fn destroy(&self, id: &str) -> Result<(), SessionStoreError> {
337 self.0.boxed_destroy(id).await
338 }
339}
340
341#[derive(Debug, Clone, serde::Deserialize)]
357pub struct SessionConfig {
358 #[serde(default)]
360 pub backend: SessionBackend,
361
362 #[serde(default = "default_cookie_name")]
364 pub cookie_name: String,
365
366 #[serde(default = "default_max_age_secs")]
368 pub max_age_secs: u64,
369
370 #[serde(default = "default_true")]
372 pub secure: bool,
373
374 #[serde(default = "default_same_site")]
376 pub same_site: String,
377
378 #[serde(default = "default_true")]
380 pub http_only: bool,
381
382 #[serde(default = "default_path")]
384 pub path: String,
385
386 #[serde(default)]
388 pub allow_memory_in_production: bool,
389
390 #[serde(default)]
392 pub redis: SessionRedisConfig,
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize)]
397#[serde(rename_all = "lowercase")]
398#[non_exhaustive]
399pub enum SessionBackend {
400 #[default]
402 Memory,
403 Redis,
405}
406
407impl SessionBackend {
408 pub(crate) fn from_env_value(value: &str) -> Option<Self> {
409 match value.trim().to_ascii_lowercase().as_str() {
410 "memory" => Some(Self::Memory),
411 "redis" => Some(Self::Redis),
412 _ => None,
413 }
414 }
415}
416
417#[derive(Debug, Clone, serde::Deserialize)]
419pub struct SessionRedisConfig {
420 #[serde(default)]
422 pub url: Option<String>,
423
424 #[serde(default = "default_redis_key_prefix")]
426 pub key_prefix: String,
427}
428
429impl Default for SessionRedisConfig {
430 fn default() -> Self {
431 Self {
432 url: None,
433 key_prefix: default_redis_key_prefix(),
434 }
435 }
436}
437
438fn default_redis_key_prefix() -> String {
439 "autumn:sessions".to_owned()
440}
441
442#[derive(Debug, Clone, PartialEq, Eq)]
444pub enum SessionBackendPlan {
445 Memory {
447 warn_in_production: bool,
449 },
450 Redis {
452 url: String,
454 key_prefix: String,
456 },
457}
458
459#[derive(Debug, Error, PartialEq, Eq)]
461#[non_exhaustive]
462pub enum SessionBackendConfigError {
463 #[error("session.backend=redis requires session.redis.url")]
465 MissingRedisUrl,
466 #[error("session.redis.url is not a valid Redis URL: {0}")]
468 InvalidRedisUrl(String),
469 #[error("session.backend=redis requires the `redis` feature")]
471 RedisFeatureDisabled,
472}
473
474fn default_cookie_name() -> String {
475 "autumn.sid".to_owned()
476}
477const fn default_max_age_secs() -> u64 {
478 86400
479}
480fn default_same_site() -> String {
481 "Lax".to_owned()
482}
483const fn default_true() -> bool {
484 true
485}
486fn default_path() -> String {
487 "/".to_owned()
488}
489
490impl Default for SessionConfig {
491 fn default() -> Self {
492 Self {
493 backend: SessionBackend::default(),
494 cookie_name: default_cookie_name(),
495 max_age_secs: default_max_age_secs(),
496 secure: true,
497 same_site: default_same_site(),
498 http_only: default_true(),
499 path: default_path(),
500 allow_memory_in_production: false,
501 redis: SessionRedisConfig::default(),
502 }
503 }
504}
505
506impl SessionConfig {
507 pub fn backend_plan(
514 &self,
515 profile: Option<&str>,
516 ) -> Result<SessionBackendPlan, SessionBackendConfigError> {
517 match self.backend {
518 SessionBackend::Memory => Ok(SessionBackendPlan::Memory {
519 warn_in_production: is_production_profile(profile)
520 && !self.allow_memory_in_production,
521 }),
522 SessionBackend::Redis => {
523 let Some(url) = self.redis.url.clone().filter(|url| !url.trim().is_empty()) else {
524 return Err(SessionBackendConfigError::MissingRedisUrl);
525 };
526
527 #[cfg(feature = "redis")]
528 {
529 if let Err(error) = redis::Client::open(url.clone()) {
530 return Err(SessionBackendConfigError::InvalidRedisUrl(
531 error.to_string(),
532 ));
533 }
534
535 Ok(SessionBackendPlan::Redis {
536 url,
537 key_prefix: self.redis.key_prefix.clone(),
538 })
539 }
540
541 #[cfg(not(feature = "redis"))]
542 {
543 let _ = url;
544 Err(SessionBackendConfigError::RedisFeatureDisabled)
545 }
546 }
547 }
548 }
549}
550
551fn is_production_profile(profile: Option<&str>) -> bool {
552 matches!(profile, Some("prod" | "production"))
553}
554
555fn get_cookie(headers: &http::HeaderMap, name: &str) -> Option<String> {
559 let mut found_token = None;
560
561 for cookie_header in headers.get_all(COOKIE) {
562 let Ok(cookie_str) = cookie_header.to_str() else {
563 continue;
564 };
565
566 for pair in cookie_str.split(';') {
567 let pair = pair.trim();
568 let Some((k, v)) = pair.split_once('=') else {
569 continue;
570 };
571
572 if k.trim() != name {
573 continue;
574 }
575
576 if found_token.is_some() {
577 return None;
581 }
582
583 found_token = Some(v.trim().to_owned());
584 }
585 }
586 found_token
587}
588
589fn build_set_cookie(config: &SessionConfig, session_id: &str) -> String {
591 use std::fmt::Write;
592 let mut cookie = format!(
593 "{}={}; Path={}",
594 config.cookie_name, session_id, config.path
595 );
596 let _ = write!(cookie, "; Max-Age={}", config.max_age_secs);
597 if config.http_only {
598 cookie.push_str("; HttpOnly");
599 }
600 if config.secure {
601 cookie.push_str("; Secure");
602 }
603 let _ = write!(cookie, "; SameSite={}", config.same_site);
604 cookie
605}
606
607fn build_expire_cookie(config: &SessionConfig) -> String {
609 format!(
610 "{}=; Path={}; Max-Age=0; HttpOnly; SameSite={}",
611 config.cookie_name, config.path, config.same_site
612 )
613}
614
615#[derive(Clone)]
630pub struct SessionLayer<S: SessionStore> {
631 store: Arc<S>,
632 config: Arc<SessionConfig>,
633 signing_keys: Option<Arc<crate::security::config::ResolvedSigningKeys>>,
634}
635
636impl<S: SessionStore> SessionLayer<S> {
637 pub fn new(store: S, config: SessionConfig) -> Self {
639 Self {
640 store: Arc::new(store),
641 config: Arc::new(config),
642 signing_keys: None,
643 }
644 }
645
646 #[must_use]
653 pub fn with_signing_keys(
654 mut self,
655 keys: Arc<crate::security::config::ResolvedSigningKeys>,
656 ) -> Self {
657 self.signing_keys = Some(keys);
658 self
659 }
660}
661
662impl<S: SessionStore + Clone, Inner> Layer<Inner> for SessionLayer<S> {
663 type Service = SessionService<S, Inner>;
664
665 fn layer(&self, inner: Inner) -> Self::Service {
666 SessionService {
667 inner,
668 store: Arc::clone(&self.store),
669 config: Arc::clone(&self.config),
670 signing_keys: self.signing_keys.clone(),
671 }
672 }
673}
674
675#[derive(Clone)]
677pub struct SessionService<S: SessionStore, Inner> {
678 inner: Inner,
679 store: Arc<S>,
680 config: Arc<SessionConfig>,
681 signing_keys: Option<Arc<crate::security::config::ResolvedSigningKeys>>,
682}
683
684impl<St, Inner> Service<Request> for SessionService<St, Inner>
685where
686 St: SessionStore + Clone,
687 Inner: Service<Request, Response = Response> + Clone + Send + 'static,
688 Inner::Future: Send + 'static,
689 Inner::Error: Send + 'static,
690{
691 type Response = Response;
692 type Error = Inner::Error;
693 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
694
695 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
696 self.inner.poll_ready(cx)
697 }
698
699 fn call(&mut self, mut req: Request) -> Self::Future {
700 let store = Arc::clone(&self.store);
701 let config = Arc::clone(&self.config);
702 let signing_keys = self.signing_keys.clone();
703 let mut inner = self.inner.clone();
704 std::mem::swap(&mut self.inner, &mut inner);
706
707 Box::pin(async move {
708 let raw_cookie = get_cookie(req.headers(), &config.cookie_name);
710 let existing_id: Option<String> = match (raw_cookie, &signing_keys) {
711 (None, _) => None,
712 (Some(raw), None) => Some(raw),
713 (Some(raw), Some(keys)) => {
714 if let Some((id, sig)) = raw.split_once('.') {
716 if keys.verify(id.as_bytes(), sig) {
717 Some(id.to_owned())
718 } else {
719 None }
721 } else {
722 None }
724 }
725 };
726
727 let (session_id, data) = if let Some(ref id) = existing_id {
728 match store.load(id).await {
729 Ok(Some(data)) => (id.clone(), data),
730 Ok(None) => (Uuid::new_v4().to_string(), HashMap::new()),
731 Err(error) => return Ok(session_store_unavailable_response(&error)),
732 }
733 } else {
734 (Uuid::new_v4().to_string(), HashMap::new())
735 };
736
737 let session = Session::new(session_id.clone(), data);
739 req.extensions_mut().insert(session.clone());
740
741 let mut response = inner.call(req).await?;
743
744 let inner_guard = session.inner.read().await;
746 if inner_guard.destroyed {
747 if let Err(error) = store.destroy(&session_id).await {
748 return Ok(session_store_unavailable_response(&error));
749 }
750 if let Ok(val) = HeaderValue::from_str(&build_expire_cookie(&config)) {
751 response.headers_mut().append(SET_COOKIE, val);
752 }
753 } else if inner_guard.dirty {
754 let data = inner_guard.data.clone();
755 let sid = inner_guard.id.clone();
756 if let Some(ref old_id) = inner_guard.old_id
757 && let Err(error) = store.destroy(old_id).await
758 {
759 return Ok(session_store_unavailable_response(&error));
760 }
761 drop(inner_guard);
762 if let Err(error) = store.save(&sid, data).await {
763 return Ok(session_store_unavailable_response(&error));
764 }
765 let cookie_value = signing_keys.as_ref().map_or_else(
767 || sid.clone(),
768 |keys| format!("{sid}.{}", keys.sign(sid.as_bytes())),
769 );
770 if let Ok(val) = HeaderValue::from_str(&build_set_cookie(&config, &cookie_value)) {
771 response.headers_mut().append(SET_COOKIE, val);
772 }
773 }
774
775 Ok(response)
776 })
777 }
778}
779
780fn session_store_unavailable_response(error: &SessionStoreError) -> Response {
781 tracing::error!("session store unavailable: {error}");
782 (StatusCode::SERVICE_UNAVAILABLE, "Session store unavailable").into_response()
783}
784
785pub(crate) fn apply_session_layer<S: Clone + Send + Sync + 'static>(
786 router: axum::Router<S>,
787 config: &SessionConfig,
788 profile: Option<&str>,
789 custom_store: Option<Arc<dyn BoxedSessionStore>>,
790 signing_keys: Option<Arc<crate::security::config::ResolvedSigningKeys>>,
791) -> Result<axum::Router<S>, SessionBackendConfigError> {
792 if let Some(store) = custom_store {
793 tracing::debug!(
794 "Custom session store installed via with_session_store(); skipping config-driven backend selection"
795 );
796 let mut layer = SessionLayer::new(ArcSessionStore(store), config.clone());
797 if let Some(keys) = signing_keys {
798 layer = layer.with_signing_keys(keys);
799 }
800 return Ok(router.layer(layer));
801 }
802
803 match config.backend_plan(profile)? {
804 SessionBackendPlan::Memory { warn_in_production } => {
805 if warn_in_production {
806 tracing::warn!(
807 "prod profile is using in-memory sessions; set session.backend=redis or \
808 session.allow_memory_in_production=true to acknowledge the risk"
809 );
810 }
811 let mut layer = SessionLayer::new(MemoryStore::new(), config.clone());
812 if let Some(keys) = signing_keys {
813 layer = layer.with_signing_keys(keys);
814 }
815 Ok(router.layer(layer))
816 }
817 SessionBackendPlan::Redis { .. } => {
818 #[cfg(feature = "redis")]
819 {
820 let store = crate::session_redis::RedisStore::from_config(config)?;
821 let mut layer = SessionLayer::new(store, config.clone());
822 if let Some(keys) = signing_keys {
823 layer = layer.with_signing_keys(keys);
824 }
825 Ok(router.layer(layer))
826 }
827
828 #[cfg(not(feature = "redis"))]
829 {
830 let _ = router;
831 Err(SessionBackendConfigError::RedisFeatureDisabled)
832 }
833 }
834 }
835}
836
837#[cfg(test)]
838mod tests {
839 use super::*;
840 use axum::Router;
841 use axum::body::Body;
842 use axum::routing::get;
843 use http::Request as HttpRequest;
844 use tower::ServiceExt;
845
846 #[derive(Clone, Default)]
850 struct SentinelStore {
851 load_calls: Arc<RwLock<u32>>,
852 }
853
854 impl SessionStore for SentinelStore {
855 async fn load(
856 &self,
857 _id: &str,
858 ) -> Result<Option<HashMap<String, String>>, SessionStoreError> {
859 *self.load_calls.write().await += 1;
860 let mut data = HashMap::new();
863 data.insert("from".to_owned(), "sentinel".to_owned());
864 Ok(Some(data))
865 }
866
867 async fn save(
868 &self,
869 _id: &str,
870 _data: HashMap<String, String>,
871 ) -> Result<(), SessionStoreError> {
872 Ok(())
873 }
874
875 async fn destroy(&self, _id: &str) -> Result<(), SessionStoreError> {
876 Ok(())
877 }
878 }
879
880 #[tokio::test]
881 async fn arc_session_store_wrapper_delegates_to_inner_session_store() {
882 let inner = SentinelStore::default();
883 let load_counter = inner.load_calls.clone();
884 let arc: Arc<dyn BoxedSessionStore> = Arc::new(inner);
885 let wrapper = ArcSessionStore(arc);
886
887 let result = wrapper
888 .load("session-id")
889 .await
890 .expect("wrapped store should succeed");
891
892 assert_eq!(*load_counter.read().await, 1);
893 assert_eq!(
894 result
895 .as_ref()
896 .and_then(|m| m.get("from"))
897 .map(String::as_str),
898 Some("sentinel"),
899 "wrapper must return data from the wrapped impl, not a default"
900 );
901 }
902
903 #[tokio::test]
904 async fn boxed_session_store_blanket_impl_works_for_any_session_store() {
905 let store = SentinelStore::default();
907 let boxed: Arc<dyn BoxedSessionStore> = Arc::new(store);
908 let result = boxed.boxed_load("session-id").await.unwrap();
909 assert!(result.is_some());
910 }
911
912 #[derive(Clone)]
913 struct FailingStore {
914 fail_on_load: bool,
915 fail_on_save: bool,
916 fail_on_destroy: bool,
917 }
918
919 impl SessionStore for FailingStore {
920 async fn load(
921 &self,
922 _id: &str,
923 ) -> Result<Option<HashMap<String, String>>, SessionStoreError> {
924 if self.fail_on_load {
925 Err(SessionStoreError::backend("load", "boom"))
926 } else {
927 Ok(None)
928 }
929 }
930
931 async fn save(
932 &self,
933 _id: &str,
934 _data: HashMap<String, String>,
935 ) -> Result<(), SessionStoreError> {
936 if self.fail_on_save {
937 Err(SessionStoreError::backend("save", "boom"))
938 } else {
939 Ok(())
940 }
941 }
942
943 async fn destroy(&self, _id: &str) -> Result<(), SessionStoreError> {
944 if self.fail_on_destroy {
945 Err(SessionStoreError::backend("destroy", "boom"))
946 } else {
947 Ok(())
948 }
949 }
950 }
951
952 #[tokio::test]
953 async fn memory_store_save_and_load() {
954 let store = MemoryStore::new();
955 let mut data = HashMap::new();
956 data.insert("user".into(), "alice".into());
957 store.save("sess1", data).await.unwrap();
958
959 let loaded = store.load("sess1").await.unwrap();
960 assert!(loaded.is_some());
961 assert_eq!(loaded.unwrap().get("user").unwrap(), "alice");
962 }
963
964 #[tokio::test]
965 async fn memory_store_destroy() {
966 let store = MemoryStore::new();
967 store.save("sess1", HashMap::new()).await.unwrap();
968 store.destroy("sess1").await.unwrap();
969 assert!(store.load("sess1").await.unwrap().is_none());
970 }
971
972 #[tokio::test]
973 async fn memory_store_load_missing() {
974 let store = MemoryStore::new();
975 assert!(store.load("nonexistent").await.unwrap().is_none());
976 }
977
978 #[tokio::test]
979 async fn session_insert_and_get() {
980 let session = Session::new("test".into(), HashMap::new());
981 session.insert("key", "value").await;
982 assert_eq!(session.get("key").await, Some("value".to_owned()));
983 }
984
985 #[tokio::test]
986 async fn session_remove() {
987 let mut data = HashMap::new();
988 data.insert("key".into(), "value".into());
989 let session = Session::new("test".into(), data);
990 let removed = session.remove("key").await;
991 assert_eq!(removed, Some("value".to_owned()));
992 assert!(session.get("key").await.is_none());
993 }
994
995 #[tokio::test]
996 async fn session_clear() {
997 let mut data = HashMap::new();
998 data.insert("a".into(), "1".into());
999 data.insert("b".into(), "2".into());
1000 let session = Session::new("test".into(), data);
1001 session.clear().await;
1002 assert!(session.get("a").await.is_none());
1003 assert!(session.get("b").await.is_none());
1004 }
1005
1006 #[tokio::test]
1007 async fn session_contains_key() {
1008 let mut data = HashMap::new();
1009 data.insert("exists".into(), "yes".into());
1010 let session = Session::new("test".into(), data);
1011 assert!(session.contains_key("exists").await);
1012 assert!(!session.contains_key("missing").await);
1013 }
1014
1015 #[tokio::test]
1016 async fn session_destroy_marks_destroyed() {
1017 let session = Session::new("test".into(), HashMap::new());
1018 session.insert("key", "value").await;
1019 session.destroy().await;
1020 let inner = session.inner.read().await;
1021 let destroyed = inner.destroyed;
1022 let empty = inner.data.is_empty();
1023 drop(inner);
1024 assert!(destroyed);
1025 assert!(empty);
1026 }
1027
1028 #[test]
1029 fn get_cookie_extracts_value() {
1030 let mut headers = http::HeaderMap::new();
1031 headers.insert(
1032 COOKIE,
1033 HeaderValue::from_static("autumn.sid=abc123; other=xyz"),
1034 );
1035 assert_eq!(get_cookie(&headers, "autumn.sid"), Some("abc123".into()));
1036 assert_eq!(get_cookie(&headers, "other"), Some("xyz".into()));
1037 assert_eq!(get_cookie(&headers, "missing"), None);
1038 }
1039
1040 #[test]
1041 fn get_cookie_rejects_multiple_cookies() {
1042 let mut headers = http::HeaderMap::new();
1043 headers.insert(
1044 COOKIE,
1045 HeaderValue::from_static("autumn.sid=abc123; autumn.sid=xyz456"),
1046 );
1047 assert_eq!(get_cookie(&headers, "autumn.sid"), None);
1048
1049 let mut headers2 = http::HeaderMap::new();
1050 headers2.append(COOKIE, HeaderValue::from_static("autumn.sid=abc123"));
1051 headers2.append(COOKIE, HeaderValue::from_static("autumn.sid=xyz456"));
1052 assert_eq!(get_cookie(&headers2, "autumn.sid"), None);
1053 }
1054
1055 #[test]
1056 fn build_set_cookie_contains_required_parts() {
1057 let config = SessionConfig::default();
1058 let cookie = build_set_cookie(&config, "test-id");
1059 assert!(cookie.contains("autumn.sid=test-id"));
1060 assert!(cookie.contains("Path=/"));
1061 assert!(cookie.contains("HttpOnly"));
1062 assert!(cookie.contains("SameSite=Lax"));
1063 assert!(cookie.contains("Max-Age=86400"));
1064 }
1065
1066 #[test]
1067 fn build_expire_cookie_has_zero_max_age() {
1068 let config = SessionConfig::default();
1069 let cookie = build_expire_cookie(&config);
1070 assert!(cookie.contains("Max-Age=0"));
1071 }
1072
1073 #[test]
1074 fn session_config_defaults() {
1075 let config = SessionConfig::default();
1076 assert_eq!(config.backend, SessionBackend::Memory);
1077 assert_eq!(config.cookie_name, "autumn.sid");
1078 assert_eq!(config.max_age_secs, 86400);
1079 assert!(config.secure);
1080 assert_eq!(config.same_site, "Lax");
1081 assert!(config.http_only);
1082 assert_eq!(config.path, "/");
1083 assert!(!config.allow_memory_in_production);
1084 assert!(config.redis.url.is_none());
1085 assert_eq!(config.redis.key_prefix, "autumn:sessions");
1086 }
1087
1088 #[test]
1089 fn session_backend_plan_warns_for_prod_memory_without_ack() {
1090 let config = SessionConfig::default();
1091 let plan = config.backend_plan(Some("prod")).unwrap();
1092 assert_eq!(
1093 plan,
1094 SessionBackendPlan::Memory {
1095 warn_in_production: true
1096 }
1097 );
1098 }
1099
1100 #[test]
1101 fn session_backend_plan_suppresses_prod_warning_when_acknowledged() {
1102 let config = SessionConfig {
1103 allow_memory_in_production: true,
1104 ..SessionConfig::default()
1105 };
1106 let plan = config.backend_plan(Some("prod")).unwrap();
1107 assert_eq!(
1108 plan,
1109 SessionBackendPlan::Memory {
1110 warn_in_production: false
1111 }
1112 );
1113 }
1114
1115 #[test]
1116 fn session_backend_plan_requires_redis_url() {
1117 let config = SessionConfig {
1118 backend: SessionBackend::Redis,
1119 ..SessionConfig::default()
1120 };
1121 let error = config.backend_plan(None).unwrap_err();
1122 assert_eq!(error, SessionBackendConfigError::MissingRedisUrl);
1123 }
1124
1125 #[tokio::test]
1126 async fn session_layer_sets_cookie_on_new_session() {
1127 use crate::state::AppState;
1128 async fn handler(session: Session) -> String {
1129 session.insert("visited", "true").await;
1130 "ok".to_owned()
1131 }
1132
1133 let state = AppState {
1134 extensions: std::sync::Arc::new(std::sync::RwLock::new(
1135 std::collections::HashMap::new(),
1136 )),
1137 #[cfg(feature = "db")]
1138 pool: None,
1139 #[cfg(feature = "db")]
1140 replica_pool: None,
1141 profile: None,
1142 started_at: std::time::Instant::now(),
1143 health_detailed: false,
1144 probes: crate::probe::ProbeState::ready_for_test(),
1145 metrics: crate::middleware::MetricsCollector::new(),
1146 log_levels: crate::actuator::LogLevels::new("info"),
1147 task_registry: crate::actuator::TaskRegistry::new(),
1148 job_registry: crate::actuator::JobRegistry::new(),
1149 config_props: crate::actuator::ConfigProperties::default(),
1150 #[cfg(feature = "ws")]
1151 channels: crate::channels::Channels::new(32),
1152 #[cfg(feature = "ws")]
1153 shutdown: tokio_util::sync::CancellationToken::new(),
1154 policy_registry: crate::authorization::PolicyRegistry::default(),
1155 forbidden_response: crate::authorization::ForbiddenResponse::default(),
1156 auth_session_key: "user_id".to_owned(),
1157 shared_cache: None,
1158 };
1159
1160 let app = Router::new()
1161 .route("/", get(handler))
1162 .layer(SessionLayer::new(
1163 MemoryStore::new(),
1164 SessionConfig::default(),
1165 ))
1166 .with_state(state);
1167
1168 let response = app
1169 .oneshot(HttpRequest::builder().uri("/").body(Body::empty()).unwrap())
1170 .await
1171 .unwrap();
1172
1173 assert_eq!(response.status(), http::StatusCode::OK);
1174 let set_cookie = response
1175 .headers()
1176 .get(SET_COOKIE)
1177 .expect("should set session cookie");
1178 let cookie_str = set_cookie.to_str().unwrap();
1179 assert!(cookie_str.contains("autumn.sid="));
1180 }
1181
1182 fn test_state() -> crate::state::AppState {
1183 crate::state::AppState {
1184 extensions: Arc::new(std::sync::RwLock::new(HashMap::new())),
1185 #[cfg(feature = "db")]
1186 pool: None,
1187 #[cfg(feature = "db")]
1188 replica_pool: None,
1189 profile: None,
1190 started_at: std::time::Instant::now(),
1191 health_detailed: false,
1192 probes: crate::probe::ProbeState::ready_for_test(),
1193 metrics: crate::middleware::MetricsCollector::new(),
1194 log_levels: crate::actuator::LogLevels::new("info"),
1195 task_registry: crate::actuator::TaskRegistry::new(),
1196 job_registry: crate::actuator::JobRegistry::new(),
1197 config_props: crate::actuator::ConfigProperties::default(),
1198 #[cfg(feature = "ws")]
1199 channels: crate::channels::Channels::new(32),
1200 #[cfg(feature = "ws")]
1201 shutdown: tokio_util::sync::CancellationToken::new(),
1202 policy_registry: crate::authorization::PolicyRegistry::default(),
1203 forbidden_response: crate::authorization::ForbiddenResponse::default(),
1204 auth_session_key: "user_id".to_owned(),
1205 shared_cache: None,
1206 }
1207 }
1208
1209 #[tokio::test]
1210 async fn session_layer_persists_data_across_requests() {
1211 async fn write_handler(session: Session) -> String {
1212 session.insert("user", "alice").await;
1213 "saved".to_owned()
1214 }
1215
1216 async fn read_handler(session: Session) -> String {
1217 session.get("user").await.unwrap_or_default()
1218 }
1219
1220 let store = MemoryStore::new();
1221 let config = SessionConfig::default();
1222 let state = test_state();
1223
1224 let app = Router::new()
1225 .route("/write", get(write_handler))
1226 .route("/read", get(read_handler))
1227 .layer(SessionLayer::new(store, config))
1228 .with_state(state);
1229
1230 let resp1 = app
1232 .clone()
1233 .oneshot(
1234 HttpRequest::builder()
1235 .uri("/write")
1236 .body(Body::empty())
1237 .unwrap(),
1238 )
1239 .await
1240 .unwrap();
1241
1242 let cookie = resp1
1243 .headers()
1244 .get(SET_COOKIE)
1245 .unwrap()
1246 .to_str()
1247 .unwrap()
1248 .to_owned();
1249 let session_cookie = cookie.split(';').next().unwrap();
1251
1252 let resp2 = app
1254 .oneshot(
1255 HttpRequest::builder()
1256 .uri("/read")
1257 .header(COOKIE, session_cookie)
1258 .body(Body::empty())
1259 .unwrap(),
1260 )
1261 .await
1262 .unwrap();
1263
1264 let body = axum::body::to_bytes(resp2.into_body(), usize::MAX)
1265 .await
1266 .unwrap();
1267 assert_eq!(std::str::from_utf8(&body).unwrap(), "alice");
1268 }
1269
1270 #[tokio::test]
1271 async fn session_destroy_expires_cookie() {
1272 async fn handler(session: Session) -> String {
1273 session.destroy().await;
1274 "destroyed".to_owned()
1275 }
1276
1277 let state = test_state();
1278
1279 let store = MemoryStore::new();
1280 store
1281 .save("existing-id", HashMap::from([("k".into(), "v".into())]))
1282 .await
1283 .unwrap();
1284
1285 let app = Router::new()
1286 .route("/", get(handler))
1287 .layer(SessionLayer::new(store.clone(), SessionConfig::default()))
1288 .with_state(state);
1289
1290 let response = app
1291 .oneshot(
1292 HttpRequest::builder()
1293 .uri("/")
1294 .header(COOKIE, "autumn.sid=existing-id")
1295 .body(Body::empty())
1296 .unwrap(),
1297 )
1298 .await
1299 .unwrap();
1300
1301 let cookie = response
1302 .headers()
1303 .get(SET_COOKIE)
1304 .unwrap()
1305 .to_str()
1306 .unwrap();
1307 assert!(cookie.contains("Max-Age=0"), "cookie should be expired");
1308
1309 assert!(store.load("existing-id").await.unwrap().is_none());
1311 }
1312
1313 #[tokio::test]
1314 async fn session_layer_returns_503_when_store_load_fails() {
1315 let state = test_state();
1316
1317 let app = Router::new()
1318 .route("/", get(|| async { "ok" }))
1319 .layer(SessionLayer::new(
1320 FailingStore {
1321 fail_on_load: true,
1322 fail_on_save: false,
1323 fail_on_destroy: false,
1324 },
1325 SessionConfig::default(),
1326 ))
1327 .with_state(state);
1328
1329 let response = app
1330 .oneshot(
1331 HttpRequest::builder()
1332 .uri("/")
1333 .header(COOKIE, "autumn.sid=existing-id")
1334 .body(Body::empty())
1335 .unwrap(),
1336 )
1337 .await
1338 .unwrap();
1339
1340 assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
1341 }
1342
1343 #[tokio::test]
1344 async fn session_layer_returns_503_when_store_save_fails() {
1345 let state = test_state();
1346
1347 let app = Router::new()
1348 .route(
1349 "/",
1350 get(|session: Session| async move {
1351 session.insert("user", "alice").await;
1352 "ok"
1353 }),
1354 )
1355 .layer(SessionLayer::new(
1356 FailingStore {
1357 fail_on_load: false,
1358 fail_on_save: true,
1359 fail_on_destroy: false,
1360 },
1361 SessionConfig::default(),
1362 ))
1363 .with_state(state);
1364
1365 let response = app
1366 .oneshot(HttpRequest::builder().uri("/").body(Body::empty()).unwrap())
1367 .await
1368 .unwrap();
1369
1370 assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
1371 }
1372
1373 #[tokio::test]
1376 async fn session_cookie_is_signed_when_signing_keys_set() {
1377 use crate::security::config::{SigningSecretConfig, resolve_signing_keys};
1378 use std::sync::Arc;
1379
1380 let config = SigningSecretConfig {
1381 secret: Some("k".repeat(32)),
1382 previous_secrets: vec![],
1383 };
1384 let keys = Arc::new(resolve_signing_keys(&config));
1385
1386 let app = Router::new()
1387 .route(
1388 "/",
1389 get(|s: Session| async move {
1390 s.insert("k", "v").await;
1391 "ok"
1392 }),
1393 )
1394 .layer(
1395 SessionLayer::new(MemoryStore::new(), SessionConfig::default())
1396 .with_signing_keys(keys),
1397 );
1398
1399 let req = HttpRequest::builder().uri("/").body(Body::empty()).unwrap();
1400 let resp = app.oneshot(req).await.unwrap();
1401
1402 let set_cookie = resp
1403 .headers()
1404 .get("set-cookie")
1405 .expect("should set cookie")
1406 .to_str()
1407 .unwrap();
1408 let cookie_value = set_cookie
1409 .split('=')
1410 .nth(1)
1411 .unwrap()
1412 .split(';')
1413 .next()
1414 .unwrap()
1415 .trim();
1416
1417 assert!(
1418 cookie_value.contains('.'),
1419 "signed session cookie must be {{id}}.{{hmac}}, got: {cookie_value}"
1420 );
1421 let (id_part, sig_part) = cookie_value.split_once('.').unwrap();
1422 assert!(!id_part.is_empty());
1423 assert_eq!(sig_part.len(), 64, "HMAC-SHA256 hex must be 64 chars");
1424 }
1425
1426 #[tokio::test]
1427 async fn session_layer_rejects_tampered_cookie() {
1428 use crate::security::config::{SigningSecretConfig, resolve_signing_keys};
1429 use std::sync::Arc;
1430
1431 let keys = Arc::new(resolve_signing_keys(&SigningSecretConfig {
1432 secret: Some("k".repeat(32)),
1433 previous_secrets: vec![],
1434 }));
1435
1436 let store = MemoryStore::new();
1437 let session_id = Uuid::new_v4().to_string();
1438 let mut data = HashMap::new();
1439 data.insert("user".to_owned(), "alice".to_owned());
1440 store.save(&session_id, data).await.unwrap();
1441
1442 let app = Router::new()
1443 .route(
1444 "/",
1445 get(|s: Session| async move {
1446 s.get("user").await.unwrap_or_else(|| "none".to_owned())
1447 }),
1448 )
1449 .layer(SessionLayer::new(store, SessionConfig::default()).with_signing_keys(keys));
1450
1451 let bad_sig = "0".repeat(64);
1453 let bad_cookie = format!("autumn.sid={session_id}.{bad_sig}");
1454 let req = HttpRequest::builder()
1455 .uri("/")
1456 .header("cookie", bad_cookie)
1457 .body(Body::empty())
1458 .unwrap();
1459 let resp = app.oneshot(req).await.unwrap();
1460 let body = axum::body::to_bytes(resp.into_body(), 64).await.unwrap();
1461 assert_eq!(&body[..], b"none", "tampered cookie must not load session");
1462 }
1463
1464 #[tokio::test]
1465 async fn session_layer_accepts_previous_key_signed_cookie() {
1466 use crate::security::config::{
1467 ResolvedSigningKeys, SigningSecretConfig, resolve_signing_keys,
1468 };
1469 use std::sync::Arc;
1470
1471 let old_secret = "old-key".repeat(5); let old_keys = resolve_signing_keys(&SigningSecretConfig {
1473 secret: Some(old_secret.clone()),
1474 previous_secrets: vec![],
1475 });
1476
1477 let session_id = Uuid::new_v4().to_string();
1478 let old_sig = old_keys.sign(session_id.as_bytes());
1479 let signed_value = format!("{session_id}.{old_sig}");
1480
1481 let new_keys = Arc::new(ResolvedSigningKeys::new(
1482 "new-key".repeat(5).into_bytes(),
1483 vec![old_secret.into_bytes()],
1484 ));
1485
1486 let store = MemoryStore::new();
1487 let mut data = HashMap::new();
1488 data.insert("user".to_owned(), "bob".to_owned());
1489 store.save(&session_id, data).await.unwrap();
1490
1491 let app = Router::new()
1492 .route(
1493 "/",
1494 get(|s: Session| async move {
1495 s.get("user").await.unwrap_or_else(|| "none".to_owned())
1496 }),
1497 )
1498 .layer(SessionLayer::new(store, SessionConfig::default()).with_signing_keys(new_keys));
1499
1500 let req = HttpRequest::builder()
1501 .uri("/")
1502 .header("cookie", format!("autumn.sid={signed_value}"))
1503 .body(Body::empty())
1504 .unwrap();
1505 let resp = app.oneshot(req).await.unwrap();
1506 let body = axum::body::to_bytes(resp.into_body(), 64).await.unwrap();
1507 assert_eq!(
1508 &body[..],
1509 b"bob",
1510 "previous-key-signed cookie must load session"
1511 );
1512 }
1513}