rocket_session_store/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4#[cfg(test)]
5mod test;
6
7pub mod memory;
8
9#[cfg(feature = "redis")]
10pub mod redis;
11
12use std::{
13	sync::Arc,
14	time::Duration,
15};
16
17use rand::{rngs::OsRng, Rng, TryRngCore};
18use rocket::{
19	fairing::{
20		Fairing,
21		Info,
22		Kind,
23	},
24	http::{
25		private::cookie::CookieBuilder,
26		Status,
27	},
28	request::{
29		FromRequest,
30		Outcome,
31	},
32	response::Responder,
33	tokio::sync::Mutex,
34	Build, Request, Response, Rocket, State,
35};
36use thiserror::Error;
37
38fn new_id(length: usize) -> SessionID {
39	SessionID(
40		OsRng
41			.unwrap_err()
42			.sample_iter(&rand::distr::Alphanumeric)
43			.take(length)
44			.map(char::from)
45			.collect(),
46	)
47}
48
49const ID_LENGTH: usize = 24;
50
51/// A generic store in which to write and retrive sessions either
52/// trough an in memory hashmap or a database connection.
53#[rocket::async_trait]
54pub trait Store: Send + Sync {
55	/// Type that is associated with sessions.
56	///
57	/// The store will store and retrieve values of this type.
58	type Value;
59	/// Get the value from the store
60	async fn get(&self, id: &str) -> SessionResult<Option<Self::Value>>;
61	/// Set the value from the store
62	async fn set(&self, id: &str, value: Self::Value, duration: Duration) -> SessionResult<()>;
63	/// Touch the value, refreshing its expiry time.
64	async fn touch(&self, id: &str, duration: Duration) -> SessionResult<()>;
65	/// Remove the value from the store.
66	async fn remove(&self, id: &str) -> SessionResult<()>;
67}
68
69/// String representing the ID.
70#[derive(Debug, Clone)]
71struct SessionID(String);
72
73impl AsRef<str> for SessionID {
74	fn as_ref(&self) -> &str {
75		&self.0
76	}
77}
78
79/// A request guard implementing [FromRequest] to retrive the session
80/// based on the cookie from the user.
81pub struct Session<'s, T: Send + Sync + Clone + 'static> {
82	store: &'s State<SessionStore<T>>,
83	token: SessionID,
84	new_token: Arc<Mutex<Option<SessionID>>>,
85}
86
87impl<'s, T: Send + Sync + Clone + 'static> Session<'s, T> {
88	/// Get the session value from the store.
89	///
90	/// Returns [None] if there is no initialized session value
91	/// or if the value has expired.
92	pub async fn get(&self) -> SessionResult<Option<T>> {
93		self.store.store.get(self.token.as_ref()).await
94	}
95
96	/// Sets the session value from the store.
97	///
98	/// This will refresh the expiration timer.
99	pub async fn set(&self, value: T) -> SessionResult<()> {
100		self.store
101			.store
102			.set(self.token.as_ref(), value, self.store.duration)
103			.await
104	}
105
106	/// Refreshes the expiration timer on the sesion in the store.
107	pub async fn touch(&self) -> SessionResult<()> {
108		self.store
109			.store
110			.touch(self.token.as_ref(), self.store.duration)
111			.await
112	}
113
114	/// Removes the session from the store.
115	pub async fn remove(&self) -> SessionResult<()> {
116		self.store.store.remove(self.token.as_ref()).await
117	}
118
119	/// Regenerates the session token. The fairing will automatically add a cookie to the response with the new token.
120	///
121	/// It is important to regenerate the session token after a user is authenticated in order to prevent session fixation attacks.
122	///
123	/// This also has a side effect of refreshing the expiration timer on the session.
124	///
125	/// # Examples
126	///
127	/// ```rust
128	/// use rocket::{
129	/// 	http::private::cookie::CookieBuilder,
130	/// 	serde::{
131	/// 		Deserialize,
132	/// 		Serialize,
133	/// 	},
134	/// 	Build,
135	/// 	Rocket,
136	/// };
137	/// use rocket_session_store::{
138	/// 	memory::MemoryStore,
139	/// 	Session,
140	/// 	SessionError,
141	/// 	SessionStore,
142	/// };
143	///
144	/// #[macro_use]
145	/// extern crate rocket;
146	///
147	/// # fn main() { // Makes doc test happy for extern crate
148	/// #[launch]
149	/// fn rocket() -> Rocket<Build> {
150	/// 	let session_store = SessionStore::<SessionState> {
151	/// 		store: Box::new(MemoryStore::new()),
152	/// 		name: "session".into(),
153	/// 		duration: std::time::Duration::from_secs(24 * 60 * 60),
154	/// 		cookie_builder: CookieBuilder::new("", ""),
155	/// 	};
156	///
157	/// 	rocket::build()
158	/// 		.attach(session_store.fairing())
159	/// 		.mount("/", routes![login])
160	/// }
161	///
162	/// #[post("/login")]
163	/// async fn login(mut session: Session<'_, SessionState>) -> Result<(), SessionError> {
164	/// 	// Authenticate the user (check password, 2fa, etc)
165	/// 	// ...
166	///
167	/// 	let user_id = Some(1);
168	///
169	/// 	// Important! Regenerate _before_ updating the session for the authenticated user. We don't
170	/// 	// want to run into a scenario where updating the session works, but then regenerating the
171	/// 	// token fails for some reason leaving the old session still valid with the user logged in
172	/// 	// (eg due to an intermittent redis connection issue or something).
173	/// 	session.regenerate_token().await?;
174	/// 	session.set(SessionState { user_id }).await?;
175	///
176	/// 	Ok(())
177	/// }
178	///
179	/// #[derive(Serialize, Deserialize, Clone, Copy)]
180	/// #[serde(crate = "rocket::serde")]
181	/// struct SessionState {
182	/// 	user_id: Option<u32>,
183	/// }
184	/// # }
185	/// ```
186	pub async fn regenerate_token<'r>(&mut self) -> SessionResult<()> {
187		let mut new_token_opt = self.new_token.lock().await;
188		if new_token_opt.is_some() {
189			// If a new token has already been generated then there's no point regenerating it again.
190			return Ok(());
191		}
192
193		// Retrieve existing session, remove it under the current token, and add it under a new token.
194		let session_opt = self.get().await?;
195		self.remove().await?;
196		self.token = new_id(ID_LENGTH);
197		*new_token_opt = Some(self.token.clone());
198		if let Some(session) = session_opt {
199			self.set(session).await?;
200		}
201
202		Ok(())
203	}
204}
205
206#[rocket::async_trait]
207impl<T, 'r, 's> FromRequest<'r> for Session<'s, T>
208where
209	T: Send + Sync + 'static + Clone,
210	'r: 's,
211{
212	type Error = ();
213	async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
214		let store: &State<SessionStore<T>> = request
215			.guard()
216			.await
217			.expect("Session store must be set in fairing");
218		let (token, new_token) = request
219			.local_cache_async(async {
220				let cookies = request.cookies();
221				cookies.get(store.name.as_str()).map_or_else(
222					|| {
223						let token = new_id(ID_LENGTH);
224						(token.clone(), Arc::new(Mutex::new(Some(token))))
225					},
226					|c| {
227						(
228							SessionID(String::from(c.value())),
229							Arc::new(Mutex::new(None)),
230						)
231					},
232				)
233			})
234			.await
235			.clone();
236
237		let session = Session {
238			store,
239			token,
240			new_token,
241		};
242		Outcome::Success(session)
243	}
244}
245
246/// Store that keeps tracks of sessions
247pub struct SessionStore<T> {
248	/// The store that will keep track of sessions.
249	pub store: Box<dyn Store<Value = T>>,
250	/// The name of the cookie to be used for sessions.
251	///
252	/// This will be the name the cookie will be stored under in the browser.
253	pub name: String,
254	/// The duration of the session.
255	///
256	/// When so much time passes after storing or touching a session, it expires
257	/// and won't be accesible.
258	pub duration: Duration,
259	/// The cookie options.
260	///
261	/// This will be used in the fairing to build the cookie. Each time a cookie needs to
262	/// be set the CookieBuilder will be cloned and the name and value will be overwritten.
263	///
264	/// Note that Rocket defaults to setting the `Secure` attribute for cookies, so when doing local development over
265	/// HTTP without TLS `CookieBuilder::secure(false)` must be used to allow sending the session cookie over an
266	/// insecure connnection, but it is important that this is never done in production to prevent session hijacking.
267	pub cookie_builder: CookieBuilder<'static>,
268}
269
270impl<T> SessionStore<T> {
271	/// A function to turn the store into a [Fairing] to attach on a rocket.
272	pub fn fairing(self) -> SessionStoreFairing<T> {
273		SessionStoreFairing {
274			store: Mutex::new(Some(self)),
275		}
276	}
277}
278
279/// The fairing for the session store.
280///
281/// This shouldn't be created directly and you should
282/// instead use [SessionStore::fairing()] to create it
283pub struct SessionStoreFairing<T> {
284	store: Mutex<Option<SessionStore<T>>>,
285}
286
287#[rocket::async_trait]
288impl<T: Send + Sync + Clone + 'static> Fairing for SessionStoreFairing<T> {
289	fn info(&self) -> rocket::fairing::Info {
290		Info {
291			name: "Session Store",
292			kind: Kind::Ignite | Kind::Response | Kind::Singleton,
293		}
294	}
295
296	async fn on_ignite(&self, rocket: Rocket<Build>) -> Result<Rocket<Build>, Rocket<Build>> {
297		let mut lock = self.store.lock().await;
298		let store = lock.take().expect("Expected store");
299		let rocket = rocket.manage(store);
300		Ok(rocket)
301	}
302
303	async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
304		// If there is a new session id, set the cookie
305		match Session::<T>::from_request(request).await {
306			Outcome::Success(session) => {
307				if let Some(new_token) = &*session.new_token.lock().await {
308					let mut cookie = session.store.cookie_builder.clone().build();
309					cookie.set_name(&session.store.name);
310					cookie.set_value(&new_token.0);
311					response.adjoin_header(cookie);
312				}
313			}
314			_ => (),
315		}
316	}
317}
318
319/// A result wrapper around [SessionError], allowing you to wrap the Result
320pub type SessionResult<T> = Result<T, SessionError>;
321
322/// Errors produced when accessing the session store.
323///
324/// These can be problems like a database connection drop.
325/// It implements [Responder], returning a 500 status error.
326#[derive(Error, Debug)]
327#[error("could not access the session store")]
328pub struct SessionError;
329
330impl<'r, 'o: 'r> Responder<'r, 'o> for SessionError {
331	fn respond_to(self, _request: &'r Request<'_>) -> rocket::response::Result<'o> {
332		Err(Status::InternalServerError)
333	}
334}