1use std::{collections::HashMap, convert::Infallible, rc::Rc};
19
20use cookie::{Cookie, CookieJar, Key, SameSite};
21use derive_more::{Display, From};
22use ntex::http::{HttpMessage, header::HeaderValue, header::SET_COOKIE};
23use ntex::service::{Middleware, Service, ServiceCtx};
24use ntex::web::{DefaultError, ErrorRenderer, WebRequest, WebResponse, WebResponseError};
25use serde_json::error::Error as JsonError;
26use time::{Duration, OffsetDateTime};
27
28use crate::{Session, SessionStatus};
29
30#[derive(Debug, From, Display)]
32pub enum CookieSessionError {
33 #[display("Size of the serialized session is greater than 4000 bytes.")]
35 Overflow,
36 #[display("Fail to serialize session")]
38 Serialize(JsonError),
39}
40
41impl WebResponseError<DefaultError> for CookieSessionError {}
42
43enum CookieSecurity {
44 Signed,
45 Private,
46}
47
48struct CookieSessionInner {
49 key: Key,
50 security: CookieSecurity,
51 name: String,
52 path: String,
53 domain: Option<String>,
54 secure: bool,
55 http_only: bool,
56 max_age: Option<Duration>,
57 expires_in: Option<Duration>,
58 same_site: Option<SameSite>,
59}
60
61impl CookieSessionInner {
62 fn new(key: &[u8], security: CookieSecurity) -> Self {
63 CookieSessionInner {
64 security,
65 key: Key::derive_from(key),
66 name: "ntex-session".to_owned(),
67 path: "/".to_owned(),
68 domain: None,
69 secure: true,
70 http_only: true,
71 max_age: None,
72 expires_in: None,
73 same_site: None,
74 }
75 }
76
77 fn set_cookie(
78 &self,
79 res: &mut WebResponse,
80 state: impl Iterator<Item = (String, String)>,
81 ) -> Result<(), CookieSessionError> {
82 let state: HashMap<String, String> = state.collect();
83 let value = serde_json::to_string(&state).map_err(CookieSessionError::Serialize)?;
84 if value.len() > 4064 {
85 return Err(CookieSessionError::Overflow);
86 }
87
88 let mut cookie = Cookie::new(self.name.clone(), value);
89 cookie.set_path(self.path.clone());
90 cookie.set_secure(self.secure);
91 cookie.set_http_only(self.http_only);
92
93 if let Some(ref domain) = self.domain {
94 cookie.set_domain(domain.clone());
95 }
96
97 if let Some(expires_in) = self.expires_in {
98 cookie.set_expires(OffsetDateTime::now_utc() + expires_in);
99 }
100
101 if let Some(max_age) = self.max_age {
102 cookie.set_max_age(max_age);
103 }
104
105 if let Some(same_site) = self.same_site {
106 cookie.set_same_site(same_site);
107 }
108
109 let mut jar = CookieJar::new();
110
111 match self.security {
112 CookieSecurity::Signed => jar.signed_mut(&self.key).add(cookie),
113 CookieSecurity::Private => jar.private_mut(&self.key).add(cookie),
114 }
115
116 for cookie in jar.delta() {
117 let val = HeaderValue::from_str(&cookie.encoded().to_string()).unwrap();
118 res.headers_mut().append(SET_COOKIE, val);
119 }
120
121 Ok(())
122 }
123
124 fn remove_cookie(&self, res: &mut WebResponse) -> Result<(), Infallible> {
126 let mut cookie = Cookie::from(self.name.clone());
127 cookie.set_value("");
128 cookie.set_max_age(Duration::ZERO);
129 cookie.set_expires(OffsetDateTime::now_utc() - Duration::days(365));
130 cookie.set_path(&self.path);
131
132 if let Some(ref domain) = self.domain {
133 cookie.set_domain(domain);
134 }
135
136 let val = HeaderValue::from_str(&cookie.to_string()).unwrap();
137 res.headers_mut().append(SET_COOKIE, val);
138
139 Ok(())
140 }
141
142 fn load<Err>(&self, req: &WebRequest<Err>) -> (bool, HashMap<String, String>) {
143 if let Ok(cookies) = req.cookies() {
144 for cookie in cookies.iter() {
145 if cookie.name() == self.name {
146 let mut jar = CookieJar::new();
147 jar.add_original(cookie.clone());
148
149 let cookie_opt = match self.security {
150 CookieSecurity::Signed => jar.signed(&self.key).get(&self.name),
151 CookieSecurity::Private => jar.private(&self.key).get(&self.name),
152 };
153 if let Some(cookie) = cookie_opt {
154 if let Ok(val) = serde_json::from_str(cookie.value()) {
155 return (false, val);
156 }
157 }
158 }
159 }
160 }
161 (true, HashMap::new())
162 }
163}
164
165pub struct CookieSession(Rc<CookieSessionInner>);
206
207impl CookieSession {
208 pub fn signed(key: &[u8]) -> Self {
212 CookieSession(Rc::new(CookieSessionInner::new(key, CookieSecurity::Signed)))
213 }
214
215 pub fn private(key: &[u8]) -> Self {
219 CookieSession(Rc::new(CookieSessionInner::new(key, CookieSecurity::Private)))
220 }
221
222 pub fn path<S: Into<String>>(mut self, value: S) -> Self {
224 Rc::get_mut(&mut self.0).unwrap().path = value.into();
225 self
226 }
227
228 pub fn name<S: Into<String>>(mut self, value: S) -> Self {
230 Rc::get_mut(&mut self.0).unwrap().name = value.into();
231 self
232 }
233
234 pub fn domain<S: Into<String>>(mut self, value: S) -> Self {
236 Rc::get_mut(&mut self.0).unwrap().domain = Some(value.into());
237 self
238 }
239
240 pub fn secure(mut self, value: bool) -> Self {
245 Rc::get_mut(&mut self.0).unwrap().secure = value;
246 self
247 }
248
249 pub fn http_only(mut self, value: bool) -> Self {
251 Rc::get_mut(&mut self.0).unwrap().http_only = value;
252 self
253 }
254
255 pub fn same_site(mut self, value: SameSite) -> Self {
257 Rc::get_mut(&mut self.0).unwrap().same_site = Some(value);
258 self
259 }
260
261 pub fn max_age(self, seconds: i64) -> Self {
263 self.max_age_time(Duration::seconds(seconds))
264 }
265
266 pub fn max_age_time(mut self, value: time::Duration) -> Self {
268 Rc::get_mut(&mut self.0).unwrap().max_age = Some(value);
269 self
270 }
271
272 pub fn expires_in(self, seconds: i64) -> Self {
274 self.expires_in_time(Duration::seconds(seconds))
275 }
276
277 pub fn expires_in_time(mut self, value: Duration) -> Self {
279 Rc::get_mut(&mut self.0).unwrap().expires_in = Some(value);
280 self
281 }
282}
283
284impl<S, C> Middleware<S, C> for CookieSession {
285 type Service = CookieSessionMiddleware<S>;
286
287 fn create(&self, service: S, _: C) -> Self::Service {
288 CookieSessionMiddleware { service, inner: self.0.clone() }
289 }
290}
291
292pub struct CookieSessionMiddleware<S> {
294 service: S,
295 inner: Rc<CookieSessionInner>,
296}
297
298impl<S, Err> Service<WebRequest<Err>> for CookieSessionMiddleware<S>
299where
300 S: Service<WebRequest<Err>, Response = WebResponse>,
301 S::Error: 'static,
302 Err: ErrorRenderer,
303 Err::Container: From<CookieSessionError>,
304{
305 type Response = WebResponse;
306 type Error = S::Error;
307
308 ntex::forward_ready!(service);
309 ntex::forward_shutdown!(service);
310
311 async fn call(
317 &self,
318 req: WebRequest<Err>,
319 ctx: ServiceCtx<'_, Self>,
320 ) -> Result<Self::Response, Self::Error> {
321 let inner = self.inner.clone();
322 let (is_new, state) = self.inner.load(&req);
323 let prolong_expiration = self.inner.expires_in.is_some();
324 Session::set_session(state.into_iter(), &req);
325
326 ctx.call(&self.service, req).await.map(|mut res| {
327 match Session::get_changes(&mut res) {
328 (SessionStatus::Changed, Some(state))
329 | (SessionStatus::Renewed, Some(state)) => {
330 res.checked_expr::<Err, _, _>(|res| inner.set_cookie(res, state))
331 }
332 (SessionStatus::Unchanged, Some(state)) if prolong_expiration => {
333 res.checked_expr::<Err, _, _>(|res| inner.set_cookie(res, state))
334 }
335 (SessionStatus::Unchanged, _) =>
336 {
338 if is_new {
339 let state: HashMap<String, String> = HashMap::new();
340 res.checked_expr::<Err, _, _>(|res| {
341 inner.set_cookie(res, state.into_iter())
342 })
343 } else {
344 res
345 }
346 }
347 (SessionStatus::Purged, _) => {
348 let _ = inner.remove_cookie(&mut res);
349 res
350 }
351 _ => res,
352 }
353 })
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use ntex::web::{self, App, test};
361 use ntex::{time, util::Bytes};
362
363 #[ntex::test]
364 async fn cookie_session() {
365 let app = test::init_service(
366 App::new().wrap(CookieSession::signed(&[0; 32]).secure(false)).service(
367 web::resource("/").to(|ses: Session| async move {
368 let _ = ses.set("counter", 100);
369 "test"
370 }),
371 ),
372 )
373 .await;
374
375 let request = test::TestRequest::get().to_request();
376 let response = app.call(request).await.unwrap();
377 assert!(response.response().cookies().any(|c| c.name() == "ntex-session"));
378 }
379
380 #[ntex::test]
381 async fn private_cookie() {
382 let app = test::init_service(
383 App::new().wrap(CookieSession::private(&[0; 32]).secure(false)).service(
384 web::resource("/").to(|ses: Session| async move {
385 let _ = ses.set("counter", 100);
386 "test"
387 }),
388 ),
389 )
390 .await;
391
392 let request = test::TestRequest::get().to_request();
393 let response = app.call(request).await.unwrap();
394 assert!(response.response().cookies().any(|c| c.name() == "ntex-session"));
395 }
396
397 #[ntex::test]
398 async fn cookie_session_extractor() {
399 let app = test::init_service(
400 App::new().wrap(CookieSession::signed(&[0; 32]).secure(false)).service(
401 web::resource("/").to(|ses: Session| async move {
402 let _ = ses.set("counter", 100);
403 "test"
404 }),
405 ),
406 )
407 .await;
408
409 let request = test::TestRequest::get().to_request();
410 let response = app.call(request).await.unwrap();
411 assert!(response.response().cookies().any(|c| c.name() == "ntex-session"));
412 }
413
414 #[ntex::test]
415 async fn basics() {
416 let app = test::init_service(
417 App::new()
418 .wrap(
419 CookieSession::signed(&[0; 32])
420 .path("/test/")
421 .name("ntex-test")
422 .domain("localhost")
423 .http_only(true)
424 .same_site(SameSite::Lax)
425 .max_age(100),
426 )
427 .service(web::resource("/").to(|ses: Session| async move {
428 let _ = ses.set("counter", 100);
429 "test"
430 }))
431 .service(web::resource("/test/").to(|ses: Session| async move {
432 let val: usize = ses.get("counter").unwrap().unwrap();
433 format!("counter: {}", val)
434 })),
435 )
436 .await;
437
438 let request = test::TestRequest::get().to_request();
439 let response = app.call(request).await.unwrap();
440 let cookie = response
441 .response()
442 .cookies()
443 .find(|c| c.name() == "ntex-test")
444 .unwrap()
445 .into_owned();
446 assert_eq!(cookie.path().unwrap(), "/test/");
447
448 let request = test::TestRequest::with_uri("/test/").cookie(cookie).to_request();
449 let body = test::read_response(&app, request).await;
450 assert_eq!(body, Bytes::from_static(b"counter: 100"));
451 }
452
453 #[ntex::test]
454 async fn prolong_expiration() {
455 let app = test::init_service(
456 App::new()
457 .wrap(CookieSession::signed(&[0; 32]).secure(false).expires_in(60))
458 .service(web::resource("/").to(|ses: Session| async move {
459 let _ = ses.set("counter", 100);
460 "test"
461 }))
462 .service(web::resource("/test/").to(|| async move { "no-changes-in-session" })),
463 )
464 .await;
465
466 let request = test::TestRequest::get().to_request();
467 let response = app.call(request).await.unwrap();
468 let expires_1 = response
469 .response()
470 .cookies()
471 .find(|c| c.name() == "ntex-session")
472 .expect("Cookie is set")
473 .expires()
474 .expect("Expiration is set");
475
476 time::sleep(time::Seconds::ONE).await;
477
478 let request = test::TestRequest::with_uri("/test/").to_request();
479 let response = app.call(request).await.unwrap();
480 let expires_2 = response
481 .response()
482 .cookies()
483 .find(|c| c.name() == "ntex-session")
484 .expect("Cookie is set")
485 .expires()
486 .expect("Expiration is set");
487
488 assert!(
489 expires_2.datetime().unwrap() - expires_1.datetime().unwrap()
490 >= Duration::seconds(1)
491 );
492 }
493}