1use std::cell::RefCell;
45use std::collections::HashMap;
46use std::convert::Infallible;
47use std::rc::Rc;
48
49use futures::future::{ok, Ready};
50use ntex::http::{Payload, RequestHead};
51use ntex::util::Extensions;
52use ntex::web::{Error, FromRequest, HttpRequest, WebRequest, WebResponse};
53use serde::de::DeserializeOwned;
54use serde::Serialize;
55
56#[cfg(feature = "cookie-session")]
57mod cookie;
58#[cfg(feature = "cookie-session")]
59pub use crate::cookie::CookieSession;
60
61pub struct Session(Rc<RefCell<SessionInner>>);
84
85pub trait UserSession {
87 fn get_session(&self) -> Session;
88}
89
90impl UserSession for HttpRequest {
91 fn get_session(&self) -> Session {
92 Session::get_session(&mut self.extensions_mut())
93 }
94}
95
96impl<Err> UserSession for WebRequest<Err> {
97 fn get_session(&self) -> Session {
98 Session::get_session(&mut self.extensions_mut())
99 }
100}
101
102impl UserSession for RequestHead {
103 fn get_session(&self) -> Session {
104 Session::get_session(&mut self.extensions_mut())
105 }
106}
107
108#[derive(PartialEq, Clone, Debug)]
109pub enum SessionStatus {
110 Changed,
111 Purged,
112 Renewed,
113 Unchanged,
114}
115
116impl Default for SessionStatus {
118 fn default() -> SessionStatus {
119 SessionStatus::Unchanged
120 }
121}
122
123#[derive(Default)]
124struct SessionInner {
125 state: HashMap<String, String>,
126 pub status: SessionStatus,
127}
128
129impl Session {
130 pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
132 if let Some(s) = self.0.borrow().state.get(key) {
133 Ok(Some(serde_json::from_str(s)?))
134 } else {
135 Ok(None)
136 }
137 }
138
139 pub fn set<T: Serialize>(&self, key: &str, value: T) -> Result<(), Error> {
141 let mut inner = self.0.borrow_mut();
142 if inner.status != SessionStatus::Purged {
143 inner.status = SessionStatus::Changed;
144 inner.state.insert(key.to_owned(), serde_json::to_string(&value)?);
145 }
146 Ok(())
147 }
148
149 pub fn remove(&self, key: &str) {
151 let mut inner = self.0.borrow_mut();
152 if inner.status != SessionStatus::Purged {
153 inner.status = SessionStatus::Changed;
154 inner.state.remove(key);
155 }
156 }
157
158 pub fn clear(&self) {
160 let mut inner = self.0.borrow_mut();
161 if inner.status != SessionStatus::Purged {
162 inner.status = SessionStatus::Changed;
163 inner.state.clear()
164 }
165 }
166
167 pub fn purge(&self) {
169 let mut inner = self.0.borrow_mut();
170 inner.status = SessionStatus::Purged;
171 inner.state.clear();
172 }
173
174 pub fn renew(&self) {
176 let mut inner = self.0.borrow_mut();
177 if inner.status != SessionStatus::Purged {
178 inner.status = SessionStatus::Renewed;
179 }
180 }
181
182 pub fn set_session<Err>(
183 data: impl Iterator<Item = (String, String)>,
184 req: &WebRequest<Err>,
185 ) {
186 let session = Session::get_session(&mut req.extensions_mut());
187 let mut inner = session.0.borrow_mut();
188 inner.state.extend(data);
189 }
190
191 pub fn get_changes(
192 res: &mut WebResponse,
193 ) -> (SessionStatus, Option<impl Iterator<Item = (String, String)>>) {
194 if let Some(s_impl) = res.request().extensions().get::<Rc<RefCell<SessionInner>>>() {
195 let state = std::mem::take(&mut s_impl.borrow_mut().state);
196 (s_impl.borrow().status.clone(), Some(state.into_iter()))
197 } else {
198 (SessionStatus::Unchanged, None)
199 }
200 }
201
202 fn get_session(extensions: &mut Extensions) -> Session {
203 if let Some(s_impl) = extensions.get::<Rc<RefCell<SessionInner>>>() {
204 return Session(Rc::clone(s_impl));
205 }
206 let inner = Rc::new(RefCell::new(SessionInner::default()));
207 extensions.insert(inner.clone());
208 Session(inner)
209 }
210}
211
212impl<Err> FromRequest<Err> for Session {
230 type Error = Infallible;
231 type Future = Ready<Result<Session, Infallible>>;
232
233 #[inline]
234 fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
235 ok(Session::get_session(&mut req.extensions_mut()))
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use ntex::web::{test, HttpResponse};
242
243 use super::*;
244
245 #[test]
246 fn session() {
247 let req = test::TestRequest::default().to_srv_request();
248
249 Session::set_session(
250 vec![("key".to_string(), "\"value\"".to_string())].into_iter(),
251 &req,
252 );
253 let session = Session::get_session(&mut req.extensions_mut());
254 let res = session.get::<String>("key").unwrap();
255 assert_eq!(res, Some("value".to_string()));
256
257 session.set("key2", "value2".to_string()).unwrap();
258 session.remove("key");
259
260 let mut res = req.into_response(HttpResponse::Ok().finish());
261 let (_status, state) = Session::get_changes(&mut res);
262 let changes: Vec<_> = state.unwrap().collect();
263 assert_eq!(changes, [("key2".to_string(), "\"value2\"".to_string())]);
264 }
265
266 #[test]
267 fn get_session() {
268 let req = test::TestRequest::default().to_srv_request();
269
270 Session::set_session(
271 vec![("key".to_string(), "\"value\"".to_string())].into_iter(),
272 &req,
273 );
274
275 let session = req.get_session();
276 let res = session.get::<String>("key").unwrap();
277 assert_eq!(res, Some("value".to_string()));
278 }
279
280 #[test]
281 fn get_session_from_request_head() {
282 let mut req = test::TestRequest::default().to_srv_request();
283
284 Session::set_session(
285 vec![("key".to_string(), "\"value\"".to_string())].into_iter(),
286 &req,
287 );
288
289 let session = req.head_mut().get_session();
290 let res = session.get::<String>("key").unwrap();
291 assert_eq!(res, Some("value".to_string()));
292 }
293
294 #[test]
295 fn purge_session() {
296 let req = test::TestRequest::default().to_srv_request();
297 let session = Session::get_session(&mut req.extensions_mut());
298 assert_eq!(session.0.borrow().status, SessionStatus::Unchanged);
299 session.purge();
300 assert_eq!(session.0.borrow().status, SessionStatus::Purged);
301 }
302
303 #[test]
304 fn renew_session() {
305 let req = test::TestRequest::default().to_srv_request();
306 let session = Session::get_session(&mut req.extensions_mut());
307 assert_eq!(session.0.borrow().status, SessionStatus::Unchanged);
308 session.renew();
309 assert_eq!(session.0.borrow().status, SessionStatus::Renewed);
310 }
311}