use alloc::{vec, vec::Vec};
use core::{
net::{Ipv4Addr, SocketAddr},
sync::atomic::{AtomicBool, AtomicI32, AtomicU32, Ordering},
task::Context,
};
use ax_errno::{AxError, AxResult, LinuxError, ax_bail, ax_err_type};
use ax_io::prelude::*;
use ax_sync::Mutex;
use axpoll::{IoEvents, PollSet, Pollable};
use hashbrown::HashMap;
use smoltcp::{
iface::SocketHandle,
socket::tcp as smol,
time::Duration,
wire::{IpEndpoint, IpListenEndpoint},
};
use spin::Lazy;
use crate::{
LISTEN_TABLE, RecvFlags, RecvOptions, SOCKET_SET, SendOptions, Shutdown, Socket, SocketAddrEx,
SocketOps,
consts::{TCP_RX_BUF_LEN, TCP_TX_BUF_LEN},
general::GeneralOptions,
get_service,
options::{Configurable, GetSocketOption, SetSocketOption},
poll_interfaces,
state::*,
};
pub(crate) fn new_tcp_socket() -> smol::Socket<'static> {
smol::Socket::new(
smol::SocketBuffer::new(vec![0; TCP_RX_BUF_LEN]),
smol::SocketBuffer::new(vec![0; TCP_TX_BUF_LEN]),
)
}
const TCP_KEEPIDLE_DEFAULT_SECS: u32 = 7200;
const TCP_KEEPINTVL_DEFAULT_SECS: u32 = 75;
const TCP_KEEPCNT_DEFAULT: u32 = 9;
const TCP_USER_TIMEOUT_DEFAULT_MS: u32 = 0;
const TCP_KEEPIDLE_MAX_SECS: u32 = 32767;
const TCP_KEEPINTVL_MAX_SECS: u32 = 32767;
const TCP_KEEPCNT_MAX: u32 = 127;
pub struct TcpSocket {
state: StateLock,
handle: SocketHandle,
bound_endpoint: Mutex<IpListenEndpoint>,
bound_registered: AtomicBool,
general: GeneralOptions,
pending_error: AtomicI32,
keep_idle_secs: AtomicU32,
keep_interval_secs: AtomicU32,
keep_count: AtomicU32,
user_timeout_millis: AtomicU32,
rx_closed: AtomicBool,
poll_rx_closed: PollSet,
}
unsafe impl Sync for TcpSocket {}
impl TcpSocket {
pub fn new() -> Self {
Self {
state: StateLock::new(State::Idle),
handle: SOCKET_SET.add(new_tcp_socket()),
bound_endpoint: Mutex::new(empty_endpoint()),
bound_registered: AtomicBool::new(false),
general: GeneralOptions::new(),
pending_error: AtomicI32::new(0),
keep_idle_secs: AtomicU32::new(TCP_KEEPIDLE_DEFAULT_SECS),
keep_interval_secs: AtomicU32::new(TCP_KEEPINTVL_DEFAULT_SECS),
keep_count: AtomicU32::new(TCP_KEEPCNT_DEFAULT),
user_timeout_millis: AtomicU32::new(TCP_USER_TIMEOUT_DEFAULT_MS),
rx_closed: AtomicBool::new(false),
poll_rx_closed: PollSet::new(),
}
}
fn new_connected(handle: SocketHandle) -> Self {
let result = Self {
state: StateLock::new(State::Connected),
handle,
bound_endpoint: Mutex::new(empty_endpoint()),
bound_registered: AtomicBool::new(false),
general: GeneralOptions::new(),
pending_error: AtomicI32::new(0),
keep_idle_secs: AtomicU32::new(TCP_KEEPIDLE_DEFAULT_SECS),
keep_interval_secs: AtomicU32::new(TCP_KEEPINTVL_DEFAULT_SECS),
keep_count: AtomicU32::new(TCP_KEEPCNT_DEFAULT),
user_timeout_millis: AtomicU32::new(TCP_USER_TIMEOUT_DEFAULT_MS),
rx_closed: AtomicBool::new(false),
poll_rx_closed: PollSet::new(),
};
let endpoint = result.with_smol_socket(|socket| socket_bound_endpoint(socket));
*result.bound_endpoint.lock() = endpoint;
result
.general
.set_device_mask(get_service().device_mask_for(&endpoint));
result
}
}
impl Default for TcpSocket {
fn default() -> Self {
Self::new()
}
}
impl TcpSocket {
fn state(&self) -> State {
self.state.get()
}
#[inline]
fn is_listening(&self) -> bool {
self.state() == State::Listening
}
fn with_smol_socket<R>(&self, f: impl FnOnce(&mut smol::Socket) -> R) -> R {
SOCKET_SET.with_socket_mut::<smol::Socket, _, _>(self.handle, f)
}
fn keep_alive_interval(&self) -> Duration {
Duration::from_secs(self.keep_idle_secs.load(Ordering::Relaxed) as u64)
}
fn bound_endpoint(&self) -> AxResult<IpListenEndpoint> {
let endpoint = *self.bound_endpoint.lock();
if endpoint.port == 0 {
ax_bail!(InvalidInput, "not bound");
}
Ok(endpoint)
}
fn store_pending_error(&self, err: LinuxError) {
self.pending_error.store(err.code(), Ordering::Release);
}
fn clear_pending_error(&self) {
self.pending_error.store(0, Ordering::Release);
}
fn take_pending_error(&self) -> i32 {
self.pending_error.swap(0, Ordering::AcqRel)
}
fn connect_error(&self) -> AxError {
LinuxError::try_from(self.pending_error.load(Ordering::Acquire))
.map_or(AxError::ConnectionRefused, AxError::from)
}
fn poll_connect(&self) -> IoEvents {
let mut events = IoEvents::empty();
self.with_smol_socket(|socket| match socket.state() {
smol::State::SynSent | smol::State::SynReceived => {
}
smol::State::Established => {
self.clear_pending_error();
self.state.set(State::Connected); debug!(
"TCP socket {}: connected to {}",
self.handle,
socket.remote_endpoint().unwrap(),
);
events.set(IoEvents::OUT, true);
}
state => {
self.store_pending_error(LinuxError::ECONNREFUSED);
self.state.set(State::Closed); debug!(
"TCP socket {}: connect failed in state {:?}",
self.handle, state
);
events.set(IoEvents::OUT, true);
events.set(IoEvents::ERR, true);
events.set(IoEvents::HUP, true);
}
});
events
}
fn poll_stream(&self) -> IoEvents {
let mut events = IoEvents::empty();
self.with_smol_socket(|socket| {
events.set(
IoEvents::IN,
!self.rx_closed.load(Ordering::Acquire)
&& (!socket.may_recv() || socket.can_recv()),
);
events.set(IoEvents::OUT, !socket.may_send() || socket.can_send());
});
events
}
fn poll_listener(&self) -> IoEvents {
let mut events = IoEvents::empty();
let port = self.bound_endpoint().unwrap().port;
let sockets = SOCKET_SET.inner.lock();
events.set(
IoEvents::IN,
LISTEN_TABLE.can_accept(port, &sockets).unwrap(),
);
events
}
}
impl Configurable for TcpSocket {
fn get_option_inner(&self, option: &mut GetSocketOption) -> AxResult<bool> {
use GetSocketOption as O;
if let O::Error(error) = option {
**error = self.take_pending_error();
return Ok(true);
}
if self.general.get_option_inner(option)? {
return Ok(true);
}
match option {
O::NoDelay(no_delay) => {
**no_delay = self.with_smol_socket(|socket| !socket.nagle_enabled());
}
O::KeepAlive(keep_alive) => {
**keep_alive = self.with_smol_socket(|socket| socket.keep_alive().is_some());
}
O::MaxSegment(max_segment) => {
**max_segment = 1460;
}
O::TcpKeepIdle(keep_idle) => {
**keep_idle = self.keep_idle_secs.load(Ordering::Relaxed);
}
O::TcpKeepInterval(keep_interval) => {
**keep_interval = self.keep_interval_secs.load(Ordering::Relaxed);
}
O::TcpKeepCount(keep_count) => {
**keep_count = self.keep_count.load(Ordering::Relaxed);
}
O::TcpUserTimeout(user_timeout) => {
**user_timeout = self.user_timeout_millis.load(Ordering::Relaxed);
}
O::SendBuffer(size) => {
**size = TCP_TX_BUF_LEN;
}
O::ReceiveBuffer(size) => {
**size = TCP_RX_BUF_LEN;
}
O::TcpInfo(_) => {
}
_ => return Ok(false),
}
Ok(true)
}
fn set_option_inner(&self, option: SetSocketOption) -> AxResult<bool> {
use SetSocketOption as O;
if self.general.set_option_inner(option)? {
return Ok(true);
}
match option {
O::NoDelay(no_delay) => {
self.with_smol_socket(|socket| {
socket.set_nagle_enabled(!no_delay);
});
}
O::KeepAlive(keep_alive) => {
let interval = self.keep_alive_interval();
self.with_smol_socket(|socket| {
socket.set_keep_alive(keep_alive.then_some(interval));
});
}
O::TcpKeepIdle(keep_idle) => {
if *keep_idle == 0 || *keep_idle > TCP_KEEPIDLE_MAX_SECS {
return Err(AxError::InvalidInput);
}
self.keep_idle_secs.store(*keep_idle, Ordering::Relaxed);
let interval = Duration::from_secs(*keep_idle as u64);
self.with_smol_socket(|socket| {
if socket.keep_alive().is_some() {
socket.set_keep_alive(Some(interval));
}
});
}
O::TcpKeepInterval(keep_interval) => {
if *keep_interval == 0 || *keep_interval > TCP_KEEPINTVL_MAX_SECS {
return Err(AxError::InvalidInput);
}
self.keep_interval_secs
.store(*keep_interval, Ordering::Relaxed);
}
O::TcpKeepCount(keep_count) => {
if *keep_count == 0 || *keep_count > TCP_KEEPCNT_MAX {
return Err(AxError::InvalidInput);
}
self.keep_count.store(*keep_count, Ordering::Relaxed);
}
O::TcpUserTimeout(user_timeout) => {
self.user_timeout_millis
.store(*user_timeout, Ordering::Relaxed);
}
_ => return Ok(false),
}
Ok(true)
}
}
impl SocketOps for TcpSocket {
fn bind(&self, local_addr: SocketAddrEx) -> AxResult {
let mut local_addr = local_addr.into_ip()?;
self.state
.lock(State::Idle)
.map_err(|_| ax_err_type!(InvalidInput, "already bound"))?
.transit(State::Idle, || {
if local_addr.port() == 0 {
local_addr.set_port(get_ephemeral_port()?);
}
if !self.general.reuse_address() && !LISTEN_TABLE.can_listen(local_addr.port()) {
return Err(AxError::AddrInUse);
}
let endpoint = IpListenEndpoint {
addr: if local_addr.ip().is_unspecified() {
None
} else {
Some(local_addr.ip().into())
},
port: local_addr.port(),
};
if self.bound_endpoint.lock().port != 0 {
return Err(AxError::InvalidInput);
}
self.register_bound_endpoint(endpoint)?;
*self.bound_endpoint.lock() = endpoint;
self.general
.set_device_mask(get_service().device_mask_for(&endpoint));
debug!("TCP socket {}: binding to {}", self.handle, local_addr);
Ok(())
})
}
fn connect(&self, remote_addr: SocketAddrEx) -> AxResult {
let remote_addr = remote_addr.into_ip()?;
self.start_connect(remote_addr)?;
ax_task::yield_now();
self.general.send_poller(self, || {
poll_interfaces();
let events = self.poll_connect();
if !events.contains(IoEvents::OUT) {
Err(AxError::WouldBlock)
} else if self.state() == State::Connected {
Ok(())
} else {
Err(self.connect_error())
}
})
}
fn listen(&self, backlog: usize) -> AxResult {
if let Ok(guard) = self.state.lock(State::Idle) {
guard.transit(State::Listening, || {
let mut bound_endpoint = *self.bound_endpoint.lock();
if bound_endpoint.port == 0 {
bound_endpoint.port = get_ephemeral_port()?;
}
let register_bound = !self.bound_registered.load(Ordering::Acquire);
if register_bound {
register_tcp_bound(bound_endpoint)?;
}
if let Err(err) = LISTEN_TABLE.listen(bound_endpoint, backlog) {
if register_bound {
unregister_tcp_bound(bound_endpoint);
}
return Err(err);
}
*self.bound_endpoint.lock() = bound_endpoint;
if register_bound {
self.bound_registered.store(true, Ordering::Release);
}
self.general
.set_device_mask(get_service().device_mask_for(&bound_endpoint));
debug!("listening on {}", bound_endpoint);
Ok(())
})?;
} else {
}
Ok(())
}
fn accept(&self) -> AxResult<Socket> {
if !self.is_listening() {
ax_bail!(InvalidInput, "not listening");
}
let bound_port = self.bound_endpoint()?.port;
self.general.recv_poller(self, || {
poll_interfaces();
let handle = {
let sockets = SOCKET_SET.inner.lock();
LISTEN_TABLE.accept(bound_port, &sockets)?
};
Ok({
let socket = TcpSocket::new_connected(handle);
debug!(
"accepted connection from {}, {}",
handle,
socket.with_smol_socket(|socket| socket.remote_endpoint().unwrap())
);
socket.into()
})
})
}
fn send(&self, mut src: impl Read, options: SendOptions) -> AxResult<usize> {
let extra_nb = options.flags.contains(crate::SendFlags::DONTWAIT);
let result = self.general.send_poller_with(self, extra_nb, || {
poll_interfaces();
self.with_smol_socket(|socket| {
if !socket.is_active() {
Err(AxError::NotConnected)
} else if !socket.can_send() {
Err(AxError::WouldBlock)
} else {
let len = socket
.send(|buffer| {
let result = src.read(buffer);
let len = result.unwrap_or(0);
(len, result)
})
.map_err(|_| ax_err_type!(NotConnected, "not connected?"))??;
Ok(len)
}
})
});
if result.is_ok() {
poll_interfaces();
}
result
}
fn recv(&self, mut dst: impl Write + IoBufMut, options: RecvOptions<'_>) -> AxResult<usize> {
if self.rx_closed.load(Ordering::Acquire) {
return Err(AxError::NotConnected);
}
let extra_nb = options.flags.contains(RecvFlags::DONTWAIT);
self.general.recv_poller_with(self, extra_nb, || {
poll_interfaces();
self.with_smol_socket(|socket| {
if !socket.is_active() {
Err(AxError::NotConnected)
} else if !socket.may_recv() {
Ok(0)
} else if socket.recv_queue() == 0 {
Err(AxError::WouldBlock)
} else if options.flags.contains(RecvFlags::PEEK) {
dst.write(
socket
.peek(dst.remaining_mut())
.map_err(|_| ax_err_type!(NotConnected, "not connected?"))?,
)
} else {
socket
.recv(|buf| {
let result = dst.write(buf);
let len = result.unwrap_or(0);
(len, result)
})
.map_err(|_| ax_err_type!(NotConnected, "not connected?"))?
}
})
})
}
fn local_addr(&self) -> AxResult<SocketAddrEx> {
let endpoint = self.with_smol_socket(|socket| {
socket
.local_endpoint()
.map(endpoint_from_ip_endpoint)
.unwrap_or_else(|| *self.bound_endpoint.lock())
});
Ok(SocketAddrEx::Ip(SocketAddr::new(
endpoint
.addr
.map_or_else(|| Ipv4Addr::UNSPECIFIED.into(), Into::into),
endpoint.port,
)))
}
fn peer_addr(&self) -> AxResult<SocketAddrEx> {
self.with_smol_socket(|socket| {
Ok(SocketAddrEx::Ip(
socket
.remote_endpoint()
.ok_or(AxError::NotConnected)?
.into(),
))
})
}
fn shutdown(&self, how: Shutdown) -> AxResult {
if how.has_read() {
self.rx_closed.store(true, Ordering::Release);
self.poll_rx_closed.wake();
}
if let Ok(guard) = self.state.lock(State::Connected) {
guard.transit(State::Closed, || {
if how.has_write() {
self.with_smol_socket(|socket| {
debug!("TCP socket {}: shutting down", self.handle);
socket.close();
});
}
self.unregister_bound_endpoint();
*self.bound_endpoint.lock() = empty_endpoint();
poll_interfaces();
Ok(())
})?;
}
if let Ok(guard) = self.state.lock(State::Listening) {
guard.transit(State::Closed, || {
LISTEN_TABLE.unlisten(self.bound_endpoint()?.port);
self.unregister_bound_endpoint();
*self.bound_endpoint.lock() = empty_endpoint();
poll_interfaces();
Ok(())
})?;
}
Ok(())
}
}
impl Pollable for TcpSocket {
fn poll(&self) -> IoEvents {
poll_interfaces();
let mut events = match self.state() {
State::Connecting => self.poll_connect(),
State::Connected | State::Idle | State::Closed => self.poll_stream(),
State::Listening => self.poll_listener(),
State::Busy => IoEvents::empty(),
};
events.set(IoEvents::RDHUP, self.rx_closed.load(Ordering::Acquire));
events
}
fn register(&self, context: &mut Context<'_>, events: IoEvents) {
if events.intersects(IoEvents::IN | IoEvents::OUT | IoEvents::RDHUP) {
self.general.register_waker(context.waker());
}
if events.contains(IoEvents::RDHUP) {
self.poll_rx_closed.register(context.waker());
}
}
}
impl Drop for TcpSocket {
fn drop(&mut self) {
if let Err(err) = self.shutdown(Shutdown::Both) {
warn!("TCP socket {}: shutdown failed: {}", self.handle, err);
}
self.unregister_bound_endpoint();
SOCKET_SET.remove(self.handle);
poll_interfaces();
}
}
const fn empty_endpoint() -> IpListenEndpoint {
IpListenEndpoint {
addr: None,
port: 0,
}
}
fn endpoint_from_ip_endpoint(endpoint: IpEndpoint) -> IpListenEndpoint {
IpListenEndpoint {
addr: Some(endpoint.addr),
port: endpoint.port,
}
}
fn socket_bound_endpoint(socket: &smol::Socket<'_>) -> IpListenEndpoint {
socket
.local_endpoint()
.map(endpoint_from_ip_endpoint)
.unwrap_or_else(|| socket.listen_endpoint())
}
impl TcpSocket {
fn start_connect(&self, remote_addr: SocketAddr) -> AxResult {
self.state
.lock(State::Idle)
.map_err(|state| {
if state == State::Connecting {
AxError::InProgress
} else {
ax_err_type!(AlreadyConnected)
}
})?
.transit(State::Connecting, || {
self.clear_pending_error();
let remote_endpoint = IpEndpoint::from(remote_addr);
let mut bound_endpoint = *self.bound_endpoint.lock();
if bound_endpoint.addr.is_none() {
bound_endpoint.addr =
Some(get_service().get_source_address(&remote_endpoint.addr));
}
if bound_endpoint.port == 0 {
bound_endpoint.port = get_ephemeral_port()?;
}
info!(
"TCP connection from {} to {}",
bound_endpoint, remote_endpoint
);
let register_bound = !self.bound_registered.load(Ordering::Acquire);
if register_bound {
register_tcp_bound(bound_endpoint)?;
}
let result = {
let mut service = get_service();
let context = service.iface.context();
self.with_smol_socket(|socket| {
socket
.connect(context, remote_endpoint, bound_endpoint)
.map_err(|e| match e {
smol::ConnectError::InvalidState => {
ax_err_type!(AlreadyConnected)
}
smol::ConnectError::Unaddressable => {
ax_err_type!(ConnectionRefused, "unaddressable")
}
})?;
Ok::<(), AxError>(())
})
};
if let Err(err) = result {
if register_bound {
unregister_tcp_bound(bound_endpoint);
}
return Err(err);
}
*self.bound_endpoint.lock() = bound_endpoint;
if register_bound {
self.bound_registered.store(true, Ordering::Release);
}
self.general.set_device_mask(
get_service().device_mask_for(&endpoint_from_ip_endpoint(remote_endpoint)),
);
Ok(())
})
}
fn register_bound_endpoint(&self, endpoint: IpListenEndpoint) -> AxResult {
if !self.bound_registered.load(Ordering::Acquire) {
register_tcp_bound(endpoint)?;
self.bound_registered.store(true, Ordering::Release);
}
Ok(())
}
fn unregister_bound_endpoint(&self) {
if self.bound_registered.swap(false, Ordering::AcqRel) {
unregister_tcp_bound(*self.bound_endpoint.lock());
}
}
}
static TCP_BOUND_PORTS: Lazy<Mutex<HashMap<u16, Vec<Option<smoltcp::wire::IpAddress>>>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
fn register_tcp_bound(endpoint: IpListenEndpoint) -> AxResult {
if endpoint.port == 0 {
return Ok(());
}
let mut bound_ports = TCP_BOUND_PORTS.lock();
let bound_addrs = bound_ports.entry(endpoint.port).or_default();
if bound_addrs
.iter()
.any(|&addr| listen_addrs_conflict(addr, endpoint.addr))
{
return Err(AxError::AddrInUse);
}
bound_addrs.push(endpoint.addr);
Ok(())
}
fn unregister_tcp_bound(endpoint: IpListenEndpoint) {
if endpoint.port != 0 {
let mut bound_ports = TCP_BOUND_PORTS.lock();
if let Some(bound_addrs) = bound_ports.get_mut(&endpoint.port) {
if let Some(idx) = bound_addrs.iter().position(|&addr| addr == endpoint.addr) {
bound_addrs.swap_remove(idx);
}
if bound_addrs.is_empty() {
bound_ports.remove(&endpoint.port);
}
}
}
}
fn tcp_port_available(port: u16) -> bool {
LISTEN_TABLE.can_listen(port) && !TCP_BOUND_PORTS.lock().contains_key(&port)
}
fn listen_addrs_conflict(
a: Option<smoltcp::wire::IpAddress>,
b: Option<smoltcp::wire::IpAddress>,
) -> bool {
a.is_none() || b.is_none() || a == b
}
fn get_ephemeral_port() -> AxResult<u16> {
const PORT_START: u16 = 0xc000;
const PORT_END: u16 = 0xffff;
static CURR: Mutex<u16> = Mutex::new(PORT_START);
let mut curr = CURR.lock();
let mut tries = 0;
while tries <= PORT_END - PORT_START {
let port = *curr;
if *curr == PORT_END {
*curr = PORT_START;
} else {
*curr += 1;
}
if tcp_port_available(port) {
return Ok(port);
}
tries += 1;
}
ax_bail!(AddrInUse, "no available ports");
}
#[cfg(test)]
mod tests {
use core::net::{IpAddr, SocketAddr};
use super::*;
use crate::{
options::{Configurable, SetSocketOption},
test_support::{LOCAL_ADDR, LOCAL_MASK, PEER_ADDR, PEER_MASK, init_split_route_network},
};
#[test]
fn connect_uses_peer_route_for_device_mask() {
init_split_route_network();
let socket = TcpSocket::new();
let nonblocking = true;
socket
.set_option(SetSocketOption::NonBlocking(&nonblocking))
.unwrap();
socket
.bind(SocketAddrEx::Ip(SocketAddr::new(IpAddr::V4(LOCAL_ADDR), 0)))
.unwrap();
assert_eq!(socket.general.device_mask(), LOCAL_MASK);
socket
.start_connect(SocketAddr::new(IpAddr::V4(PEER_ADDR), 80))
.unwrap();
assert_eq!(socket.general.device_mask(), PEER_MASK);
}
}