network_protocol/transport/
session_cache.rs1use std::collections::HashMap;
36use std::sync::Arc;
37use std::time::{Duration, SystemTime};
38
39use tokio::sync::Mutex;
40use tracing::{debug, trace};
41
42#[derive(Clone, Debug)]
44struct SessionEntry {
45 ticket: Arc<Vec<u8>>,
47 created_at: SystemTime,
49 ttl: Duration,
51}
52
53impl SessionEntry {
54 fn is_expired(&self) -> bool {
56 match self.created_at.elapsed() {
57 Ok(elapsed) => elapsed > self.ttl,
58 Err(_) => true, }
60 }
61}
62
63#[derive(Clone)]
68pub struct SessionCache {
69 max_entries: usize,
71 default_ttl: Duration,
73 inner: Arc<Mutex<SessionCacheInner>>,
75}
76
77struct SessionCacheInner {
78 sessions: HashMap<String, SessionEntry>,
80 total_inserts: u64,
82}
83
84impl SessionCache {
85 pub fn new(max_entries: usize, default_ttl: Duration) -> Self {
96 Self {
97 max_entries,
98 default_ttl,
99 inner: Arc::new(Mutex::new(SessionCacheInner {
100 sessions: HashMap::with_capacity(max_entries),
101 total_inserts: 0,
102 })),
103 }
104 }
105
106 pub async fn store<S: Into<String>>(&self, session_id: S, ticket: Vec<u8>) {
115 let mut inner = self.inner.lock().await;
116
117 let session_id = session_id.into();
118 let entry = SessionEntry {
119 ticket: Arc::new(ticket),
120 created_at: SystemTime::now(),
121 ttl: self.default_ttl,
122 };
123
124 self.evict_expired(&mut inner);
126
127 inner.sessions.insert(session_id.clone(), entry);
129 inner.total_inserts += 1;
130
131 if inner.sessions.len() > self.max_entries {
133 self.evict_oldest(&mut inner);
134 }
135
136 trace!(
137 session_count = inner.sessions.len(),
138 "Session stored in cache"
139 );
140 }
141
142 pub async fn get(&self, session_id: &str) -> Option<Arc<Vec<u8>>> {
153 let mut inner = self.inner.lock().await;
154
155 if let Some(entry) = inner.sessions.get(session_id) {
157 if !entry.is_expired() {
158 trace!("Session cache hit");
159 return Some(entry.ticket.clone());
160 }
161 }
162
163 inner.sessions.remove(session_id);
165 trace!("Session cache miss or expired");
166 None
167 }
168
169 pub async fn clear(&self) {
171 let mut inner = self.inner.lock().await;
172 let count = inner.sessions.len();
173 inner.sessions.clear();
174 debug!(cleared_count = count, "Session cache cleared");
175 }
176
177 pub async fn stats(&self) -> SessionCacheStats {
179 let inner = self.inner.lock().await;
180
181 let expired_count = inner.sessions.values().filter(|e| e.is_expired()).count();
182
183 SessionCacheStats {
184 total_entries: inner.sessions.len(),
185 max_entries: self.max_entries,
186 expired_count,
187 total_inserts: inner.total_inserts,
188 }
189 }
190
191 #[allow(dead_code)]
193 async fn evict_expired_async(&self) {
194 let mut inner = self.inner.lock().await;
195 self.evict_expired(&mut inner);
196 }
197
198 fn evict_expired(&self, inner: &mut SessionCacheInner) {
200 let before = inner.sessions.len();
201 inner.sessions.retain(|_, entry| !entry.is_expired());
202 let after = inner.sessions.len();
203
204 if before != after {
205 debug!(
206 removed_count = before - after,
207 remaining_count = after,
208 "Expired sessions evicted"
209 );
210 }
211 }
212
213 fn evict_oldest(&self, inner: &mut SessionCacheInner) {
215 if let Some(oldest_key) = inner
216 .sessions
217 .iter()
218 .min_by_key(|(_, entry)| entry.created_at)
219 .map(|(k, _)| k.clone())
220 {
221 inner.sessions.remove(&oldest_key);
222 debug!("Oldest session evicted to make room");
223 }
224 }
225}
226
227#[derive(Debug, Clone, Copy)]
229pub struct SessionCacheStats {
230 pub total_entries: usize,
232 pub max_entries: usize,
234 pub expired_count: usize,
236 pub total_inserts: u64,
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[tokio::test]
245 #[allow(clippy::unwrap_used)]
246 async fn test_store_and_retrieve() {
247 let cache = SessionCache::new(10, Duration::from_secs(60));
248
249 cache.store("session-1", vec![1, 2, 3, 4]).await;
250 let ticket = cache.get("session-1").await;
251
252 assert!(ticket.is_some());
253 assert_eq!(*ticket.unwrap(), vec![1, 2, 3, 4]);
254 }
255
256 #[tokio::test]
257 async fn test_missing_session() {
258 let cache = SessionCache::new(10, Duration::from_secs(60));
259 let ticket = cache.get("nonexistent").await;
260 assert!(ticket.is_none());
261 }
262
263 #[tokio::test]
264 async fn test_capacity_eviction() {
265 let cache = SessionCache::new(3, Duration::from_secs(60));
266
267 for i in 0..5 {
268 cache.store(format!("session-{i}"), vec![i as u8]).await;
269 }
270
271 let stats = cache.stats().await;
272 assert_eq!(stats.total_entries, 3);
273 assert_eq!(stats.total_inserts, 5);
274 }
275
276 #[tokio::test]
277 async fn test_clear() {
278 let cache = SessionCache::new(10, Duration::from_secs(60));
279
280 cache.store("session-1", vec![1, 2, 3]).await;
281 cache.clear().await;
282
283 let ticket = cache.get("session-1").await;
284 assert!(ticket.is_none());
285 }
286}