1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
64#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
65#![cfg_attr(docsrs, feature(doc_cfg))]
66
67pub use async_session::{CookieStore, MemoryStore, Session, SessionStore};
68
69use std::fmt::{self, Formatter};
70use std::time::Duration;
71
72use async_session::base64;
73use async_session::hmac::{Hmac, Mac, NewMac};
74use async_session::sha2::Sha256;
75use cookie::{Cookie, Key, SameSite};
76use salvo_core::http::uri::Scheme;
77use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
78
79pub const SESSION_KEY: &str = "::salvo::session";
81const BASE64_DIGEST_LEN: usize = 44;
82
83pub trait SessionDepotExt {
85 fn set_session(&mut self, session: Session) -> &mut Self;
87 fn take_session(&mut self) -> Option<Session>;
89 fn session(&self) -> Option<&Session>;
91 fn session_mut(&mut self) -> Option<&mut Session>;
93}
94
95impl SessionDepotExt for Depot {
96 #[inline]
97 fn set_session(&mut self, session: Session) -> &mut Self {
98 self.insert(SESSION_KEY, session);
99 self
100 }
101 #[inline]
102 fn take_session(&mut self) -> Option<Session> {
103 self.remove(SESSION_KEY).ok()
104 }
105 #[inline]
106 fn session(&self) -> Option<&Session> {
107 self.get(SESSION_KEY).ok()
108 }
109 #[inline]
110 fn session_mut(&mut self) -> Option<&mut Session> {
111 self.get_mut(SESSION_KEY).ok()
112 }
113}
114
115pub struct HandlerBuilder<S> {
117 store: S,
118 cookie_path: String,
119 cookie_name: String,
120 cookie_domain: Option<String>,
121 session_ttl: Option<Duration>,
122 save_unchanged: bool,
123 same_site_policy: SameSite,
124 key: Key,
125 fallback_keys: Vec<Key>,
126}
127impl<S: SessionStore> fmt::Debug for HandlerBuilder<S> {
128 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
129 f.debug_struct("HandlerBuilder")
130 .field("store", &self.store)
131 .field("cookie_path", &self.cookie_path)
132 .field("cookie_name", &self.cookie_name)
133 .field("cookie_domain", &self.cookie_domain)
134 .field("session_ttl", &self.session_ttl)
135 .field("same_site_policy", &self.same_site_policy)
136 .field("key", &"..")
137 .field("fallback_keys", &"..")
138 .field("save_unchanged", &self.save_unchanged)
139 .finish()
140 }
141}
142
143impl<S> HandlerBuilder<S>
144where
145 S: SessionStore,
146{
147 #[inline]
149 #[must_use]
150 pub fn new(store: S, secret: &[u8]) -> Self {
151 Self {
152 store,
153 save_unchanged: true,
154 cookie_path: "/".into(),
155 cookie_name: "salvo.session.id".into(),
156 cookie_domain: None,
157 same_site_policy: SameSite::Lax,
158 session_ttl: Some(Duration::from_secs(24 * 60 * 60)),
159 key: Key::from(secret),
160 fallback_keys: vec![],
161 }
162 }
163
164 #[inline]
168 #[must_use]
169 pub fn cookie_path(mut self, cookie_path: impl Into<String>) -> Self {
170 self.cookie_path = cookie_path.into();
171 self
172 }
173
174 #[inline]
180 #[must_use]
181 pub fn session_ttl(mut self, session_ttl: Option<Duration>) -> Self {
182 self.session_ttl = session_ttl;
183 self
184 }
185
186 #[inline]
192 #[must_use]
193 pub fn cookie_name(mut self, cookie_name: impl Into<String>) -> Self {
194 self.cookie_name = cookie_name.into();
195 self
196 }
197
198 #[inline]
208 #[must_use]
209 pub fn save_unchanged(mut self, value: bool) -> Self {
210 self.save_unchanged = value;
211 self
212 }
213
214 #[inline]
219 #[must_use]
220 pub fn same_site_policy(mut self, policy: SameSite) -> Self {
221 self.same_site_policy = policy;
222 self
223 }
224
225 #[inline]
227 #[must_use]
228 pub fn cookie_domain(mut self, cookie_domain: impl AsRef<str>) -> Self {
229 self.cookie_domain = Some(cookie_domain.as_ref().to_owned());
230 self
231 }
232 #[inline]
234 #[must_use]
235 pub fn fallback_keys(mut self, keys: Vec<impl Into<Key>>) -> Self {
236 self.fallback_keys = keys.into_iter().map(|s| s.into()).collect();
237 self
238 }
239
240 #[inline]
242 #[must_use]
243 pub fn add_fallback_key(mut self, key: impl Into<Key>) -> Self {
244 self.fallback_keys.push(key.into());
245 self
246 }
247
248 pub fn build(self) -> Result<SessionHandler<S>, Error> {
250 let Self {
251 store,
252 save_unchanged,
253 cookie_path,
254 cookie_name,
255 cookie_domain,
256 session_ttl,
257 same_site_policy,
258 key,
259 fallback_keys,
260 } = self;
261 let hmac = Hmac::<Sha256>::new_from_slice(key.signing())
262 .map_err(|_| Error::Other("invalid key length".into()))?;
263 let fallback_hmacs = fallback_keys
264 .iter()
265 .map(|key| Hmac::<Sha256>::new_from_slice(key.signing()))
266 .collect::<Result<Vec<_>, _>>()
267 .map_err(|_| Error::Other("invalid key length".into()))?;
268 Ok(SessionHandler {
269 store,
270 save_unchanged,
271 cookie_path,
272 cookie_name,
273 cookie_domain,
274 session_ttl,
275 same_site_policy,
276 hmac,
277 fallback_hmacs,
278 })
279 }
280}
281
282pub struct SessionHandler<S> {
284 store: S,
285 cookie_path: String,
286 cookie_name: String,
287 cookie_domain: Option<String>,
288 session_ttl: Option<Duration>,
289 save_unchanged: bool,
290 same_site_policy: SameSite,
291 hmac: Hmac<Sha256>,
292 fallback_hmacs: Vec<Hmac<Sha256>>,
293}
294impl<S: SessionStore> fmt::Debug for SessionHandler<S> {
295 #[inline]
296 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
297 f.debug_struct("SessionHandler")
298 .field("store", &self.store)
299 .field("cookie_path", &self.cookie_path)
300 .field("cookie_name", &self.cookie_name)
301 .field("cookie_domain", &self.cookie_domain)
302 .field("session_ttl", &self.session_ttl)
303 .field("same_site_policy", &self.same_site_policy)
304 .field("key", &"..")
305 .field("fallback_keys", &"..")
306 .field("save_unchanged", &self.save_unchanged)
307 .finish()
308 }
309}
310#[async_trait]
311impl<S> Handler for SessionHandler<S>
312where
313 S: SessionStore,
314{
315 async fn handle(
316 &self,
317 req: &mut Request,
318 depot: &mut Depot,
319 res: &mut Response,
320 ctrl: &mut FlowCtrl,
321 ) {
322 let cookie = req.cookies().get(&self.cookie_name);
323 let cookie_value = cookie.and_then(|cookie| self.verify_signature(cookie.value()).ok());
324
325 let mut session = self.load_or_create(cookie_value).await;
326
327 if let Some(ttl) = self.session_ttl {
328 session.expire_in(ttl);
329 }
330
331 depot.set_session(session);
332
333 ctrl.call_next(req, depot, res).await;
334 if ctrl.is_ceased() {
335 return;
336 }
337
338 let session = depot.take_session().expect("session should exist in depot");
339 if session.is_destroyed() {
340 if let Err(e) = self.store.destroy_session(session).await {
341 tracing::error!(error = ?e, "unable to destroy session");
342 }
343 res.remove_cookie(&self.cookie_name);
344 } else if self.save_unchanged || session.data_changed() {
345 match self.store.store_session(session).await {
346 Ok(cookie_value) => {
347 if let Some(cookie_value) = cookie_value {
348 let secure_cookie = req.uri().scheme() == Some(&Scheme::HTTPS);
349 let cookie = self.build_cookie(secure_cookie, cookie_value);
350 res.add_cookie(cookie);
351 }
352 }
353 Err(e) => {
354 tracing::error!(error = ?e, "store session error");
355 }
356 }
357 }
358 }
359}
360
361impl<S> SessionHandler<S>
362where
363 S: SessionStore,
364{
365 pub fn builder(store: S, secret: &[u8]) -> HandlerBuilder<S> {
367 HandlerBuilder::new(store, secret)
368 }
369 #[inline]
370 async fn load_or_create(&self, cookie_value: Option<String>) -> Session {
371 let session = match cookie_value {
372 Some(cookie_value) => self.store.load_session(cookie_value).await.ok().flatten(),
373 None => None,
374 };
375
376 session
377 .and_then(|session| session.validate())
378 .unwrap_or_default()
379 }
380 fn verify_signature(&self, cookie_value: &str) -> Result<String, Error> {
386 if cookie_value.len() < BASE64_DIGEST_LEN {
387 return Err(Error::Other(
388 "length of value is <= BASE64_DIGEST_LEN".into(),
389 ));
390 }
391
392 let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN);
394 let digest =
395 base64::decode(digest_str).map_err(|_| Error::Other("bad base64 digest".into()))?;
396
397 let mut hmac = self.hmac.clone();
399 hmac.update(value.as_bytes());
400 if hmac.verify(&digest).is_ok() {
401 return Ok(value.to_owned());
402 }
403 for hmac in &self.fallback_hmacs {
404 let mut hmac = hmac.clone();
405 hmac.update(value.as_bytes());
406 if hmac.verify(&digest).is_ok() {
407 return Ok(value.to_owned());
408 }
409 }
410 Err(Error::Other("value did not verify".into()))
411 }
412 fn build_cookie(&self, secure: bool, cookie_value: String) -> Cookie<'static> {
413 let mut cookie = Cookie::build((self.cookie_name.clone(), cookie_value))
414 .http_only(true)
415 .same_site(self.same_site_policy)
416 .secure(secure)
417 .path(self.cookie_path.clone())
418 .build();
419
420 if let Some(ttl) = self.session_ttl {
421 cookie.set_expires(Some((std::time::SystemTime::now() + ttl).into()));
422 }
423
424 if let Some(cookie_domain) = self.cookie_domain.clone() {
425 cookie.set_domain(cookie_domain)
426 }
427
428 self.sign_cookie(&mut cookie);
429
430 cookie
431 }
432 fn sign_cookie(&self, cookie: &mut Cookie<'_>) {
436 let mut mac = self.hmac.clone();
438 mac.update(cookie.value().as_bytes());
439
440 let mut new_value = base64::encode(mac.finalize().into_bytes());
442 new_value.push_str(cookie.value());
443 cookie.set_value(new_value);
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use salvo_core::http::Method;
450 use salvo_core::http::header::*;
451 use salvo_core::prelude::*;
452 use salvo_core::test::{ResponseExt, TestClient};
453
454 use super::*;
455
456 #[test]
457 fn test_session_data() {
458 let builder = SessionHandler::builder(
459 async_session::CookieStore,
460 b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
461 )
462 .cookie_domain("test.domain")
463 .cookie_name("test_cookie")
464 .cookie_path("/abc")
465 .same_site_policy(SameSite::Strict)
466 .session_ttl(Some(Duration::from_secs(30)));
467 assert!(format!("{builder:?}").contains("test_cookie"));
468
469 let handler = builder.build().unwrap();
470 assert!(format!("{handler:?}").contains("test_cookie"));
471 assert_eq!(handler.cookie_domain, Some("test.domain".into()));
472 assert_eq!(handler.cookie_name, "test_cookie");
473 assert_eq!(handler.cookie_path, "/abc");
474 assert_eq!(handler.same_site_policy, SameSite::Strict);
475 assert_eq!(handler.session_ttl, Some(Duration::from_secs(30)));
476 }
477
478 #[tokio::test]
479 async fn test_session_login() {
480 #[handler]
481 pub async fn login(req: &mut Request, depot: &mut Depot, res: &mut Response) {
482 if req.method() == Method::POST {
483 let mut session = Session::new();
484 session
485 .insert("username", req.form::<String>("username").await.unwrap())
486 .unwrap();
487 depot.set_session(session);
488 res.render(Redirect::other("/"));
489 } else {
490 res.render(Text::Html("login page"));
491 }
492 }
493
494 #[handler]
495 pub async fn logout(depot: &mut Depot, res: &mut Response) {
496 if let Some(session) = depot.session_mut() {
497 session.remove("username");
498 }
499 res.render(Redirect::other("/"));
500 }
501
502 #[handler]
503 pub async fn home(depot: &mut Depot, res: &mut Response) {
504 let mut content = r#"home"#.into();
505 if let Some(session) = depot.session_mut() {
506 if let Some(username) = session.get::<String>("username") {
507 content = username;
508 }
509 }
510 res.render(Text::Html(content));
511 }
512
513 let session_handler = SessionHandler::builder(
514 MemoryStore::new(),
515 b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
516 )
517 .build()
518 .unwrap();
519 let router = Router::new()
520 .hoop(session_handler)
521 .get(home)
522 .push(Router::with_path("login").get(login).post(login))
523 .push(Router::with_path("logout").get(logout));
524 let service = Service::new(router);
525
526 let response = TestClient::post("http://127.0.0.1:8698/login")
527 .raw_form("username=salvo")
528 .send(&service)
529 .await;
530 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
531 let cookie = response.headers().get(SET_COOKIE).unwrap();
532
533 let mut response = TestClient::get("http://127.0.0.1:8698/")
534 .add_header(COOKIE, cookie, true)
535 .send(&service)
536 .await;
537 assert_eq!(response.take_string().await.unwrap(), "salvo");
538
539 let response = TestClient::get("http://127.0.0.1:8698/logout")
540 .send(&service)
541 .await;
542 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
543
544 let mut response = TestClient::get("http://127.0.0.1:8698/")
545 .send(&service)
546 .await;
547 assert_eq!(response.take_string().await.unwrap(), "home");
548 }
549}