matrix_http_rendezvous/
sessions.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::BTreeMap,
17    future::Future,
18    sync::Arc,
19    time::{Duration, SystemTime},
20};
21
22use axum::TypedHeader;
23use base64ct::Encoding;
24use bytes::Bytes;
25use headers::{ContentType, ETag, Expires, LastModified};
26use mime::Mime;
27use sha2::Digest;
28use tokio::sync::{Mutex, RwLock, RwLockMappedWriteGuard, RwLockReadGuard, RwLockWriteGuard};
29use ulid::Ulid;
30
31pub struct Session {
32    hash: [u8; 32],
33    data: Bytes,
34    content_type: Mime,
35    last_modified: SystemTime,
36    expires: SystemTime,
37}
38
39impl Session {
40    fn new(data: Bytes, content_type: Mime, ttl: Duration) -> Self {
41        let hash = sha2::Sha256::digest(&data).into();
42        let now = SystemTime::now();
43        Self {
44            hash,
45            data,
46            content_type,
47            expires: now + ttl,
48            last_modified: now,
49        }
50    }
51
52    pub fn update(&mut self, data: Bytes, content_type: Mime) {
53        self.hash = sha2::Sha256::digest(&data).into();
54        self.data = data;
55        self.content_type = content_type;
56        self.last_modified = SystemTime::now();
57    }
58
59    pub fn content_type(&self) -> ContentType {
60        self.content_type.clone().into()
61    }
62
63    pub fn etag(&self) -> ETag {
64        let encoded = base64ct::Base64Url::encode_string(&self.hash);
65        // SAFETY: Base64 encoding is URL-safe, so ETag-safe
66        format!("\"{encoded}\"")
67            .parse()
68            .expect("base64-encoded hash should be URL-safe")
69    }
70
71    pub fn data(&self) -> Bytes {
72        self.data.clone()
73    }
74
75    fn last_modified(&self) -> LastModified {
76        self.last_modified.into()
77    }
78
79    fn expires(&self) -> Expires {
80        self.expires.into()
81    }
82
83    pub fn typed_headers(
84        &self,
85    ) -> (
86        TypedHeader<ETag>,
87        TypedHeader<Expires>,
88        TypedHeader<LastModified>,
89    ) {
90        (
91            TypedHeader(self.etag()),
92            TypedHeader(self.expires()),
93            TypedHeader(self.last_modified()),
94        )
95    }
96}
97
98#[derive(Clone, Default)]
99pub struct Sessions {
100    inner: Arc<RwLock<BTreeMap<Ulid, Session>>>,
101    generator: Arc<Mutex<ulid::Generator>>,
102    capacity: usize,
103    hard_capacity: usize,
104    ttl: Duration,
105}
106
107/// Evict the keys at the beginning of a [`BTreeMap`], up to ``capacity``
108fn evict<K: Copy + Ord, V>(sessions: &mut BTreeMap<K, V>, capacity: usize) {
109    // NOTE: eviction is based on the fact that ULIDs are monotonically increasing,
110    // by evictin the keys at the head of the map
111
112    // List of keys to evict
113    let keys: Vec<K> = sessions
114        .keys()
115        .take(sessions.len() - capacity)
116        .copied()
117        .collect();
118
119    // Now evict the keys
120    for key in keys {
121        sessions.remove(&key);
122    }
123}
124
125impl Sessions {
126    /// Create a new session store with the given parameters
127    #[must_use]
128    pub fn new(ttl: Duration, capacity: usize) -> Self {
129        Self {
130            inner: Arc::new(RwLock::new(BTreeMap::new())),
131            generator: Arc::new(Mutex::new(ulid::Generator::new())),
132            ttl,
133            capacity,
134            hard_capacity: capacity * 2,
135        }
136    }
137
138    /// Create and insert a new session in the store
139    pub async fn new_session(
140        &self,
141        payload: Bytes,
142        content_type: Mime,
143    ) -> (Ulid, RwLockReadGuard<Session>) {
144        let id = self.generate_id().await;
145        let session = Session::new(payload, content_type, self.ttl);
146        let session = self.insert(id, session, self.ttl).await;
147        (id, session)
148    }
149
150    /// Find a session in the store
151    pub async fn get_session(&self, id: Ulid) -> Option<RwLockReadGuard<Session>> {
152        let sessions = self.inner.read().await;
153        RwLockReadGuard::try_map(sessions, |sessions| sessions.get(&id)).ok()
154    }
155
156    /// Get a mutable reference to a session from the store
157    pub async fn get_session_mut(&self, id: Ulid) -> Option<RwLockMappedWriteGuard<Session>> {
158        let sessions = self.inner.write().await;
159        RwLockWriteGuard::try_map(sessions, |sessions| sessions.get_mut(&id)).ok()
160    }
161
162    /// Delete a session from the store
163    pub async fn delete_session(&self, id: Ulid) -> bool {
164        self.inner.write().await.remove(&id).is_some()
165    }
166
167    async fn insert(&self, id: Ulid, session: Session, ttl: Duration) -> RwLockReadGuard<Session> {
168        let mut sessions = self.inner.write().await;
169
170        // When inserting, we check if we will hit the 'hard' capacity, so that we never
171        // go over that capacity
172        if sessions.len() + 1 >= self.hard_capacity {
173            evict(&mut sessions, self.capacity);
174        }
175
176        sessions.insert(id, session);
177
178        // Downgrade the write lock to a read lock, to get a reference to the
179        // just-inserted session
180        let session = RwLockReadGuard::map(sessions.downgrade(), |sessions| {
181            sessions
182                .get(&id)
183                .expect("Session should be in the map just after insertion")
184        });
185
186        let this = self.clone();
187        // TODO: cancel this task when an item gets deleted
188        tokio::task::spawn(async move {
189            tokio::time::sleep(ttl).await;
190            this.delete_session(id).await;
191        });
192
193        session
194    }
195
196    async fn generate_id(&self) -> Ulid {
197        self.generator
198            .lock()
199            .await
200            .generate()
201            // This would panic the thread if too many IDs (more than 2^40) are generated on the
202            // same millisecond, which is very unlikely
203            .expect("Failed to generate random ID")
204    }
205
206    /// A loop which evicts keys if the capacity is reached
207    pub fn eviction_task(
208        &self,
209        interval: Duration,
210    ) -> impl Future<Output = ()> + Send + Sync + 'static {
211        let this = self.clone();
212        async move {
213            let mut interval = tokio::time::interval(interval);
214            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
215
216            loop {
217                interval.tick().await;
218                this.evict().await;
219            }
220        }
221    }
222
223    /// Trigger an eviction, removing the oldest entries if the soft capacity
224    /// has been reached
225    pub(crate) async fn evict(&self) {
226        if self.inner.read().await.len() > self.capacity {
227            let mut sessions = self.inner.write().await;
228            evict(&mut sessions, self.capacity);
229        }
230    }
231
232    /// Fill the sessions storage to check how much memory it might use on max
233    /// capacity
234    ///
235    /// # Panics
236    ///
237    /// It panics if the session storage is not empty
238    pub async fn fill_for_mem_check(&self, entry_size: usize) {
239        let mut sessions = self.inner.write().await;
240        let mut generator = self.generator.lock().await;
241        assert!(sessions.is_empty());
242
243        let data: Vec<u8> = std::iter::repeat(42).take(entry_size).collect();
244        sessions.extend((0..self.capacity).map(|_| {
245            let data = Bytes::from(data.clone());
246            let id = generator.generate().unwrap();
247            let session = Session::new(data, mime::APPLICATION_OCTET_STREAM, self.ttl);
248            (id, session)
249        }));
250
251        // Start the deletion tasks for all the sessions
252        let ttl = self.ttl;
253        for &key in sessions.keys() {
254            let inner = self.inner.clone();
255            tokio::task::spawn(async move {
256                tokio::time::sleep(ttl).await;
257                inner.write().await.remove(&key);
258            });
259        }
260    }
261}