use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct SessionCache {
inner: Arc<Mutex<SessionCacheInner>>,
}
#[derive(Debug)]
struct SessionCacheInner {
tickets: HashMap<String, SessionTicket>,
max_age: Duration,
}
#[derive(Debug, Clone)]
struct SessionTicket {
data: Vec<u8>,
received_at: Instant,
max_age: Duration,
}
impl SessionCache {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(SessionCacheInner {
tickets: HashMap::new(),
max_age: Duration::from_secs(86400), })),
}
}
pub fn with_max_age(max_age: Duration) -> Self {
Self {
inner: Arc::new(Mutex::new(SessionCacheInner {
tickets: HashMap::new(),
max_age,
})),
}
}
pub fn store_ticket(&self, host: &str, ticket_data: Vec<u8>, max_age: Option<Duration>) {
let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
let max_age = max_age.unwrap_or(inner.max_age);
inner.tickets.insert(
host.to_string(),
SessionTicket {
data: ticket_data,
received_at: Instant::now(),
max_age,
},
);
}
pub fn get_ticket(&self, host: &str) -> Option<Vec<u8>> {
let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
if let Some(ticket) = inner.tickets.get(host) {
if ticket.received_at.elapsed() < ticket.max_age {
return Some(ticket.data.clone());
} else {
inner.tickets.remove(host);
}
}
None
}
pub fn clear(&self) {
let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
inner.tickets.clear();
}
pub fn cleanup_expired(&self) {
let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
inner
.tickets
.retain(|_, ticket| ticket.received_at.elapsed() < ticket.max_age);
}
pub fn len(&self) -> usize {
let inner = self.inner.lock().expect("Session cache mutex poisoned");
inner.tickets.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for SessionCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_cache_store_and_retrieve() {
let cache = SessionCache::new();
cache.store_ticket("example.com", vec![1, 2, 3], None);
assert_eq!(cache.get_ticket("example.com"), Some(vec![1, 2, 3]));
assert_eq!(cache.get_ticket("other.com"), None);
}
#[test]
fn test_session_cache_expiration() {
let cache = SessionCache::with_max_age(Duration::from_secs(1));
cache.store_ticket("example.com", vec![1, 2, 3], None);
assert_eq!(cache.get_ticket("example.com"), Some(vec![1, 2, 3]));
std::thread::sleep(Duration::from_secs(2));
assert_eq!(cache.get_ticket("example.com"), None);
assert_eq!(cache.len(), 0);
}
#[test]
fn test_session_cache_clear() {
let cache = SessionCache::new();
cache.store_ticket("example.com", vec![1, 2, 3], None);
cache.store_ticket("other.com", vec![4, 5, 6], None);
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
}
}