1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
62#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
63#![cfg_attr(docsrs, feature(doc_cfg))]
64
65use std::fmt::{self, Formatter};
66use std::time::Duration;
67
68use cookie::{Cookie, Key, SameSite};
69use salvo_core::http::uri::Scheme;
70use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
71use saysion::base64::Engine as _;
72use saysion::base64::engine::general_purpose;
73use saysion::hmac::{Hmac, Mac};
74use saysion::sha2::Sha256;
75pub use saysion::{CookieStore, MemoryStore, Session, SessionStore};
76
77pub const SESSION_KEY: &str = "::salvo::session";
79const BASE64_DIGEST_LEN: usize = 44;
80
81pub trait SessionDepotExt {
83 fn set_session(&mut self, session: Session) -> &mut Self;
85 fn take_session(&mut self) -> Option<Session>;
87 fn session(&self) -> Option<&Session>;
89 fn session_mut(&mut self) -> Option<&mut Session>;
91}
92
93impl SessionDepotExt for Depot {
94 #[inline]
95 fn set_session(&mut self, session: Session) -> &mut Self {
96 self.insert(SESSION_KEY, session);
97 self
98 }
99 #[inline]
100 fn take_session(&mut self) -> Option<Session> {
101 self.remove(SESSION_KEY).ok()
102 }
103 #[inline]
104 fn session(&self) -> Option<&Session> {
105 self.get(SESSION_KEY).ok()
106 }
107 #[inline]
108 fn session_mut(&mut self) -> Option<&mut Session> {
109 self.get_mut(SESSION_KEY).ok()
110 }
111}
112
113pub struct HandlerBuilder<S> {
115 store: S,
116 cookie_path: String,
117 cookie_name: String,
118 cookie_domain: Option<String>,
119 session_ttl: Option<Duration>,
120 save_unchanged: bool,
121 same_site_policy: SameSite,
122 key: Key,
123 fallback_keys: Vec<Key>,
124}
125impl<S> fmt::Debug for HandlerBuilder<S>
126where
127 S: SessionStore + fmt::Debug,
128{
129 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
130 f.debug_struct("HandlerBuilder")
131 .field("store", &self.store)
132 .field("cookie_path", &self.cookie_path)
133 .field("cookie_name", &self.cookie_name)
134 .field("cookie_domain", &self.cookie_domain)
135 .field("session_ttl", &self.session_ttl)
136 .field("same_site_policy", &self.same_site_policy)
137 .field("key", &"..")
138 .field("fallback_keys", &"..")
139 .field("save_unchanged", &self.save_unchanged)
140 .finish()
141 }
142}
143
144impl<S> HandlerBuilder<S>
145where
146 S: SessionStore,
147{
148 #[inline]
150 #[must_use]
151 pub fn new(store: S, secret: &[u8]) -> Self {
152 Self {
153 store,
154 save_unchanged: true,
155 cookie_path: "/".into(),
156 cookie_name: "salvo.session.id".into(),
157 cookie_domain: None,
158 same_site_policy: SameSite::Lax,
159 session_ttl: Some(Duration::from_secs(24 * 60 * 60)),
160 key: Key::from(secret),
161 fallback_keys: vec![],
162 }
163 }
164
165 #[inline]
169 #[must_use]
170 pub fn cookie_path(mut self, cookie_path: impl Into<String>) -> Self {
171 self.cookie_path = cookie_path.into();
172 self
173 }
174
175 #[inline]
181 #[must_use]
182 pub fn session_ttl(mut self, session_ttl: Option<Duration>) -> Self {
183 self.session_ttl = session_ttl;
184 self
185 }
186
187 #[inline]
193 #[must_use]
194 pub fn cookie_name(mut self, cookie_name: impl Into<String>) -> Self {
195 self.cookie_name = cookie_name.into();
196 self
197 }
198
199 #[inline]
209 #[must_use]
210 pub fn save_unchanged(mut self, value: bool) -> Self {
211 self.save_unchanged = value;
212 self
213 }
214
215 #[inline]
220 #[must_use]
221 pub fn same_site_policy(mut self, policy: SameSite) -> Self {
222 self.same_site_policy = policy;
223 self
224 }
225
226 #[inline]
228 #[must_use]
229 pub fn cookie_domain(mut self, cookie_domain: impl AsRef<str>) -> Self {
230 self.cookie_domain = Some(cookie_domain.as_ref().to_owned());
231 self
232 }
233 #[inline]
235 #[must_use]
236 pub fn fallback_keys(mut self, keys: Vec<impl Into<Key>>) -> Self {
237 self.fallback_keys = keys.into_iter().map(|s| s.into()).collect();
238 self
239 }
240
241 #[inline]
243 #[must_use]
244 pub fn add_fallback_key(mut self, key: impl Into<Key>) -> Self {
245 self.fallback_keys.push(key.into());
246 self
247 }
248
249 pub fn build(self) -> Result<SessionHandler<S>, Error> {
251 let Self {
252 store,
253 save_unchanged,
254 cookie_path,
255 cookie_name,
256 cookie_domain,
257 session_ttl,
258 same_site_policy,
259 key,
260 fallback_keys,
261 } = self;
262 let hmac = Hmac::<Sha256>::new_from_slice(key.signing())
263 .map_err(|_| Error::Other("invalid key length".into()))?;
264 let fallback_hmacs = fallback_keys
265 .iter()
266 .map(|key| Hmac::<Sha256>::new_from_slice(key.signing()))
267 .collect::<Result<Vec<_>, _>>()
268 .map_err(|_| Error::Other("invalid key length".into()))?;
269 Ok(SessionHandler {
270 store,
271 save_unchanged,
272 cookie_path,
273 cookie_name,
274 cookie_domain,
275 session_ttl,
276 same_site_policy,
277 hmac,
278 fallback_hmacs,
279 })
280 }
281}
282
283pub struct SessionHandler<S> {
285 store: S,
286 cookie_path: String,
287 cookie_name: String,
288 cookie_domain: Option<String>,
289 session_ttl: Option<Duration>,
290 save_unchanged: bool,
291 same_site_policy: SameSite,
292 hmac: Hmac<Sha256>,
293 fallback_hmacs: Vec<Hmac<Sha256>>,
294}
295impl<S> fmt::Debug for SessionHandler<S>
296where
297 S: SessionStore + fmt::Debug,
298{
299 #[inline]
300 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
301 f.debug_struct("SessionHandler")
302 .field("store", &self.store)
303 .field("cookie_path", &self.cookie_path)
304 .field("cookie_name", &self.cookie_name)
305 .field("cookie_domain", &self.cookie_domain)
306 .field("session_ttl", &self.session_ttl)
307 .field("same_site_policy", &self.same_site_policy)
308 .field("key", &"..")
309 .field("fallback_keys", &"..")
310 .field("save_unchanged", &self.save_unchanged)
311 .finish()
312 }
313}
314#[async_trait]
315impl<S> Handler for SessionHandler<S>
316where
317 S: SessionStore + Send + Sync + 'static,
318{
319 async fn handle(
320 &self,
321 req: &mut Request,
322 depot: &mut Depot,
323 res: &mut Response,
324 ctrl: &mut FlowCtrl,
325 ) {
326 let cookie = req.cookies().get(&self.cookie_name);
327 let cookie_value = cookie.and_then(|cookie| self.verify_signature(cookie.value()).ok());
328
329 let mut session = self.load_or_create(cookie_value).await;
330
331 if let Some(ttl) = self.session_ttl {
332 session.expire_in(ttl);
333 }
334
335 depot.set_session(session);
336
337 ctrl.call_next(req, depot, res).await;
338 if ctrl.is_ceased() {
339 return;
340 }
341
342 let session = depot.take_session().expect("session should exist in depot");
343 if session.is_destroyed() {
344 if let Err(e) = self.store.destroy_session(session).await {
345 tracing::error!(error = ?e, "unable to destroy session");
346 }
347 res.remove_cookie(&self.cookie_name);
348 } else if self.save_unchanged || session.data_changed() {
349 match self.store.store_session(session).await {
350 Ok(cookie_value) => {
351 if let Some(cookie_value) = cookie_value {
352 let secure_cookie = req.uri().scheme() == Some(&Scheme::HTTPS);
353 let cookie = self.build_cookie(secure_cookie, cookie_value);
354 res.add_cookie(cookie);
355 }
356 }
357 Err(e) => {
358 tracing::error!(error = ?e, "store session error");
359 }
360 }
361 }
362 }
363}
364
365impl<S> SessionHandler<S>
366where
367 S: SessionStore + Send + Sync + 'static,
368{
369 pub fn builder(store: S, secret: &[u8]) -> HandlerBuilder<S> {
371 HandlerBuilder::new(store, secret)
372 }
373 #[inline]
374 async fn load_or_create(&self, cookie_value: Option<String>) -> Session {
375 let session = match cookie_value {
376 Some(cookie_value) => self.store.load_session(cookie_value).await.ok().flatten(),
377 None => None,
378 };
379
380 session
381 .and_then(|session| session.validate())
382 .unwrap_or_default()
383 }
384 fn verify_signature(&self, cookie_value: &str) -> Result<String, Error> {
390 if cookie_value.len() < BASE64_DIGEST_LEN {
391 return Err(Error::Other(
392 "length of value is <= BASE64_DIGEST_LEN".into(),
393 ));
394 }
395
396 let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN);
398 let digest = general_purpose::STANDARD
399 .decode(digest_str)
400 .map_err(|_| Error::Other("bad base64 digest".into()))?;
401
402 let mut hmac = self.hmac.clone();
404 hmac.update(value.as_bytes());
405 if hmac.verify_slice(&digest).is_ok() {
406 return Ok(value.to_owned());
407 }
408 for hmac in &self.fallback_hmacs {
409 let mut hmac = hmac.clone();
410 hmac.update(value.as_bytes());
411 if hmac.verify_slice(&digest).is_ok() {
412 return Ok(value.to_owned());
413 }
414 }
415 Err(Error::Other("value did not verify".into()))
416 }
417 fn build_cookie(&self, secure: bool, cookie_value: String) -> Cookie<'static> {
418 let mut cookie = Cookie::build((self.cookie_name.clone(), cookie_value))
419 .http_only(true)
420 .same_site(self.same_site_policy)
421 .secure(secure)
422 .path(self.cookie_path.clone())
423 .build();
424
425 if let Some(ttl) = self.session_ttl {
426 cookie.set_expires(Some((std::time::SystemTime::now() + ttl).into()));
427 }
428
429 if let Some(cookie_domain) = self.cookie_domain.clone() {
430 cookie.set_domain(cookie_domain)
431 }
432
433 self.sign_cookie(&mut cookie);
434
435 cookie
436 }
437 fn sign_cookie(&self, cookie: &mut Cookie<'_>) {
441 let mut mac = self.hmac.clone();
443 mac.update(cookie.value().as_bytes());
444
445 let mut new_value = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
447 new_value.push_str(cookie.value());
448 cookie.set_value(new_value);
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use salvo_core::http::Method;
455 use salvo_core::http::header::*;
456 use salvo_core::prelude::*;
457 use salvo_core::test::{ResponseExt, TestClient};
458
459 use super::*;
460
461 #[test]
462 fn test_session_data() {
463 let builder = SessionHandler::builder(
464 saysion::CookieStore,
465 b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
466 )
467 .cookie_domain("test.domain")
468 .cookie_name("test_cookie")
469 .cookie_path("/abc")
470 .same_site_policy(SameSite::Strict)
471 .session_ttl(Some(Duration::from_secs(30)));
472 assert!(format!("{builder:?}").contains("test_cookie"));
473
474 let handler = builder.build().unwrap();
475 assert!(format!("{handler:?}").contains("test_cookie"));
476 assert_eq!(handler.cookie_domain, Some("test.domain".into()));
477 assert_eq!(handler.cookie_name, "test_cookie");
478 assert_eq!(handler.cookie_path, "/abc");
479 assert_eq!(handler.same_site_policy, SameSite::Strict);
480 assert_eq!(handler.session_ttl, Some(Duration::from_secs(30)));
481 }
482
483 #[tokio::test]
484 async fn test_session_login() {
485 #[handler]
486 pub async fn login(req: &mut Request, depot: &mut Depot, res: &mut Response) {
487 if req.method() == Method::POST {
488 let mut session = Session::new();
489 session
490 .insert("username", req.form::<String>("username").await.unwrap())
491 .unwrap();
492 depot.set_session(session);
493 res.render(Redirect::other("/"));
494 } else {
495 res.render(Text::Html("login page"));
496 }
497 }
498
499 #[handler]
500 pub async fn logout(depot: &mut Depot, res: &mut Response) {
501 if let Some(session) = depot.session_mut() {
502 session.remove("username");
503 }
504 res.render(Redirect::other("/"));
505 }
506
507 #[handler]
508 pub async fn home(depot: &mut Depot, res: &mut Response) {
509 let mut content = r#"home"#.into();
510 if let Some(session) = depot.session_mut() {
511 if let Some(username) = session.get::<String>("username") {
512 content = username;
513 }
514 }
515 res.render(Text::Html(content));
516 }
517
518 let session_handler = SessionHandler::builder(
519 MemoryStore::new(),
520 b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
521 )
522 .build()
523 .unwrap();
524 let router = Router::new()
525 .hoop(session_handler)
526 .get(home)
527 .push(Router::with_path("login").get(login).post(login))
528 .push(Router::with_path("logout").get(logout));
529 let service = Service::new(router);
530
531 let response = TestClient::post("http://127.0.0.1:8698/login")
532 .raw_form("username=salvo")
533 .send(&service)
534 .await;
535 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
536 let cookie = response.headers().get(SET_COOKIE).unwrap();
537
538 let mut response = TestClient::get("http://127.0.0.1:8698/")
539 .add_header(COOKIE, cookie, true)
540 .send(&service)
541 .await;
542 assert_eq!(response.take_string().await.unwrap(), "salvo");
543
544 let response = TestClient::get("http://127.0.0.1:8698/logout")
545 .send(&service)
546 .await;
547 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
548
549 let mut response = TestClient::get("http://127.0.0.1:8698/")
550 .send(&service)
551 .await;
552 assert_eq!(response.take_string().await.unwrap(), "home");
553 }
554}