use std::pin::Pin;
use std::sync::{Arc, Mutex, PoisonError, Weak};
use std::thread;
use std::time::{Duration, Instant};
use futures_util::future::{self, FutureExt};
use linked_hash_map::LinkedHashMap;
use log::trace;
use crate::middleware::session::backend::{
Backend, GetSessionFuture, NewBackend, SetSessionFuture,
};
use crate::middleware::session::SessionIdentifier;
use crate::state::State;
type MemoryMap = Mutex<LinkedHashMap<String, (Instant, Vec<u8>)>>;
#[derive(Clone)]
pub struct MemoryBackend {
storage: Arc<MemoryMap>,
}
impl MemoryBackend {
pub fn new(ttl: Duration) -> MemoryBackend {
let storage = Arc::new(Mutex::new(LinkedHashMap::new()));
{
let storage = Arc::downgrade(&storage);
thread::spawn(move || cleanup_loop(storage, ttl));
}
MemoryBackend { storage }
}
}
impl Default for MemoryBackend {
fn default() -> MemoryBackend {
MemoryBackend::new(Duration::from_secs(3600))
}
}
impl NewBackend for MemoryBackend {
type Instance = MemoryBackend;
fn new_backend(&self) -> anyhow::Result<Self::Instance> {
Ok(self.clone())
}
}
impl Backend for MemoryBackend {
fn persist_session(
&self,
_: &State,
identifier: SessionIdentifier,
content: &[u8],
) -> Pin<Box<SetSessionFuture>> {
match self.storage.lock() {
Ok(mut storage) => {
storage.insert(identifier.value, (Instant::now(), Vec::from(content)));
Box::pin(future::ok(()))
}
Err(PoisonError { .. }) => {
unreachable!("session memory backend lock poisoned, HashMap panicked?")
}
}
}
fn read_session(&self, _: &State, identifier: SessionIdentifier) -> Pin<Box<GetSessionFuture>> {
match self.storage.lock() {
Ok(mut storage) => match storage.get_refresh(&identifier.value) {
Some(&mut (ref mut instant, ref value)) => {
*instant = Instant::now();
future::ok(Some(value.clone())).boxed()
}
None => future::ok(None).boxed(),
},
Err(PoisonError { .. }) => {
unreachable!("session memory backend lock poisoned, HashMap panicked?")
}
}
}
fn drop_session(&self, _: &State, identifier: SessionIdentifier) -> Pin<Box<SetSessionFuture>> {
match self.storage.lock() {
Ok(mut storage) => {
storage.remove(&identifier.value);
future::ok(()).boxed()
}
Err(PoisonError { .. }) => {
unreachable!("session memory backend lock poisoned, HashMap panicked?")
}
}
}
}
fn cleanup_loop(storage: Weak<MemoryMap>, ttl: Duration) {
loop {
let storage = match storage.upgrade() {
None => break,
Some(storage) => storage,
};
let duration = match storage.lock() {
Err(PoisonError { .. }) => break,
Ok(mut storage) => cleanup_once(&mut storage, ttl),
};
if let Some(duration) = duration {
thread::sleep(duration);
}
}
}
fn cleanup_once(
storage: &mut LinkedHashMap<String, (Instant, Vec<u8>)>,
ttl: Duration,
) -> Option<Duration> {
match storage.front() {
Some((_, &(instant, _))) => {
let age = instant.elapsed();
if age >= ttl {
if let Some((key, _)) = storage.pop_front() {
trace!(" expired session {} and removed from MemoryBackend", key);
}
None
} else {
let cap = storage.capacity();
let len = storage.len();
if cap >= 65536 && cap / 8 > len {
storage.shrink_to_fit();
trace!(
" session backend had capacity {} and {} sessions, new capacity: {}",
cap,
len,
storage.capacity()
);
}
Some(::std::cmp::max(ttl - age, Duration::from_secs(1)))
}
}
None => Some(ttl),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cleanup_test() {
let mut storage = LinkedHashMap::new();
storage.insert(
"abcd".to_owned(),
(Instant::now() - Duration::from_secs(2), vec![]),
);
cleanup_once(&mut storage, Duration::from_secs(1));
assert!(storage.is_empty());
}
#[test]
fn cleanup_join_test() {
let storage = Arc::new(Mutex::new(LinkedHashMap::new()));
let weak = Arc::downgrade(&storage);
let handle = thread::spawn(move || cleanup_loop(weak, Duration::from_millis(1)));
drop(storage);
handle.join().unwrap();
}
#[test]
fn memory_backend_test() {
let new_backend = MemoryBackend::new(Duration::from_millis(100));
let bytes: Vec<u8> = (0..64).map(|_| rand::random()).collect();
let state = State::new();
let identifier = SessionIdentifier {
value: "totally_random_identifier".to_owned(),
};
futures_executor::block_on(
new_backend
.new_backend()
.expect("can't create backend for write")
.persist_session(&state, identifier.clone(), &bytes[..]),
)
.expect("failed to persist");
let received = futures_executor::block_on(
new_backend
.new_backend()
.expect("can't create backend for read")
.read_session(&state, identifier),
)
.expect("no response from backend")
.expect("session data missing");
assert_eq!(bytes, received);
}
#[test]
fn memory_backend_refresh_test() {
let new_backend = MemoryBackend::new(Duration::from_millis(100));
let bytes: Vec<u8> = (0..64).map(|_| rand::random()).collect();
let state = State::new();
let identifier = SessionIdentifier {
value: "totally_random_identifier".to_owned(),
};
let bytes2: Vec<u8> = (0..64).map(|_| rand::random()).collect();
let identifier2 = SessionIdentifier {
value: "another_totally_random_identifier".to_owned(),
};
let backend = new_backend
.new_backend()
.expect("can't create backend for write");
futures_executor::block_on(backend.persist_session(&state, identifier.clone(), &bytes[..]))
.expect("failed to persist");
futures_executor::block_on(backend.persist_session(
&state,
identifier2.clone(),
&bytes2[..],
))
.expect("failed to persist");
{
let storage = backend.storage.lock().expect("couldn't lock storage");
assert_eq!(
storage.front().expect("no front element").0,
&identifier.value
);
assert_eq!(
storage.back().expect("no back element").0,
&identifier2.value
);
}
futures_executor::block_on(backend.read_session(&state, identifier.clone()))
.expect("failed to read session");
{
let storage = backend.storage.lock().expect("couldn't lock storage");
assert_eq!(
storage.front().expect("no front element").0,
&identifier2.value
);
assert_eq!(
storage.back().expect("no back element").0,
&identifier.value
);
}
}
}