rocket_session_store/
lib.rs1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4#[cfg(test)]
5mod test;
6
7pub mod memory;
8
9#[cfg(feature = "redis")]
10pub mod redis;
11
12use std::{
13 sync::Arc,
14 time::Duration,
15};
16
17use rand::{rngs::OsRng, Rng, TryRngCore};
18use rocket::{
19 fairing::{
20 Fairing,
21 Info,
22 Kind,
23 },
24 http::{
25 private::cookie::CookieBuilder,
26 Status,
27 },
28 request::{
29 FromRequest,
30 Outcome,
31 },
32 response::Responder,
33 tokio::sync::Mutex,
34 Build, Request, Response, Rocket, State,
35};
36use thiserror::Error;
37
38fn new_id(length: usize) -> SessionID {
39 SessionID(
40 OsRng
41 .unwrap_err()
42 .sample_iter(&rand::distr::Alphanumeric)
43 .take(length)
44 .map(char::from)
45 .collect(),
46 )
47}
48
49const ID_LENGTH: usize = 24;
50
51#[rocket::async_trait]
54pub trait Store: Send + Sync {
55 type Value;
59 async fn get(&self, id: &str) -> SessionResult<Option<Self::Value>>;
61 async fn set(&self, id: &str, value: Self::Value, duration: Duration) -> SessionResult<()>;
63 async fn touch(&self, id: &str, duration: Duration) -> SessionResult<()>;
65 async fn remove(&self, id: &str) -> SessionResult<()>;
67}
68
69#[derive(Debug, Clone)]
71struct SessionID(String);
72
73impl AsRef<str> for SessionID {
74 fn as_ref(&self) -> &str {
75 &self.0
76 }
77}
78
79pub struct Session<'s, T: Send + Sync + Clone + 'static> {
82 store: &'s State<SessionStore<T>>,
83 token: SessionID,
84 new_token: Arc<Mutex<Option<SessionID>>>,
85}
86
87impl<'s, T: Send + Sync + Clone + 'static> Session<'s, T> {
88 pub async fn get(&self) -> SessionResult<Option<T>> {
93 self.store.store.get(self.token.as_ref()).await
94 }
95
96 pub async fn set(&self, value: T) -> SessionResult<()> {
100 self.store
101 .store
102 .set(self.token.as_ref(), value, self.store.duration)
103 .await
104 }
105
106 pub async fn touch(&self) -> SessionResult<()> {
108 self.store
109 .store
110 .touch(self.token.as_ref(), self.store.duration)
111 .await
112 }
113
114 pub async fn remove(&self) -> SessionResult<()> {
116 self.store.store.remove(self.token.as_ref()).await
117 }
118
119 pub async fn regenerate_token<'r>(&mut self) -> SessionResult<()> {
187 let mut new_token_opt = self.new_token.lock().await;
188 if new_token_opt.is_some() {
189 return Ok(());
191 }
192
193 let session_opt = self.get().await?;
195 self.remove().await?;
196 self.token = new_id(ID_LENGTH);
197 *new_token_opt = Some(self.token.clone());
198 if let Some(session) = session_opt {
199 self.set(session).await?;
200 }
201
202 Ok(())
203 }
204}
205
206#[rocket::async_trait]
207impl<T, 'r, 's> FromRequest<'r> for Session<'s, T>
208where
209 T: Send + Sync + 'static + Clone,
210 'r: 's,
211{
212 type Error = ();
213 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
214 let store: &State<SessionStore<T>> = request
215 .guard()
216 .await
217 .expect("Session store must be set in fairing");
218 let (token, new_token) = request
219 .local_cache_async(async {
220 let cookies = request.cookies();
221 cookies.get(store.name.as_str()).map_or_else(
222 || {
223 let token = new_id(ID_LENGTH);
224 (token.clone(), Arc::new(Mutex::new(Some(token))))
225 },
226 |c| {
227 (
228 SessionID(String::from(c.value())),
229 Arc::new(Mutex::new(None)),
230 )
231 },
232 )
233 })
234 .await
235 .clone();
236
237 let session = Session {
238 store,
239 token,
240 new_token,
241 };
242 Outcome::Success(session)
243 }
244}
245
246pub struct SessionStore<T> {
248 pub store: Box<dyn Store<Value = T>>,
250 pub name: String,
254 pub duration: Duration,
259 pub cookie_builder: CookieBuilder<'static>,
268}
269
270impl<T> SessionStore<T> {
271 pub fn fairing(self) -> SessionStoreFairing<T> {
273 SessionStoreFairing {
274 store: Mutex::new(Some(self)),
275 }
276 }
277}
278
279pub struct SessionStoreFairing<T> {
284 store: Mutex<Option<SessionStore<T>>>,
285}
286
287#[rocket::async_trait]
288impl<T: Send + Sync + Clone + 'static> Fairing for SessionStoreFairing<T> {
289 fn info(&self) -> rocket::fairing::Info {
290 Info {
291 name: "Session Store",
292 kind: Kind::Ignite | Kind::Response | Kind::Singleton,
293 }
294 }
295
296 async fn on_ignite(&self, rocket: Rocket<Build>) -> Result<Rocket<Build>, Rocket<Build>> {
297 let mut lock = self.store.lock().await;
298 let store = lock.take().expect("Expected store");
299 let rocket = rocket.manage(store);
300 Ok(rocket)
301 }
302
303 async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
304 match Session::<T>::from_request(request).await {
306 Outcome::Success(session) => {
307 if let Some(new_token) = &*session.new_token.lock().await {
308 let mut cookie = session.store.cookie_builder.clone().build();
309 cookie.set_name(&session.store.name);
310 cookie.set_value(&new_token.0);
311 response.adjoin_header(cookie);
312 }
313 }
314 _ => (),
315 }
316 }
317}
318
319pub type SessionResult<T> = Result<T, SessionError>;
321
322#[derive(Error, Debug)]
327#[error("could not access the session store")]
328pub struct SessionError;
329
330impl<'r, 'o: 'r> Responder<'r, 'o> for SessionError {
331 fn respond_to(self, _request: &'r Request<'_>) -> rocket::response::Result<'o> {
332 Err(Status::InternalServerError)
333 }
334}