use std::{fmt, io::Result, net::SocketAddr, sync::Arc};
use tracing::instrument;
use crate::net::{IpProtocol::Tcp, *};
#[cfg_attr(docsrs, doc(cfg(madsim)))]
pub struct TcpListener {
guard: Arc<BindGuard>,
rx: async_channel::Receiver<TcpStream>,
}
impl fmt::Debug for TcpListener {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("TcpListener")
.field("addr", &self.guard.addr)
.finish()
}
}
impl TcpListener {
#[instrument]
pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<TcpListener> {
let (tx, rx) = async_channel::unbounded();
let guard = BindGuard::bind(addr, Tcp, Arc::new(TcpListenerSocket { tx })).await?;
Ok(TcpListener {
guard: Arc::new(guard),
rx,
})
}
#[instrument]
pub async fn accept(&self) -> Result<(TcpStream, SocketAddr)> {
self.guard.net.rand_delay().await?;
let mut stream = (self.rx.recv().await)
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))?;
let peer_addr = stream.peer;
trace!(?peer_addr, "accept tcp connection");
stream.guard = Some(self.guard.clone());
Ok((stream, peer_addr))
}
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.guard.addr)
}
}
struct TcpListenerSocket {
tx: async_channel::Sender<TcpStream>,
}
impl Socket for TcpListenerSocket {
fn new_connection(
&self,
peer: SocketAddr,
addr: SocketAddr,
tx: PayloadSender,
rx: PayloadReceiver,
) {
let stream = TcpStream {
guard: None,
addr,
peer,
write_buf: Default::default(),
read_buf: Default::default(),
tx,
rx,
};
let _ = self.tx.try_send(stream);
}
}