use alloc::vec;
use core::{
net::{IpAddr, Ipv4Addr, SocketAddr},
task::Context,
};
use ax_errno::{AxError, AxResult, ax_bail, ax_err_type};
use ax_io::prelude::*;
use ax_sync::Mutex;
use axpoll::{IoEvents, Pollable};
use smoltcp::{
iface::SocketHandle,
phy::PacketMeta,
socket::udp::{self as smol, UdpMetadata},
storage::PacketMetadata,
wire::{IpAddress, IpEndpoint, IpListenEndpoint},
};
use spin::RwLock;
use crate::{
RecvFlags, RecvOptions, SOCKET_SET, SendOptions, Shutdown, SocketAddrEx, SocketOps,
consts::{UDP_RX_BUF_LEN, UDP_TX_BUF_LEN},
general::GeneralOptions,
get_service,
options::{Configurable, GetSocketOption, SetSocketOption},
poll_interfaces,
};
pub(crate) fn new_udp_socket() -> smol::Socket<'static> {
smol::Socket::new(
smol::PacketBuffer::new(vec![PacketMetadata::EMPTY; 256], vec![0; UDP_RX_BUF_LEN]),
smol::PacketBuffer::new(vec![PacketMetadata::EMPTY; 256], vec![0; UDP_TX_BUF_LEN]),
)
}
pub struct UdpSocket {
handle: SocketHandle,
local_addr: RwLock<Option<IpEndpoint>>,
peer_addr: RwLock<Option<(IpEndpoint, IpAddress)>>,
general: GeneralOptions,
}
impl UdpSocket {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let socket = new_udp_socket();
let handle = SOCKET_SET.add(socket);
Self {
handle,
local_addr: RwLock::new(None),
peer_addr: RwLock::new(None),
general: GeneralOptions::new(),
}
}
fn with_smol_socket<R>(&self, f: impl FnOnce(&mut smol::Socket) -> R) -> R {
SOCKET_SET.with_socket_mut::<smol::Socket, _, _>(self.handle, f)
}
fn remote_endpoint(&self) -> AxResult<(IpEndpoint, IpAddress)> {
match self.peer_addr.try_read() {
Some(addr) => addr.ok_or(AxError::NotConnected),
None => Err(AxError::NotConnected),
}
}
}
impl Configurable for UdpSocket {
fn get_option_inner(&self, option: &mut GetSocketOption) -> AxResult<bool> {
use GetSocketOption as O;
if self.general.get_option_inner(option)? {
return Ok(true);
}
match option {
O::Ttl(ttl) => {
self.with_smol_socket(|socket| {
**ttl = socket.hop_limit().unwrap_or(64);
});
}
O::SendBuffer(size) => {
**size = UDP_TX_BUF_LEN;
}
O::ReceiveBuffer(size) => {
**size = UDP_RX_BUF_LEN;
}
_ => return Ok(false),
}
Ok(true)
}
fn set_option_inner(&self, option: SetSocketOption) -> AxResult<bool> {
use SetSocketOption as O;
if self.general.set_option_inner(option)? {
return Ok(true);
}
match option {
O::Ttl(ttl) => {
self.with_smol_socket(|socket| {
socket.set_hop_limit(Some(*ttl));
});
}
_ => return Ok(false),
}
Ok(true)
}
}
impl SocketOps for UdpSocket {
fn bind(&self, local_addr: SocketAddrEx) -> AxResult {
let mut local_addr = local_addr.into_ip()?;
let mut guard = self.local_addr.write();
if local_addr.port() == 0 {
local_addr.set_port(get_ephemeral_port()?);
}
if guard.is_some() {
ax_bail!(InvalidInput, "already bound");
}
let local_endpoint = IpEndpoint::from(local_addr);
let endpoint = IpListenEndpoint {
addr: (!local_endpoint.addr.is_unspecified()).then_some(local_endpoint.addr),
port: local_endpoint.port,
};
if !self.general.reuse_address() {
SOCKET_SET.bind_check(local_endpoint.addr, local_endpoint.port)?;
}
self.with_smol_socket(|socket| {
socket.bind(endpoint).map_err(|e| match e {
smol::BindError::InvalidState => ax_err_type!(InvalidInput, "already bound"),
smol::BindError::Unaddressable => ax_err_type!(ConnectionRefused, "unaddressable"),
})
})?;
self.general
.set_device_mask(get_service().device_mask_for(&endpoint));
*guard = Some(local_endpoint);
info!("UDP socket {}: bound on {}", self.handle, endpoint);
Ok(())
}
fn connect(&self, remote_addr: SocketAddrEx) -> AxResult {
let remote_addr = remote_addr.into_ip()?;
let mut guard = self.peer_addr.write();
if self.local_addr.read().is_none() {
self.bind(SocketAddrEx::Ip(SocketAddr::new(
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
0,
)))?;
}
let remote_addr = IpEndpoint::from(remote_addr);
let src = get_service().get_source_address(&remote_addr.addr);
*guard = Some((remote_addr, src));
debug!("UDP socket {}: connected to {}", self.handle, remote_addr);
Ok(())
}
fn send(&self, mut src: impl Read + IoBuf, options: SendOptions) -> AxResult<usize> {
let (remote_addr, source_addr) = match options.to {
Some(addr) => {
let addr = IpEndpoint::from(addr.into_ip()?);
let src = get_service().get_source_address(&addr.addr);
(addr, src)
}
None => self.remote_endpoint()?,
};
if remote_addr.port == 0 || remote_addr.addr.is_unspecified() {
ax_bail!(InvalidInput, "invalid address");
}
if self.local_addr.read().is_none() {
self.bind(SocketAddrEx::Ip(SocketAddr::new(
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
0,
)))?;
}
self.general.send_poller(self, || {
poll_interfaces();
self.with_smol_socket(|socket| {
if !socket.is_open() {
Err(ax_err_type!(NotConnected))
} else if !socket.can_send() {
Err(AxError::WouldBlock)
} else {
let buf = socket
.send(
src.remaining(),
UdpMetadata {
endpoint: remote_addr,
local_address: Some(source_addr),
meta: PacketMeta::default(),
},
)
.map_err(|e| match e {
smol::SendError::BufferFull => AxError::WouldBlock,
smol::SendError::Unaddressable => {
ax_err_type!(ConnectionRefused, "unaddressable")
}
})?;
let read = src.read(buf)?;
assert_eq!(read, buf.len());
Ok(read)
}
})
})
}
fn recv(&self, mut dst: impl Write, options: RecvOptions) -> AxResult<usize> {
if self.local_addr.read().is_none() {
ax_bail!(NotConnected);
}
enum ExpectedRemote<'a> {
Any(&'a mut SocketAddrEx),
Expecting(IpEndpoint),
}
let mut expected_remote = match options.from {
Some(addr) => ExpectedRemote::Any(addr),
None => ExpectedRemote::Expecting(self.remote_endpoint()?.0),
};
self.general.recv_poller(self, || {
poll_interfaces();
self.with_smol_socket(|socket| {
if !socket.is_open() {
Err(ax_err_type!(NotConnected))
} else if !socket.can_recv() {
Err(AxError::WouldBlock)
} else {
let result = if options.flags.contains(RecvFlags::PEEK) {
socket.peek().map(|(data, meta)| (data, *meta))
} else {
socket.recv()
};
match result {
Ok((src, meta)) => {
match &mut expected_remote {
ExpectedRemote::Any(remote_addr) => {
**remote_addr = SocketAddrEx::Ip(meta.endpoint.into());
}
ExpectedRemote::Expecting(expected) => {
if (!expected.addr.is_unspecified()
&& expected.addr != meta.endpoint.addr)
|| (expected.port != 0
&& expected.port != meta.endpoint.port)
{
return Err(AxError::WouldBlock);
}
}
}
let read = dst.write(src)?;
if read < src.len() {
warn!("UDP message truncated: {} -> {} bytes", src.len(), read);
}
Ok(if options.flags.contains(RecvFlags::TRUNCATE) {
src.len()
} else {
read
})
}
Err(smol::RecvError::Exhausted) => Err(AxError::WouldBlock),
Err(smol::RecvError::Truncated) => {
unreachable!("UDP socket recv never returns Err(Truncated)")
}
}
}
})
})
}
fn local_addr(&self) -> AxResult<SocketAddrEx> {
match self.local_addr.try_read() {
Some(addr) => addr
.map(Into::into)
.map(SocketAddrEx::Ip)
.ok_or(AxError::NotConnected),
None => Err(AxError::NotConnected),
}
}
fn peer_addr(&self) -> AxResult<SocketAddrEx> {
self.remote_endpoint()
.map(|it| it.0.into())
.map(SocketAddrEx::Ip)
}
fn shutdown(&self, _how: Shutdown) -> AxResult {
poll_interfaces();
self.with_smol_socket(|socket| {
debug!("UDP socket {}: shutting down", self.handle);
socket.close();
});
Ok(())
}
}
impl Pollable for UdpSocket {
fn poll(&self) -> IoEvents {
poll_interfaces();
if self.local_addr.read().is_none() {
return IoEvents::empty();
}
let mut events = IoEvents::empty();
self.with_smol_socket(|socket| {
events.set(IoEvents::IN, socket.can_recv());
events.set(IoEvents::OUT, socket.can_send());
});
events
}
fn register(&self, context: &mut Context<'_>, events: IoEvents) {
if events.intersects(IoEvents::IN | IoEvents::OUT) {
self.general.register_waker(context.waker());
}
}
}
impl Drop for UdpSocket {
fn drop(&mut self) {
self.shutdown(Shutdown::Both).ok();
SOCKET_SET.remove(self.handle);
}
}
fn get_ephemeral_port() -> AxResult<u16> {
const PORT_START: u16 = 0xc000;
const PORT_END: u16 = 0xffff;
static CURR: Mutex<u16> = Mutex::new(PORT_START);
let mut curr = CURR.lock();
let port = *curr;
if *curr == PORT_END {
*curr = PORT_START;
} else {
*curr += 1;
}
Ok(port)
}