use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use crate::ConnectionEvent;
use super::ConnectionEventFn;
pub(crate) struct ConnectionTracker {
counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
peer: iroh::PublicKey,
peer_id_str: String,
on_event: Option<ConnectionEventFn>,
total: Arc<AtomicUsize>,
}
impl ConnectionTracker {
pub(crate) fn acquire(
counts: &Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
peer: iroh::PublicKey,
peer_id_str: String,
max_per_peer: usize,
on_event: Option<ConnectionEventFn>,
total: Arc<AtomicUsize>,
) -> Option<Self> {
let mut map = counts.lock().unwrap_or_else(|e| e.into_inner());
let count = map.entry(peer).or_insert(0);
if *count >= max_per_peer {
return None;
}
let was_zero = *count == 0;
*count = count.saturating_add(1);
drop(map);
total.fetch_add(1, Ordering::Relaxed);
if was_zero {
if let Some(cb) = &on_event {
cb(ConnectionEvent {
peer_id: peer_id_str.clone(),
connected: true,
});
}
}
Some(ConnectionTracker {
counts: counts.clone(),
peer,
peer_id_str,
on_event,
total,
})
}
}
impl Drop for ConnectionTracker {
fn drop(&mut self) {
self.total.fetch_sub(1, Ordering::Relaxed);
let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner());
if let Some(c) = map.get_mut(&self.peer) {
*c = c.saturating_sub(1);
if *c == 0 {
map.remove(&self.peer);
if let Some(cb) = &self.on_event {
cb(ConnectionEvent {
peer_id: self.peer_id_str.clone(),
connected: false,
});
}
}
}
}
}
pub(crate) struct RequestTracker {
counter: Arc<AtomicUsize>,
in_flight: Arc<AtomicUsize>,
drain_notify: Arc<tokio::sync::Notify>,
}
impl RequestTracker {
pub(crate) fn new(
counter: Arc<AtomicUsize>,
in_flight: Arc<AtomicUsize>,
drain_notify: Arc<tokio::sync::Notify>,
) -> Self {
RequestTracker {
counter,
in_flight,
drain_notify,
}
}
}
impl Drop for RequestTracker {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::Relaxed);
if self.in_flight.fetch_sub(1, Ordering::AcqRel) == 1 {
self.drain_notify.notify_waiters();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_peer() -> iroh::PublicKey {
iroh::SecretKey::generate().public()
}
#[test]
fn connection_tracker_increments_and_decrements_total() {
let total = Arc::new(AtomicUsize::new(0));
let counts = Arc::new(Mutex::new(HashMap::new()));
let peer = dummy_peer();
{
let _t =
ConnectionTracker::acquire(&counts, peer, "p".to_string(), 4, None, total.clone())
.expect("acquire should succeed under cap");
assert_eq!(total.load(Ordering::Relaxed), 1);
}
assert_eq!(total.load(Ordering::Relaxed), 0);
assert!(counts.lock().unwrap().is_empty());
}
#[test]
fn connection_tracker_enforces_per_peer_cap() {
let total = Arc::new(AtomicUsize::new(0));
let counts = Arc::new(Mutex::new(HashMap::new()));
let peer = dummy_peer();
let _a =
ConnectionTracker::acquire(&counts, peer, "p".into(), 1, None, total.clone()).unwrap();
let b = ConnectionTracker::acquire(&counts, peer, "p".into(), 1, None, total.clone());
assert!(b.is_none(), "second acquire over cap must fail");
assert_eq!(total.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn request_tracker_notifies_when_in_flight_reaches_zero() {
let counter = Arc::new(AtomicUsize::new(1));
let in_flight = Arc::new(AtomicUsize::new(1));
let drain = Arc::new(tokio::sync::Notify::new());
let waiter = drain.clone();
let waited = tokio::spawn(async move {
waiter.notified().await;
});
tokio::task::yield_now().await;
let tracker = RequestTracker::new(counter.clone(), in_flight.clone(), drain.clone());
drop(tracker);
tokio::time::timeout(std::time::Duration::from_millis(100), waited)
.await
.expect("waiter must be notified")
.unwrap();
assert_eq!(counter.load(Ordering::Relaxed), 0);
assert_eq!(in_flight.load(Ordering::Relaxed), 0);
}
}