use std::future::Future;
use std::io;
use std::net::{SocketAddr, ToSocketAddrs};
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::io::{AsyncRead, AsyncWrite};
pub trait TcpListenerApi: Sized + Send {
type Stream: TcpStreamApi;
fn bind<A: ToSocketAddrs + Send + 'static>(
addr: A,
) -> impl Future<Output = io::Result<Self>> + Send;
fn accept(&self) -> impl Future<Output = io::Result<(Self::Stream, SocketAddr)>> + Send;
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(Self::Stream, SocketAddr)>>;
fn local_addr(&self) -> io::Result<SocketAddr>;
fn pending_connections(&self) -> Option<usize> {
None
}
fn set_ttl(&self, ttl: u32) -> io::Result<()>;
}
pub trait TcpStreamApi: AsyncRead + AsyncWrite + Sized + Send + Unpin {
fn connect<A: ToSocketAddrs + Send + 'static>(
addr: A,
) -> impl Future<Output = io::Result<Self>> + Send;
fn peer_addr(&self) -> io::Result<SocketAddr>;
fn local_addr(&self) -> io::Result<SocketAddr>;
fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()>;
fn set_nodelay(&self, nodelay: bool) -> io::Result<()>;
fn nodelay(&self) -> io::Result<bool>;
fn set_ttl(&self, ttl: u32) -> io::Result<()>;
fn ttl(&self) -> io::Result<u32>;
}
pub struct TcpListenerBuilder<A> {
addr: A,
backlog: Option<u32>,
reuse_addr: bool,
reuse_port: bool,
only_v6: bool,
}
impl<A: ToSocketAddrs + Send + 'static> TcpListenerBuilder<A> {
pub fn new(addr: A) -> Self {
Self {
addr,
backlog: None,
reuse_addr: false,
reuse_port: false,
only_v6: false,
}
}
#[must_use]
pub fn backlog(mut self, n: u32) -> Self {
self.backlog = Some(n);
self
}
#[must_use]
pub fn reuse_addr(mut self, enable: bool) -> Self {
self.reuse_addr = enable;
self
}
#[must_use]
pub fn reuse_port(mut self, enable: bool) -> Self {
self.reuse_port = enable;
self
}
#[must_use]
pub fn only_v6(mut self, enable: bool) -> Self {
self.only_v6 = enable;
self
}
pub async fn bind(self) -> io::Result<super::listener::TcpListener> {
#[cfg(target_arch = "wasm32")]
{
let _ = self;
Err(super::browser_tcp_unsupported("TcpListenerBuilder::bind"))
}
#[cfg(not(target_arch = "wasm32"))]
{
use crate::net::lookup_all;
use socket2::{Domain, Protocol, Socket, Type};
let addrs = lookup_all(self.addr).await?;
if addrs.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"no socket addresses found",
));
}
let mut last_err = None;
for addr in addrs {
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = match Socket::new(domain, Type::STREAM, Some(Protocol::TCP)) {
Ok(s) => s,
Err(e) => {
last_err = Some(e);
continue;
}
};
if self.reuse_addr {
if let Err(e) = socket.set_reuse_address(true) {
last_err = Some(e);
continue;
}
}
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
if self.reuse_port {
if let Err(e) = socket.set_reuse_port(true) {
last_err = Some(e);
continue;
}
}
if addr.is_ipv6() && self.only_v6 {
if let Err(e) = socket.set_only_v6(true) {
last_err = Some(e);
continue;
}
}
if let Err(e) = socket.bind(&addr.into()) {
last_err = Some(e);
continue;
}
let backlog = i32::try_from(self.backlog.unwrap_or(128)).unwrap_or(i32::MAX);
if let Err(e) = socket.listen(backlog) {
last_err = Some(e);
continue;
}
if let Err(e) = socket.set_nonblocking(true) {
last_err = Some(e);
continue;
}
let listener: std::net::TcpListener = socket.into();
return super::listener::TcpListener::from_std(listener);
}
Err(last_err.unwrap_or_else(|| io::Error::other("failed to bind any address")))
}
}
}
pub trait TcpListenerExt: TcpListenerApi {
fn serve_sequential<F, Fut>(&self, handler: F) -> impl Future<Output = io::Result<()>> + Send
where
F: Fn(Self::Stream, SocketAddr) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = ()> + Send + 'static,
Self::Stream: 'static,
Self: Sync,
{
async move {
loop {
let (stream, addr) = self.accept().await?;
let handler = handler.clone();
handler(stream, addr).await;
}
}
}
fn incoming_stream(&self) -> IncomingStream<'_, Self>
where
Self: Sized,
{
IncomingStream::new(self)
}
}
impl<T: TcpListenerApi> TcpListenerExt for T {}
pub struct IncomingStream<'a, L: TcpListenerApi> {
listener: &'a L,
}
impl<'a, L: TcpListenerApi> IncomingStream<'a, L> {
pub fn new(listener: &'a L) -> Self {
Self { listener }
}
}
impl<L: TcpListenerApi> std::fmt::Debug for IncomingStream<'_, L> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IncomingStream").finish_non_exhaustive()
}
}
impl<L: TcpListenerApi> crate::stream::Stream for IncomingStream<'_, L> {
type Item = io::Result<L::Stream>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.listener.poll_accept(cx) {
Poll::Ready(Ok((stream, _addr))) => Poll::Ready(Some(Ok(stream))),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_creates_with_defaults() {
let builder = TcpListenerBuilder::new("127.0.0.1:0");
assert_eq!(builder.backlog, None);
assert!(!builder.reuse_addr);
assert!(!builder.reuse_port);
}
#[test]
fn builder_chain_works() {
let builder = TcpListenerBuilder::new("127.0.0.1:0")
.backlog(256)
.reuse_addr(true)
.reuse_port(true);
assert_eq!(builder.backlog, Some(256));
assert!(builder.reuse_addr);
assert!(builder.reuse_port);
}
}