use std::{fmt, sync::Arc};
use tokio::sync::mpsc;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
enum ConnectionStatus {
Opened,
Closed,
}
pub struct ActiveConnectionCounter {
count: usize,
reserved_count: usize,
limit: usize,
label: Arc<str>,
status_notification_tx: mpsc::UnboundedSender<ConnectionStatus>,
status_notification_rx: mpsc::UnboundedReceiver<ConnectionStatus>,
#[cfg(feature = "progress-bar")]
connection_bar: howudoin::Tx,
}
impl fmt::Debug for ActiveConnectionCounter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ActiveConnectionCounter")
.field("label", &self.label)
.field("count", &self.count)
.field("reserved_count", &self.reserved_count)
.field("limit", &self.limit)
.finish()
}
}
impl ActiveConnectionCounter {
pub fn new_counter() -> Self {
Self::new_counter_with(usize::MAX, "Active Connections")
}
pub fn new_counter_with<S: ToString>(limit: usize, label: S) -> Self {
let (status_notification_tx, status_notification_rx) = mpsc::unbounded_channel();
let label = label.to_string();
#[cfg(feature = "progress-bar")]
let connection_bar = howudoin::new_root().label(label.clone());
Self {
count: 0,
reserved_count: 0,
limit,
label: label.into(),
status_notification_rx,
status_notification_tx,
#[cfg(feature = "progress-bar")]
connection_bar,
}
}
pub fn track_connection(&mut self) -> ConnectionTracker {
ConnectionTracker::new(self)
}
pub fn update_count(&mut self) -> usize {
let previous_connections = self.count;
while let Ok(status) = self.status_notification_rx.try_recv() {
match status {
ConnectionStatus::Opened => {
self.reserved_count -= 1;
self.count += 1;
debug!(
open_connections = ?self.count,
?previous_connections,
limit = ?self.limit,
label = ?self.label,
"a peer connection was opened",
);
}
ConnectionStatus::Closed => {
self.count -= 1;
debug!(
open_connections = ?self.count,
?previous_connections,
limit = ?self.limit,
label = ?self.label,
"a peer connection was closed",
);
}
}
}
trace!(
open_connections = ?self.count,
?previous_connections,
limit = ?self.limit,
label = ?self.label,
"updated active connection count",
);
#[cfg(feature = "progress-bar")]
self.connection_bar
.set_pos(u64::try_from(self.count).expect("fits in u64"));
self.count + self.reserved_count
}
}
impl Drop for ActiveConnectionCounter {
fn drop(&mut self) {
#[cfg(feature = "progress-bar")]
self.connection_bar.close();
}
}
pub struct ConnectionTracker {
status_notification_tx: mpsc::UnboundedSender<ConnectionStatus>,
has_marked_open: bool,
label: Arc<str>,
}
impl fmt::Debug for ConnectionTracker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ConnectionTracker")
.field(&self.label)
.finish()
}
}
impl ConnectionTracker {
pub fn mark_open(&mut self) {
if !self.has_marked_open {
let _ = self.status_notification_tx.send(ConnectionStatus::Opened);
self.has_marked_open = true;
}
}
fn new(counter: &mut ActiveConnectionCounter) -> Self {
counter.reserved_count += 1;
debug!(
open_connections = ?counter.count,
limit = ?counter.limit,
label = ?counter.label,
"opening a new peer connection",
);
Self {
status_notification_tx: counter.status_notification_tx.clone(),
has_marked_open: false,
label: counter.label.clone(),
}
}
}
impl Drop for ConnectionTracker {
fn drop(&mut self) {
debug!(label = ?self.label, "closing a peer connection");
self.mark_open();
let _ = self.status_notification_tx.send(ConnectionStatus::Closed);
}
}