use crate::{mocks, Error, IoBufs};
use commonware_utils::{channel::mpsc, sync::Mutex};
use std::{
collections::HashMap,
net::{IpAddr, Ipv4Addr, SocketAddr},
ops::Range,
sync::Arc,
};
const EPHEMERAL_PORT_RANGE: Range<u16> = 32768..61000;
pub struct Sink {
sender: mocks::Sink,
}
impl crate::Sink for Sink {
async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
self.sender.send(bufs).await.map_err(|_| Error::SendFailed)
}
}
pub struct Stream {
receiver: mocks::Stream,
}
impl crate::Stream for Stream {
async fn recv(&mut self, len: usize) -> Result<IoBufs, Error> {
self.receiver.recv(len).await.map_err(|_| Error::RecvFailed)
}
fn peek(&self, max_len: usize) -> &[u8] {
self.receiver.peek(max_len)
}
}
pub struct Listener {
address: SocketAddr,
listener: mpsc::UnboundedReceiver<(SocketAddr, mocks::Sink, mocks::Stream)>,
}
impl crate::Listener for Listener {
type Sink = Sink;
type Stream = Stream;
async fn accept(&mut self) -> Result<(SocketAddr, Self::Sink, Self::Stream), Error> {
let (socket, sender, receiver) = self.listener.recv().await.ok_or(Error::ReadFailed)?;
Ok((socket, Sink { sender }, Stream { receiver }))
}
fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
Ok(self.address)
}
}
type Dialable = mpsc::UnboundedSender<(
SocketAddr,
mocks::Sink, // Listener -> Dialer
mocks::Stream, // Dialer -> Listener
)>;
#[derive(Clone)]
pub struct Network {
ephemeral: Arc<Mutex<u16>>,
listeners: Arc<Mutex<HashMap<SocketAddr, Dialable>>>,
}
impl Default for Network {
fn default() -> Self {
Self {
ephemeral: Arc::new(Mutex::new(EPHEMERAL_PORT_RANGE.start)),
listeners: Arc::new(Mutex::new(HashMap::new())),
}
}
}
impl crate::Network for Network {
type Listener = Listener;
async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
if socket.ip() == IpAddr::V4(Ipv4Addr::LOCALHOST)
&& EPHEMERAL_PORT_RANGE.contains(&socket.port())
{
return Err(Error::BindFailed);
}
let mut listeners = self.listeners.lock();
if listeners.contains_key(&socket) {
return Err(Error::BindFailed);
}
let (sender, receiver) = mpsc::unbounded_channel();
listeners.insert(socket, sender);
Ok(Listener {
address: socket,
listener: receiver,
})
}
async fn dial(&self, socket: SocketAddr) -> Result<(Sink, Stream), Error> {
let dialer = {
let mut ephemeral = self.ephemeral.lock();
let dialer = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), *ephemeral);
*ephemeral = ephemeral
.checked_add(1)
.expect("ephemeral port range exhausted");
dialer
};
let sender = {
let listeners = self.listeners.lock();
let sender = listeners.get(&socket).ok_or(Error::ConnectionFailed)?;
sender.clone()
};
let (dialer_sender, dialer_receiver) = mocks::Channel::init();
let (listener_sender, listener_receiver) = mocks::Channel::init();
sender
.send((dialer, dialer_sender, listener_receiver))
.map_err(|_| Error::ConnectionFailed)?;
Ok((
Sink {
sender: listener_sender,
},
Stream {
receiver: dialer_receiver,
},
))
}
}
#[cfg(test)]
mod tests {
use crate::network::{deterministic as DeterministicNetwork, tests};
use commonware_macros::test_group;
#[tokio::test]
async fn test_trait() {
tests::test_network_trait(DeterministicNetwork::Network::default).await;
}
#[test_group("slow")]
#[tokio::test]
async fn test_stress_trait() {
tests::stress_test_network_trait(DeterministicNetwork::Network::default).await;
}
}