use std::{
future::Future,
io,
net::SocketAddr,
os::fd::{AsRawFd, FromRawFd},
pin::Pin,
sync::{atomic::Ordering, Arc},
task::{Context, Poll},
time::Duration,
};
use futures_util::{AsyncReadExt, AsyncWriteExt, FutureExt};
use smol::net::{TcpListener, TcpStream, UdpSocket};
#[cfg(feature = "compat")]
use tokio_util::compat::FuturesAsyncWriteCompatExt;
use crate::{
net::{Net, TcpStreamOwnedReadHalf, TcpStreamOwnedWriteHalf, ToSocketAddrs},
Runtime,
};
use super::SmolRuntime;
#[cfg(feature = "quinn")]
pub use super::quinn_::SmolRuntime as SmolQuinnRuntime;
#[derive(Debug, Default, Clone, Copy)]
pub struct SmolNet;
impl Net for SmolNet {
type TcpListener = SmolTcpListener;
type TcpStream = SmolTcpStream;
type UdpSocket = SmolUdpSocket;
#[cfg(feature = "quinn")]
type Quinn = SmolQuinnRuntime;
}
#[derive(Debug)]
#[repr(transparent)]
pub struct SmolTcpListener {
ln: TcpListener,
}
impl crate::net::TcpListener for SmolTcpListener {
type Stream = SmolTcpStream;
type Runtime = SmolRuntime;
fn bind<A: ToSocketAddrs<Self::Runtime>>(addr: A) -> impl Future<Output = io::Result<Self>> + Send
where
Self: Sized,
{
async move {
let mut addrs = addr.to_socket_addrs().await?;
let res = if addrs.size_hint().0 <= 1 {
if let Some(addr) = addrs.next() {
TcpListener::bind(addr).await
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid socket address",
));
}
} else {
TcpListener::bind(addrs.collect::<Vec<_>>().as_slice()).await
};
res.map(|ln| Self { ln })
}
}
fn accept(&self) -> impl Future<Output = io::Result<(Self::Stream, SocketAddr)>> + Send {
async move {
self
.ln
.accept()
.await
.map(|(stream, addr)| (SmolTcpStream { stream }, addr))
}
}
fn local_addr(&self) -> io::Result<SocketAddr> {
self.ln.local_addr()
}
}
#[derive(Debug)]
#[repr(transparent)]
pub struct SmolTcpStream {
stream: TcpStream,
}
impl futures_util::AsyncRead for SmolTcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut (&mut self.stream)).poll_read(cx, buf)
}
}
impl futures_util::AsyncWrite for SmolTcpStream {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<io::Result<usize>> {
Pin::new(&mut (&mut self.stream)).poll_write(cx, buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
Pin::new(&mut (&mut self.stream)).poll_flush(cx)
}
fn poll_close(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
Pin::new(&mut (&mut self.stream)).poll_close(cx)
}
}
#[cfg(feature = "tokio-compat")]
impl tokio::io::AsyncRead for SmolTcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(
self.get_mut(),
))
.poll_read(cx, buf)
}
}
#[cfg(feature = "tokio-compat")]
impl tokio::io::AsyncWrite for SmolTcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut tokio_util::compat::FuturesAsyncWriteCompatExt::compat_write(self.get_mut()))
.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut tokio_util::compat::FuturesAsyncWriteCompatExt::compat_write(self.get_mut()))
.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut tokio_util::compat::FuturesAsyncWriteCompatExt::compat_write(self.get_mut()))
.poll_shutdown(cx)
}
}
#[derive(Debug)]
pub struct ReuniteError(
pub SmolTcpStreamOwnedReadHalf,
pub SmolTcpStreamOwnedWriteHalf,
);
impl core::fmt::Display for ReuniteError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same socket"
)
}
}
impl std::error::Error for ReuniteError {}
#[derive(Debug)]
pub struct SmolTcpStreamOwnedReadHalf {
stream: TcpStream,
id: usize,
}
#[derive(Debug)]
pub struct SmolTcpStreamOwnedWriteHalf {
stream: TcpStream,
shutdown_on_drop: bool,
id: usize,
}
impl SmolTcpStreamOwnedWriteHalf {
pub fn forget(mut self) {
self.shutdown_on_drop = false;
drop(self);
}
}
impl Drop for SmolTcpStreamOwnedWriteHalf {
fn drop(&mut self) {
if self.shutdown_on_drop {
let _ = self.stream.shutdown(std::net::Shutdown::Write);
}
}
}
impl futures_util::AsyncRead for SmolTcpStreamOwnedReadHalf {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut (&mut self.stream)).poll_read(cx, buf)
}
}
impl futures_util::AsyncWrite for SmolTcpStreamOwnedWriteHalf {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<io::Result<usize>> {
Pin::new(&mut (&mut self.stream)).poll_write(cx, buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
Pin::new(&mut (&mut self.stream)).poll_flush(cx)
}
fn poll_close(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
Pin::new(&mut (&mut self.stream)).poll_close(cx)
}
}
#[cfg(feature = "tokio-compat")]
impl tokio::io::AsyncRead for SmolTcpStreamOwnedReadHalf {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(
self.get_mut(),
))
.poll_read(cx, buf)
}
}
#[cfg(feature = "tokio-compat")]
impl tokio::io::AsyncWrite for SmolTcpStreamOwnedWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut tokio_util::compat::FuturesAsyncWriteCompatExt::compat_write(self.get_mut()))
.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut tokio_util::compat::FuturesAsyncWriteCompatExt::compat_write(self.get_mut()))
.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut tokio_util::compat::FuturesAsyncWriteCompatExt::compat_write(self.get_mut()))
.poll_shutdown(cx)
}
}
impl TcpStreamOwnedReadHalf for SmolTcpStreamOwnedReadHalf {
type Runtime = SmolRuntime;
fn local_addr(&self) -> io::Result<SocketAddr> {
self.stream.local_addr()
}
fn peer_addr(&self) -> io::Result<SocketAddr> {
self.stream.peer_addr()
}
}
impl TcpStreamOwnedWriteHalf for SmolTcpStreamOwnedWriteHalf {
type Runtime = SmolRuntime;
fn forget(mut self) {
self.shutdown_on_drop = false;
drop(self);
}
fn local_addr(&self) -> io::Result<SocketAddr> {
self.stream.local_addr()
}
fn peer_addr(&self) -> io::Result<SocketAddr> {
self.stream.peer_addr()
}
}
impl crate::net::TcpStream for SmolTcpStream {
type Runtime = SmolRuntime;
type OwnedReadHalf = SmolTcpStreamOwnedReadHalf;
type OwnedWriteHalf = SmolTcpStreamOwnedWriteHalf;
type ReuniteError = ReuniteError;
fn connect<A: ToSocketAddrs<Self::Runtime>>(
addr: A,
) -> impl Future<Output = io::Result<Self>> + Send
where
Self: Sized,
{
async move {
let mut addrs = addr.to_socket_addrs().await?;
let res = if addrs.size_hint().0 <= 1 {
if let Some(addr) = addrs.next() {
TcpStream::connect(addr).await
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid socket address",
));
}
} else {
TcpStream::connect(&addrs.collect::<Vec<_>>().as_slice()).await
};
res.map(|stream| Self { stream })
}
}
fn local_addr(&self) -> io::Result<SocketAddr> {
self.stream.local_addr()
}
fn peer_addr(&self) -> io::Result<SocketAddr> {
self.stream.peer_addr()
}
fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.stream.set_ttl(ttl)
}
fn ttl(&self) -> io::Result<u32> {
self.stream.ttl()
}
fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.stream.set_nodelay(nodelay)
}
fn nodelay(&self) -> io::Result<bool> {
self.stream.nodelay()
}
fn into_split(self) -> (Self::OwnedReadHalf, Self::OwnedWriteHalf) {
let id = &self.stream as *const _ as usize;
(
SmolTcpStreamOwnedReadHalf {
stream: self.stream.clone(),
id,
},
SmolTcpStreamOwnedWriteHalf {
stream: self.stream,
shutdown_on_drop: true,
id,
},
)
}
fn reunite(
read: Self::OwnedReadHalf,
mut write: Self::OwnedWriteHalf,
) -> Result<Self, Self::ReuniteError>
where
Self: Sized,
{
if read.id == write.id {
write.shutdown_on_drop = false;
Ok(Self {
stream: read.stream,
})
} else {
Err(ReuniteError(read, write))
}
}
fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
unsafe { socket2::Socket::from_raw_fd(self.stream.as_raw_fd()) }.shutdown(how)
}
}
#[derive(Debug)]
#[repr(transparent)]
pub struct SmolUdpSocket {
socket: UdpSocket,
}
impl crate::net::UdpSocket for SmolUdpSocket {
type Runtime = SmolRuntime;
fn bind<A: ToSocketAddrs<Self::Runtime>>(addr: A) -> impl Future<Output = io::Result<Self>> + Send
where
Self: Sized,
{
async move {
let mut addrs = addr.to_socket_addrs().await?;
let res = if addrs.size_hint().0 <= 1 {
if let Some(addr) = addrs.next() {
UdpSocket::bind(addr).await
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid socket address",
));
}
} else {
UdpSocket::bind(&addrs.collect::<Vec<_>>().as_slice()).await
};
res.map(|socket| Self { socket })
}
}
fn connect<A: ToSocketAddrs<Self::Runtime>>(
&self,
addr: A,
) -> impl Future<Output = io::Result<()>> + Send {
async move {
let mut addrs = addr.to_socket_addrs().await?;
if addrs.size_hint().0 <= 1 {
if let Some(addr) = addrs.next() {
self.socket.connect(addr).await
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid socket address",
));
}
} else {
self
.socket
.connect(&addrs.collect::<Vec<_>>().as_slice())
.await
}
}
}
fn recv(&self, buf: &mut [u8]) -> impl Future<Output = io::Result<usize>> + Send {
async move { self.socket.recv(buf).await }
}
fn recv_from(
&self,
buf: &mut [u8],
) -> impl Future<Output = io::Result<(usize, SocketAddr)>> + Send {
async move { self.socket.recv_from(buf).await }
}
fn send(&self, buf: &[u8]) -> impl Future<Output = io::Result<usize>> + Send {
async move { self.socket.send(buf).await }
}
fn send_to<A: ToSocketAddrs<Self::Runtime>>(
&self,
buf: &[u8],
target: A,
) -> impl Future<Output = io::Result<usize>> + Send {
async move {
let mut addrs = target.to_socket_addrs().await?;
if let Some(addr) = addrs.next() {
self.socket.send_to(buf, addr).await
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid socket address",
));
}
}
}
fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.socket.set_ttl(ttl)
}
fn ttl(&self) -> io::Result<u32> {
self.socket.ttl()
}
fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
self.socket.set_broadcast(broadcast)
}
fn broadcast(&self) -> io::Result<bool> {
self.socket.broadcast()
}
fn set_read_buffer(&self, size: usize) -> io::Result<()> {
#[cfg(not(any(unix, windows)))]
{
panic!("unsupported platform");
}
#[cfg(all(unix, feature = "socket2"))]
{
use std::os::fd::AsRawFd;
return crate::net::set_read_buffer(self.socket.as_raw_fd(), size);
}
#[cfg(all(windows, feature = "socket2"))]
{
use std::os::windows::io::AsRawSocket;
return crate::net::set_read_buffer(self.socket.as_raw_socket(), size);
}
let _ = size;
Ok(())
}
fn set_write_buffer(&self, size: usize) -> io::Result<()> {
#[cfg(not(any(unix, windows)))]
{
panic!("unsupported platform");
}
#[cfg(all(unix, feature = "socket2"))]
{
use std::os::fd::AsRawFd;
return crate::net::set_write_buffer(self.socket.as_raw_fd(), size);
}
#[cfg(all(windows, feature = "socket2"))]
{
use std::os::windows::io::AsRawSocket;
return crate::net::set_write_buffer(self.socket.as_raw_socket(), size);
}
let _ = size;
Ok(())
}
fn poll_recv_from(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>> {
let fut = self.socket.recv_from(buf);
futures_util::pin_mut!(fut);
fut.poll_unpin(cx)
}
fn poll_send_to(
&self,
cx: &mut Context<'_>,
buf: &[u8],
target: SocketAddr,
) -> Poll<io::Result<usize>> {
let fut = self.socket.send_to(buf, target);
futures_util::pin_mut!(fut);
fut.poll_unpin(cx)
}
fn local_addr(&self) -> io::Result<SocketAddr> {
self.socket.local_addr()
}
}