use crate::p2::P2TcpStreamingState;
use crate::runtime::with_ambient_tokio_runtime;
use crate::sockets::util::{
ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
is_valid_unicast_address, receive_buffer_size, send_buffer_size, set_keep_alive_count,
set_keep_alive_idle_time, set_keep_alive_interval, set_receive_buffer_size,
set_send_buffer_size, set_unicast_hop_limit, tcp_bind,
};
use crate::sockets::{DEFAULT_TCP_BACKLOG, SocketAddressFamily, WasiSocketsCtx};
use io_lifetimes::AsSocketlike as _;
use io_lifetimes::views::SocketlikeView;
use rustix::io::Errno;
use rustix::net::sockopt;
use std::fmt::Debug;
use std::io;
use std::mem;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use std::time::Duration;
enum TcpState {
Default(tokio::net::TcpSocket),
BindStarted(tokio::net::TcpSocket),
Bound(tokio::net::TcpSocket),
ListenStarted(tokio::net::TcpSocket),
Listening {
listener: Arc<tokio::net::TcpListener>,
pending_accept: Option<io::Result<tokio::net::TcpStream>>,
},
Connecting(Option<Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>>),
ConnectReady(io::Result<tokio::net::TcpStream>),
Connected(Arc<tokio::net::TcpStream>),
#[cfg(feature = "p3")]
Receiving(Arc<tokio::net::TcpStream>),
P2Streaming(Box<P2TcpStreamingState>),
#[cfg(feature = "p3")]
Error(io::Error),
Closed,
}
impl Debug for TcpState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Default(_) => f.debug_tuple("Default").finish(),
Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
Self::Bound(_) => f.debug_tuple("Bound").finish(),
Self::ListenStarted { .. } => f.debug_tuple("ListenStarted").finish(),
Self::Listening { .. } => f.debug_tuple("Listening").finish(),
Self::Connecting(..) => f.debug_tuple("Connecting").finish(),
Self::ConnectReady(..) => f.debug_tuple("ConnectReady").finish(),
Self::Connected { .. } => f.debug_tuple("Connected").finish(),
#[cfg(feature = "p3")]
Self::Receiving { .. } => f.debug_tuple("Receiving").finish(),
Self::P2Streaming(_) => f.debug_tuple("P2Streaming").finish(),
#[cfg(feature = "p3")]
Self::Error(..) => f.debug_tuple("Error").finish(),
Self::Closed => write!(f, "Closed"),
}
}
}
pub struct TcpSocket {
tcp_state: TcpState,
listen_backlog_size: u32,
family: SocketAddressFamily,
options: NonInheritedOptions,
}
impl TcpSocket {
pub(crate) fn new(
ctx: &WasiSocketsCtx,
family: SocketAddressFamily,
) -> Result<Self, ErrorCode> {
ctx.allowed_network_uses.check_allowed_tcp()?;
with_ambient_tokio_runtime(|| {
let socket = match family {
SocketAddressFamily::Ipv4 => tokio::net::TcpSocket::new_v4()?,
SocketAddressFamily::Ipv6 => {
let socket = tokio::net::TcpSocket::new_v6()?;
sockopt::set_ipv6_v6only(&socket, true)?;
socket
}
};
Ok(Self::from_state(TcpState::Default(socket), family))
})
}
#[cfg(feature = "p3")]
pub(crate) fn new_error(err: io::Error, family: SocketAddressFamily) -> Self {
TcpSocket::from_state(TcpState::Error(err), family)
}
pub(crate) fn new_accept(
result: io::Result<tokio::net::TcpStream>,
options: &NonInheritedOptions,
family: SocketAddressFamily,
) -> io::Result<Self> {
let client = result.map_err(|err| match Errno::from_io_error(&err) {
#[cfg(windows)]
Some(Errno::INPROGRESS) => Errno::INTR.into(),
#[cfg(target_os = "linux")]
Some(
Errno::CONNRESET
| Errno::NETRESET
| Errno::HOSTUNREACH
| Errno::HOSTDOWN
| Errno::NETDOWN
| Errno::NETUNREACH
| Errno::PROTO
| Errno::NOPROTOOPT
| Errno::NONET
| Errno::OPNOTSUPP,
) => Errno::CONNABORTED.into(),
_ => err,
})?;
options.apply(family, &client);
Ok(Self::from_state(
TcpState::Connected(Arc::new(client)),
family,
))
}
fn from_state(state: TcpState, family: SocketAddressFamily) -> Self {
Self {
tcp_state: state,
listen_backlog_size: DEFAULT_TCP_BACKLOG,
family,
options: Default::default(),
}
}
pub(crate) fn as_std_view(&self) -> Result<SocketlikeView<'_, std::net::TcpStream>, ErrorCode> {
match &self.tcp_state {
TcpState::Default(socket)
| TcpState::BindStarted(socket)
| TcpState::Bound(socket)
| TcpState::ListenStarted(socket) => Ok(socket.as_socketlike_view()),
TcpState::Connected(stream) => Ok(stream.as_socketlike_view()),
#[cfg(feature = "p3")]
TcpState::Receiving(stream) => Ok(stream.as_socketlike_view()),
TcpState::Listening { listener, .. } => Ok(listener.as_socketlike_view()),
TcpState::P2Streaming(state) => Ok(state.stream.as_socketlike_view()),
TcpState::Connecting(..) | TcpState::ConnectReady(_) | TcpState::Closed => {
Err(ErrorCode::InvalidState)
}
#[cfg(feature = "p3")]
TcpState::Error(err) => Err(err.into()),
}
}
pub(crate) fn start_bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
let ip = addr.ip();
if !is_valid_unicast_address(ip) || !is_valid_address_family(ip, self.family) {
return Err(ErrorCode::InvalidArgument);
}
match mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::Default(sock) => {
if let Err(err) = tcp_bind(&sock, addr) {
self.tcp_state = TcpState::Default(sock);
Err(err)
} else {
self.tcp_state = TcpState::BindStarted(sock);
Ok(())
}
}
tcp_state => {
self.tcp_state = tcp_state;
Err(ErrorCode::InvalidState)
}
}
}
pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
match mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::BindStarted(socket) => {
self.tcp_state = TcpState::Bound(socket);
Ok(())
}
current_state => {
self.tcp_state = current_state;
Err(ErrorCode::NotInProgress)
}
}
}
pub(crate) fn start_connect(
&mut self,
addr: &SocketAddr,
) -> Result<tokio::net::TcpSocket, ErrorCode> {
match self.tcp_state {
TcpState::Default(..) | TcpState::Bound(..) => {}
TcpState::Connecting(..) => {
return Err(ErrorCode::ConcurrencyConflict);
}
_ => return Err(ErrorCode::InvalidState),
};
if !is_valid_unicast_address(addr.ip())
|| !is_valid_remote_address(*addr)
|| !is_valid_address_family(addr.ip(), self.family)
{
return Err(ErrorCode::InvalidArgument);
};
let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) =
mem::replace(&mut self.tcp_state, TcpState::Connecting(None))
else {
unreachable!();
};
Ok(tokio_socket)
}
pub(crate) fn set_pending_connect(
&mut self,
future: impl Future<Output = io::Result<tokio::net::TcpStream>> + Send + 'static,
) -> Result<(), ErrorCode> {
match &mut self.tcp_state {
TcpState::Connecting(slot @ None) => {
*slot = Some(Box::pin(future));
Ok(())
}
_ => Err(ErrorCode::InvalidState),
}
}
pub(crate) fn take_pending_connect(
&mut self,
) -> Result<Option<io::Result<tokio::net::TcpStream>>, ErrorCode> {
match mem::replace(&mut self.tcp_state, TcpState::Connecting(None)) {
TcpState::ConnectReady(result) => Ok(Some(result)),
TcpState::Connecting(Some(mut future)) => {
let mut cx = Context::from_waker(Waker::noop());
match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
Poll::Ready(result) => Ok(Some(result)),
Poll::Pending => {
self.tcp_state = TcpState::Connecting(Some(future));
Ok(None)
}
}
}
current_state => {
self.tcp_state = current_state;
Err(ErrorCode::NotInProgress)
}
}
}
pub(crate) fn finish_connect(
&mut self,
result: io::Result<tokio::net::TcpStream>,
) -> Result<(), ErrorCode> {
if !matches!(self.tcp_state, TcpState::Connecting(None)) {
return Err(ErrorCode::InvalidState);
}
match result {
Ok(stream) => {
self.tcp_state = TcpState::Connected(Arc::new(stream));
Ok(())
}
Err(err) => {
self.tcp_state = TcpState::Closed;
Err(ErrorCode::from(err))
}
}
}
pub(crate) fn start_listen_p2(&mut self) -> Result<(), ErrorCode> {
match mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::Bound(tokio_socket) => {
self.tcp_state = TcpState::ListenStarted(tokio_socket);
Ok(())
}
previous_state => {
self.tcp_state = previous_state;
Err(ErrorCode::InvalidState)
}
}
}
pub(crate) fn finish_listen_p2(&mut self) -> Result<(), ErrorCode> {
let tokio_socket = match mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::ListenStarted(tokio_socket) => tokio_socket,
previous_state => {
self.tcp_state = previous_state;
return Err(ErrorCode::NotInProgress);
}
};
self.listen_common(tokio_socket)
}
#[cfg(feature = "p3")]
pub(crate) fn listen_p3(&mut self) -> Result<(), ErrorCode> {
let tokio_socket = match mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::Bound(tokio_socket) => tokio_socket,
TcpState::Default(tokio_socket) => {
let implicit_addr = crate::sockets::util::implicit_bind_addr(self.family);
tcp_bind(&tokio_socket, implicit_addr)?;
tokio_socket
}
previous_state => {
self.tcp_state = previous_state;
return Err(ErrorCode::InvalidState);
}
};
self.listen_common(tokio_socket)
}
fn listen_common(&mut self, tokio_socket: tokio::net::TcpSocket) -> Result<(), ErrorCode> {
match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
Ok(listener) => {
self.tcp_state = TcpState::Listening {
listener: Arc::new(listener),
pending_accept: None,
};
Ok(())
}
Err(err) => {
self.tcp_state = TcpState::Closed;
Err(match Errno::from_io_error(&err) {
#[cfg(windows)]
Some(Errno::MFILE) => Errno::NOBUFS.into(),
_ => err.into(),
})
}
}
}
pub(crate) fn accept(&mut self) -> Result<Option<Self>, ErrorCode> {
let TcpState::Listening {
listener,
pending_accept,
} = &mut self.tcp_state
else {
return Err(ErrorCode::InvalidState);
};
let result = match pending_accept.take() {
Some(result) => result,
None => {
let mut cx = std::task::Context::from_waker(Waker::noop());
match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
.map_ok(|(stream, _)| stream)
{
Poll::Ready(result) => result,
Poll::Pending => return Ok(None),
}
}
};
Ok(Some(Self::new_accept(result, &self.options, self.family)?))
}
#[cfg(feature = "p3")]
pub(crate) fn start_receive(&mut self) -> Option<&Arc<tokio::net::TcpStream>> {
match mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::Connected(stream) => {
self.tcp_state = TcpState::Receiving(stream);
Some(self.tcp_stream_arc().unwrap())
}
prev => {
self.tcp_state = prev;
None
}
}
}
pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
match &self.tcp_state {
TcpState::Bound(socket) => Ok(socket.local_addr()?),
TcpState::Connected(stream) => Ok(stream.local_addr()?),
#[cfg(feature = "p3")]
TcpState::Receiving(stream) => Ok(stream.local_addr()?),
TcpState::P2Streaming(state) => Ok(state.stream.local_addr()?),
TcpState::Listening { listener, .. } => Ok(listener.local_addr()?),
#[cfg(feature = "p3")]
TcpState::Error(err) => Err(err.into()),
_ => Err(ErrorCode::InvalidState),
}
}
pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
let stream = self.tcp_stream_arc()?;
let addr = stream.peer_addr()?;
Ok(addr)
}
pub(crate) fn is_listening(&self) -> bool {
matches!(self.tcp_state, TcpState::Listening { .. })
}
pub(crate) fn address_family(&self) -> SocketAddressFamily {
self.family
}
pub(crate) fn set_listen_backlog_size(&mut self, value: u64) -> Result<(), ErrorCode> {
const MIN_BACKLOG: u32 = 1;
const MAX_BACKLOG: u32 = i32::MAX as u32;
if value == 0 {
return Err(ErrorCode::InvalidArgument);
}
let value = value
.try_into()
.unwrap_or(MAX_BACKLOG)
.clamp(MIN_BACKLOG, MAX_BACKLOG);
match &self.tcp_state {
TcpState::Default(..) | TcpState::Bound(..) => {
self.listen_backlog_size = value;
Ok(())
}
TcpState::Listening { listener, .. } => {
if rustix::net::listen(&listener, value.try_into().unwrap_or(i32::MAX)).is_err() {
return Err(ErrorCode::NotSupported);
}
self.listen_backlog_size = value;
Ok(())
}
#[cfg(feature = "p3")]
TcpState::Error(err) => Err(err.into()),
_ => Err(ErrorCode::InvalidState),
}
}
pub(crate) fn keep_alive_enabled(&self) -> Result<bool, ErrorCode> {
let fd = &*self.as_std_view()?;
let v = sockopt::socket_keepalive(fd)?;
Ok(v)
}
pub(crate) fn set_keep_alive_enabled(&self, value: bool) -> Result<(), ErrorCode> {
let fd = &*self.as_std_view()?;
sockopt::set_socket_keepalive(fd, value)?;
Ok(())
}
pub(crate) fn keep_alive_idle_time(&self) -> Result<u64, ErrorCode> {
let fd = &*self.as_std_view()?;
let v = sockopt::tcp_keepidle(fd)?;
Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
}
pub(crate) fn set_keep_alive_idle_time(&mut self, value: u64) -> Result<(), ErrorCode> {
let value = {
let fd = self.as_std_view()?;
set_keep_alive_idle_time(&*fd, value)?
};
self.options.set_keep_alive_idle_time(value);
Ok(())
}
pub(crate) fn keep_alive_interval(&self) -> Result<u64, ErrorCode> {
let fd = &*self.as_std_view()?;
let v = sockopt::tcp_keepintvl(fd)?;
Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
}
pub(crate) fn set_keep_alive_interval(&self, value: u64) -> Result<(), ErrorCode> {
let fd = &*self.as_std_view()?;
set_keep_alive_interval(fd, Duration::from_nanos(value))?;
Ok(())
}
pub(crate) fn keep_alive_count(&self) -> Result<u32, ErrorCode> {
let fd = &*self.as_std_view()?;
let v = sockopt::tcp_keepcnt(fd)?;
Ok(v)
}
pub(crate) fn set_keep_alive_count(&self, value: u32) -> Result<(), ErrorCode> {
let fd = &*self.as_std_view()?;
set_keep_alive_count(fd, value)?;
Ok(())
}
pub(crate) fn hop_limit(&self) -> Result<u8, ErrorCode> {
let fd = &*self.as_std_view()?;
let n = get_unicast_hop_limit(fd, self.family)?;
Ok(n)
}
pub(crate) fn set_hop_limit(&mut self, value: u8) -> Result<(), ErrorCode> {
{
let fd = &*self.as_std_view()?;
set_unicast_hop_limit(fd, self.family, value)?;
}
self.options.set_hop_limit(value);
Ok(())
}
pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
let fd = &*self.as_std_view()?;
let n = receive_buffer_size(fd)?;
Ok(n)
}
pub(crate) fn set_receive_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
let res = {
let fd = &*self.as_std_view()?;
set_receive_buffer_size(fd, value)?
};
self.options.set_receive_buffer_size(res);
Ok(())
}
pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
let fd = &*self.as_std_view()?;
let n = send_buffer_size(fd)?;
Ok(n)
}
pub(crate) fn set_send_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
let res = {
let fd = &*self.as_std_view()?;
set_send_buffer_size(fd, value)?
};
self.options.set_send_buffer_size(res);
Ok(())
}
#[cfg(feature = "p3")]
pub(crate) fn non_inherited_options(&self) -> &NonInheritedOptions {
&self.options
}
#[cfg(feature = "p3")]
pub(crate) fn tcp_listener_arc(&self) -> Result<&Arc<tokio::net::TcpListener>, ErrorCode> {
match &self.tcp_state {
TcpState::Listening { listener, .. } => Ok(listener),
#[cfg(feature = "p3")]
TcpState::Error(err) => Err(err.into()),
_ => Err(ErrorCode::InvalidState),
}
}
pub(crate) fn tcp_stream_arc(&self) -> Result<&Arc<tokio::net::TcpStream>, ErrorCode> {
match &self.tcp_state {
TcpState::Connected(socket) => Ok(socket),
#[cfg(feature = "p3")]
TcpState::Receiving(socket) => Ok(socket),
TcpState::P2Streaming(state) => Ok(&state.stream),
#[cfg(feature = "p3")]
TcpState::Error(err) => Err(err.into()),
_ => Err(ErrorCode::InvalidState),
}
}
pub(crate) fn p2_streaming_state(&self) -> Result<&P2TcpStreamingState, ErrorCode> {
match &self.tcp_state {
TcpState::P2Streaming(state) => Ok(state),
#[cfg(feature = "p3")]
TcpState::Error(err) => Err(err.into()),
_ => Err(ErrorCode::InvalidState),
}
}
pub(crate) fn set_p2_streaming_state(
&mut self,
state: P2TcpStreamingState,
) -> Result<(), ErrorCode> {
if !matches!(self.tcp_state, TcpState::Connected(_)) {
return Err(ErrorCode::InvalidState);
}
self.tcp_state = TcpState::P2Streaming(Box::new(state));
Ok(())
}
pub(crate) async fn ready(&mut self) {
match &mut self.tcp_state {
TcpState::Default(..)
| TcpState::BindStarted(..)
| TcpState::Bound(..)
| TcpState::ListenStarted(..)
| TcpState::ConnectReady(..)
| TcpState::Closed
| TcpState::Connected { .. }
| TcpState::Connecting(None)
| TcpState::Listening {
pending_accept: Some(_),
..
}
| TcpState::P2Streaming(_) => {}
#[cfg(feature = "p3")]
TcpState::Receiving(_) | TcpState::Error(_) => {}
TcpState::Connecting(Some(future)) => {
self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
}
TcpState::Listening {
listener,
pending_accept: slot @ None,
} => {
let result = futures::future::poll_fn(|cx| {
listener.poll_accept(cx).map_ok(|(stream, _)| stream)
})
.await;
*slot = Some(result);
}
}
}
}
#[cfg(not(target_os = "macos"))]
pub use inherits_option::*;
#[cfg(not(target_os = "macos"))]
mod inherits_option {
use crate::sockets::SocketAddressFamily;
use tokio::net::TcpStream;
#[derive(Default, Clone)]
pub struct NonInheritedOptions;
impl NonInheritedOptions {
pub fn set_keep_alive_idle_time(&mut self, _value: u64) {}
pub fn set_hop_limit(&mut self, _value: u8) {}
pub fn set_receive_buffer_size(&mut self, _value: usize) {}
pub fn set_send_buffer_size(&mut self, _value: usize) {}
pub(crate) fn apply(&self, _family: SocketAddressFamily, _stream: &TcpStream) {}
}
}
#[cfg(target_os = "macos")]
pub use does_not_inherit_options::*;
#[cfg(target_os = "macos")]
mod does_not_inherit_options {
use crate::sockets::SocketAddressFamily;
use rustix::net::sockopt;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering::Relaxed};
use std::time::Duration;
use tokio::net::TcpStream;
#[derive(Default, Clone)]
pub struct NonInheritedOptions(Arc<Inner>);
#[derive(Default)]
struct Inner {
receive_buffer_size: AtomicUsize,
send_buffer_size: AtomicUsize,
hop_limit: AtomicU8,
keep_alive_idle_time: AtomicU64, }
impl NonInheritedOptions {
pub fn set_keep_alive_idle_time(&mut self, value: u64) {
self.0.keep_alive_idle_time.store(value, Relaxed);
}
pub fn set_hop_limit(&mut self, value: u8) {
self.0.hop_limit.store(value, Relaxed);
}
pub fn set_receive_buffer_size(&mut self, value: usize) {
self.0.receive_buffer_size.store(value, Relaxed);
}
pub fn set_send_buffer_size(&mut self, value: usize) {
self.0.send_buffer_size.store(value, Relaxed);
}
pub(crate) fn apply(&self, family: SocketAddressFamily, stream: &TcpStream) {
let receive_buffer_size = self.0.receive_buffer_size.load(Relaxed);
if receive_buffer_size > 0 {
_ = sockopt::set_socket_recv_buffer_size(&stream, receive_buffer_size);
}
let send_buffer_size = self.0.send_buffer_size.load(Relaxed);
if send_buffer_size > 0 {
_ = sockopt::set_socket_send_buffer_size(&stream, send_buffer_size);
}
if family == SocketAddressFamily::Ipv6 {
let hop_limit = self.0.hop_limit.load(Relaxed);
if hop_limit > 0 {
_ = sockopt::set_ipv6_unicast_hops(&stream, Some(hop_limit));
}
}
let keep_alive_idle_time = self.0.keep_alive_idle_time.load(Relaxed);
if keep_alive_idle_time > 0 {
_ = sockopt::set_tcp_keepidle(&stream, Duration::from_nanos(keep_alive_idle_time));
}
}
}
}