use std::{
borrow::Borrow,
collections::{HashMap, hash_map::Entry},
hash::{Hash, Hasher},
num::NonZeroUsize,
sync::Arc,
};
use btls::ssl::{SslSession, SslVersion};
use lru::LruCache;
use crate::{conn::descriptor::ConnectionId, sync::Mutex, tls::TlsVersion};
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Key(pub(super) ConnectionId);
#[derive(Clone)]
pub struct TlsSession(pub(super) SslSession);
pub trait TlsSessionCache: Send + Sync {
fn put(&self, key: Key, session: TlsSession);
fn pop(&self, key: &Key) -> Option<TlsSession>;
}
impl_into_shared!(
pub trait IntoTlsSessionCache => TlsSessionCache
);
pub struct LruTlsSessionCache {
inner: Mutex<Inner>,
per_host_session_capacity: usize,
}
struct Inner {
reverse: HashMap<TlsSession, Key>,
per_host_sessions: HashMap<Key, LruCache<TlsSession, ()>>,
}
impl TlsSession {
#[inline]
pub fn id(&self) -> &[u8] {
self.0.id()
}
#[inline]
pub fn time(&self) -> u64 {
self.0.time()
}
#[inline]
pub fn timeout(&self) -> u32 {
self.0.timeout()
}
#[inline]
pub fn protocol_version(&self) -> TlsVersion {
let version = self.0.protocol_version();
if version == SslVersion::SSL3 {
unreachable!(
"Encountered unsupported protocol: SSLv3 (SSL 3.0) is obsolete and not accepted by btls"
);
}
TlsVersion(version)
}
}
impl Eq for TlsSession {}
impl PartialEq for TlsSession {
#[inline]
fn eq(&self, other: &TlsSession) -> bool {
self.0.id() == other.0.id()
}
}
impl Hash for TlsSession {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.id().hash(state);
}
}
impl Borrow<[u8]> for TlsSession {
#[inline]
fn borrow(&self) -> &[u8] {
self.0.id()
}
}
impl LruTlsSessionCache {
pub fn new(per_host_session_capacity: usize) -> Self {
LruTlsSessionCache {
inner: Mutex::new(Inner {
reverse: HashMap::new(),
per_host_sessions: HashMap::new(),
}),
per_host_session_capacity,
}
}
}
impl TlsSessionCache for LruTlsSessionCache {
fn put(&self, key: Key, session: TlsSession) {
let mut inner = self.inner.lock();
let evicted = {
let per_host_sessions =
inner
.per_host_sessions
.entry(key.clone())
.or_insert_with(|| {
NonZeroUsize::new(self.per_host_session_capacity)
.map_or_else(LruCache::unbounded, LruCache::new)
});
let evicted = if per_host_sessions.len() >= self.per_host_session_capacity {
per_host_sessions.pop_lru().map(|(s, _)| s)
} else {
None
};
per_host_sessions.put(session.clone(), ());
evicted
};
if let Some(evicted_session) = evicted {
inner.reverse.remove(&evicted_session);
}
inner.reverse.insert(session, key);
}
fn pop(&self, key: &Key) -> Option<TlsSession> {
let mut inner = self.inner.lock();
let session = {
let per_host_sessions = inner.per_host_sessions.get_mut(key)?;
per_host_sessions.peek_lru()?.0.clone()
};
if session.protocol_version() == TlsVersion::TLS_1_3 {
if let Some(key) = inner.reverse.remove(&session) {
if let Entry::Occupied(mut entry) = inner.per_host_sessions.entry(key) {
entry.get_mut().pop(&session);
if entry.get().is_empty() {
entry.remove();
}
}
}
}
Some(session)
}
}