quincy_server/server/
session.rs1use std::num::NonZeroU32;
8use std::sync::Arc;
9use std::time::Instant;
10
11use dashmap::DashMap;
12use governor::clock::DefaultClock;
13use governor::middleware::NoOpMiddleware;
14use governor::state::{InMemoryState, NotKeyed};
15use governor::{Quota, RateLimiter};
16use ipnet::IpNet;
17use tracing::info;
18
19use quincy::config::Bandwidth;
20
21pub type BandwidthLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
23
24pub struct ConnectionSession {
26 pub client_address: IpNet,
28 pub connected_at: Instant,
30}
31
32pub struct UserSession {
34 connections: Vec<ConnectionSession>,
36 rate_limiter: Option<Arc<BandwidthLimiter>>,
39}
40
41pub struct UserSessionRegistry {
43 sessions: DashMap<String, UserSession>,
44}
45
46impl Default for UserSessionRegistry {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl UserSessionRegistry {
53 pub fn new() -> Self {
55 Self {
56 sessions: DashMap::new(),
57 }
58 }
59
60 pub fn add_connection(
76 &self,
77 username: &str,
78 session: ConnectionSession,
79 bandwidth_limit: Option<Bandwidth>,
80 ) -> Option<Arc<BandwidthLimiter>> {
81 let mut entry = self
82 .sessions
83 .entry(username.to_string())
84 .or_insert_with(|| {
85 let rate_limiter = bandwidth_limit.map(|bw| {
86 let kib_per_sec = bw.kib_per_second();
87 let rate = NonZeroU32::new(kib_per_sec).expect("kib_per_second returns >= 1");
89 let burst = NonZeroU32::new(kib_per_sec.max(64)).expect("burst is >= 64");
91 let quota = Quota::per_second(rate).allow_burst(burst);
92 Arc::new(RateLimiter::direct(quota))
93 });
94
95 info!(
96 "Created new session for user '{username}' (bandwidth limit: {})",
97 bandwidth_limit
98 .map(|bw| bw.to_string())
99 .unwrap_or_else(|| "unlimited".to_string())
100 );
101
102 UserSession {
103 connections: Vec::new(),
104 rate_limiter,
105 }
106 });
107
108 entry.connections.push(session);
109 entry.rate_limiter.clone()
110 }
111
112 pub fn remove_connection(&self, username: &str, client_address: &IpNet) {
118 if self
119 .sessions
120 .remove_if_mut(username, |_, session| {
121 session
122 .connections
123 .retain(|c| &c.client_address != client_address);
124 session.connections.is_empty()
125 })
126 .is_some()
127 {
128 info!("Removed last session for user '{username}'");
129 }
130 }
131
132 pub fn active_connection_count(&self) -> usize {
134 self.sessions.iter().map(|e| e.connections.len()).sum()
135 }
136
137 pub fn active_user_count(&self) -> usize {
139 self.sessions.len()
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use std::sync::Arc;
146 use std::time::Instant;
147
148 use ipnet::IpNet;
149
150 use quincy::config::Bandwidth;
151
152 use super::{ConnectionSession, UserSessionRegistry};
153
154 fn make_session(ip: &str) -> ConnectionSession {
156 ConnectionSession {
157 client_address: ip.parse().unwrap(),
158 connected_at: Instant::now(),
159 }
160 }
161
162 #[test]
163 fn add_first_connection_creates_session() {
164 let registry = UserSessionRegistry::new();
165 registry.add_connection("alice", make_session("10.0.0.2/24"), None);
166
167 assert_eq!(registry.active_connection_count(), 1);
168 assert_eq!(registry.active_user_count(), 1);
169 }
170
171 #[test]
172 fn add_second_connection_shares_limiter() {
173 let registry = UserSessionRegistry::new();
174 let bw = Some(Bandwidth::from_bytes_per_second(1_250_000));
175
176 let limiter1 = registry.add_connection("alice", make_session("10.0.0.2/24"), bw);
177 let limiter2 = registry.add_connection("alice", make_session("10.0.0.3/24"), bw);
178
179 assert!(limiter1.is_some());
180 assert!(limiter2.is_some());
181 assert!(Arc::ptr_eq(
182 limiter1.as_ref().unwrap(),
183 limiter2.as_ref().unwrap()
184 ));
185
186 assert_eq!(registry.active_connection_count(), 2);
187 assert_eq!(registry.active_user_count(), 1);
188 }
189
190 #[test]
191 fn add_connection_unlimited() {
192 let registry = UserSessionRegistry::new();
193 let limiter = registry.add_connection("bob", make_session("10.0.0.4/24"), None);
194
195 assert!(limiter.is_none());
196 }
197
198 #[test]
199 fn remove_last_connection_drops_session() {
200 let registry = UserSessionRegistry::new();
201 let addr: IpNet = "10.0.0.2/24".parse().unwrap();
202
203 registry.add_connection("alice", make_session("10.0.0.2/24"), None);
204 assert_eq!(registry.active_connection_count(), 1);
205
206 registry.remove_connection("alice", &addr);
207 assert_eq!(registry.active_connection_count(), 0);
208 assert_eq!(registry.active_user_count(), 0);
209 }
210
211 #[test]
212 fn remove_one_of_two_connections() {
213 let registry = UserSessionRegistry::new();
214 let addr1: IpNet = "10.0.0.2/24".parse().unwrap();
215
216 registry.add_connection("alice", make_session("10.0.0.2/24"), None);
217 registry.add_connection("alice", make_session("10.0.0.3/24"), None);
218 assert_eq!(registry.active_connection_count(), 2);
219
220 registry.remove_connection("alice", &addr1);
221 assert_eq!(registry.active_connection_count(), 1);
222 assert_eq!(registry.active_user_count(), 1);
223 }
224
225 #[test]
226 fn remove_nonexistent_connection_is_noop() {
227 let registry = UserSessionRegistry::new();
228 let addr: IpNet = "10.0.0.99/24".parse().unwrap();
229
230 registry.remove_connection("nobody", &addr);
232
233 registry.add_connection("alice", make_session("10.0.0.2/24"), None);
235 registry.remove_connection("alice", &addr);
236
237 assert_eq!(registry.active_connection_count(), 1);
238 }
239
240 #[tokio::test]
241 async fn concurrent_add_remove() {
242 let registry = Arc::new(UserSessionRegistry::new());
243 let mut handles = Vec::new();
244
245 for i in 0..20 {
246 let registry = registry.clone();
247 handles.push(tokio::spawn(async move {
248 let ip = format!("10.0.{}.{}/24", i / 256, i % 256);
249 let username = format!("user_{}", i % 5);
250 let bw = if i % 2 == 0 {
251 Some(Bandwidth::from_bytes_per_second(1_000_000))
252 } else {
253 None
254 };
255
256 registry.add_connection(&username, make_session(&ip), bw);
257
258 tokio::task::yield_now().await;
260
261 let addr: IpNet = ip.parse().unwrap();
262 registry.remove_connection(&username, &addr);
263 }));
264 }
265
266 for handle in handles {
267 handle.await.expect("task should not panic");
268 }
269
270 assert_eq!(registry.active_connection_count(), 0);
272 assert_eq!(registry.active_user_count(), 0);
273 }
274}