use alloc::{boxed::Box, vec::Vec};
use core::{
any::Any,
fmt::{self, Debug},
net::SocketAddr,
task::Context,
};
#[cfg(feature = "vsock")]
use ax_driver::prelude::VsockAddr;
use ax_errno::{AxError, AxResult, LinuxError};
use ax_io::prelude::*;
use axpoll::{IoEvents, Pollable};
use bitflags::bitflags;
use enum_dispatch::enum_dispatch;
#[cfg(feature = "vsock")]
use crate::vsock::VsockSocket;
use crate::{
options::{Configurable, GetSocketOption, SetSocketOption},
tcp::TcpSocket,
udp::UdpSocket,
unix::{UnixSocket, UnixSocketAddr},
};
#[derive(Clone, Debug)]
pub enum SocketAddrEx {
Ip(SocketAddr),
Unix(UnixSocketAddr),
#[cfg(feature = "vsock")]
Vsock(VsockAddr),
}
impl SocketAddrEx {
pub fn into_ip(self) -> AxResult<SocketAddr> {
match self {
SocketAddrEx::Ip(addr) => Ok(addr),
SocketAddrEx::Unix(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)),
#[cfg(feature = "vsock")]
SocketAddrEx::Vsock(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)),
}
}
pub fn into_unix(self) -> AxResult<UnixSocketAddr> {
match self {
SocketAddrEx::Unix(addr) => Ok(addr),
SocketAddrEx::Ip(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)),
#[cfg(feature = "vsock")]
SocketAddrEx::Vsock(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)),
}
}
#[cfg(feature = "vsock")]
pub fn into_vsock(self) -> AxResult<VsockAddr> {
match self {
SocketAddrEx::Ip(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)),
SocketAddrEx::Unix(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)),
SocketAddrEx::Vsock(addr) => Ok(addr),
}
}
}
bitflags! {
#[derive(Default, Debug, Clone, Copy)]
pub struct SendFlags: u32 {
}
}
bitflags! {
#[derive(Default, Debug, Clone, Copy)]
pub struct RecvFlags: u32 {
const PEEK = 0x01;
const TRUNCATE = 0x02;
}
}
pub type CMsgData = Box<dyn Any + Send + Sync>;
#[derive(Default, Debug)]
pub struct SendOptions {
pub to: Option<SocketAddrEx>,
pub flags: SendFlags,
pub cmsg: Vec<CMsgData>,
}
#[derive(Default)]
pub struct RecvOptions<'a> {
pub from: Option<&'a mut SocketAddrEx>,
pub flags: RecvFlags,
pub cmsg: Option<&'a mut Vec<CMsgData>>,
}
impl Debug for RecvOptions<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RecvOptions")
.field("from", &self.from)
.field("flags", &self.flags)
.finish()
}
}
#[derive(Debug, Clone, Copy)]
pub enum Shutdown {
Read,
Write,
Both,
}
impl Shutdown {
pub fn has_read(&self) -> bool {
matches!(self, Shutdown::Read | Shutdown::Both)
}
pub fn has_write(&self) -> bool {
matches!(self, Shutdown::Write | Shutdown::Both)
}
}
#[enum_dispatch]
pub trait SocketOps: Configurable {
fn bind(&self, local_addr: SocketAddrEx) -> AxResult;
fn connect(&self, remote_addr: SocketAddrEx) -> AxResult;
fn listen(&self) -> AxResult {
Err(AxError::OperationNotSupported)
}
fn accept(&self) -> AxResult<Socket> {
Err(AxError::OperationNotSupported)
}
fn send(&self, src: impl Read + IoBuf, options: SendOptions) -> AxResult<usize>;
fn recv(&self, dst: impl Write + IoBufMut, options: RecvOptions<'_>) -> AxResult<usize>;
fn local_addr(&self) -> AxResult<SocketAddrEx>;
fn peer_addr(&self) -> AxResult<SocketAddrEx>;
fn shutdown(&self, how: Shutdown) -> AxResult;
}
#[enum_dispatch(Configurable, SocketOps)]
pub enum Socket {
Udp(UdpSocket),
Tcp(TcpSocket),
Unix(UnixSocket),
#[cfg(feature = "vsock")]
Vsock(VsockSocket),
}
impl Pollable for Socket {
fn poll(&self) -> IoEvents {
match self {
Socket::Tcp(tcp) => tcp.poll(),
Socket::Udp(udp) => udp.poll(),
Socket::Unix(unix) => unix.poll(),
#[cfg(feature = "vsock")]
Socket::Vsock(vsock) => vsock.poll(),
}
}
fn register(&self, context: &mut Context<'_>, events: IoEvents) {
match self {
Socket::Tcp(tcp) => tcp.register(context, events),
Socket::Udp(udp) => udp.register(context, events),
Socket::Unix(unix) => unix.register(context, events),
#[cfg(feature = "vsock")]
Socket::Vsock(vsock) => vsock.register(context, events),
}
}
}