1use std::cell::RefCell;
45use std::collections::HashMap;
46use std::convert::Infallible;
47use std::rc::Rc;
48
49use ntex::http::{Payload, RequestHead};
50use ntex::util::Extensions;
51use ntex::web::{Error, FromRequest, HttpRequest, WebRequest, WebResponse};
52use serde::Serialize;
53use serde::de::DeserializeOwned;
54
55#[cfg(feature = "cookie-session")]
56mod cookie;
57#[cfg(feature = "cookie-session")]
58pub use crate::cookie::CookieSession;
59
60pub struct Session(Rc<RefCell<SessionInner>>);
83
84pub trait UserSession {
86 fn get_session(&self) -> Session;
87}
88
89impl UserSession for HttpRequest {
90 fn get_session(&self) -> Session {
91 Session::get_session(&mut self.extensions_mut())
92 }
93}
94
95impl<Err> UserSession for WebRequest<Err> {
96 fn get_session(&self) -> Session {
97 Session::get_session(&mut self.extensions_mut())
98 }
99}
100
101impl UserSession for RequestHead {
102 fn get_session(&self) -> Session {
103 Session::get_session(&mut self.extensions_mut())
104 }
105}
106
107#[derive(PartialEq, Clone, Debug)]
108pub enum SessionStatus {
109 Changed,
110 Purged,
111 Renewed,
112 Unchanged,
113}
114
115impl Default for SessionStatus {
117 fn default() -> SessionStatus {
118 SessionStatus::Unchanged
119 }
120}
121
122#[derive(Default)]
123struct SessionInner {
124 state: HashMap<String, String>,
125 pub status: SessionStatus,
126}
127
128impl Session {
129 pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
131 if let Some(s) = self.0.borrow().state.get(key) {
132 Ok(Some(serde_json::from_str(s)?))
133 } else {
134 Ok(None)
135 }
136 }
137
138 pub fn set<T: Serialize>(&self, key: &str, value: T) -> Result<(), Error> {
140 let mut inner = self.0.borrow_mut();
141 if inner.status != SessionStatus::Purged {
142 inner.status = SessionStatus::Changed;
143 inner.state.insert(key.to_owned(), serde_json::to_string(&value)?);
144 }
145 Ok(())
146 }
147
148 pub fn remove(&self, key: &str) {
150 let mut inner = self.0.borrow_mut();
151 if inner.status != SessionStatus::Purged {
152 inner.status = SessionStatus::Changed;
153 inner.state.remove(key);
154 }
155 }
156
157 pub fn clear(&self) {
159 let mut inner = self.0.borrow_mut();
160 if inner.status != SessionStatus::Purged {
161 inner.status = SessionStatus::Changed;
162 inner.state.clear()
163 }
164 }
165
166 pub fn purge(&self) {
168 let mut inner = self.0.borrow_mut();
169 inner.status = SessionStatus::Purged;
170 inner.state.clear();
171 }
172
173 pub fn renew(&self) {
175 let mut inner = self.0.borrow_mut();
176 if inner.status != SessionStatus::Purged {
177 inner.status = SessionStatus::Renewed;
178 }
179 }
180
181 pub fn set_session<Err>(
182 data: impl Iterator<Item = (String, String)>,
183 req: &WebRequest<Err>,
184 ) {
185 let session = Session::get_session(&mut req.extensions_mut());
186 let mut inner = session.0.borrow_mut();
187 inner.state.extend(data);
188 }
189
190 pub fn get_changes(
191 res: &mut WebResponse,
192 ) -> (SessionStatus, Option<impl Iterator<Item = (String, String)> + use<>>) {
193 if let Some(s_impl) = res.request().extensions().get::<Rc<RefCell<SessionInner>>>() {
194 let state = std::mem::take(&mut s_impl.borrow_mut().state);
195 (s_impl.borrow().status.clone(), Some(state.into_iter()))
196 } else {
197 (SessionStatus::Unchanged, None)
198 }
199 }
200
201 fn get_session(extensions: &mut Extensions) -> Session {
202 if let Some(s_impl) = extensions.get::<Rc<RefCell<SessionInner>>>() {
203 return Session(Rc::clone(s_impl));
204 }
205 let inner = Rc::new(RefCell::new(SessionInner::default()));
206 extensions.insert(inner.clone());
207 Session(inner)
208 }
209}
210
211impl<Err> FromRequest<Err> for Session {
229 type Error = Infallible;
230
231 #[inline]
232 async fn from_request(req: &HttpRequest, _: &mut Payload) -> Result<Session, Infallible> {
233 Ok(Session::get_session(&mut req.extensions_mut()))
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use ntex::web::{HttpResponse, test};
240
241 use super::*;
242
243 #[test]
244 fn session() {
245 let req = test::TestRequest::default().to_srv_request();
246
247 Session::set_session(
248 vec![("key".to_string(), "\"value\"".to_string())].into_iter(),
249 &req,
250 );
251 let session = Session::get_session(&mut req.extensions_mut());
252 let res = session.get::<String>("key").unwrap();
253 assert_eq!(res, Some("value".to_string()));
254
255 session.set("key2", "value2".to_string()).unwrap();
256 session.remove("key");
257
258 let mut res = req.into_response(HttpResponse::Ok().finish());
259 let (_status, state) = Session::get_changes(&mut res);
260 let changes: Vec<_> = state.unwrap().collect();
261 assert_eq!(changes, [("key2".to_string(), "\"value2\"".to_string())]);
262 }
263
264 #[test]
265 fn get_session() {
266 let req = test::TestRequest::default().to_srv_request();
267
268 Session::set_session(
269 vec![("key".to_string(), "\"value\"".to_string())].into_iter(),
270 &req,
271 );
272
273 let session = req.get_session();
274 let res = session.get::<String>("key").unwrap();
275 assert_eq!(res, Some("value".to_string()));
276 }
277
278 #[test]
279 fn get_session_from_request_head() {
280 let mut req = test::TestRequest::default().to_srv_request();
281
282 Session::set_session(
283 vec![("key".to_string(), "\"value\"".to_string())].into_iter(),
284 &req,
285 );
286
287 let session = req.head_mut().get_session();
288 let res = session.get::<String>("key").unwrap();
289 assert_eq!(res, Some("value".to_string()));
290 }
291
292 #[test]
293 fn purge_session() {
294 let req = test::TestRequest::default().to_srv_request();
295 let session = Session::get_session(&mut req.extensions_mut());
296 assert_eq!(session.0.borrow().status, SessionStatus::Unchanged);
297 session.purge();
298 assert_eq!(session.0.borrow().status, SessionStatus::Purged);
299 }
300
301 #[test]
302 fn renew_session() {
303 let req = test::TestRequest::default().to_srv_request();
304 let session = Session::get_session(&mut req.extensions_mut());
305 assert_eq!(session.0.borrow().status, SessionStatus::Unchanged);
306 session.renew();
307 assert_eq!(session.0.borrow().status, SessionStatus::Renewed);
308 }
309}