use ax_errno::{AxError, AxResult, LinuxError};
use axnet::options::{
Configurable, GetSocketOption, SetSocketOption, TcpInfo, TcpInfoOptions, TcpState,
};
use linux_raw_sys::net::{
IPPROTO_IPV6, IPV6_V6ONLY, TCP_INFO, TCPI_OPT_ECN, TCPI_OPT_ECN_SEEN, TCPI_OPT_SACK,
TCPI_OPT_SYN_DATA, TCPI_OPT_TIMESTAMPS, TCPI_OPT_WSCALE, socklen_t, tcp_info,
};
use starry_vm::vm_write_slice;
use crate::{
file::{FileLike, Socket, netlink::NetlinkSocket},
mm::{UserConstPtr, UserPtr},
};
const PROTO_TCP: u32 = linux_raw_sys::net::IPPROTO_TCP as u32;
const PROTO_IP: u32 = linux_raw_sys::net::IPPROTO_IP as u32;
fn read_int_sockopt(optval: UserConstPtr<u8>, optlen: socklen_t) -> AxResult<i32> {
if (optlen as usize) < size_of::<i32>() {
return Err(AxError::InvalidInput);
}
Ok(*optval.cast::<i32>().get_as_ref()?)
}
fn tcp_state_to_linux(state: TcpState) -> u8 {
match state {
TcpState::Established => 1,
TcpState::SynSent => 2,
TcpState::SynReceived => 3,
TcpState::FinWait1 => 4,
TcpState::FinWait2 => 5,
TcpState::TimeWait => 6,
TcpState::Closed => 7,
TcpState::CloseWait => 8,
TcpState::LastAck => 9,
TcpState::Listen => 10,
TcpState::Closing => 11,
}
}
fn tcp_options_to_linux(options: TcpInfoOptions) -> u8 {
let mut raw = 0u8;
if options.contains(TcpInfoOptions::TIMESTAMPS) {
raw |= TCPI_OPT_TIMESTAMPS as u8;
}
if options.contains(TcpInfoOptions::SACK) {
raw |= TCPI_OPT_SACK as u8;
}
if options.contains(TcpInfoOptions::WSCALE) {
raw |= TCPI_OPT_WSCALE as u8;
}
if options.contains(TcpInfoOptions::ECN) {
raw |= TCPI_OPT_ECN as u8;
}
if options.contains(TcpInfoOptions::ECN_SEEN) {
raw |= TCPI_OPT_ECN_SEEN as u8;
}
if options.contains(TcpInfoOptions::SYN_DATA) {
raw |= TCPI_OPT_SYN_DATA as u8;
}
raw
}
fn to_linux_tcp_info(info: TcpInfo) -> tcp_info {
let mut raw = unsafe { core::mem::zeroed::<tcp_info>() };
raw.tcpi_state = tcp_state_to_linux(info.state);
raw.tcpi_ca_state = info.ca_state;
raw.tcpi_retransmits = info.retransmits;
raw.tcpi_probes = info.probes;
raw.tcpi_backoff = info.backoff;
raw.tcpi_options = tcp_options_to_linux(info.options);
raw.set_tcpi_snd_wscale(info.snd_wscale);
raw.set_tcpi_rcv_wscale(info.rcv_wscale);
raw.tcpi_rto = info.rto_micros;
raw.tcpi_ato = info.ato_micros;
raw.tcpi_snd_mss = info.snd_mss;
raw.tcpi_rcv_mss = info.rcv_mss;
raw.tcpi_notsent_bytes = info.notsent_bytes;
raw.tcpi_pmtu = info.pmtu;
raw.tcpi_advmss = info.advmss;
raw.tcpi_snd_cwnd = info.snd_cwnd;
raw.tcpi_reordering = info.reordering;
raw.tcpi_rcv_space = info.rcv_space;
raw.tcpi_snd_wnd = info.snd_wnd;
raw.tcpi_rcv_wnd = info.rcv_wnd;
raw
}
fn write_tcp_info(socket: &Socket, optval: UserPtr<u8>, optlen: &mut socklen_t) -> AxResult<()> {
let mut info = TcpInfo::default();
socket.get_option(GetSocketOption::TcpInfo(&mut info))?;
let write_len = (*optlen as usize).min(size_of::<tcp_info>());
*optlen = write_len as socklen_t;
if write_len == 0 {
return Ok(());
}
let raw = to_linux_tcp_info(info);
let raw_bytes = unsafe {
core::slice::from_raw_parts(
(&raw as *const tcp_info).cast::<u8>(),
size_of::<tcp_info>(),
)
};
Ok(vm_write_slice(optval.as_ptr(), &raw_bytes[..write_len])?)
}
mod conv {
use ax_errno::{AxError, AxResult};
use axnet::options::UnixCredentials;
use linux_raw_sys::{general::timeval, net::ucred};
use crate::time::TimeValueLike;
pub struct Int<T>(T);
impl<T: TryFrom<i32> + TryInto<i32>> Int<T> {
pub fn sys_to_rust(val: i32) -> AxResult<T> {
T::try_from(val).map_err(|_| AxError::InvalidInput)
}
pub fn rust_to_sys(val: T) -> AxResult<i32> {
val.try_into().map_err(|_| AxError::InvalidInput)
}
}
pub struct IntBool;
impl IntBool {
pub fn sys_to_rust(val: i32) -> AxResult<bool> {
Ok(val != 0)
}
pub fn rust_to_sys(val: bool) -> AxResult<i32> {
Ok(val as _)
}
}
pub struct Duration;
impl Duration {
pub fn sys_to_rust(val: timeval) -> AxResult<core::time::Duration> {
val.try_into_time_value()
}
pub fn rust_to_sys(val: core::time::Duration) -> AxResult<timeval> {
Ok(timeval::from_time_value(val))
}
}
pub struct Ucred;
impl Ucred {
pub fn sys_to_rust(val: ucred) -> AxResult<UnixCredentials> {
Ok(UnixCredentials {
pid: val.pid,
uid: val.uid,
gid: val.gid,
})
}
pub fn rust_to_sys(val: UnixCredentials) -> AxResult<ucred> {
Ok(ucred {
pid: val.pid,
uid: val.uid,
gid: val.gid,
})
}
}
}
macro_rules! call_dispatch {
($dispatch:ident, $pat:expr) => {{
use conv::*;
use linux_raw_sys::net::*;
call_dispatch! {
$dispatch, $pat,
(SOL_SOCKET, SO_REUSEADDR) => ReuseAddress as IntBool,
(SOL_SOCKET, SO_ERROR) => Error,
(SOL_SOCKET, SO_DONTROUTE) => DontRoute as IntBool, (SOL_SOCKET, SO_SNDBUF) => SendBuffer as Int<usize>, (SOL_SOCKET, SO_RCVBUF) => ReceiveBuffer as Int<usize>, (SOL_SOCKET, SO_KEEPALIVE) => KeepAlive as IntBool,
(SOL_SOCKET, SO_RCVTIMEO) => ReceiveTimeout as Duration,
(SOL_SOCKET, SO_SNDTIMEO) => SendTimeout as Duration,
(SOL_SOCKET, SO_PASSCRED) => PassCredentials as IntBool, (SOL_SOCKET, SO_PEERCRED) => PeerCredentials as Ucred,
(SOL_SOCKET, SO_TYPE) => SocketType as Int<i32>, (SOL_SOCKET, SO_PROTOCOL) => SocketProtocol as Int<i32>, (SOL_SOCKET, SO_DOMAIN) => SocketDomain as Int<i32>,
(PROTO_TCP, TCP_NODELAY) => NoDelay as IntBool,
(PROTO_TCP, TCP_MAXSEG) => MaxSegment as Int<usize>, (PROTO_TCP, TCP_KEEPIDLE) => TcpKeepIdle as Int<u32>,
(PROTO_TCP, TCP_KEEPINTVL) => TcpKeepInterval as Int<u32>,
(PROTO_TCP, TCP_KEEPCNT) => TcpKeepCount as Int<u32>,
(PROTO_TCP, TCP_USER_TIMEOUT) => TcpUserTimeout as Int<u32>,
(PROTO_IP, IP_TTL) => Ttl as Int<u8>,
(PROTO_IP, IP_RECVERR) => RecvErr as IntBool, }
}};
($dispatch:ident, $in:expr, $($pat:pat => $which:ident $(as $conv:ty)?),* $(,)?) => {
match $in {
$(
$pat => {
dispatch!($which $(as $conv)?);
}
)*
_ => return Err(AxError::from(LinuxError::ENOPROTOOPT)),
}
}
}
pub fn sys_getsockopt(
fd: i32,
level: u32,
optname: u32,
optval: UserPtr<u8>,
optlen: UserPtr<socklen_t>,
) -> AxResult<isize> {
let optlen = optlen.get_as_mut()?;
debug!(
"sys_getsockopt <= fd: {}, level: {}, optname: {}, optval: {:?}, optlen: {}",
fd,
level,
optname,
optval.address(),
optlen,
);
fn get<'a, T: 'static>(val: UserPtr<u8>, len: &mut socklen_t) -> AxResult<&'a mut T> {
if (*len as usize) < size_of::<T>() {
return Err(AxError::InvalidInput);
}
*len = size_of::<T>() as socklen_t;
val.cast().get_as_mut()
}
let socket = Socket::from_fd(fd)?;
{
use axnet::Socket as SocketInner;
use linux_raw_sys::net::{SO_TYPE, SOCK_DGRAM, SOCK_RAW, SOCK_STREAM, SOL_SOCKET};
if level == SOL_SOCKET && optname == SO_TYPE {
if *optlen == 0 {
return Ok(0);
}
let so_type = match &**socket {
SocketInner::Tcp(_) => SOCK_STREAM,
SocketInner::Udp(_) => SOCK_DGRAM,
SocketInner::Raw(_) => SOCK_RAW,
SocketInner::Unix(_) => SOCK_STREAM,
#[cfg(feature = "vsock")]
SocketInner::Vsock(_) => SOCK_STREAM,
};
*get(optval, optlen)? = so_type as i32;
return Ok(0);
}
}
if level == IPPROTO_IPV6 as u32 && optname == IPV6_V6ONLY {
*get::<i32>(optval, optlen)? = 0;
return Ok(0);
}
if level == PROTO_TCP && optname == TCP_INFO {
write_tcp_info(&socket, optval, optlen)?;
return Ok(0);
}
macro_rules! dispatch {
($which:ident) => {
socket.get_option(GetSocketOption::$which(get(optval, optlen)?))?;
};
($which:ident as $conv:ty) => {
let mut val = Default::default();
socket.get_option(GetSocketOption::$which(&mut val))?;
*get(optval, optlen)? = <$conv>::rust_to_sys(val)?;
};
}
call_dispatch!(dispatch, (level, optname));
Ok(0)
}
pub fn sys_setsockopt(
fd: i32,
level: u32,
optname: u32,
optval: UserConstPtr<u8>,
optlen: socklen_t,
) -> AxResult<isize> {
debug!(
"sys_setsockopt <= fd: {}, level: {}, optname: {}, optval: {:?}, optlen: {}",
fd,
level,
optname,
optval.address(),
optlen
);
if let Ok(socket) = NetlinkSocket::from_fd(fd) {
use linux_raw_sys::net::{
SO_ATTACH_FILTER, SO_LOCK_FILTER, SO_PASSCRED, SO_RCVBUF, SO_RCVBUFFORCE, SOL_SOCKET,
};
match (level, optname) {
(SOL_SOCKET, SO_ATTACH_FILTER | SO_LOCK_FILTER) => {
return Ok(0);
}
(SOL_SOCKET, SO_RCVBUF | SO_RCVBUFFORCE) => {
let value = read_int_sockopt(optval, optlen)?;
socket.set_receive_buffer_size(value.max(0) as usize);
return Ok(0);
}
(SOL_SOCKET, SO_PASSCRED) => {
let value = read_int_sockopt(optval, optlen)?;
socket.set_passcred(value != 0);
return Ok(0);
}
_ => return Err(AxError::from(LinuxError::ENOPROTOOPT)),
}
}
{
use linux_raw_sys::net::{SO_BROADCAST, SOL_SOCKET};
if (level, optname) == (SOL_SOCKET, SO_BROADCAST) {
let _ = read_int_sockopt(optval, optlen)?;
return Ok(0);
}
}
fn get<'a, T: 'static>(val: UserConstPtr<u8>, len: socklen_t) -> AxResult<&'a T> {
if len as usize != size_of::<T>() {
return Err(AxError::InvalidInput);
}
val.cast().get_as_ref()
}
let socket = Socket::from_fd(fd)?;
if level == PROTO_TCP && optname == TCP_INFO {
return Err(AxError::from(LinuxError::ENOPROTOOPT));
}
if level == IPPROTO_IPV6 as u32 && optname == IPV6_V6ONLY {
let _ = *get::<i32>(optval, optlen)?;
return Ok(0);
}
macro_rules! dispatch {
($which:ident) => {
socket.set_option(SetSocketOption::$which(get(optval, optlen)?))?;
};
($which:ident as $conv:ty) => {
let mut val = <$conv>::sys_to_rust(*get(optval, optlen)?)?;
socket.set_option(SetSocketOption::$which(&mut val))?;
};
}
call_dispatch!(dispatch, (level, optname));
Ok(0)
}