use alloc::{sync::Arc, task::Wake, vec, vec::Vec};
use core::{
net::{Ipv4Addr, SocketAddr},
sync::atomic::{AtomicBool, AtomicI32, AtomicU32, Ordering},
task::{Context, Waker},
};
use ax_errno::{AxError, AxResult, LinuxError, ax_bail, ax_err_type};
use ax_io::prelude::*;
use ax_sync::Mutex;
use axpoll::{IoEvents, PollSet, Pollable};
use hashbrown::HashMap;
use smoltcp::{
iface::SocketHandle,
socket::tcp as smol,
time::Duration,
wire::{IpEndpoint, IpListenEndpoint},
};
use spin::LazyLock;
use crate::{
LISTEN_TABLE, RecvFlags, RecvOptions, SOCKET_SET, SendOptions, Shutdown, Socket, SocketAddrEx,
SocketOps,
config::{DeviceBinding, InterfaceId},
consts::{TCP_RX_BUF_LEN, TCP_TX_BUF_LEN},
endpoint_from_ip_endpoint,
general::GeneralOptions,
get_control, get_service, interface_by_id,
options::{Configurable, GetSocketOption, SetSocketOption, TcpInfo, TcpInfoOptions, TcpState},
request_poll,
state::*,
};
pub(crate) fn new_tcp_socket() -> smol::Socket<'static> {
smol::Socket::new(
smol::SocketBuffer::new(vec![0; TCP_RX_BUF_LEN]),
smol::SocketBuffer::new(vec![0; TCP_TX_BUF_LEN]),
)
}
const TCP_KEEPIDLE_DEFAULT_SECS: u32 = 7200;
const TCP_KEEPINTVL_DEFAULT_SECS: u32 = 75;
const TCP_KEEPCNT_DEFAULT: u32 = 9;
const TCP_USER_TIMEOUT_DEFAULT_MS: u32 = 0;
const TCP_KEEPIDLE_MAX_SECS: u32 = 32767;
const TCP_KEEPINTVL_MAX_SECS: u32 = 32767;
const TCP_KEEPCNT_MAX: u32 = 127;
const TCP_INFO_DEFAULT_MSS: u32 = 1460;
const TCP_INFO_DEFAULT_PMTU: u32 = 1500;
const TCP_INFO_INITIAL_RTO_MICROS: u32 = 1_000_000;
const TCP_INFO_DEFAULT_REORDERING: u32 = 3;
pub struct TcpSocket {
state: StateLock,
handle: SocketHandle,
bound_endpoint: Mutex<IpListenEndpoint>,
peer_endpoint: Mutex<Option<IpEndpoint>>,
bound_registered: AtomicBool,
general: GeneralOptions,
pending_error: AtomicI32,
keep_idle_secs: AtomicU32,
keep_interval_secs: AtomicU32,
keep_count: AtomicU32,
user_timeout_millis: AtomicU32,
rx_closed: AtomicBool,
poll_rx: Arc<PollSet>,
poll_tx: Arc<PollSet>,
poll_rx_closed: PollSet,
}
unsafe impl Sync for TcpSocket {}
struct TcpPollWake {
poll: Arc<PollSet>,
ready: IoEvents,
}
impl Wake for TcpPollWake {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}
fn wake_by_ref(self: &Arc<Self>) {
crate::defer_poll_wake(self.poll.clone(), self.ready);
}
}
impl TcpSocket {
pub fn new() -> Self {
Self {
state: StateLock::new(State::Idle),
handle: SOCKET_SET.add(new_tcp_socket()),
bound_endpoint: Mutex::new(empty_endpoint()),
peer_endpoint: Mutex::new(None),
bound_registered: AtomicBool::new(false),
general: GeneralOptions::new(1, 2, 6), pending_error: AtomicI32::new(0),
keep_idle_secs: AtomicU32::new(TCP_KEEPIDLE_DEFAULT_SECS),
keep_interval_secs: AtomicU32::new(TCP_KEEPINTVL_DEFAULT_SECS),
keep_count: AtomicU32::new(TCP_KEEPCNT_DEFAULT),
user_timeout_millis: AtomicU32::new(TCP_USER_TIMEOUT_DEFAULT_MS),
rx_closed: AtomicBool::new(false),
poll_rx: Arc::new(PollSet::new()),
poll_tx: Arc::new(PollSet::new()),
poll_rx_closed: PollSet::new(),
}
}
pub fn bind_device(&self, interface_id: InterfaceId) -> AxResult {
if interface_by_id(interface_id).is_none() {
return Err(AxError::NoSuchDevice);
}
self.general.set_device_binding(DeviceBinding {
bound_if: Some(interface_id),
});
Ok(())
}
fn new_connected(
handle: SocketHandle,
local_endpoint: IpEndpoint,
remote_endpoint: IpEndpoint,
) -> Self {
let result = Self {
state: StateLock::new(State::Connected),
handle,
bound_endpoint: Mutex::new(empty_endpoint()),
peer_endpoint: Mutex::new(Some(remote_endpoint)),
bound_registered: AtomicBool::new(false),
general: GeneralOptions::new(1, 2, 6), pending_error: AtomicI32::new(0),
keep_idle_secs: AtomicU32::new(TCP_KEEPIDLE_DEFAULT_SECS),
keep_interval_secs: AtomicU32::new(TCP_KEEPINTVL_DEFAULT_SECS),
keep_count: AtomicU32::new(TCP_KEEPCNT_DEFAULT),
user_timeout_millis: AtomicU32::new(TCP_USER_TIMEOUT_DEFAULT_MS),
rx_closed: AtomicBool::new(false),
poll_rx: Arc::new(PollSet::new()),
poll_tx: Arc::new(PollSet::new()),
poll_rx_closed: PollSet::new(),
};
let endpoint = endpoint_from_ip_endpoint(local_endpoint);
*result.bound_endpoint.lock() = endpoint;
result.general.set_device_binding(
get_control()
.local_binding_for(&endpoint)
.unwrap_or_default(),
);
result
}
}
impl Default for TcpSocket {
fn default() -> Self {
Self::new()
}
}
impl TcpSocket {
fn state(&self) -> State {
self.state.get()
}
#[inline]
fn is_listening(&self) -> bool {
self.state() == State::Listening
}
fn with_smol_socket<R>(&self, f: impl FnOnce(&mut smol::Socket) -> R) -> R {
SOCKET_SET.with_socket_mut::<smol::Socket, _, _>(self.handle, f)
}
fn keep_alive_interval(&self) -> Duration {
Duration::from_secs(self.keep_idle_secs.load(Ordering::Relaxed) as u64)
}
fn tcp_info_snapshot(&self) -> TcpInfo {
self.with_smol_socket(|socket| {
let send_queue = saturating_u32(socket.send_queue());
let snd_mss = TCP_INFO_DEFAULT_MSS;
let mut options = TcpInfoOptions::empty();
if socket.timestamp_enabled() {
options |= TcpInfoOptions::TIMESTAMPS;
}
TcpInfo {
state: tcp_state_info(socket.state()),
options,
rto_micros: socket
.timeout()
.map(duration_micros_u32)
.unwrap_or(TCP_INFO_INITIAL_RTO_MICROS),
ato_micros: socket.ack_delay().map(duration_micros_u32).unwrap_or(0),
snd_mss,
rcv_mss: snd_mss,
notsent_bytes: send_queue,
pmtu: TCP_INFO_DEFAULT_PMTU,
advmss: snd_mss,
reordering: TCP_INFO_DEFAULT_REORDERING,
snd_wnd: 0,
..Default::default()
}
})
}
fn bound_endpoint(&self) -> AxResult<IpListenEndpoint> {
let endpoint = *self.bound_endpoint.lock();
if endpoint.port == 0 {
ax_bail!(InvalidInput, "not bound");
}
Ok(endpoint)
}
fn store_pending_error(&self, err: LinuxError) {
self.pending_error.store(err.code(), Ordering::Release);
}
fn clear_pending_error(&self) {
self.pending_error.store(0, Ordering::Release);
}
fn take_pending_error(&self) -> i32 {
self.pending_error.swap(0, Ordering::AcqRel)
}
fn connect_error(&self) -> AxError {
LinuxError::try_from(self.pending_error.load(Ordering::Acquire))
.map_or(AxError::ConnectionRefused, AxError::from)
}
fn poll_connect(&self) -> IoEvents {
let mut events = IoEvents::empty();
self.with_smol_socket(|socket| match socket.state() {
smol::State::SynSent | smol::State::SynReceived => {
}
smol::State::Established => {
self.clear_pending_error();
self.state.set(State::Connected); *self.peer_endpoint.lock() = socket.remote_endpoint();
debug!(
"TCP socket {}: connected to {}",
self.handle,
socket.remote_endpoint().unwrap(),
);
events.set(IoEvents::OUT, true);
}
state => {
*self.peer_endpoint.lock() = None;
self.store_pending_error(LinuxError::ECONNREFUSED);
self.state.set(State::Closed); debug!(
"TCP socket {}: connect failed in state {:?}",
self.handle, state
);
events.set(IoEvents::OUT, true);
events.set(IoEvents::ERR, true);
events.set(IoEvents::HUP, true);
}
});
events
}
fn poll_stream(&self) -> IoEvents {
let mut events = IoEvents::empty();
self.with_smol_socket(|socket| {
events.set(
IoEvents::IN,
!self.rx_closed.load(Ordering::Acquire)
&& (!socket.may_recv() || socket.can_recv()),
);
events.set(IoEvents::OUT, !socket.may_send() || socket.can_send());
});
events
}
fn poll_listener(&self) -> IoEvents {
let mut events = IoEvents::empty();
let endpoint = self.bound_endpoint().unwrap();
let sockets = SOCKET_SET.inner.lock();
events.set(
IoEvents::IN,
LISTEN_TABLE.can_accept(endpoint, &sockets).unwrap(),
);
events
}
}
impl Configurable for TcpSocket {
fn get_option_inner(&self, option: &mut GetSocketOption) -> AxResult<bool> {
use GetSocketOption as O;
if let O::Error(error) = option {
**error = self.take_pending_error();
return Ok(true);
}
if self.general.get_option_inner(option)? {
return Ok(true);
}
match option {
O::NoDelay(no_delay) => {
**no_delay = self.with_smol_socket(|socket| !socket.nagle_enabled());
}
O::KeepAlive(keep_alive) => {
**keep_alive = self.with_smol_socket(|socket| socket.keep_alive().is_some());
}
O::MaxSegment(max_segment) => {
**max_segment = 1460;
}
O::TcpKeepIdle(keep_idle) => {
**keep_idle = self.keep_idle_secs.load(Ordering::Relaxed);
}
O::TcpKeepInterval(keep_interval) => {
**keep_interval = self.keep_interval_secs.load(Ordering::Relaxed);
}
O::TcpKeepCount(keep_count) => {
**keep_count = self.keep_count.load(Ordering::Relaxed);
}
O::TcpUserTimeout(user_timeout) => {
**user_timeout = self.user_timeout_millis.load(Ordering::Relaxed);
}
O::SendBuffer(size) => {
**size = TCP_TX_BUF_LEN;
}
O::ReceiveBuffer(size) => {
**size = TCP_RX_BUF_LEN;
}
O::TcpInfo(info) => {
**info = self.tcp_info_snapshot();
}
_ => return Ok(false),
}
Ok(true)
}
fn set_option_inner(&self, option: SetSocketOption) -> AxResult<bool> {
use SetSocketOption as O;
if self.general.set_option_inner(option)? {
return Ok(true);
}
match option {
O::NoDelay(no_delay) => {
self.with_smol_socket(|socket| {
socket.set_nagle_enabled(!no_delay);
});
}
O::KeepAlive(keep_alive) => {
let interval = self.keep_alive_interval();
self.with_smol_socket(|socket| {
socket.set_keep_alive(keep_alive.then_some(interval));
});
}
O::TcpKeepIdle(keep_idle) => {
if *keep_idle == 0 || *keep_idle > TCP_KEEPIDLE_MAX_SECS {
return Err(AxError::InvalidInput);
}
self.keep_idle_secs.store(*keep_idle, Ordering::Relaxed);
let interval = Duration::from_secs(*keep_idle as u64);
self.with_smol_socket(|socket| {
if socket.keep_alive().is_some() {
socket.set_keep_alive(Some(interval));
}
});
}
O::TcpKeepInterval(keep_interval) => {
if *keep_interval == 0 || *keep_interval > TCP_KEEPINTVL_MAX_SECS {
return Err(AxError::InvalidInput);
}
self.keep_interval_secs
.store(*keep_interval, Ordering::Relaxed);
}
O::TcpKeepCount(keep_count) => {
if *keep_count == 0 || *keep_count > TCP_KEEPCNT_MAX {
return Err(AxError::InvalidInput);
}
self.keep_count.store(*keep_count, Ordering::Relaxed);
}
O::TcpUserTimeout(user_timeout) => {
self.user_timeout_millis
.store(*user_timeout, Ordering::Relaxed);
}
_ => return Ok(false),
}
Ok(true)
}
}
impl SocketOps for TcpSocket {
fn bind(&self, local_addr: SocketAddrEx) -> AxResult {
let mut local_addr = local_addr.into_ip()?;
self.state
.lock(State::Idle)
.map_err(|_| ax_err_type!(InvalidInput, "already bound"))?
.transit(State::Idle, || {
if local_addr.port() == 0 {
local_addr.set_port(get_ephemeral_port()?);
}
if self.bound_endpoint.lock().port != 0 {
return Err(AxError::InvalidInput);
}
let endpoint = IpListenEndpoint {
addr: if local_addr.ip().is_unspecified() {
None
} else {
Some(local_addr.ip().into())
},
port: local_addr.port(),
};
if !self.general.reuse_address() && !LISTEN_TABLE.can_listen(endpoint) {
return Err(AxError::AddrInUse);
}
let binding = get_control().local_binding_for(&endpoint)?;
self.register_bound_endpoint(endpoint)?;
*self.bound_endpoint.lock() = endpoint;
if binding.bound_if.is_some() {
self.general.set_device_binding(binding);
}
debug!("TCP socket {}: binding to {}", self.handle, local_addr);
Ok(())
})
}
fn connect(&self, remote_addr: SocketAddrEx) -> AxResult {
let remote_addr = remote_addr.into_ip()?;
self.start_connect(remote_addr)?;
request_poll();
self.general.send_poller(self, || {
request_poll();
let events = self.poll_connect();
if !events.contains(IoEvents::OUT) {
Err(AxError::WouldBlock)
} else if self.state() == State::Connected {
Ok(())
} else {
Err(self.connect_error())
}
})
}
fn listen(&self, backlog: usize) -> AxResult {
if let Ok(guard) = self.state.lock(State::Idle) {
guard.transit(State::Listening, || {
let mut bound_endpoint = *self.bound_endpoint.lock();
if bound_endpoint.port == 0 {
bound_endpoint.port = get_ephemeral_port()?;
}
let binding = get_control().local_binding_for(&bound_endpoint)?;
let register_bound = !self.bound_registered.load(Ordering::Acquire);
if register_bound {
register_tcp_bound(bound_endpoint)?;
}
if let Err(err) = LISTEN_TABLE.listen(bound_endpoint, backlog) {
if register_bound {
unregister_tcp_bound(bound_endpoint);
}
return Err(err);
}
*self.bound_endpoint.lock() = bound_endpoint;
if register_bound {
self.bound_registered.store(true, Ordering::Release);
}
if binding.bound_if.is_some() {
self.general.set_device_binding(binding);
}
debug!("listening on {}", bound_endpoint);
Ok(())
})?;
} else {
}
Ok(())
}
fn accept(&self) -> AxResult<Socket> {
if !self.is_listening() {
ax_bail!(InvalidInput, "not listening");
}
let bound_endpoint = self.bound_endpoint()?;
self.general.recv_poller(self, || {
request_poll();
let accepted = {
let mut sockets = SOCKET_SET.inner.lock();
LISTEN_TABLE.accept(bound_endpoint, &mut sockets)?
};
Ok({
let socket = TcpSocket::new_connected(
accepted.handle,
accepted.local_endpoint,
accepted.remote_endpoint,
);
debug!(
"accepted connection from {}, {}",
accepted.handle, accepted.remote_endpoint
);
socket.into()
})
})
}
fn send(&self, mut src: impl Read, options: SendOptions) -> AxResult<usize> {
let extra_nb = options.flags.contains(crate::SendFlags::DONTWAIT);
let result = self.general.send_poller_with(self, extra_nb, || {
request_poll();
self.with_smol_socket(|socket| {
if !socket.is_active() {
Err(AxError::NotConnected)
} else if !socket.can_send() {
Err(AxError::WouldBlock)
} else {
let len = socket
.send(|buffer| {
let result = src.read(buffer);
let len = result.unwrap_or(0);
(len, result)
})
.map_err(|_| ax_err_type!(NotConnected, "not connected?"))??;
Ok(len)
}
})
});
if result.is_ok() {
request_poll();
}
result
}
fn recv(&self, mut dst: impl Write + IoBufMut, options: RecvOptions<'_>) -> AxResult<usize> {
if self.rx_closed.load(Ordering::Acquire) {
return Err(AxError::NotConnected);
}
if self.state() == State::Closed {
return Err(AxError::NotConnected);
}
let extra_nb = options.flags.contains(RecvFlags::DONTWAIT);
self.general.recv_poller_with(self, extra_nb, || {
request_poll();
self.with_smol_socket(|socket| {
if socket.recv_queue() > 0 {
if options.flags.contains(RecvFlags::PEEK) {
dst.write(
socket
.peek(dst.remaining_mut())
.map_err(|_| ax_err_type!(NotConnected, "not connected?"))?,
)
} else {
let mut total = 0;
while socket.recv_queue() > 0 && dst.remaining_mut() > 0 {
let len = socket
.recv(|buf| {
let result = dst.write(buf);
let len = result.unwrap_or(0);
(len, result)
})
.map_err(|_| ax_err_type!(NotConnected, "not connected?"))??;
if len == 0 {
break;
}
total += len;
}
Ok(total)
}
} else if !socket.may_recv() {
Ok(0)
} else {
Err(AxError::WouldBlock)
}
})
})
}
fn recv_available(&self) -> AxResult<usize> {
if self.is_listening() {
return Err(AxError::InvalidInput);
}
let available = self.with_smol_socket(|socket| socket.recv_queue());
if available > 0 {
return Ok(available);
}
request_poll();
Ok(self.with_smol_socket(|socket| socket.recv_queue()))
}
fn local_addr(&self) -> AxResult<SocketAddrEx> {
let endpoint = self.with_smol_socket(|socket| {
socket
.local_endpoint()
.map(endpoint_from_ip_endpoint)
.unwrap_or_else(|| *self.bound_endpoint.lock())
});
Ok(SocketAddrEx::Ip(SocketAddr::new(
endpoint
.addr
.map_or_else(|| Ipv4Addr::UNSPECIFIED.into(), Into::into),
endpoint.port,
)))
}
fn peer_addr(&self) -> AxResult<SocketAddrEx> {
self.with_smol_socket(|socket| {
Ok(SocketAddrEx::Ip(
socket
.remote_endpoint()
.or_else(|| *self.peer_endpoint.lock())
.ok_or(AxError::NotConnected)?
.into(),
))
})
}
fn shutdown(&self, how: Shutdown) -> AxResult {
if how.has_read() {
self.rx_closed.store(true, Ordering::Release);
unsafe { self.poll_rx_closed.wake(IoEvents::RDHUP | IoEvents::IN) };
}
if let Ok(guard) = self.state.lock(State::Connected) {
if how.has_read() && how.has_write() {
guard.transit(State::Closed, || {
self.with_smol_socket(|socket| {
debug!("TCP socket {}: shutting down", self.handle);
socket.close();
});
self.unregister_bound_endpoint();
*self.bound_endpoint.lock() = empty_endpoint();
request_poll();
Ok(())
})?;
} else if how.has_write() {
self.with_smol_socket(|socket| {
debug!("TCP socket {}: shutting down write side", self.handle);
socket.close();
});
request_poll();
}
}
if let Ok(guard) = self.state.lock(State::Listening) {
guard.transit(State::Closed, || {
LISTEN_TABLE.unlisten(self.bound_endpoint()?);
self.unregister_bound_endpoint();
*self.bound_endpoint.lock() = empty_endpoint();
request_poll();
Ok(())
})?;
}
Ok(())
}
}
impl Pollable for TcpSocket {
fn poll(&self) -> IoEvents {
request_poll();
let mut events = match self.state() {
State::Connecting => self.poll_connect(),
State::Connected | State::Idle | State::Closed => self.poll_stream(),
State::Listening => self.poll_listener(),
State::Busy => IoEvents::empty(),
};
events.set(IoEvents::RDHUP, self.rx_closed.load(Ordering::Acquire));
events
}
fn register(&self, context: &mut Context<'_>, events: IoEvents) {
let mut accept_registration = None;
if self.is_listening() && events.intersects(IoEvents::IN | IoEvents::RDHUP) {
let port = self.bound_endpoint.lock().port;
if port != 0 {
let endpoint = *self.bound_endpoint.lock();
if let Some(accept_poll) = LISTEN_TABLE.accept_poll(endpoint) {
unsafe { accept_poll.register(context.waker(), IoEvents::IN) };
let accept_waker = LISTEN_TABLE.accept_waker(accept_poll.clone());
accept_registration = Some((endpoint, accept_poll, accept_waker));
}
}
}
let recv_waker = if events.intersects(IoEvents::IN | IoEvents::RDHUP) {
unsafe {
self.poll_rx
.register(context.waker(), IoEvents::IN | IoEvents::RDHUP)
};
Some(Waker::from(Arc::new(TcpPollWake {
poll: self.poll_rx.clone(),
ready: IoEvents::IN | IoEvents::RDHUP,
})))
} else {
None
};
let send_waker = if events.contains(IoEvents::OUT) {
unsafe { self.poll_tx.register(context.waker(), IoEvents::OUT) };
Some(Waker::from(Arc::new(TcpPollWake {
poll: self.poll_tx.clone(),
ready: IoEvents::OUT,
})))
} else {
None
};
if let Some((endpoint, accept_poll, accept_waker)) = accept_registration.as_ref() {
let mut sockets = SOCKET_SET.inner.lock();
LISTEN_TABLE.register_pending_accept_wakers(
*endpoint,
&mut sockets,
accept_poll,
accept_waker,
);
}
self.with_smol_socket(|socket| {
if let Some(waker) = recv_waker.as_ref() {
socket.register_recv_waker(waker);
}
if let Some(waker) = send_waker.as_ref() {
socket.register_send_waker(waker);
}
});
if events.intersects(IoEvents::IN | IoEvents::OUT | IoEvents::RDHUP) {
self.general.register_waker(context.waker());
}
if events.contains(IoEvents::RDHUP) {
unsafe {
self.poll_rx_closed
.register(context.waker(), IoEvents::RDHUP | IoEvents::IN)
};
}
}
}
impl Drop for TcpSocket {
fn drop(&mut self) {
let should_orphan = self.with_smol_socket(|socket| {
matches!(
socket.state(),
smol::State::Established
| smol::State::CloseWait
| smol::State::FinWait1
| smol::State::FinWait2
| smol::State::Closing
| smol::State::LastAck
| smol::State::TimeWait
) || socket.send_queue() > 0
});
if let Err(err) = self.shutdown(Shutdown::Both) {
warn!("TCP socket {}: shutdown failed: {}", self.handle, err);
}
self.unregister_bound_endpoint();
if should_orphan {
let timestamp = smoltcp::time::Instant::from_micros_const(
(ax_hal::time::monotonic_time_nanos() / 1_000) as i64,
);
crate::orphan::add_orphan(self.handle, timestamp);
} else {
SOCKET_SET.remove(self.handle);
}
crate::request_poll();
}
}
fn saturating_u32(value: usize) -> u32 {
value.min(u32::MAX as usize) as u32
}
fn duration_micros_u32(value: Duration) -> u32 {
value.total_micros().min(u32::MAX as u64) as u32
}
fn tcp_state_info(state: smol::State) -> TcpState {
match state {
smol::State::Closed => TcpState::Closed,
smol::State::Listen => TcpState::Listen,
smol::State::SynSent => TcpState::SynSent,
smol::State::SynReceived => TcpState::SynReceived,
smol::State::Established => TcpState::Established,
smol::State::FinWait1 => TcpState::FinWait1,
smol::State::FinWait2 => TcpState::FinWait2,
smol::State::CloseWait => TcpState::CloseWait,
smol::State::Closing => TcpState::Closing,
smol::State::LastAck => TcpState::LastAck,
smol::State::TimeWait => TcpState::TimeWait,
}
}
const fn empty_endpoint() -> IpListenEndpoint {
IpListenEndpoint {
addr: None,
port: 0,
}
}
impl TcpSocket {
fn start_connect(&self, remote_addr: SocketAddr) -> AxResult {
self.state
.lock(State::Idle)
.map_err(|state| {
if state == State::Connecting {
AxError::InProgress
} else {
ax_err_type!(AlreadyConnected)
}
})?
.transit(State::Connecting, || {
self.clear_pending_error();
let remote_endpoint = IpEndpoint::from(remote_addr);
let mut bound_endpoint = *self.bound_endpoint.lock();
let was_unbound_or_unspecified =
bound_endpoint.addr.is_none_or(|addr| addr.is_unspecified());
let had_explicit_device_binding = self.general.device_binding().bound_if.is_some();
if bound_endpoint.addr.is_none_or(|addr| addr.is_unspecified()) {
bound_endpoint.addr = Some(
get_control()
.select_route_with_binding(
&remote_endpoint.addr,
self.general.device_binding(),
)?
.source,
);
}
if bound_endpoint.port == 0 {
bound_endpoint.port = get_ephemeral_port()?;
}
info!(
"TCP connection from {} to {}",
bound_endpoint, remote_endpoint
);
let register_bound = !self.bound_registered.load(Ordering::Acquire);
if register_bound {
register_tcp_bound(bound_endpoint)?;
}
let result = {
let mut service = get_service();
let context = service.iface.context();
self.with_smol_socket(|socket| {
socket
.connect(context, remote_endpoint, bound_endpoint)
.map_err(|e| match e {
smol::ConnectError::InvalidState => {
ax_err_type!(AlreadyConnected)
}
smol::ConnectError::Unaddressable => {
ax_err_type!(ConnectionRefused, "unaddressable")
}
})?;
Ok::<(), AxError>(())
})
};
if let Err(err) = result {
if register_bound {
unregister_tcp_bound(bound_endpoint);
}
return Err(err);
}
*self.bound_endpoint.lock() = bound_endpoint;
if register_bound {
self.bound_registered.store(true, Ordering::Release);
}
if !had_explicit_device_binding && was_unbound_or_unspecified {
self.general
.set_device_binding(get_control().local_binding_for(&bound_endpoint)?);
}
Ok(())
})
}
fn register_bound_endpoint(&self, endpoint: IpListenEndpoint) -> AxResult {
if !self.bound_registered.load(Ordering::Acquire) {
register_tcp_bound(endpoint)?;
self.bound_registered.store(true, Ordering::Release);
}
Ok(())
}
fn unregister_bound_endpoint(&self) {
if self.bound_registered.swap(false, Ordering::AcqRel) {
unregister_tcp_bound(*self.bound_endpoint.lock());
}
}
}
static TCP_BOUND_PORTS: LazyLock<Mutex<HashMap<u16, Vec<Option<smoltcp::wire::IpAddress>>>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
fn register_tcp_bound(endpoint: IpListenEndpoint) -> AxResult {
if endpoint.port == 0 {
return Ok(());
}
let mut bound_ports = TCP_BOUND_PORTS.lock();
let bound_addrs = bound_ports.entry(endpoint.port).or_default();
if bound_addrs
.iter()
.any(|&addr| listen_addrs_conflict(addr, endpoint.addr))
{
return Err(AxError::AddrInUse);
}
bound_addrs.push(endpoint.addr);
Ok(())
}
fn unregister_tcp_bound(endpoint: IpListenEndpoint) {
if endpoint.port != 0 {
let mut bound_ports = TCP_BOUND_PORTS.lock();
if let Some(bound_addrs) = bound_ports.get_mut(&endpoint.port) {
if let Some(idx) = bound_addrs.iter().position(|&addr| addr == endpoint.addr) {
bound_addrs.swap_remove(idx);
}
if bound_addrs.is_empty() {
bound_ports.remove(&endpoint.port);
}
}
}
}
fn tcp_port_available(port: u16) -> bool {
LISTEN_TABLE.can_listen(IpListenEndpoint { addr: None, port })
&& !TCP_BOUND_PORTS.lock().contains_key(&port)
}
fn listen_addrs_conflict(
a: Option<smoltcp::wire::IpAddress>,
b: Option<smoltcp::wire::IpAddress>,
) -> bool {
a.is_none() || b.is_none() || a == b
}
fn get_ephemeral_port() -> AxResult<u16> {
const PORT_START: u16 = 0xc000;
const PORT_END: u16 = 0xffff;
static CURR: Mutex<u16> = Mutex::new(PORT_START);
let mut curr = CURR.lock();
let mut tries = 0;
while tries <= PORT_END - PORT_START {
let port = *curr;
if *curr == PORT_END {
*curr = PORT_START;
} else {
*curr += 1;
}
if tcp_port_available(port) {
return Ok(port);
}
tries += 1;
}
ax_bail!(AddrInUse, "no available ports");
}
#[cfg(test)]
mod tests {
use core::net::{IpAddr, SocketAddr};
use super::*;
use crate::{
options::{Configurable, GetSocketOption, SetSocketOption, TcpState},
test_support::{
LOCAL_ADDR, LOCAL_IF, PEER_ADDR, PEER_IF, init_split_route_network, network_test_guard,
},
};
#[test]
fn tcp_info_reports_default_socket_metrics() {
let _guard = network_test_guard();
init_split_route_network();
let socket = TcpSocket::new();
let mut info = TcpInfo::default();
socket
.get_option(GetSocketOption::TcpInfo(&mut info))
.unwrap();
assert_eq!(info.state, TcpState::Closed);
assert_eq!(info.snd_mss, TCP_INFO_DEFAULT_MSS);
assert_eq!(info.rcv_mss, TCP_INFO_DEFAULT_MSS);
assert_eq!(info.pmtu, TCP_INFO_DEFAULT_PMTU);
assert_eq!(info.notsent_bytes, 0);
assert_eq!(info.snd_wnd, 0);
assert_eq!(info.snd_cwnd, 0);
assert_eq!(info.rcv_space, 0);
assert_eq!(info.rcv_wnd, 0);
}
#[test]
fn connect_preserves_bound_interface() {
let _guard = network_test_guard();
init_split_route_network();
let socket = TcpSocket::new();
let nonblocking = true;
socket
.set_option(SetSocketOption::NonBlocking(&nonblocking))
.unwrap();
socket
.bind(SocketAddrEx::Ip(SocketAddr::new(IpAddr::V4(LOCAL_ADDR), 0)))
.unwrap();
assert_eq!(
socket.general.device_binding(),
DeviceBinding {
bound_if: Some(LOCAL_IF)
}
);
socket
.start_connect(SocketAddr::new(IpAddr::V4(PEER_ADDR), 80))
.unwrap();
assert_eq!(
socket.general.device_binding(),
DeviceBinding {
bound_if: Some(LOCAL_IF)
}
);
}
#[test]
fn connect_uses_peer_route_when_unbound() {
let _guard = network_test_guard();
init_split_route_network();
let socket = TcpSocket::new();
let nonblocking = true;
socket
.set_option(SetSocketOption::NonBlocking(&nonblocking))
.unwrap();
socket
.bind(SocketAddrEx::Ip(SocketAddr::new(
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
0,
)))
.unwrap();
socket
.start_connect(SocketAddr::new(IpAddr::V4(PEER_ADDR), 80))
.unwrap();
assert_eq!(
socket.general.device_binding(),
DeviceBinding {
bound_if: Some(PEER_IF)
}
);
}
#[test]
fn connect_rejects_unroutable_bound_device() {
let _guard = network_test_guard();
init_split_route_network();
let socket = TcpSocket::new();
let nonblocking = true;
socket
.set_option(SetSocketOption::NonBlocking(&nonblocking))
.unwrap();
socket.bind_device(LOCAL_IF).unwrap();
socket
.bind(SocketAddrEx::Ip(SocketAddr::new(
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
0,
)))
.unwrap();
assert!(
socket
.start_connect(SocketAddr::new(IpAddr::V4(PEER_ADDR), 80))
.is_err()
);
assert_eq!(
socket.general.device_binding(),
DeviceBinding {
bound_if: Some(LOCAL_IF)
}
);
}
}