use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use tokio::sync::broadcast;
use tracing::{debug, trace};
use crate::connection_core::ShardRouter;
use super::config::ConnectionPoolConfig;
pub type ConnectionId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Active,
Idle,
Draining,
Closing,
}
#[derive(Debug, Clone)]
pub struct ShardedConnectionInfo {
pub id: ConnectionId,
pub peer_addr: SocketAddr,
pub shard_index: usize,
pub state: ConnectionState,
pub connected_at: DateTime<Utc>,
connected_instant: Instant,
last_activity_instant: Instant,
pub requests: RequestCounters,
pub bytes: ByteCounters,
}
#[derive(Debug, Clone, Default)]
pub struct RequestCounters {
pub total: u64,
pub success: u64,
pub failed: u64,
}
#[derive(Debug, Clone, Default)]
pub struct ByteCounters {
pub received: u64,
pub sent: u64,
}
impl ShardedConnectionInfo {
fn new(id: ConnectionId, peer_addr: SocketAddr, shard_index: usize) -> Self {
let now = Instant::now();
Self {
id,
peer_addr,
shard_index,
state: ConnectionState::Active,
connected_at: Utc::now(),
connected_instant: now,
last_activity_instant: now,
requests: RequestCounters::default(),
bytes: ByteCounters::default(),
}
}
#[inline]
pub fn duration(&self) -> Duration {
self.connected_instant.elapsed()
}
#[inline]
pub fn idle_duration(&self) -> Duration {
self.last_activity_instant.elapsed()
}
#[inline]
pub fn is_idle(&self, threshold: Duration) -> bool {
self.idle_duration() > threshold
}
#[inline]
pub fn record_success(&mut self, bytes_in: u64, bytes_out: u64) {
self.requests.total += 1;
self.requests.success += 1;
self.bytes.received += bytes_in;
self.bytes.sent += bytes_out;
self.last_activity_instant = Instant::now();
self.state = ConnectionState::Active;
}
#[inline]
pub fn record_failure(&mut self, bytes_in: u64) {
self.requests.total += 1;
self.requests.failed += 1;
self.bytes.received += bytes_in;
self.last_activity_instant = Instant::now();
}
#[inline]
pub fn drain(&mut self) {
self.state = ConnectionState::Draining;
}
#[inline]
pub fn close(&mut self) {
self.state = ConnectionState::Closing;
}
}
pub struct ConnectionShard {
connections: DashMap<ConnectionId, ShardedConnectionInfo>,
addr_to_id: DashMap<SocketAddr, ConnectionId>,
active_count: AtomicUsize,
index: usize,
}
impl ConnectionShard {
fn new(index: usize, estimated_capacity: usize) -> Self {
Self {
connections: DashMap::with_capacity(estimated_capacity),
addr_to_id: DashMap::with_capacity(estimated_capacity),
active_count: AtomicUsize::new(0),
index,
}
}
#[inline]
pub fn len(&self) -> usize {
self.active_count.load(Ordering::Relaxed)
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn insert(&self, id: ConnectionId, peer_addr: SocketAddr) -> ShardedConnectionInfo {
let info = ShardedConnectionInfo::new(id, peer_addr, self.index);
self.connections.insert(id, info.clone());
self.addr_to_id.insert(peer_addr, id);
self.active_count.fetch_add(1, Ordering::Relaxed);
info
}
fn remove(&self, id: ConnectionId) -> Option<ShardedConnectionInfo> {
if let Some((_, info)) = self.connections.remove(&id) {
self.addr_to_id.remove(&info.peer_addr);
self.active_count.fetch_sub(1, Ordering::Relaxed);
Some(info)
} else {
None
}
}
fn get(&self, id: ConnectionId) -> Option<ShardedConnectionInfo> {
self.connections.get(&id).map(|r| r.value().clone())
}
fn get_by_addr(&self, addr: &SocketAddr) -> Option<ShardedConnectionInfo> {
self.addr_to_id.get(addr).and_then(|id| self.get(*id))
}
fn update<F>(&self, id: ConnectionId, f: F) -> bool
where
F: FnOnce(&mut ShardedConnectionInfo),
{
if let Some(mut entry) = self.connections.get_mut(&id) {
f(entry.value_mut());
true
} else {
false
}
}
fn get_idle(&self, threshold: Duration) -> Vec<ConnectionId> {
self.connections
.iter()
.filter(|r| r.value().is_idle(threshold))
.map(|r| *r.key())
.collect()
}
}
#[derive(Debug, Clone)]
pub enum PoolEvent {
Connected {
connection_id: ConnectionId,
peer_addr: SocketAddr,
shard_index: usize,
},
Disconnected {
connection_id: ConnectionId,
peer_addr: SocketAddr,
duration: Duration,
requests_total: u64,
},
Rejected {
peer_addr: SocketAddr,
reason: String,
},
Rebalance {
shard_index: usize,
count_before: usize,
count_after: usize,
},
}
pub struct ConnectionHandle {
pool: Arc<ShardedConnectionPool>,
id: ConnectionId,
shard_index: usize,
}
impl ConnectionHandle {
#[inline]
pub fn id(&self) -> ConnectionId {
self.id
}
#[inline]
pub fn shard_index(&self) -> usize {
self.shard_index
}
#[inline]
pub fn record_success(&self, bytes_in: u64, bytes_out: u64) {
self.pool.record_success(self.id, bytes_in, bytes_out);
}
#[inline]
pub fn record_failure(&self, bytes_in: u64) {
self.pool.record_failure(self.id, bytes_in);
}
pub fn info(&self) -> Option<ShardedConnectionInfo> {
self.pool.get(self.id)
}
pub fn drain(&self) {
self.pool.shards[self.shard_index].update(self.id, |info| {
info.drain();
});
}
}
impl Drop for ConnectionHandle {
fn drop(&mut self) {
self.pool.unregister(self.id);
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStatistics {
pub total_connections: u64,
pub active_connections: usize,
pub total_rejected: u64,
pub total_requests: u64,
pub total_bytes_received: u64,
pub total_bytes_sent: u64,
pub per_shard: Vec<ShardStatistics>,
pub uptime: Duration,
}
#[derive(Debug, Clone, Default)]
pub struct ShardStatistics {
pub index: usize,
pub active_connections: usize,
pub utilization: f64,
}
pub struct ShardedConnectionPool {
config: ConnectionPoolConfig,
shards: Vec<ConnectionShard>,
router: ShardRouter,
next_id: AtomicU64,
total_connections: AtomicU64,
total_rejected: AtomicU64,
active_connections: AtomicUsize,
total_requests: AtomicU64,
total_bytes_received: AtomicU64,
total_bytes_sent: AtomicU64,
created_at: Instant,
event_tx: broadcast::Sender<PoolEvent>,
max_connections: usize,
}
impl ShardedConnectionPool {
pub fn new(config: ConnectionPoolConfig) -> Arc<Self> {
let shard_count = config.shard_count;
let connections_per_shard = config.connections_per_shard;
let shards: Vec<_> = (0..shard_count)
.map(|i| ConnectionShard::new(i, connections_per_shard))
.collect();
let (event_tx, _) = broadcast::channel(4096);
Arc::new(Self {
max_connections: config.max_connections,
config,
shards,
router: ShardRouter::new(shard_count),
next_id: AtomicU64::new(1),
total_connections: AtomicU64::new(0),
total_rejected: AtomicU64::new(0),
active_connections: AtomicUsize::new(0),
total_requests: AtomicU64::new(0),
total_bytes_received: AtomicU64::new(0),
total_bytes_sent: AtomicU64::new(0),
created_at: Instant::now(),
event_tx,
})
}
pub fn default_medium() -> Arc<Self> {
Self::new(ConnectionPoolConfig::default())
}
#[inline]
fn shard_index(&self, id: ConnectionId) -> usize {
self.router.index_for_id(id)
}
#[inline]
fn shard_index_for_addr(&self, addr: &SocketAddr) -> usize {
self.router.index_for_addr(addr)
}
pub fn try_register(self: &Arc<Self>, peer_addr: SocketAddr) -> Option<ConnectionHandle> {
let current = self.active_connections.load(Ordering::Relaxed);
if current >= self.max_connections {
self.total_rejected.fetch_add(1, Ordering::Relaxed);
let _ = self.event_tx.send(PoolEvent::Rejected {
peer_addr,
reason: format!("Pool at capacity ({}/{})", current, self.max_connections),
});
trace!(
peer = %peer_addr,
current,
max = self.max_connections,
"Connection rejected: pool at capacity"
);
return None;
}
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let shard_index = self.shard_index(id);
let info = self.shards[shard_index].insert(id, peer_addr);
self.total_connections.fetch_add(1, Ordering::Relaxed);
self.active_connections.fetch_add(1, Ordering::Relaxed);
let _ = self.event_tx.send(PoolEvent::Connected {
connection_id: id,
peer_addr,
shard_index,
});
debug!(
peer = %peer_addr,
connection_id = id,
shard = shard_index,
active = self.active_connections.load(Ordering::Relaxed),
"Connection registered"
);
Some(ConnectionHandle {
pool: Arc::clone(self),
id,
shard_index: info.shard_index,
})
}
pub fn register(self: &Arc<Self>, peer_addr: SocketAddr) -> ConnectionHandle {
self.try_register(peer_addr)
.expect("Connection pool at capacity")
}
pub fn unregister(&self, id: ConnectionId) -> Option<ShardedConnectionInfo> {
let shard_index = self.shard_index(id);
if let Some(info) = self.shards[shard_index].remove(id) {
self.active_connections.fetch_sub(1, Ordering::Relaxed);
let _ = self.event_tx.send(PoolEvent::Disconnected {
connection_id: id,
peer_addr: info.peer_addr,
duration: info.duration(),
requests_total: info.requests.total,
});
debug!(
peer = %info.peer_addr,
connection_id = id,
duration_ms = info.duration().as_millis(),
requests = info.requests.total,
"Connection unregistered"
);
Some(info)
} else {
None
}
}
pub fn get(&self, id: ConnectionId) -> Option<ShardedConnectionInfo> {
let shard_index = self.shard_index(id);
self.shards[shard_index].get(id)
}
pub fn get_by_addr(&self, addr: &SocketAddr) -> Option<ShardedConnectionInfo> {
let likely_shard = self.shard_index_for_addr(addr);
if let Some(info) = self.shards[likely_shard].get_by_addr(addr) {
return Some(info);
}
for (i, shard) in self.shards.iter().enumerate() {
if i == likely_shard {
continue;
}
if let Some(info) = shard.get_by_addr(addr) {
return Some(info);
}
}
None
}
#[inline]
pub fn record_success(&self, id: ConnectionId, bytes_in: u64, bytes_out: u64) {
let shard_index = self.shard_index(id);
self.shards[shard_index].update(id, |info| {
info.record_success(bytes_in, bytes_out);
});
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.total_bytes_received
.fetch_add(bytes_in, Ordering::Relaxed);
self.total_bytes_sent
.fetch_add(bytes_out, Ordering::Relaxed);
}
#[inline]
pub fn record_failure(&self, id: ConnectionId, bytes_in: u64) {
let shard_index = self.shard_index(id);
self.shards[shard_index].update(id, |info| {
info.record_failure(bytes_in);
});
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.total_bytes_received
.fetch_add(bytes_in, Ordering::Relaxed);
}
#[inline]
pub fn active_count(&self) -> usize {
self.active_connections.load(Ordering::Relaxed)
}
#[inline]
pub fn max_connections(&self) -> usize {
self.max_connections
}
#[inline]
pub fn shard_count(&self) -> usize {
self.router.shard_count()
}
#[inline]
pub fn uptime(&self) -> Duration {
self.created_at.elapsed()
}
pub fn subscribe(&self) -> broadcast::Receiver<PoolEvent> {
self.event_tx.subscribe()
}
pub fn get_idle_connections(&self, threshold: Duration) -> Vec<ConnectionId> {
let mut idle = Vec::new();
for shard in &self.shards {
idle.extend(shard.get_idle(threshold));
}
idle
}
pub fn cleanup_idle(&self, threshold: Duration) -> usize {
let idle = self.get_idle_connections(threshold);
let count = idle.len();
for id in idle {
self.unregister(id);
}
if count > 0 {
debug!(count, "Cleaned up idle connections");
}
count
}
pub fn statistics(&self) -> PoolStatistics {
let per_shard: Vec<_> = self
.shards
.iter()
.enumerate()
.map(|(i, shard)| {
let active = shard.len();
ShardStatistics {
index: i,
active_connections: active,
utilization: active as f64 / self.config.connections_per_shard as f64,
}
})
.collect();
PoolStatistics {
total_connections: self.total_connections.load(Ordering::Relaxed),
active_connections: self.active_count(),
total_rejected: self.total_rejected.load(Ordering::Relaxed),
total_requests: self.total_requests.load(Ordering::Relaxed),
total_bytes_received: self.total_bytes_received.load(Ordering::Relaxed),
total_bytes_sent: self.total_bytes_sent.load(Ordering::Relaxed),
per_shard,
uptime: self.uptime(),
}
}
pub fn config(&self) -> &ConnectionPoolConfig {
&self.config
}
pub fn iter_all(&self) -> impl Iterator<Item = ShardedConnectionInfo> + '_ {
self.shards
.iter()
.flat_map(|shard| shard.connections.iter().map(|r| r.value().clone()))
}
#[inline]
pub fn utilization(&self) -> f64 {
self.active_count() as f64 / self.max_connections as f64
}
#[inline]
pub fn is_full(&self) -> bool {
self.active_count() >= self.max_connections
}
#[inline]
pub fn available(&self) -> usize {
self.max_connections.saturating_sub(self.active_count())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn make_addr(port: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port)
}
fn make_pool(max_connections: usize, shard_count: usize) -> Arc<ShardedConnectionPool> {
let config = ConnectionPoolConfig {
max_connections,
shard_count,
connections_per_shard: max_connections / shard_count + 1,
idle_timeout: Duration::from_secs(60),
health_check_interval: Duration::from_secs(30),
enable_metrics: true,
};
ShardedConnectionPool::new(config)
}
#[test]
fn test_basic_registration() {
let pool = make_pool(100, 4);
let handle = pool.try_register(make_addr(1001));
assert!(handle.is_some());
let handle = handle.unwrap();
assert_eq!(pool.active_count(), 1);
let info = pool.get(handle.id());
assert!(info.is_some());
assert_eq!(info.unwrap().peer_addr.port(), 1001);
}
#[test]
fn test_capacity_limit() {
let pool = make_pool(3, 2);
let h1 = pool.try_register(make_addr(1001));
let h2 = pool.try_register(make_addr(1002));
let h3 = pool.try_register(make_addr(1003));
assert!(h1.is_some());
assert!(h2.is_some());
assert!(h3.is_some());
let h4 = pool.try_register(make_addr(1004));
assert!(h4.is_none());
let stats = pool.statistics();
assert_eq!(stats.total_rejected, 1);
}
#[test]
fn test_automatic_unregister_on_drop() {
let pool = make_pool(100, 4);
{
let _handle = pool.try_register(make_addr(1001)).unwrap();
assert_eq!(pool.active_count(), 1);
}
assert_eq!(pool.active_count(), 0);
}
#[test]
fn test_request_tracking() {
let pool = make_pool(100, 4);
let handle = pool.try_register(make_addr(1001)).unwrap();
handle.record_success(100, 200);
handle.record_success(150, 250);
handle.record_failure(50);
let info = handle.info().unwrap();
assert_eq!(info.requests.total, 3);
assert_eq!(info.requests.success, 2);
assert_eq!(info.requests.failed, 1);
assert_eq!(info.bytes.received, 300);
assert_eq!(info.bytes.sent, 450);
}
#[test]
fn test_shard_distribution() {
let pool = make_pool(1000, 8);
let handles: Vec<_> = (0..100)
.map(|i| pool.try_register(make_addr(1000 + i)).unwrap())
.collect();
let stats = pool.statistics();
let total_in_shards: usize = stats.per_shard.iter().map(|s| s.active_connections).sum();
assert_eq!(total_in_shards, 100);
let non_empty_shards = stats
.per_shard
.iter()
.filter(|s| s.active_connections > 0)
.count();
assert!(
non_empty_shards >= 4,
"Expected connections distributed across shards"
);
drop(handles);
}
#[test]
fn test_idle_detection() {
let pool = make_pool(100, 4);
let _handle = pool.try_register(make_addr(1001)).unwrap();
let idle = pool.get_idle_connections(Duration::from_millis(100));
assert!(idle.is_empty());
std::thread::sleep(Duration::from_millis(10));
let idle = pool.get_idle_connections(Duration::from_secs(60));
assert!(idle.is_empty());
}
#[test]
fn test_get_by_addr() {
let pool = make_pool(100, 4);
let addr = make_addr(1001);
let _handle = pool.try_register(addr).unwrap();
let info = pool.get_by_addr(&addr);
assert!(info.is_some());
assert_eq!(info.unwrap().peer_addr, addr);
let missing = pool.get_by_addr(&make_addr(9999));
assert!(missing.is_none());
}
#[test]
fn test_statistics() {
let pool = make_pool(100, 4);
let h1 = pool.try_register(make_addr(1001)).unwrap();
let h2 = pool.try_register(make_addr(1002)).unwrap();
h1.record_success(100, 200);
h2.record_failure(50);
let stats = pool.statistics();
assert_eq!(stats.active_connections, 2);
assert_eq!(stats.total_connections, 2);
assert_eq!(stats.total_requests, 2);
assert_eq!(stats.total_bytes_received, 150);
assert_eq!(stats.total_bytes_sent, 200);
assert_eq!(stats.per_shard.len(), 4);
}
#[test]
fn test_utilization() {
let pool = make_pool(100, 4);
assert_eq!(pool.utilization(), 0.0);
assert!(!pool.is_full());
assert_eq!(pool.available(), 100);
let handles: Vec<_> = (0..50)
.map(|i| pool.try_register(make_addr(1000 + i)).unwrap())
.collect();
assert_eq!(pool.utilization(), 0.5);
assert!(!pool.is_full());
assert_eq!(pool.available(), 50);
drop(handles);
}
#[tokio::test]
async fn test_event_subscription() {
let pool = make_pool(100, 4);
let mut rx = pool.subscribe();
let addr = make_addr(1001);
let handle = pool.try_register(addr).unwrap();
let id = handle.id();
let event = rx.recv().await.unwrap();
match event {
PoolEvent::Connected {
connection_id,
peer_addr,
..
} => {
assert_eq!(connection_id, id);
assert_eq!(peer_addr, addr);
}
_ => panic!("Expected Connected event"),
}
drop(handle);
let event = rx.recv().await.unwrap();
match event {
PoolEvent::Disconnected {
connection_id,
peer_addr,
..
} => {
assert_eq!(connection_id, id);
assert_eq!(peer_addr, addr);
}
_ => panic!("Expected Disconnected event"),
}
}
#[test]
fn test_power_of_two_sharding() {
let pool = make_pool(100, 8);
assert_eq!(pool.router.shard_count(), 8);
assert_eq!(pool.shard_index(7), 7);
assert_eq!(pool.shard_index(8), 0);
let pool = make_pool(100, 6);
assert_eq!(pool.router.shard_count(), 6);
assert_eq!(pool.shard_index(7), 1);
assert_eq!(pool.shard_index(12), 0);
}
}