use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tokio::time::interval;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct ConnectionLimitConfig {
pub max_connections_per_ip: usize,
pub max_total_connections: usize,
pub idle_timeout: Duration,
pub reaper_interval: Duration,
}
impl Default for ConnectionLimitConfig {
fn default() -> Self {
Self {
max_connections_per_ip: 10,
max_total_connections: 1000,
idle_timeout: Duration::from_secs(300), reaper_interval: Duration::from_secs(60), }
}
}
#[derive(Debug, Clone)]
struct ConnectionInfo {
#[allow(dead_code)]
id: u64,
ip: IpAddr,
#[allow(dead_code)]
established_at: Instant,
last_activity: Instant,
}
impl ConnectionInfo {
fn new(id: u64, ip: IpAddr) -> Self {
let now = Instant::now();
Self {
id,
ip,
established_at: now,
last_activity: now,
}
}
fn update_activity(&mut self) {
self.last_activity = Instant::now();
}
fn is_idle(&self, timeout: Duration) -> bool {
self.last_activity.elapsed() > timeout
}
}
#[derive(Debug, Clone, Default)]
pub struct ConnectionStats {
pub current_connections: usize,
pub peak_connections: usize,
pub total_connections: u64,
pub total_rejected: u64,
pub total_reaped: u64,
}
struct ConnectionLimiterState {
config: ConnectionLimitConfig,
next_id: u64,
connections: HashMap<u64, ConnectionInfo>,
connections_per_ip: HashMap<IpAddr, Vec<u64>>,
stats: ConnectionStats,
}
impl ConnectionLimiterState {
fn new(config: ConnectionLimitConfig) -> Self {
Self {
config,
next_id: 1,
connections: HashMap::new(),
connections_per_ip: HashMap::new(),
stats: ConnectionStats::default(),
}
}
fn can_accept(&self, ip: IpAddr) -> Result<(), String> {
if self.config.max_total_connections > 0
&& self.connections.len() >= self.config.max_total_connections
{
return Err(format!(
"Maximum total connections ({}) reached",
self.config.max_total_connections
));
}
if self.config.max_connections_per_ip > 0 {
let ip_count = self
.connections_per_ip
.get(&ip)
.map(|v| v.len())
.unwrap_or(0);
if ip_count >= self.config.max_connections_per_ip {
return Err(format!(
"Maximum connections per IP ({}) reached for {}",
self.config.max_connections_per_ip, ip
));
}
}
Ok(())
}
fn register_connection(&mut self, ip: IpAddr) -> u64 {
let id = self.next_id;
self.next_id = self.next_id.wrapping_add(1);
let conn_info = ConnectionInfo::new(id, ip);
self.connections.insert(id, conn_info);
self.connections_per_ip.entry(ip).or_default().push(id);
self.stats.current_connections = self.connections.len();
self.stats.total_connections = self.stats.total_connections.wrapping_add(1);
if self.stats.current_connections > self.stats.peak_connections {
self.stats.peak_connections = self.stats.current_connections;
}
id
}
fn unregister_connection(&mut self, id: u64) {
if let Some(conn_info) = self.connections.remove(&id) {
if let Some(ip_conns) = self.connections_per_ip.get_mut(&conn_info.ip) {
ip_conns.retain(|&conn_id| conn_id != id);
if ip_conns.is_empty() {
self.connections_per_ip.remove(&conn_info.ip);
}
}
self.stats.current_connections = self.connections.len();
}
}
fn update_activity(&mut self, id: u64) {
if let Some(conn_info) = self.connections.get_mut(&id) {
conn_info.update_activity();
}
}
fn reap_idle_connections(&mut self) -> Vec<u64> {
let mut reaped = Vec::new();
let idle_timeout = self.config.idle_timeout;
for (&id, conn_info) in &self.connections {
if conn_info.is_idle(idle_timeout) {
reaped.push(id);
}
}
for &id in &reaped {
self.unregister_connection(id);
self.stats.total_reaped = self.stats.total_reaped.wrapping_add(1);
}
reaped
}
fn get_stats(&self) -> ConnectionStats {
self.stats.clone()
}
}
#[derive(Clone)]
#[allow(dead_code)]
pub struct ConnectionLimiter {
state: Arc<RwLock<ConnectionLimiterState>>,
}
impl ConnectionLimiter {
#[allow(dead_code)]
pub fn new(config: ConnectionLimitConfig) -> Self {
Self {
state: Arc::new(RwLock::new(ConnectionLimiterState::new(config))),
}
}
#[allow(dead_code)]
pub fn with_defaults() -> Self {
Self::new(ConnectionLimitConfig::default())
}
#[allow(dead_code)]
pub async fn acquire(&self, ip: IpAddr) -> Result<ConnectionGuard, String> {
let mut state = self.state.write().await;
state.can_accept(ip)?;
let id = state.register_connection(ip);
debug!(
"Connection accepted: id={}, ip={}, current={}, peak={}",
id, ip, state.stats.current_connections, state.stats.peak_connections
);
Ok(ConnectionGuard {
id,
limiter: self.clone(),
})
}
#[allow(dead_code)]
pub async fn update_activity(&self, id: u64) {
let mut state = self.state.write().await;
state.update_activity(id);
}
#[allow(dead_code)]
pub async fn get_stats(&self) -> ConnectionStats {
let state = self.state.read().await;
state.get_stats()
}
#[allow(dead_code)]
pub async fn update_config(&self, config: ConnectionLimitConfig) {
let mut state = self.state.write().await;
info!(
"Updating connection limits: max_per_ip={}, max_total={}, idle_timeout={:?}",
config.max_connections_per_ip, config.max_total_connections, config.idle_timeout
);
state.config = config;
}
#[allow(dead_code)]
pub async fn get_config(&self) -> ConnectionLimitConfig {
let state = self.state.read().await;
state.config.clone()
}
#[allow(dead_code)]
pub async fn reap_idle(&self) -> usize {
let mut state = self.state.write().await;
let reaped = state.reap_idle_connections();
let count = reaped.len();
if count > 0 {
info!("Reaped {} idle connections", count);
debug!("Reaped connection IDs: {:?}", reaped);
}
count
}
#[allow(dead_code)]
pub fn start_reaper(self) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let reaper_interval = {
let state = self.state.read().await;
state.config.reaper_interval
};
let mut ticker = interval(reaper_interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
ticker.tick().await;
let count = self.reap_idle().await;
if count > 0 {
let stats = self.get_stats().await;
info!(
"Connection reaper: reaped={}, current={}, total_reaped={}",
count, stats.current_connections, stats.total_reaped
);
}
}
})
}
#[allow(dead_code)]
async fn record_rejection(&self) {
let mut state = self.state.write().await;
state.stats.total_rejected = state.stats.total_rejected.wrapping_add(1);
}
}
#[allow(dead_code)]
pub struct ConnectionGuard {
id: u64,
limiter: ConnectionLimiter,
}
impl ConnectionGuard {
#[allow(dead_code)]
pub fn id(&self) -> u64 {
self.id
}
#[allow(dead_code)]
pub async fn update_activity(&self) {
self.limiter.update_activity(self.id).await;
}
#[allow(dead_code)]
pub fn into_id(self) -> u64 {
let id = self.id;
std::mem::forget(self); id
}
#[allow(dead_code)]
pub async fn unregister(limiter: &ConnectionLimiter, id: u64) {
let mut state = limiter.state.write().await;
state.unregister_connection(id);
debug!(
"Connection released: id={}, current={}",
id, state.stats.current_connections
);
}
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
let limiter = self.limiter.clone();
let id = self.id;
tokio::spawn(async move {
ConnectionGuard::unregister(&limiter, id).await;
});
}
}
pub struct ConnectionLimitConfigBuilder {
max_connections_per_ip: usize,
max_total_connections: usize,
idle_timeout: Duration,
reaper_interval: Duration,
}
impl ConnectionLimitConfigBuilder {
pub fn new() -> Self {
let defaults = ConnectionLimitConfig::default();
Self {
max_connections_per_ip: defaults.max_connections_per_ip,
max_total_connections: defaults.max_total_connections,
idle_timeout: defaults.idle_timeout,
reaper_interval: defaults.reaper_interval,
}
}
pub fn max_connections_per_ip(mut self, max: usize) -> Self {
self.max_connections_per_ip = max;
self
}
#[allow(dead_code)]
pub fn max_total_connections(mut self, max: usize) -> Self {
self.max_total_connections = max;
self
}
#[allow(dead_code)]
pub fn idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = timeout;
self
}
#[allow(dead_code)]
pub fn reaper_interval(mut self, interval: Duration) -> Self {
self.reaper_interval = interval;
self
}
pub fn build(self) -> ConnectionLimitConfig {
ConnectionLimitConfig {
max_connections_per_ip: self.max_connections_per_ip,
max_total_connections: self.max_total_connections,
idle_timeout: self.idle_timeout,
reaper_interval: self.reaper_interval,
}
}
}
impl Default for ConnectionLimitConfigBuilder {
fn default() -> Self {
Self::new()
}
}
pub fn from_server_config(config: &rusmes_config::ServerConfig) -> ConnectionLimiter {
let mut builder = ConnectionLimitConfigBuilder::new();
if let Some(conn_limits) = &config.connection_limits {
builder = builder.max_connections_per_ip(conn_limits.max_connections_per_ip);
builder = builder.max_total_connections(conn_limits.max_total_connections);
if let Ok(idle_secs) = conn_limits.idle_timeout_seconds() {
builder = builder.idle_timeout(Duration::from_secs(idle_secs));
}
if let Ok(reaper_secs) = conn_limits.reaper_interval_seconds() {
builder = builder.reaper_interval(Duration::from_secs(reaper_secs));
}
} else {
if let Some(rate_limit) = &config.smtp.rate_limit {
builder = builder.max_connections_per_ip(rate_limit.max_connections_per_ip);
}
builder = builder.idle_timeout(Duration::from_secs(300));
}
let limit_config = builder.build();
ConnectionLimiter::new(limit_config)
}
#[allow(dead_code)]
fn create_connection_limiter_from_config(
config: &rusmes_config::ServerConfig,
) -> ConnectionLimiter {
from_server_config(config)
}