Skip to main content

specter/transport/
session.rs

1//! TLS session resumption and caching for TCP-TLS connections.
2//!
3//! BoringSSL does not retain client sessions in its internal cache. Callers must
4//! install `SSL_CTX_sess_set_new_cb`, store tickets externally, and replay them
5//! with `SSL_set_session` on subsequent dials.
6
7use boring::ssl::SslSession;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio::sync::Notify;
12
13/// Cache key for a TLS session ticket.
14#[derive(Debug, Clone, Hash, PartialEq, Eq)]
15pub struct SessionCacheKey {
16    pub host: String,
17    pub port: u16,
18}
19
20impl SessionCacheKey {
21    pub fn new(host: &str, port: u16) -> Self {
22        Self {
23            host: host.trim_end_matches('.').to_ascii_lowercase(),
24            port,
25        }
26    }
27}
28
29#[derive(Debug, Clone)]
30struct CachedSession {
31    der: Vec<u8>,
32    early_data_capable: bool,
33    max_age: Duration,
34    received_at: Instant,
35}
36
37/// Host-keyed TLS session ticket cache shared across connector clones.
38#[derive(Debug, Clone)]
39pub struct SessionCache {
40    inner: Arc<Mutex<SessionCacheInner>>,
41    session_stored: Arc<Notify>,
42}
43
44#[derive(Debug)]
45struct SessionCacheInner {
46    sessions: HashMap<SessionCacheKey, CachedSession>,
47    default_max_age: Duration,
48}
49
50impl SessionCache {
51    /// Create a new session cache with default max age (24 hours).
52    pub fn new() -> Self {
53        Self {
54            inner: Arc::new(Mutex::new(SessionCacheInner {
55                sessions: HashMap::new(),
56                default_max_age: Duration::from_secs(86400),
57            })),
58            session_stored: Arc::new(Notify::new()),
59        }
60    }
61
62    /// Create a session cache with custom default max age.
63    pub fn with_max_age(max_age: Duration) -> Self {
64        Self {
65            inner: Arc::new(Mutex::new(SessionCacheInner {
66                sessions: HashMap::new(),
67                default_max_age: max_age,
68            })),
69            session_stored: Arc::new(Notify::new()),
70        }
71    }
72
73    /// Store a serialized TLS session for later resumption.
74    pub fn store_session(
75        &self,
76        key: SessionCacheKey,
77        der: Vec<u8>,
78        early_data_capable: bool,
79        max_age: Option<Duration>,
80    ) {
81        {
82            let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
83            let max_age = max_age.unwrap_or(inner.default_max_age);
84            inner.sessions.insert(
85                key,
86                CachedSession {
87                    der,
88                    early_data_capable,
89                    max_age,
90                    received_at: Instant::now(),
91                },
92            );
93        }
94        self.session_stored.notify_waiters();
95    }
96
97    /// Legacy host-only store API retained for compatibility.
98    pub fn store_ticket(&self, host: &str, ticket_data: Vec<u8>, max_age: Option<Duration>) {
99        self.store_session(SessionCacheKey::new(host, 443), ticket_data, false, max_age);
100    }
101
102    /// Load a cached session if still valid.
103    pub fn get_session(&self, key: &SessionCacheKey) -> Option<SslSession> {
104        let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
105        let entry = inner.sessions.get(key)?.clone();
106        if entry.received_at.elapsed() >= entry.max_age {
107            inner.sessions.remove(key);
108            return None;
109        }
110        SslSession::from_der(&entry.der).ok()
111    }
112
113    /// Wait until a session for `key` is stored or `timeout` elapses.
114    pub async fn wait_for_session(&self, key: &SessionCacheKey, timeout: Duration) -> bool {
115        tokio::time::timeout(timeout, async {
116            loop {
117                if self.has_session(key) {
118                    return;
119                }
120                let notified = self.session_stored.notified();
121                tokio::pin!(notified);
122                notified.as_mut().enable();
123                if self.has_session(key) {
124                    return;
125                }
126                notified.await;
127            }
128        })
129        .await
130        .is_ok()
131    }
132
133    fn has_session(&self, key: &SessionCacheKey) -> bool {
134        let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
135        let Some(entry) = inner.sessions.get(key) else {
136            return false;
137        };
138        if entry.received_at.elapsed() >= entry.max_age {
139            inner.sessions.remove(key);
140            return false;
141        }
142        true
143    }
144
145    /// Whether a cached session advertises TLS 1.3 early-data support.
146    pub fn supports_zero_rtt(&self, key: &SessionCacheKey) -> bool {
147        let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
148        let Some(entry) = inner.sessions.get(key) else {
149            return false;
150        };
151        if entry.received_at.elapsed() >= entry.max_age {
152            inner.sessions.remove(key);
153            return false;
154        }
155        entry.early_data_capable
156    }
157
158    /// Legacy host-only lookup API retained for compatibility.
159    pub fn get_ticket(&self, host: &str) -> Option<Vec<u8>> {
160        let key = SessionCacheKey::new(host, 443);
161        let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
162        let entry = inner.sessions.get(&key)?.clone();
163        if entry.received_at.elapsed() >= entry.max_age {
164            inner.sessions.remove(&key);
165            return None;
166        }
167        Some(entry.der.clone())
168    }
169
170    /// Clear all cached sessions.
171    pub fn clear(&self) {
172        let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
173        inner.sessions.clear();
174    }
175
176    /// Remove expired sessions.
177    pub fn cleanup_expired(&self) {
178        let mut inner = self.inner.lock().expect("Session cache mutex poisoned");
179        inner
180            .sessions
181            .retain(|_, entry| entry.received_at.elapsed() < entry.max_age);
182    }
183
184    /// Number of cached sessions.
185    pub fn len(&self) -> usize {
186        let inner = self.inner.lock().expect("Session cache mutex poisoned");
187        inner.sessions.len()
188    }
189
190    /// Whether the cache is empty.
191    pub fn is_empty(&self) -> bool {
192        self.len() == 0
193    }
194}
195
196impl Default for SessionCache {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_session_cache_store_and_retrieve() {
208        let cache = SessionCache::new();
209        cache.store_session(
210            SessionCacheKey::new("example.com", 443),
211            vec![1, 2, 3],
212            false,
213            None,
214        );
215
216        assert_eq!(
217            cache
218                .get_ticket("example.com")
219                .expect("legacy lookup should work"),
220            vec![1, 2, 3]
221        );
222        assert!(cache
223            .get_session(&SessionCacheKey::new("other.com", 443))
224            .is_none());
225    }
226
227    #[test]
228    fn test_session_cache_clear() {
229        let cache = SessionCache::new();
230        cache.store_ticket("example.com", vec![1, 2, 3], None);
231        cache.store_ticket("other.com", vec![4, 5, 6], None);
232
233        assert_eq!(cache.len(), 2);
234        cache.clear();
235        assert_eq!(cache.len(), 0);
236    }
237
238    #[tokio::test]
239    async fn wait_for_session_observes_preexisting_session() {
240        let cache = SessionCache::new();
241        let key = SessionCacheKey::new("example.com", 443);
242        cache.store_session(key.clone(), vec![1, 2, 3], false, None);
243
244        assert!(cache.wait_for_session(&key, Duration::from_millis(1)).await);
245    }
246
247    #[tokio::test]
248    async fn store_session_notifies_after_releasing_cache_lock() {
249        let cache = SessionCache::new();
250        let key = SessionCacheKey::new("example.com", 443);
251        let waiter = {
252            let cache = cache.clone();
253            let key = key.clone();
254            tokio::spawn(async move {
255                assert!(cache.wait_for_session(&key, Duration::from_secs(1)).await);
256                let _guard = cache.inner.lock().expect("Session cache mutex poisoned");
257            })
258        };
259
260        cache.store_session(key, vec![1, 2, 3], false, None);
261
262        tokio::time::timeout(Duration::from_secs(1), waiter)
263            .await
264            .expect("waiter must not block on a notification sent while the cache lock is held")
265            .expect("waiter task must not panic");
266    }
267}