1use std::{
2 cell::{Ref, RefCell},
3 error::Error as StdError,
4 mem,
5 rc::Rc,
6};
7
8use actix_utils::future::{ready, Ready};
9use actix_web::{
10 body::BoxBody,
11 dev::{Extensions, Payload, ServiceRequest, ServiceResponse},
12 error::Error,
13 FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
14};
15use anyhow::Context;
16use derive_more::derive::{Display, From};
17use serde::{de::DeserializeOwned, Serialize};
18use serde_json::Value;
19
20use crate::storage::SessionState;
21
22#[derive(Clone)]
48pub struct Session(Rc<RefCell<SessionInner>>);
49
50#[derive(Debug, Clone, Default, PartialEq, Eq)]
52pub enum SessionStatus {
53 Changed,
55
56 Purged,
61
62 Renewed,
67
68 #[default]
70 Unchanged,
71}
72
73#[derive(Default)]
74struct SessionInner {
75 state: SessionState,
76 status: SessionStatus,
77}
78
79impl Session {
80 pub fn mock(state: SessionState, status: SessionStatus) -> Self {
82 Self(Rc::new(RefCell::new(SessionInner { state, status })))
83 }
84
85 pub fn get_value(&self, key: &str) -> Option<Value> {
87 self.0.borrow().state.get(key).cloned()
88 }
89
90 pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, SessionGetError> {
94 if let Some(value) = self.0.borrow().state.get(key) {
95 Ok(Some(
96 serde_json::from_value(value.to_owned())
97 .with_context(|| {
98 format!(
99 "Failed to deserialize the JSON-encoded session data attached to key \
100 `{}` as a `{}` type",
101 key,
102 std::any::type_name::<T>()
103 )
104 })
105 .map_err(SessionGetError)?,
106 ))
107 } else {
108 Ok(None)
109 }
110 }
111
112 pub fn entries(&self) -> Ref<'_, SessionState> {
116 Ref::map(self.0.borrow(), |inner| &inner.state)
117 }
118
119 pub fn status(&self) -> SessionStatus {
121 Ref::map(self.0.borrow(), |inner| &inner.status).clone()
122 }
123
124 pub fn insert<T: Serialize>(
131 &self,
132 key: impl Into<String>,
133 value: T,
134 ) -> Result<(), SessionInsertError> {
135 let mut inner = self.0.borrow_mut();
136
137 if inner.status != SessionStatus::Purged {
138 if inner.status != SessionStatus::Renewed {
139 inner.status = SessionStatus::Changed;
140 }
141
142 let key = key.into();
143 let val = serde_json::to_value(&value)
144 .with_context(|| {
145 format!(
146 "Failed to serialize the provided `{}` type instance as JSON in order to \
147 attach as session data to the `{}` key",
148 std::any::type_name::<T>(),
149 &key
150 )
151 })
152 .map_err(SessionInsertError)?;
153
154 inner.state.insert(key, val);
155 }
156
157 Ok(())
158 }
159
160 pub fn remove(&self, key: &str) -> Option<Value> {
164 let mut inner = self.0.borrow_mut();
165
166 if inner.status != SessionStatus::Purged {
167 if inner.status != SessionStatus::Renewed {
168 inner.status = SessionStatus::Changed;
169 }
170 return inner.state.remove(key);
171 }
172
173 None
174 }
175
176 pub fn remove_as<T: DeserializeOwned>(&self, key: &str) -> Option<Result<T, Value>> {
181 self.remove(key)
182 .map(|val| match serde_json::from_value(val.clone()) {
183 Ok(val) => Ok(val),
184 Err(_err) => {
185 tracing::debug!(
186 "Removed value (key: {}) could not be deserialized as {}",
187 key,
188 std::any::type_name::<T>()
189 );
190
191 Err(val)
192 }
193 })
194 }
195
196 pub fn clear(&self) {
198 let mut inner = self.0.borrow_mut();
199
200 if inner.status != SessionStatus::Purged {
201 if inner.status != SessionStatus::Renewed {
202 inner.status = SessionStatus::Changed;
203 }
204 inner.state.clear()
205 }
206 }
207
208 pub fn purge(&self) {
210 let mut inner = self.0.borrow_mut();
211 inner.status = SessionStatus::Purged;
212 inner.state.clear();
213 }
214
215 pub fn renew(&self) {
217 let mut inner = self.0.borrow_mut();
218
219 if inner.status != SessionStatus::Purged {
220 inner.status = SessionStatus::Renewed;
221 }
222 }
223
224 #[allow(clippy::needless_pass_by_ref_mut)]
229 pub(crate) fn set_session(
230 req: &mut ServiceRequest,
231 data: impl IntoIterator<Item = (String, Value)>,
232 ) {
233 let session = Session::get_session(&mut req.extensions_mut());
234 let mut inner = session.0.borrow_mut();
235 inner
236 .state
237 .extend(data.into_iter().map(|(k, v)| (k, v.into())));
238 }
239
240 #[allow(clippy::needless_pass_by_ref_mut)]
246 pub(crate) fn get_changes<B>(res: &mut ServiceResponse<B>) -> (SessionStatus, SessionState) {
247 if let Some(s_impl) = res
248 .request()
249 .extensions()
250 .get::<Rc<RefCell<SessionInner>>>()
251 {
252 let state = mem::take(&mut s_impl.borrow_mut().state);
253 (s_impl.borrow().status.clone(), state)
254 } else {
255 (SessionStatus::Unchanged, SessionState::new())
256 }
257 }
258
259 pub(crate) fn get_session(extensions: &mut Extensions) -> Session {
260 if let Some(s_impl) = extensions.get::<Rc<RefCell<SessionInner>>>() {
261 return Session(Rc::clone(s_impl));
262 }
263
264 let inner = Rc::new(RefCell::new(SessionInner::default()));
265 extensions.insert(inner.clone());
266
267 Session(inner)
268 }
269}
270
271impl FromRequest for Session {
292 type Error = Error;
293 type Future = Ready<Result<Session, Error>>;
294
295 #[inline]
296 fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
297 ready(Ok(Session::get_session(&mut req.extensions_mut())))
298 }
299}
300
301#[derive(Debug, Display, From)]
303#[display("{_0}")]
304pub struct SessionGetError(anyhow::Error);
305
306impl StdError for SessionGetError {
307 fn source(&self) -> Option<&(dyn StdError + 'static)> {
308 Some(self.0.as_ref())
309 }
310}
311
312impl ResponseError for SessionGetError {
313 fn error_response(&self) -> HttpResponse<BoxBody> {
314 HttpResponse::new(self.status_code())
315 }
316}
317
318#[derive(Debug, Display, From)]
320#[display("{_0}")]
321pub struct SessionInsertError(anyhow::Error);
322
323impl StdError for SessionInsertError {
324 fn source(&self) -> Option<&(dyn StdError + 'static)> {
325 Some(self.0.as_ref())
326 }
327}
328
329impl ResponseError for SessionInsertError {
330 fn error_response(&self) -> HttpResponse<BoxBody> {
331 HttpResponse::new(self.status_code())
332 }
333}