use std::{
collections::BTreeMap,
fmt::{self, Debug, Formatter},
sync::Arc,
};
use parking_lot::RwLock;
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;
use crate::{FromRequest, Request, RequestBody, Result};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum SessionStatus {
Changed,
Purged,
Renewed,
Unchanged,
}
struct SessionInner {
status: SessionStatus,
entries: BTreeMap<String, Value>,
}
#[derive(Clone)]
pub struct Session {
inner: Arc<RwLock<SessionInner>>,
}
impl Debug for Session {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let inner = self.inner.read();
f.debug_struct("Session")
.field("status", &inner.status)
.field("entries", &inner.entries)
.finish()
}
}
impl Default for Session {
fn default() -> Self {
Self::new(Default::default())
}
}
impl Session {
pub(crate) fn new(entries: BTreeMap<String, Value>) -> Self {
Self {
inner: Arc::new(RwLock::new(SessionInner {
status: SessionStatus::Unchanged,
entries,
})),
}
}
pub fn get<T: DeserializeOwned>(&self, name: &str) -> Option<T> {
let inner = self.inner.read();
inner
.entries
.get(name)
.and_then(|value| serde_json::from_value(value.clone()).ok())
}
pub fn set(&self, name: &str, value: impl Serialize) {
let mut inner = self.inner.write();
if inner.status != SessionStatus::Purged {
if let Ok(value) = serde_json::to_value(&value) {
inner.entries.insert(name.to_string(), value);
if inner.status != SessionStatus::Renewed {
inner.status = SessionStatus::Changed;
}
}
}
}
pub fn remove(&self, name: &str) {
let mut inner = self.inner.write();
if inner.status != SessionStatus::Purged {
inner.entries.remove(name);
if inner.status != SessionStatus::Renewed {
inner.status = SessionStatus::Changed;
}
}
}
pub fn is_empty(&self) -> bool {
let inner = self.inner.read();
inner.entries.is_empty()
}
pub fn entries(&self) -> BTreeMap<String, Value> {
let inner = self.inner.read();
inner.entries.clone()
}
pub fn clear(&self) {
let mut inner = self.inner.write();
if inner.status != SessionStatus::Purged {
inner.entries.clear();
if inner.status != SessionStatus::Renewed {
inner.status = SessionStatus::Changed;
}
}
}
pub fn renew(&self) {
let mut inner = self.inner.write();
if inner.status != SessionStatus::Purged {
inner.status = SessionStatus::Renewed;
}
}
pub fn purge(&self) {
let mut inner = self.inner.write();
if inner.status != SessionStatus::Purged {
inner.entries.clear();
inner.status = SessionStatus::Purged;
}
}
pub fn status(&self) -> SessionStatus {
let inner = self.inner.read();
inner.status
}
}
impl<'a> FromRequest<'a> for &'a Session {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
Ok(req
.extensions()
.get::<Session>()
.expect("To use the `Session` extractor, the `CookieSession` middleware is required."))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn update_session() {
let session = Session::default();
session.set("a", 1);
assert_eq!(session.status(), SessionStatus::Changed);
assert_eq!(
session.entries().into_iter().collect::<Vec<_>>(),
vec![("a".to_string(), 1.into())]
);
session.set("b", 2);
assert_eq!(session.status(), SessionStatus::Changed);
assert_eq!(
session.entries().into_iter().collect::<Vec<_>>(),
vec![("a".to_string(), 1.into()), ("b".to_string(), 2.into())]
);
session.renew();
session.set("c", 3);
assert_eq!(session.status(), SessionStatus::Renewed);
assert_eq!(
session.entries().into_iter().collect::<Vec<_>>(),
vec![
("a".to_string(), 1.into()),
("b".to_string(), 2.into()),
("c".to_string(), 3.into())
]
);
session.remove("c");
assert_eq!(session.status(), SessionStatus::Renewed);
assert_eq!(
session.entries().into_iter().collect::<Vec<_>>(),
vec![("a".to_string(), 1.into()), ("b".to_string(), 2.into()),]
);
session.clear();
assert_eq!(session.status(), SessionStatus::Renewed);
assert_eq!(session.entries().into_iter().collect::<Vec<_>>(), vec![]);
}
#[test]
fn purge_session() {
let session = Session::default();
session.set("a", 1);
session.set("b", 2);
assert_eq!(session.status(), SessionStatus::Changed);
assert_eq!(
session.entries().into_iter().collect::<Vec<_>>(),
vec![("a".to_string(), 1.into()), ("b".to_string(), 2.into())]
);
session.purge();
session.set("c", 3);
assert_eq!(session.status(), SessionStatus::Purged);
assert_eq!(session.entries().into_iter().collect::<Vec<_>>(), vec![]);
session.clear();
assert_eq!(session.status(), SessionStatus::Purged);
assert_eq!(session.entries().into_iter().collect::<Vec<_>>(), vec![]);
session.set("d", 4);
assert_eq!(session.status(), SessionStatus::Purged);
assert_eq!(session.entries().into_iter().collect::<Vec<_>>(), vec![]);
session.remove("d");
assert_eq!(session.status(), SessionStatus::Purged);
assert_eq!(session.entries().into_iter().collect::<Vec<_>>(), vec![]);
}
}