1use boring::ssl::SslSession;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio::sync::Notify;
12
13#[derive(Debug, Clone, Hash, PartialEq, Eq)]
15pub struct SessionCacheKey {
16 pub host: String,
17 pub port: u16,
18}
19
20impl SessionCacheKey {
21 pub fn new(host: &str, port: u16) -> Self {
22 Self {
23 host: host.trim_end_matches('.').to_ascii_lowercase(),
24 port,
25 }
26 }
27}
28
29#[derive(Debug, Clone)]
30struct CachedSession {
31 der: Vec<u8>,
32 early_data_capable: bool,
33 max_age: Duration,
34 received_at: Instant,
35}
36
37#[derive(Debug, Clone)]
39pub struct SessionCache {
40 inner: Arc<Mutex<SessionCacheInner>>,
41 session_stored: Arc<Notify>,
42}
43
44#[derive(Debug)]
45struct SessionCacheInner {
46 sessions: HashMap<SessionCacheKey, CachedSession>,
47 default_max_age: Duration,
48}
49
50impl SessionCache {
51 pub fn new() -> Self {
53 Self {
54 inner: Arc::new(Mutex::new(SessionCacheInner {
55 sessions: HashMap::new(),
56 default_max_age: Duration::from_secs(86400),
57 })),
58 session_stored: Arc::new(Notify::new()),
59 }
60 }
61
62 pub fn with_max_age(max_age: Duration) -> Self {
64 Self {
65 inner: Arc::new(Mutex::new(SessionCacheInner {
66 sessions: HashMap::new(),
67 default_max_age: max_age,
68 })),
69 session_stored: Arc::new(Notify::new()),
70 }
71 }
72
73 pub fn store_session(
75 &self,
76 key: SessionCacheKey,
77 der: Vec<u8>,
78 early_data_capable: bool,
79 max_age: Option<Duration>,
80 ) {
81 {
82 let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
83 let max_age = max_age.unwrap_or(inner.default_max_age);
84 inner.sessions.insert(
85 key,
86 CachedSession {
87 der,
88 early_data_capable,
89 max_age,
90 received_at: Instant::now(),
91 },
92 );
93 }
94 self.session_stored.notify_waiters();
95 }
96
97 pub fn store_ticket(&self, host: &str, ticket_data: Vec<u8>, max_age: Option<Duration>) {
99 self.store_session(SessionCacheKey::new(host, 443), ticket_data, false, max_age);
100 }
101
102 pub fn get_session(&self, key: &SessionCacheKey) -> Option<SslSession> {
104 let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
105 let entry = inner.sessions.get(key)?.clone();
106 if entry.received_at.elapsed() >= entry.max_age {
107 inner.sessions.remove(key);
108 return None;
109 }
110 SslSession::from_der(&entry.der).ok()
111 }
112
113 pub async fn wait_for_session(&self, key: &SessionCacheKey, timeout: Duration) -> bool {
115 tokio::time::timeout(timeout, async {
116 loop {
117 if self.has_session(key) {
118 return;
119 }
120 let notified = self.session_stored.notified();
121 tokio::pin!(notified);
122 notified.as_mut().enable();
123 if self.has_session(key) {
124 return;
125 }
126 notified.await;
127 }
128 })
129 .await
130 .is_ok()
131 }
132
133 fn has_session(&self, key: &SessionCacheKey) -> bool {
134 let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
135 let Some(entry) = inner.sessions.get(key) else {
136 return false;
137 };
138 if entry.received_at.elapsed() >= entry.max_age {
139 inner.sessions.remove(key);
140 return false;
141 }
142 true
143 }
144
145 pub fn supports_zero_rtt(&self, key: &SessionCacheKey) -> bool {
147 let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
148 let Some(entry) = inner.sessions.get(key) else {
149 return false;
150 };
151 if entry.received_at.elapsed() >= entry.max_age {
152 inner.sessions.remove(key);
153 return false;
154 }
155 entry.early_data_capable
156 }
157
158 pub fn get_ticket(&self, host: &str) -> Option<Vec<u8>> {
160 let key = SessionCacheKey::new(host, 443);
161 let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
162 let entry = inner.sessions.get(&key)?.clone();
163 if entry.received_at.elapsed() >= entry.max_age {
164 inner.sessions.remove(&key);
165 return None;
166 }
167 Some(entry.der.clone())
168 }
169
170 pub fn clear(&self) {
172 let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
173 inner.sessions.clear();
174 }
175
176 pub fn cleanup_expired(&self) {
178 let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
179 inner
180 .sessions
181 .retain(|_, entry| entry.received_at.elapsed() < entry.max_age);
182 }
183
184 pub fn len(&self) -> usize {
186 let inner = self.inner.lock().expect("Session cache mutex poisoned");
187 inner.sessions.len()
188 }
189
190 pub fn is_empty(&self) -> bool {
192 self.len() == 0
193 }
194}
195
196impl Default for SessionCache {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_session_cache_store_and_retrieve() {
208 let cache = SessionCache::new();
209 cache.store_session(
210 SessionCacheKey::new("example.com", 443),
211 vec![1, 2, 3],
212 false,
213 None,
214 );
215
216 assert_eq!(
217 cache
218 .get_ticket("example.com")
219 .expect("legacy lookup should work"),
220 vec![1, 2, 3]
221 );
222 assert!(cache
223 .get_session(&SessionCacheKey::new("other.com", 443))
224 .is_none());
225 }
226
227 #[test]
228 fn test_session_cache_clear() {
229 let cache = SessionCache::new();
230 cache.store_ticket("example.com", vec![1, 2, 3], None);
231 cache.store_ticket("other.com", vec![4, 5, 6], None);
232
233 assert_eq!(cache.len(), 2);
234 cache.clear();
235 assert_eq!(cache.len(), 0);
236 }
237
238 #[tokio::test]
239 async fn wait_for_session_observes_preexisting_session() {
240 let cache = SessionCache::new();
241 let key = SessionCacheKey::new("example.com", 443);
242 cache.store_session(key.clone(), vec![1, 2, 3], false, None);
243
244 assert!(cache.wait_for_session(&key, Duration::from_millis(1)).await);
245 }
246
247 #[tokio::test]
248 async fn store_session_notifies_after_releasing_cache_lock() {
249 let cache = SessionCache::new();
250 let key = SessionCacheKey::new("example.com", 443);
251 let waiter = {
252 let cache = cache.clone();
253 let key = key.clone();
254 tokio::spawn(async move {
255 assert!(cache.wait_for_session(&key, Duration::from_secs(1)).await);
256 let _guard = cache.inner.lock().expect("Session cache mutex poisoned");
257 })
258 };
259
260 cache.store_session(key, vec![1, 2, 3], false, None);
261
262 tokio::time::timeout(Duration::from_secs(1), waiter)
263 .await
264 .expect("waiter must not block on a notification sent while the cache lock is held")
265 .expect("waiter task must not panic");
266 }
267}