Skip to main content

quincy_server/server/
session.rs

1//! User session registry for tracking active VPN connections.
2//!
3//! Provides a centralized, thread-safe registry that maps usernames to their
4//! active connection sessions and shared rate limiters. The registry is only
5//! accessed on connect/disconnect -- it is NOT in the packet forwarding hot path.
6
7use std::num::NonZeroU32;
8use std::sync::Arc;
9use std::time::Instant;
10
11use dashmap::DashMap;
12use governor::clock::DefaultClock;
13use governor::middleware::NoOpMiddleware;
14use governor::state::{InMemoryState, NotKeyed};
15use governor::{Quota, RateLimiter};
16use ipnet::IpNet;
17use tracing::info;
18
19use quincy::config::Bandwidth;
20
21/// Type alias for the governor rate limiter used for bandwidth limiting.
22pub type BandwidthLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
23
24/// Metadata for a single active QUIC connection.
25pub struct ConnectionSession {
26    /// Tunnel IP assigned to this connection.
27    pub client_address: IpNet,
28    /// When this connection was established.
29    pub connected_at: Instant,
30}
31
32/// Per-user state, potentially spanning multiple concurrent connections.
33pub struct UserSession {
34    /// All active connections for this user.
35    connections: Vec<ConnectionSession>,
36    /// Shared rate limiter across all connections and directions.
37    /// `None` means unlimited bandwidth.
38    rate_limiter: Option<Arc<BandwidthLimiter>>,
39}
40
41/// Thread-safe registry of active user sessions.
42pub struct UserSessionRegistry {
43    sessions: DashMap<String, UserSession>,
44}
45
46impl Default for UserSessionRegistry {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl UserSessionRegistry {
53    /// Creates a new empty registry.
54    pub fn new() -> Self {
55        Self {
56            sessions: DashMap::new(),
57        }
58    }
59
60    /// Registers a new connection for the given user.
61    ///
62    /// On the user's first connection, creates a new `UserSession` and
63    /// optionally a rate limiter (if `bandwidth_limit` is `Some`). On
64    /// subsequent connections, the new connection joins the existing session
65    /// and shares the existing rate limiter.
66    ///
67    /// ### Arguments
68    /// - `username` - the authenticated username
69    /// - `session` - metadata for this connection
70    /// - `bandwidth_limit` - effective bandwidth limit for this user
71    ///
72    /// ### Returns
73    /// A cloned `Arc<BandwidthLimiter>` if the user has a bandwidth limit,
74    /// or `None` if the user has unlimited bandwidth.
75    pub fn add_connection(
76        &self,
77        username: &str,
78        session: ConnectionSession,
79        bandwidth_limit: Option<Bandwidth>,
80    ) -> Option<Arc<BandwidthLimiter>> {
81        let mut entry = self
82            .sessions
83            .entry(username.to_string())
84            .or_insert_with(|| {
85                let rate_limiter = bandwidth_limit.map(|bw| {
86                    let kib_per_sec = bw.kib_per_second();
87                    // kib_per_second() guarantees >= 1
88                    let rate = NonZeroU32::new(kib_per_sec).expect("kib_per_second returns >= 1");
89                    // Burst: at least 64 KiB or per-second rate, whichever is larger
90                    let burst = NonZeroU32::new(kib_per_sec.max(64)).expect("burst is >= 64");
91                    let quota = Quota::per_second(rate).allow_burst(burst);
92                    Arc::new(RateLimiter::direct(quota))
93                });
94
95                info!(
96                    "Created new session for user '{username}' (bandwidth limit: {})",
97                    bandwidth_limit
98                        .map(|bw| bw.to_string())
99                        .unwrap_or_else(|| "unlimited".to_string())
100                );
101
102                UserSession {
103                    connections: Vec::new(),
104                    rate_limiter,
105                }
106            });
107
108        entry.connections.push(session);
109        entry.rate_limiter.clone()
110    }
111
112    /// Removes a specific connection for the given user, identified by its
113    /// assigned tunnel IP address.
114    ///
115    /// If this was the user's last active connection, the entire `UserSession`
116    /// (including the rate limiter) is dropped.
117    pub fn remove_connection(&self, username: &str, client_address: &IpNet) {
118        if self
119            .sessions
120            .remove_if_mut(username, |_, session| {
121                session
122                    .connections
123                    .retain(|c| &c.client_address != client_address);
124                session.connections.is_empty()
125            })
126            .is_some()
127        {
128            info!("Removed last session for user '{username}'");
129        }
130    }
131
132    /// Returns the total number of active connections across all users.
133    pub fn active_connection_count(&self) -> usize {
134        self.sessions.iter().map(|e| e.connections.len()).sum()
135    }
136
137    /// Returns the number of users with at least one active connection.
138    pub fn active_user_count(&self) -> usize {
139        self.sessions.len()
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::sync::Arc;
146    use std::time::Instant;
147
148    use ipnet::IpNet;
149
150    use quincy::config::Bandwidth;
151
152    use super::{ConnectionSession, UserSessionRegistry};
153
154    /// Helper to create a `ConnectionSession` with the given IP string.
155    fn make_session(ip: &str) -> ConnectionSession {
156        ConnectionSession {
157            client_address: ip.parse().unwrap(),
158            connected_at: Instant::now(),
159        }
160    }
161
162    #[test]
163    fn add_first_connection_creates_session() {
164        let registry = UserSessionRegistry::new();
165        registry.add_connection("alice", make_session("10.0.0.2/24"), None);
166
167        assert_eq!(registry.active_connection_count(), 1);
168        assert_eq!(registry.active_user_count(), 1);
169    }
170
171    #[test]
172    fn add_second_connection_shares_limiter() {
173        let registry = UserSessionRegistry::new();
174        let bw = Some(Bandwidth::from_bytes_per_second(1_250_000));
175
176        let limiter1 = registry.add_connection("alice", make_session("10.0.0.2/24"), bw);
177        let limiter2 = registry.add_connection("alice", make_session("10.0.0.3/24"), bw);
178
179        assert!(limiter1.is_some());
180        assert!(limiter2.is_some());
181        assert!(Arc::ptr_eq(
182            limiter1.as_ref().unwrap(),
183            limiter2.as_ref().unwrap()
184        ));
185
186        assert_eq!(registry.active_connection_count(), 2);
187        assert_eq!(registry.active_user_count(), 1);
188    }
189
190    #[test]
191    fn add_connection_unlimited() {
192        let registry = UserSessionRegistry::new();
193        let limiter = registry.add_connection("bob", make_session("10.0.0.4/24"), None);
194
195        assert!(limiter.is_none());
196    }
197
198    #[test]
199    fn remove_last_connection_drops_session() {
200        let registry = UserSessionRegistry::new();
201        let addr: IpNet = "10.0.0.2/24".parse().unwrap();
202
203        registry.add_connection("alice", make_session("10.0.0.2/24"), None);
204        assert_eq!(registry.active_connection_count(), 1);
205
206        registry.remove_connection("alice", &addr);
207        assert_eq!(registry.active_connection_count(), 0);
208        assert_eq!(registry.active_user_count(), 0);
209    }
210
211    #[test]
212    fn remove_one_of_two_connections() {
213        let registry = UserSessionRegistry::new();
214        let addr1: IpNet = "10.0.0.2/24".parse().unwrap();
215
216        registry.add_connection("alice", make_session("10.0.0.2/24"), None);
217        registry.add_connection("alice", make_session("10.0.0.3/24"), None);
218        assert_eq!(registry.active_connection_count(), 2);
219
220        registry.remove_connection("alice", &addr1);
221        assert_eq!(registry.active_connection_count(), 1);
222        assert_eq!(registry.active_user_count(), 1);
223    }
224
225    #[test]
226    fn remove_nonexistent_connection_is_noop() {
227        let registry = UserSessionRegistry::new();
228        let addr: IpNet = "10.0.0.99/24".parse().unwrap();
229
230        // Remove from unknown user — should not panic
231        registry.remove_connection("nobody", &addr);
232
233        // Remove unknown IP from existing user — should not panic
234        registry.add_connection("alice", make_session("10.0.0.2/24"), None);
235        registry.remove_connection("alice", &addr);
236
237        assert_eq!(registry.active_connection_count(), 1);
238    }
239
240    #[tokio::test]
241    async fn concurrent_add_remove() {
242        let registry = Arc::new(UserSessionRegistry::new());
243        let mut handles = Vec::new();
244
245        for i in 0..20 {
246            let registry = registry.clone();
247            handles.push(tokio::spawn(async move {
248                let ip = format!("10.0.{}.{}/24", i / 256, i % 256);
249                let username = format!("user_{}", i % 5);
250                let bw = if i % 2 == 0 {
251                    Some(Bandwidth::from_bytes_per_second(1_000_000))
252                } else {
253                    None
254                };
255
256                registry.add_connection(&username, make_session(&ip), bw);
257
258                // Yield to let other tasks interleave
259                tokio::task::yield_now().await;
260
261                let addr: IpNet = ip.parse().unwrap();
262                registry.remove_connection(&username, &addr);
263            }));
264        }
265
266        for handle in handles {
267            handle.await.expect("task should not panic");
268        }
269
270        // All connections added and removed — counts should be zero
271        assert_eq!(registry.active_connection_count(), 0);
272        assert_eq!(registry.active_user_count(), 0);
273    }
274}