use bytes::Bytes;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::sync::Notify;
use crate::error::{Error, Result};
const SHARDS: usize = 64;
pub struct UdpCore {
shards: Box<[Shard; SHARDS]>,
stats: CoreStats,
}
struct CoreStats {
delivered: AtomicU64,
expired: AtomicU64,
}
#[derive(Debug, Clone, Copy)]
pub struct TransportStats {
pub delivered: u64,
pub expired: u64,
}
struct Shard {
pending: Mutex<HashMap<i32, ResponseSlot>>,
}
struct ResponseSlot {
response: Option<(Bytes, SocketAddr)>,
deadline: Instant,
notify: Arc<Notify>,
}
impl UdpCore {
pub fn new() -> Self {
let shards: Vec<Shard> = (0..SHARDS)
.map(|_| Shard {
pending: Mutex::new(HashMap::new()),
})
.collect();
Self {
shards: shards
.try_into()
.unwrap_or_else(|_| unreachable!("Vec has exactly SHARDS elements")),
stats: CoreStats {
delivered: AtomicU64::new(0),
expired: AtomicU64::new(0),
},
}
}
fn shard(&self, request_id: i32) -> &Shard {
&self.shards[request_id as usize % SHARDS]
}
pub fn register(&self, request_id: i32, timeout: Duration) {
let shard = self.shard(request_id);
let slot = ResponseSlot {
response: None,
deadline: Instant::now() + timeout,
notify: Arc::new(Notify::new()),
};
shard.pending.lock().unwrap().insert(request_id, slot);
}
pub fn deliver(&self, request_id: i32, data: Bytes, source: SocketAddr) -> bool {
let shard = self.shard(request_id);
let mut pending = shard.pending.lock().unwrap();
if let Some(slot) = pending.get_mut(&request_id) {
slot.response = Some((data, source));
let notify = slot.notify.clone();
drop(pending);
notify.notify_one();
self.stats.delivered.fetch_add(1, Ordering::Relaxed);
return true;
}
false
}
pub async fn wait_for_response(
&self,
request_id: i32,
target: SocketAddr,
) -> Result<(Bytes, SocketAddr)> {
let shard = self.shard(request_id);
loop {
let (notify, deadline) = {
let mut pending = shard.pending.lock().unwrap();
if let Some(slot) = pending.get_mut(&request_id) {
if let Some(response) = slot.response.take() {
pending.remove(&request_id);
return Ok(response);
}
(slot.notify.clone(), slot.deadline)
} else {
tracing::debug!(target: "async_snmp::transport::udp", { request_id, %target, elapsed = ?Duration::ZERO }, "transport timeout (slot missing)");
return Err(Error::Timeout {
target,
elapsed: Duration::ZERO,
retries: 0,
}
.boxed());
}
};
let now = Instant::now();
if now >= deadline {
self.unregister(request_id);
self.stats.expired.fetch_add(1, Ordering::Relaxed);
let elapsed = now.saturating_duration_since(deadline - Duration::from_secs(1));
tracing::debug!(target: "async_snmp::transport::udp", { request_id, %target, ?elapsed }, "transport timeout");
return Err(Error::Timeout {
target,
elapsed,
retries: 0,
}
.boxed());
}
tokio::select! {
_ = notify.notified() => {
}
_ = tokio::time::sleep_until(tokio::time::Instant::from_std(deadline)) => {
}
}
}
}
pub fn stats(&self) -> TransportStats {
TransportStats {
delivered: self.stats.delivered.load(Ordering::Relaxed),
expired: self.stats.expired.load(Ordering::Relaxed),
}
}
pub fn unregister(&self, request_id: i32) {
let shard = self.shard(request_id);
shard.pending.lock().unwrap().remove(&request_id);
}
pub fn cleanup_expired(&self) {
let now = Instant::now();
for shard in self.shards.iter() {
let mut pending = shard.pending.lock().unwrap();
pending.retain(|_, slot| slot.deadline > now);
}
}
}
impl Default for UdpCore {
fn default() -> Self {
Self::new()
}
}