matrix_http_rendezvous/
sessions.rs1use std::{
16 collections::BTreeMap,
17 future::Future,
18 sync::Arc,
19 time::{Duration, SystemTime},
20};
21
22use axum::TypedHeader;
23use base64ct::Encoding;
24use bytes::Bytes;
25use headers::{ContentType, ETag, Expires, LastModified};
26use mime::Mime;
27use sha2::Digest;
28use tokio::sync::{Mutex, RwLock, RwLockMappedWriteGuard, RwLockReadGuard, RwLockWriteGuard};
29use ulid::Ulid;
30
31pub struct Session {
32 hash: [u8; 32],
33 data: Bytes,
34 content_type: Mime,
35 last_modified: SystemTime,
36 expires: SystemTime,
37}
38
39impl Session {
40 fn new(data: Bytes, content_type: Mime, ttl: Duration) -> Self {
41 let hash = sha2::Sha256::digest(&data).into();
42 let now = SystemTime::now();
43 Self {
44 hash,
45 data,
46 content_type,
47 expires: now + ttl,
48 last_modified: now,
49 }
50 }
51
52 pub fn update(&mut self, data: Bytes, content_type: Mime) {
53 self.hash = sha2::Sha256::digest(&data).into();
54 self.data = data;
55 self.content_type = content_type;
56 self.last_modified = SystemTime::now();
57 }
58
59 pub fn content_type(&self) -> ContentType {
60 self.content_type.clone().into()
61 }
62
63 pub fn etag(&self) -> ETag {
64 let encoded = base64ct::Base64Url::encode_string(&self.hash);
65 format!("\"{encoded}\"")
67 .parse()
68 .expect("base64-encoded hash should be URL-safe")
69 }
70
71 pub fn data(&self) -> Bytes {
72 self.data.clone()
73 }
74
75 fn last_modified(&self) -> LastModified {
76 self.last_modified.into()
77 }
78
79 fn expires(&self) -> Expires {
80 self.expires.into()
81 }
82
83 pub fn typed_headers(
84 &self,
85 ) -> (
86 TypedHeader<ETag>,
87 TypedHeader<Expires>,
88 TypedHeader<LastModified>,
89 ) {
90 (
91 TypedHeader(self.etag()),
92 TypedHeader(self.expires()),
93 TypedHeader(self.last_modified()),
94 )
95 }
96}
97
98#[derive(Clone, Default)]
99pub struct Sessions {
100 inner: Arc<RwLock<BTreeMap<Ulid, Session>>>,
101 generator: Arc<Mutex<ulid::Generator>>,
102 capacity: usize,
103 hard_capacity: usize,
104 ttl: Duration,
105}
106
107fn evict<K: Copy + Ord, V>(sessions: &mut BTreeMap<K, V>, capacity: usize) {
109 let keys: Vec<K> = sessions
114 .keys()
115 .take(sessions.len() - capacity)
116 .copied()
117 .collect();
118
119 for key in keys {
121 sessions.remove(&key);
122 }
123}
124
125impl Sessions {
126 #[must_use]
128 pub fn new(ttl: Duration, capacity: usize) -> Self {
129 Self {
130 inner: Arc::new(RwLock::new(BTreeMap::new())),
131 generator: Arc::new(Mutex::new(ulid::Generator::new())),
132 ttl,
133 capacity,
134 hard_capacity: capacity * 2,
135 }
136 }
137
138 pub async fn new_session(
140 &self,
141 payload: Bytes,
142 content_type: Mime,
143 ) -> (Ulid, RwLockReadGuard<Session>) {
144 let id = self.generate_id().await;
145 let session = Session::new(payload, content_type, self.ttl);
146 let session = self.insert(id, session, self.ttl).await;
147 (id, session)
148 }
149
150 pub async fn get_session(&self, id: Ulid) -> Option<RwLockReadGuard<Session>> {
152 let sessions = self.inner.read().await;
153 RwLockReadGuard::try_map(sessions, |sessions| sessions.get(&id)).ok()
154 }
155
156 pub async fn get_session_mut(&self, id: Ulid) -> Option<RwLockMappedWriteGuard<Session>> {
158 let sessions = self.inner.write().await;
159 RwLockWriteGuard::try_map(sessions, |sessions| sessions.get_mut(&id)).ok()
160 }
161
162 pub async fn delete_session(&self, id: Ulid) -> bool {
164 self.inner.write().await.remove(&id).is_some()
165 }
166
167 async fn insert(&self, id: Ulid, session: Session, ttl: Duration) -> RwLockReadGuard<Session> {
168 let mut sessions = self.inner.write().await;
169
170 if sessions.len() + 1 >= self.hard_capacity {
173 evict(&mut sessions, self.capacity);
174 }
175
176 sessions.insert(id, session);
177
178 let session = RwLockReadGuard::map(sessions.downgrade(), |sessions| {
181 sessions
182 .get(&id)
183 .expect("Session should be in the map just after insertion")
184 });
185
186 let this = self.clone();
187 tokio::task::spawn(async move {
189 tokio::time::sleep(ttl).await;
190 this.delete_session(id).await;
191 });
192
193 session
194 }
195
196 async fn generate_id(&self) -> Ulid {
197 self.generator
198 .lock()
199 .await
200 .generate()
201 .expect("Failed to generate random ID")
204 }
205
206 pub fn eviction_task(
208 &self,
209 interval: Duration,
210 ) -> impl Future<Output = ()> + Send + Sync + 'static {
211 let this = self.clone();
212 async move {
213 let mut interval = tokio::time::interval(interval);
214 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
215
216 loop {
217 interval.tick().await;
218 this.evict().await;
219 }
220 }
221 }
222
223 pub(crate) async fn evict(&self) {
226 if self.inner.read().await.len() > self.capacity {
227 let mut sessions = self.inner.write().await;
228 evict(&mut sessions, self.capacity);
229 }
230 }
231
232 pub async fn fill_for_mem_check(&self, entry_size: usize) {
239 let mut sessions = self.inner.write().await;
240 let mut generator = self.generator.lock().await;
241 assert!(sessions.is_empty());
242
243 let data: Vec<u8> = std::iter::repeat(42).take(entry_size).collect();
244 sessions.extend((0..self.capacity).map(|_| {
245 let data = Bytes::from(data.clone());
246 let id = generator.generate().unwrap();
247 let session = Session::new(data, mime::APPLICATION_OCTET_STREAM, self.ttl);
248 (id, session)
249 }));
250
251 let ttl = self.ttl;
253 for &key in sessions.keys() {
254 let inner = self.inner.clone();
255 tokio::task::spawn(async move {
256 tokio::time::sleep(ttl).await;
257 inner.write().await.remove(&key);
258 });
259 }
260 }
261}