use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use governor::clock::DefaultClock;
use governor::middleware::NoOpMiddleware;
use governor::state::{InMemoryState, NotKeyed};
use governor::{Quota, RateLimiter};
use ipnet::IpNet;
use tracing::info;
use quincy::config::Bandwidth;
pub type BandwidthLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
pub struct ConnectionSession {
pub client_address: IpNet,
pub connected_at: Instant,
}
pub struct UserSession {
connections: Vec<ConnectionSession>,
rate_limiter: Option<Arc<BandwidthLimiter>>,
}
pub struct UserSessionRegistry {
sessions: DashMap<String, UserSession>,
}
impl Default for UserSessionRegistry {
fn default() -> Self {
Self::new()
}
}
impl UserSessionRegistry {
pub fn new() -> Self {
Self {
sessions: DashMap::new(),
}
}
pub fn add_connection(
&self,
username: &str,
session: ConnectionSession,
bandwidth_limit: Option<Bandwidth>,
) -> Option<Arc<BandwidthLimiter>> {
let mut entry = self
.sessions
.entry(username.to_string())
.or_insert_with(|| {
let rate_limiter = bandwidth_limit.map(|bw| {
let kib_per_sec = bw.kib_per_second();
let rate = NonZeroU32::new(kib_per_sec).expect("kib_per_second returns >= 1");
let burst = NonZeroU32::new(kib_per_sec.max(64)).expect("burst is >= 64");
let quota = Quota::per_second(rate).allow_burst(burst);
Arc::new(RateLimiter::direct(quota))
});
info!(
"Created new session for user '{username}' (bandwidth limit: {})",
bandwidth_limit
.map(|bw| bw.to_string())
.unwrap_or_else(|| "unlimited".to_string())
);
UserSession {
connections: Vec::new(),
rate_limiter,
}
});
entry.connections.push(session);
entry.rate_limiter.clone()
}
pub fn remove_connection(&self, username: &str, client_address: &IpNet) {
if self
.sessions
.remove_if_mut(username, |_, session| {
session
.connections
.retain(|c| &c.client_address != client_address);
session.connections.is_empty()
})
.is_some()
{
info!("Removed last session for user '{username}'");
}
}
pub fn active_connection_count(&self) -> usize {
self.sessions.iter().map(|e| e.connections.len()).sum()
}
pub fn active_user_count(&self) -> usize {
self.sessions.len()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Instant;
use ipnet::IpNet;
use quincy::config::Bandwidth;
use super::{ConnectionSession, UserSessionRegistry};
fn make_session(ip: &str) -> ConnectionSession {
ConnectionSession {
client_address: ip.parse().unwrap(),
connected_at: Instant::now(),
}
}
#[test]
fn add_first_connection_creates_session() {
let registry = UserSessionRegistry::new();
registry.add_connection("alice", make_session("10.0.0.2/24"), None);
assert_eq!(registry.active_connection_count(), 1);
assert_eq!(registry.active_user_count(), 1);
}
#[test]
fn add_second_connection_shares_limiter() {
let registry = UserSessionRegistry::new();
let bw = Some(Bandwidth::from_bytes_per_second(1_250_000));
let limiter1 = registry.add_connection("alice", make_session("10.0.0.2/24"), bw);
let limiter2 = registry.add_connection("alice", make_session("10.0.0.3/24"), bw);
assert!(limiter1.is_some());
assert!(limiter2.is_some());
assert!(Arc::ptr_eq(
limiter1.as_ref().unwrap(),
limiter2.as_ref().unwrap()
));
assert_eq!(registry.active_connection_count(), 2);
assert_eq!(registry.active_user_count(), 1);
}
#[test]
fn add_connection_unlimited() {
let registry = UserSessionRegistry::new();
let limiter = registry.add_connection("bob", make_session("10.0.0.4/24"), None);
assert!(limiter.is_none());
}
#[test]
fn remove_last_connection_drops_session() {
let registry = UserSessionRegistry::new();
let addr: IpNet = "10.0.0.2/24".parse().unwrap();
registry.add_connection("alice", make_session("10.0.0.2/24"), None);
assert_eq!(registry.active_connection_count(), 1);
registry.remove_connection("alice", &addr);
assert_eq!(registry.active_connection_count(), 0);
assert_eq!(registry.active_user_count(), 0);
}
#[test]
fn remove_one_of_two_connections() {
let registry = UserSessionRegistry::new();
let addr1: IpNet = "10.0.0.2/24".parse().unwrap();
registry.add_connection("alice", make_session("10.0.0.2/24"), None);
registry.add_connection("alice", make_session("10.0.0.3/24"), None);
assert_eq!(registry.active_connection_count(), 2);
registry.remove_connection("alice", &addr1);
assert_eq!(registry.active_connection_count(), 1);
assert_eq!(registry.active_user_count(), 1);
}
#[test]
fn remove_nonexistent_connection_is_noop() {
let registry = UserSessionRegistry::new();
let addr: IpNet = "10.0.0.99/24".parse().unwrap();
registry.remove_connection("nobody", &addr);
registry.add_connection("alice", make_session("10.0.0.2/24"), None);
registry.remove_connection("alice", &addr);
assert_eq!(registry.active_connection_count(), 1);
}
#[tokio::test]
async fn concurrent_add_remove() {
let registry = Arc::new(UserSessionRegistry::new());
let mut handles = Vec::new();
for i in 0..20 {
let registry = registry.clone();
handles.push(tokio::spawn(async move {
let ip = format!("10.0.{}.{}/24", i / 256, i % 256);
let username = format!("user_{}", i % 5);
let bw = if i % 2 == 0 {
Some(Bandwidth::from_bytes_per_second(1_000_000))
} else {
None
};
registry.add_connection(&username, make_session(&ip), bw);
tokio::task::yield_now().await;
let addr: IpNet = ip.parse().unwrap();
registry.remove_connection(&username, &addr);
}));
}
for handle in handles {
handle.await.expect("task should not panic");
}
assert_eq!(registry.active_connection_count(), 0);
assert_eq!(registry.active_user_count(), 0);
}
}