use crate::config::GlobalExecutor;
use dashmap::DashMap;
use std::{
collections::{BTreeMap, VecDeque},
io,
net::SocketAddr,
sync::{Arc, LazyLock, RwLock},
time::Duration,
};
use super::Socket;
use crate::simulation::{RealTime, TimeSource, VirtualTime};
const MAX_PACKET_SIZE: usize = 65535;
static ADDRESS_NETWORKS: LazyLock<DashMap<SocketAddr, String>> = LazyLock::new(DashMap::new);
static NETWORK_TIME_SOURCES: LazyLock<DashMap<String, VirtualTime>> = LazyLock::new(DashMap::new);
pub fn register_address_network(addr: SocketAddr, network_name: &str) {
ADDRESS_NETWORKS.insert(addr, network_name.to_string());
}
fn unregister_address_network(addr: &SocketAddr) {
ADDRESS_NETWORKS.remove(addr);
}
fn get_address_network(addr: &SocketAddr) -> Option<String> {
ADDRESS_NETWORKS.get(addr).map(|r| r.value().clone())
}
pub fn clear_all_address_networks() {
ADDRESS_NETWORKS.clear();
}
pub fn clear_network_address_mappings(network_name: &str) {
ADDRESS_NETWORKS.retain(|_, v| v != network_name);
}
pub fn clear_all_network_time_sources() {
NETWORK_TIME_SOURCES.clear();
}
pub fn register_network_time_source(network_name: &str, virtual_time: VirtualTime) {
NETWORK_TIME_SOURCES.insert(network_name.to_string(), virtual_time);
}
pub fn unregister_network_time_source(network_name: &str) {
NETWORK_TIME_SOURCES.remove(network_name);
}
fn get_network_time_source(network_name: &str) -> Option<VirtualTime> {
NETWORK_TIME_SOURCES.get(network_name).map(|r| r.clone())
}
#[derive(Debug)]
pub enum PacketDeliveryDecision {
Deliver,
DelayedDelivery(Duration),
QueuedDelivery {
deadline: u64,
},
Drop,
}
pub type PacketDeliveryCallback =
Arc<dyn Fn(&str, SocketAddr, SocketAddr) -> PacketDeliveryDecision + Send + Sync>;
pub type QueuePacketCallback =
Arc<dyn Fn(&str, u64, Vec<u8>, SocketAddr, SocketAddr) + Send + Sync>;
static DELIVERY_CALLBACK: LazyLock<RwLock<Option<PacketDeliveryCallback>>> =
LazyLock::new(|| RwLock::new(None));
static QUEUE_PACKET_CALLBACK: LazyLock<RwLock<Option<QueuePacketCallback>>> =
LazyLock::new(|| RwLock::new(None));
pub fn set_packet_delivery_callback(callback: Option<PacketDeliveryCallback>) {
*DELIVERY_CALLBACK.write().unwrap() = callback;
}
pub fn set_queue_packet_callback(callback: Option<QueuePacketCallback>) {
*QUEUE_PACKET_CALLBACK.write().unwrap() = callback;
}
fn check_packet_delivery(
network_name: &str,
from: SocketAddr,
to: SocketAddr,
) -> PacketDeliveryDecision {
let callback = DELIVERY_CALLBACK.read().unwrap();
match callback.as_ref() {
Some(cb) => cb(network_name, from, to),
None => PacketDeliveryDecision::Deliver,
}
}
fn queue_packet_for_delivery(
network_name: &str,
deadline: u64,
data: Vec<u8>,
from: SocketAddr,
target: SocketAddr,
) {
let callback = QUEUE_PACKET_CALLBACK.read().unwrap();
if let Some(cb) = callback.as_ref() {
cb(network_name, deadline, data, from, target);
}
}
#[derive(Debug, Clone)]
struct ReceivedPacket {
data: Vec<u8>,
from: SocketAddr,
}
struct SocketInbox {
packets: std::sync::Mutex<VecDeque<ReceivedPacket>>,
notify: Arc<tokio::sync::Notify>,
}
impl SocketInbox {
fn new() -> Self {
Self {
packets: std::sync::Mutex::new(VecDeque::new()),
notify: Arc::new(tokio::sync::Notify::new()),
}
}
fn push(&self, data: Vec<u8>, from: SocketAddr) {
self.packets
.lock()
.unwrap()
.push_back(ReceivedPacket { data, from });
self.notify.notify_one();
}
fn pop(&self) -> Option<ReceivedPacket> {
self.packets.lock().unwrap().pop_front()
}
fn notifier(&self) -> Arc<tokio::sync::Notify> {
self.notify.clone()
}
}
#[derive(Default)]
struct SocketRegistry {
sockets: BTreeMap<SocketAddr, Arc<SocketInbox>>,
}
impl SocketRegistry {
fn register(&mut self, addr: SocketAddr) -> Arc<SocketInbox> {
let inbox = Arc::new(SocketInbox::new());
self.sockets.insert(addr, inbox.clone());
inbox
}
fn unregister(&mut self, addr: &SocketAddr) {
self.sockets.remove(addr);
}
fn deliver_packet(&self, target: SocketAddr, data: Vec<u8>, from: SocketAddr) -> bool {
if let Some(inbox) = self.sockets.get(&target) {
inbox.push(data, from);
true
} else {
tracing::trace!(target = %target, "No socket registered at target address");
false
}
}
fn is_registered(&self, addr: &SocketAddr) -> bool {
self.sockets.contains_key(addr)
}
}
static SOCKET_REGISTRIES: LazyLock<DashMap<String, Arc<RwLock<SocketRegistry>>>> =
LazyLock::new(DashMap::new);
fn get_or_create_registry(network_name: &str) -> Arc<RwLock<SocketRegistry>> {
SOCKET_REGISTRIES
.entry(network_name.to_string())
.or_insert_with(|| Arc::new(RwLock::new(SocketRegistry::default())))
.value()
.clone()
}
fn get_registry(network_name: &str) -> Option<Arc<RwLock<SocketRegistry>>> {
SOCKET_REGISTRIES
.get(network_name)
.map(|r| r.value().clone())
}
pub fn deliver_packet_to_network(
network_name: &str,
target: SocketAddr,
data: Vec<u8>,
from: SocketAddr,
) -> bool {
if let Some(registry) = get_registry(network_name) {
registry.read().unwrap().deliver_packet(target, data, from)
} else {
tracing::warn!(
network = %network_name,
"Attempted to deliver packet to non-existent network"
);
false
}
}
pub fn is_socket_registered(network_name: &str, addr: &SocketAddr) -> bool {
get_registry(network_name)
.map(|r| r.read().unwrap().is_registered(addr))
.unwrap_or(false)
}
pub fn unregister_socket(network_name: &str, addr: &SocketAddr) {
if let Some(registry) = get_registry(network_name) {
registry.write().unwrap().unregister(addr);
}
}
pub fn clear_network_sockets(network_name: &str) {
if let Some(registry) = get_registry(network_name) {
registry.write().unwrap().sockets.clear();
}
}
pub fn remove_network_socket_registry(network_name: &str) {
SOCKET_REGISTRIES.remove(network_name);
}
pub fn clear_all_socket_registries() {
SOCKET_REGISTRIES.clear();
}
pub struct InMemorySocket<T: TimeSource = RealTime> {
network_name: String,
addr: SocketAddr,
inbox: Arc<SocketInbox>,
notify: Arc<tokio::sync::Notify>,
time_source: T,
}
impl<T: TimeSource> std::fmt::Debug for InMemorySocket<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InMemorySocket")
.field("network_name", &self.network_name)
.field("addr", &self.addr)
.finish_non_exhaustive()
}
}
impl<T: TimeSource> InMemorySocket<T> {
pub async fn bind_with_time_source(addr: SocketAddr, time_source: T) -> io::Result<Self> {
let network_name = get_address_network(&addr).ok_or_else(|| {
io::Error::other(format!(
"No network registered for address {}. Call register_address_network() before binding InMemorySocket.",
addr
))
})?;
let registry = get_or_create_registry(&network_name);
let inbox = registry.write().unwrap().register(addr);
let notify = inbox.notifier();
tracing::debug!(network = %network_name, addr = %addr, "InMemorySocket bound");
Ok(Self {
network_name,
addr,
inbox,
notify,
time_source,
})
}
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
loop {
if let Some(packet) = self.inbox.pop() {
let len = packet.data.len().min(buf.len());
buf[..len].copy_from_slice(&packet.data[..len]);
return Ok((len, packet.from));
}
self.notify.notified().await;
}
}
pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
if buf.len() > MAX_PACKET_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"packet too large",
));
}
let data = buf.to_vec();
match check_packet_delivery(&self.network_name, self.addr, target) {
PacketDeliveryDecision::Drop => {
Ok(buf.len())
}
PacketDeliveryDecision::DelayedDelivery(delay) => {
let network_name = self.network_name.clone();
let from = self.addr;
let time_source = self.time_source.clone();
GlobalExecutor::spawn(async move {
time_source.sleep(delay).await;
deliver_packet_to_network(&network_name, target, data, from);
});
Ok(buf.len())
}
PacketDeliveryDecision::QueuedDelivery { deadline } => {
queue_packet_for_delivery(&self.network_name, deadline, data, self.addr, target);
Ok(buf.len())
}
PacketDeliveryDecision::Deliver => {
deliver_packet_to_network(&self.network_name, target, data, self.addr);
Ok(buf.len())
}
}
}
pub fn send_to_blocking(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
if buf.len() > MAX_PACKET_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"packet too large",
));
}
let data = buf.to_vec();
match check_packet_delivery(&self.network_name, self.addr, target) {
PacketDeliveryDecision::Drop => Ok(buf.len()),
PacketDeliveryDecision::DelayedDelivery(delay) => {
std::thread::sleep(delay);
deliver_packet_to_network(&self.network_name, target, data, self.addr);
Ok(buf.len())
}
PacketDeliveryDecision::QueuedDelivery { deadline } => {
queue_packet_for_delivery(&self.network_name, deadline, data, self.addr, target);
Ok(buf.len())
}
PacketDeliveryDecision::Deliver => {
deliver_packet_to_network(&self.network_name, target, data, self.addr);
Ok(buf.len())
}
}
}
}
impl Socket for InMemorySocket<RealTime> {
async fn bind(addr: SocketAddr) -> io::Result<Self> {
Self::bind_with_time_source(addr, RealTime::new()).await
}
async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
InMemorySocket::<RealTime>::recv_from(self, buf).await
}
async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
InMemorySocket::<RealTime>::send_to(self, buf, target).await
}
fn send_to_blocking(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
InMemorySocket::<RealTime>::send_to_blocking(self, buf, target)
}
}
impl<T: TimeSource> Drop for InMemorySocket<T> {
fn drop(&mut self) {
unregister_socket(&self.network_name, &self.addr);
unregister_address_network(&self.addr);
tracing::debug!(
network = %self.network_name,
addr = %self.addr,
"InMemorySocket dropped"
);
}
}
pub struct SimulationSocket(InMemorySocket<VirtualTime>);
impl SimulationSocket {
pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
<Self as Socket>::bind(addr).await
}
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.0.recv_from(buf).await
}
pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
self.0.send_to(buf, target).await
}
}
impl std::fmt::Debug for SimulationSocket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SimulationSocket")
.field("addr", &self.0.addr)
.field("network", &self.0.network_name)
.finish()
}
}
impl Socket for SimulationSocket {
async fn bind(addr: SocketAddr) -> io::Result<Self> {
let network_name = get_address_network(&addr).ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!(
"No network registered for address {}. \
Call register_address_network() before binding SimulationSocket.",
addr
),
)
})?;
let virtual_time = get_network_time_source(&network_name).ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!(
"No VirtualTime registered for network '{}'. \
SimNetwork should register VirtualTime before nodes bind sockets.",
network_name
),
)
})?;
tracing::debug!(
addr = %addr,
network = %network_name,
"SimulationSocket binding with VirtualTime"
);
let inner = InMemorySocket::bind_with_time_source(addr, virtual_time).await?;
Ok(Self(inner))
}
async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.0.recv_from(buf).await
}
async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
self.0.send_to(buf, target).await
}
fn send_to_blocking(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
self.0.send_to_blocking(buf, target)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::simulation::VirtualTime;
#[tokio::test]
async fn test_socket_bind_and_send() {
let network = "test-bind-send";
clear_network_sockets(network);
let addr1: SocketAddr = "127.0.0.1:10001".parse().unwrap();
let addr2: SocketAddr = "127.0.0.1:10002".parse().unwrap();
register_address_network(addr1, network);
register_address_network(addr2, network);
let socket1 = InMemorySocket::bind(addr1).await.unwrap();
let socket2 = InMemorySocket::bind(addr2).await.unwrap();
let msg = b"hello";
socket1.send_to(msg, addr2).await.unwrap();
let mut buf = [0u8; 100];
let (len, from) = socket2.recv_from(&mut buf).await.unwrap();
assert_eq!(&buf[..len], msg);
assert_eq!(from, addr1);
}
#[tokio::test]
async fn test_network_isolation() {
let network1 = "test-isolation-1";
let network2 = "test-isolation-2";
clear_network_sockets(network1);
clear_network_sockets(network2);
let addr1: SocketAddr = "127.0.0.1:20001".parse().unwrap();
let addr2: SocketAddr = "127.0.0.1:20002".parse().unwrap();
register_address_network(addr1, network1);
let socket1 = InMemorySocket::bind(addr1).await.unwrap();
register_address_network(addr2, network2);
let socket2 = InMemorySocket::bind(addr2).await.unwrap();
assert!(is_socket_registered(network1, &addr1));
assert!(is_socket_registered(network2, &addr2));
assert!(!is_socket_registered(network1, &addr2));
assert!(!is_socket_registered(network2, &addr1));
drop(socket1);
drop(socket2);
}
#[tokio::test]
async fn test_bind_without_registration_fails() {
let addr: SocketAddr = "127.0.0.1:30001".parse().unwrap();
unregister_address_network(&addr);
let result = InMemorySocket::bind(addr).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::Other);
assert!(err.to_string().contains("No network registered"));
}
#[tokio::test]
async fn test_socket_with_virtual_time() {
let network = "test-virtual-time";
clear_network_sockets(network);
let addr1: SocketAddr = "127.0.0.1:40001".parse().unwrap();
let addr2: SocketAddr = "127.0.0.1:40002".parse().unwrap();
let time_source = VirtualTime::new();
register_address_network(addr1, network);
register_address_network(addr2, network);
let socket1 = InMemorySocket::bind_with_time_source(addr1, time_source.clone())
.await
.unwrap();
let socket2 = InMemorySocket::bind_with_time_source(addr2, time_source.clone())
.await
.unwrap();
let msg = b"hello from virtual time";
socket1.send_to(msg, addr2).await.unwrap();
let mut buf = [0u8; 100];
let (len, from) = socket2.recv_from(&mut buf).await.unwrap();
assert_eq!(&buf[..len], msg);
assert_eq!(from, addr1);
}
}