Skip to main content

network_protocol/transport/
session_cache.rs

1//! # TLS Session Cache
2//!
3//! This module provides in-memory TLS session ticket caching for session resumption.
4//! Session resumption reduces the handshake overhead by ~50-70%, allowing clients to
5//! reconnect without full TLS 1.3 handshakes.
6//!
7//! ## Features
8//! - **Thread-safe**: Uses Arc<Mutex<>> for safe concurrent access
9//! - **TTL-based expiration**: Sessions expire after configurable duration
10//! - **Memory-bounded**: Configurable maximum entries to prevent unbounded growth
11//! - **Non-blocking lookups**: Fast session retrieval without blocking other operations
12//!
13//! ## Performance
14//! - Lookup: ~100-200ns (HashMap)
15//! - Insertion: ~500-1000ns (with lock contention)
16//! - Eviction: O(n) for expired entries, but typically O(1) for normal lookups
17//!
18//! ## Usage
19//! ```ignore
20//! use network_protocol::transport::session_cache::SessionCache;
21//! use std::time::Duration;
22//!
23//! // Create a cache with 1000 max entries and 1-hour TTL
24//! let cache = SessionCache::new(1000, Duration::from_secs(3600));
25//!
26//! // Store a session (typically handled internally by TLS layer)
27//! cache.store(session_id.clone(), ticket.clone()).await;
28//!
29//! // Retrieve for resumption (returned as Arc<Vec<u8>>)
30//! if let Some(ticket) = cache.get(&session_id).await {
31//!     // Use ticket for session resumption
32//! }
33//! ```
34
35use std::collections::HashMap;
36use std::sync::Arc;
37use std::time::{Duration, SystemTime};
38
39use tokio::sync::Mutex;
40use tracing::{debug, trace};
41
42/// A cached TLS session ticket with metadata
43#[derive(Clone, Debug)]
44struct SessionEntry {
45    /// The serialized session ticket from rustls
46    ticket: Arc<Vec<u8>>,
47    /// When this session was cached
48    created_at: SystemTime,
49    /// Time-to-live for this session
50    ttl: Duration,
51}
52
53impl SessionEntry {
54    /// Check if this entry has expired
55    fn is_expired(&self) -> bool {
56        match self.created_at.elapsed() {
57            Ok(elapsed) => elapsed > self.ttl,
58            Err(_) => true, // System time went backward, treat as expired
59        }
60    }
61}
62
63/// Thread-safe in-memory TLS session cache
64///
65/// Stores session tickets for TLS 1.3 resumption. This enables clients to reconnect
66/// without performing full handshakes, reducing latency by 50-70%.
67#[derive(Clone)]
68pub struct SessionCache {
69    /// Maximum number of sessions to cache
70    max_entries: usize,
71    /// Default TTL for new sessions
72    default_ttl: Duration,
73    /// Inner cache protected by mutex
74    inner: Arc<Mutex<SessionCacheInner>>,
75}
76
77struct SessionCacheInner {
78    /// Session ID -> cached ticket
79    sessions: HashMap<String, SessionEntry>,
80    /// Metadata for eviction policies
81    total_inserts: u64,
82}
83
84impl SessionCache {
85    /// Create a new session cache
86    ///
87    /// # Arguments
88    /// * `max_entries` - Maximum number of sessions to cache (e.g., 1000)
89    /// * `default_ttl` - Default time-to-live for each session (e.g., 1 hour)
90    ///
91    /// # Example
92    /// ```ignore
93    /// let cache = SessionCache::new(1000, Duration::from_secs(3600));
94    /// ```
95    pub fn new(max_entries: usize, default_ttl: Duration) -> Self {
96        Self {
97            max_entries,
98            default_ttl,
99            inner: Arc::new(Mutex::new(SessionCacheInner {
100                sessions: HashMap::with_capacity(max_entries),
101                total_inserts: 0,
102            })),
103        }
104    }
105
106    /// Store a session ticket in the cache
107    ///
108    /// This is typically called by the TLS layer after establishing a connection.
109    /// Automatically manages eviction when cache is full.
110    ///
111    /// # Arguments
112    /// * `session_id` - Unique session identifier
113    /// * `ticket` - Serialized TLS session ticket
114    pub async fn store<S: Into<String>>(&self, session_id: S, ticket: Vec<u8>) {
115        let mut inner = self.inner.lock().await;
116
117        let session_id = session_id.into();
118        let entry = SessionEntry {
119            ticket: Arc::new(ticket),
120            created_at: SystemTime::now(),
121            ttl: self.default_ttl,
122        };
123
124        // Clean expired entries before checking capacity
125        self.evict_expired(&mut inner);
126
127        // Store the session
128        inner.sessions.insert(session_id.clone(), entry);
129        inner.total_inserts += 1;
130
131        // Evict oldest if we exceed capacity
132        if inner.sessions.len() > self.max_entries {
133            self.evict_oldest(&mut inner);
134        }
135
136        trace!(
137            session_count = inner.sessions.len(),
138            "Session stored in cache"
139        );
140    }
141
142    /// Retrieve a session ticket from the cache
143    ///
144    /// Returns the ticket if found and not expired, None otherwise.
145    ///
146    /// # Example
147    /// ```ignore
148    /// if let Some(ticket) = cache.get("session-123").await {
149    ///     // Use ticket for resumption
150    /// }
151    /// ```
152    pub async fn get(&self, session_id: &str) -> Option<Arc<Vec<u8>>> {
153        let mut inner = self.inner.lock().await;
154
155        // Check if session exists and is not expired
156        if let Some(entry) = inner.sessions.get(session_id) {
157            if !entry.is_expired() {
158                trace!("Session cache hit");
159                return Some(entry.ticket.clone());
160            }
161        }
162
163        // Remove expired session
164        inner.sessions.remove(session_id);
165        trace!("Session cache miss or expired");
166        None
167    }
168
169    /// Clear all sessions from the cache
170    pub async fn clear(&self) {
171        let mut inner = self.inner.lock().await;
172        let count = inner.sessions.len();
173        inner.sessions.clear();
174        debug!(cleared_count = count, "Session cache cleared");
175    }
176
177    /// Get current cache statistics
178    pub async fn stats(&self) -> SessionCacheStats {
179        let inner = self.inner.lock().await;
180
181        let expired_count = inner.sessions.values().filter(|e| e.is_expired()).count();
182
183        SessionCacheStats {
184            total_entries: inner.sessions.len(),
185            max_entries: self.max_entries,
186            expired_count,
187            total_inserts: inner.total_inserts,
188        }
189    }
190
191    /// Evict all expired entries from the cache
192    #[allow(dead_code)]
193    async fn evict_expired_async(&self) {
194        let mut inner = self.inner.lock().await;
195        self.evict_expired(&mut inner);
196    }
197
198    /// Internal: Evict expired entries (called with lock held)
199    fn evict_expired(&self, inner: &mut SessionCacheInner) {
200        let before = inner.sessions.len();
201        inner.sessions.retain(|_, entry| !entry.is_expired());
202        let after = inner.sessions.len();
203
204        if before != after {
205            debug!(
206                removed_count = before - after,
207                remaining_count = after,
208                "Expired sessions evicted"
209            );
210        }
211    }
212
213    /// Internal: Evict oldest entry (called with lock held)
214    fn evict_oldest(&self, inner: &mut SessionCacheInner) {
215        if let Some(oldest_key) = inner
216            .sessions
217            .iter()
218            .min_by_key(|(_, entry)| entry.created_at)
219            .map(|(k, _)| k.clone())
220        {
221            inner.sessions.remove(&oldest_key);
222            debug!("Oldest session evicted to make room");
223        }
224    }
225}
226
227/// Statistics about the session cache
228#[derive(Debug, Clone, Copy)]
229pub struct SessionCacheStats {
230    /// Current number of valid sessions
231    pub total_entries: usize,
232    /// Maximum capacity
233    pub max_entries: usize,
234    /// Number of expired but not yet evicted entries
235    pub expired_count: usize,
236    /// Total sessions ever inserted
237    pub total_inserts: u64,
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[tokio::test]
245    #[allow(clippy::unwrap_used)]
246    async fn test_store_and_retrieve() {
247        let cache = SessionCache::new(10, Duration::from_secs(60));
248
249        cache.store("session-1", vec![1, 2, 3, 4]).await;
250        let ticket = cache.get("session-1").await;
251
252        assert!(ticket.is_some());
253        assert_eq!(*ticket.unwrap(), vec![1, 2, 3, 4]);
254    }
255
256    #[tokio::test]
257    async fn test_missing_session() {
258        let cache = SessionCache::new(10, Duration::from_secs(60));
259        let ticket = cache.get("nonexistent").await;
260        assert!(ticket.is_none());
261    }
262
263    #[tokio::test]
264    async fn test_capacity_eviction() {
265        let cache = SessionCache::new(3, Duration::from_secs(60));
266
267        for i in 0..5 {
268            cache.store(format!("session-{i}"), vec![i as u8]).await;
269        }
270
271        let stats = cache.stats().await;
272        assert_eq!(stats.total_entries, 3);
273        assert_eq!(stats.total_inserts, 5);
274    }
275
276    #[tokio::test]
277    async fn test_clear() {
278        let cache = SessionCache::new(10, Duration::from_secs(60));
279
280        cache.store("session-1", vec![1, 2, 3]).await;
281        cache.clear().await;
282
283        let ticket = cache.get("session-1").await;
284        assert!(ticket.is_none());
285    }
286}